├── .gitignore
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── configs
├── __init__.py
├── base_config.py
├── stylegan2_config.py
├── stylegan2_finetune_config.py
├── stylegan3_config.py
├── stylegan_config.py
└── volumegan_ffhq_config.py
├── convert_model.py
├── converters
├── __init__.py
├── base_converter.py
├── pggan_converter.py
├── stylegan2_converter.py
├── stylegan2ada_pth_converter.py
├── stylegan2ada_tf_converter.py
├── stylegan3_converter.py
└── stylegan_converter.py
├── datasets
├── __init__.py
├── base_dataset.py
├── data_loaders
│ ├── __init__.py
│ ├── base_data_loader.py
│ ├── dali_batch_iterator.py
│ ├── dali_data_loader.py
│ ├── dali_pipeline.py
│ ├── distributed_sampler.py
│ └── iter_data_loader.py
├── file_readers
│ ├── __init__.py
│ ├── base_reader.py
│ ├── directory_reader.py
│ ├── lmdb_reader.py
│ ├── tar_reader.py
│ └── zip_reader.py
├── image_dataset.py
├── paired_dataset.py
└── transformations
│ ├── __init__.py
│ ├── affine_transform.py
│ ├── base_transformation.py
│ ├── blur_and_sharpen.py
│ ├── crop.py
│ ├── decode.py
│ ├── flip.py
│ ├── hsv_jittering.py
│ ├── identity.py
│ ├── jpeg_compress.py
│ ├── misc.py
│ ├── normalize.py
│ ├── region_brightness.py
│ ├── resize.py
│ └── utils
│ ├── __init__.py
│ ├── affine_transform.py
│ └── polygon.py
├── docs
├── assets
│ ├── bootstrap.min.css
│ ├── comparison.png
│ ├── font.css
│ ├── framework.png
│ ├── freezed.png
│ ├── genforce.png
│ ├── giraffe.png
│ ├── graf.png
│ ├── hologan.png
│ ├── pigan.png
│ ├── style.css
│ └── teaser.png
└── index.html
├── dump_command_args.py
├── metrics
├── __init__.py
├── base_gan_metric.py
├── base_metric.py
├── equivariance.py
├── fid.py
├── gan_pr.py
├── gan_snapshot.py
├── inception_score.py
├── intra_class_fid.py
├── kid.py
└── utils.py
├── models
├── __init__.py
├── ghfeat_encoder.py
├── inception_model.py
├── perceptual_model.py
├── pggan_discriminator.py
├── pggan_generator.py
├── rendering
│ ├── __init__.py
│ ├── hierarchicle_sampling.py
│ ├── points_sampling.py
│ ├── renderer.py
│ └── utils.py
├── stylegan2_discriminator.py
├── stylegan2_generator.py
├── stylegan3_generator.py
├── stylegan_discriminator.py
├── stylegan_generator.py
├── test.py
├── utils
│ ├── __init__.py
│ └── ops.py
├── volumegan_discriminator.py
└── volumegan_generator.py
├── prepare_dataset.py
├── render.py
├── requirements
├── convert.txt
├── develop.txt
└── minimal.txt
├── runners
├── __init__.py
├── augmentations
│ ├── __init__.py
│ ├── ada_aug.py
│ └── no_aug.py
├── base_runner.py
├── controllers
│ ├── __init__.py
│ ├── ada_aug_controller.py
│ ├── base_controller.py
│ ├── batch_visualizer.py
│ ├── cache_cleaner.py
│ ├── checkpointer.py
│ ├── dataset_visualizer.py
│ ├── evaluator.py
│ ├── lr_scheduler.py
│ ├── progress_scheduler.py
│ ├── running_logger.py
│ └── timer.py
├── losses
│ ├── __init__.py
│ ├── base_loss.py
│ ├── stylegan2_loss.py
│ ├── stylegan3_loss.py
│ ├── stylegan_loss.py
│ └── volumegan_loss.py
├── stylegan2_runner.py
├── stylegan3_runner.py
├── stylegan_runner.py
├── utils
│ ├── __init__.py
│ ├── freezer.py
│ ├── optimizer.py
│ ├── profiler.py
│ └── running_stats.py
└── volumegan_runner.py
├── scripts
├── dist_train.sh
├── kill_zombies.sh
├── test_converters.sh
├── test_metrics.sh
└── training_demos
│ ├── stylegan2_ffhq1024.sh
│ ├── stylegan2_ffhq256.sh
│ ├── stylegan2_ffhq512.sh
│ ├── stylegan2_lsun_bedroom256.sh
│ ├── stylegan2ada_afhq512.sh
│ ├── stylegan2ada_cifar10.sh
│ ├── stylegan2ada_ffhq1024.sh
│ ├── stylegan2ada_ffhq256.sh
│ ├── stylegan3r_afhq512.sh
│ ├── stylegan3r_ffhqu1024.sh
│ ├── stylegan3r_ffhqu256.sh
│ ├── stylegan3t_afhq512.sh
│ ├── stylegan3t_ffhqu1024.sh
│ ├── stylegan3t_ffhqu256.sh
│ ├── stylegan_ffhq1024.sh
│ ├── stylegan_ffhq256.sh
│ ├── stylegan_ffhq512.sh
│ ├── stylegan_lsun_bedroom256.sh
│ └── volumegan_ffhq256.sh
├── test_metrics.py
├── third_party
├── __init__.py
├── stylegan2_official_ops
│ ├── README.md
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── custom_ops.py
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── misc.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
└── stylegan3_official_ops
│ ├── README.md
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── custom_ops.py
│ ├── filtered_lrelu.cpp
│ ├── filtered_lrelu.cu
│ ├── filtered_lrelu.h
│ ├── filtered_lrelu.py
│ ├── filtered_lrelu_ns.cu
│ ├── filtered_lrelu_rd.cu
│ ├── filtered_lrelu_wr.cu
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── misc.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
├── train.py
├── unit_tests.py
└── utils
├── __init__.py
├── dist_utils.py
├── file_transmitters
├── __init__.py
├── base_file_transmitter.py
├── dummy_file_transmitter.py
└── local_file_transmitter.py
├── formatting_utils.py
├── image_utils.py
├── loggers
├── __init__.py
├── base_logger.py
├── dummy_logger.py
├── normal_logger.py
├── rich_logger.py
└── test.py
├── misc.py
├── parsing_utils.py
├── tf_utils.py
└── visualizers
├── __init__.py
├── gif_visualizer.py
├── grid_visualizer.py
├── html_visualizer.py
├── test.py
└── video_visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore compiled files.
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # Ignore files created by IDEs.
6 | /.vscode/
7 | /.idea/
8 | .ipynb_*/
9 | *.ipynb
10 | .DS_Store
11 | *.sw[pon]
12 |
13 | # Ignore result files within default working directory.
14 | /work_dirs/
15 |
16 | # Ignore data files.
17 | data/
18 | *.npy
19 | *.tar
20 | *.zip
21 | *.mdb
22 |
23 | # Ignore network files.
24 | checkpoints/
25 | *.pth
26 | *.pt
27 | *.pkl
28 | *.h5
29 | *.dat
30 |
31 | # Ignore media files.
32 | results/
33 | *.jpg
34 | *.png
35 | *.jpeg
36 | *.gif
37 | *.avi
38 | *.mp4
39 |
40 | # Ignore log files.
41 | resources/
42 | events/
43 | profile/
44 | *.txt
45 | *.json
46 | *.log
47 | *.html
48 | events.*
49 |
50 | # Files that should not be ignored.
51 | !/requirements/*
52 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Shen"
5 | given-names: "Yujun"
6 | - family-names: "Zhang"
7 | given-names: "Zhiyi"
8 | - family-names: "Yang"
9 | given-names: "Dingdong"
10 | - family-names: "Xu"
11 | given-names: "Yinghao"
12 | - family-names: "Yang"
13 | given-names: "Ceyuan"
14 | - family-names: "Zhu"
15 | given-names: "Jiapeng"
16 | title: "Hammer: An Efficient Toolkit for Training Deep Models"
17 | version: 1.0.0
18 | date-released: 2022-02-08
19 | url: "https://github.com/bytedance/Hammer"
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations
2 |
3 | 
4 | **Figure:** *Framework of VolumeGAN.*
5 |
6 | > **3D-aware Image Synthesis via Learning Structural and Textural Representations**
7 | > Yinghao Xu, Sida Peng, Ceyuan Yang, Yujun Shen, Bolei Zhou
8 | > *Computer Vision and Pattern Recognition (CVPR), 2022*
9 |
10 | [[Paper](https://arxiv.org/pdf/2112.10759.pdf)]
11 | [[Project Page](https://genforce.github.io/volumegan/)]
12 | [[Demo](https://www.youtube.com/watch?v=p85TVGJBMFc)]
13 |
14 | This paper aims at achieving high-fidelity 3D-aware images synthesis. We propose a novel framework, termed as VolumeGAN, for synthesizing images under different camera views, through explicitly learning a structural representation and a textural representation. We first learn a feature volume to represent the underlying structure, which is then converted to a feature field using a NeRF-like model. The feature field is further accumulated into a 2D feature map as the textural representation, followed by a neural renderer for appearance synthesis. Such a design enables independent control of the shape and the appearance. Extensive experiments on a wide range of datasets show that our approach achieves sufficiently higher image quality and better 3D control than the previous methods.
15 |
16 | ## Usage
17 |
18 | ### Setup
19 |
20 | This repository is based on [Hammer](https://github.com/bytedance/Hammer), where you can find detailed instructions on environmental setup.
21 |
22 | ### Test Demo
23 |
24 | ```shell
25 | python render.py \
26 | --work_dir ${WORK_DIR} \
27 | --checkpoint ${MODEL_PATH} \
28 | --num ${NUM} \
29 | --seed ${SEED} \
30 | --render_mode ${RENDER_MODE} \
31 | --generate_html ${SAVE_HTML} \
32 | volumegan-ffhq
33 | ```
34 |
35 | where
36 |
37 | - `WORK_DIR` refers to the path to save the results.
38 | - `MODEL_PATH` refers to the path of the pretrained model, regarding which we provide
39 | - [FFHQ-256](https://www.dropbox.com/s/ygwhufzwi2vb2t8/volumegan_ffhq256.pth?dl=0)
40 | - `NUM` refers to the number of samples to synthesize.
41 | - `SEED` refers to the random seed used for sampling.
42 | - `RENDER_MODE` refers to the type of the rendered results, including `video` and `shape`.
43 | - `SAVE_HTML` controls whether to save images as an HTML for better visualization when rendering videos.
44 |
45 | ### Training
46 |
47 | For example, users can use the following command to train VolumeGAN on FFHQ in the resolution of 256x256
48 |
49 | ```shell
50 | ./scripts/training_demos/volumegan_ffhq256.sh \
51 | ${NUM_GPUS} \
52 | ${DATA_PATH} \
53 | [OPTIONS]
54 | ```
55 |
56 | where
57 |
58 | - `NUM_GPUS` refers to the number of GPUs used for training.
59 | - `DATA_PATH` refers to the path to the dataset (`zip` format is strongly recommended).
60 | - `[OPTIONS]` refers to any additional option to pass. Detailed instructions on available options can be found via `python train.py volumegan-ffhq --help`.
61 |
62 | **NOTE:** This demo script uses `volumegan_ffhq256` as the default `job_name`, which is particularly used to identify experiments. Concretely, a directory with name `job_name` will be created under the root working directory, which is set as `work_dirs/` by default. To prevent overwriting previous experiments, an exception will be raised to interrupt the training if the `job_name` directory has already existed. Please use `--job_name=${JOB_NAME}` option to specify a new job name.
63 |
64 | ### Evaluation
65 |
66 | Users can use the following command to evaluate a well-trained model
67 |
68 | ```shell
69 | ./scripts/test_metrics.sh \
70 | ${NUM_GPUS} \
71 | ${DATA_PATH} \
72 | ${MODEL_PATH} \
73 | fid \
74 | --G_kwargs '{"ps_kwargs":'{"perturb_mode":"none"}'}' \
75 | [OPTIONS]
76 | ```
77 |
78 | ## BibTeX
79 |
80 | ```bibtex
81 | @inproceedings{xu2021volumegan,
82 | title = {3D-aware Image Synthesis via Learning Structural and Textural Representations},
83 | author = {Xu, Yinghao and Peng, Sida and Yang, Ceyuan and Shen, Yujun and Zhou, Bolei},
84 | booktitle = {CVPR},
85 | year = {2022}
86 | }
87 | ```
88 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all configs."""
3 |
4 | from .stylegan_config import StyleGANConfig
5 | from .stylegan2_config import StyleGAN2Config
6 | from .stylegan2_finetune_config import StyleGAN2FineTuneConfig
7 | from .stylegan3_config import StyleGAN3Config
8 | from .volumegan_ffhq_config import VolumeGANFFHQConfig
9 | __all__ = ['CONFIG_POOL', 'build_config']
10 |
11 | CONFIG_POOL = [
12 | StyleGANConfig,
13 | StyleGAN2Config,
14 | StyleGAN2FineTuneConfig,
15 | StyleGAN3Config,
16 | VolumeGANFFHQConfig,
17 | ]
18 |
19 |
20 | def build_config(invoked_command, kwargs):
21 | """Builds a configuration based on the invoked command.
22 |
23 | Args:
24 | invoked_command: The command that is invoked.
25 | kwargs: Keyword arguments passed from command line, which will be used
26 | to build the configuration.
27 |
28 | Raises:
29 | ValueError: If the `invoked_command` is missing.
30 | """
31 | for config in CONFIG_POOL:
32 | if config.name == invoked_command:
33 | return config(kwargs)
34 | raise ValueError(f'Invoked command `{invoked_command}` is missing!\n')
35 |
--------------------------------------------------------------------------------
/convert_model.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Script to convert officially released models to match this repository."""
3 |
4 | import os
5 | import argparse
6 |
7 | from converters import build_converter
8 |
9 |
10 | def parse_args():
11 | """Parses arguments."""
12 | parser = argparse.ArgumentParser(description='Convert pre-trained models.')
13 | parser.add_argument('model_type', type=str,
14 | choices=['pggan', 'stylegan', 'stylegan2',
15 | 'stylegan2ada_tf', 'stylegan2ada_pth',
16 | 'stylegan3'],
17 | help='Type of the model to convert.')
18 | parser.add_argument('--source_model_path', type=str, required=True,
19 | help='Path to load the model for conversion.')
20 | parser.add_argument('--target_model_path', type=str, required=True,
21 | help='Path to save the converted model.')
22 | parser.add_argument('--forward_test_num', type=int, default=10,
23 | help='Number of samples used for forward test. '
24 | '(default: %(default)s)')
25 | parser.add_argument('--backward_test_num', type=int, default=0,
26 | help='Number of samples used for backward test. '
27 | '(default: %(default)s)')
28 | parser.add_argument('--save_test_image', action='store_true',
29 | help='Whether to save the intermediate image in '
30 | 'forward test. (default: False)')
31 | parser.add_argument('--learning_rate', type=float, default=0.01,
32 | help='Learning rate used in backward test. '
33 | '(default: %(default)s)')
34 | parser.add_argument('--verbose_log', action='store_true',
35 | help='Whether to print verbose log. (default: False)')
36 | return parser.parse_args()
37 |
38 |
39 | def main():
40 | """Main function."""
41 | args = parse_args()
42 |
43 | if os.path.exists(args.target_model_path):
44 | raise SystemExit(f'File `{args.target_model_path}` has already '
45 | f'existed!\n'
46 | f'Please specify another path.')
47 |
48 | converter = build_converter(args.model_type, verbose_log=args.verbose_log)
49 | converter.run(src_path=args.source_model_path,
50 | dst_path=args.target_model_path,
51 | forward_test_num=args.forward_test_num,
52 | backward_test_num=args.backward_test_num,
53 | save_test_image=args.save_test_image,
54 | learning_rate=args.learning_rate)
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/converters/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all model converters."""
3 |
4 | from .pggan_converter import PGGANConverter
5 | from .stylegan_converter import StyleGANConverter
6 | from .stylegan2_converter import StyleGAN2Converter
7 | from .stylegan2ada_tf_converter import StyleGAN2ADATFConverter
8 | from .stylegan2ada_pth_converter import StyleGAN2ADAPTHConverter
9 | from .stylegan3_converter import StyleGAN3Converter
10 |
11 | __all__ = ['build_converter']
12 |
13 | _CONVERTERS = {
14 | 'pggan': PGGANConverter,
15 | 'stylegan': StyleGANConverter,
16 | 'stylegan2': StyleGAN2Converter,
17 | 'stylegan2ada_tf': StyleGAN2ADATFConverter,
18 | 'stylegan2ada_pth': StyleGAN2ADAPTHConverter,
19 | 'stylegan3': StyleGAN3Converter
20 | }
21 |
22 |
23 | def build_converter(model_type, verbose_log=False):
24 | """Builds a converter based on the model type.
25 |
26 | Args:
27 | model_type: Type of the model that the converter serves, which is case
28 | sensitive.
29 | verbose_log: Whether to print verbose log messages. (default: False)
30 |
31 | Raises:
32 | ValueError: If the `model_type` is not supported.
33 | """
34 | if model_type not in _CONVERTERS:
35 | raise ValueError(f'Invalid model type: `{model_type}`!\n'
36 | f'Types allowed: {list(_CONVERTERS)}.')
37 |
38 | return _CONVERTERS[model_type](verbose_log=verbose_log)
39 |
--------------------------------------------------------------------------------
/datasets/data_loaders/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all data loaders."""
3 |
4 | import warnings
5 |
6 | from .iter_data_loader import IterDataLoader
7 | try:
8 | from .dali_data_loader import DALIDataLoader
9 | except ImportError:
10 | DALIDataLoader = None
11 |
12 | __all__ = ['build_data_loader']
13 |
14 | _DATA_LOADERS_ALLOWED = ['iter', 'dali']
15 |
16 |
17 | def build_data_loader(data_loader_type,
18 | dataset,
19 | batch_size,
20 | repeat=1,
21 | shuffle=True,
22 | seed=0,
23 | drop_last_sample=False,
24 | drop_last_batch=True,
25 | num_workers=0,
26 | prefetch_factor=2,
27 | pin_memory=False,
28 | num_threads=1):
29 | """Builds a data loader with given dataset.
30 |
31 | Args:
32 | data_loader_type: Class type to which the data loader belongs, which is
33 | case insensitive.
34 | dataset: The dataset to load data from.
35 | batch_size: The batch size of the data produced by each replica.
36 | repeat: Repeating number of the entire dataset. (default: 1)
37 | shuffle: Whether to shuffle the samples within each epoch.
38 | (default: True)
39 | seed: Random seed used for shuffling. (default: 0)
40 | drop_last_sample: Whether to drop the tailing samples that cannot be
41 | evenly distributed. (default: False)
42 | drop_last_batch: Whether to drop the last incomplete batch.
43 | (default: True)
44 | num_workers: Number of workers to prefetch data for each replica.
45 | (default: 0)
46 | prefetch_factor: Number of samples loaded in advance by each worker.
47 | `N` means there will be a total of `N * num_workers` samples
48 | prefetched across all workers. (default: 2)
49 | pin_memory: Whether to use pinned memory for loaded data. This field is
50 | particularly used for `IterDataLoader`. (default: False)
51 | num_threads: Number of threads for each replica. This field is
52 | particularly used for `DALIDataLoader`. (default: 1)
53 |
54 | Raises:
55 | ValueError: If `data_loader_type` is not supported.
56 | NotImplementedError: If `data_loader_type` is not implemented yet.
57 | """
58 | data_loader_type = data_loader_type.lower()
59 | if data_loader_type not in _DATA_LOADERS_ALLOWED:
60 | raise ValueError(f'Invalid data loader type: `{data_loader_type}`!\n'
61 | f'Types allowed: {_DATA_LOADERS_ALLOWED}.')
62 |
63 | if data_loader_type == 'dali' and DALIDataLoader is None:
64 | warnings.warn('DALI is not supported on the current environment! '
65 | 'Fall back to `IterDataLoader`.')
66 | data_loader_type = 'iter'
67 |
68 | if data_loader_type == 'dali' and not dataset.support_dali:
69 | warnings.warn('DALI is not supported by some transformation node of '
70 | 'the dataset! Fall back to `IterDataLoader`.')
71 | data_loader_type = 'iter'
72 |
73 | if data_loader_type == 'iter':
74 | return IterDataLoader(dataset=dataset,
75 | batch_size=batch_size,
76 | repeat=repeat,
77 | shuffle=shuffle,
78 | seed=seed,
79 | drop_last_sample=drop_last_sample,
80 | drop_last_batch=drop_last_batch,
81 | num_workers=num_workers,
82 | prefetch_factor=prefetch_factor,
83 | pin_memory=pin_memory)
84 | if data_loader_type == 'dali':
85 | return DALIDataLoader(dataset=dataset,
86 | batch_size=batch_size,
87 | repeat=repeat,
88 | shuffle=shuffle,
89 | seed=seed,
90 | drop_last_sample=drop_last_sample,
91 | drop_last_batch=drop_last_batch,
92 | num_workers=num_workers,
93 | prefetch_factor=prefetch_factor,
94 | num_threads=num_threads)
95 | raise NotImplementedError(f'Not implemented data loader type '
96 | f'`{data_loader_type}`!')
97 |
--------------------------------------------------------------------------------
/datasets/data_loaders/dali_batch_iterator.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Wraps a batch-based iterator introduced in DALI.
3 |
4 | For more details, please refer to
5 |
6 | https://docs.nvidia.com/deeplearning/dali/user-guide/docs/
7 | """
8 |
9 | try:
10 | from nvidia.dali.plugin.base_iterator import LastBatchPolicy
11 | from nvidia.dali.plugin.pytorch import DALIGenericIterator
12 | except ImportError as e:
13 | raise ImportError('DALI is not supported! Please install first.') from e
14 |
15 | __all__ = ['DALIBatchIterator']
16 |
17 |
18 | class DALIBatchIterator(DALIGenericIterator):
19 | """Defines the batch iterator for DALI data pre-processing.
20 |
21 | Args:
22 | pipeline: The pre-defined pipeline for data pre-processing.
23 | batch_size: Number of samples for each batch.
24 | drop_last_batch: Whether to drop the last incomplete batch.
25 | (default: True)
26 | """
27 | def __init__(self, pipeline, batch_size, drop_last_batch=True):
28 | self.batch_size = batch_size
29 | self.drop_last_batch = drop_last_batch
30 |
31 | if self.drop_last_batch:
32 | last_batch_padded = False
33 | last_batch_policy = LastBatchPolicy.FILL
34 | self.num_batches = len(pipeline) // batch_size
35 | else:
36 | last_batch_padded = True
37 | last_batch_policy = LastBatchPolicy.DROP
38 | self.num_batches = (len(pipeline) - 1) // batch_size + 1
39 |
40 | super().__init__(pipelines=pipeline,
41 | size=-1,
42 | auto_reset=False,
43 | output_map=pipeline.dataset.output_keys,
44 | last_batch_padded=last_batch_padded,
45 | last_batch_policy=last_batch_policy,
46 | prepare_first_batch=True)
47 |
48 | def __next__(self):
49 | # [0] means the first GPU. In the distributed case, each replica only
50 | # has one GPU.
51 | return super().__next__()[0]
52 |
53 | def __len__(self):
54 | return self.num_batches
55 |
--------------------------------------------------------------------------------
/datasets/data_loaders/dali_data_loader.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of DALI-based data loader.
3 |
4 | For more details, please refer to
5 |
6 | https://docs.nvidia.com/deeplearning/dali/user-guide/docs/
7 | """
8 |
9 | from .dali_batch_iterator import DALIBatchIterator
10 | from .dali_pipeline import DALIPipeline
11 | from .distributed_sampler import DistributedSampler
12 | from .base_data_loader import BaseDataLoader
13 |
14 | __all__ = ['DALIDataLoader']
15 |
16 |
17 | class DALIDataLoader(BaseDataLoader):
18 | """Defines the DALI-based data loader."""
19 |
20 | def __init__(self,
21 | dataset,
22 | batch_size,
23 | repeat=1,
24 | shuffle=True,
25 | seed=0,
26 | drop_last_sample=False,
27 | drop_last_batch=True,
28 | num_workers=0,
29 | prefetch_factor=2,
30 | num_threads=1):
31 | """Initializes the data loader.
32 |
33 | Args:
34 | num_threads: Number of threads used for each replica. (default: 1)
35 | """
36 | self.num_threads = num_threads
37 | self._pipeline = None
38 | super().__init__(dataset=dataset,
39 | batch_size=batch_size,
40 | repeat=repeat,
41 | shuffle=shuffle,
42 | seed=seed,
43 | drop_last_sample=drop_last_sample,
44 | drop_last_batch=drop_last_batch,
45 | num_workers=num_workers,
46 | prefetch_factor=prefetch_factor)
47 |
48 | def __len__(self):
49 | return len(self.iter_loader)
50 |
51 | def build(self):
52 | self._sampler = DistributedSampler(
53 | dataset=self._dataset,
54 | shuffle=self.shuffle,
55 | repeat=self.repeat,
56 | seed=self.seed,
57 | drop_last_sample=self.drop_last_sample,
58 | for_dali=True)
59 | prefetch_queue_depth = max(1, self.num_workers * self.prefetch_factor)
60 | self._pipeline = DALIPipeline(dataset=self._dataset,
61 | sampler=self._sampler,
62 | batch_size=self.batch_size,
63 | seed=self.seed,
64 | num_workers=self.num_workers,
65 | num_threads=self.num_threads,
66 | prefetch_queue_depth=prefetch_queue_depth)
67 | self._iter_loader = DALIBatchIterator(
68 | pipeline=self._pipeline,
69 | batch_size=self.batch_size,
70 | drop_last_batch=self.drop_last_batch)
71 |
72 | def reset_iter_loader(self):
73 | self._iter_loader.reset()
74 |
75 | def info(self):
76 | data_loader_info = super().info()
77 | data_loader_info['Num threads'] = self.num_threads
78 | return data_loader_info
79 |
--------------------------------------------------------------------------------
/datasets/data_loaders/dali_pipeline.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Wraps the data pre-processing pipeline introduced in DALI.
3 |
4 | DALI deploys the data pre-processing pipeline on GPU instead of CPU for
5 | acceleration. It relies on a pre-compiling graph. This file wraps this pipeline
6 | to fit the data loader.
7 |
8 | For more details, please refer to
9 |
10 | https://docs.nvidia.com/deeplearning/dali/user-guide/docs/
11 | """
12 |
13 | try:
14 | import dill
15 |
16 | import nvidia.dali.ops as ops
17 | from nvidia.dali.pipeline import Pipeline
18 | except ImportError as e:
19 | raise ImportError('DALI is not supported! Please install first.') from e
20 |
21 | __all__ = ['DALIPipeline']
22 |
23 |
24 | class DALIPipeline(Pipeline):
25 | """Defines the pipeline for DALI data pre-processing.
26 |
27 | Args:
28 | dataset: Dataset to load data from.
29 | sampler: Index sampler.
30 | batch_size: Number of samples for each batch.
31 | seed: Seed for randomness in data pre-processing. (default: 0)
32 | num_workers: Number of workers to pre-fetch data (on CPU) from the
33 | dataset. (default: 0)
34 | num_threads: Number of threads used for data pre-processing by the
35 | current replica. (default: 1)
36 | prefetch_queue_depth: Prefetch queue depth. (default: 1)
37 | """
38 |
39 | def __init__(self,
40 | dataset,
41 | sampler,
42 | batch_size,
43 | seed=0,
44 | num_workers=0,
45 | num_threads=1,
46 | prefetch_queue_depth=1):
47 | self._dataset = dataset
48 | self._sampler = sampler
49 |
50 | # Starting node of the data pre-processing graph.
51 | self.get_raw_data = ops.ExternalSource(
52 | source=self.sampler,
53 | num_outputs=self.dataset.num_raw_outputs,
54 | parallel=True,
55 | prefetch_queue_depth=prefetch_queue_depth)
56 |
57 | if seed >= 0:
58 | seed = seed * self.sampler.world_size + self.sampler.rank
59 | else:
60 | seed = -1
61 |
62 | if dataset.has_customized_function_for_dali:
63 | exec_pipelined = False
64 | exec_async = False
65 | num_workers = 0
66 | else:
67 | exec_pipelined = True
68 | exec_async = True
69 | super().__init__(batch_size=batch_size,
70 | num_threads=num_threads,
71 | device_id=self.sampler.rank,
72 | seed=seed,
73 | exec_pipelined=exec_pipelined,
74 | exec_async=exec_async,
75 | py_num_workers=num_workers,
76 | py_start_method='spawn',
77 | py_callback_pickler=dill)
78 |
79 | @property
80 | def dataset(self):
81 | """Returns the dataset."""
82 | return self._dataset
83 |
84 | @property
85 | def sampler(self):
86 | """Returns the sampler."""
87 | return self._sampler
88 |
89 | def define_graph(self):
90 | """Builds a static graph for data transformation."""
91 | return self.dataset.define_dali_graph(self.get_raw_data())
92 |
93 | def __len__(self):
94 | return len(self.sampler)
95 |
--------------------------------------------------------------------------------
/datasets/data_loaders/iter_data_loader.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of iteration-based data loader."""
3 |
4 | from torch.utils.data import DataLoader
5 |
6 | from .distributed_sampler import DistributedSampler
7 | from .base_data_loader import BaseDataLoader
8 |
9 | __all__ = ['IterDataLoader']
10 |
11 |
12 | class IterDataLoader(BaseDataLoader):
13 | """Defines the iteration-based data loader."""
14 |
15 | def __init__(self,
16 | dataset,
17 | batch_size,
18 | repeat=1,
19 | shuffle=True,
20 | seed=0,
21 | drop_last_sample=False,
22 | drop_last_batch=True,
23 | num_workers=0,
24 | prefetch_factor=2,
25 | pin_memory=False):
26 | """Initializes the data loader.
27 |
28 | Args:
29 | pin_memory: Whether to use pinned memory for loaded data. If `True`,
30 | it will be faster to move data from CPU to GPU, however, it may
31 | require a high-performance computing system. (default: False)
32 | """
33 | self.pin_memory = pin_memory
34 | self._batch_grouper = None
35 | super().__init__(dataset=dataset,
36 | batch_size=batch_size,
37 | repeat=repeat,
38 | shuffle=shuffle,
39 | seed=seed,
40 | drop_last_sample=drop_last_sample,
41 | drop_last_batch=drop_last_batch,
42 | num_workers=num_workers,
43 | prefetch_factor=prefetch_factor)
44 |
45 | def __len__(self):
46 | return len(self._batch_grouper)
47 |
48 | def build(self):
49 | self._sampler = DistributedSampler(
50 | dataset=self._dataset,
51 | repeat=self.repeat,
52 | shuffle=self.shuffle,
53 | seed=self.seed,
54 | drop_last_sample=self.drop_last_sample,
55 | for_dali=False)
56 | self._batch_grouper = DataLoader(dataset=self._dataset,
57 | batch_size=self.batch_size,
58 | sampler=self._sampler,
59 | shuffle=False,
60 | drop_last=self.drop_last_batch,
61 | num_workers=self.num_workers,
62 | pin_memory=self.pin_memory,
63 | prefetch_factor=self.prefetch_factor)
64 | self._iter_loader = iter(self._batch_grouper)
65 |
66 | def reset_iter_loader(self):
67 | self._iter_loader = iter(self._batch_grouper)
68 |
69 | def info(self):
70 | data_loader_info = super().info()
71 | data_loader_info['Pin memory'] = self.pin_memory
72 | return data_loader_info
73 |
--------------------------------------------------------------------------------
/datasets/file_readers/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all file readers."""
3 |
4 | from .directory_reader import DirectoryReader
5 | from .lmdb_reader import LmdbReader
6 | from .tar_reader import TarReader
7 | from .zip_reader import ZipReader
8 |
9 | __all__ = ['build_file_reader']
10 |
11 | _READERS = {
12 | 'dir': DirectoryReader,
13 | 'lmdb': LmdbReader,
14 | 'tar': TarReader,
15 | 'zip': ZipReader
16 | }
17 |
18 |
19 | def build_file_reader(reader_type='zip'):
20 | """Builds a file reader.
21 |
22 | Args:
23 | reader_type: Type of the file reader, which is case insensitive.
24 | (default: `zip`)
25 |
26 | Raises:
27 | ValueError: If the `reader_type` is not supported.
28 | """
29 | reader_type = reader_type.lower()
30 | if reader_type not in _READERS:
31 | raise ValueError(f'Invalid reader type: `{reader_type}`!\n'
32 | f'Types allowed: {list(_READERS)}.')
33 | return _READERS[reader_type]
34 |
--------------------------------------------------------------------------------
/datasets/file_readers/base_reader.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the base class to read files.
3 |
4 | A file reader reads data from a given file and cache the file if possible.
5 | Typically, file readers are designed to read files from zip, lmdb, or directory.
6 | """
7 |
8 | from utils.misc import IMAGE_EXTENSIONS
9 | from utils.misc import check_file_ext
10 |
11 | __all__ = ['BaseReader']
12 |
13 |
14 | class BaseReader(object):
15 | """Defines the base file reader.
16 |
17 | A reader should have the following functions:
18 |
19 | (1) open(): The function to open (and cache) a given file/directory.
20 | (2) close(): The function to close a given file/directory.
21 | (3) open_anno_file(): The function to open a specific annotation file inside
22 | the given file/directory.
23 | (4) get_file_list(): The function to get the list of all files inside the
24 | given file/directory. The returned list is already sorted.
25 | (5) get_file_list_with_ext(): The function to get the list of files with
26 | expected file extensions inside the given file/directory. The returned
27 | list is already sorted.
28 | (6) get_image_list(): The function to get the list of all images inside the
29 | given file/directory. The returned list is already sorted.
30 | (7) fetch_file(): The function to fetch the bytes of member file inside the
31 | given file/directory.
32 | """
33 |
34 | @staticmethod
35 | def open(path):
36 | """Opens the given path."""
37 | raise NotImplementedError('Should be implemented in derived class!')
38 |
39 | @staticmethod
40 | def close(path):
41 | """Closes the given path."""
42 | raise NotImplementedError('Should be implemented in derived class!')
43 |
44 | @staticmethod
45 | def open_anno_file(path, anno_filename=None):
46 | """Opens the annotation file in `path` and returns a file pointer.
47 |
48 | If the annotation file does not exist, return `None`.
49 | """
50 | raise NotImplementedError('Should be implemented in derived class!')
51 |
52 | @staticmethod
53 | def _get_file_list(path):
54 | """Gets the list of all files inside `path`."""
55 | raise NotImplementedError('Should be implemented in derived class!')
56 |
57 | @classmethod
58 | def get_file_list(cls, path):
59 | """Gets the sorted list of all files inside `path`."""
60 | return sorted(cls._get_file_list(path))
61 |
62 | @classmethod
63 | def get_file_list_with_ext(cls, path, ext=None):
64 | """Gets the sorted list of files with expected extensions.
65 |
66 | NOTE: If no `ext` is specified, the list of all files will be returned.
67 | """
68 | ext = ext or []
69 | assert isinstance(ext, (list, tuple))
70 | if len(ext) == 0:
71 | return cls.get_file_list(path)
72 | return [f for f in cls.get_file_list(path) if check_file_ext(f, *ext)]
73 |
74 | @classmethod
75 | def get_image_list(cls, path):
76 | """Gets the sorted list of image files inside `path`."""
77 | return cls.get_file_list_with_ext(path, IMAGE_EXTENSIONS)
78 |
79 | @staticmethod
80 | def fetch_file(path, filename):
81 | """Fetches the bytes of file `filename` inside `path`.
82 |
83 | Example:
84 |
85 | >>> f = BaseReader.fetch_file('data', 'face.obj')
86 | >>> obj = f.decode('utf-8') # convert `bytes` to `str`
87 | """
88 | raise NotImplementedError('Should be implemented in derived class!')
89 |
--------------------------------------------------------------------------------
/datasets/file_readers/directory_reader.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of directory reader.
3 |
4 | This reader can summarize file list or fetch bytes of files inside a directory.
5 | """
6 |
7 | import os.path
8 |
9 | from .base_reader import BaseReader
10 |
11 | __all__ = ['DirectoryReader']
12 |
13 |
14 | class DirectoryReader(BaseReader):
15 | """Defines a class to load directory."""
16 |
17 | @staticmethod
18 | def open(path):
19 | assert os.path.isdir(path), f'Directory `{path}` is invalid!'
20 | return path
21 |
22 | @staticmethod
23 | def close(path):
24 | _ = path # Dummy function.
25 |
26 | @staticmethod
27 | def open_anno_file(path, anno_filename=None):
28 | path = DirectoryReader.open(path)
29 | if not anno_filename:
30 | return None
31 | anno_path = os.path.join(path, anno_filename)
32 | if not os.path.isfile(anno_path):
33 | return None
34 | # File will be closed after parsed in dataset.
35 | return open(anno_path, 'r')
36 |
37 | @staticmethod
38 | def _get_file_list(path):
39 | path = DirectoryReader.open(path)
40 | return os.listdir(path)
41 |
42 | @staticmethod
43 | def fetch_file(path, filename):
44 | path = DirectoryReader.open(path)
45 | with open(os.path.join(path, filename), 'rb') as f:
46 | file_bytes = f.read()
47 | return file_bytes
48 |
--------------------------------------------------------------------------------
/datasets/file_readers/lmdb_reader.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of LMDB database reader.
3 |
4 | This reader can summarize file list or fetch bytes of files inside a LMDB
5 | database.
6 | """
7 |
8 | import lmdb
9 |
10 | from .base_reader import BaseReader
11 |
12 | __all__ = ['LmdbReader']
13 |
14 |
15 | class LmdbReader(BaseReader):
16 | """Defines a class to load LMDB file.
17 |
18 | This is a static class, which is used to solve the problem that different
19 | data workers cannot share the same memory.
20 | """
21 |
22 | reader_cache = dict()
23 |
24 | @staticmethod
25 | def open(path):
26 | """Opens a lmdb file."""
27 | lmdb_files = LmdbReader.reader_cache
28 | if path not in lmdb_files:
29 | env = lmdb.open(path,
30 | max_readers=1,
31 | readonly=True,
32 | lock=False,
33 | readahead=False,
34 | meminit=False)
35 | with env.begin(write=False) as txn:
36 | num_samples = txn.stat()['entries']
37 | keys = [key for key, _ in txn.cursor()]
38 | file_info = {'env': env,
39 | 'num_samples': num_samples,
40 | 'keys': keys}
41 | lmdb_files[path] = file_info
42 | return lmdb_files[path]
43 |
44 | @staticmethod
45 | def close(path):
46 | lmdb_files = LmdbReader.reader_cache
47 | lmdb_file = lmdb_files.pop(path, None)
48 | if lmdb_file is not None:
49 | lmdb_file['env'].close()
50 | lmdb_file.clear()
51 |
52 | @staticmethod
53 | def open_anno_file(path, anno_filename=None):
54 | # TODO: Support loading annotation file from LMDB.
55 | return None
56 |
57 | @staticmethod
58 | def _get_file_list(path):
59 | lmdb_file = LmdbReader.open(path)
60 | return lmdb_file['keys']
61 |
62 | @classmethod
63 | def get_file_list_with_ext(cls, path, ext=None):
64 | # NOTE: In LMDB, keys do not reveal file extension.
65 | return cls.get_file_list(path)
66 |
67 | @classmethod
68 | def get_image_list(cls, path):
69 | # NOTE: In LMDB, keys do not reveal file extension.
70 | return cls.get_file_list(path)
71 |
72 | @staticmethod
73 | def fetch_file(path, filename):
74 | lmdb_file = LmdbReader.open(path)
75 | env = lmdb_file['env']
76 | with env.begin(write=False) as txn:
77 | file_bytes = txn.get(filename)
78 | return file_bytes
79 |
--------------------------------------------------------------------------------
/datasets/file_readers/tar_reader.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of TAR file reader.
3 |
4 | Basically, a TAR file will be first extracted, under the same root directory as
5 | the source TAR file, and with the same base name as the source TAR file. For
6 | example, the TAR file `/home/data/test_data.tar.gz` will be extracted to
7 | `/home/data/test_data/`. Then, this file reader degenerates into
8 | `DirectoryReader`.
9 |
10 | NOTE: TAR file is not recommended to use. Instead, please use ZIP file.
11 | """
12 |
13 | import os.path
14 | import shutil
15 | import tarfile
16 |
17 | from .base_reader import BaseReader
18 |
19 | __all__ = ['TarReader']
20 |
21 |
22 | class TarReader(BaseReader):
23 | """Defines a class to load TAR file.
24 |
25 | This is a static class, which is used to solve the problem that different
26 | data workers cannot share the same memory.
27 | """
28 |
29 | reader_cache = dict()
30 |
31 | @staticmethod
32 | def open(path):
33 | tar_files = TarReader.reader_cache
34 | if path not in tar_files:
35 | root_dir = os.path.dirname(path)
36 | base_dir = os.path.basename(path).split('.tar')[0]
37 | extract_dir = os.path.join(root_dir, base_dir)
38 | filenames = []
39 | with tarfile.open(path, 'r') as f:
40 | for member in f.getmembers():
41 | if member.isfile():
42 | filenames.append(member.name)
43 | f.extractall(extract_dir)
44 | file_info = {'extract_dir': extract_dir,
45 | 'filenames': filenames}
46 | tar_files[path] = file_info
47 | return tar_files[path]
48 |
49 | @staticmethod
50 | def close(path):
51 | tar_files = TarReader.reader_cache
52 | tar_file = tar_files.pop(path, None)
53 | if tar_file is not None:
54 | extract_dir = tar_file['extract_dir']
55 | shutil.rmtree(extract_dir)
56 | tar_file.clear()
57 |
58 | @staticmethod
59 | def open_anno_file(path, anno_filename=None):
60 | tar_file = TarReader.open(path)
61 | if not anno_filename:
62 | return None
63 | anno_path = os.path.join(tar_file['extract_dir'], anno_filename)
64 | if not os.path.isfile(anno_path):
65 | return None
66 | # File will be closed after parsed in dataset.
67 | return open(anno_path, 'r')
68 |
69 | @staticmethod
70 | def _get_file_list(path):
71 | tar_file = TarReader.open(path)
72 | return tar_file['filenames']
73 |
74 | @staticmethod
75 | def fetch_file(path, filename):
76 | tar_file = TarReader.open(path)
77 | with open(os.path.join(tar_file['extract_dir'], filename), 'rb') as f:
78 | file_bytes = f.read()
79 | return file_bytes
80 |
--------------------------------------------------------------------------------
/datasets/file_readers/zip_reader.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of ZIP file reader.
3 |
4 | This reader can summarize file list or fetch bytes of files inside a ZIP.
5 | """
6 |
7 | import zipfile
8 |
9 | from .base_reader import BaseReader
10 |
11 | __all__ = ['ZipReader']
12 |
13 |
14 | class ZipReader(BaseReader):
15 | """Defines a class to load ZIP file.
16 |
17 | This is a static class, which is used to solve the problem that different
18 | data workers cannot share the same memory.
19 | """
20 |
21 | reader_cache = dict()
22 |
23 | @staticmethod
24 | def open(path):
25 | zip_files = ZipReader.reader_cache
26 | if path not in zip_files:
27 | # File will be closed by calling `cls.close()`.
28 | zip_files[path] = zipfile.ZipFile(path, 'r') # pylint: disable=consider-using-with
29 | return zip_files[path]
30 |
31 | @staticmethod
32 | def close(path):
33 | zip_files = ZipReader.reader_cache
34 | zip_file = zip_files.pop(path, None)
35 | if zip_file is not None:
36 | zip_file.close()
37 |
38 | @staticmethod
39 | def open_anno_file(path, anno_filename=None):
40 | zip_file = ZipReader.open(path)
41 | if not anno_filename:
42 | return None
43 | if anno_filename not in zip_file.namelist():
44 | return None
45 | # File will be closed after parsed in dataset.
46 | return zip_file.open(anno_filename, 'r')
47 |
48 | @staticmethod
49 | def _get_file_list(path):
50 | zip_file = ZipReader.open(path)
51 | return [
52 | info.filename for info in zip_file.infolist() if not info.is_dir()]
53 |
54 | @staticmethod
55 | def fetch_file(path, filename):
56 | zip_file = ZipReader.open(path)
57 | return zip_file.read(filename)
58 |
--------------------------------------------------------------------------------
/datasets/transformations/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all transformations for data pre-processing."""
3 |
4 | from .affine_transform import AffineTransform
5 | from .blur_and_sharpen import BlurAndSharpen
6 | from .crop import CenterCrop
7 | from .crop import RandomCrop
8 | from .crop import LongSideCrop
9 | from .decode import Decode
10 | from .flip import Flip
11 | from .hsv_jittering import HSVJittering
12 | from .identity import Identity
13 | from .jpeg_compress import JpegCompress
14 | from .normalize import Normalize
15 | from .region_brightness import RegionBrightness
16 | from .resize import Resize
17 | from .resize import ProgressiveResize
18 | from .resize import ResizeAug
19 |
20 | __all__ = ['build_transformation']
21 |
22 |
23 | _TRANSFORMATIONS = {
24 | 'AffineTransform': AffineTransform,
25 | 'BlurAndSharpen': BlurAndSharpen,
26 | 'CenterCrop': CenterCrop,
27 | 'RandomCrop': RandomCrop,
28 | 'LongSideCrop': LongSideCrop,
29 | 'Decode': Decode,
30 | 'Flip': Flip,
31 | 'HSVJittering': HSVJittering,
32 | 'Identity': Identity,
33 | 'JpegCompress': JpegCompress,
34 | 'Normalize': Normalize,
35 | 'RegionBrightness': RegionBrightness,
36 | 'Resize': Resize,
37 | 'ProgressiveResize': ProgressiveResize,
38 | 'ResizeAug': ResizeAug
39 | }
40 |
41 |
42 | def build_transformation(transform_type, **kwargs):
43 | """Builds a transformation based on its class type.
44 |
45 | Args:
46 | transform_type: Class type to which the transformation belongs,
47 | which is case sensitive.
48 | **kwargs: Additional arguments to build the transformation.
49 |
50 | Raises:
51 | ValueError: If the `transform_type` is not supported.
52 | """
53 | if transform_type not in _TRANSFORMATIONS:
54 | raise ValueError(f'Invalid transformation type: '
55 | f'`{transform_type}`!\n'
56 | f'Types allowed: {list(_TRANSFORMATIONS)}.')
57 | return _TRANSFORMATIONS[transform_type](**kwargs)
58 |
--------------------------------------------------------------------------------
/datasets/transformations/flip.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Implements image flipping."""
3 |
4 | import numpy as np
5 |
6 | try:
7 | import nvidia.dali.fn as fn
8 | except ImportError:
9 | fn = None
10 |
11 | from .base_transformation import BaseTransformation
12 |
13 | __all__ = ['Flip']
14 |
15 |
16 | class Flip(BaseTransformation):
17 | """Applies random flipping to images.
18 |
19 | Args:
20 | horizontal_prob: Probability of flipping images horizontally.
21 | (default: 0.0)
22 | vertical_prob: Probability of flipping images vertically. (default: 0.0)
23 | """
24 |
25 | def __init__(self, horizontal_prob=0.0, vertical_prob=0.0):
26 | super().__init__(support_dali=(fn is not None))
27 |
28 | self.horizontal_prob = np.clip(horizontal_prob, 0, 1)
29 | self.vertical_prob = np.clip(vertical_prob, 0, 1)
30 |
31 | def _CPU_forward(self, data):
32 | do_horizontal = np.random.uniform() < self.horizontal_prob
33 | do_vertical = np.random.uniform() < self.vertical_prob
34 |
35 | # Early return if no flipping is applied.
36 | if not do_horizontal and not do_vertical:
37 | return data
38 |
39 | outputs = []
40 | for image in data:
41 | if do_horizontal:
42 | image = image[:, ::-1]
43 | if do_vertical:
44 | image = image[::-1, :]
45 | outputs.append(np.ascontiguousarray(image))
46 | return outputs
47 |
48 | def _DALI_forward(self, data):
49 | do_horizontal = fn.random.coin_flip(probability=self.horizontal_prob)
50 | do_vertical = fn.random.coin_flip(probability=self.vertical_prob)
51 | return fn.flip(data, horizontal=do_horizontal, vertical=do_vertical)
52 |
--------------------------------------------------------------------------------
/datasets/transformations/hsv_jittering.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Implements image color jittering from the HSV space."""
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | try:
8 | import nvidia.dali.fn as fn
9 | import nvidia.dali.types as types
10 | except ImportError:
11 | fn = None
12 |
13 | from utils.formatting_utils import format_range
14 | from .base_transformation import BaseTransformation
15 |
16 | __all__ = ['HSVJittering']
17 |
18 |
19 | class HSVJittering(BaseTransformation):
20 | """Applies random color jittering to images from the HSV space.
21 |
22 | Args:
23 | h_range: The range within which to uniformly sample a hue. Use `(0, 0)`
24 | to disable the hue jittering. (default: (0, 0))
25 | s_range: The range within which to uniformly sample a saturation. Use
26 | `(1, 1)` to disable the saturation jittering. (default: (1, 1))
27 | v_range: The range within which to uniformly sample a brightness value.
28 | Use `(1, 1)` to disable the brightness jittering. (default: (1, 1))
29 | """
30 |
31 | def __init__(self, h_range=(0, 0), s_range=(1, 1), v_range=(1, 1)):
32 | super().__init__(support_dali=(fn is not None))
33 |
34 | self.h_range = format_range(h_range)
35 | self.s_range = format_range(s_range, min_val=0)
36 | self.v_range = format_range(v_range, min_val=0)
37 |
38 | def _CPU_forward(self, data):
39 | # Early return if no jittering is needed.
40 | if (self.h_range == (0, 0) and self.s_range == (1, 1) and
41 | self.v_range == (1, 1)):
42 | return data
43 |
44 | # Get random jittering value for hue, saturation, and brightness.
45 | hue = np.random.uniform(*self.h_range)
46 | sat = np.random.uniform(*self.s_range)
47 | val = np.random.uniform(*self.v_range)
48 |
49 | # Perform color jittering.
50 | outputs = []
51 | for image in data:
52 | assert image.shape[2] == 3, 'RGB image is expected!'
53 | h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_RGB2HSV))
54 | h = ((h + hue) % 180).astype(np.uint8)
55 | s = np.clip(s * sat, 0, 255).astype(np.uint8)
56 | v = np.clip(v * val, 0, 255).astype(np.uint8)
57 | new_image = cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2RGB)
58 | outputs.append(new_image)
59 | return outputs
60 |
61 | def _DALI_forward(self, data):
62 | # Early return is no jittering is needed.
63 | if (self.h_range == (0, 0) and self.s_range == (1, 1) and
64 | self.v_range == (1, 1)):
65 | return data
66 |
67 | # Get random jittering value for hue, saturation, and brightness.
68 | if self.h_range[0] == self.h_range[1]:
69 | hue = self.h_range[0]
70 | else:
71 | hue = fn.random.uniform(range=self.h_range)
72 | hue = fn.cast(hue, dtype=types.FLOAT)
73 | if self.s_range[0] == self.s_range[1]:
74 | sat = self.s_range[0]
75 | else:
76 | sat = fn.random.uniform(range=self.s_range)
77 | sat = fn.cast(sat, dtype=types.FLOAT)
78 | if self.v_range[0] == self.v_range[1]:
79 | val = self.v_range[0]
80 | else:
81 | val = fn.random.uniform(range=self.v_range)
82 | val = fn.cast(val, dtype=types.FLOAT)
83 |
84 | # Perform color jittering.
85 | return fn.hsv(data, hue=hue, saturation=sat, value=val)
86 |
--------------------------------------------------------------------------------
/datasets/transformations/identity.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Implements an identity transformation, which can be used as a placeholder."""
3 |
4 | from .base_transformation import BaseTransformation
5 |
6 | __all__ = ['Identity']
7 |
8 |
9 | class Identity(BaseTransformation):
10 | """Applies no transformation by directly returning the input."""
11 |
12 | def __init__(self):
13 | super().__init__(support_dali=True)
14 |
15 | def _CPU_forward(self, data):
16 | return data
17 |
18 | def _DALI_forward(self, data):
19 | return data
20 |
--------------------------------------------------------------------------------
/datasets/transformations/jpeg_compress.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Implements JPEG compression on images."""
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | try:
8 | import nvidia.dali.fn as fn
9 | import nvidia.dali.types as types
10 | except ImportError:
11 | fn = None
12 |
13 | from utils.formatting_utils import format_range
14 | from .base_transformation import BaseTransformation
15 |
16 | __all__ = ['JpegCompress']
17 |
18 |
19 | class JpegCompress(BaseTransformation):
20 | """Applies random JPEG compression to images.
21 |
22 | This transformation can be used as an augmentation by distorting images.
23 | In other words, the input image(s) will be first compressed (i.e., encoded)
24 | with a random quality ratio, and then decoded back to the image space.
25 | The distortion is introduced in the encoding process.
26 |
27 | Args:
28 | quality_range: The range within which to uniformly sample a quality
29 | value after compression. 100 means highest and 0 means lowest.
30 | (default: (40, 60))
31 | prob: Probability of applying JPEG compression. (default: 0.5)
32 | """
33 |
34 | def __init__(self, prob=0.5, quality_range=(40, 60)):
35 | super().__init__(support_dali=(fn is not None))
36 |
37 | self.prob = np.clip(prob, 0, 1)
38 | self.quality_range = format_range(quality_range, min_val=0, max_val=100)
39 |
40 | def _CPU_forward(self, data):
41 | # Early return if no compression is needed.
42 | if np.random.uniform() >= self.prob:
43 | return data
44 |
45 | # Set compression quality.
46 | quality = np.random.randint(*self.quality_range)
47 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
48 |
49 | # Compress images.
50 | outputs = []
51 | for image in data:
52 | _, encoded_image = cv2.imencode('.jpg', image, encode_param)
53 | decoded_image = cv2.imdecode(encoded_image, cv2.IMREAD_UNCHANGED)
54 | if decoded_image.ndim == 2:
55 | decoded_image = decoded_image[:, :, np.newaxis]
56 | outputs.append(decoded_image)
57 | return outputs
58 |
59 | def _DALI_forward(self, data):
60 | # Set compression quality.
61 | if self.quality_range[0] == self.quality_range[1]:
62 | quality = self.quality_range[0]
63 | else:
64 | quality = fn.random.uniform(range=self.quality_range)
65 | quality = fn.cast(quality, dtype=types.INT32)
66 |
67 | # Compress images.
68 | compressed_data = fn.jpeg_compression_distortion(
69 | data, quality=quality)
70 | if not isinstance(compressed_data, (list, tuple)):
71 | compressed_data = [compressed_data]
72 |
73 | # Determine whether the transformation should be applied.
74 | cond = fn.random.coin_flip(dtype=types.BOOL, probability=self.prob)
75 | outputs = []
76 | for image, compressed_image in zip(data, compressed_data):
77 | outputs.append(compressed_image * cond + image * (cond ^ True))
78 | return outputs
79 |
--------------------------------------------------------------------------------
/datasets/transformations/misc.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Helper functions for data transformation."""
3 |
4 | try:
5 | import nvidia.dali.fn as fn
6 | import nvidia.dali.types as types
7 | except ImportError:
8 | fn = None
9 |
10 | __all__ = ['switch_between', 'FunctionOp']
11 |
12 |
13 | def switch_between(cond, cond_true, cond_false, use_dali=False):
14 | """Switches between two transformation nodes for data pre-processing.
15 |
16 | Args:
17 | cond: Condition to switch between two alternatives.
18 | cond_true: The returned value if the condition fulfills.
19 | cond_false: The returned value if the condition fails.
20 | use_dali: Whether the nodes are from DALI pre-processing pipeline.
21 | (default: False)
22 |
23 | Returns:
24 | One of `cond_true` and `cond_false`, depending on `cond`.
25 | """
26 | if use_dali and fn is None:
27 | raise NotImplementedError('DALI is not supported! '
28 | 'Please install first.')
29 |
30 | if not use_dali:
31 | return cond_true if cond else cond_false
32 |
33 | # Record whether any input (cond_true/cond_false) is not a list. If that is
34 | # the case, the returned value will be a single node. Otherwise, the
35 | # returned value will also be a list of nodes.
36 | is_input_list = True
37 | if not isinstance(cond_true, (list, tuple)):
38 | is_input_list = False
39 | cond_true = [cond_true]
40 | if not isinstance(cond_false, (list, tuple)):
41 | is_input_list = False
42 | cond_false = [cond_false]
43 | assert len(cond_true) == len(cond_false)
44 |
45 | cond = fn.cast(cond, dtype=types.BOOL)
46 | outputs = []
47 | for sample_true, sample_false in zip(cond_true, cond_false):
48 | outputs.append(sample_true * cond + sample_false * (cond ^ True))
49 | return outputs if is_input_list else outputs[0]
50 |
51 |
52 | class FunctionOp(object):
53 | """Contains the class to turn a function as an operator.
54 |
55 | DALI supports creating a data node, which is populated with data from an
56 | external source function. This function should be callable via accepting one
57 | positional argument. This class is particularly designed to turn a function,
58 | with default settings, into a DALI compatible operator. Please refer to
59 | `nvidia.dali.fn.external_source()` for more details.
60 |
61 | More concretely, a function `f(a, b)` with desired arguments `a=1, b=2` can
62 | be wrapped with
63 |
64 | ```
65 | op = FunctionOp(f, a=1, b=2)
66 | ```
67 |
68 | Then, it can be used with
69 |
70 | ```
71 | dali_node = fn.external_source(source=op,
72 | parallel=True,
73 | prefetch_queue_depth=32,
74 | batch=False)
75 | ```
76 |
77 | OR, it can be directly called with
78 |
79 | ```
80 | result = op()
81 | ```
82 | """
83 |
84 | def __init__(self, function, **kwargs):
85 | """Initializes with the function to wrap and the default arguments."""
86 | self.function = function
87 | self.kwargs = kwargs
88 |
89 | def __call__(self, _dali_arg=None):
90 | return self.function(**self.kwargs)
91 |
--------------------------------------------------------------------------------
/datasets/transformations/normalize.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Implements image normalization."""
3 |
4 | import numpy as np
5 |
6 | try:
7 | import nvidia.dali.fn as fn
8 | import nvidia.dali.types as types
9 | except ImportError:
10 | fn = None
11 |
12 | from .base_transformation import BaseTransformation
13 |
14 | __all__ = ['Normalize']
15 |
16 |
17 | class Normalize(BaseTransformation):
18 | """Normalizes images.
19 |
20 | The input images is expected to with pixel range [0, 255].
21 |
22 | The output images will be with data format `CHW` and dtype `float32`.
23 |
24 | Args:
25 | min_val: The minimum value after normalization. (default: -1.0)
26 | max_val: The maximum value after normalization. (default: 1.0)
27 | """
28 |
29 | def __init__(self, min_val=-1.0, max_val=1.0):
30 | super().__init__(support_dali=(fn is not None))
31 |
32 | self.min_val = float(min_val)
33 | self.max_val = float(max_val)
34 |
35 | def _CPU_forward(self, data):
36 | outputs = []
37 | for image in data:
38 | image = image.astype(np.float32)
39 | image = image / 255 * (self.max_val - self.min_val) + self.min_val
40 | image = image.transpose(2, 0, 1)
41 | outputs.append(image)
42 | return outputs
43 |
44 | def _DALI_forward(self, data):
45 | return fn.crop_mirror_normalize(
46 | data,
47 | dtype=types.FLOAT,
48 | output_layout=types.NCHW,
49 | scale=(self.max_val - self.min_val) / 255.0,
50 | shift=self.min_val)
51 |
--------------------------------------------------------------------------------
/datasets/transformations/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects dataset related utility functions."""
3 |
4 | from .affine_transform import generate_affine_transformation
5 | from .polygon import generate_polygon_contour
6 | from .polygon import generate_polygon_mask
7 |
8 | __all__ = [
9 | 'generate_affine_transformation', 'generate_polygon_contour',
10 | 'generate_polygon_mask'
11 | ]
12 |
--------------------------------------------------------------------------------
/datasets/transformations/utils/affine_transform.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the function to generate a random affine transformation."""
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | from utils.formatting_utils import format_range
8 | from utils.formatting_utils import format_image_size
9 |
10 | __all__ = ['generate_affine_transformation']
11 |
12 |
13 | def generate_affine_transformation(image_size,
14 | rotation_range,
15 | scale_range,
16 | tx_range,
17 | ty_range):
18 | """Generates a random affine transformation matrix.
19 |
20 | Args:
21 | image_size: Size of the image, which is used as a reference of the
22 | transformation. The size is assumed with order (height, width).
23 | rotation_range: The range (in degrees) within which to uniformly sample
24 | a rotation angle.
25 | scale_range: The range within which to uniformly sample a scaling
26 | factor.
27 | tx_range: The range (in length of image width) within which to uniformly
28 | sample a X translation.
29 | ty_range: The range (in length of image height) within which to
30 | uniformly sample a Y translation.
31 |
32 | Returns:
33 | A transformation matrix, with shape [2, 3] and dtype `numpy.float32`.
34 | """
35 | # Regularize inputs.
36 | height, width = format_image_size(image_size)
37 | rotation_range = format_range(rotation_range)
38 | scale_range = format_range(scale_range, min_val=0)
39 | tx_range = format_range(tx_range)
40 | ty_range = format_range(ty_range)
41 |
42 | # Sample parameters for the affine transformation.
43 | rotation = np.random.uniform(*rotation_range)
44 | scale = np.random.uniform(*scale_range)
45 | tx = np.random.uniform(*tx_range)
46 | ty = np.random.uniform(*ty_range)
47 |
48 | # Get the transformation matrix.
49 | matrix = cv2.getRotationMatrix2D(center=(width // 2, height // 2),
50 | angle=rotation,
51 | scale=scale)
52 | matrix[:, 2] += (tx * width, ty * height)
53 |
54 | return matrix.astype(np.float32)
55 |
--------------------------------------------------------------------------------
/docs/assets/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/comparison.png
--------------------------------------------------------------------------------
/docs/assets/font.css:
--------------------------------------------------------------------------------
1 | /* Homepage Font */
2 |
3 | /* latin-ext */
4 | @font-face {
5 | font-family: 'Lato';
6 | font-style: normal;
7 | font-weight: 400;
8 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2');
9 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
10 | }
11 |
12 | /* latin */
13 | @font-face {
14 | font-family: 'Lato';
15 | font-style: normal;
16 | font-weight: 400;
17 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2');
18 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
19 | }
20 |
21 | /* latin-ext */
22 | @font-face {
23 | font-family: 'Lato';
24 | font-style: normal;
25 | font-weight: 700;
26 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2');
27 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
28 | }
29 |
30 | /* latin */
31 | @font-face {
32 | font-family: 'Lato';
33 | font-style: normal;
34 | font-weight: 700;
35 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2');
36 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
37 | }
38 |
--------------------------------------------------------------------------------
/docs/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/framework.png
--------------------------------------------------------------------------------
/docs/assets/freezed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/freezed.png
--------------------------------------------------------------------------------
/docs/assets/genforce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/genforce.png
--------------------------------------------------------------------------------
/docs/assets/giraffe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/giraffe.png
--------------------------------------------------------------------------------
/docs/assets/graf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/graf.png
--------------------------------------------------------------------------------
/docs/assets/hologan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/hologan.png
--------------------------------------------------------------------------------
/docs/assets/pigan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/pigan.png
--------------------------------------------------------------------------------
/docs/assets/style.css:
--------------------------------------------------------------------------------
1 | /* Body */
2 | body {
3 | background: #e3e5e8;
4 | color: #ffffff;
5 | font-family: 'Lato', Verdana, Helvetica, sans-serif;
6 | font-weight: 300;
7 | font-size: 14pt;
8 | }
9 |
10 | /* Hyperlinks */
11 | a {text-decoration: none;}
12 | a:link {color: #1772d0;}
13 | a:visited {color: #1772d0;}
14 | a:active {color: red;}
15 | a:hover {color: #f09228;}
16 |
17 | /* Pre-formatted Text */
18 | pre {
19 | margin: 5pt 0;
20 | border: 0;
21 | font-size: 12pt;
22 | background: #fcfcfc;
23 | }
24 |
25 | /* Project Page Style */
26 | /* Section */
27 | .section {
28 | width: 768pt;
29 | min-height: 100pt;
30 | margin: 15pt auto;
31 | padding: 20pt 30pt;
32 | border: 1pt hidden #000;
33 | text-align: justify;
34 | color: #000000;
35 | background: #ffffff;
36 | }
37 |
38 | /* Header (Title and Logo) */
39 | .section .header {
40 | min-height: 80pt;
41 | margin-top: 30pt;
42 | }
43 | .section .header .logo {
44 | width: 80pt;
45 | margin-left: 10pt;
46 | float: left;
47 | }
48 | .section .header .logo img {
49 | width: 80pt;
50 | object-fit: cover;
51 | }
52 | .section .header .title {
53 | margin: 0 120pt;
54 | text-align: center;
55 | font-size: 22pt;
56 | }
57 |
58 | /* Author */
59 | .section .author {
60 | margin: 5pt 0;
61 | text-align: center;
62 | font-size: 16pt;
63 | }
64 |
65 | /* Institution */
66 | .section .institution {
67 | margin: 5pt 0;
68 | text-align: center;
69 | font-size: 16pt;
70 | }
71 |
72 | /* Hyperlink (such as Paper and Code) */
73 | .section .link {
74 | margin: 5pt 0;
75 | text-align: center;
76 | font-size: 16pt;
77 | }
78 |
79 | /* Teaser */
80 | .section .teaser {
81 | margin: 20pt 0;
82 | text-align: center;
83 | }
84 | .section .teaser img {
85 | width: 95%;
86 | }
87 |
88 | /* Section Title */
89 | .section .title {
90 | text-align: center;
91 | font-size: 22pt;
92 | margin: 5pt 0 15pt 0; /* top right bottom left */
93 | }
94 |
95 | /* Section Body */
96 | .section .body {
97 | margin-bottom: 15pt;
98 | text-align: justify;
99 | font-size: 14pt;
100 | }
101 |
102 | /* BibTeX */
103 | .section .bibtex {
104 | margin: 5pt 0;
105 | text-align: left;
106 | font-size: 22pt;
107 | }
108 |
109 | /* Related Work */
110 | .section .ref {
111 | margin: 20pt 0 10pt 0; /* top right bottom left */
112 | text-align: left;
113 | font-size: 18pt;
114 | font-weight: bold;
115 | }
116 |
117 | /* Citation */
118 | .section .citation {
119 | min-height: 60pt;
120 | margin: 10pt 0;
121 | }
122 | .section .citation .image {
123 | width: 120pt;
124 | float: left;
125 | }
126 | .section .citation .image img {
127 | max-height: 60pt;
128 | width: 120pt;
129 | object-fit: cover;
130 | }
131 | .section .citation .comment{
132 | margin-left: 130pt;
133 | text-align: left;
134 | font-size: 14pt;
135 | }
136 |
--------------------------------------------------------------------------------
/docs/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/teaser.png
--------------------------------------------------------------------------------
/dump_command_args.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Dumps available arguments of all commands (configurations).
3 |
4 | This file parses the arguments of all commands provided in `configs/` and dump
5 | the results as a json file. Each parsed argument includes the name, argument
6 | type, default value, and the help message (description). The dumped file looks
7 | like
8 |
9 | {
10 | "command_1": {
11 | "type": "object",
12 | "properties": {
13 | "arg_group_1": {
14 | "type": "object",
15 | "properties": {
16 | "arg_1": {
17 | "is_recommended": # true / false
18 | "type": # int / float / bool / str / json-string /
19 | # index-string
20 | "default":
21 | "description":
22 | },
23 | "arg_2": {
24 | "is_recommended":
25 | "type":
26 | "default":
27 | "description":
28 | }
29 | }
30 | },
31 | "arg_group_2": {
32 | "type": "object",
33 | "properties": {
34 | "arg_3": {
35 | "is_recommended":
36 | "type":
37 | "default":
38 | "description":
39 | },
40 | "arg_4": {
41 | "is_recommended":
42 | "type":
43 | "default":
44 | "description":
45 | }
46 | }
47 | }
48 | }
49 | },
50 | "command_2": {
51 | "type": "object",
52 | "properties: {
53 | "arg_group_1": {
54 | "type": "object",
55 | "properties": {
56 | "arg_1": {
57 | "is_recommended":
58 | "type":
59 | "default":
60 | "description":
61 | }
62 | }
63 | }
64 | }
65 | }
66 | }
67 | """
68 |
69 | import sys
70 | import json
71 |
72 | from configs import CONFIG_POOL
73 |
74 |
75 | def parse_args_from_config(config):
76 | """Parses available arguments from a configuration class.
77 |
78 | Args:
79 | config: The configuration class to parse arguments from, which is
80 | defined in `configs/`. This class is supposed to derive from
81 | `BaseConfig` defined in `configs/base_config.py`.
82 | """
83 | recommended_opts = config.get_recommended_options()
84 | args = dict()
85 | for opt_group, opts in config.get_options().items():
86 | args[opt_group] = dict(
87 | type='object',
88 | properties=dict()
89 | )
90 | for opt in opts:
91 | arg = config.inspect_option(opt)
92 | args[opt_group]['properties'][arg.name] = dict(
93 | is_recommended=arg.name in recommended_opts,
94 | type=arg.type,
95 | default=arg.default,
96 | description=arg.help
97 | )
98 | return args
99 |
100 |
101 | def dump(configs, save_path):
102 | """Dumps available arguments from given configurations to target file.
103 |
104 | Args:
105 | configs: A list of configurations, each of which should be a
106 | class derived from `BaseConfig` defined in `configs/base_config.py`.
107 | save_path: The path to save the dumped results.
108 | """
109 | args = dict()
110 | for config in configs:
111 | args[config.name] = dict(type='object',
112 | properties=parse_args_from_config(config))
113 | with open(save_path, 'w') as f:
114 | json.dump(args, f, indent=4)
115 |
116 |
117 | if __name__ == '__main__':
118 | if len(sys.argv) != 2:
119 | sys.exit(f'Usage: python {sys.argv[0]} SAVE_PATH')
120 | dump(CONFIG_POOL, sys.argv[1])
121 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all metrics."""
3 |
4 | from .gan_snapshot import GANSnapshot
5 | from .fid import FIDMetric as FID
6 | from .fid import FID50K
7 | from .fid import FID50KFull
8 | from .inception_score import ISMetric as IS
9 | from .inception_score import IS50K
10 | from .intra_class_fid import ICFIDMetric as ICFID
11 | from .intra_class_fid import ICFID50K
12 | from .intra_class_fid import ICFID50KFull
13 | from .kid import KIDMetric as KID
14 | from .kid import KID50K
15 | from .kid import KID50KFull
16 | from .gan_pr import GANPRMetric as GANPR
17 | from .gan_pr import GANPR50K
18 | from .gan_pr import GANPR50KFull
19 | from .equivariance import EquivarianceMetric
20 | from .equivariance import EQTMetric
21 | from .equivariance import EQT50K
22 | from .equivariance import EQTFracMetric
23 | from .equivariance import EQTFrac50K
24 | from .equivariance import EQRMetric
25 | from .equivariance import EQR50K
26 |
27 | __all__ = ['build_metric']
28 |
29 | _METRICS = {
30 | 'GANSnapshot': GANSnapshot,
31 | 'FID': FID,
32 | 'FID50K': FID50K,
33 | 'FID50KFull': FID50KFull,
34 | 'IS': IS,
35 | 'IS50K': IS50K,
36 | 'ICFID': ICFID,
37 | 'ICFID50K': ICFID50K,
38 | 'ICFID50KFull': ICFID50KFull,
39 | 'KID': KID,
40 | 'KID50K': KID50K,
41 | 'KID50KFull': KID50KFull,
42 | 'GANPR': GANPR,
43 | 'GANPR50K': GANPR50K,
44 | 'GANPR50KFull': GANPR50KFull,
45 | 'Equivariance': EquivarianceMetric,
46 | 'EQT': EQTMetric,
47 | 'EQT50K': EQT50K,
48 | 'EQTFrac': EQTFracMetric,
49 | 'EQTFrac50K': EQTFrac50K,
50 | 'EQR': EQRMetric,
51 | 'EQR50K': EQR50K
52 | }
53 |
54 |
55 | def build_metric(metric_type, **kwargs):
56 | """Builds a metric evaluator based on its class type.
57 |
58 | Args:
59 | metric_type: Type of the metric, which is case sensitive.
60 | **kwargs: Configurations used to build the metric.
61 |
62 | Raises:
63 | ValueError: If the `metric_type` is not supported.
64 | """
65 | if metric_type not in _METRICS:
66 | raise ValueError(f'Invalid metric type: `{metric_type}`!\n'
67 | f'Types allowed: {list(_METRICS)}.')
68 | return _METRICS[metric_type](**kwargs)
69 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all models."""
3 |
4 | from .pggan_generator import PGGANGenerator
5 | from .pggan_discriminator import PGGANDiscriminator
6 | from .stylegan_generator import StyleGANGenerator
7 | from .stylegan_discriminator import StyleGANDiscriminator
8 | from .stylegan2_generator import StyleGAN2Generator
9 | from .stylegan2_discriminator import StyleGAN2Discriminator
10 | from .stylegan3_generator import StyleGAN3Generator
11 | from .volumegan_generator import VolumeGANGenerator
12 | from .volumegan_discriminator import VolumeGANDiscriminator
13 | from .ghfeat_encoder import GHFeatEncoder
14 | from .perceptual_model import PerceptualModel
15 | from .inception_model import InceptionModel
16 |
17 | __all__ = ['build_model']
18 |
19 | _MODELS = {
20 | 'PGGANGenerator': PGGANGenerator,
21 | 'PGGANDiscriminator': PGGANDiscriminator,
22 | 'StyleGANGenerator': StyleGANGenerator,
23 | 'StyleGANDiscriminator': StyleGANDiscriminator,
24 | 'StyleGAN2Generator': StyleGAN2Generator,
25 | 'StyleGAN2Discriminator': StyleGAN2Discriminator,
26 | 'StyleGAN3Generator': StyleGAN3Generator,
27 | 'VolumeGANGenerator': VolumeGANGenerator,
28 | 'VolumeGANDiscriminator': VolumeGANDiscriminator,
29 | 'GHFeatEncoder': GHFeatEncoder,
30 | 'PerceptualModel': PerceptualModel.build_model,
31 | 'InceptionModel': InceptionModel.build_model
32 | }
33 |
34 |
35 | def build_model(model_type, **kwargs):
36 | """Builds a model based on its class type.
37 |
38 | Args:
39 | model_type: Class type to which the model belongs, which is case
40 | sensitive.
41 | **kwargs: Additional arguments to build the model.
42 |
43 | Raises:
44 | ValueError: If the `model_type` is not supported.
45 | """
46 | if model_type not in _MODELS:
47 | raise ValueError(f'Invalid model type: `{model_type}`!\n'
48 | f'Types allowed: {list(_MODELS)}.')
49 | return _MODELS[model_type](**kwargs)
50 |
--------------------------------------------------------------------------------
/models/rendering/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all functions for rendering."""
3 | from .points_sampling import PointsSampling
4 | from .hierarchicle_sampling import HierarchicalSampling
5 | from .renderer import Renderer
6 | from .utils import interpolate_feature
7 |
8 | __all__ = ['PointsSampling', 'HierarchicalSampling', 'Renderer', 'interpolate_feature']
9 |
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/models/utils/__init__.py
--------------------------------------------------------------------------------
/models/utils/ops.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains operators for neural networks."""
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 | __all__ = ['all_gather']
8 |
9 |
10 | def all_gather(tensor):
11 | """Gathers tensor from all devices and executes averaging."""
12 | if not dist.is_initialized():
13 | return tensor
14 |
15 | world_size = dist.get_world_size()
16 | tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
17 | dist.all_gather(tensor_list, tensor, async_op=False)
18 | return torch.stack(tensor_list, dim=0).mean(dim=0)
19 |
--------------------------------------------------------------------------------
/requirements/convert.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.1
2 | tensorflow-gpu==1.15
3 | ninja==1.10.2
4 | scikit-video==1.1.11
5 | pillow==9.0.0
6 | opencv-python-headless==4.5.5.62
7 | requests
8 | bs4
9 | tqdm
10 | rich
11 | easydict
12 |
--------------------------------------------------------------------------------
/requirements/develop.txt:
--------------------------------------------------------------------------------
1 | bpytop # Monitor system resources.
2 | gpustat # Monitor GPU usage.
3 | pylint # Check coding style.
4 |
--------------------------------------------------------------------------------
/requirements/minimal.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
2 | torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
3 | tensorboard==2.7.0
4 | torch-tb-profiler==0.3.1
5 | ninja==1.10.2
6 | numpy==1.21.5
7 | scipy==1.7.3
8 | scikit-learn==1.0.2
9 | scikit-video==1.1.11
10 | pillow==9.0.0
11 | opencv-python-headless==4.5.5.62
12 | requests
13 | bs4
14 | tqdm
15 | rich
16 | click
17 | cloup
18 | psutil
19 | easydict
20 | lmdb
21 | matplotlib
22 | mrcfile
23 | pymcubes
24 | trimesh
25 | einops
26 |
--------------------------------------------------------------------------------
/runners/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all runners."""
3 |
4 | from .stylegan_runner import StyleGANRunner
5 | from .stylegan2_runner import StyleGAN2Runner
6 | from .stylegan3_runner import StyleGAN3Runner
7 | from .volumegan_runner import VolumeGANRunner
8 | __all__ = ['build_runner']
9 |
10 | _RUNNERS = {
11 | 'StyleGANRunner': StyleGANRunner,
12 | 'StyleGAN2Runner': StyleGAN2Runner,
13 | 'StyleGAN3Runner': StyleGAN3Runner,
14 | 'VolumeGANRunner': VolumeGANRunner,
15 | }
16 |
17 |
18 | def build_runner(config):
19 | """Builds a runner with given configuration.
20 |
21 | Args:
22 | config: Configurations used to build the runner.
23 |
24 | Raises:
25 | ValueError: If the `config.runner_type` is not supported.
26 | """
27 | if not isinstance(config, dict) or 'runner_type' not in config:
28 | raise ValueError('`runner_type` is missing from configuration!')
29 |
30 | runner_type = config['runner_type']
31 | if runner_type not in _RUNNERS:
32 | raise ValueError(f'Invalid runner type: `{runner_type}`!\n'
33 | f'Types allowed: {list(_RUNNERS)}.')
34 | return _RUNNERS[runner_type](config)
35 |
--------------------------------------------------------------------------------
/runners/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all augmentation pipelines."""
3 |
4 | from .no_aug import NoAug
5 | from .ada_aug import AdaAug
6 |
7 | __all__ = ['build_aug']
8 |
9 | _AUGMENTATIONS = {
10 | 'NoAug': NoAug,
11 | 'AdaAug': AdaAug
12 | }
13 |
14 |
15 | def build_aug(aug_type, **kwargs):
16 | """Builds a differentiable augmentation pipeline based on its class type.
17 |
18 | Args:
19 | aug_type: Class type to which the augmentation belongs, which is case
20 | sensitive.
21 | **kwargs: Additional arguments to build the aug.
22 |
23 | Raises:
24 | ValueError: If the `aug_type` is not supported.
25 | """
26 | if aug_type not in _AUGMENTATIONS:
27 | raise ValueError(f'Invalid augmentation type: `{aug_type}`!\n'
28 | f'Types allowed: {list(_AUGMENTATIONS)}.')
29 | return _AUGMENTATIONS[aug_type](**kwargs)
30 |
--------------------------------------------------------------------------------
/runners/augmentations/no_aug.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Defines the dummy augmentation pipeline that executes no augmentation."""
3 |
4 | import torch.nn as nn
5 |
6 | __all__ = ['NoAug']
7 |
8 |
9 | NoAug = nn.Identity
10 |
--------------------------------------------------------------------------------
/runners/controllers/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all controllers."""
3 |
4 | from .ada_aug_controller import AdaAugController
5 | from .batch_visualizer import BatchVisualizer
6 | from .cache_cleaner import CacheCleaner
7 | from .checkpointer import Checkpointer
8 | from .dataset_visualizer import DatasetVisualizer
9 | from .evaluator import Evaluator
10 | from .lr_scheduler import LRScheduler
11 | from .progress_scheduler import ProgressScheduler
12 | from .running_logger import RunningLogger
13 | from .timer import Timer
14 |
15 | __all__ = ['build_controller']
16 |
17 | _CONTROLLERS = {
18 | 'AdaAugController': AdaAugController,
19 | 'BatchVisualizer': BatchVisualizer,
20 | 'CacheCleaner': CacheCleaner,
21 | 'Checkpointer': Checkpointer,
22 | 'DatasetVisualizer': DatasetVisualizer,
23 | 'Evaluator': Evaluator,
24 | 'LRScheduler': LRScheduler,
25 | 'ProgressScheduler': ProgressScheduler,
26 | 'RunningLogger': RunningLogger,
27 | 'Timer': Timer
28 | }
29 |
30 |
31 | def build_controller(controller_type, config=None):
32 | """Builds a controller based on its class type.
33 |
34 | Args:
35 | controller_type: Class type to which the controller belongs, which is
36 | case sensitive.
37 | config: Configuration of the controller. (default: None)
38 |
39 | Raises:
40 | ValueError: If the `controller_type` is not supported.
41 | """
42 | if controller_type not in _CONTROLLERS:
43 | raise ValueError(f'Invalid controller type: `{controller_type}`!\n'
44 | f'Types allowed: {list(_CONTROLLERS)}.')
45 | return _CONTROLLERS[controller_type](config)
46 |
--------------------------------------------------------------------------------
/runners/controllers/ada_aug_controller.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the running controller to control augmentation strength."""
3 |
4 | import numpy as np
5 |
6 | import torch
7 |
8 | from ..utils.running_stats import SingleStats
9 | from ..augmentations.ada_aug import AdaAug
10 | from .base_controller import BaseController
11 |
12 | __all__ = ['AdaAugController']
13 |
14 |
15 | class AdaAugController(BaseController):
16 | """Defines the running controller to adjust the strength of augmentations.
17 |
18 | This controller works together with the augmentation pipeline introduces by
19 | StyleGAN2-ADA (https://arxiv.org/pdf/2006.06676.pdf). Concretely, AadAug,
20 | which is defined in `runners/augmentations/ada_aug.py`, augments the data
21 | based on an adjustable probability. This controller controls how this
22 | probability is adjusted.
23 |
24 | NOTE: The controller is set to `FIRST` priority.
25 |
26 | Basically, the aug_config is expected to contain following settings:
27 |
28 | (1) init_p: The init prob of augmentations. (default: 0.0)
29 | (2) target_p: The target (final) prob of augmentations. (default: 0.6)
30 | (3) every_n_iters: How often to adjust the probability. (default: 4)
31 | (4) speed_img: Speed to adjust the probability, which is measured in number
32 | of images it takes for the probability to increase/decrease by one unit.
33 | (default: 500_000)
34 | (5) strategy: The strategy to adjust the probability. Support `fixed`,
35 | `linear`, and `adaptive`. (default: `adaptive`)
36 | """
37 |
38 | def __init__(self, config):
39 | assert isinstance(config, dict)
40 | config.setdefault('priority', 'First')
41 | config.setdefault('every_n_iters', 4)
42 | config.setdefault('first_iter', False)
43 | super().__init__(config)
44 |
45 | self._init_p = config.get('init_p', 0.0)
46 | self._target_p = config.get('target_p', 0.6)
47 | self._speed_img = config.get('speed_img', 500_000)
48 | self._strategy = config.get('strategy', 'adaptive').lower()
49 | assert self._strategy in ['fixed', 'linear', 'adaptive']
50 |
51 | def setup(self, runner):
52 | """Sets the initial augmentation strength before training."""
53 | if not isinstance(runner.augment, AdaAug):
54 | raise ValueError(f'`{self.name}` only works together with '
55 | f'adaptive augmentation pipeline `AdaAug`!\n')
56 |
57 | if self._strategy == 'fixed':
58 | aug_prob = self._target_p
59 | else:
60 | aug_prob = self._init_p
61 | runner.augment.p = torch.as_tensor(aug_prob).relu()
62 | runner.augment.prob_tracker = SingleStats('Aug Prob Tracker',
63 | log_format=None,
64 | log_strategy='AVERAGE',
65 | requires_sync=True)
66 | runner.running_stats.add('Misc/Aug Prob',
67 | log_name='aug_prob',
68 | log_format='.3f',
69 | log_strategy='CURRENT',
70 | requires_sync=False,
71 | keep_previous=True)
72 |
73 | runner.logger.info('Adaptive augmentation settings:', indent_level=2)
74 | runner.logger.info(f'Strategy: {self._strategy}', indent_level=3)
75 | runner.logger.info(f'Initial probability: {self._init_p}',
76 | indent_level=3)
77 | runner.logger.info(f'Target probability : {self._target_p}',
78 | indent_level=3)
79 | runner.logger.info(f'Adjustment speed {self._speed_img} images',
80 | indent_level=3)
81 | super().setup(runner)
82 |
83 | def execute_after_iteration(self, runner):
84 | if self._strategy == 'fixed':
85 | aug_prob = self._target_p
86 | elif self._strategy == 'linear':
87 | slope = runner.iter / runner.total_iters
88 | aug_prob = self._init_p + (self._target_p - self._init_p) * slope
89 | else:
90 | minibatch = runner.batch_size * runner.world_size
91 | slope = (minibatch * self.every_n_iters) / self._speed_img
92 | criterion = runner.augment.prob_tracker.summarize()
93 | adjust = np.sign(criterion - self._target_p) * slope
94 | aug_prob = runner.augment.p + adjust
95 |
96 | runner.augment.p = torch.as_tensor(aug_prob).relu()
97 | runner.running_stats.update({'Misc/Aug Prob': runner.augment.p})
98 |
--------------------------------------------------------------------------------
/runners/controllers/cache_cleaner.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the running controller to clean cache."""
3 |
4 | import torch
5 |
6 | from .base_controller import BaseController
7 |
8 | __all__ = ['CacheCleaner']
9 |
10 |
11 | class CacheCleaner(BaseController):
12 | """Defines the running controller to clean cache.
13 |
14 | This controller is used to empty the GPU cache after each iteration.
15 |
16 | NOTE: The controller is set to `LAST` priority by default.
17 | """
18 |
19 | def __init__(self, config=None):
20 | config = config or dict()
21 | config.setdefault('priority', 'LAST')
22 | config.setdefault('every_n_iters', 1)
23 | super().__init__(config)
24 |
25 | def setup(self, runner):
26 | torch.cuda.empty_cache()
27 | super().setup(runner)
28 |
29 | def close(self, runner):
30 | torch.cuda.empty_cache()
31 |
32 | def execute_after_iteration(self, runner):
33 | torch.cuda.empty_cache()
34 |
--------------------------------------------------------------------------------
/runners/controllers/timer.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the running controller to record time."""
3 |
4 | import time
5 |
6 | from .base_controller import BaseController
7 |
8 | __all__ = ['Timer']
9 |
10 |
11 | class Timer(BaseController):
12 | """Defines the running controller to record running time.
13 |
14 | This controller will be executed every iteration (both before and after) to
15 | summarize the data preparation time as well as the model running time.
16 | Besides, this controller will also mark the start and end time of the
17 | running process.
18 |
19 | NOTE: This controller is set to `LOW` priority by default and will only be
20 | executed on the chief worker.
21 | """
22 |
23 | def __init__(self, config=None):
24 | config = config or dict()
25 | config.setdefault('priority', 'LOW')
26 | config.setdefault('every_n_iters', 1)
27 | config.setdefault('chief_only', True)
28 | super().__init__(config)
29 |
30 | self.time = time.time()
31 |
32 | def setup(self, runner):
33 | runner.running_stats.add(
34 | 'data time', log_format='time', requires_sync=False)
35 | runner.running_stats.add(
36 | 'iter time', log_format='time', requires_sync=False)
37 | runner.running_stats.add('run time',
38 | log_format='time',
39 | log_strategy='CURRENT',
40 | requires_sync=False)
41 | self.time = time.time()
42 | runner.start_time = self.time
43 |
44 | def close(self, runner):
45 | runner.end_time = time.time()
46 |
47 | def execute_before_iteration(self, runner):
48 | start_time = time.time()
49 | runner.running_stats.update({'data time': start_time - self.time})
50 |
51 | def execute_after_iteration(self, runner):
52 | end_time = time.time()
53 | runner.running_stats.update({'iter time': end_time - self.time})
54 | runner.running_stats.update({'run time': end_time - runner.start_time})
55 | self.time = end_time
56 |
--------------------------------------------------------------------------------
/runners/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all loss functions."""
3 |
4 | from .stylegan_loss import StyleGANLoss
5 | from .stylegan2_loss import StyleGAN2Loss
6 | from .stylegan3_loss import StyleGAN3Loss
7 | from .volumegan_loss import VolumeGANLoss
8 | __all__ = ['build_loss']
9 |
10 | _LOSSES = {
11 | 'StyleGANLoss': StyleGANLoss,
12 | 'StyleGAN2Loss': StyleGAN2Loss,
13 | 'StyleGAN3Loss': StyleGAN3Loss,
14 | 'VolumeGANLoss': VolumeGANLoss
15 | }
16 |
17 |
18 | def build_loss(runner, loss_type, **kwargs):
19 | """Builds a loss based on its class type.
20 |
21 | Args:
22 | runner: The runner on which the loss is built.
23 | loss_type: Class type to which the loss belongs, which is case
24 | sensitive.
25 | **kwargs: Additional arguments to build the loss.
26 |
27 | Raises:
28 | ValueError: If the `loss_type` is not supported.
29 | """
30 | if loss_type not in _LOSSES:
31 | raise ValueError(f'Invalid loss type: `{loss_type}`!\n'
32 | f'Types allowed: {list(_LOSSES)}.')
33 | return _LOSSES[loss_type](runner, **kwargs)
34 |
--------------------------------------------------------------------------------
/runners/losses/base_loss.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the base class to implement loss."""
3 |
4 | import torch
5 |
6 | __all__ = ['BaseLoss']
7 |
8 |
9 | class BaseLoss(object):
10 | """Base loss class.
11 |
12 | The derived class can easily serialize its members to a `dict`, and to load
13 | from such `dict` to resume the saved loss.
14 |
15 | NOTE: By default, the derived class will save ALL members. To ensure members
16 | are saved and loaded as expectation, you may need to override the
17 | `state_dict()` or `load_state_dict()` method.
18 | """
19 |
20 | @property
21 | def name(self):
22 | """Returns the class name of the loss."""
23 | return self.__class__.__name__
24 |
25 | def state_dict(self):
26 | """Returns a serialized `dict` that records all members.
27 |
28 | The returned `dict` maps attribute names to their values.
29 |
30 | NOTE: Override this method if such default behavior is unexpected.
31 | """
32 | return vars(self)
33 |
34 | def load_state_dict(self, state_dict):
35 | """Loads parameters from the `state_dict`.
36 |
37 | By default, this method directly sets all attributes from the given
38 | `state_dict`.
39 |
40 | NOTE: Override this method if such default behavior is unexpected.
41 | """
42 | for key, val in state_dict.items():
43 | if not hasattr(self, key): # current loss does not init with `key`
44 | continue
45 | origin_attr = getattr(self, key)
46 | if isinstance(origin_attr, torch.nn.Module):
47 | origin_attr.load_state_dict(val.state_dict())
48 | continue
49 | if isinstance(origin_attr, torch.Tensor):
50 | val = val.to(device=origin_attr.device)
51 | setattr(self, key, val)
52 |
--------------------------------------------------------------------------------
/runners/stylegan2_runner.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the runner for StyleGAN2."""
3 |
4 | from copy import deepcopy
5 |
6 | from .base_runner import BaseRunner
7 |
8 | __all__ = ['StyleGAN2Runner']
9 |
10 |
11 | class StyleGAN2Runner(BaseRunner):
12 | """Defines the runner for StyleGAN2."""
13 |
14 | def build_models(self):
15 | super().build_models()
16 |
17 | self.g_ema_img = self.config.models['generator'].get(
18 | 'g_ema_img', 10_000)
19 | self.g_ema_rampup = self.config.models['generator'].get(
20 | 'g_ema_rampup', 0)
21 | if 'generator_smooth' not in self.models:
22 | self.models['generator_smooth'] = deepcopy(self.models['generator'])
23 | self.model_kwargs_init['generator_smooth'] = deepcopy(
24 | self.model_kwargs_init['generator'])
25 | if 'generator_smooth' not in self.model_kwargs_val:
26 | self.model_kwargs_val['generator_smooth'] = deepcopy(
27 | self.model_kwargs_val['generator'])
28 |
29 | def build_loss(self):
30 | super().build_loss()
31 | self.running_stats.add('Misc/Gs Beta',
32 | log_name='Gs_beta',
33 | log_format='.4f',
34 | log_strategy='CURRENT')
35 |
36 | def train_step(self, data):
37 | # Update generator.
38 | self.models['discriminator'].requires_grad_(False)
39 | self.models['generator'].requires_grad_(True)
40 |
41 | # Update with adversarial loss.
42 | g_loss = self.loss.g_loss(self, data, sync=True)
43 | self.zero_grad_optimizer('generator')
44 | g_loss.backward()
45 | self.step_optimizer('generator')
46 |
47 | # Update with perceptual path length regularization if needed.
48 | pl_penalty = self.loss.g_reg(self, data, sync=True)
49 | if pl_penalty is not None:
50 | self.zero_grad_optimizer('generator')
51 | pl_penalty.backward()
52 | self.step_optimizer('generator')
53 |
54 | # Update discriminator.
55 | self.models['discriminator'].requires_grad_(True)
56 | self.models['generator'].requires_grad_(False)
57 |
58 | # Update with adversarial loss.
59 | self.zero_grad_optimizer('discriminator')
60 | # Update with fake images (get synchronized together with real loss).
61 | d_fake_loss = self.loss.d_fake_loss(self, data, sync=False)
62 | d_fake_loss.backward()
63 | # Update with real images.
64 | d_real_loss = self.loss.d_real_loss(self, data, sync=True)
65 | d_real_loss.backward()
66 | self.step_optimizer('discriminator')
67 |
68 | # Update with gradient penalty.
69 | r1_penalty = self.loss.d_reg(self, data, sync=True)
70 | if r1_penalty is not None:
71 | self.zero_grad_optimizer('discriminator')
72 | r1_penalty.backward()
73 | self.step_optimizer('discriminator')
74 |
75 | # Life-long update generator.
76 | if self.g_ema_rampup is not None and self.g_ema_rampup > 0:
77 | g_ema_img = min(self.g_ema_img, self.seen_img * self.g_ema_rampup)
78 | else:
79 | g_ema_img = self.g_ema_img
80 | beta = 0.5 ** (self.minibatch / max(g_ema_img, 1e-8))
81 | self.running_stats.update({'Misc/Gs Beta': beta})
82 | self.smooth_model(src=self.models['generator'],
83 | avg=self.models['generator_smooth'],
84 | beta=beta)
85 |
--------------------------------------------------------------------------------
/runners/stylegan3_runner.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the runner for StyleGAN3."""
3 |
4 | from copy import deepcopy
5 |
6 | from .base_runner import BaseRunner
7 |
8 | __all__ = ['StyleGAN3Runner']
9 |
10 |
11 | class StyleGAN3Runner(BaseRunner):
12 | """Defines the runner for StyleGAN3."""
13 |
14 | def build_models(self):
15 | super().build_models()
16 |
17 | g_ema_img = self.config.models['generator'].get('g_ema_img', 10_000)
18 | self.g_ema_img = g_ema_img * self.minibatch / 32
19 | self.g_ema_rampup = self.config.models['generator'].get(
20 | 'g_ema_rampup', 0)
21 | if 'generator_smooth' not in self.models:
22 | self.models['generator_smooth'] = deepcopy(self.models['generator'])
23 | self.model_kwargs_init['generator_smooth'] = deepcopy(
24 | self.model_kwargs_init['generator'])
25 | if 'generator_smooth' not in self.model_kwargs_val:
26 | self.model_kwargs_val['generator_smooth'] = deepcopy(
27 | self.model_kwargs_val['generator'])
28 |
29 | def build_loss(self):
30 | super().build_loss()
31 | self.running_stats.add('Misc/Gs Beta',
32 | log_name='Gs_beta',
33 | log_format='.4f',
34 | log_strategy='CURRENT')
35 |
36 | def train_step(self, data):
37 | # Update generator.
38 | self.models['discriminator'].requires_grad_(False)
39 | self.models['generator'].requires_grad_(True)
40 |
41 | # Update with adversarial loss.
42 | g_loss = self.loss.g_loss(self, data, sync=True)
43 | self.zero_grad_optimizer('generator')
44 | g_loss.backward()
45 | self.step_optimizer('generator')
46 |
47 | # Update with perceptual path length regularization if needed.
48 | pl_penalty = self.loss.g_reg(self, data, sync=True)
49 | if pl_penalty is not None:
50 | self.zero_grad_optimizer('generator')
51 | pl_penalty.backward()
52 | self.step_optimizer('generator')
53 |
54 | # Update discriminator.
55 | self.models['discriminator'].requires_grad_(True)
56 | self.models['generator'].requires_grad_(False)
57 |
58 | # Update with adversarial loss.
59 | self.zero_grad_optimizer('discriminator')
60 | # Update with fake images (get synchronized together with real loss).
61 | d_fake_loss = self.loss.d_fake_loss(self, data, sync=False)
62 | d_fake_loss.backward()
63 | # Update with real images.
64 | d_real_loss = self.loss.d_real_loss(self, data, sync=True)
65 | d_real_loss.backward()
66 | self.step_optimizer('discriminator')
67 |
68 | # Update with gradient penalty.
69 | r1_penalty = self.loss.d_reg(self, data, sync=True)
70 | if r1_penalty is not None:
71 | self.zero_grad_optimizer('discriminator')
72 | r1_penalty.backward()
73 | self.step_optimizer('discriminator')
74 |
75 | # Life-long update generator.
76 | if self.g_ema_rampup is not None and self.g_ema_rampup > 0:
77 | g_ema_img = min(self.g_ema_img, self.seen_img * self.g_ema_rampup)
78 | else:
79 | g_ema_img = self.g_ema_img
80 | beta = 0.5 ** (self.minibatch / max(g_ema_img, 1e-8))
81 | self.running_stats.update({'Misc/Gs Beta': beta})
82 | self.smooth_model(src=self.models['generator'],
83 | avg=self.models['generator_smooth'],
84 | beta=beta)
85 |
--------------------------------------------------------------------------------
/runners/stylegan_runner.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the runner for StyleGAN."""
3 |
4 | from copy import deepcopy
5 |
6 | from .base_runner import BaseRunner
7 |
8 | __all__ = ['StyleGANRunner']
9 |
10 |
11 | class StyleGANRunner(BaseRunner):
12 | """Defines the runner for StyleGAN."""
13 |
14 | def __init__(self, config):
15 | super().__init__(config)
16 | self.lod = getattr(self, 'lod', 0.0)
17 | self.D_repeats = self.config.get('D_repeats', 1)
18 |
19 | def build_models(self):
20 | super().build_models()
21 | self.g_ema_img = self.config.models['generator'].get(
22 | 'g_ema_img', 10_000)
23 | if 'generator_smooth' not in self.models:
24 | self.models['generator_smooth'] = deepcopy(self.models['generator'])
25 | self.model_kwargs_init['generator_smooth'] = deepcopy(
26 | self.model_kwargs_init['generator'])
27 | if 'generator_smooth' not in self.model_kwargs_val:
28 | self.model_kwargs_val['generator_smooth'] = deepcopy(
29 | self.model_kwargs_val['generator'])
30 |
31 | def build_loss(self):
32 | super().build_loss()
33 | self.running_stats.add('Misc/Gs Beta',
34 | log_name='Gs_beta',
35 | log_format='.4f',
36 | log_strategy='CURRENT')
37 |
38 | def train_step(self, data):
39 | # Set level-of-details.
40 | G = self.models['generator']
41 | D = self.models['discriminator']
42 | Gs = self.models['generator_smooth']
43 | G.synthesis.lod.data.fill_(self.lod)
44 | D.lod.data.fill_(self.lod)
45 | Gs.synthesis.lod.data.fill_(self.lod)
46 |
47 | # Update discriminator.
48 | self.models['discriminator'].requires_grad_(True)
49 | self.models['generator'].requires_grad_(False)
50 | d_loss = self.loss.d_loss(self, data, sync=True)
51 | self.zero_grad_optimizer('discriminator')
52 | d_loss.backward()
53 | self.step_optimizer('discriminator')
54 |
55 | # Life-long update for generator.
56 | beta = 0.5 ** (self.minibatch / self.g_ema_img)
57 | self.running_stats.update({'Misc/Gs Beta': beta})
58 | self.smooth_model(src=self.models['generator'],
59 | avg=self.models['generator_smooth'],
60 | beta=beta)
61 |
62 | # Update generator.
63 | if self.iter % self.D_repeats == 0:
64 | self.models['discriminator'].requires_grad_(False)
65 | self.models['generator'].requires_grad_(True)
66 | g_loss = self.loss.g_loss(self, data, sync=True)
67 | self.zero_grad_optimizer('generator')
68 | g_loss.backward()
69 | self.step_optimizer('generator')
70 |
71 | # Update automatic mixed-precision scaler.
72 | self.amp_scaler.update()
73 |
--------------------------------------------------------------------------------
/runners/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/runners/utils/__init__.py
--------------------------------------------------------------------------------
/runners/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the function to build optimizer for a model."""
3 |
4 | import torch
5 |
6 | __all__ = ['build_optimizer']
7 |
8 | _ALLOWED_OPT_TYPES = ['sgd', 'adam']
9 |
10 |
11 | def build_optimizer(config, model):
12 | """Builds an optimizer for the given model.
13 |
14 | Basically, the configuration is expected to contain following settings:
15 |
16 | (1) opt_type: The type of the optimizer. (required)
17 | (2) base_lr: The base learning rate for all parameters. (required)
18 | (3) base_wd: The base weight decay for all parameters. (default: 0.0)
19 | (4) bias_lr_multiplier: The learning rate multiplier for bias parameters.
20 | (default: 1.0)
21 | (5) bias_wd_multiplier: The weight decay multiplier for bias parameters.
22 | (default: 1.0)
23 | (6) **kwargs: Additional settings for the optimizer, such as `momentum`.
24 |
25 | Args:
26 | config: The configuration used to build the optimizer.
27 | model: The model which the optimizer serves.
28 |
29 | Returns:
30 | A `torch.optim.Optimizer`.
31 |
32 | Raises:
33 | ValueError: The `opt_type` is not supported.
34 | NotImplementedError: If `opt_type` is not implemented.
35 | """
36 | assert isinstance(config, dict)
37 | opt_type = config['opt_type'].lower()
38 | base_lr = config['base_lr']
39 | base_wd = config.get('base_wd', 0.0)
40 | bias_lr_multiplier = config.get('bias_lr_multiplier', 1.0)
41 | bias_wd_multiplier = config.get('bias_wd_multiplier', 1.0)
42 |
43 | if opt_type not in _ALLOWED_OPT_TYPES:
44 | raise ValueError(f'Invalid optimizer type `{opt_type}`!'
45 | f'Allowed types: {_ALLOWED_OPT_TYPES}.')
46 |
47 | model_params = []
48 | for param_name, param in model.named_parameters():
49 | param_group = {'params': [param]}
50 | if 'bias' in param_name:
51 | param_group['lr'] = base_lr * bias_lr_multiplier
52 | param_group['weight_decay'] = base_wd * bias_wd_multiplier
53 | else:
54 | param_group['lr'] = base_lr
55 | param_group['weight_decay'] = base_wd
56 | model_params.append(param_group)
57 |
58 | if opt_type == 'sgd':
59 | return torch.optim.SGD(params=model_params,
60 | lr=base_lr,
61 | momentum=config.get('momentum', 0.9),
62 | dampening=config.get('dampening', 0),
63 | weight_decay=base_wd,
64 | nesterov=config.get('nesterov', False))
65 | if opt_type == 'adam':
66 | return torch.optim.Adam(params=model_params,
67 | lr=base_lr,
68 | betas=config.get('betas', (0.9, 0.999)),
69 | eps=config.get('eps', 1e-8),
70 | weight_decay=base_wd,
71 | amsgrad=config.get('amsgrad', False))
72 | raise NotImplementedError(f'Not implemented optimizer type `{opt_type}`!')
73 |
--------------------------------------------------------------------------------
/runners/utils/profiler.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class for profiling."""
3 |
4 | import torch
5 | from torch import distributed as dist
6 |
7 | from utils.tf_utils import import_tb_writer
8 |
9 | SummaryWriter = import_tb_writer()
10 |
11 | __all__ = ['Profiler']
12 |
13 |
14 | class Profiler(object):
15 | """Defines the profiler.
16 |
17 | Essentially, this is a wrapper of `torch.profiler.profile`.
18 | If `enable` is set to `False`, this profiler becomes a dummy context manager
19 | with a dummy `step()`.
20 |
21 | Args:
22 | enable: `bool`, whether to enable `torch.profiler`.
23 | tb_dir: `str`, path to save profiler's TensorBoard events file.
24 | logger: `utils.loggers` or `logging.Logger`, the event logging system.
25 | **schedule_kwargs: settings to the `schedule` of
26 | `torch.profiler.profile`. The profiler will skip the first
27 | `skip_first` steps, then wait for `wait` steps, then do the warmup
28 | for the next `warmup` steps, then do the active recording for the
29 | next `active` steps and then repeat the cycle starting with `wait`
30 | steps. (default: dict(wait=1, warmup=1, active=3, repeat=2))
31 | """
32 | def __init__(self,
33 | enable=False,
34 | tb_dir='.',
35 | logger=None,
36 | **schedule_kwargs):
37 | self.enable = enable
38 | rank = dist.get_rank() if dist.is_initialized() else 0
39 | if enable and SummaryWriter is not None:
40 | try: # In case the PyTorch version is outdated.
41 | if rank == 0:
42 | SummaryWriter(tb_dir)
43 | if dist.is_initialized():
44 | dist.barrier() # Only create one TensorBoard event.
45 | if schedule_kwargs is None:
46 | schedule_kwargs = dict(wait=1, warmup=1, active=3, repeat=2)
47 | # Profile CPU and each GPU.
48 | self.profiler = torch.profiler.profile(
49 | schedule=torch.profiler.schedule(**schedule_kwargs),
50 | on_trace_ready=torch.profiler.tensorboard_trace_handler(
51 | tb_dir),
52 | record_shapes=True,
53 | with_stack=True)
54 | if logger:
55 | logger.info(f'Enable profiler with schedule: '
56 | f'{schedule_kwargs}.\n')
57 | except AttributeError as error:
58 | logger.warning(f'Skipping profiler due to {error}!\n'
59 | f'Please update your PyTorch to 1.8.1 or '
60 | f'later to enable profiler.\n')
61 | self.enable = False
62 |
63 | def __enter__(self):
64 | if self.enable:
65 | return self.profiler.__enter__()
66 | return self
67 |
68 | def __exit__(self, exc_type, exc_val, exc_tb):
69 | if self.enable:
70 | self.profiler.__exit__(exc_type, exc_val, exc_tb)
71 |
72 | def step(self):
73 | """Executes the profiler for one step."""
74 | if self.enable:
75 | self.profiler.step()
76 |
--------------------------------------------------------------------------------
/runners/volumegan_runner.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the runner for VolumeGAN."""
3 |
4 | from copy import deepcopy
5 |
6 | from .base_runner import BaseRunner
7 |
8 | __all__ = ['VolumeGANRunner']
9 |
10 |
11 | class VolumeGANRunner(BaseRunner):
12 | """Defines the runner for VolumeGAN."""
13 | def __init__(self, config):
14 | super().__init__(config)
15 | self.lod = getattr(self, 'lod', 0.0)
16 |
17 | def build_models(self):
18 | super().build_models()
19 |
20 | self.g_ema_img = self.config.models['generator'].get(
21 | 'g_ema_img', 10_000)
22 | self.g_ema_rampup = self.config.models['generator'].get(
23 | 'g_ema_rampup', 0)
24 | if 'generator_smooth' not in self.models:
25 | self.models['generator_smooth'] = deepcopy(self.models['generator'])
26 | self.model_kwargs_init['generator_smooth'] = deepcopy(
27 | self.model_kwargs_init['generator'])
28 | if 'generator_smooth' not in self.model_kwargs_val:
29 | self.model_kwargs_val['generator_smooth'] = deepcopy(
30 | self.model_kwargs_val['generator'])
31 |
32 | def build_loss(self):
33 | super().build_loss()
34 | self.running_stats.add('Misc/Gs Beta',
35 | log_name='Gs_beta',
36 | log_format='.4f',
37 | log_strategy='CURRENT')
38 |
39 | def train_step(self, data):
40 | # Set lod for progressive training
41 | self.models['generator'].synthesis.lod.data.fill_(self.lod)
42 | self.models['discriminator'].lod.data.fill_(self.lod)
43 | self.models['generator_smooth'].synthesis.lod.data.fill_(self.lod)
44 |
45 | # Update generator.
46 | self.models['discriminator'].requires_grad_(False)
47 | self.models['generator'].requires_grad_(True)
48 |
49 | # Update with adversarial loss.
50 | g_loss = self.loss.g_loss(self, data, sync=True)
51 | self.zero_grad_optimizer('generator')
52 | g_loss.backward()
53 | self.step_optimizer('generator')
54 |
55 | # Update with perceptual path length regularization if needed.
56 | pl_penalty = self.loss.g_reg(self, data, sync=True)
57 | if pl_penalty is not None:
58 | self.zero_grad_optimizer('generator')
59 | pl_penalty.backward()
60 | self.step_optimizer('generator')
61 |
62 | # Update discriminator.
63 | self.models['discriminator'].requires_grad_(True)
64 | self.models['generator'].requires_grad_(False)
65 |
66 | # Update with adversarial loss.
67 | self.zero_grad_optimizer('discriminator')
68 | # Update with fake images (get synchronized together with real loss).
69 | d_fake_loss = self.loss.d_fake_loss(self, data, sync=False)
70 | d_fake_loss.backward()
71 | # Update with real images.
72 | d_real_loss = self.loss.d_real_loss(self, data, sync=True)
73 | d_real_loss.backward()
74 | self.step_optimizer('discriminator')
75 |
76 | # Update with gradient penalty.
77 | r1_penalty = self.loss.d_reg(self, data, sync=True)
78 | if r1_penalty is not None:
79 | self.zero_grad_optimizer('discriminator')
80 | r1_penalty.backward()
81 | self.step_optimizer('discriminator')
82 |
83 | # Life-long update generator.
84 | if self.g_ema_rampup is not None and self.g_ema_rampup > 0:
85 | g_ema_img = min(self.g_ema_img, self.seen_img * self.g_ema_rampup)
86 | else:
87 | g_ema_img = self.g_ema_img
88 | beta = 0.5 ** (self.minibatch / max(g_ema_img, 1e-8))
89 | self.running_stats.update({'Misc/Gs Beta': beta})
90 | self.smooth_model(src=self.models['generator'],
91 | avg=self.models['generator_smooth'],
92 | beta=beta)
93 |
--------------------------------------------------------------------------------
/scripts/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Detect `python3` command.
4 | # This workaround addresses a common issue:
5 | # `python` points to `python2`, which is deprecated.
6 | export PYTHONS
7 | export RVAL
8 |
9 | PYTHONS=$(compgen -c | grep "^python3$")
10 |
11 | # `$?` is a built-in variable in bash, which is the exit status of the most
12 | # recently-executed command; by convention, 0 means success and anything else
13 | # indicates failure.
14 | RVAL=$?
15 |
16 | if [[ $RVAL -eq 0 ]]; then # if `python3` exist
17 | PYTHON="python3"
18 | else
19 | PYTHON="python"
20 | fi
21 |
22 | # Help message.
23 | if [[ $# -lt 2 ]]; then
24 | echo "This script helps launch distributed training job on local machine."
25 | echo
26 | echo "Usage: $0 GPUS COMMAND [ARGS]"
27 | echo
28 | echo "Example: $0 8 stylegan2 [--help]"
29 | echo
30 | echo "Detailed instruction on available commands:"
31 | echo "--------------------------------------------------"
32 | ${PYTHON} ./train.py --help
33 | echo
34 | exit 0
35 | fi
36 |
37 | GPUS=$1
38 | COMMAND=$2
39 |
40 | # Help message for a particular command.
41 | if [[ $# -lt 3 || ${*: -1} == "--help" ]]; then
42 | echo "Detailed instruction on the arguments for command \`"${COMMAND}"\`:"
43 | echo "--------------------------------------------------"
44 | ${PYTHON} ./train.py ${COMMAND} --help
45 | echo
46 | exit 0
47 | fi
48 |
49 | # Switch memory allocator if available.
50 | # Search order: jemalloc.so -> tcmalloc.so.
51 | # According to https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html,
52 | # it can get better performance by reusing memory as much as possible than
53 | # default malloc function.
54 | JEMALLOC=$(ldconfig -p | grep -i "libjemalloc.so$" | tr " " "\n" | grep "/" \
55 | | head -n 1)
56 | TCMALLOC=$(ldconfig -p | grep -i "libtcmalloc.so.4$" | tr " " "\n" | grep "/" \
57 | | head -n 1)
58 | if [ -n "$JEMALLOC" ]; then # if found the path to libjemalloc.so
59 | echo "Switch memory allocator to jemalloc."
60 | export LD_PRELOAD=$JEMALLOC:$LD_PRELOAD
61 | elif [ -n "$TCMALLOC" ]; then # if found the path to libtcmalloc.so.4
62 | echo "Switch memory allocator to tcmalloc."
63 | export LD_PRELOAD=$TCMALLOC:$LD_PRELOAD
64 | fi
65 |
66 | # Get an available port for launching distributed training.
67 | # Credit to https://superuser.com/a/1293762.
68 | export DEFAULT_FREE_PORT
69 | DEFAULT_FREE_PORT=$(comm -23 <(seq 49152 65535 | sort) \
70 | <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) \
71 | | shuf | head -n 1)
72 |
73 | PORT=${PORT:-$DEFAULT_FREE_PORT}
74 |
75 | ${PYTHON} -m torch.distributed.launch \
76 | --nproc_per_node=${GPUS} \
77 | --master_port=${PORT} \
78 | ./train.py \
79 | --launcher="pytorch" \
80 | --backend="nccl" \
81 | ${COMMAND} \
82 | ${@:3} \
83 | || exit 1 # Stop the script when it finds exception threw by Python.
84 |
--------------------------------------------------------------------------------
/scripts/kill_zombies.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help information.
4 | if [[ $# -le 0 || ${*: -1} == "-h" || ${*: -1} == "--help" ]]; then
5 | echo "This script kills processes launched with" \
6 | "\`./scripts/dist_train.sh\`, with arguments as keywords to filter."
7 | echo
8 | echo "Note: It does NOT check whether they are zombies. Hence," \
9 | "to ensure killing the desired processes rather than innocent ones," \
10 | "you MUST provide sufficient arguments to identified targets."
11 | echo
12 | echo "Usage: $0 [any arguments pass to your \`dist_train.sh\`]."
13 | echo
14 | echo "Example: $0 configs/stylegan2_config.py --work_dir work_dirs/debug"
15 | echo
16 | exit 0
17 | fi
18 |
19 | kill -9 echo $(ps ux | grep "$*" | grep -v grep | awk '{print $2}')
20 |
--------------------------------------------------------------------------------
/scripts/test_metrics.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help information.
4 | if [[ $# -lt 4 || ${*: -1} == "-h" || ${*: -1} == "--help" ]]; then
5 | echo "This script tests metrics defined in \`./metrics/\`."
6 | echo
7 | echo "Usage: $0 GPUS DATASET MODEL METRICS"
8 | echo
9 | echo "Note: More than one metric should be separated by comma." \
10 | "Also, all metrics assume using all samples from the real dataset" \
11 | "and 50000 fake samples for GAN-related metrics."
12 | echo
13 | echo "Example: $0 1 ~/data/ffhq1024.zip ~/checkpoints/ffhq1024.pth" \
14 | "fid,is,kid,gan_pr,snapshot,equivariance"
15 | echo
16 | exit 0
17 | fi
18 |
19 | # Get an available port for launching distributed training.
20 | # Credit to https://superuser.com/a/1293762.
21 | export DEFAULT_FREE_PORT
22 | DEFAULT_FREE_PORT=$(comm -23 <(seq 49152 65535 | sort) \
23 | <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) \
24 | | shuf | head -n 1)
25 |
26 | GPUS=$1
27 | DATASET=$2
28 | MODEL=$3
29 | PORT=${PORT:-$DEFAULT_FREE_PORT}
30 |
31 | # Parse metrics to test.
32 | METRICS=$4
33 | TEST_FID="false"
34 | TEST_IS="false"
35 | TEST_KID="false"
36 | TEST_GAN_PR="false"
37 | TEST_SNAPSHOT="false"
38 | TEST_EQUIVARIANCE="false"
39 | if [[ ${METRICS} == "all" ]]; then
40 | TEST_FID="true"
41 | TEST_IS="true"
42 | TEST_KID="true"
43 | TEST_GAN_PR="true"
44 | TEST_SNAPSHOT="true"
45 | TEST_EQUIVARIANCE="true"
46 | else
47 | array=(${METRICS//,/ })
48 | for var in ${array[@]}; do
49 | if [[ ${var} == "fid" ]]; then
50 | TEST_FID="true"
51 | fi
52 | if [[ ${var} == "is" ]]; then
53 | TEST_IS="true"
54 | fi
55 | if [[ ${var} == "kid" ]]; then
56 | TEST_KID="true"
57 | fi
58 | if [[ ${var} == "gan_pr" ]]; then
59 | TEST_GAN_PR="true"
60 | fi
61 | if [[ ${var} == "snapshot" ]]; then
62 | TEST_SNAPSHOT="true"
63 | fi
64 | if [[ ${var} == "equivariance" ]]; then
65 | TEST_EQUIVARIANCE="true"
66 | fi
67 | done
68 | fi
69 |
70 | # Detect `python3` command.
71 | # This workaround addresses a common issue:
72 | # `python` points to python2, which is deprecated.
73 | export PYTHONS
74 | export RVAL
75 |
76 | PYTHONS=$(compgen -c | grep "^python3$")
77 |
78 | # `$?` is a built-in variable in bash, which is the exit status of the most
79 | # recently-executed command; by convention, 0 means success and anything else
80 | # indicates failure.
81 | RVAL=$?
82 |
83 | if [ $RVAL -eq 0 ]; then # if `python3` exist
84 | PYTHON="python3"
85 | else
86 | PYTHON="python"
87 | fi
88 |
89 | ${PYTHON} -m torch.distributed.launch \
90 | --nproc_per_node=${GPUS} \
91 | --master_port=${PORT} \
92 | ./test_metrics.py \
93 | --launcher="pytorch" \
94 | --backend="nccl" \
95 | --dataset ${DATASET} \
96 | --model ${MODEL} \
97 | --real_num -1 \
98 | --fake_num 50000 \
99 | --test_fid ${TEST_FID} \
100 | --test_is ${TEST_IS} \
101 | --test_kid ${TEST_KID} \
102 | --test_gan_pr ${TEST_GAN_PR} \
103 | --test_snapshot ${TEST_SNAPSHOT} \
104 | --test_equivariance ${TEST_EQUIVARIANCE} \
105 | ${@:5}
106 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan2_ffhq1024.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN2 on FFHQ-1024."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/ffhq1024.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \
22 | --job_name='stylegan2_ffhq1024' \
23 | --seed=0 \
24 | --resolution=1024 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=4 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=200 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.002 \
45 | --g_lr=0.002 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_interval=16 \
50 | --r1_gamma=10.0 \
51 | --pl_interval=4 \
52 | --pl_weight=2.0 \
53 | --pl_decay=0.01 \
54 | --pl_batch_shrink=2 \
55 | --g_ema_img=10_000 \
56 | --g_ema_rampup=0.0 \
57 | --eval_at_start=true \
58 | --eval_interval=6400 \
59 | --ckpt_interval=6400 \
60 | --log_interval=128 \
61 | --enable_amp=false \
62 | --use_ada=false \
63 | --num_fp16_res=0 \
64 | ${@:3}
65 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan2_ffhq256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN2 on FFHQ-256."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/ffhq256.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \
22 | --job_name='stylegan2_ffhq256' \
23 | --seed=0 \
24 | --resolution=256 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=16 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=200 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.002 \
45 | --g_lr=0.002 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_interval=16 \
50 | --r1_gamma=10.0 \
51 | --pl_interval=4 \
52 | --pl_weight=2.0 \
53 | --pl_decay=0.01 \
54 | --pl_batch_shrink=2 \
55 | --g_ema_img=10_000 \
56 | --g_ema_rampup=0.0 \
57 | --eval_at_start=true \
58 | --eval_interval=6400 \
59 | --ckpt_interval=6400 \
60 | --log_interval=128 \
61 | --enable_amp=false \
62 | --use_ada=false \
63 | --num_fp16_res=0 \
64 | ${@:3}
65 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan2_ffhq512.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN2 on FFHQ-512."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/ffhq512.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \
22 | --job_name='stylegan2_ffhq512' \
23 | --seed=0 \
24 | --resolution=512 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=8 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=200 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.002 \
45 | --g_lr=0.002 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_interval=16 \
50 | --r1_gamma=10.0 \
51 | --pl_interval=4 \
52 | --pl_weight=2.0 \
53 | --pl_decay=0.01 \
54 | --pl_batch_shrink=2 \
55 | --g_ema_img=10_000 \
56 | --g_ema_rampup=0.0 \
57 | --eval_at_start=true \
58 | --eval_interval=6400 \
59 | --ckpt_interval=6400 \
60 | --log_interval=128 \
61 | --enable_amp=false \
62 | --use_ada=false \
63 | --num_fp16_res=0 \
64 | ${@:3}
65 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan2_lsun_bedroom256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN2 on LSUN-Bedroom-256."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/LSUN/bedroom_train_lmdb [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \
22 | --job_name='stylegan2_lsun_bedroom256' \
23 | --seed=0 \
24 | --resolution=256 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=100_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=16 \
32 | --train_data_mirror=false \
33 | --data_loader_type='iter' \
34 | --data_repeat=30 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.002 \
45 | --g_lr=0.002 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_interval=16 \
50 | --r1_gamma=10.0 \
51 | --pl_interval=4 \
52 | --pl_weight=2.0 \
53 | --pl_decay=0.01 \
54 | --pl_batch_shrink=2 \
55 | --g_ema_img=10_000 \
56 | --g_ema_rampup=0.0 \
57 | --eval_at_start=true \
58 | --eval_interval=6400 \
59 | --ckpt_interval=6400 \
60 | --log_interval=128 \
61 | --enable_amp=false \
62 | --use_ada=false \
63 | --num_fp16_res=0 \
64 | ${@:3}
65 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan2ada_afhq512.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN2-ADA on AFHQ-512."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/afhq512.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \
22 | --job_name='stylegan2ada_afhq512' \
23 | --seed=0 \
24 | --resolution=512 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=8 \
31 | --val_batch_size=8 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=2000 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=8 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.0025 \
45 | --g_lr=0.0025 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_interval=16 \
50 | --r1_gamma=0.5 \
51 | --pl_interval=4 \
52 | --pl_weight=2.0 \
53 | --pl_decay=0.01 \
54 | --pl_batch_shrink=2 \
55 | --g_ema_img=20_000 \
56 | --g_ema_rampup=0.0 \
57 | --eval_at_start=true \
58 | --eval_interval=3200 \
59 | --ckpt_interval=3200 \
60 | --log_interval=64 \
61 | --enable_amp=false \
62 | --use_ada=true \
63 | --num_fp16_res=4 \
64 | ${@:3}
65 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan2ada_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN2-ADA on CIFAR10."
6 | echo
7 | echo "Note: All settings are already preset for training with 2 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 2 /data/cifar10.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \
22 | --job_name='stylegan2ada_cifar10' \
23 | --seed=0 \
24 | --resolution=32 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --train_anno_meta='annotation.json' \
28 | --val_dataset=${DATASET} \
29 | --val_anno_meta='annotation.json' \
30 | --val_max_samples=-1 \
31 | --total_img=100_000_000 \
32 | --batch_size=32 \
33 | --val_batch_size=128 \
34 | --train_data_mirror=false \
35 | --data_loader_type='iter' \
36 | --data_repeat=500 \
37 | --data_workers=3 \
38 | --data_prefetch_factor=2 \
39 | --data_pin_memory=true \
40 | --g_init_res=4 \
41 | --latent_dim=512 \
42 | --label_dim=10 \
43 | --d_fmaps_factor=1.0 \
44 | --g_fmaps_factor=1.0 \
45 | --d_mbstd_groups=8 \
46 | --g_num_mappings=2 \
47 | --g_architecture='origin' \
48 | --d_lr=0.0025 \
49 | --g_lr=0.0025 \
50 | --w_moving_decay=0.995 \
51 | --sync_w_avg=false \
52 | --style_mixing_prob=0.0 \
53 | --r1_interval=16 \
54 | --r1_gamma=0.01 \
55 | --pl_interval=4 \
56 | --pl_weight=0.0 \
57 | --pl_decay=0.0 \
58 | --pl_batch_shrink=2 \
59 | --g_ema_img=500_000 \
60 | --g_ema_rampup=0.05 \
61 | --eval_at_start=true \
62 | --eval_interval=3200 \
63 | --ckpt_interval=3200 \
64 | --log_interval=64 \
65 | --enable_amp=false \
66 | --use_ada=true \
67 | --num_fp16_res=4 \
68 | ${@:3}
69 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan2ada_ffhq1024.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN2-ADA on FFHQ-1024."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/ffhq1024.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \
22 | --job_name='stylegan2ada_ffhq1024' \
23 | --seed=0 \
24 | --resolution=1024 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=4 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=200 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.002 \
45 | --g_lr=0.002 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_interval=16 \
50 | --r1_gamma=2.0 \
51 | --pl_interval=4 \
52 | --pl_weight=2.0 \
53 | --pl_decay=0.01 \
54 | --pl_batch_shrink=2 \
55 | --g_ema_img=10_000 \
56 | --g_ema_rampup=0.0 \
57 | --eval_at_start=true \
58 | --eval_interval=6400 \
59 | --ckpt_interval=6400 \
60 | --log_interval=128 \
61 | --enable_amp=false \
62 | --use_ada=true \
63 | --num_fp16_res=4 \
64 | ${@:3}
65 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan2ada_ffhq256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN2-ADA on FFHQ-256."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/ffhq256.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \
22 | --job_name='stylegan2ada_ffhq256' \
23 | --seed=0 \
24 | --resolution=256 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=8 \
31 | --val_batch_size=16 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=200 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=0.5 \
41 | --g_fmaps_factor=0.5 \
42 | --d_mbstd_groups=8 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.0025 \
45 | --g_lr=0.0025 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_interval=16 \
50 | --r1_gamma=1.0 \
51 | --pl_interval=4 \
52 | --pl_weight=2.0 \
53 | --pl_decay=0.01 \
54 | --pl_batch_shrink=2 \
55 | --g_ema_img=20_000 \
56 | --g_ema_rampup=0.0 \
57 | --eval_at_start=true \
58 | --eval_interval=3200 \
59 | --ckpt_interval=3200 \
60 | --log_interval=64 \
61 | --enable_amp=false \
62 | --use_ada=true \
63 | --num_fp16_res=4 \
64 | ${@:3}
65 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan3r_afhq512.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN3 (config R) on" \
6 | "AFHQ-512."
7 | echo
8 | echo "Note: All settings are already preset for training with 8 GPUs." \
9 | "Please pass addition options, which will overwrite the original" \
10 | "settings, if needed."
11 | echo
12 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
13 | echo
14 | echo "Example: $0 8 /data/afhq512.zip [--help]"
15 | echo
16 | exit 0
17 | fi
18 |
19 | GPUS=$1
20 | DATASET=$2
21 |
22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \
23 | --job_name='stylegan3r_afhq512' \
24 | --seed=0 \
25 | --resolution=512 \
26 | --image_channels=3 \
27 | --train_dataset=${DATASET} \
28 | --val_dataset=${DATASET} \
29 | --val_max_samples=-1 \
30 | --total_img=25_000_000 \
31 | --batch_size=4 \
32 | --val_batch_size=8 \
33 | --train_data_mirror=true \
34 | --data_loader_type='iter' \
35 | --data_repeat=2000 \
36 | --data_workers=3 \
37 | --data_prefetch_factor=2 \
38 | --data_pin_memory=true \
39 | --g_kernel_size=1 \
40 | --latent_dim=512 \
41 | --d_fmaps_factor=1.0 \
42 | --g_fmaps_factor=1.0 \
43 | --d_mbstd_groups=4 \
44 | --g_num_mappings=2 \
45 | --d_lr=0.002 \
46 | --g_lr=0.0025 \
47 | --w_moving_decay=0.998 \
48 | --sync_w_avg=false \
49 | --style_mixing_prob=0.0 \
50 | --r1_interval=16 \
51 | --r1_gamma=16.4 \
52 | --blur_init_sigma=10.0 \
53 | --blur_fade_img=200_000 \
54 | --pl_interval=0 \
55 | --pl_weight=0.0 \
56 | --pl_decay=0.01 \
57 | --pl_batch_shrink=2 \
58 | --g_ema_img=10_000 \
59 | --g_ema_rampup=0.05 \
60 | --eval_at_start=true \
61 | --eval_interval=6400 \
62 | --ckpt_interval=6400 \
63 | --log_interval=128 \
64 | --enable_amp=false \
65 | --use_ada=true \
66 | --num_fp16_res=4 \
67 | ${@:3}
68 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan3r_ffhqu1024.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN3 (config R) on" \
6 | "FFHQ-U-1024."
7 | echo
8 | echo "Note: All settings are already preset for training with 8 GPUs." \
9 | "Please pass addition options, which will overwrite the original" \
10 | "settings, if needed."
11 | echo
12 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
13 | echo
14 | echo "Example: $0 8 /data/ffhqu1024.zip [--help]"
15 | echo
16 | exit 0
17 | fi
18 |
19 | GPUS=$1
20 | DATASET=$2
21 |
22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \
23 | --job_name='stylegan3r_ffhqu1024' \
24 | --seed=0 \
25 | --resolution=1024 \
26 | --image_channels=3 \
27 | --train_dataset=${DATASET} \
28 | --val_dataset=${DATASET} \
29 | --val_max_samples=-1 \
30 | --total_img=25_000_000 \
31 | --batch_size=4 \
32 | --val_batch_size=4 \
33 | --train_data_mirror=true \
34 | --data_loader_type='iter' \
35 | --data_repeat=200 \
36 | --data_workers=3 \
37 | --data_prefetch_factor=2 \
38 | --data_pin_memory=true \
39 | --g_kernel_size=1 \
40 | --latent_dim=512 \
41 | --d_fmaps_factor=1.0 \
42 | --g_fmaps_factor=1.0 \
43 | --d_mbstd_groups=4 \
44 | --g_num_mappings=2 \
45 | --d_lr=0.002 \
46 | --g_lr=0.0025 \
47 | --w_moving_decay=0.998 \
48 | --sync_w_avg=false \
49 | --style_mixing_prob=0.0 \
50 | --r1_interval=16 \
51 | --r1_gamma=32.8 \
52 | --blur_init_sigma=10.0 \
53 | --blur_fade_img=200_000 \
54 | --pl_interval=0 \
55 | --pl_weight=0.0 \
56 | --pl_decay=0.01 \
57 | --pl_batch_shrink=2 \
58 | --g_ema_img=10_000 \
59 | --g_ema_rampup=0.05 \
60 | --eval_at_start=true \
61 | --eval_interval=6400 \
62 | --ckpt_interval=6400 \
63 | --log_interval=128 \
64 | --enable_amp=false \
65 | --use_ada=true \
66 | --num_fp16_res=4 \
67 | ${@:3}
68 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan3r_ffhqu256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN3 (config R) on" \
6 | "FFHQ-U-256."
7 | echo
8 | echo "Note: All settings are already preset for training with 8 GPUs." \
9 | "Please pass addition options, which will overwrite the original" \
10 | "settings, if needed."
11 | echo
12 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
13 | echo
14 | echo "Example: $0 8 /data/ffhqu256.zip [--help]"
15 | echo
16 | exit 0
17 | fi
18 |
19 | GPUS=$1
20 | DATASET=$2
21 |
22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \
23 | --job_name='stylegan3r_ffhqu256' \
24 | --seed=0 \
25 | --resolution=256 \
26 | --image_channels=3 \
27 | --train_dataset=${DATASET} \
28 | --val_dataset=${DATASET} \
29 | --val_max_samples=-1 \
30 | --total_img=25_000_000 \
31 | --batch_size=8 \
32 | --val_batch_size=16 \
33 | --train_data_mirror=true \
34 | --data_loader_type='iter' \
35 | --data_repeat=200 \
36 | --data_workers=3 \
37 | --data_prefetch_factor=2 \
38 | --data_pin_memory=true \
39 | --g_kernel_size=1 \
40 | --latent_dim=512 \
41 | --d_fmaps_factor=0.5 \
42 | --g_fmaps_factor=0.5 \
43 | --d_mbstd_groups=4 \
44 | --g_num_mappings=2 \
45 | --d_lr=0.0025 \
46 | --g_lr=0.0025 \
47 | --w_moving_decay=0.998 \
48 | --sync_w_avg=false \
49 | --style_mixing_prob=0.0 \
50 | --r1_interval=16 \
51 | --r1_gamma=1.0 \
52 | --blur_init_sigma=10.0 \
53 | --blur_fade_img=200_000 \
54 | --pl_interval=0 \
55 | --pl_weight=0.0 \
56 | --pl_decay=0.01 \
57 | --pl_batch_shrink=2 \
58 | --g_ema_img=10_000 \
59 | --g_ema_rampup=0.05 \
60 | --eval_at_start=true \
61 | --eval_interval=3200 \
62 | --ckpt_interval=3200 \
63 | --log_interval=64 \
64 | --enable_amp=false \
65 | --use_ada=true \
66 | --num_fp16_res=4 \
67 | ${@:3}
68 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan3t_afhq512.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN3 (config T) on" \
6 | "AFHQ-512."
7 | echo
8 | echo "Note: All settings are already preset for training with 8 GPUs." \
9 | "Please pass addition options, which will overwrite the original" \
10 | "settings, if needed."
11 | echo
12 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
13 | echo
14 | echo "Example: $0 8 /data/afhq512.zip [--help]"
15 | echo
16 | exit 0
17 | fi
18 |
19 | GPUS=$1
20 | DATASET=$2
21 |
22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \
23 | --job_name='stylegan3t_afhq512' \
24 | --seed=0 \
25 | --resolution=512 \
26 | --image_channels=3 \
27 | --train_dataset=${DATASET} \
28 | --val_dataset=${DATASET} \
29 | --val_max_samples=-1 \
30 | --total_img=25_000_000 \
31 | --batch_size=4 \
32 | --val_batch_size=8 \
33 | --train_data_mirror=true \
34 | --data_loader_type='iter' \
35 | --data_repeat=2000 \
36 | --data_workers=3 \
37 | --data_prefetch_factor=2 \
38 | --data_pin_memory=true \
39 | --g_kernel_size=3 \
40 | --latent_dim=512 \
41 | --d_fmaps_factor=1.0 \
42 | --g_fmaps_factor=1.0 \
43 | --d_mbstd_groups=4 \
44 | --g_num_mappings=2 \
45 | --d_lr=0.002 \
46 | --g_lr=0.0025 \
47 | --w_moving_decay=0.998 \
48 | --sync_w_avg=false \
49 | --style_mixing_prob=0.0 \
50 | --r1_interval=16 \
51 | --r1_gamma=8.2 \
52 | --blur_init_sigma=0.0 \
53 | --blur_fade_img=0 \
54 | --pl_interval=0 \
55 | --pl_weight=0.0 \
56 | --pl_decay=0.01 \
57 | --pl_batch_shrink=2 \
58 | --g_ema_img=10_000 \
59 | --g_ema_rampup=0.05 \
60 | --eval_at_start=true \
61 | --eval_interval=6400 \
62 | --ckpt_interval=6400 \
63 | --log_interval=128 \
64 | --enable_amp=false \
65 | --use_ada=true \
66 | --num_fp16_res=4 \
67 | ${@:3}
68 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan3t_ffhqu1024.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN3 (config T) on" \
6 | "FFHQ-U-1024."
7 | echo
8 | echo "Note: All settings are already preset for training with 8 GPUs." \
9 | "Please pass addition options, which will overwrite the original" \
10 | "settings, if needed."
11 | echo
12 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
13 | echo
14 | echo "Example: $0 8 /data/ffhqu1024.zip [--help]"
15 | echo
16 | exit 0
17 | fi
18 |
19 | GPUS=$1
20 | DATASET=$2
21 |
22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \
23 | --job_name='stylegan3t_ffhqu1024' \
24 | --seed=0 \
25 | --resolution=1024 \
26 | --image_channels=3 \
27 | --train_dataset=${DATASET} \
28 | --val_dataset=${DATASET} \
29 | --val_max_samples=-1 \
30 | --total_img=25_000_000 \
31 | --batch_size=4 \
32 | --val_batch_size=4 \
33 | --train_data_mirror=true \
34 | --data_loader_type='iter' \
35 | --data_repeat=200 \
36 | --data_workers=3 \
37 | --data_prefetch_factor=2 \
38 | --data_pin_memory=true \
39 | --g_kernel_size=3 \
40 | --latent_dim=512 \
41 | --d_fmaps_factor=1.0 \
42 | --g_fmaps_factor=1.0 \
43 | --d_mbstd_groups=4 \
44 | --g_num_mappings=2 \
45 | --d_lr=0.002 \
46 | --g_lr=0.0025 \
47 | --w_moving_decay=0.998 \
48 | --sync_w_avg=false \
49 | --style_mixing_prob=0.0 \
50 | --r1_interval=16 \
51 | --r1_gamma=32.8 \
52 | --blur_init_sigma=0.0 \
53 | --blur_fade_img=0 \
54 | --pl_interval=0 \
55 | --pl_weight=0.0 \
56 | --pl_decay=0.01 \
57 | --pl_batch_shrink=2 \
58 | --g_ema_img=10_000 \
59 | --g_ema_rampup=0.05 \
60 | --eval_at_start=true \
61 | --eval_interval=6400 \
62 | --ckpt_interval=6400 \
63 | --log_interval=128 \
64 | --enable_amp=false \
65 | --use_ada=true \
66 | --num_fp16_res=4 \
67 | ${@:3}
68 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan3t_ffhqu256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN3 (config T) on" \
6 | "FFHQ-U-256."
7 | echo
8 | echo "Note: All settings are already preset for training with 8 GPUs." \
9 | "Please pass addition options, which will overwrite the original" \
10 | "settings, if needed."
11 | echo
12 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
13 | echo
14 | echo "Example: $0 8 /data/ffhqu256.zip [--help]"
15 | echo
16 | exit 0
17 | fi
18 |
19 | GPUS=$1
20 | DATASET=$2
21 |
22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \
23 | --job_name='stylegan3t_ffhqu256' \
24 | --seed=0 \
25 | --resolution=256 \
26 | --image_channels=3 \
27 | --train_dataset=${DATASET} \
28 | --val_dataset=${DATASET} \
29 | --val_max_samples=-1 \
30 | --total_img=25_000_000 \
31 | --batch_size=8 \
32 | --val_batch_size=16 \
33 | --train_data_mirror=true \
34 | --data_loader_type='iter' \
35 | --data_repeat=200 \
36 | --data_workers=3 \
37 | --data_prefetch_factor=2 \
38 | --data_pin_memory=true \
39 | --g_kernel_size=3 \
40 | --latent_dim=512 \
41 | --d_fmaps_factor=0.5 \
42 | --g_fmaps_factor=0.5 \
43 | --d_mbstd_groups=4 \
44 | --g_num_mappings=2 \
45 | --d_lr=0.0025 \
46 | --g_lr=0.0025 \
47 | --w_moving_decay=0.998 \
48 | --sync_w_avg=false \
49 | --style_mixing_prob=0.0 \
50 | --r1_interval=16 \
51 | --r1_gamma=1.0 \
52 | --blur_init_sigma=0.0 \
53 | --blur_fade_img=0 \
54 | --pl_interval=0 \
55 | --pl_weight=0.0 \
56 | --pl_decay=0.01 \
57 | --pl_batch_shrink=2 \
58 | --g_ema_img=10_000 \
59 | --g_ema_rampup=0.05 \
60 | --eval_at_start=true \
61 | --eval_interval=3200 \
62 | --ckpt_interval=3200 \
63 | --log_interval=64 \
64 | --enable_amp=false \
65 | --use_ada=true \
66 | --num_fp16_res=4 \
67 | ${@:3}
68 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan_ffhq1024.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN on FFHQ-1024."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/ffhq1024.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan \
22 | --job_name='stylegan_ffhq1024' \
23 | --seed=0 \
24 | --resolution=1024 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=4 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=200 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.001 \
45 | --g_lr=0.001 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_gamma=10.0 \
50 | --g_ema_img=10_000 \
51 | --eval_at_start=true \
52 | --eval_interval=6400 \
53 | --ckpt_interval=6400 \
54 | --log_interval=128 \
55 | --use_ada=false \
56 | --enable_amp=true \
57 | -o controllers.ProgressScheduler.init_res=8 \
58 | -o controllers.ProgressScheduler.batch_size_schedule='{"res4":64,"res8":32,"res16":16,"res32":8}' \
59 | -o controllers.ProgressScheduler.lr_schedule='{"res128":1.5,"res256":2,"res512":3,"res1024":3}' \
60 | ${@:3}
61 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan_ffhq256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN on FFHQ-256."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/ffhq256.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan \
22 | --job_name='stylegan_ffhq256' \
23 | --seed=0 \
24 | --resolution=256 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=16 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=200 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.001 \
45 | --g_lr=0.001 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_gamma=10.0 \
50 | --g_ema_img=10_000 \
51 | --eval_at_start=true \
52 | --eval_interval=6400 \
53 | --ckpt_interval=6400 \
54 | --log_interval=128 \
55 | --use_ada=false \
56 | --enable_amp=true \
57 | -o controllers.ProgressScheduler.init_res=8 \
58 | -o controllers.ProgressScheduler.batch_size_schedule='{"res4":64,"res8":32,"res16":16,"res32":8}' \
59 | -o controllers.ProgressScheduler.lr_schedule='{"res128":1.5,"res256":2,"res512":3,"res1024":3}' \
60 | ${@:3}
61 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan_ffhq512.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN on FFHQ-512."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/ffhq512.zip [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan \
22 | --job_name='stylegan_ffhq512' \
23 | --seed=0 \
24 | --resolution=512 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=8 \
32 | --train_data_mirror=true \
33 | --data_loader_type='iter' \
34 | --data_repeat=200 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.001 \
45 | --g_lr=0.001 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_gamma=10.0 \
50 | --g_ema_img=10_000 \
51 | --eval_at_start=true \
52 | --eval_interval=6400 \
53 | --ckpt_interval=6400 \
54 | --log_interval=128 \
55 | --use_ada=false \
56 | --enable_amp=true \
57 | -o controllers.ProgressScheduler.init_res=8 \
58 | -o controllers.ProgressScheduler.batch_size_schedule='{"res4":64,"res8":32,"res16":16,"res32":8}' \
59 | -o controllers.ProgressScheduler.lr_schedule='{"res128":1.5,"res256":2,"res512":3,"res1024":3}' \
60 | ${@:3}
61 |
--------------------------------------------------------------------------------
/scripts/training_demos/stylegan_lsun_bedroom256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help message.
4 | if [[ $# -lt 2 ]]; then
5 | echo "This script launches a job of training StyleGAN on LSUN-Bedroom-256."
6 | echo
7 | echo "Note: All settings are already preset for training with 8 GPUs." \
8 | "Please pass addition options, which will overwrite the original" \
9 | "settings, if needed."
10 | echo
11 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
12 | echo
13 | echo "Example: $0 8 /data/LSUN/bedroom_train_lmdb [--help]"
14 | echo
15 | exit 0
16 | fi
17 |
18 | GPUS=$1
19 | DATASET=$2
20 |
21 | ./scripts/dist_train.sh ${GPUS} stylegan \
22 | --job_name='stylegan_lsun_bedroom256' \
23 | --seed=0 \
24 | --resolution=256 \
25 | --image_channels=3 \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=100_000_000 \
30 | --batch_size=4 \
31 | --val_batch_size=16 \
32 | --train_data_mirror=false \
33 | --data_loader_type='iter' \
34 | --data_repeat=30 \
35 | --data_workers=3 \
36 | --data_prefetch_factor=2 \
37 | --data_pin_memory=true \
38 | --g_init_res=4 \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.001 \
45 | --g_lr=0.001 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_gamma=10.0 \
50 | --g_ema_img=10_000 \
51 | --eval_at_start=true \
52 | --eval_interval=6400 \
53 | --ckpt_interval=6400 \
54 | --log_interval=128 \
55 | --use_ada=false \
56 | --enable_amp=true \
57 | -o controllers.ProgressScheduler.init_res=8 \
58 | -o controllers.ProgressScheduler.batch_size_schedule='{"res4":64,"res8":32,"res16":16,"res32":8}' \
59 | -o controllers.ProgressScheduler.lr_schedule='{"res128":1.5,"res256":2,"res512":3,"res1024":3}' \
60 | ${@:3}
61 |
--------------------------------------------------------------------------------
/scripts/training_demos/volumegan_ffhq256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | # Help message.
5 | if [[ $# -lt 2 ]]; then
6 | echo "This script launches a job of training VolumeGAN on FFHQ-256."
7 | echo
8 | echo "Note: All settings are already preset for training with 8 GPUs." \
9 | "Please pass addition options, which will overwrite the original" \
10 | "settings, if needed."
11 | echo
12 | echo "Usage: $0 GPUS DATASET [OPTIONS]"
13 | echo
14 | echo "Example: $0 8 /data/ffhq256.zip [--help]"
15 | echo
16 | exit 0
17 | fi
18 |
19 | GPUS=$1
20 | DATASET=$2
21 | NeRFRES=32
22 | RES=256
23 | ./scripts/dist_train.sh ${GPUS} volumegan-ffhq \
24 | --seed=0 \
25 | --resolution=${RES} \
26 | --train_dataset=${DATASET} \
27 | --val_dataset=${DATASET} \
28 | --val_max_samples=-1 \
29 | --total_img=25_000_000 \
30 | --batch_size=8 \
31 | --val_batch_size=16 \
32 | --train_data_mirror=true \
33 | --data_workers=3 \
34 | --data_pin_memory=true \
35 | --data_prefetch_factor=2 \
36 | --data_repeat=500 \
37 | --train_data_mirror=true \
38 | --g_init_res=${NeRFRES} \
39 | --latent_dim=512 \
40 | --d_fmaps_factor=1.0 \
41 | --g_fmaps_factor=1.0 \
42 | --d_mbstd_groups=4 \
43 | --g_num_mappings=8 \
44 | --d_lr=0.002 \
45 | --g_lr=0.002 \
46 | --w_moving_decay=0.995 \
47 | --sync_w_avg=false \
48 | --style_mixing_prob=0.9 \
49 | --r1_interval=16 \
50 | --r1_gamma=10.0 \
51 | --pl_interval=4 \
52 | --pl_weight=2.0 \
53 | --pl_decay=0.01 \
54 | --pl_batch_shrink=2 \
55 | --g_ema_img=20000 \
56 | --eval_at_start=false \
57 | --eval_interval=6400 \
58 | --ckpt_interval=6400 \
59 | --log_interval=128 \
60 | --enable_amp=false \
61 | --use_ada=false \
62 | --num_fp16_res=128\
63 | --logger_type=normal \
64 | --use_pg=true \
65 | ${@:4}
66 |
--------------------------------------------------------------------------------
/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/third_party/__init__.py
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/README.md:
--------------------------------------------------------------------------------
1 | # Operators for StyleGAN2
2 |
3 | All files in this directory are borrowed from repository [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including
4 |
5 | - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator.
6 | - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering.
7 | - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel.
8 | - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map.
9 | - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map.
10 | - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map.
11 | - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
12 | - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
13 | - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`)
14 |
15 | We make following slight modifications beyond disabling some lint warnings:
16 |
17 | - Line 25 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch).
18 | - Line 35 of file `custom_ops.py`: Disable log message when setting up customized operators.
19 | - Line 53/89 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*)
20 | - Line 24 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch).
21 | - Line 32 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default.
22 | - Line 36 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator.
23 | - Line 33 of file `conv2d_gradfix.py`: Enable customized convolution operators by default.
24 | - Line 46/51 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators.
25 | - Line 36/66 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators.
26 | - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator.
27 |
28 | Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default.
29 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/third_party/stylegan2_official_ops/__init__.py
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/fma.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.
12 |
13 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14 | """
15 |
16 | # pylint: disable=line-too-long
17 | # pylint: disable=missing-function-docstring
18 |
19 | import torch
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def fma(a, b, c, impl='cuda'): # => a * b + c
24 | if impl == 'cuda':
25 | return _FusedMultiplyAdd.apply(a, b, c)
26 | return torch.addcmul(c, a, b)
27 |
28 | #----------------------------------------------------------------------------
29 |
30 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
31 | @staticmethod
32 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
33 | out = torch.addcmul(c, a, b)
34 | ctx.save_for_backward(a, b)
35 | ctx.c_shape = c.shape
36 | return out
37 |
38 | @staticmethod
39 | def backward(ctx, dout): # pylint: disable=arguments-differ
40 | a, b = ctx.saved_tensors
41 | c_shape = ctx.c_shape
42 | da = None
43 | db = None
44 | dc = None
45 |
46 | if ctx.needs_input_grad[0]:
47 | da = _unbroadcast(dout * b, a.shape)
48 |
49 | if ctx.needs_input_grad[1]:
50 | db = _unbroadcast(dout * a, b.shape)
51 |
52 | if ctx.needs_input_grad[2]:
53 | dc = _unbroadcast(dout, c_shape)
54 |
55 | return da, db, dc
56 |
57 | #----------------------------------------------------------------------------
58 |
59 | def _unbroadcast(x, shape):
60 | extra_dims = x.ndim - len(shape)
61 | assert extra_dims >= 0
62 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
63 | if len(dim):
64 | x = x.sum(dim=dim, keepdim=True)
65 | if extra_dims:
66 | x = x.reshape(-1, *x.shape[extra_dims+1:])
67 | assert x.shape == shape
68 | return x
69 |
70 | #----------------------------------------------------------------------------
71 |
72 | # pylint: enable=line-too-long
73 | # pylint: enable=missing-function-docstring
74 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Custom replacement for `torch.nn.functional.grid_sample`.
12 |
13 | This is useful for differentiable augmentation. This customized operator
14 | supports arbitrarily high order gradients between the input and output. Only
15 | works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and
16 | `align_corners=False`.
17 |
18 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
19 | """
20 |
21 | # pylint: disable=redefined-builtin
22 | # pylint: disable=arguments-differ
23 | # pylint: disable=protected-access
24 | # pylint: disable=line-too-long
25 | # pylint: disable=missing-function-docstring
26 |
27 | import warnings
28 | import torch
29 | from distutils.version import LooseVersion
30 |
31 | #----------------------------------------------------------------------------
32 |
33 | enabled = True # Enable the custom op by setting this to true.
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | def grid_sample(input, grid, impl='cuda'):
38 | if impl == 'cuda' and _should_use_custom_op():
39 | return _GridSample2dForward.apply(input, grid)
40 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | def _should_use_custom_op():
45 | if not enabled:
46 | return False
47 | if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
48 | return True
49 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
50 | return False
51 |
52 | #----------------------------------------------------------------------------
53 |
54 | class _GridSample2dForward(torch.autograd.Function):
55 | @staticmethod
56 | def forward(ctx, input, grid):
57 | assert input.ndim == 4
58 | assert grid.ndim == 4
59 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
60 | ctx.save_for_backward(input, grid)
61 | return output
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | input, grid = ctx.saved_tensors
66 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
67 | return grad_input, grad_grid
68 |
69 | #----------------------------------------------------------------------------
70 |
71 | class _GridSample2dBackward(torch.autograd.Function):
72 | @staticmethod
73 | def forward(ctx, grad_output, input, grid):
74 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
75 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
76 | ctx.save_for_backward(grid)
77 | return grad_input, grad_grid
78 |
79 | @staticmethod
80 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
81 | _ = grad2_grad_grid # unused
82 | grid, = ctx.saved_tensors
83 | grad2_grad_output = None
84 | grad2_input = None
85 | grad2_grid = None
86 |
87 | if ctx.needs_input_grad[0]:
88 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
89 |
90 | assert not ctx.needs_input_grad[2]
91 | return grad2_grad_output, grad2_input, grad2_grid
92 |
93 | #----------------------------------------------------------------------------
94 |
95 | # pylint: enable=redefined-builtin
96 | # pylint: enable=arguments-differ
97 | # pylint: enable=protected-access
98 | # pylint: enable=line-too-long
99 | # pylint: enable=missing-function-docstring
100 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/README.md:
--------------------------------------------------------------------------------
1 | # Operators for StyleGAN2
2 |
3 | All files in this directory are borrowed from repository [stylegan3](https://github.com/NVlabs/stylegan3). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including
4 |
5 | - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator.
6 | - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering.
7 | - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel.
8 | - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map.
9 | - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map.
10 | - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map.
11 | - `filtered_lrelu.filtered_lrelu()`: Leaky ReLU layer, wrapped with upsampling and downsampling for anti-aliasing.
12 | - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
13 | - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
14 | - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`)
15 |
16 | We make following slight modifications beyond disabling some lint warnings:
17 |
18 | - Line 24 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3).
19 | - Line 36 of file `custom_ops.py`: Disable log message when setting up customized operators.
20 | - Line 54/109 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*)
21 | - Line 21 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3).
22 | - Line 162-165 of file `filtered_lrelu.py`: Change some implementations in `_filtered_lrelu_ref()` to `ref`.
23 | - Line 31 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default.
24 | - Line 35 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator.
25 | - Line 34 of file `conv2d_gradfix.py`: Enable customized convolution operators by default.
26 | - Line 48/53 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators.
27 | - Line 36/53 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators.
28 | - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator.
29 |
30 | Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default.
31 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/third_party/stylegan3_official_ops/__init__.py
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/filtered_lrelu.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct filtered_lrelu_kernel_params
15 | {
16 | // These parameters decide which kernel to use.
17 | int up; // upsampling ratio (1, 2, 4)
18 | int down; // downsampling ratio (1, 2, 4)
19 | int2 fuShape; // [size, 1] | [size, size]
20 | int2 fdShape; // [size, 1] | [size, size]
21 |
22 | int _dummy; // Alignment.
23 |
24 | // Rest of the parameters.
25 | const void* x; // Input tensor.
26 | void* y; // Output tensor.
27 | const void* b; // Bias tensor.
28 | unsigned char* s; // Sign tensor in/out. NULL if unused.
29 | const float* fu; // Upsampling filter.
30 | const float* fd; // Downsampling filter.
31 |
32 | int2 pad0; // Left/top padding.
33 | float gain; // Additional gain factor.
34 | float slope; // Leaky ReLU slope on negative side.
35 | float clamp; // Clamp after nonlinearity.
36 | int flip; // Filter kernel flip for gradient computation.
37 |
38 | int tilesXdim; // Original number of horizontal output tiles.
39 | int tilesXrep; // Number of horizontal tiles per CTA.
40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41 |
42 | int4 xShape; // [width, height, channel, batch]
43 | int4 yShape; // [width, height, channel, batch]
44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46 | int swLimit; // Active width of sign tensor in bytes.
47 |
48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49 | longlong4 yStride; //
50 | int64_t bStride; //
51 | longlong3 fuStride; //
52 | longlong3 fdStride; //
53 | };
54 |
55 | struct filtered_lrelu_act_kernel_params
56 | {
57 | void* x; // Input/output, modified in-place.
58 | unsigned char* s; // Sign tensor in/out. NULL if unused.
59 |
60 | float gain; // Additional gain factor.
61 | float slope; // Leaky ReLU slope on negative side.
62 | float clamp; // Clamp after nonlinearity.
63 |
64 | int4 xShape; // [width, height, channel, batch]
65 | longlong4 xStride; // Input/output tensor strides, same order as in shape.
66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68 | };
69 |
70 | //------------------------------------------------------------------------
71 | // CUDA kernel specialization.
72 |
73 | struct filtered_lrelu_kernel_spec
74 | {
75 | void* setup; // Function for filter kernel setup.
76 | void* exec; // Function for main operation.
77 | int2 tileOut; // Width/height of launch tile.
78 | int numWarps; // Number of warps per thread block, determines launch block size.
79 | int xrep; // For processing multiple horizontal tiles per thread block.
80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81 | };
82 |
83 | //------------------------------------------------------------------------
84 | // CUDA kernel selection.
85 |
86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87 | template void* choose_filtered_lrelu_act_kernel(void);
88 | template cudaError_t copy_filters(cudaStream_t stream);
89 |
90 | //------------------------------------------------------------------------
91 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for no signs mode (no gradients required).
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign read mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign write mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/fma.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.
12 |
13 | Please refer to https://github.com/NVlabs/stylegan3
14 | """
15 |
16 | # pylint: disable=line-too-long
17 | # pylint: disable=missing-function-docstring
18 |
19 | import torch
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def fma(a, b, c, impl='cuda'): # => a * b + c
24 | if impl == 'cuda':
25 | return _FusedMultiplyAdd.apply(a, b, c)
26 | return torch.addcmul(c, a, b)
27 |
28 | #----------------------------------------------------------------------------
29 |
30 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
31 | @staticmethod
32 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
33 | out = torch.addcmul(c, a, b)
34 | ctx.save_for_backward(a, b)
35 | ctx.c_shape = c.shape
36 | return out
37 |
38 | @staticmethod
39 | def backward(ctx, dout): # pylint: disable=arguments-differ
40 | a, b = ctx.saved_tensors
41 | c_shape = ctx.c_shape
42 | da = None
43 | db = None
44 | dc = None
45 |
46 | if ctx.needs_input_grad[0]:
47 | da = _unbroadcast(dout * b, a.shape)
48 |
49 | if ctx.needs_input_grad[1]:
50 | db = _unbroadcast(dout * a, b.shape)
51 |
52 | if ctx.needs_input_grad[2]:
53 | dc = _unbroadcast(dout, c_shape)
54 |
55 | return da, db, dc
56 |
57 | #----------------------------------------------------------------------------
58 |
59 | def _unbroadcast(x, shape):
60 | extra_dims = x.ndim - len(shape)
61 | assert extra_dims >= 0
62 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
63 | if len(dim):
64 | x = x.sum(dim=dim, keepdim=True)
65 | if extra_dims:
66 | x = x.reshape(-1, *x.shape[extra_dims+1:])
67 | assert x.shape == shape
68 | return x
69 |
70 | #----------------------------------------------------------------------------
71 |
72 | # pylint: enable=line-too-long
73 | # pylint: enable=missing-function-docstring
74 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Custom replacement for `torch.nn.functional.grid_sample`.
12 |
13 | This is useful for differentiable augmentation. This customized operator
14 | supports arbitrarily high order gradients between the input and output. Only
15 | works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and
16 | `align_corners=False`.
17 |
18 | Please refer to https://github.com/NVlabs/stylegan3
19 | """
20 |
21 | # pylint: disable=redefined-builtin
22 | # pylint: disable=arguments-differ
23 | # pylint: disable=protected-access
24 | # pylint: disable=line-too-long
25 | # pylint: disable=missing-function-docstring
26 |
27 | import torch
28 |
29 | #----------------------------------------------------------------------------
30 |
31 | enabled = True # Enable the custom op by setting this to true.
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def grid_sample(input, grid, impl='cuda'):
36 | if impl == 'cuda' and _should_use_custom_op():
37 | return _GridSample2dForward.apply(input, grid)
38 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
39 |
40 | #----------------------------------------------------------------------------
41 |
42 | def _should_use_custom_op():
43 | return enabled
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | class _GridSample2dForward(torch.autograd.Function):
48 | @staticmethod
49 | def forward(ctx, input, grid):
50 | assert input.ndim == 4
51 | assert grid.ndim == 4
52 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
53 | ctx.save_for_backward(input, grid)
54 | return output
55 |
56 | @staticmethod
57 | def backward(ctx, grad_output):
58 | input, grid = ctx.saved_tensors
59 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
60 | return grad_input, grad_grid
61 |
62 | #----------------------------------------------------------------------------
63 |
64 | class _GridSample2dBackward(torch.autograd.Function):
65 | @staticmethod
66 | def forward(ctx, grad_output, input, grid):
67 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
68 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
69 | ctx.save_for_backward(grid)
70 | return grad_input, grad_grid
71 |
72 | @staticmethod
73 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
74 | _ = grad2_grad_grid # unused
75 | grid, = ctx.saved_tensors
76 | grad2_grad_output = None
77 | grad2_input = None
78 | grad2_grid = None
79 |
80 | if ctx.needs_input_grad[0]:
81 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
82 |
83 | assert not ctx.needs_input_grad[2]
84 | return grad2_grad_output, grad2_input, grad2_grid
85 |
86 | #----------------------------------------------------------------------------
87 |
88 | # pylint: enable=redefined-builtin
89 | # pylint: enable=arguments-differ
90 | # pylint: enable=protected-access
91 | # pylint: enable=line-too-long
92 | # pylint: enable=missing-function-docstring
93 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Main function for model training."""
3 |
4 | import click
5 |
6 | from configs import CONFIG_POOL
7 | from configs import build_config
8 | from runners import build_runner
9 | from utils.dist_utils import init_dist
10 | from utils.dist_utils import exit_dist
11 |
12 |
13 | @click.group(name='Distributed Training',
14 | help='Train a deep model by choosing a command (configuration).',
15 | context_settings={'show_default': True, 'max_content_width': 180})
16 | @click.option('--launcher', default='pytorch',
17 | type=click.Choice(['pytorch', 'slurm']),
18 | help='Distributed launcher.')
19 | @click.option('--backend', default='nccl',
20 | type=click.Choice(['nccl', 'gloo', 'mpi']),
21 | help='Distributed backend.')
22 | @click.option('--local_rank', type=int, default=0, hidden=True,
23 | help='Replica rank on the current node. This field is required '
24 | 'by `torch.distributed.launch`.')
25 | def command_group(launcher, backend, local_rank): # pylint: disable=unused-argument
26 | """Defines a command group for launching distributed jobs.
27 |
28 | This function is mainly for interaction with the command line. The real
29 | launching is executed by `main()` function, through `result_callback()`
30 | decorator. In other words, the arguments obtained from the command line will
31 | be passed to `main()` function. As for how the arguments are passed, it is
32 | the responsibility of each command of this command group. Please refer to
33 | `BaseConfig.get_command()` in `configs/base_config.py` for more details.
34 | """
35 |
36 |
37 | @command_group.result_callback()
38 | @click.pass_context
39 | def main(ctx, kwargs, launcher, backend, local_rank):
40 | """Main function for distributed training.
41 |
42 | Basically, this function first initializes a distributed environment, then
43 | parses configuration from the command line, and finally sets up the runner
44 | with the parsed configuration for training.
45 | """
46 | _ = local_rank # unused variable
47 |
48 | # Initialize distributed environment.
49 | init_dist(launcher=launcher, backend=backend)
50 |
51 | # Build configurations and runner.
52 | config = build_config(ctx.invoked_subcommand, kwargs).get_config()
53 | runner = build_runner(config)
54 |
55 | # Start training.
56 | runner.train()
57 | runner.close()
58 |
59 | # Exit distributed environment.
60 | exit_dist()
61 |
62 |
63 | if __name__ == '__main__':
64 | # Append all available commands (from `configs/`) into the command group.
65 | for cfg in CONFIG_POOL:
66 | command_group.add_command(cfg.get_command())
67 | # Run by interacting with command line.
68 | command_group() # pylint: disable=no-value-for-parameter
69 |
--------------------------------------------------------------------------------
/unit_tests.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all unit tests."""
3 |
4 | import argparse
5 |
6 | from models.test import test_model
7 | from utils.loggers.test import test_logger
8 | from utils.visualizers.test import test_visualizer
9 | from utils.parsing_utils import parse_bool
10 |
11 |
12 | def parse_args():
13 | """Parses arguments."""
14 | parser = argparse.ArgumentParser(description='Run unit tests.')
15 | parser.add_argument('--result_dir', type=str,
16 | default='work_dirs/unit_tests',
17 | help='Path to save the test results. (default: '
18 | '%(default)s)')
19 | parser.add_argument('--test_all', type=parse_bool, default=False,
20 | help='Whether to run all unit tests. (default: '
21 | '%(default)s)')
22 | parser.add_argument('--test_model', type=parse_bool, default=False,
23 | help='Whether to run unit test on models. (default: '
24 | '%(default)s)')
25 | parser.add_argument('--test_logger', type=parse_bool, default=False,
26 | help='Whether to run unit test on loggers. (default: '
27 | '%(default)s)')
28 | parser.add_argument('--test_visualizer', type=parse_bool, default=False,
29 | help='Whether to do unit test on visualizers. '
30 | '(default: %(default)s)')
31 | return parser.parse_args()
32 |
33 |
34 | def main():
35 | """Main function."""
36 | args = parse_args()
37 |
38 | if args.test_all or args.test_model:
39 | test_model()
40 |
41 | if args.test_all or args.test_logger:
42 | test_logger(args.result_dir)
43 |
44 | if args.test_all or args.test_visualizer:
45 | test_visualizer(args.result_dir)
46 |
47 |
48 | if __name__ == '__main__':
49 | main()
50 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains utility functions used for distribution."""
3 |
4 | import contextlib
5 | import os
6 | import subprocess
7 |
8 | import torch
9 | import torch.distributed as dist
10 | import torch.multiprocessing as mp
11 |
12 | __all__ = ['init_dist', 'exit_dist', 'ddp_sync', 'get_ddp_module']
13 |
14 |
15 | def init_dist(launcher, backend='nccl', **kwargs):
16 | """Initializes distributed environment."""
17 | if mp.get_start_method(allow_none=True) is None:
18 | mp.set_start_method('spawn')
19 | if launcher == 'pytorch':
20 | rank = int(os.environ['RANK'])
21 | num_gpus = torch.cuda.device_count()
22 | torch.cuda.set_device(rank % num_gpus)
23 | dist.init_process_group(backend=backend, **kwargs)
24 | elif launcher == 'slurm':
25 | proc_id = int(os.environ['SLURM_PROCID'])
26 | ntasks = int(os.environ['SLURM_NTASKS'])
27 | node_list = os.environ['SLURM_NODELIST']
28 | num_gpus = torch.cuda.device_count()
29 | torch.cuda.set_device(proc_id % num_gpus)
30 | addr = subprocess.getoutput(
31 | f'scontrol show hostname {node_list} | head -n1')
32 | port = os.environ.get('PORT', 29500)
33 | os.environ['MASTER_PORT'] = str(port)
34 | os.environ['MASTER_ADDR'] = addr
35 | os.environ['WORLD_SIZE'] = str(ntasks)
36 | os.environ['RANK'] = str(proc_id)
37 | dist.init_process_group(backend=backend)
38 | else:
39 | raise NotImplementedError(f'Not implemented launcher type: '
40 | f'`{launcher}`!')
41 |
42 |
43 | def exit_dist():
44 | """Exits the distributed environment."""
45 | if dist.is_initialized():
46 | dist.destroy_process_group()
47 |
48 |
49 | @contextlib.contextmanager
50 | def ddp_sync(model, sync):
51 | """Controls whether the `DistributedDataParallel` model should be synced."""
52 | assert isinstance(model, torch.nn.Module)
53 | is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel)
54 | if sync or not is_ddp:
55 | yield
56 | else:
57 | with model.no_sync():
58 | yield
59 |
60 |
61 | def get_ddp_module(model):
62 | """Gets the module from `DistributedDataParallel`."""
63 | assert isinstance(model, torch.nn.Module)
64 | is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel)
65 | if is_ddp:
66 | return model.module
67 | return model
68 |
--------------------------------------------------------------------------------
/utils/file_transmitters/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all file transmitters."""
3 |
4 | from .local_file_transmitter import LocalFileTransmitter
5 | from .dummy_file_transmitter import DummyFileTransmitter
6 |
7 | __all__ = ['build_file_transmitter']
8 |
9 | _TRANSMITTERS = {
10 | 'local': LocalFileTransmitter,
11 | 'dummy': DummyFileTransmitter,
12 | }
13 |
14 |
15 | def build_file_transmitter(transmitter_type='local', **kwargs):
16 | """Builds a file transmitter.
17 |
18 | Args:
19 | transmitter_type: Type of the file transmitter_type, which is case
20 | insensitive. (default: `normal`)
21 | **kwargs: Additional arguments to build the file transmitter.
22 |
23 | Raises:
24 | ValueError: If the `transmitter_type` is not supported.
25 | """
26 | transmitter_type = transmitter_type.lower()
27 | if transmitter_type not in _TRANSMITTERS:
28 | raise ValueError(f'Invalid transmitter type: `{transmitter_type}`!\n'
29 | f'Types allowed: {list(_TRANSMITTERS)}.')
30 | return _TRANSMITTERS[transmitter_type](**kwargs)
31 |
--------------------------------------------------------------------------------
/utils/file_transmitters/base_file_transmitter.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the base class to transmit files across file systems.
3 |
4 | Basically, a file transmitter connects the local file system, on which the
5 | programme runs, to a remote file system. This is particularly used for
6 | (1) pulling files that are required by the programme from remote, and
7 | (2) pushing results that are produced by the programme to remote. In this way,
8 | the programme can focus on local file system only.
9 |
10 | NOTE: The remote file system can be the same as the local file system, since
11 | users may want to transmit files across directories.
12 | """
13 |
14 | import warnings
15 |
16 | __all__ = ['BaseFileTransmitter']
17 |
18 |
19 | class BaseFileTransmitter(object):
20 | """Defines the base file transmitter.
21 |
22 | A transmitter should have the following functions:
23 |
24 | (1) pull(): The function to pull a file/directory from remote to local.
25 | (2) push(): The function to push a file/directory from local to remote.
26 | (3) remove(): The function to remove a file/directory.
27 | (4) make_remote_dir(): Make directory remotely.
28 |
29 |
30 | To simplify, each derived class just need to implement the following helper
31 | functions:
32 |
33 | (1) download_hard(): Hard download a file/directory from remote to local.
34 | (2) download_soft(): Soft download a file/directory from remote to local.
35 | This is especially used to save space (e.g., soft link).
36 | (3) upload(): Upload a file/directory from local to remote.
37 | (4) delete(): Delete a file/directory according to given path.
38 | """
39 |
40 | def __init__(self):
41 | pass
42 |
43 | @property
44 | def name(self):
45 | """Returns the class name of the file transmitter."""
46 | return self.__class__.__name__
47 |
48 | @staticmethod
49 | def download_hard(src, dst):
50 | """Downloads (in hard mode) a file/directory from remote to local."""
51 | raise NotImplementedError('Should be implemented in derived class!')
52 |
53 | @staticmethod
54 | def download_soft(src, dst):
55 | """Downloads (in soft mode) a file/directory from local to remote."""
56 | raise NotImplementedError('Should be implemented in derived class!')
57 |
58 | @staticmethod
59 | def upload(src, dst):
60 | """Uploads a file/directory from local to remote."""
61 | raise NotImplementedError('Should be implemented in derived class!')
62 |
63 | @staticmethod
64 | def delete(path):
65 | """Deletes the given path."""
66 | # TODO: should we secure the path to avoid mis-removing / attacks?
67 | raise NotImplementedError('Should be implemented in derived class!')
68 |
69 | def pull(self, src, dst, hard=False):
70 | """Pulls a file/directory from remote to local.
71 |
72 | The argument `hard` is to control the download mode (hard or soft).
73 | For example, the hard mode may hardly copy the file while the soft mode
74 | may softly link the file.
75 | """
76 | if hard:
77 | self.download_hard(src, dst)
78 | else:
79 | self.download_soft(src, dst)
80 |
81 | def push(self, src, dst):
82 | """Pushes a file/directory from local to remote."""
83 | self.upload(src, dst)
84 |
85 | def remove(self, path):
86 | """Removes the given path."""
87 | warnings.warn(f'`{path}` will be removed!')
88 | self.delete(path)
89 |
90 | def make_remote_dir(self, directory):
91 | """Makes a directory on the remote system."""
92 | raise NotImplementedError('Should be implemented in derived class!')
93 |
--------------------------------------------------------------------------------
/utils/file_transmitters/dummy_file_transmitter.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of dummy file transmitter.
3 |
4 | This file transmitter has all expected data transmission functions but behaves
5 | silently, which is very useful in multi-processing mode. Only the chief process
6 | can have the file transmitter with normal behavior.
7 | """
8 |
9 | from .base_file_transmitter import BaseFileTransmitter
10 |
11 | __all__ = ['DummyFileTransmitter']
12 |
13 |
14 | class DummyFileTransmitter(BaseFileTransmitter):
15 | """Implements a dummy transmitter which transmits nothing."""
16 |
17 | @staticmethod
18 | def download_hard(src, dst):
19 | return
20 |
21 | @staticmethod
22 | def download_soft(src, dst):
23 | return
24 |
25 | @staticmethod
26 | def upload(src, dst):
27 | return
28 |
29 | @staticmethod
30 | def delete(path):
31 | return
32 |
33 | def make_remote_dir(self, directory):
34 | return
35 |
--------------------------------------------------------------------------------
/utils/file_transmitters/local_file_transmitter.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of local file transmitter.
3 |
4 | The transmitter builds the connection between the local file system and itself.
5 | This can be used to transmit files from one directory to another. Consequently,
6 | `remote` in this file also means `local`.
7 | """
8 |
9 | from utils.misc import print_and_execute
10 | from .base_file_transmitter import BaseFileTransmitter
11 |
12 | __all__ = ['LocalFileTransmitter']
13 |
14 |
15 | class LocalFileTransmitter(BaseFileTransmitter):
16 | """Implements the transmitter connecting local file system to itself."""
17 |
18 | @staticmethod
19 | def download_hard(src, dst):
20 | print_and_execute(f'cp {src} {dst}')
21 |
22 | @staticmethod
23 | def download_soft(src, dst):
24 | print_and_execute(f'ln -s {src} {dst}')
25 |
26 | @staticmethod
27 | def upload(src, dst):
28 | print_and_execute(f'cp {src} {dst}')
29 |
30 | @staticmethod
31 | def delete(path):
32 | print_and_execute(f'rm -r {path}')
33 |
34 | def make_remote_dir(self, directory):
35 | print_and_execute(f'mkdir -p {directory}')
36 |
--------------------------------------------------------------------------------
/utils/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all loggers."""
3 |
4 | from .normal_logger import NormalLogger
5 | from .rich_logger import RichLogger
6 | from .dummy_logger import DummyLogger
7 |
8 | __all__ = ['build_logger']
9 |
10 | _LOGGERS = {
11 | 'normal': NormalLogger,
12 | 'rich': RichLogger,
13 | 'dummy': DummyLogger
14 | }
15 |
16 |
17 | def build_logger(logger_type='normal', **kwargs):
18 | """Builds a logger.
19 |
20 | Args:
21 | logger_type: Type of logger, which is case insensitive.
22 | (default: `normal`)
23 | **kwargs: Additional arguments to build the logger.
24 |
25 | Raises:
26 | ValueError: If the `logger_type` is not supported.
27 | """
28 | logger_type = logger_type.lower()
29 | if logger_type not in _LOGGERS:
30 | raise ValueError(f'Invalid logger type: `{logger_type}`!\n'
31 | f'Types allowed: {list(_LOGGERS)}.')
32 | return _LOGGERS[logger_type](**kwargs)
33 |
--------------------------------------------------------------------------------
/utils/loggers/dummy_logger.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of dummy logger.
3 |
4 | This logger has all expected logging functions but behaves silently, which is
5 | very useful in multi-processing mode. Only the chief process can have the logger
6 | with normal behavior.
7 | """
8 |
9 | from .base_logger import BaseLogger
10 |
11 | __all__ = ['DummyLogger']
12 |
13 |
14 | class DummyLogger(BaseLogger):
15 | """Implements a dummy logger which logs nothing."""
16 |
17 | def __init__(self,
18 | logger_name='logger',
19 | logfile=None,
20 | screen_level=None,
21 | file_level=None,
22 | indent_space=4,
23 | verbose_log=False):
24 | super().__init__(logger_name=logger_name,
25 | logfile=logfile,
26 | screen_level=screen_level,
27 | file_level=file_level,
28 | indent_space=indent_space,
29 | verbose_log=verbose_log)
30 |
31 | def _log(self, message, **kwargs):
32 | return
33 |
34 | def _debug(self, message, **kwargs):
35 | return
36 |
37 | def _info(self, message, **kwargs):
38 | return
39 |
40 | def _warning(self, message, **kwargs):
41 | return
42 |
43 | def _error(self, message, **kwargs):
44 | return
45 |
46 | def _exception(self, message, **kwargs):
47 | return
48 |
49 | def _critical(self, message, **kwargs):
50 | return
51 |
52 | def _print(self, *messages, **kwargs):
53 | return
54 |
55 | def init_pbar(self, leave=False):
56 | return
57 |
58 | def add_pbar_task(self, name, total, **kwargs):
59 | return -1
60 |
61 | def update_pbar(self, task_id, advance=1):
62 | return
63 |
64 | def close_pbar(self):
65 | return
66 |
--------------------------------------------------------------------------------
/utils/loggers/test.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Unit test for logger."""
3 |
4 | import os
5 | import time
6 |
7 | from . import build_logger
8 |
9 | __all__ = ['test_logger']
10 |
11 | _TEST_DIR = 'logger_test'
12 |
13 |
14 | def test_logger(test_dir=_TEST_DIR):
15 | """Tests loggers."""
16 | print('========== Start Logger Test ==========')
17 |
18 | os.makedirs(test_dir, exist_ok=True)
19 |
20 | for logger_type in ['normal', 'rich', 'dummy']:
21 | for indent_space in [2, 4]:
22 | for verbose_log in [False, True]:
23 | if logger_type == 'normal':
24 | class_name = 'Logger'
25 | elif logger_type == 'rich':
26 | class_name = 'RichLogger'
27 | elif logger_type == 'dummy':
28 | class_name = 'DummyLogger'
29 |
30 | print(f'===== '
31 | f'Testing `utils.logger.{class_name}` '
32 | f' (indent: {indent_space}, verbose: {verbose_log}) '
33 | f'=====')
34 | logger_name = (f'{logger_type}_logger_'
35 | f'indent_{indent_space}_'
36 | f'verbose_{verbose_log}')
37 | logger = build_logger(
38 | logger_type,
39 | logger_name=logger_name,
40 | logfile=os.path.join(test_dir, f'test_{logger_name}.log'),
41 | verbose_log=verbose_log,
42 | indent_space=indent_space)
43 | logger.print('print log')
44 | logger.print('print log,', 'log 2')
45 | logger.print('print log (indent level 0)', indent_level=0)
46 | logger.print('print log (indent level 1)', indent_level=1)
47 | logger.print('print log (indent level 2)', indent_level=2)
48 | logger.print('print log (verbose `False`)', is_verbose=False)
49 | logger.print('print log (verbose `True`)', is_verbose=True)
50 | logger.debug('debug log')
51 | logger.info('info log')
52 | logger.warning('warning log')
53 | logger.init_pbar()
54 | task_1 = logger.add_pbar_task('Task 1', 500)
55 | task_2 = logger.add_pbar_task('Task 2', 1000)
56 | for _ in range(1000):
57 | logger.update_pbar(task_1, 1)
58 | logger.update_pbar(task_2, 1)
59 | time.sleep(0.002)
60 | logger.close_pbar()
61 | print('Success!')
62 |
63 | print('========== Finish Logger Test ==========')
64 |
--------------------------------------------------------------------------------
/utils/tf_utils.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the utility functions to handle import TensorFlow modules.
3 |
4 | Basically, TensorFlow may not be supported in the current environment, or may
5 | cause some warnings. This file provides functions to help ease TensorFlow
6 | related imports, such as TensorBoard.
7 | """
8 |
9 | import warnings
10 |
11 | __all__ = ['import_tf', 'import_tb_writer']
12 |
13 |
14 | def import_tf():
15 | """Imports TensorFlow module if possible.
16 |
17 | If `ImportError` is raised, `None` will be returned. Otherwise, the module
18 | `tensorflow` will be returned.
19 | """
20 | warnings.filterwarnings('ignore', category=FutureWarning)
21 | try:
22 | import tensorflow as tf # pylint: disable=import-outside-toplevel
23 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
24 | module = tf
25 | except ImportError:
26 | module = None
27 | warnings.filterwarnings('default', category=FutureWarning)
28 | return module
29 |
30 |
31 | def import_tb_writer():
32 | """Imports the SummaryWriter of TensorBoard.
33 |
34 | If `ImportError` is raised, `None` will be returned. Otherwise, the class
35 | `SummaryWriter` will be returned.
36 |
37 | NOTE: This function attempts to import `SummaryWriter` from
38 | `torch.utils.tensorboard`. But it does not necessarily mean the import
39 | always succeeds because installing TensorBoard is not a duty of `PyTorch`.
40 | """
41 | warnings.filterwarnings('ignore', category=FutureWarning)
42 | try:
43 | from torch.utils.tensorboard import SummaryWriter # pylint: disable=import-outside-toplevel
44 | except ImportError: # In case TensorBoard is not supported.
45 | SummaryWriter = None
46 | warnings.filterwarnings('default', category=FutureWarning)
47 | return SummaryWriter
48 |
--------------------------------------------------------------------------------
/utils/visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all visualizers."""
3 |
4 | from .grid_visualizer import GridVisualizer
5 | from .gif_visualizer import GifVisualizer
6 | from .html_visualizer import HtmlVisualizer
7 | from .html_visualizer import HtmlReader
8 | from .video_visualizer import VideoVisualizer
9 | from .video_visualizer import VideoReader
10 |
11 | __all__ = [
12 | 'GridVisualizer', 'GifVisualizer', 'HtmlVisualizer', 'HtmlReader',
13 | 'VideoVisualizer', 'VideoReader'
14 | ]
15 |
--------------------------------------------------------------------------------
/utils/visualizers/gif_visualizer.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the visualizer to visualize images as a GIF."""
3 |
4 | from PIL import Image
5 |
6 | from ..image_utils import parse_image_size
7 | from ..image_utils import load_image
8 | from ..image_utils import resize_image
9 | from ..image_utils import list_images_from_dir
10 |
11 | __all__ = ['GifVisualizer']
12 |
13 |
14 | class GifVisualizer(object):
15 | """Defines the visualizer that visualizes an image collection as GIF."""
16 |
17 | def __init__(self, image_size=None, duration=100, loop=0):
18 | """Initializes the GIF visualizer.
19 |
20 | Args:
21 | image_size: Size for image visualization. (default: None)
22 | duration: Duration between two frames, in milliseconds.
23 | (default: 100)
24 | loop: How many times to loop the GIF. `0` means infinite.
25 | (default: 0)
26 | """
27 | self.set_image_size(image_size)
28 | self.set_duration(duration)
29 | self.set_loop(loop)
30 |
31 | def set_image_size(self, image_size=None):
32 | """Sets the image size of the GIF."""
33 | height, width = parse_image_size(image_size)
34 | self.image_height = height
35 | self.image_width = width
36 |
37 | def set_duration(self, duration=100):
38 | """Sets the GIF duration."""
39 | self.duration = duration
40 |
41 | def set_loop(self, loop=0):
42 | """Sets how many times the GIF will be looped. `0` means infinite."""
43 | self.loop = loop
44 |
45 | def visualize_collection(self, images, save_path):
46 | """Visualizes a collection of images one by one."""
47 | height, width = images[0].shape[0:2]
48 | height = self.image_height or height
49 | width = self.image_width or width
50 | pil_images = []
51 | for image in images:
52 | if image.shape[0:2] != (height, width):
53 | image = resize_image(image, (width, height))
54 | pil_images.append(Image.fromarray(image))
55 | pil_images[0].save(save_path, format='GIF', save_all=True,
56 | append_images=pil_images[1:],
57 | duration=self.duration,
58 | loop=self.loop)
59 |
60 | def visualize_list(self, image_list, save_path):
61 | """Visualizes a list of image files."""
62 | height, width = load_image(image_list[0]).shape[0:2]
63 | height = self.image_height or height
64 | width = self.image_width or width
65 | pil_images = []
66 | for filename in image_list:
67 | image = load_image(filename)
68 | if image.shape[0:2] != (height, width):
69 | image = resize_image(image, (width, height))
70 | pil_images.append(Image.fromarray(image))
71 | pil_images[0].save(save_path, format='GIF', save_all=True,
72 | append_images=pil_images[1:],
73 | duration=self.duration,
74 | loop=self.loop)
75 |
76 | def visualize_directory(self, directory, save_path):
77 | """Visualizes all images under a directory."""
78 | image_list = list_images_from_dir(directory)
79 | self.visualize_list(image_list, save_path)
80 |
--------------------------------------------------------------------------------
/utils/visualizers/test.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Unit test for visualizer."""
3 |
4 | import os
5 | import skvideo.datasets
6 |
7 | from ..image_utils import save_image
8 | from . import GridVisualizer
9 | from . import HtmlVisualizer
10 | from . import HtmlReader
11 | from . import GifVisualizer
12 | from . import VideoVisualizer
13 | from . import VideoReader
14 |
15 | __all__ = ['test_visualizer']
16 |
17 | _TEST_DIR = 'visualizer_test'
18 |
19 |
20 | def test_visualizer(test_dir=_TEST_DIR):
21 | """Tests visualizers."""
22 | print('========== Start Visualizer Test ==========')
23 |
24 | frame_dir = os.path.join(test_dir, 'test_frames')
25 | os.makedirs(frame_dir, exist_ok=True)
26 |
27 | print('===== Testing `VideoReader` =====')
28 | # Total 132 frames, with size (720, 1080).
29 | video_reader = VideoReader(skvideo.datasets.bigbuckbunny())
30 | frame_height = video_reader.frame_height
31 | frame_width = video_reader.frame_width
32 | frame_size = (frame_height, frame_width)
33 | half_size = (frame_height // 2, frame_width // 2)
34 | # Save frames as the test set.
35 | for idx in range(80):
36 | frame = video_reader.read()
37 | save_image(os.path.join(frame_dir, f'{idx:02d}.png'), frame)
38 |
39 | print('===== Testing `GirdVisualizer` =====')
40 | grid_visualizer = GridVisualizer()
41 | grid_visualizer.set_row_spacing(30)
42 | grid_visualizer.set_col_spacing(30)
43 | grid_visualizer.set_background(use_black=True)
44 | path = os.path.join(test_dir, 'portrait_row_major_ori_space30_black.png')
45 | grid_visualizer.visualize_directory(frame_dir, path,
46 | is_portrait=True, is_row_major=True)
47 | path = os.path.join(
48 | test_dir, 'landscape_col_major_downsample_space15_white.png')
49 | grid_visualizer.set_image_size(half_size)
50 | grid_visualizer.set_row_spacing(15)
51 | grid_visualizer.set_col_spacing(15)
52 | grid_visualizer.set_background(use_black=False)
53 | grid_visualizer.visualize_directory(frame_dir, path,
54 | is_portrait=False, is_row_major=False)
55 |
56 | print('===== Testing `HtmlVisualizer` =====')
57 | html_visualizer = HtmlVisualizer()
58 | path = os.path.join(test_dir, 'portrait_col_major_ori.html')
59 | html_visualizer.visualize_directory(frame_dir, path,
60 | is_portrait=True, is_row_major=False)
61 | path = os.path.join(test_dir, 'landscape_row_major_downsample.html')
62 | html_visualizer.set_image_size(half_size)
63 | html_visualizer.visualize_directory(frame_dir, path,
64 | is_portrait=False, is_row_major=True)
65 |
66 | print('===== Testing `HtmlReader` =====')
67 | path = os.path.join(test_dir, 'landscape_row_major_downsample.html')
68 | html_reader = HtmlReader(path)
69 | for j in range(html_reader.num_cols):
70 | assert html_reader.get_header(j) == ''
71 | parsed_dir = os.path.join(test_dir, 'parsed_frames')
72 | os.makedirs(parsed_dir, exist_ok=True)
73 | for i in range(html_reader.num_rows):
74 | for j in range(html_reader.num_cols):
75 | idx = i * html_reader.num_cols + j
76 | assert html_reader.get_text(i, j).endswith(f'(index {idx:03d})')
77 | image = html_reader.get_image(i, j, image_size=frame_size)
78 | assert image.shape[0:2] == frame_size
79 | save_image(os.path.join(parsed_dir, f'{idx:02d}.png'), image)
80 |
81 | print('===== Testing `GifVisualizer` =====')
82 | gif_visualizer = GifVisualizer()
83 | path = os.path.join(test_dir, 'gif_ori.gif')
84 | gif_visualizer.visualize_directory(frame_dir, path)
85 | gif_visualizer.set_image_size(half_size)
86 | path = os.path.join(test_dir, 'gif_downsample.gif')
87 | gif_visualizer.visualize_directory(frame_dir, path)
88 |
89 | print('===== Testing `VideoVisualizer` =====')
90 | video_visualizer = VideoVisualizer()
91 | path = os.path.join(test_dir, 'video_ori.mp4')
92 | video_visualizer.visualize_directory(frame_dir, path)
93 | path = os.path.join(test_dir, 'video_downsample.mp4')
94 | video_visualizer.set_frame_size(half_size)
95 | video_visualizer.visualize_directory(frame_dir, path)
96 |
97 | print('========== Finish Visualizer Test ==========')
98 |
--------------------------------------------------------------------------------