├── .gitignore ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── base_config.py ├── stylegan2_config.py ├── stylegan2_finetune_config.py ├── stylegan3_config.py ├── stylegan_config.py └── volumegan_ffhq_config.py ├── convert_model.py ├── converters ├── __init__.py ├── base_converter.py ├── pggan_converter.py ├── stylegan2_converter.py ├── stylegan2ada_pth_converter.py ├── stylegan2ada_tf_converter.py ├── stylegan3_converter.py └── stylegan_converter.py ├── datasets ├── __init__.py ├── base_dataset.py ├── data_loaders │ ├── __init__.py │ ├── base_data_loader.py │ ├── dali_batch_iterator.py │ ├── dali_data_loader.py │ ├── dali_pipeline.py │ ├── distributed_sampler.py │ └── iter_data_loader.py ├── file_readers │ ├── __init__.py │ ├── base_reader.py │ ├── directory_reader.py │ ├── lmdb_reader.py │ ├── tar_reader.py │ └── zip_reader.py ├── image_dataset.py ├── paired_dataset.py └── transformations │ ├── __init__.py │ ├── affine_transform.py │ ├── base_transformation.py │ ├── blur_and_sharpen.py │ ├── crop.py │ ├── decode.py │ ├── flip.py │ ├── hsv_jittering.py │ ├── identity.py │ ├── jpeg_compress.py │ ├── misc.py │ ├── normalize.py │ ├── region_brightness.py │ ├── resize.py │ └── utils │ ├── __init__.py │ ├── affine_transform.py │ └── polygon.py ├── docs ├── assets │ ├── bootstrap.min.css │ ├── comparison.png │ ├── font.css │ ├── framework.png │ ├── freezed.png │ ├── genforce.png │ ├── giraffe.png │ ├── graf.png │ ├── hologan.png │ ├── pigan.png │ ├── style.css │ └── teaser.png └── index.html ├── dump_command_args.py ├── metrics ├── __init__.py ├── base_gan_metric.py ├── base_metric.py ├── equivariance.py ├── fid.py ├── gan_pr.py ├── gan_snapshot.py ├── inception_score.py ├── intra_class_fid.py ├── kid.py └── utils.py ├── models ├── __init__.py ├── ghfeat_encoder.py ├── inception_model.py ├── perceptual_model.py ├── pggan_discriminator.py ├── pggan_generator.py ├── rendering │ ├── __init__.py │ ├── hierarchicle_sampling.py │ ├── points_sampling.py │ ├── renderer.py │ └── utils.py ├── stylegan2_discriminator.py ├── stylegan2_generator.py ├── stylegan3_generator.py ├── stylegan_discriminator.py ├── stylegan_generator.py ├── test.py ├── utils │ ├── __init__.py │ └── ops.py ├── volumegan_discriminator.py └── volumegan_generator.py ├── prepare_dataset.py ├── render.py ├── requirements ├── convert.txt ├── develop.txt └── minimal.txt ├── runners ├── __init__.py ├── augmentations │ ├── __init__.py │ ├── ada_aug.py │ └── no_aug.py ├── base_runner.py ├── controllers │ ├── __init__.py │ ├── ada_aug_controller.py │ ├── base_controller.py │ ├── batch_visualizer.py │ ├── cache_cleaner.py │ ├── checkpointer.py │ ├── dataset_visualizer.py │ ├── evaluator.py │ ├── lr_scheduler.py │ ├── progress_scheduler.py │ ├── running_logger.py │ └── timer.py ├── losses │ ├── __init__.py │ ├── base_loss.py │ ├── stylegan2_loss.py │ ├── stylegan3_loss.py │ ├── stylegan_loss.py │ └── volumegan_loss.py ├── stylegan2_runner.py ├── stylegan3_runner.py ├── stylegan_runner.py ├── utils │ ├── __init__.py │ ├── freezer.py │ ├── optimizer.py │ ├── profiler.py │ └── running_stats.py └── volumegan_runner.py ├── scripts ├── dist_train.sh ├── kill_zombies.sh ├── test_converters.sh ├── test_metrics.sh └── training_demos │ ├── stylegan2_ffhq1024.sh │ ├── stylegan2_ffhq256.sh │ ├── stylegan2_ffhq512.sh │ ├── stylegan2_lsun_bedroom256.sh │ ├── stylegan2ada_afhq512.sh │ ├── stylegan2ada_cifar10.sh │ ├── stylegan2ada_ffhq1024.sh │ ├── stylegan2ada_ffhq256.sh │ ├── stylegan3r_afhq512.sh │ ├── stylegan3r_ffhqu1024.sh │ ├── stylegan3r_ffhqu256.sh │ ├── stylegan3t_afhq512.sh │ ├── stylegan3t_ffhqu1024.sh │ ├── stylegan3t_ffhqu256.sh │ ├── stylegan_ffhq1024.sh │ ├── stylegan_ffhq256.sh │ ├── stylegan_ffhq512.sh │ ├── stylegan_lsun_bedroom256.sh │ └── volumegan_ffhq256.sh ├── test_metrics.py ├── third_party ├── __init__.py ├── stylegan2_official_ops │ ├── README.md │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── custom_ops.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── misc.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py └── stylegan3_official_ops │ ├── README.md │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── custom_ops.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── misc.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── train.py ├── unit_tests.py └── utils ├── __init__.py ├── dist_utils.py ├── file_transmitters ├── __init__.py ├── base_file_transmitter.py ├── dummy_file_transmitter.py └── local_file_transmitter.py ├── formatting_utils.py ├── image_utils.py ├── loggers ├── __init__.py ├── base_logger.py ├── dummy_logger.py ├── normal_logger.py ├── rich_logger.py └── test.py ├── misc.py ├── parsing_utils.py ├── tf_utils.py └── visualizers ├── __init__.py ├── gif_visualizer.py ├── grid_visualizer.py ├── html_visualizer.py ├── test.py └── video_visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore compiled files. 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # Ignore files created by IDEs. 6 | /.vscode/ 7 | /.idea/ 8 | .ipynb_*/ 9 | *.ipynb 10 | .DS_Store 11 | *.sw[pon] 12 | 13 | # Ignore result files within default working directory. 14 | /work_dirs/ 15 | 16 | # Ignore data files. 17 | data/ 18 | *.npy 19 | *.tar 20 | *.zip 21 | *.mdb 22 | 23 | # Ignore network files. 24 | checkpoints/ 25 | *.pth 26 | *.pt 27 | *.pkl 28 | *.h5 29 | *.dat 30 | 31 | # Ignore media files. 32 | results/ 33 | *.jpg 34 | *.png 35 | *.jpeg 36 | *.gif 37 | *.avi 38 | *.mp4 39 | 40 | # Ignore log files. 41 | resources/ 42 | events/ 43 | profile/ 44 | *.txt 45 | *.json 46 | *.log 47 | *.html 48 | events.* 49 | 50 | # Files that should not be ignored. 51 | !/requirements/* 52 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Shen" 5 | given-names: "Yujun" 6 | - family-names: "Zhang" 7 | given-names: "Zhiyi" 8 | - family-names: "Yang" 9 | given-names: "Dingdong" 10 | - family-names: "Xu" 11 | given-names: "Yinghao" 12 | - family-names: "Yang" 13 | given-names: "Ceyuan" 14 | - family-names: "Zhu" 15 | given-names: "Jiapeng" 16 | title: "Hammer: An Efficient Toolkit for Training Deep Models" 17 | version: 1.0.0 18 | date-released: 2022-02-08 19 | url: "https://github.com/bytedance/Hammer" 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations 2 | 3 | ![image](./docs/assets/framework.png) 4 | **Figure:** *Framework of VolumeGAN.* 5 | 6 | > **3D-aware Image Synthesis via Learning Structural and Textural Representations**
7 | > Yinghao Xu, Sida Peng, Ceyuan Yang, Yujun Shen, Bolei Zhou
8 | > *Computer Vision and Pattern Recognition (CVPR), 2022* 9 | 10 | [[Paper](https://arxiv.org/pdf/2112.10759.pdf)] 11 | [[Project Page](https://genforce.github.io/volumegan/)] 12 | [[Demo](https://www.youtube.com/watch?v=p85TVGJBMFc)] 13 | 14 | This paper aims at achieving high-fidelity 3D-aware images synthesis. We propose a novel framework, termed as VolumeGAN, for synthesizing images under different camera views, through explicitly learning a structural representation and a textural representation. We first learn a feature volume to represent the underlying structure, which is then converted to a feature field using a NeRF-like model. The feature field is further accumulated into a 2D feature map as the textural representation, followed by a neural renderer for appearance synthesis. Such a design enables independent control of the shape and the appearance. Extensive experiments on a wide range of datasets show that our approach achieves sufficiently higher image quality and better 3D control than the previous methods. 15 | 16 | ## Usage 17 | 18 | ### Setup 19 | 20 | This repository is based on [Hammer](https://github.com/bytedance/Hammer), where you can find detailed instructions on environmental setup. 21 | 22 | ### Test Demo 23 | 24 | ```shell 25 | python render.py \ 26 | --work_dir ${WORK_DIR} \ 27 | --checkpoint ${MODEL_PATH} \ 28 | --num ${NUM} \ 29 | --seed ${SEED} \ 30 | --render_mode ${RENDER_MODE} \ 31 | --generate_html ${SAVE_HTML} \ 32 | volumegan-ffhq 33 | ``` 34 | 35 | where 36 | 37 | - `WORK_DIR` refers to the path to save the results. 38 | - `MODEL_PATH` refers to the path of the pretrained model, regarding which we provide 39 | - [FFHQ-256](https://www.dropbox.com/s/ygwhufzwi2vb2t8/volumegan_ffhq256.pth?dl=0) 40 | - `NUM` refers to the number of samples to synthesize. 41 | - `SEED` refers to the random seed used for sampling. 42 | - `RENDER_MODE` refers to the type of the rendered results, including `video` and `shape`. 43 | - `SAVE_HTML` controls whether to save images as an HTML for better visualization when rendering videos. 44 | 45 | ### Training 46 | 47 | For example, users can use the following command to train VolumeGAN on FFHQ in the resolution of 256x256 48 | 49 | ```shell 50 | ./scripts/training_demos/volumegan_ffhq256.sh \ 51 | ${NUM_GPUS} \ 52 | ${DATA_PATH} \ 53 | [OPTIONS] 54 | ``` 55 | 56 | where 57 | 58 | - `NUM_GPUS` refers to the number of GPUs used for training. 59 | - `DATA_PATH` refers to the path to the dataset (`zip` format is strongly recommended). 60 | - `[OPTIONS]` refers to any additional option to pass. Detailed instructions on available options can be found via `python train.py volumegan-ffhq --help`. 61 | 62 | **NOTE:** This demo script uses `volumegan_ffhq256` as the default `job_name`, which is particularly used to identify experiments. Concretely, a directory with name `job_name` will be created under the root working directory, which is set as `work_dirs/` by default. To prevent overwriting previous experiments, an exception will be raised to interrupt the training if the `job_name` directory has already existed. Please use `--job_name=${JOB_NAME}` option to specify a new job name. 63 | 64 | ### Evaluation 65 | 66 | Users can use the following command to evaluate a well-trained model 67 | 68 | ```shell 69 | ./scripts/test_metrics.sh \ 70 | ${NUM_GPUS} \ 71 | ${DATA_PATH} \ 72 | ${MODEL_PATH} \ 73 | fid \ 74 | --G_kwargs '{"ps_kwargs":'{"perturb_mode":"none"}'}' \ 75 | [OPTIONS] 76 | ``` 77 | 78 | ## BibTeX 79 | 80 | ```bibtex 81 | @inproceedings{xu2021volumegan, 82 | title = {3D-aware Image Synthesis via Learning Structural and Textural Representations}, 83 | author = {Xu, Yinghao and Peng, Sida and Yang, Ceyuan and Shen, Yujun and Zhou, Bolei}, 84 | booktitle = {CVPR}, 85 | year = {2022} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all configs.""" 3 | 4 | from .stylegan_config import StyleGANConfig 5 | from .stylegan2_config import StyleGAN2Config 6 | from .stylegan2_finetune_config import StyleGAN2FineTuneConfig 7 | from .stylegan3_config import StyleGAN3Config 8 | from .volumegan_ffhq_config import VolumeGANFFHQConfig 9 | __all__ = ['CONFIG_POOL', 'build_config'] 10 | 11 | CONFIG_POOL = [ 12 | StyleGANConfig, 13 | StyleGAN2Config, 14 | StyleGAN2FineTuneConfig, 15 | StyleGAN3Config, 16 | VolumeGANFFHQConfig, 17 | ] 18 | 19 | 20 | def build_config(invoked_command, kwargs): 21 | """Builds a configuration based on the invoked command. 22 | 23 | Args: 24 | invoked_command: The command that is invoked. 25 | kwargs: Keyword arguments passed from command line, which will be used 26 | to build the configuration. 27 | 28 | Raises: 29 | ValueError: If the `invoked_command` is missing. 30 | """ 31 | for config in CONFIG_POOL: 32 | if config.name == invoked_command: 33 | return config(kwargs) 34 | raise ValueError(f'Invoked command `{invoked_command}` is missing!\n') 35 | -------------------------------------------------------------------------------- /convert_model.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Script to convert officially released models to match this repository.""" 3 | 4 | import os 5 | import argparse 6 | 7 | from converters import build_converter 8 | 9 | 10 | def parse_args(): 11 | """Parses arguments.""" 12 | parser = argparse.ArgumentParser(description='Convert pre-trained models.') 13 | parser.add_argument('model_type', type=str, 14 | choices=['pggan', 'stylegan', 'stylegan2', 15 | 'stylegan2ada_tf', 'stylegan2ada_pth', 16 | 'stylegan3'], 17 | help='Type of the model to convert.') 18 | parser.add_argument('--source_model_path', type=str, required=True, 19 | help='Path to load the model for conversion.') 20 | parser.add_argument('--target_model_path', type=str, required=True, 21 | help='Path to save the converted model.') 22 | parser.add_argument('--forward_test_num', type=int, default=10, 23 | help='Number of samples used for forward test. ' 24 | '(default: %(default)s)') 25 | parser.add_argument('--backward_test_num', type=int, default=0, 26 | help='Number of samples used for backward test. ' 27 | '(default: %(default)s)') 28 | parser.add_argument('--save_test_image', action='store_true', 29 | help='Whether to save the intermediate image in ' 30 | 'forward test. (default: False)') 31 | parser.add_argument('--learning_rate', type=float, default=0.01, 32 | help='Learning rate used in backward test. ' 33 | '(default: %(default)s)') 34 | parser.add_argument('--verbose_log', action='store_true', 35 | help='Whether to print verbose log. (default: False)') 36 | return parser.parse_args() 37 | 38 | 39 | def main(): 40 | """Main function.""" 41 | args = parse_args() 42 | 43 | if os.path.exists(args.target_model_path): 44 | raise SystemExit(f'File `{args.target_model_path}` has already ' 45 | f'existed!\n' 46 | f'Please specify another path.') 47 | 48 | converter = build_converter(args.model_type, verbose_log=args.verbose_log) 49 | converter.run(src_path=args.source_model_path, 50 | dst_path=args.target_model_path, 51 | forward_test_num=args.forward_test_num, 52 | backward_test_num=args.backward_test_num, 53 | save_test_image=args.save_test_image, 54 | learning_rate=args.learning_rate) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /converters/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all model converters.""" 3 | 4 | from .pggan_converter import PGGANConverter 5 | from .stylegan_converter import StyleGANConverter 6 | from .stylegan2_converter import StyleGAN2Converter 7 | from .stylegan2ada_tf_converter import StyleGAN2ADATFConverter 8 | from .stylegan2ada_pth_converter import StyleGAN2ADAPTHConverter 9 | from .stylegan3_converter import StyleGAN3Converter 10 | 11 | __all__ = ['build_converter'] 12 | 13 | _CONVERTERS = { 14 | 'pggan': PGGANConverter, 15 | 'stylegan': StyleGANConverter, 16 | 'stylegan2': StyleGAN2Converter, 17 | 'stylegan2ada_tf': StyleGAN2ADATFConverter, 18 | 'stylegan2ada_pth': StyleGAN2ADAPTHConverter, 19 | 'stylegan3': StyleGAN3Converter 20 | } 21 | 22 | 23 | def build_converter(model_type, verbose_log=False): 24 | """Builds a converter based on the model type. 25 | 26 | Args: 27 | model_type: Type of the model that the converter serves, which is case 28 | sensitive. 29 | verbose_log: Whether to print verbose log messages. (default: False) 30 | 31 | Raises: 32 | ValueError: If the `model_type` is not supported. 33 | """ 34 | if model_type not in _CONVERTERS: 35 | raise ValueError(f'Invalid model type: `{model_type}`!\n' 36 | f'Types allowed: {list(_CONVERTERS)}.') 37 | 38 | return _CONVERTERS[model_type](verbose_log=verbose_log) 39 | -------------------------------------------------------------------------------- /datasets/data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all data loaders.""" 3 | 4 | import warnings 5 | 6 | from .iter_data_loader import IterDataLoader 7 | try: 8 | from .dali_data_loader import DALIDataLoader 9 | except ImportError: 10 | DALIDataLoader = None 11 | 12 | __all__ = ['build_data_loader'] 13 | 14 | _DATA_LOADERS_ALLOWED = ['iter', 'dali'] 15 | 16 | 17 | def build_data_loader(data_loader_type, 18 | dataset, 19 | batch_size, 20 | repeat=1, 21 | shuffle=True, 22 | seed=0, 23 | drop_last_sample=False, 24 | drop_last_batch=True, 25 | num_workers=0, 26 | prefetch_factor=2, 27 | pin_memory=False, 28 | num_threads=1): 29 | """Builds a data loader with given dataset. 30 | 31 | Args: 32 | data_loader_type: Class type to which the data loader belongs, which is 33 | case insensitive. 34 | dataset: The dataset to load data from. 35 | batch_size: The batch size of the data produced by each replica. 36 | repeat: Repeating number of the entire dataset. (default: 1) 37 | shuffle: Whether to shuffle the samples within each epoch. 38 | (default: True) 39 | seed: Random seed used for shuffling. (default: 0) 40 | drop_last_sample: Whether to drop the tailing samples that cannot be 41 | evenly distributed. (default: False) 42 | drop_last_batch: Whether to drop the last incomplete batch. 43 | (default: True) 44 | num_workers: Number of workers to prefetch data for each replica. 45 | (default: 0) 46 | prefetch_factor: Number of samples loaded in advance by each worker. 47 | `N` means there will be a total of `N * num_workers` samples 48 | prefetched across all workers. (default: 2) 49 | pin_memory: Whether to use pinned memory for loaded data. This field is 50 | particularly used for `IterDataLoader`. (default: False) 51 | num_threads: Number of threads for each replica. This field is 52 | particularly used for `DALIDataLoader`. (default: 1) 53 | 54 | Raises: 55 | ValueError: If `data_loader_type` is not supported. 56 | NotImplementedError: If `data_loader_type` is not implemented yet. 57 | """ 58 | data_loader_type = data_loader_type.lower() 59 | if data_loader_type not in _DATA_LOADERS_ALLOWED: 60 | raise ValueError(f'Invalid data loader type: `{data_loader_type}`!\n' 61 | f'Types allowed: {_DATA_LOADERS_ALLOWED}.') 62 | 63 | if data_loader_type == 'dali' and DALIDataLoader is None: 64 | warnings.warn('DALI is not supported on the current environment! ' 65 | 'Fall back to `IterDataLoader`.') 66 | data_loader_type = 'iter' 67 | 68 | if data_loader_type == 'dali' and not dataset.support_dali: 69 | warnings.warn('DALI is not supported by some transformation node of ' 70 | 'the dataset! Fall back to `IterDataLoader`.') 71 | data_loader_type = 'iter' 72 | 73 | if data_loader_type == 'iter': 74 | return IterDataLoader(dataset=dataset, 75 | batch_size=batch_size, 76 | repeat=repeat, 77 | shuffle=shuffle, 78 | seed=seed, 79 | drop_last_sample=drop_last_sample, 80 | drop_last_batch=drop_last_batch, 81 | num_workers=num_workers, 82 | prefetch_factor=prefetch_factor, 83 | pin_memory=pin_memory) 84 | if data_loader_type == 'dali': 85 | return DALIDataLoader(dataset=dataset, 86 | batch_size=batch_size, 87 | repeat=repeat, 88 | shuffle=shuffle, 89 | seed=seed, 90 | drop_last_sample=drop_last_sample, 91 | drop_last_batch=drop_last_batch, 92 | num_workers=num_workers, 93 | prefetch_factor=prefetch_factor, 94 | num_threads=num_threads) 95 | raise NotImplementedError(f'Not implemented data loader type ' 96 | f'`{data_loader_type}`!') 97 | -------------------------------------------------------------------------------- /datasets/data_loaders/dali_batch_iterator.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Wraps a batch-based iterator introduced in DALI. 3 | 4 | For more details, please refer to 5 | 6 | https://docs.nvidia.com/deeplearning/dali/user-guide/docs/ 7 | """ 8 | 9 | try: 10 | from nvidia.dali.plugin.base_iterator import LastBatchPolicy 11 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 12 | except ImportError as e: 13 | raise ImportError('DALI is not supported! Please install first.') from e 14 | 15 | __all__ = ['DALIBatchIterator'] 16 | 17 | 18 | class DALIBatchIterator(DALIGenericIterator): 19 | """Defines the batch iterator for DALI data pre-processing. 20 | 21 | Args: 22 | pipeline: The pre-defined pipeline for data pre-processing. 23 | batch_size: Number of samples for each batch. 24 | drop_last_batch: Whether to drop the last incomplete batch. 25 | (default: True) 26 | """ 27 | def __init__(self, pipeline, batch_size, drop_last_batch=True): 28 | self.batch_size = batch_size 29 | self.drop_last_batch = drop_last_batch 30 | 31 | if self.drop_last_batch: 32 | last_batch_padded = False 33 | last_batch_policy = LastBatchPolicy.FILL 34 | self.num_batches = len(pipeline) // batch_size 35 | else: 36 | last_batch_padded = True 37 | last_batch_policy = LastBatchPolicy.DROP 38 | self.num_batches = (len(pipeline) - 1) // batch_size + 1 39 | 40 | super().__init__(pipelines=pipeline, 41 | size=-1, 42 | auto_reset=False, 43 | output_map=pipeline.dataset.output_keys, 44 | last_batch_padded=last_batch_padded, 45 | last_batch_policy=last_batch_policy, 46 | prepare_first_batch=True) 47 | 48 | def __next__(self): 49 | # [0] means the first GPU. In the distributed case, each replica only 50 | # has one GPU. 51 | return super().__next__()[0] 52 | 53 | def __len__(self): 54 | return self.num_batches 55 | -------------------------------------------------------------------------------- /datasets/data_loaders/dali_data_loader.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of DALI-based data loader. 3 | 4 | For more details, please refer to 5 | 6 | https://docs.nvidia.com/deeplearning/dali/user-guide/docs/ 7 | """ 8 | 9 | from .dali_batch_iterator import DALIBatchIterator 10 | from .dali_pipeline import DALIPipeline 11 | from .distributed_sampler import DistributedSampler 12 | from .base_data_loader import BaseDataLoader 13 | 14 | __all__ = ['DALIDataLoader'] 15 | 16 | 17 | class DALIDataLoader(BaseDataLoader): 18 | """Defines the DALI-based data loader.""" 19 | 20 | def __init__(self, 21 | dataset, 22 | batch_size, 23 | repeat=1, 24 | shuffle=True, 25 | seed=0, 26 | drop_last_sample=False, 27 | drop_last_batch=True, 28 | num_workers=0, 29 | prefetch_factor=2, 30 | num_threads=1): 31 | """Initializes the data loader. 32 | 33 | Args: 34 | num_threads: Number of threads used for each replica. (default: 1) 35 | """ 36 | self.num_threads = num_threads 37 | self._pipeline = None 38 | super().__init__(dataset=dataset, 39 | batch_size=batch_size, 40 | repeat=repeat, 41 | shuffle=shuffle, 42 | seed=seed, 43 | drop_last_sample=drop_last_sample, 44 | drop_last_batch=drop_last_batch, 45 | num_workers=num_workers, 46 | prefetch_factor=prefetch_factor) 47 | 48 | def __len__(self): 49 | return len(self.iter_loader) 50 | 51 | def build(self): 52 | self._sampler = DistributedSampler( 53 | dataset=self._dataset, 54 | shuffle=self.shuffle, 55 | repeat=self.repeat, 56 | seed=self.seed, 57 | drop_last_sample=self.drop_last_sample, 58 | for_dali=True) 59 | prefetch_queue_depth = max(1, self.num_workers * self.prefetch_factor) 60 | self._pipeline = DALIPipeline(dataset=self._dataset, 61 | sampler=self._sampler, 62 | batch_size=self.batch_size, 63 | seed=self.seed, 64 | num_workers=self.num_workers, 65 | num_threads=self.num_threads, 66 | prefetch_queue_depth=prefetch_queue_depth) 67 | self._iter_loader = DALIBatchIterator( 68 | pipeline=self._pipeline, 69 | batch_size=self.batch_size, 70 | drop_last_batch=self.drop_last_batch) 71 | 72 | def reset_iter_loader(self): 73 | self._iter_loader.reset() 74 | 75 | def info(self): 76 | data_loader_info = super().info() 77 | data_loader_info['Num threads'] = self.num_threads 78 | return data_loader_info 79 | -------------------------------------------------------------------------------- /datasets/data_loaders/dali_pipeline.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Wraps the data pre-processing pipeline introduced in DALI. 3 | 4 | DALI deploys the data pre-processing pipeline on GPU instead of CPU for 5 | acceleration. It relies on a pre-compiling graph. This file wraps this pipeline 6 | to fit the data loader. 7 | 8 | For more details, please refer to 9 | 10 | https://docs.nvidia.com/deeplearning/dali/user-guide/docs/ 11 | """ 12 | 13 | try: 14 | import dill 15 | 16 | import nvidia.dali.ops as ops 17 | from nvidia.dali.pipeline import Pipeline 18 | except ImportError as e: 19 | raise ImportError('DALI is not supported! Please install first.') from e 20 | 21 | __all__ = ['DALIPipeline'] 22 | 23 | 24 | class DALIPipeline(Pipeline): 25 | """Defines the pipeline for DALI data pre-processing. 26 | 27 | Args: 28 | dataset: Dataset to load data from. 29 | sampler: Index sampler. 30 | batch_size: Number of samples for each batch. 31 | seed: Seed for randomness in data pre-processing. (default: 0) 32 | num_workers: Number of workers to pre-fetch data (on CPU) from the 33 | dataset. (default: 0) 34 | num_threads: Number of threads used for data pre-processing by the 35 | current replica. (default: 1) 36 | prefetch_queue_depth: Prefetch queue depth. (default: 1) 37 | """ 38 | 39 | def __init__(self, 40 | dataset, 41 | sampler, 42 | batch_size, 43 | seed=0, 44 | num_workers=0, 45 | num_threads=1, 46 | prefetch_queue_depth=1): 47 | self._dataset = dataset 48 | self._sampler = sampler 49 | 50 | # Starting node of the data pre-processing graph. 51 | self.get_raw_data = ops.ExternalSource( 52 | source=self.sampler, 53 | num_outputs=self.dataset.num_raw_outputs, 54 | parallel=True, 55 | prefetch_queue_depth=prefetch_queue_depth) 56 | 57 | if seed >= 0: 58 | seed = seed * self.sampler.world_size + self.sampler.rank 59 | else: 60 | seed = -1 61 | 62 | if dataset.has_customized_function_for_dali: 63 | exec_pipelined = False 64 | exec_async = False 65 | num_workers = 0 66 | else: 67 | exec_pipelined = True 68 | exec_async = True 69 | super().__init__(batch_size=batch_size, 70 | num_threads=num_threads, 71 | device_id=self.sampler.rank, 72 | seed=seed, 73 | exec_pipelined=exec_pipelined, 74 | exec_async=exec_async, 75 | py_num_workers=num_workers, 76 | py_start_method='spawn', 77 | py_callback_pickler=dill) 78 | 79 | @property 80 | def dataset(self): 81 | """Returns the dataset.""" 82 | return self._dataset 83 | 84 | @property 85 | def sampler(self): 86 | """Returns the sampler.""" 87 | return self._sampler 88 | 89 | def define_graph(self): 90 | """Builds a static graph for data transformation.""" 91 | return self.dataset.define_dali_graph(self.get_raw_data()) 92 | 93 | def __len__(self): 94 | return len(self.sampler) 95 | -------------------------------------------------------------------------------- /datasets/data_loaders/iter_data_loader.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of iteration-based data loader.""" 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from .distributed_sampler import DistributedSampler 7 | from .base_data_loader import BaseDataLoader 8 | 9 | __all__ = ['IterDataLoader'] 10 | 11 | 12 | class IterDataLoader(BaseDataLoader): 13 | """Defines the iteration-based data loader.""" 14 | 15 | def __init__(self, 16 | dataset, 17 | batch_size, 18 | repeat=1, 19 | shuffle=True, 20 | seed=0, 21 | drop_last_sample=False, 22 | drop_last_batch=True, 23 | num_workers=0, 24 | prefetch_factor=2, 25 | pin_memory=False): 26 | """Initializes the data loader. 27 | 28 | Args: 29 | pin_memory: Whether to use pinned memory for loaded data. If `True`, 30 | it will be faster to move data from CPU to GPU, however, it may 31 | require a high-performance computing system. (default: False) 32 | """ 33 | self.pin_memory = pin_memory 34 | self._batch_grouper = None 35 | super().__init__(dataset=dataset, 36 | batch_size=batch_size, 37 | repeat=repeat, 38 | shuffle=shuffle, 39 | seed=seed, 40 | drop_last_sample=drop_last_sample, 41 | drop_last_batch=drop_last_batch, 42 | num_workers=num_workers, 43 | prefetch_factor=prefetch_factor) 44 | 45 | def __len__(self): 46 | return len(self._batch_grouper) 47 | 48 | def build(self): 49 | self._sampler = DistributedSampler( 50 | dataset=self._dataset, 51 | repeat=self.repeat, 52 | shuffle=self.shuffle, 53 | seed=self.seed, 54 | drop_last_sample=self.drop_last_sample, 55 | for_dali=False) 56 | self._batch_grouper = DataLoader(dataset=self._dataset, 57 | batch_size=self.batch_size, 58 | sampler=self._sampler, 59 | shuffle=False, 60 | drop_last=self.drop_last_batch, 61 | num_workers=self.num_workers, 62 | pin_memory=self.pin_memory, 63 | prefetch_factor=self.prefetch_factor) 64 | self._iter_loader = iter(self._batch_grouper) 65 | 66 | def reset_iter_loader(self): 67 | self._iter_loader = iter(self._batch_grouper) 68 | 69 | def info(self): 70 | data_loader_info = super().info() 71 | data_loader_info['Pin memory'] = self.pin_memory 72 | return data_loader_info 73 | -------------------------------------------------------------------------------- /datasets/file_readers/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all file readers.""" 3 | 4 | from .directory_reader import DirectoryReader 5 | from .lmdb_reader import LmdbReader 6 | from .tar_reader import TarReader 7 | from .zip_reader import ZipReader 8 | 9 | __all__ = ['build_file_reader'] 10 | 11 | _READERS = { 12 | 'dir': DirectoryReader, 13 | 'lmdb': LmdbReader, 14 | 'tar': TarReader, 15 | 'zip': ZipReader 16 | } 17 | 18 | 19 | def build_file_reader(reader_type='zip'): 20 | """Builds a file reader. 21 | 22 | Args: 23 | reader_type: Type of the file reader, which is case insensitive. 24 | (default: `zip`) 25 | 26 | Raises: 27 | ValueError: If the `reader_type` is not supported. 28 | """ 29 | reader_type = reader_type.lower() 30 | if reader_type not in _READERS: 31 | raise ValueError(f'Invalid reader type: `{reader_type}`!\n' 32 | f'Types allowed: {list(_READERS)}.') 33 | return _READERS[reader_type] 34 | -------------------------------------------------------------------------------- /datasets/file_readers/base_reader.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the base class to read files. 3 | 4 | A file reader reads data from a given file and cache the file if possible. 5 | Typically, file readers are designed to read files from zip, lmdb, or directory. 6 | """ 7 | 8 | from utils.misc import IMAGE_EXTENSIONS 9 | from utils.misc import check_file_ext 10 | 11 | __all__ = ['BaseReader'] 12 | 13 | 14 | class BaseReader(object): 15 | """Defines the base file reader. 16 | 17 | A reader should have the following functions: 18 | 19 | (1) open(): The function to open (and cache) a given file/directory. 20 | (2) close(): The function to close a given file/directory. 21 | (3) open_anno_file(): The function to open a specific annotation file inside 22 | the given file/directory. 23 | (4) get_file_list(): The function to get the list of all files inside the 24 | given file/directory. The returned list is already sorted. 25 | (5) get_file_list_with_ext(): The function to get the list of files with 26 | expected file extensions inside the given file/directory. The returned 27 | list is already sorted. 28 | (6) get_image_list(): The function to get the list of all images inside the 29 | given file/directory. The returned list is already sorted. 30 | (7) fetch_file(): The function to fetch the bytes of member file inside the 31 | given file/directory. 32 | """ 33 | 34 | @staticmethod 35 | def open(path): 36 | """Opens the given path.""" 37 | raise NotImplementedError('Should be implemented in derived class!') 38 | 39 | @staticmethod 40 | def close(path): 41 | """Closes the given path.""" 42 | raise NotImplementedError('Should be implemented in derived class!') 43 | 44 | @staticmethod 45 | def open_anno_file(path, anno_filename=None): 46 | """Opens the annotation file in `path` and returns a file pointer. 47 | 48 | If the annotation file does not exist, return `None`. 49 | """ 50 | raise NotImplementedError('Should be implemented in derived class!') 51 | 52 | @staticmethod 53 | def _get_file_list(path): 54 | """Gets the list of all files inside `path`.""" 55 | raise NotImplementedError('Should be implemented in derived class!') 56 | 57 | @classmethod 58 | def get_file_list(cls, path): 59 | """Gets the sorted list of all files inside `path`.""" 60 | return sorted(cls._get_file_list(path)) 61 | 62 | @classmethod 63 | def get_file_list_with_ext(cls, path, ext=None): 64 | """Gets the sorted list of files with expected extensions. 65 | 66 | NOTE: If no `ext` is specified, the list of all files will be returned. 67 | """ 68 | ext = ext or [] 69 | assert isinstance(ext, (list, tuple)) 70 | if len(ext) == 0: 71 | return cls.get_file_list(path) 72 | return [f for f in cls.get_file_list(path) if check_file_ext(f, *ext)] 73 | 74 | @classmethod 75 | def get_image_list(cls, path): 76 | """Gets the sorted list of image files inside `path`.""" 77 | return cls.get_file_list_with_ext(path, IMAGE_EXTENSIONS) 78 | 79 | @staticmethod 80 | def fetch_file(path, filename): 81 | """Fetches the bytes of file `filename` inside `path`. 82 | 83 | Example: 84 | 85 | >>> f = BaseReader.fetch_file('data', 'face.obj') 86 | >>> obj = f.decode('utf-8') # convert `bytes` to `str` 87 | """ 88 | raise NotImplementedError('Should be implemented in derived class!') 89 | -------------------------------------------------------------------------------- /datasets/file_readers/directory_reader.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of directory reader. 3 | 4 | This reader can summarize file list or fetch bytes of files inside a directory. 5 | """ 6 | 7 | import os.path 8 | 9 | from .base_reader import BaseReader 10 | 11 | __all__ = ['DirectoryReader'] 12 | 13 | 14 | class DirectoryReader(BaseReader): 15 | """Defines a class to load directory.""" 16 | 17 | @staticmethod 18 | def open(path): 19 | assert os.path.isdir(path), f'Directory `{path}` is invalid!' 20 | return path 21 | 22 | @staticmethod 23 | def close(path): 24 | _ = path # Dummy function. 25 | 26 | @staticmethod 27 | def open_anno_file(path, anno_filename=None): 28 | path = DirectoryReader.open(path) 29 | if not anno_filename: 30 | return None 31 | anno_path = os.path.join(path, anno_filename) 32 | if not os.path.isfile(anno_path): 33 | return None 34 | # File will be closed after parsed in dataset. 35 | return open(anno_path, 'r') 36 | 37 | @staticmethod 38 | def _get_file_list(path): 39 | path = DirectoryReader.open(path) 40 | return os.listdir(path) 41 | 42 | @staticmethod 43 | def fetch_file(path, filename): 44 | path = DirectoryReader.open(path) 45 | with open(os.path.join(path, filename), 'rb') as f: 46 | file_bytes = f.read() 47 | return file_bytes 48 | -------------------------------------------------------------------------------- /datasets/file_readers/lmdb_reader.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of LMDB database reader. 3 | 4 | This reader can summarize file list or fetch bytes of files inside a LMDB 5 | database. 6 | """ 7 | 8 | import lmdb 9 | 10 | from .base_reader import BaseReader 11 | 12 | __all__ = ['LmdbReader'] 13 | 14 | 15 | class LmdbReader(BaseReader): 16 | """Defines a class to load LMDB file. 17 | 18 | This is a static class, which is used to solve the problem that different 19 | data workers cannot share the same memory. 20 | """ 21 | 22 | reader_cache = dict() 23 | 24 | @staticmethod 25 | def open(path): 26 | """Opens a lmdb file.""" 27 | lmdb_files = LmdbReader.reader_cache 28 | if path not in lmdb_files: 29 | env = lmdb.open(path, 30 | max_readers=1, 31 | readonly=True, 32 | lock=False, 33 | readahead=False, 34 | meminit=False) 35 | with env.begin(write=False) as txn: 36 | num_samples = txn.stat()['entries'] 37 | keys = [key for key, _ in txn.cursor()] 38 | file_info = {'env': env, 39 | 'num_samples': num_samples, 40 | 'keys': keys} 41 | lmdb_files[path] = file_info 42 | return lmdb_files[path] 43 | 44 | @staticmethod 45 | def close(path): 46 | lmdb_files = LmdbReader.reader_cache 47 | lmdb_file = lmdb_files.pop(path, None) 48 | if lmdb_file is not None: 49 | lmdb_file['env'].close() 50 | lmdb_file.clear() 51 | 52 | @staticmethod 53 | def open_anno_file(path, anno_filename=None): 54 | # TODO: Support loading annotation file from LMDB. 55 | return None 56 | 57 | @staticmethod 58 | def _get_file_list(path): 59 | lmdb_file = LmdbReader.open(path) 60 | return lmdb_file['keys'] 61 | 62 | @classmethod 63 | def get_file_list_with_ext(cls, path, ext=None): 64 | # NOTE: In LMDB, keys do not reveal file extension. 65 | return cls.get_file_list(path) 66 | 67 | @classmethod 68 | def get_image_list(cls, path): 69 | # NOTE: In LMDB, keys do not reveal file extension. 70 | return cls.get_file_list(path) 71 | 72 | @staticmethod 73 | def fetch_file(path, filename): 74 | lmdb_file = LmdbReader.open(path) 75 | env = lmdb_file['env'] 76 | with env.begin(write=False) as txn: 77 | file_bytes = txn.get(filename) 78 | return file_bytes 79 | -------------------------------------------------------------------------------- /datasets/file_readers/tar_reader.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of TAR file reader. 3 | 4 | Basically, a TAR file will be first extracted, under the same root directory as 5 | the source TAR file, and with the same base name as the source TAR file. For 6 | example, the TAR file `/home/data/test_data.tar.gz` will be extracted to 7 | `/home/data/test_data/`. Then, this file reader degenerates into 8 | `DirectoryReader`. 9 | 10 | NOTE: TAR file is not recommended to use. Instead, please use ZIP file. 11 | """ 12 | 13 | import os.path 14 | import shutil 15 | import tarfile 16 | 17 | from .base_reader import BaseReader 18 | 19 | __all__ = ['TarReader'] 20 | 21 | 22 | class TarReader(BaseReader): 23 | """Defines a class to load TAR file. 24 | 25 | This is a static class, which is used to solve the problem that different 26 | data workers cannot share the same memory. 27 | """ 28 | 29 | reader_cache = dict() 30 | 31 | @staticmethod 32 | def open(path): 33 | tar_files = TarReader.reader_cache 34 | if path not in tar_files: 35 | root_dir = os.path.dirname(path) 36 | base_dir = os.path.basename(path).split('.tar')[0] 37 | extract_dir = os.path.join(root_dir, base_dir) 38 | filenames = [] 39 | with tarfile.open(path, 'r') as f: 40 | for member in f.getmembers(): 41 | if member.isfile(): 42 | filenames.append(member.name) 43 | f.extractall(extract_dir) 44 | file_info = {'extract_dir': extract_dir, 45 | 'filenames': filenames} 46 | tar_files[path] = file_info 47 | return tar_files[path] 48 | 49 | @staticmethod 50 | def close(path): 51 | tar_files = TarReader.reader_cache 52 | tar_file = tar_files.pop(path, None) 53 | if tar_file is not None: 54 | extract_dir = tar_file['extract_dir'] 55 | shutil.rmtree(extract_dir) 56 | tar_file.clear() 57 | 58 | @staticmethod 59 | def open_anno_file(path, anno_filename=None): 60 | tar_file = TarReader.open(path) 61 | if not anno_filename: 62 | return None 63 | anno_path = os.path.join(tar_file['extract_dir'], anno_filename) 64 | if not os.path.isfile(anno_path): 65 | return None 66 | # File will be closed after parsed in dataset. 67 | return open(anno_path, 'r') 68 | 69 | @staticmethod 70 | def _get_file_list(path): 71 | tar_file = TarReader.open(path) 72 | return tar_file['filenames'] 73 | 74 | @staticmethod 75 | def fetch_file(path, filename): 76 | tar_file = TarReader.open(path) 77 | with open(os.path.join(tar_file['extract_dir'], filename), 'rb') as f: 78 | file_bytes = f.read() 79 | return file_bytes 80 | -------------------------------------------------------------------------------- /datasets/file_readers/zip_reader.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of ZIP file reader. 3 | 4 | This reader can summarize file list or fetch bytes of files inside a ZIP. 5 | """ 6 | 7 | import zipfile 8 | 9 | from .base_reader import BaseReader 10 | 11 | __all__ = ['ZipReader'] 12 | 13 | 14 | class ZipReader(BaseReader): 15 | """Defines a class to load ZIP file. 16 | 17 | This is a static class, which is used to solve the problem that different 18 | data workers cannot share the same memory. 19 | """ 20 | 21 | reader_cache = dict() 22 | 23 | @staticmethod 24 | def open(path): 25 | zip_files = ZipReader.reader_cache 26 | if path not in zip_files: 27 | # File will be closed by calling `cls.close()`. 28 | zip_files[path] = zipfile.ZipFile(path, 'r') # pylint: disable=consider-using-with 29 | return zip_files[path] 30 | 31 | @staticmethod 32 | def close(path): 33 | zip_files = ZipReader.reader_cache 34 | zip_file = zip_files.pop(path, None) 35 | if zip_file is not None: 36 | zip_file.close() 37 | 38 | @staticmethod 39 | def open_anno_file(path, anno_filename=None): 40 | zip_file = ZipReader.open(path) 41 | if not anno_filename: 42 | return None 43 | if anno_filename not in zip_file.namelist(): 44 | return None 45 | # File will be closed after parsed in dataset. 46 | return zip_file.open(anno_filename, 'r') 47 | 48 | @staticmethod 49 | def _get_file_list(path): 50 | zip_file = ZipReader.open(path) 51 | return [ 52 | info.filename for info in zip_file.infolist() if not info.is_dir()] 53 | 54 | @staticmethod 55 | def fetch_file(path, filename): 56 | zip_file = ZipReader.open(path) 57 | return zip_file.read(filename) 58 | -------------------------------------------------------------------------------- /datasets/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all transformations for data pre-processing.""" 3 | 4 | from .affine_transform import AffineTransform 5 | from .blur_and_sharpen import BlurAndSharpen 6 | from .crop import CenterCrop 7 | from .crop import RandomCrop 8 | from .crop import LongSideCrop 9 | from .decode import Decode 10 | from .flip import Flip 11 | from .hsv_jittering import HSVJittering 12 | from .identity import Identity 13 | from .jpeg_compress import JpegCompress 14 | from .normalize import Normalize 15 | from .region_brightness import RegionBrightness 16 | from .resize import Resize 17 | from .resize import ProgressiveResize 18 | from .resize import ResizeAug 19 | 20 | __all__ = ['build_transformation'] 21 | 22 | 23 | _TRANSFORMATIONS = { 24 | 'AffineTransform': AffineTransform, 25 | 'BlurAndSharpen': BlurAndSharpen, 26 | 'CenterCrop': CenterCrop, 27 | 'RandomCrop': RandomCrop, 28 | 'LongSideCrop': LongSideCrop, 29 | 'Decode': Decode, 30 | 'Flip': Flip, 31 | 'HSVJittering': HSVJittering, 32 | 'Identity': Identity, 33 | 'JpegCompress': JpegCompress, 34 | 'Normalize': Normalize, 35 | 'RegionBrightness': RegionBrightness, 36 | 'Resize': Resize, 37 | 'ProgressiveResize': ProgressiveResize, 38 | 'ResizeAug': ResizeAug 39 | } 40 | 41 | 42 | def build_transformation(transform_type, **kwargs): 43 | """Builds a transformation based on its class type. 44 | 45 | Args: 46 | transform_type: Class type to which the transformation belongs, 47 | which is case sensitive. 48 | **kwargs: Additional arguments to build the transformation. 49 | 50 | Raises: 51 | ValueError: If the `transform_type` is not supported. 52 | """ 53 | if transform_type not in _TRANSFORMATIONS: 54 | raise ValueError(f'Invalid transformation type: ' 55 | f'`{transform_type}`!\n' 56 | f'Types allowed: {list(_TRANSFORMATIONS)}.') 57 | return _TRANSFORMATIONS[transform_type](**kwargs) 58 | -------------------------------------------------------------------------------- /datasets/transformations/flip.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Implements image flipping.""" 3 | 4 | import numpy as np 5 | 6 | try: 7 | import nvidia.dali.fn as fn 8 | except ImportError: 9 | fn = None 10 | 11 | from .base_transformation import BaseTransformation 12 | 13 | __all__ = ['Flip'] 14 | 15 | 16 | class Flip(BaseTransformation): 17 | """Applies random flipping to images. 18 | 19 | Args: 20 | horizontal_prob: Probability of flipping images horizontally. 21 | (default: 0.0) 22 | vertical_prob: Probability of flipping images vertically. (default: 0.0) 23 | """ 24 | 25 | def __init__(self, horizontal_prob=0.0, vertical_prob=0.0): 26 | super().__init__(support_dali=(fn is not None)) 27 | 28 | self.horizontal_prob = np.clip(horizontal_prob, 0, 1) 29 | self.vertical_prob = np.clip(vertical_prob, 0, 1) 30 | 31 | def _CPU_forward(self, data): 32 | do_horizontal = np.random.uniform() < self.horizontal_prob 33 | do_vertical = np.random.uniform() < self.vertical_prob 34 | 35 | # Early return if no flipping is applied. 36 | if not do_horizontal and not do_vertical: 37 | return data 38 | 39 | outputs = [] 40 | for image in data: 41 | if do_horizontal: 42 | image = image[:, ::-1] 43 | if do_vertical: 44 | image = image[::-1, :] 45 | outputs.append(np.ascontiguousarray(image)) 46 | return outputs 47 | 48 | def _DALI_forward(self, data): 49 | do_horizontal = fn.random.coin_flip(probability=self.horizontal_prob) 50 | do_vertical = fn.random.coin_flip(probability=self.vertical_prob) 51 | return fn.flip(data, horizontal=do_horizontal, vertical=do_vertical) 52 | -------------------------------------------------------------------------------- /datasets/transformations/hsv_jittering.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Implements image color jittering from the HSV space.""" 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | try: 8 | import nvidia.dali.fn as fn 9 | import nvidia.dali.types as types 10 | except ImportError: 11 | fn = None 12 | 13 | from utils.formatting_utils import format_range 14 | from .base_transformation import BaseTransformation 15 | 16 | __all__ = ['HSVJittering'] 17 | 18 | 19 | class HSVJittering(BaseTransformation): 20 | """Applies random color jittering to images from the HSV space. 21 | 22 | Args: 23 | h_range: The range within which to uniformly sample a hue. Use `(0, 0)` 24 | to disable the hue jittering. (default: (0, 0)) 25 | s_range: The range within which to uniformly sample a saturation. Use 26 | `(1, 1)` to disable the saturation jittering. (default: (1, 1)) 27 | v_range: The range within which to uniformly sample a brightness value. 28 | Use `(1, 1)` to disable the brightness jittering. (default: (1, 1)) 29 | """ 30 | 31 | def __init__(self, h_range=(0, 0), s_range=(1, 1), v_range=(1, 1)): 32 | super().__init__(support_dali=(fn is not None)) 33 | 34 | self.h_range = format_range(h_range) 35 | self.s_range = format_range(s_range, min_val=0) 36 | self.v_range = format_range(v_range, min_val=0) 37 | 38 | def _CPU_forward(self, data): 39 | # Early return if no jittering is needed. 40 | if (self.h_range == (0, 0) and self.s_range == (1, 1) and 41 | self.v_range == (1, 1)): 42 | return data 43 | 44 | # Get random jittering value for hue, saturation, and brightness. 45 | hue = np.random.uniform(*self.h_range) 46 | sat = np.random.uniform(*self.s_range) 47 | val = np.random.uniform(*self.v_range) 48 | 49 | # Perform color jittering. 50 | outputs = [] 51 | for image in data: 52 | assert image.shape[2] == 3, 'RGB image is expected!' 53 | h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_RGB2HSV)) 54 | h = ((h + hue) % 180).astype(np.uint8) 55 | s = np.clip(s * sat, 0, 255).astype(np.uint8) 56 | v = np.clip(v * val, 0, 255).astype(np.uint8) 57 | new_image = cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2RGB) 58 | outputs.append(new_image) 59 | return outputs 60 | 61 | def _DALI_forward(self, data): 62 | # Early return is no jittering is needed. 63 | if (self.h_range == (0, 0) and self.s_range == (1, 1) and 64 | self.v_range == (1, 1)): 65 | return data 66 | 67 | # Get random jittering value for hue, saturation, and brightness. 68 | if self.h_range[0] == self.h_range[1]: 69 | hue = self.h_range[0] 70 | else: 71 | hue = fn.random.uniform(range=self.h_range) 72 | hue = fn.cast(hue, dtype=types.FLOAT) 73 | if self.s_range[0] == self.s_range[1]: 74 | sat = self.s_range[0] 75 | else: 76 | sat = fn.random.uniform(range=self.s_range) 77 | sat = fn.cast(sat, dtype=types.FLOAT) 78 | if self.v_range[0] == self.v_range[1]: 79 | val = self.v_range[0] 80 | else: 81 | val = fn.random.uniform(range=self.v_range) 82 | val = fn.cast(val, dtype=types.FLOAT) 83 | 84 | # Perform color jittering. 85 | return fn.hsv(data, hue=hue, saturation=sat, value=val) 86 | -------------------------------------------------------------------------------- /datasets/transformations/identity.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Implements an identity transformation, which can be used as a placeholder.""" 3 | 4 | from .base_transformation import BaseTransformation 5 | 6 | __all__ = ['Identity'] 7 | 8 | 9 | class Identity(BaseTransformation): 10 | """Applies no transformation by directly returning the input.""" 11 | 12 | def __init__(self): 13 | super().__init__(support_dali=True) 14 | 15 | def _CPU_forward(self, data): 16 | return data 17 | 18 | def _DALI_forward(self, data): 19 | return data 20 | -------------------------------------------------------------------------------- /datasets/transformations/jpeg_compress.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Implements JPEG compression on images.""" 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | try: 8 | import nvidia.dali.fn as fn 9 | import nvidia.dali.types as types 10 | except ImportError: 11 | fn = None 12 | 13 | from utils.formatting_utils import format_range 14 | from .base_transformation import BaseTransformation 15 | 16 | __all__ = ['JpegCompress'] 17 | 18 | 19 | class JpegCompress(BaseTransformation): 20 | """Applies random JPEG compression to images. 21 | 22 | This transformation can be used as an augmentation by distorting images. 23 | In other words, the input image(s) will be first compressed (i.e., encoded) 24 | with a random quality ratio, and then decoded back to the image space. 25 | The distortion is introduced in the encoding process. 26 | 27 | Args: 28 | quality_range: The range within which to uniformly sample a quality 29 | value after compression. 100 means highest and 0 means lowest. 30 | (default: (40, 60)) 31 | prob: Probability of applying JPEG compression. (default: 0.5) 32 | """ 33 | 34 | def __init__(self, prob=0.5, quality_range=(40, 60)): 35 | super().__init__(support_dali=(fn is not None)) 36 | 37 | self.prob = np.clip(prob, 0, 1) 38 | self.quality_range = format_range(quality_range, min_val=0, max_val=100) 39 | 40 | def _CPU_forward(self, data): 41 | # Early return if no compression is needed. 42 | if np.random.uniform() >= self.prob: 43 | return data 44 | 45 | # Set compression quality. 46 | quality = np.random.randint(*self.quality_range) 47 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] 48 | 49 | # Compress images. 50 | outputs = [] 51 | for image in data: 52 | _, encoded_image = cv2.imencode('.jpg', image, encode_param) 53 | decoded_image = cv2.imdecode(encoded_image, cv2.IMREAD_UNCHANGED) 54 | if decoded_image.ndim == 2: 55 | decoded_image = decoded_image[:, :, np.newaxis] 56 | outputs.append(decoded_image) 57 | return outputs 58 | 59 | def _DALI_forward(self, data): 60 | # Set compression quality. 61 | if self.quality_range[0] == self.quality_range[1]: 62 | quality = self.quality_range[0] 63 | else: 64 | quality = fn.random.uniform(range=self.quality_range) 65 | quality = fn.cast(quality, dtype=types.INT32) 66 | 67 | # Compress images. 68 | compressed_data = fn.jpeg_compression_distortion( 69 | data, quality=quality) 70 | if not isinstance(compressed_data, (list, tuple)): 71 | compressed_data = [compressed_data] 72 | 73 | # Determine whether the transformation should be applied. 74 | cond = fn.random.coin_flip(dtype=types.BOOL, probability=self.prob) 75 | outputs = [] 76 | for image, compressed_image in zip(data, compressed_data): 77 | outputs.append(compressed_image * cond + image * (cond ^ True)) 78 | return outputs 79 | -------------------------------------------------------------------------------- /datasets/transformations/misc.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Helper functions for data transformation.""" 3 | 4 | try: 5 | import nvidia.dali.fn as fn 6 | import nvidia.dali.types as types 7 | except ImportError: 8 | fn = None 9 | 10 | __all__ = ['switch_between', 'FunctionOp'] 11 | 12 | 13 | def switch_between(cond, cond_true, cond_false, use_dali=False): 14 | """Switches between two transformation nodes for data pre-processing. 15 | 16 | Args: 17 | cond: Condition to switch between two alternatives. 18 | cond_true: The returned value if the condition fulfills. 19 | cond_false: The returned value if the condition fails. 20 | use_dali: Whether the nodes are from DALI pre-processing pipeline. 21 | (default: False) 22 | 23 | Returns: 24 | One of `cond_true` and `cond_false`, depending on `cond`. 25 | """ 26 | if use_dali and fn is None: 27 | raise NotImplementedError('DALI is not supported! ' 28 | 'Please install first.') 29 | 30 | if not use_dali: 31 | return cond_true if cond else cond_false 32 | 33 | # Record whether any input (cond_true/cond_false) is not a list. If that is 34 | # the case, the returned value will be a single node. Otherwise, the 35 | # returned value will also be a list of nodes. 36 | is_input_list = True 37 | if not isinstance(cond_true, (list, tuple)): 38 | is_input_list = False 39 | cond_true = [cond_true] 40 | if not isinstance(cond_false, (list, tuple)): 41 | is_input_list = False 42 | cond_false = [cond_false] 43 | assert len(cond_true) == len(cond_false) 44 | 45 | cond = fn.cast(cond, dtype=types.BOOL) 46 | outputs = [] 47 | for sample_true, sample_false in zip(cond_true, cond_false): 48 | outputs.append(sample_true * cond + sample_false * (cond ^ True)) 49 | return outputs if is_input_list else outputs[0] 50 | 51 | 52 | class FunctionOp(object): 53 | """Contains the class to turn a function as an operator. 54 | 55 | DALI supports creating a data node, which is populated with data from an 56 | external source function. This function should be callable via accepting one 57 | positional argument. This class is particularly designed to turn a function, 58 | with default settings, into a DALI compatible operator. Please refer to 59 | `nvidia.dali.fn.external_source()` for more details. 60 | 61 | More concretely, a function `f(a, b)` with desired arguments `a=1, b=2` can 62 | be wrapped with 63 | 64 | ``` 65 | op = FunctionOp(f, a=1, b=2) 66 | ``` 67 | 68 | Then, it can be used with 69 | 70 | ``` 71 | dali_node = fn.external_source(source=op, 72 | parallel=True, 73 | prefetch_queue_depth=32, 74 | batch=False) 75 | ``` 76 | 77 | OR, it can be directly called with 78 | 79 | ``` 80 | result = op() 81 | ``` 82 | """ 83 | 84 | def __init__(self, function, **kwargs): 85 | """Initializes with the function to wrap and the default arguments.""" 86 | self.function = function 87 | self.kwargs = kwargs 88 | 89 | def __call__(self, _dali_arg=None): 90 | return self.function(**self.kwargs) 91 | -------------------------------------------------------------------------------- /datasets/transformations/normalize.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Implements image normalization.""" 3 | 4 | import numpy as np 5 | 6 | try: 7 | import nvidia.dali.fn as fn 8 | import nvidia.dali.types as types 9 | except ImportError: 10 | fn = None 11 | 12 | from .base_transformation import BaseTransformation 13 | 14 | __all__ = ['Normalize'] 15 | 16 | 17 | class Normalize(BaseTransformation): 18 | """Normalizes images. 19 | 20 | The input images is expected to with pixel range [0, 255]. 21 | 22 | The output images will be with data format `CHW` and dtype `float32`. 23 | 24 | Args: 25 | min_val: The minimum value after normalization. (default: -1.0) 26 | max_val: The maximum value after normalization. (default: 1.0) 27 | """ 28 | 29 | def __init__(self, min_val=-1.0, max_val=1.0): 30 | super().__init__(support_dali=(fn is not None)) 31 | 32 | self.min_val = float(min_val) 33 | self.max_val = float(max_val) 34 | 35 | def _CPU_forward(self, data): 36 | outputs = [] 37 | for image in data: 38 | image = image.astype(np.float32) 39 | image = image / 255 * (self.max_val - self.min_val) + self.min_val 40 | image = image.transpose(2, 0, 1) 41 | outputs.append(image) 42 | return outputs 43 | 44 | def _DALI_forward(self, data): 45 | return fn.crop_mirror_normalize( 46 | data, 47 | dtype=types.FLOAT, 48 | output_layout=types.NCHW, 49 | scale=(self.max_val - self.min_val) / 255.0, 50 | shift=self.min_val) 51 | -------------------------------------------------------------------------------- /datasets/transformations/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects dataset related utility functions.""" 3 | 4 | from .affine_transform import generate_affine_transformation 5 | from .polygon import generate_polygon_contour 6 | from .polygon import generate_polygon_mask 7 | 8 | __all__ = [ 9 | 'generate_affine_transformation', 'generate_polygon_contour', 10 | 'generate_polygon_mask' 11 | ] 12 | -------------------------------------------------------------------------------- /datasets/transformations/utils/affine_transform.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the function to generate a random affine transformation.""" 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from utils.formatting_utils import format_range 8 | from utils.formatting_utils import format_image_size 9 | 10 | __all__ = ['generate_affine_transformation'] 11 | 12 | 13 | def generate_affine_transformation(image_size, 14 | rotation_range, 15 | scale_range, 16 | tx_range, 17 | ty_range): 18 | """Generates a random affine transformation matrix. 19 | 20 | Args: 21 | image_size: Size of the image, which is used as a reference of the 22 | transformation. The size is assumed with order (height, width). 23 | rotation_range: The range (in degrees) within which to uniformly sample 24 | a rotation angle. 25 | scale_range: The range within which to uniformly sample a scaling 26 | factor. 27 | tx_range: The range (in length of image width) within which to uniformly 28 | sample a X translation. 29 | ty_range: The range (in length of image height) within which to 30 | uniformly sample a Y translation. 31 | 32 | Returns: 33 | A transformation matrix, with shape [2, 3] and dtype `numpy.float32`. 34 | """ 35 | # Regularize inputs. 36 | height, width = format_image_size(image_size) 37 | rotation_range = format_range(rotation_range) 38 | scale_range = format_range(scale_range, min_val=0) 39 | tx_range = format_range(tx_range) 40 | ty_range = format_range(ty_range) 41 | 42 | # Sample parameters for the affine transformation. 43 | rotation = np.random.uniform(*rotation_range) 44 | scale = np.random.uniform(*scale_range) 45 | tx = np.random.uniform(*tx_range) 46 | ty = np.random.uniform(*ty_range) 47 | 48 | # Get the transformation matrix. 49 | matrix = cv2.getRotationMatrix2D(center=(width // 2, height // 2), 50 | angle=rotation, 51 | scale=scale) 52 | matrix[:, 2] += (tx * width, ty * height) 53 | 54 | return matrix.astype(np.float32) 55 | -------------------------------------------------------------------------------- /docs/assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/comparison.png -------------------------------------------------------------------------------- /docs/assets/font.css: -------------------------------------------------------------------------------- 1 | /* Homepage Font */ 2 | 3 | /* latin-ext */ 4 | @font-face { 5 | font-family: 'Lato'; 6 | font-style: normal; 7 | font-weight: 400; 8 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2'); 9 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 10 | } 11 | 12 | /* latin */ 13 | @font-face { 14 | font-family: 'Lato'; 15 | font-style: normal; 16 | font-weight: 400; 17 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2'); 18 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 19 | } 20 | 21 | /* latin-ext */ 22 | @font-face { 23 | font-family: 'Lato'; 24 | font-style: normal; 25 | font-weight: 700; 26 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2'); 27 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 28 | } 29 | 30 | /* latin */ 31 | @font-face { 32 | font-family: 'Lato'; 33 | font-style: normal; 34 | font-weight: 700; 35 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2'); 36 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 37 | } 38 | -------------------------------------------------------------------------------- /docs/assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/framework.png -------------------------------------------------------------------------------- /docs/assets/freezed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/freezed.png -------------------------------------------------------------------------------- /docs/assets/genforce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/genforce.png -------------------------------------------------------------------------------- /docs/assets/giraffe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/giraffe.png -------------------------------------------------------------------------------- /docs/assets/graf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/graf.png -------------------------------------------------------------------------------- /docs/assets/hologan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/hologan.png -------------------------------------------------------------------------------- /docs/assets/pigan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/pigan.png -------------------------------------------------------------------------------- /docs/assets/style.css: -------------------------------------------------------------------------------- 1 | /* Body */ 2 | body { 3 | background: #e3e5e8; 4 | color: #ffffff; 5 | font-family: 'Lato', Verdana, Helvetica, sans-serif; 6 | font-weight: 300; 7 | font-size: 14pt; 8 | } 9 | 10 | /* Hyperlinks */ 11 | a {text-decoration: none;} 12 | a:link {color: #1772d0;} 13 | a:visited {color: #1772d0;} 14 | a:active {color: red;} 15 | a:hover {color: #f09228;} 16 | 17 | /* Pre-formatted Text */ 18 | pre { 19 | margin: 5pt 0; 20 | border: 0; 21 | font-size: 12pt; 22 | background: #fcfcfc; 23 | } 24 | 25 | /* Project Page Style */ 26 | /* Section */ 27 | .section { 28 | width: 768pt; 29 | min-height: 100pt; 30 | margin: 15pt auto; 31 | padding: 20pt 30pt; 32 | border: 1pt hidden #000; 33 | text-align: justify; 34 | color: #000000; 35 | background: #ffffff; 36 | } 37 | 38 | /* Header (Title and Logo) */ 39 | .section .header { 40 | min-height: 80pt; 41 | margin-top: 30pt; 42 | } 43 | .section .header .logo { 44 | width: 80pt; 45 | margin-left: 10pt; 46 | float: left; 47 | } 48 | .section .header .logo img { 49 | width: 80pt; 50 | object-fit: cover; 51 | } 52 | .section .header .title { 53 | margin: 0 120pt; 54 | text-align: center; 55 | font-size: 22pt; 56 | } 57 | 58 | /* Author */ 59 | .section .author { 60 | margin: 5pt 0; 61 | text-align: center; 62 | font-size: 16pt; 63 | } 64 | 65 | /* Institution */ 66 | .section .institution { 67 | margin: 5pt 0; 68 | text-align: center; 69 | font-size: 16pt; 70 | } 71 | 72 | /* Hyperlink (such as Paper and Code) */ 73 | .section .link { 74 | margin: 5pt 0; 75 | text-align: center; 76 | font-size: 16pt; 77 | } 78 | 79 | /* Teaser */ 80 | .section .teaser { 81 | margin: 20pt 0; 82 | text-align: center; 83 | } 84 | .section .teaser img { 85 | width: 95%; 86 | } 87 | 88 | /* Section Title */ 89 | .section .title { 90 | text-align: center; 91 | font-size: 22pt; 92 | margin: 5pt 0 15pt 0; /* top right bottom left */ 93 | } 94 | 95 | /* Section Body */ 96 | .section .body { 97 | margin-bottom: 15pt; 98 | text-align: justify; 99 | font-size: 14pt; 100 | } 101 | 102 | /* BibTeX */ 103 | .section .bibtex { 104 | margin: 5pt 0; 105 | text-align: left; 106 | font-size: 22pt; 107 | } 108 | 109 | /* Related Work */ 110 | .section .ref { 111 | margin: 20pt 0 10pt 0; /* top right bottom left */ 112 | text-align: left; 113 | font-size: 18pt; 114 | font-weight: bold; 115 | } 116 | 117 | /* Citation */ 118 | .section .citation { 119 | min-height: 60pt; 120 | margin: 10pt 0; 121 | } 122 | .section .citation .image { 123 | width: 120pt; 124 | float: left; 125 | } 126 | .section .citation .image img { 127 | max-height: 60pt; 128 | width: 120pt; 129 | object-fit: cover; 130 | } 131 | .section .citation .comment{ 132 | margin-left: 130pt; 133 | text-align: left; 134 | font-size: 14pt; 135 | } 136 | -------------------------------------------------------------------------------- /docs/assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/docs/assets/teaser.png -------------------------------------------------------------------------------- /dump_command_args.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Dumps available arguments of all commands (configurations). 3 | 4 | This file parses the arguments of all commands provided in `configs/` and dump 5 | the results as a json file. Each parsed argument includes the name, argument 6 | type, default value, and the help message (description). The dumped file looks 7 | like 8 | 9 | { 10 | "command_1": { 11 | "type": "object", 12 | "properties": { 13 | "arg_group_1": { 14 | "type": "object", 15 | "properties": { 16 | "arg_1": { 17 | "is_recommended": # true / false 18 | "type": # int / float / bool / str / json-string / 19 | # index-string 20 | "default": 21 | "description": 22 | }, 23 | "arg_2": { 24 | "is_recommended": 25 | "type": 26 | "default": 27 | "description": 28 | } 29 | } 30 | }, 31 | "arg_group_2": { 32 | "type": "object", 33 | "properties": { 34 | "arg_3": { 35 | "is_recommended": 36 | "type": 37 | "default": 38 | "description": 39 | }, 40 | "arg_4": { 41 | "is_recommended": 42 | "type": 43 | "default": 44 | "description": 45 | } 46 | } 47 | } 48 | } 49 | }, 50 | "command_2": { 51 | "type": "object", 52 | "properties: { 53 | "arg_group_1": { 54 | "type": "object", 55 | "properties": { 56 | "arg_1": { 57 | "is_recommended": 58 | "type": 59 | "default": 60 | "description": 61 | } 62 | } 63 | } 64 | } 65 | } 66 | } 67 | """ 68 | 69 | import sys 70 | import json 71 | 72 | from configs import CONFIG_POOL 73 | 74 | 75 | def parse_args_from_config(config): 76 | """Parses available arguments from a configuration class. 77 | 78 | Args: 79 | config: The configuration class to parse arguments from, which is 80 | defined in `configs/`. This class is supposed to derive from 81 | `BaseConfig` defined in `configs/base_config.py`. 82 | """ 83 | recommended_opts = config.get_recommended_options() 84 | args = dict() 85 | for opt_group, opts in config.get_options().items(): 86 | args[opt_group] = dict( 87 | type='object', 88 | properties=dict() 89 | ) 90 | for opt in opts: 91 | arg = config.inspect_option(opt) 92 | args[opt_group]['properties'][arg.name] = dict( 93 | is_recommended=arg.name in recommended_opts, 94 | type=arg.type, 95 | default=arg.default, 96 | description=arg.help 97 | ) 98 | return args 99 | 100 | 101 | def dump(configs, save_path): 102 | """Dumps available arguments from given configurations to target file. 103 | 104 | Args: 105 | configs: A list of configurations, each of which should be a 106 | class derived from `BaseConfig` defined in `configs/base_config.py`. 107 | save_path: The path to save the dumped results. 108 | """ 109 | args = dict() 110 | for config in configs: 111 | args[config.name] = dict(type='object', 112 | properties=parse_args_from_config(config)) 113 | with open(save_path, 'w') as f: 114 | json.dump(args, f, indent=4) 115 | 116 | 117 | if __name__ == '__main__': 118 | if len(sys.argv) != 2: 119 | sys.exit(f'Usage: python {sys.argv[0]} SAVE_PATH') 120 | dump(CONFIG_POOL, sys.argv[1]) 121 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all metrics.""" 3 | 4 | from .gan_snapshot import GANSnapshot 5 | from .fid import FIDMetric as FID 6 | from .fid import FID50K 7 | from .fid import FID50KFull 8 | from .inception_score import ISMetric as IS 9 | from .inception_score import IS50K 10 | from .intra_class_fid import ICFIDMetric as ICFID 11 | from .intra_class_fid import ICFID50K 12 | from .intra_class_fid import ICFID50KFull 13 | from .kid import KIDMetric as KID 14 | from .kid import KID50K 15 | from .kid import KID50KFull 16 | from .gan_pr import GANPRMetric as GANPR 17 | from .gan_pr import GANPR50K 18 | from .gan_pr import GANPR50KFull 19 | from .equivariance import EquivarianceMetric 20 | from .equivariance import EQTMetric 21 | from .equivariance import EQT50K 22 | from .equivariance import EQTFracMetric 23 | from .equivariance import EQTFrac50K 24 | from .equivariance import EQRMetric 25 | from .equivariance import EQR50K 26 | 27 | __all__ = ['build_metric'] 28 | 29 | _METRICS = { 30 | 'GANSnapshot': GANSnapshot, 31 | 'FID': FID, 32 | 'FID50K': FID50K, 33 | 'FID50KFull': FID50KFull, 34 | 'IS': IS, 35 | 'IS50K': IS50K, 36 | 'ICFID': ICFID, 37 | 'ICFID50K': ICFID50K, 38 | 'ICFID50KFull': ICFID50KFull, 39 | 'KID': KID, 40 | 'KID50K': KID50K, 41 | 'KID50KFull': KID50KFull, 42 | 'GANPR': GANPR, 43 | 'GANPR50K': GANPR50K, 44 | 'GANPR50KFull': GANPR50KFull, 45 | 'Equivariance': EquivarianceMetric, 46 | 'EQT': EQTMetric, 47 | 'EQT50K': EQT50K, 48 | 'EQTFrac': EQTFracMetric, 49 | 'EQTFrac50K': EQTFrac50K, 50 | 'EQR': EQRMetric, 51 | 'EQR50K': EQR50K 52 | } 53 | 54 | 55 | def build_metric(metric_type, **kwargs): 56 | """Builds a metric evaluator based on its class type. 57 | 58 | Args: 59 | metric_type: Type of the metric, which is case sensitive. 60 | **kwargs: Configurations used to build the metric. 61 | 62 | Raises: 63 | ValueError: If the `metric_type` is not supported. 64 | """ 65 | if metric_type not in _METRICS: 66 | raise ValueError(f'Invalid metric type: `{metric_type}`!\n' 67 | f'Types allowed: {list(_METRICS)}.') 68 | return _METRICS[metric_type](**kwargs) 69 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all models.""" 3 | 4 | from .pggan_generator import PGGANGenerator 5 | from .pggan_discriminator import PGGANDiscriminator 6 | from .stylegan_generator import StyleGANGenerator 7 | from .stylegan_discriminator import StyleGANDiscriminator 8 | from .stylegan2_generator import StyleGAN2Generator 9 | from .stylegan2_discriminator import StyleGAN2Discriminator 10 | from .stylegan3_generator import StyleGAN3Generator 11 | from .volumegan_generator import VolumeGANGenerator 12 | from .volumegan_discriminator import VolumeGANDiscriminator 13 | from .ghfeat_encoder import GHFeatEncoder 14 | from .perceptual_model import PerceptualModel 15 | from .inception_model import InceptionModel 16 | 17 | __all__ = ['build_model'] 18 | 19 | _MODELS = { 20 | 'PGGANGenerator': PGGANGenerator, 21 | 'PGGANDiscriminator': PGGANDiscriminator, 22 | 'StyleGANGenerator': StyleGANGenerator, 23 | 'StyleGANDiscriminator': StyleGANDiscriminator, 24 | 'StyleGAN2Generator': StyleGAN2Generator, 25 | 'StyleGAN2Discriminator': StyleGAN2Discriminator, 26 | 'StyleGAN3Generator': StyleGAN3Generator, 27 | 'VolumeGANGenerator': VolumeGANGenerator, 28 | 'VolumeGANDiscriminator': VolumeGANDiscriminator, 29 | 'GHFeatEncoder': GHFeatEncoder, 30 | 'PerceptualModel': PerceptualModel.build_model, 31 | 'InceptionModel': InceptionModel.build_model 32 | } 33 | 34 | 35 | def build_model(model_type, **kwargs): 36 | """Builds a model based on its class type. 37 | 38 | Args: 39 | model_type: Class type to which the model belongs, which is case 40 | sensitive. 41 | **kwargs: Additional arguments to build the model. 42 | 43 | Raises: 44 | ValueError: If the `model_type` is not supported. 45 | """ 46 | if model_type not in _MODELS: 47 | raise ValueError(f'Invalid model type: `{model_type}`!\n' 48 | f'Types allowed: {list(_MODELS)}.') 49 | return _MODELS[model_type](**kwargs) 50 | -------------------------------------------------------------------------------- /models/rendering/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all functions for rendering.""" 3 | from .points_sampling import PointsSampling 4 | from .hierarchicle_sampling import HierarchicalSampling 5 | from .renderer import Renderer 6 | from .utils import interpolate_feature 7 | 8 | __all__ = ['PointsSampling', 'HierarchicalSampling', 'Renderer', 'interpolate_feature'] 9 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/models/utils/__init__.py -------------------------------------------------------------------------------- /models/utils/ops.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains operators for neural networks.""" 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | __all__ = ['all_gather'] 8 | 9 | 10 | def all_gather(tensor): 11 | """Gathers tensor from all devices and executes averaging.""" 12 | if not dist.is_initialized(): 13 | return tensor 14 | 15 | world_size = dist.get_world_size() 16 | tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] 17 | dist.all_gather(tensor_list, tensor, async_op=False) 18 | return torch.stack(tensor_list, dim=0).mean(dim=0) 19 | -------------------------------------------------------------------------------- /requirements/convert.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | tensorflow-gpu==1.15 3 | ninja==1.10.2 4 | scikit-video==1.1.11 5 | pillow==9.0.0 6 | opencv-python-headless==4.5.5.62 7 | requests 8 | bs4 9 | tqdm 10 | rich 11 | easydict 12 | -------------------------------------------------------------------------------- /requirements/develop.txt: -------------------------------------------------------------------------------- 1 | bpytop # Monitor system resources. 2 | gpustat # Monitor GPU usage. 3 | pylint # Check coding style. 4 | -------------------------------------------------------------------------------- /requirements/minimal.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html 2 | torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html 3 | tensorboard==2.7.0 4 | torch-tb-profiler==0.3.1 5 | ninja==1.10.2 6 | numpy==1.21.5 7 | scipy==1.7.3 8 | scikit-learn==1.0.2 9 | scikit-video==1.1.11 10 | pillow==9.0.0 11 | opencv-python-headless==4.5.5.62 12 | requests 13 | bs4 14 | tqdm 15 | rich 16 | click 17 | cloup 18 | psutil 19 | easydict 20 | lmdb 21 | matplotlib 22 | mrcfile 23 | pymcubes 24 | trimesh 25 | einops 26 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all runners.""" 3 | 4 | from .stylegan_runner import StyleGANRunner 5 | from .stylegan2_runner import StyleGAN2Runner 6 | from .stylegan3_runner import StyleGAN3Runner 7 | from .volumegan_runner import VolumeGANRunner 8 | __all__ = ['build_runner'] 9 | 10 | _RUNNERS = { 11 | 'StyleGANRunner': StyleGANRunner, 12 | 'StyleGAN2Runner': StyleGAN2Runner, 13 | 'StyleGAN3Runner': StyleGAN3Runner, 14 | 'VolumeGANRunner': VolumeGANRunner, 15 | } 16 | 17 | 18 | def build_runner(config): 19 | """Builds a runner with given configuration. 20 | 21 | Args: 22 | config: Configurations used to build the runner. 23 | 24 | Raises: 25 | ValueError: If the `config.runner_type` is not supported. 26 | """ 27 | if not isinstance(config, dict) or 'runner_type' not in config: 28 | raise ValueError('`runner_type` is missing from configuration!') 29 | 30 | runner_type = config['runner_type'] 31 | if runner_type not in _RUNNERS: 32 | raise ValueError(f'Invalid runner type: `{runner_type}`!\n' 33 | f'Types allowed: {list(_RUNNERS)}.') 34 | return _RUNNERS[runner_type](config) 35 | -------------------------------------------------------------------------------- /runners/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all augmentation pipelines.""" 3 | 4 | from .no_aug import NoAug 5 | from .ada_aug import AdaAug 6 | 7 | __all__ = ['build_aug'] 8 | 9 | _AUGMENTATIONS = { 10 | 'NoAug': NoAug, 11 | 'AdaAug': AdaAug 12 | } 13 | 14 | 15 | def build_aug(aug_type, **kwargs): 16 | """Builds a differentiable augmentation pipeline based on its class type. 17 | 18 | Args: 19 | aug_type: Class type to which the augmentation belongs, which is case 20 | sensitive. 21 | **kwargs: Additional arguments to build the aug. 22 | 23 | Raises: 24 | ValueError: If the `aug_type` is not supported. 25 | """ 26 | if aug_type not in _AUGMENTATIONS: 27 | raise ValueError(f'Invalid augmentation type: `{aug_type}`!\n' 28 | f'Types allowed: {list(_AUGMENTATIONS)}.') 29 | return _AUGMENTATIONS[aug_type](**kwargs) 30 | -------------------------------------------------------------------------------- /runners/augmentations/no_aug.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Defines the dummy augmentation pipeline that executes no augmentation.""" 3 | 4 | import torch.nn as nn 5 | 6 | __all__ = ['NoAug'] 7 | 8 | 9 | NoAug = nn.Identity 10 | -------------------------------------------------------------------------------- /runners/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all controllers.""" 3 | 4 | from .ada_aug_controller import AdaAugController 5 | from .batch_visualizer import BatchVisualizer 6 | from .cache_cleaner import CacheCleaner 7 | from .checkpointer import Checkpointer 8 | from .dataset_visualizer import DatasetVisualizer 9 | from .evaluator import Evaluator 10 | from .lr_scheduler import LRScheduler 11 | from .progress_scheduler import ProgressScheduler 12 | from .running_logger import RunningLogger 13 | from .timer import Timer 14 | 15 | __all__ = ['build_controller'] 16 | 17 | _CONTROLLERS = { 18 | 'AdaAugController': AdaAugController, 19 | 'BatchVisualizer': BatchVisualizer, 20 | 'CacheCleaner': CacheCleaner, 21 | 'Checkpointer': Checkpointer, 22 | 'DatasetVisualizer': DatasetVisualizer, 23 | 'Evaluator': Evaluator, 24 | 'LRScheduler': LRScheduler, 25 | 'ProgressScheduler': ProgressScheduler, 26 | 'RunningLogger': RunningLogger, 27 | 'Timer': Timer 28 | } 29 | 30 | 31 | def build_controller(controller_type, config=None): 32 | """Builds a controller based on its class type. 33 | 34 | Args: 35 | controller_type: Class type to which the controller belongs, which is 36 | case sensitive. 37 | config: Configuration of the controller. (default: None) 38 | 39 | Raises: 40 | ValueError: If the `controller_type` is not supported. 41 | """ 42 | if controller_type not in _CONTROLLERS: 43 | raise ValueError(f'Invalid controller type: `{controller_type}`!\n' 44 | f'Types allowed: {list(_CONTROLLERS)}.') 45 | return _CONTROLLERS[controller_type](config) 46 | -------------------------------------------------------------------------------- /runners/controllers/ada_aug_controller.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller to control augmentation strength.""" 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from ..utils.running_stats import SingleStats 9 | from ..augmentations.ada_aug import AdaAug 10 | from .base_controller import BaseController 11 | 12 | __all__ = ['AdaAugController'] 13 | 14 | 15 | class AdaAugController(BaseController): 16 | """Defines the running controller to adjust the strength of augmentations. 17 | 18 | This controller works together with the augmentation pipeline introduces by 19 | StyleGAN2-ADA (https://arxiv.org/pdf/2006.06676.pdf). Concretely, AadAug, 20 | which is defined in `runners/augmentations/ada_aug.py`, augments the data 21 | based on an adjustable probability. This controller controls how this 22 | probability is adjusted. 23 | 24 | NOTE: The controller is set to `FIRST` priority. 25 | 26 | Basically, the aug_config is expected to contain following settings: 27 | 28 | (1) init_p: The init prob of augmentations. (default: 0.0) 29 | (2) target_p: The target (final) prob of augmentations. (default: 0.6) 30 | (3) every_n_iters: How often to adjust the probability. (default: 4) 31 | (4) speed_img: Speed to adjust the probability, which is measured in number 32 | of images it takes for the probability to increase/decrease by one unit. 33 | (default: 500_000) 34 | (5) strategy: The strategy to adjust the probability. Support `fixed`, 35 | `linear`, and `adaptive`. (default: `adaptive`) 36 | """ 37 | 38 | def __init__(self, config): 39 | assert isinstance(config, dict) 40 | config.setdefault('priority', 'First') 41 | config.setdefault('every_n_iters', 4) 42 | config.setdefault('first_iter', False) 43 | super().__init__(config) 44 | 45 | self._init_p = config.get('init_p', 0.0) 46 | self._target_p = config.get('target_p', 0.6) 47 | self._speed_img = config.get('speed_img', 500_000) 48 | self._strategy = config.get('strategy', 'adaptive').lower() 49 | assert self._strategy in ['fixed', 'linear', 'adaptive'] 50 | 51 | def setup(self, runner): 52 | """Sets the initial augmentation strength before training.""" 53 | if not isinstance(runner.augment, AdaAug): 54 | raise ValueError(f'`{self.name}` only works together with ' 55 | f'adaptive augmentation pipeline `AdaAug`!\n') 56 | 57 | if self._strategy == 'fixed': 58 | aug_prob = self._target_p 59 | else: 60 | aug_prob = self._init_p 61 | runner.augment.p = torch.as_tensor(aug_prob).relu() 62 | runner.augment.prob_tracker = SingleStats('Aug Prob Tracker', 63 | log_format=None, 64 | log_strategy='AVERAGE', 65 | requires_sync=True) 66 | runner.running_stats.add('Misc/Aug Prob', 67 | log_name='aug_prob', 68 | log_format='.3f', 69 | log_strategy='CURRENT', 70 | requires_sync=False, 71 | keep_previous=True) 72 | 73 | runner.logger.info('Adaptive augmentation settings:', indent_level=2) 74 | runner.logger.info(f'Strategy: {self._strategy}', indent_level=3) 75 | runner.logger.info(f'Initial probability: {self._init_p}', 76 | indent_level=3) 77 | runner.logger.info(f'Target probability : {self._target_p}', 78 | indent_level=3) 79 | runner.logger.info(f'Adjustment speed {self._speed_img} images', 80 | indent_level=3) 81 | super().setup(runner) 82 | 83 | def execute_after_iteration(self, runner): 84 | if self._strategy == 'fixed': 85 | aug_prob = self._target_p 86 | elif self._strategy == 'linear': 87 | slope = runner.iter / runner.total_iters 88 | aug_prob = self._init_p + (self._target_p - self._init_p) * slope 89 | else: 90 | minibatch = runner.batch_size * runner.world_size 91 | slope = (minibatch * self.every_n_iters) / self._speed_img 92 | criterion = runner.augment.prob_tracker.summarize() 93 | adjust = np.sign(criterion - self._target_p) * slope 94 | aug_prob = runner.augment.p + adjust 95 | 96 | runner.augment.p = torch.as_tensor(aug_prob).relu() 97 | runner.running_stats.update({'Misc/Aug Prob': runner.augment.p}) 98 | -------------------------------------------------------------------------------- /runners/controllers/cache_cleaner.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller to clean cache.""" 3 | 4 | import torch 5 | 6 | from .base_controller import BaseController 7 | 8 | __all__ = ['CacheCleaner'] 9 | 10 | 11 | class CacheCleaner(BaseController): 12 | """Defines the running controller to clean cache. 13 | 14 | This controller is used to empty the GPU cache after each iteration. 15 | 16 | NOTE: The controller is set to `LAST` priority by default. 17 | """ 18 | 19 | def __init__(self, config=None): 20 | config = config or dict() 21 | config.setdefault('priority', 'LAST') 22 | config.setdefault('every_n_iters', 1) 23 | super().__init__(config) 24 | 25 | def setup(self, runner): 26 | torch.cuda.empty_cache() 27 | super().setup(runner) 28 | 29 | def close(self, runner): 30 | torch.cuda.empty_cache() 31 | 32 | def execute_after_iteration(self, runner): 33 | torch.cuda.empty_cache() 34 | -------------------------------------------------------------------------------- /runners/controllers/timer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the running controller to record time.""" 3 | 4 | import time 5 | 6 | from .base_controller import BaseController 7 | 8 | __all__ = ['Timer'] 9 | 10 | 11 | class Timer(BaseController): 12 | """Defines the running controller to record running time. 13 | 14 | This controller will be executed every iteration (both before and after) to 15 | summarize the data preparation time as well as the model running time. 16 | Besides, this controller will also mark the start and end time of the 17 | running process. 18 | 19 | NOTE: This controller is set to `LOW` priority by default and will only be 20 | executed on the chief worker. 21 | """ 22 | 23 | def __init__(self, config=None): 24 | config = config or dict() 25 | config.setdefault('priority', 'LOW') 26 | config.setdefault('every_n_iters', 1) 27 | config.setdefault('chief_only', True) 28 | super().__init__(config) 29 | 30 | self.time = time.time() 31 | 32 | def setup(self, runner): 33 | runner.running_stats.add( 34 | 'data time', log_format='time', requires_sync=False) 35 | runner.running_stats.add( 36 | 'iter time', log_format='time', requires_sync=False) 37 | runner.running_stats.add('run time', 38 | log_format='time', 39 | log_strategy='CURRENT', 40 | requires_sync=False) 41 | self.time = time.time() 42 | runner.start_time = self.time 43 | 44 | def close(self, runner): 45 | runner.end_time = time.time() 46 | 47 | def execute_before_iteration(self, runner): 48 | start_time = time.time() 49 | runner.running_stats.update({'data time': start_time - self.time}) 50 | 51 | def execute_after_iteration(self, runner): 52 | end_time = time.time() 53 | runner.running_stats.update({'iter time': end_time - self.time}) 54 | runner.running_stats.update({'run time': end_time - runner.start_time}) 55 | self.time = end_time 56 | -------------------------------------------------------------------------------- /runners/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all loss functions.""" 3 | 4 | from .stylegan_loss import StyleGANLoss 5 | from .stylegan2_loss import StyleGAN2Loss 6 | from .stylegan3_loss import StyleGAN3Loss 7 | from .volumegan_loss import VolumeGANLoss 8 | __all__ = ['build_loss'] 9 | 10 | _LOSSES = { 11 | 'StyleGANLoss': StyleGANLoss, 12 | 'StyleGAN2Loss': StyleGAN2Loss, 13 | 'StyleGAN3Loss': StyleGAN3Loss, 14 | 'VolumeGANLoss': VolumeGANLoss 15 | } 16 | 17 | 18 | def build_loss(runner, loss_type, **kwargs): 19 | """Builds a loss based on its class type. 20 | 21 | Args: 22 | runner: The runner on which the loss is built. 23 | loss_type: Class type to which the loss belongs, which is case 24 | sensitive. 25 | **kwargs: Additional arguments to build the loss. 26 | 27 | Raises: 28 | ValueError: If the `loss_type` is not supported. 29 | """ 30 | if loss_type not in _LOSSES: 31 | raise ValueError(f'Invalid loss type: `{loss_type}`!\n' 32 | f'Types allowed: {list(_LOSSES)}.') 33 | return _LOSSES[loss_type](runner, **kwargs) 34 | -------------------------------------------------------------------------------- /runners/losses/base_loss.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the base class to implement loss.""" 3 | 4 | import torch 5 | 6 | __all__ = ['BaseLoss'] 7 | 8 | 9 | class BaseLoss(object): 10 | """Base loss class. 11 | 12 | The derived class can easily serialize its members to a `dict`, and to load 13 | from such `dict` to resume the saved loss. 14 | 15 | NOTE: By default, the derived class will save ALL members. To ensure members 16 | are saved and loaded as expectation, you may need to override the 17 | `state_dict()` or `load_state_dict()` method. 18 | """ 19 | 20 | @property 21 | def name(self): 22 | """Returns the class name of the loss.""" 23 | return self.__class__.__name__ 24 | 25 | def state_dict(self): 26 | """Returns a serialized `dict` that records all members. 27 | 28 | The returned `dict` maps attribute names to their values. 29 | 30 | NOTE: Override this method if such default behavior is unexpected. 31 | """ 32 | return vars(self) 33 | 34 | def load_state_dict(self, state_dict): 35 | """Loads parameters from the `state_dict`. 36 | 37 | By default, this method directly sets all attributes from the given 38 | `state_dict`. 39 | 40 | NOTE: Override this method if such default behavior is unexpected. 41 | """ 42 | for key, val in state_dict.items(): 43 | if not hasattr(self, key): # current loss does not init with `key` 44 | continue 45 | origin_attr = getattr(self, key) 46 | if isinstance(origin_attr, torch.nn.Module): 47 | origin_attr.load_state_dict(val.state_dict()) 48 | continue 49 | if isinstance(origin_attr, torch.Tensor): 50 | val = val.to(device=origin_attr.device) 51 | setattr(self, key, val) 52 | -------------------------------------------------------------------------------- /runners/stylegan2_runner.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the runner for StyleGAN2.""" 3 | 4 | from copy import deepcopy 5 | 6 | from .base_runner import BaseRunner 7 | 8 | __all__ = ['StyleGAN2Runner'] 9 | 10 | 11 | class StyleGAN2Runner(BaseRunner): 12 | """Defines the runner for StyleGAN2.""" 13 | 14 | def build_models(self): 15 | super().build_models() 16 | 17 | self.g_ema_img = self.config.models['generator'].get( 18 | 'g_ema_img', 10_000) 19 | self.g_ema_rampup = self.config.models['generator'].get( 20 | 'g_ema_rampup', 0) 21 | if 'generator_smooth' not in self.models: 22 | self.models['generator_smooth'] = deepcopy(self.models['generator']) 23 | self.model_kwargs_init['generator_smooth'] = deepcopy( 24 | self.model_kwargs_init['generator']) 25 | if 'generator_smooth' not in self.model_kwargs_val: 26 | self.model_kwargs_val['generator_smooth'] = deepcopy( 27 | self.model_kwargs_val['generator']) 28 | 29 | def build_loss(self): 30 | super().build_loss() 31 | self.running_stats.add('Misc/Gs Beta', 32 | log_name='Gs_beta', 33 | log_format='.4f', 34 | log_strategy='CURRENT') 35 | 36 | def train_step(self, data): 37 | # Update generator. 38 | self.models['discriminator'].requires_grad_(False) 39 | self.models['generator'].requires_grad_(True) 40 | 41 | # Update with adversarial loss. 42 | g_loss = self.loss.g_loss(self, data, sync=True) 43 | self.zero_grad_optimizer('generator') 44 | g_loss.backward() 45 | self.step_optimizer('generator') 46 | 47 | # Update with perceptual path length regularization if needed. 48 | pl_penalty = self.loss.g_reg(self, data, sync=True) 49 | if pl_penalty is not None: 50 | self.zero_grad_optimizer('generator') 51 | pl_penalty.backward() 52 | self.step_optimizer('generator') 53 | 54 | # Update discriminator. 55 | self.models['discriminator'].requires_grad_(True) 56 | self.models['generator'].requires_grad_(False) 57 | 58 | # Update with adversarial loss. 59 | self.zero_grad_optimizer('discriminator') 60 | # Update with fake images (get synchronized together with real loss). 61 | d_fake_loss = self.loss.d_fake_loss(self, data, sync=False) 62 | d_fake_loss.backward() 63 | # Update with real images. 64 | d_real_loss = self.loss.d_real_loss(self, data, sync=True) 65 | d_real_loss.backward() 66 | self.step_optimizer('discriminator') 67 | 68 | # Update with gradient penalty. 69 | r1_penalty = self.loss.d_reg(self, data, sync=True) 70 | if r1_penalty is not None: 71 | self.zero_grad_optimizer('discriminator') 72 | r1_penalty.backward() 73 | self.step_optimizer('discriminator') 74 | 75 | # Life-long update generator. 76 | if self.g_ema_rampup is not None and self.g_ema_rampup > 0: 77 | g_ema_img = min(self.g_ema_img, self.seen_img * self.g_ema_rampup) 78 | else: 79 | g_ema_img = self.g_ema_img 80 | beta = 0.5 ** (self.minibatch / max(g_ema_img, 1e-8)) 81 | self.running_stats.update({'Misc/Gs Beta': beta}) 82 | self.smooth_model(src=self.models['generator'], 83 | avg=self.models['generator_smooth'], 84 | beta=beta) 85 | -------------------------------------------------------------------------------- /runners/stylegan3_runner.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the runner for StyleGAN3.""" 3 | 4 | from copy import deepcopy 5 | 6 | from .base_runner import BaseRunner 7 | 8 | __all__ = ['StyleGAN3Runner'] 9 | 10 | 11 | class StyleGAN3Runner(BaseRunner): 12 | """Defines the runner for StyleGAN3.""" 13 | 14 | def build_models(self): 15 | super().build_models() 16 | 17 | g_ema_img = self.config.models['generator'].get('g_ema_img', 10_000) 18 | self.g_ema_img = g_ema_img * self.minibatch / 32 19 | self.g_ema_rampup = self.config.models['generator'].get( 20 | 'g_ema_rampup', 0) 21 | if 'generator_smooth' not in self.models: 22 | self.models['generator_smooth'] = deepcopy(self.models['generator']) 23 | self.model_kwargs_init['generator_smooth'] = deepcopy( 24 | self.model_kwargs_init['generator']) 25 | if 'generator_smooth' not in self.model_kwargs_val: 26 | self.model_kwargs_val['generator_smooth'] = deepcopy( 27 | self.model_kwargs_val['generator']) 28 | 29 | def build_loss(self): 30 | super().build_loss() 31 | self.running_stats.add('Misc/Gs Beta', 32 | log_name='Gs_beta', 33 | log_format='.4f', 34 | log_strategy='CURRENT') 35 | 36 | def train_step(self, data): 37 | # Update generator. 38 | self.models['discriminator'].requires_grad_(False) 39 | self.models['generator'].requires_grad_(True) 40 | 41 | # Update with adversarial loss. 42 | g_loss = self.loss.g_loss(self, data, sync=True) 43 | self.zero_grad_optimizer('generator') 44 | g_loss.backward() 45 | self.step_optimizer('generator') 46 | 47 | # Update with perceptual path length regularization if needed. 48 | pl_penalty = self.loss.g_reg(self, data, sync=True) 49 | if pl_penalty is not None: 50 | self.zero_grad_optimizer('generator') 51 | pl_penalty.backward() 52 | self.step_optimizer('generator') 53 | 54 | # Update discriminator. 55 | self.models['discriminator'].requires_grad_(True) 56 | self.models['generator'].requires_grad_(False) 57 | 58 | # Update with adversarial loss. 59 | self.zero_grad_optimizer('discriminator') 60 | # Update with fake images (get synchronized together with real loss). 61 | d_fake_loss = self.loss.d_fake_loss(self, data, sync=False) 62 | d_fake_loss.backward() 63 | # Update with real images. 64 | d_real_loss = self.loss.d_real_loss(self, data, sync=True) 65 | d_real_loss.backward() 66 | self.step_optimizer('discriminator') 67 | 68 | # Update with gradient penalty. 69 | r1_penalty = self.loss.d_reg(self, data, sync=True) 70 | if r1_penalty is not None: 71 | self.zero_grad_optimizer('discriminator') 72 | r1_penalty.backward() 73 | self.step_optimizer('discriminator') 74 | 75 | # Life-long update generator. 76 | if self.g_ema_rampup is not None and self.g_ema_rampup > 0: 77 | g_ema_img = min(self.g_ema_img, self.seen_img * self.g_ema_rampup) 78 | else: 79 | g_ema_img = self.g_ema_img 80 | beta = 0.5 ** (self.minibatch / max(g_ema_img, 1e-8)) 81 | self.running_stats.update({'Misc/Gs Beta': beta}) 82 | self.smooth_model(src=self.models['generator'], 83 | avg=self.models['generator_smooth'], 84 | beta=beta) 85 | -------------------------------------------------------------------------------- /runners/stylegan_runner.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the runner for StyleGAN.""" 3 | 4 | from copy import deepcopy 5 | 6 | from .base_runner import BaseRunner 7 | 8 | __all__ = ['StyleGANRunner'] 9 | 10 | 11 | class StyleGANRunner(BaseRunner): 12 | """Defines the runner for StyleGAN.""" 13 | 14 | def __init__(self, config): 15 | super().__init__(config) 16 | self.lod = getattr(self, 'lod', 0.0) 17 | self.D_repeats = self.config.get('D_repeats', 1) 18 | 19 | def build_models(self): 20 | super().build_models() 21 | self.g_ema_img = self.config.models['generator'].get( 22 | 'g_ema_img', 10_000) 23 | if 'generator_smooth' not in self.models: 24 | self.models['generator_smooth'] = deepcopy(self.models['generator']) 25 | self.model_kwargs_init['generator_smooth'] = deepcopy( 26 | self.model_kwargs_init['generator']) 27 | if 'generator_smooth' not in self.model_kwargs_val: 28 | self.model_kwargs_val['generator_smooth'] = deepcopy( 29 | self.model_kwargs_val['generator']) 30 | 31 | def build_loss(self): 32 | super().build_loss() 33 | self.running_stats.add('Misc/Gs Beta', 34 | log_name='Gs_beta', 35 | log_format='.4f', 36 | log_strategy='CURRENT') 37 | 38 | def train_step(self, data): 39 | # Set level-of-details. 40 | G = self.models['generator'] 41 | D = self.models['discriminator'] 42 | Gs = self.models['generator_smooth'] 43 | G.synthesis.lod.data.fill_(self.lod) 44 | D.lod.data.fill_(self.lod) 45 | Gs.synthesis.lod.data.fill_(self.lod) 46 | 47 | # Update discriminator. 48 | self.models['discriminator'].requires_grad_(True) 49 | self.models['generator'].requires_grad_(False) 50 | d_loss = self.loss.d_loss(self, data, sync=True) 51 | self.zero_grad_optimizer('discriminator') 52 | d_loss.backward() 53 | self.step_optimizer('discriminator') 54 | 55 | # Life-long update for generator. 56 | beta = 0.5 ** (self.minibatch / self.g_ema_img) 57 | self.running_stats.update({'Misc/Gs Beta': beta}) 58 | self.smooth_model(src=self.models['generator'], 59 | avg=self.models['generator_smooth'], 60 | beta=beta) 61 | 62 | # Update generator. 63 | if self.iter % self.D_repeats == 0: 64 | self.models['discriminator'].requires_grad_(False) 65 | self.models['generator'].requires_grad_(True) 66 | g_loss = self.loss.g_loss(self, data, sync=True) 67 | self.zero_grad_optimizer('generator') 68 | g_loss.backward() 69 | self.step_optimizer('generator') 70 | 71 | # Update automatic mixed-precision scaler. 72 | self.amp_scaler.update() 73 | -------------------------------------------------------------------------------- /runners/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/runners/utils/__init__.py -------------------------------------------------------------------------------- /runners/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the function to build optimizer for a model.""" 3 | 4 | import torch 5 | 6 | __all__ = ['build_optimizer'] 7 | 8 | _ALLOWED_OPT_TYPES = ['sgd', 'adam'] 9 | 10 | 11 | def build_optimizer(config, model): 12 | """Builds an optimizer for the given model. 13 | 14 | Basically, the configuration is expected to contain following settings: 15 | 16 | (1) opt_type: The type of the optimizer. (required) 17 | (2) base_lr: The base learning rate for all parameters. (required) 18 | (3) base_wd: The base weight decay for all parameters. (default: 0.0) 19 | (4) bias_lr_multiplier: The learning rate multiplier for bias parameters. 20 | (default: 1.0) 21 | (5) bias_wd_multiplier: The weight decay multiplier for bias parameters. 22 | (default: 1.0) 23 | (6) **kwargs: Additional settings for the optimizer, such as `momentum`. 24 | 25 | Args: 26 | config: The configuration used to build the optimizer. 27 | model: The model which the optimizer serves. 28 | 29 | Returns: 30 | A `torch.optim.Optimizer`. 31 | 32 | Raises: 33 | ValueError: The `opt_type` is not supported. 34 | NotImplementedError: If `opt_type` is not implemented. 35 | """ 36 | assert isinstance(config, dict) 37 | opt_type = config['opt_type'].lower() 38 | base_lr = config['base_lr'] 39 | base_wd = config.get('base_wd', 0.0) 40 | bias_lr_multiplier = config.get('bias_lr_multiplier', 1.0) 41 | bias_wd_multiplier = config.get('bias_wd_multiplier', 1.0) 42 | 43 | if opt_type not in _ALLOWED_OPT_TYPES: 44 | raise ValueError(f'Invalid optimizer type `{opt_type}`!' 45 | f'Allowed types: {_ALLOWED_OPT_TYPES}.') 46 | 47 | model_params = [] 48 | for param_name, param in model.named_parameters(): 49 | param_group = {'params': [param]} 50 | if 'bias' in param_name: 51 | param_group['lr'] = base_lr * bias_lr_multiplier 52 | param_group['weight_decay'] = base_wd * bias_wd_multiplier 53 | else: 54 | param_group['lr'] = base_lr 55 | param_group['weight_decay'] = base_wd 56 | model_params.append(param_group) 57 | 58 | if opt_type == 'sgd': 59 | return torch.optim.SGD(params=model_params, 60 | lr=base_lr, 61 | momentum=config.get('momentum', 0.9), 62 | dampening=config.get('dampening', 0), 63 | weight_decay=base_wd, 64 | nesterov=config.get('nesterov', False)) 65 | if opt_type == 'adam': 66 | return torch.optim.Adam(params=model_params, 67 | lr=base_lr, 68 | betas=config.get('betas', (0.9, 0.999)), 69 | eps=config.get('eps', 1e-8), 70 | weight_decay=base_wd, 71 | amsgrad=config.get('amsgrad', False)) 72 | raise NotImplementedError(f'Not implemented optimizer type `{opt_type}`!') 73 | -------------------------------------------------------------------------------- /runners/utils/profiler.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class for profiling.""" 3 | 4 | import torch 5 | from torch import distributed as dist 6 | 7 | from utils.tf_utils import import_tb_writer 8 | 9 | SummaryWriter = import_tb_writer() 10 | 11 | __all__ = ['Profiler'] 12 | 13 | 14 | class Profiler(object): 15 | """Defines the profiler. 16 | 17 | Essentially, this is a wrapper of `torch.profiler.profile`. 18 | If `enable` is set to `False`, this profiler becomes a dummy context manager 19 | with a dummy `step()`. 20 | 21 | Args: 22 | enable: `bool`, whether to enable `torch.profiler`. 23 | tb_dir: `str`, path to save profiler's TensorBoard events file. 24 | logger: `utils.loggers` or `logging.Logger`, the event logging system. 25 | **schedule_kwargs: settings to the `schedule` of 26 | `torch.profiler.profile`. The profiler will skip the first 27 | `skip_first` steps, then wait for `wait` steps, then do the warmup 28 | for the next `warmup` steps, then do the active recording for the 29 | next `active` steps and then repeat the cycle starting with `wait` 30 | steps. (default: dict(wait=1, warmup=1, active=3, repeat=2)) 31 | """ 32 | def __init__(self, 33 | enable=False, 34 | tb_dir='.', 35 | logger=None, 36 | **schedule_kwargs): 37 | self.enable = enable 38 | rank = dist.get_rank() if dist.is_initialized() else 0 39 | if enable and SummaryWriter is not None: 40 | try: # In case the PyTorch version is outdated. 41 | if rank == 0: 42 | SummaryWriter(tb_dir) 43 | if dist.is_initialized(): 44 | dist.barrier() # Only create one TensorBoard event. 45 | if schedule_kwargs is None: 46 | schedule_kwargs = dict(wait=1, warmup=1, active=3, repeat=2) 47 | # Profile CPU and each GPU. 48 | self.profiler = torch.profiler.profile( 49 | schedule=torch.profiler.schedule(**schedule_kwargs), 50 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 51 | tb_dir), 52 | record_shapes=True, 53 | with_stack=True) 54 | if logger: 55 | logger.info(f'Enable profiler with schedule: ' 56 | f'{schedule_kwargs}.\n') 57 | except AttributeError as error: 58 | logger.warning(f'Skipping profiler due to {error}!\n' 59 | f'Please update your PyTorch to 1.8.1 or ' 60 | f'later to enable profiler.\n') 61 | self.enable = False 62 | 63 | def __enter__(self): 64 | if self.enable: 65 | return self.profiler.__enter__() 66 | return self 67 | 68 | def __exit__(self, exc_type, exc_val, exc_tb): 69 | if self.enable: 70 | self.profiler.__exit__(exc_type, exc_val, exc_tb) 71 | 72 | def step(self): 73 | """Executes the profiler for one step.""" 74 | if self.enable: 75 | self.profiler.step() 76 | -------------------------------------------------------------------------------- /runners/volumegan_runner.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the runner for VolumeGAN.""" 3 | 4 | from copy import deepcopy 5 | 6 | from .base_runner import BaseRunner 7 | 8 | __all__ = ['VolumeGANRunner'] 9 | 10 | 11 | class VolumeGANRunner(BaseRunner): 12 | """Defines the runner for VolumeGAN.""" 13 | def __init__(self, config): 14 | super().__init__(config) 15 | self.lod = getattr(self, 'lod', 0.0) 16 | 17 | def build_models(self): 18 | super().build_models() 19 | 20 | self.g_ema_img = self.config.models['generator'].get( 21 | 'g_ema_img', 10_000) 22 | self.g_ema_rampup = self.config.models['generator'].get( 23 | 'g_ema_rampup', 0) 24 | if 'generator_smooth' not in self.models: 25 | self.models['generator_smooth'] = deepcopy(self.models['generator']) 26 | self.model_kwargs_init['generator_smooth'] = deepcopy( 27 | self.model_kwargs_init['generator']) 28 | if 'generator_smooth' not in self.model_kwargs_val: 29 | self.model_kwargs_val['generator_smooth'] = deepcopy( 30 | self.model_kwargs_val['generator']) 31 | 32 | def build_loss(self): 33 | super().build_loss() 34 | self.running_stats.add('Misc/Gs Beta', 35 | log_name='Gs_beta', 36 | log_format='.4f', 37 | log_strategy='CURRENT') 38 | 39 | def train_step(self, data): 40 | # Set lod for progressive training 41 | self.models['generator'].synthesis.lod.data.fill_(self.lod) 42 | self.models['discriminator'].lod.data.fill_(self.lod) 43 | self.models['generator_smooth'].synthesis.lod.data.fill_(self.lod) 44 | 45 | # Update generator. 46 | self.models['discriminator'].requires_grad_(False) 47 | self.models['generator'].requires_grad_(True) 48 | 49 | # Update with adversarial loss. 50 | g_loss = self.loss.g_loss(self, data, sync=True) 51 | self.zero_grad_optimizer('generator') 52 | g_loss.backward() 53 | self.step_optimizer('generator') 54 | 55 | # Update with perceptual path length regularization if needed. 56 | pl_penalty = self.loss.g_reg(self, data, sync=True) 57 | if pl_penalty is not None: 58 | self.zero_grad_optimizer('generator') 59 | pl_penalty.backward() 60 | self.step_optimizer('generator') 61 | 62 | # Update discriminator. 63 | self.models['discriminator'].requires_grad_(True) 64 | self.models['generator'].requires_grad_(False) 65 | 66 | # Update with adversarial loss. 67 | self.zero_grad_optimizer('discriminator') 68 | # Update with fake images (get synchronized together with real loss). 69 | d_fake_loss = self.loss.d_fake_loss(self, data, sync=False) 70 | d_fake_loss.backward() 71 | # Update with real images. 72 | d_real_loss = self.loss.d_real_loss(self, data, sync=True) 73 | d_real_loss.backward() 74 | self.step_optimizer('discriminator') 75 | 76 | # Update with gradient penalty. 77 | r1_penalty = self.loss.d_reg(self, data, sync=True) 78 | if r1_penalty is not None: 79 | self.zero_grad_optimizer('discriminator') 80 | r1_penalty.backward() 81 | self.step_optimizer('discriminator') 82 | 83 | # Life-long update generator. 84 | if self.g_ema_rampup is not None and self.g_ema_rampup > 0: 85 | g_ema_img = min(self.g_ema_img, self.seen_img * self.g_ema_rampup) 86 | else: 87 | g_ema_img = self.g_ema_img 88 | beta = 0.5 ** (self.minibatch / max(g_ema_img, 1e-8)) 89 | self.running_stats.update({'Misc/Gs Beta': beta}) 90 | self.smooth_model(src=self.models['generator'], 91 | avg=self.models['generator_smooth'], 92 | beta=beta) 93 | -------------------------------------------------------------------------------- /scripts/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Detect `python3` command. 4 | # This workaround addresses a common issue: 5 | # `python` points to `python2`, which is deprecated. 6 | export PYTHONS 7 | export RVAL 8 | 9 | PYTHONS=$(compgen -c | grep "^python3$") 10 | 11 | # `$?` is a built-in variable in bash, which is the exit status of the most 12 | # recently-executed command; by convention, 0 means success and anything else 13 | # indicates failure. 14 | RVAL=$? 15 | 16 | if [[ $RVAL -eq 0 ]]; then # if `python3` exist 17 | PYTHON="python3" 18 | else 19 | PYTHON="python" 20 | fi 21 | 22 | # Help message. 23 | if [[ $# -lt 2 ]]; then 24 | echo "This script helps launch distributed training job on local machine." 25 | echo 26 | echo "Usage: $0 GPUS COMMAND [ARGS]" 27 | echo 28 | echo "Example: $0 8 stylegan2 [--help]" 29 | echo 30 | echo "Detailed instruction on available commands:" 31 | echo "--------------------------------------------------" 32 | ${PYTHON} ./train.py --help 33 | echo 34 | exit 0 35 | fi 36 | 37 | GPUS=$1 38 | COMMAND=$2 39 | 40 | # Help message for a particular command. 41 | if [[ $# -lt 3 || ${*: -1} == "--help" ]]; then 42 | echo "Detailed instruction on the arguments for command \`"${COMMAND}"\`:" 43 | echo "--------------------------------------------------" 44 | ${PYTHON} ./train.py ${COMMAND} --help 45 | echo 46 | exit 0 47 | fi 48 | 49 | # Switch memory allocator if available. 50 | # Search order: jemalloc.so -> tcmalloc.so. 51 | # According to https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html, 52 | # it can get better performance by reusing memory as much as possible than 53 | # default malloc function. 54 | JEMALLOC=$(ldconfig -p | grep -i "libjemalloc.so$" | tr " " "\n" | grep "/" \ 55 | | head -n 1) 56 | TCMALLOC=$(ldconfig -p | grep -i "libtcmalloc.so.4$" | tr " " "\n" | grep "/" \ 57 | | head -n 1) 58 | if [ -n "$JEMALLOC" ]; then # if found the path to libjemalloc.so 59 | echo "Switch memory allocator to jemalloc." 60 | export LD_PRELOAD=$JEMALLOC:$LD_PRELOAD 61 | elif [ -n "$TCMALLOC" ]; then # if found the path to libtcmalloc.so.4 62 | echo "Switch memory allocator to tcmalloc." 63 | export LD_PRELOAD=$TCMALLOC:$LD_PRELOAD 64 | fi 65 | 66 | # Get an available port for launching distributed training. 67 | # Credit to https://superuser.com/a/1293762. 68 | export DEFAULT_FREE_PORT 69 | DEFAULT_FREE_PORT=$(comm -23 <(seq 49152 65535 | sort) \ 70 | <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) \ 71 | | shuf | head -n 1) 72 | 73 | PORT=${PORT:-$DEFAULT_FREE_PORT} 74 | 75 | ${PYTHON} -m torch.distributed.launch \ 76 | --nproc_per_node=${GPUS} \ 77 | --master_port=${PORT} \ 78 | ./train.py \ 79 | --launcher="pytorch" \ 80 | --backend="nccl" \ 81 | ${COMMAND} \ 82 | ${@:3} \ 83 | || exit 1 # Stop the script when it finds exception threw by Python. 84 | -------------------------------------------------------------------------------- /scripts/kill_zombies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help information. 4 | if [[ $# -le 0 || ${*: -1} == "-h" || ${*: -1} == "--help" ]]; then 5 | echo "This script kills processes launched with" \ 6 | "\`./scripts/dist_train.sh\`, with arguments as keywords to filter." 7 | echo 8 | echo "Note: It does NOT check whether they are zombies. Hence," \ 9 | "to ensure killing the desired processes rather than innocent ones," \ 10 | "you MUST provide sufficient arguments to identified targets." 11 | echo 12 | echo "Usage: $0 [any arguments pass to your \`dist_train.sh\`]." 13 | echo 14 | echo "Example: $0 configs/stylegan2_config.py --work_dir work_dirs/debug" 15 | echo 16 | exit 0 17 | fi 18 | 19 | kill -9 echo $(ps ux | grep "$*" | grep -v grep | awk '{print $2}') 20 | -------------------------------------------------------------------------------- /scripts/test_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help information. 4 | if [[ $# -lt 4 || ${*: -1} == "-h" || ${*: -1} == "--help" ]]; then 5 | echo "This script tests metrics defined in \`./metrics/\`." 6 | echo 7 | echo "Usage: $0 GPUS DATASET MODEL METRICS" 8 | echo 9 | echo "Note: More than one metric should be separated by comma." \ 10 | "Also, all metrics assume using all samples from the real dataset" \ 11 | "and 50000 fake samples for GAN-related metrics." 12 | echo 13 | echo "Example: $0 1 ~/data/ffhq1024.zip ~/checkpoints/ffhq1024.pth" \ 14 | "fid,is,kid,gan_pr,snapshot,equivariance" 15 | echo 16 | exit 0 17 | fi 18 | 19 | # Get an available port for launching distributed training. 20 | # Credit to https://superuser.com/a/1293762. 21 | export DEFAULT_FREE_PORT 22 | DEFAULT_FREE_PORT=$(comm -23 <(seq 49152 65535 | sort) \ 23 | <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) \ 24 | | shuf | head -n 1) 25 | 26 | GPUS=$1 27 | DATASET=$2 28 | MODEL=$3 29 | PORT=${PORT:-$DEFAULT_FREE_PORT} 30 | 31 | # Parse metrics to test. 32 | METRICS=$4 33 | TEST_FID="false" 34 | TEST_IS="false" 35 | TEST_KID="false" 36 | TEST_GAN_PR="false" 37 | TEST_SNAPSHOT="false" 38 | TEST_EQUIVARIANCE="false" 39 | if [[ ${METRICS} == "all" ]]; then 40 | TEST_FID="true" 41 | TEST_IS="true" 42 | TEST_KID="true" 43 | TEST_GAN_PR="true" 44 | TEST_SNAPSHOT="true" 45 | TEST_EQUIVARIANCE="true" 46 | else 47 | array=(${METRICS//,/ }) 48 | for var in ${array[@]}; do 49 | if [[ ${var} == "fid" ]]; then 50 | TEST_FID="true" 51 | fi 52 | if [[ ${var} == "is" ]]; then 53 | TEST_IS="true" 54 | fi 55 | if [[ ${var} == "kid" ]]; then 56 | TEST_KID="true" 57 | fi 58 | if [[ ${var} == "gan_pr" ]]; then 59 | TEST_GAN_PR="true" 60 | fi 61 | if [[ ${var} == "snapshot" ]]; then 62 | TEST_SNAPSHOT="true" 63 | fi 64 | if [[ ${var} == "equivariance" ]]; then 65 | TEST_EQUIVARIANCE="true" 66 | fi 67 | done 68 | fi 69 | 70 | # Detect `python3` command. 71 | # This workaround addresses a common issue: 72 | # `python` points to python2, which is deprecated. 73 | export PYTHONS 74 | export RVAL 75 | 76 | PYTHONS=$(compgen -c | grep "^python3$") 77 | 78 | # `$?` is a built-in variable in bash, which is the exit status of the most 79 | # recently-executed command; by convention, 0 means success and anything else 80 | # indicates failure. 81 | RVAL=$? 82 | 83 | if [ $RVAL -eq 0 ]; then # if `python3` exist 84 | PYTHON="python3" 85 | else 86 | PYTHON="python" 87 | fi 88 | 89 | ${PYTHON} -m torch.distributed.launch \ 90 | --nproc_per_node=${GPUS} \ 91 | --master_port=${PORT} \ 92 | ./test_metrics.py \ 93 | --launcher="pytorch" \ 94 | --backend="nccl" \ 95 | --dataset ${DATASET} \ 96 | --model ${MODEL} \ 97 | --real_num -1 \ 98 | --fake_num 50000 \ 99 | --test_fid ${TEST_FID} \ 100 | --test_is ${TEST_IS} \ 101 | --test_kid ${TEST_KID} \ 102 | --test_gan_pr ${TEST_GAN_PR} \ 103 | --test_snapshot ${TEST_SNAPSHOT} \ 104 | --test_equivariance ${TEST_EQUIVARIANCE} \ 105 | ${@:5} 106 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan2_ffhq1024.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN2 on FFHQ-1024." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/ffhq1024.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \ 22 | --job_name='stylegan2_ffhq1024' \ 23 | --seed=0 \ 24 | --resolution=1024 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=4 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=200 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.002 \ 45 | --g_lr=0.002 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_interval=16 \ 50 | --r1_gamma=10.0 \ 51 | --pl_interval=4 \ 52 | --pl_weight=2.0 \ 53 | --pl_decay=0.01 \ 54 | --pl_batch_shrink=2 \ 55 | --g_ema_img=10_000 \ 56 | --g_ema_rampup=0.0 \ 57 | --eval_at_start=true \ 58 | --eval_interval=6400 \ 59 | --ckpt_interval=6400 \ 60 | --log_interval=128 \ 61 | --enable_amp=false \ 62 | --use_ada=false \ 63 | --num_fp16_res=0 \ 64 | ${@:3} 65 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan2_ffhq256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN2 on FFHQ-256." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/ffhq256.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \ 22 | --job_name='stylegan2_ffhq256' \ 23 | --seed=0 \ 24 | --resolution=256 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=16 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=200 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.002 \ 45 | --g_lr=0.002 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_interval=16 \ 50 | --r1_gamma=10.0 \ 51 | --pl_interval=4 \ 52 | --pl_weight=2.0 \ 53 | --pl_decay=0.01 \ 54 | --pl_batch_shrink=2 \ 55 | --g_ema_img=10_000 \ 56 | --g_ema_rampup=0.0 \ 57 | --eval_at_start=true \ 58 | --eval_interval=6400 \ 59 | --ckpt_interval=6400 \ 60 | --log_interval=128 \ 61 | --enable_amp=false \ 62 | --use_ada=false \ 63 | --num_fp16_res=0 \ 64 | ${@:3} 65 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan2_ffhq512.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN2 on FFHQ-512." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/ffhq512.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \ 22 | --job_name='stylegan2_ffhq512' \ 23 | --seed=0 \ 24 | --resolution=512 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=8 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=200 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.002 \ 45 | --g_lr=0.002 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_interval=16 \ 50 | --r1_gamma=10.0 \ 51 | --pl_interval=4 \ 52 | --pl_weight=2.0 \ 53 | --pl_decay=0.01 \ 54 | --pl_batch_shrink=2 \ 55 | --g_ema_img=10_000 \ 56 | --g_ema_rampup=0.0 \ 57 | --eval_at_start=true \ 58 | --eval_interval=6400 \ 59 | --ckpt_interval=6400 \ 60 | --log_interval=128 \ 61 | --enable_amp=false \ 62 | --use_ada=false \ 63 | --num_fp16_res=0 \ 64 | ${@:3} 65 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan2_lsun_bedroom256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN2 on LSUN-Bedroom-256." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/LSUN/bedroom_train_lmdb [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \ 22 | --job_name='stylegan2_lsun_bedroom256' \ 23 | --seed=0 \ 24 | --resolution=256 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=100_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=16 \ 32 | --train_data_mirror=false \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=30 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.002 \ 45 | --g_lr=0.002 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_interval=16 \ 50 | --r1_gamma=10.0 \ 51 | --pl_interval=4 \ 52 | --pl_weight=2.0 \ 53 | --pl_decay=0.01 \ 54 | --pl_batch_shrink=2 \ 55 | --g_ema_img=10_000 \ 56 | --g_ema_rampup=0.0 \ 57 | --eval_at_start=true \ 58 | --eval_interval=6400 \ 59 | --ckpt_interval=6400 \ 60 | --log_interval=128 \ 61 | --enable_amp=false \ 62 | --use_ada=false \ 63 | --num_fp16_res=0 \ 64 | ${@:3} 65 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan2ada_afhq512.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN2-ADA on AFHQ-512." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/afhq512.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \ 22 | --job_name='stylegan2ada_afhq512' \ 23 | --seed=0 \ 24 | --resolution=512 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=8 \ 31 | --val_batch_size=8 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=2000 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=8 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.0025 \ 45 | --g_lr=0.0025 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_interval=16 \ 50 | --r1_gamma=0.5 \ 51 | --pl_interval=4 \ 52 | --pl_weight=2.0 \ 53 | --pl_decay=0.01 \ 54 | --pl_batch_shrink=2 \ 55 | --g_ema_img=20_000 \ 56 | --g_ema_rampup=0.0 \ 57 | --eval_at_start=true \ 58 | --eval_interval=3200 \ 59 | --ckpt_interval=3200 \ 60 | --log_interval=64 \ 61 | --enable_amp=false \ 62 | --use_ada=true \ 63 | --num_fp16_res=4 \ 64 | ${@:3} 65 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan2ada_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN2-ADA on CIFAR10." 6 | echo 7 | echo "Note: All settings are already preset for training with 2 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 2 /data/cifar10.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \ 22 | --job_name='stylegan2ada_cifar10' \ 23 | --seed=0 \ 24 | --resolution=32 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --train_anno_meta='annotation.json' \ 28 | --val_dataset=${DATASET} \ 29 | --val_anno_meta='annotation.json' \ 30 | --val_max_samples=-1 \ 31 | --total_img=100_000_000 \ 32 | --batch_size=32 \ 33 | --val_batch_size=128 \ 34 | --train_data_mirror=false \ 35 | --data_loader_type='iter' \ 36 | --data_repeat=500 \ 37 | --data_workers=3 \ 38 | --data_prefetch_factor=2 \ 39 | --data_pin_memory=true \ 40 | --g_init_res=4 \ 41 | --latent_dim=512 \ 42 | --label_dim=10 \ 43 | --d_fmaps_factor=1.0 \ 44 | --g_fmaps_factor=1.0 \ 45 | --d_mbstd_groups=8 \ 46 | --g_num_mappings=2 \ 47 | --g_architecture='origin' \ 48 | --d_lr=0.0025 \ 49 | --g_lr=0.0025 \ 50 | --w_moving_decay=0.995 \ 51 | --sync_w_avg=false \ 52 | --style_mixing_prob=0.0 \ 53 | --r1_interval=16 \ 54 | --r1_gamma=0.01 \ 55 | --pl_interval=4 \ 56 | --pl_weight=0.0 \ 57 | --pl_decay=0.0 \ 58 | --pl_batch_shrink=2 \ 59 | --g_ema_img=500_000 \ 60 | --g_ema_rampup=0.05 \ 61 | --eval_at_start=true \ 62 | --eval_interval=3200 \ 63 | --ckpt_interval=3200 \ 64 | --log_interval=64 \ 65 | --enable_amp=false \ 66 | --use_ada=true \ 67 | --num_fp16_res=4 \ 68 | ${@:3} 69 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan2ada_ffhq1024.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN2-ADA on FFHQ-1024." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/ffhq1024.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \ 22 | --job_name='stylegan2ada_ffhq1024' \ 23 | --seed=0 \ 24 | --resolution=1024 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=4 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=200 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.002 \ 45 | --g_lr=0.002 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_interval=16 \ 50 | --r1_gamma=2.0 \ 51 | --pl_interval=4 \ 52 | --pl_weight=2.0 \ 53 | --pl_decay=0.01 \ 54 | --pl_batch_shrink=2 \ 55 | --g_ema_img=10_000 \ 56 | --g_ema_rampup=0.0 \ 57 | --eval_at_start=true \ 58 | --eval_interval=6400 \ 59 | --ckpt_interval=6400 \ 60 | --log_interval=128 \ 61 | --enable_amp=false \ 62 | --use_ada=true \ 63 | --num_fp16_res=4 \ 64 | ${@:3} 65 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan2ada_ffhq256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN2-ADA on FFHQ-256." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/ffhq256.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan2 \ 22 | --job_name='stylegan2ada_ffhq256' \ 23 | --seed=0 \ 24 | --resolution=256 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=8 \ 31 | --val_batch_size=16 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=200 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=0.5 \ 41 | --g_fmaps_factor=0.5 \ 42 | --d_mbstd_groups=8 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.0025 \ 45 | --g_lr=0.0025 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_interval=16 \ 50 | --r1_gamma=1.0 \ 51 | --pl_interval=4 \ 52 | --pl_weight=2.0 \ 53 | --pl_decay=0.01 \ 54 | --pl_batch_shrink=2 \ 55 | --g_ema_img=20_000 \ 56 | --g_ema_rampup=0.0 \ 57 | --eval_at_start=true \ 58 | --eval_interval=3200 \ 59 | --ckpt_interval=3200 \ 60 | --log_interval=64 \ 61 | --enable_amp=false \ 62 | --use_ada=true \ 63 | --num_fp16_res=4 \ 64 | ${@:3} 65 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan3r_afhq512.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN3 (config R) on" \ 6 | "AFHQ-512." 7 | echo 8 | echo "Note: All settings are already preset for training with 8 GPUs." \ 9 | "Please pass addition options, which will overwrite the original" \ 10 | "settings, if needed." 11 | echo 12 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 13 | echo 14 | echo "Example: $0 8 /data/afhq512.zip [--help]" 15 | echo 16 | exit 0 17 | fi 18 | 19 | GPUS=$1 20 | DATASET=$2 21 | 22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \ 23 | --job_name='stylegan3r_afhq512' \ 24 | --seed=0 \ 25 | --resolution=512 \ 26 | --image_channels=3 \ 27 | --train_dataset=${DATASET} \ 28 | --val_dataset=${DATASET} \ 29 | --val_max_samples=-1 \ 30 | --total_img=25_000_000 \ 31 | --batch_size=4 \ 32 | --val_batch_size=8 \ 33 | --train_data_mirror=true \ 34 | --data_loader_type='iter' \ 35 | --data_repeat=2000 \ 36 | --data_workers=3 \ 37 | --data_prefetch_factor=2 \ 38 | --data_pin_memory=true \ 39 | --g_kernel_size=1 \ 40 | --latent_dim=512 \ 41 | --d_fmaps_factor=1.0 \ 42 | --g_fmaps_factor=1.0 \ 43 | --d_mbstd_groups=4 \ 44 | --g_num_mappings=2 \ 45 | --d_lr=0.002 \ 46 | --g_lr=0.0025 \ 47 | --w_moving_decay=0.998 \ 48 | --sync_w_avg=false \ 49 | --style_mixing_prob=0.0 \ 50 | --r1_interval=16 \ 51 | --r1_gamma=16.4 \ 52 | --blur_init_sigma=10.0 \ 53 | --blur_fade_img=200_000 \ 54 | --pl_interval=0 \ 55 | --pl_weight=0.0 \ 56 | --pl_decay=0.01 \ 57 | --pl_batch_shrink=2 \ 58 | --g_ema_img=10_000 \ 59 | --g_ema_rampup=0.05 \ 60 | --eval_at_start=true \ 61 | --eval_interval=6400 \ 62 | --ckpt_interval=6400 \ 63 | --log_interval=128 \ 64 | --enable_amp=false \ 65 | --use_ada=true \ 66 | --num_fp16_res=4 \ 67 | ${@:3} 68 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan3r_ffhqu1024.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN3 (config R) on" \ 6 | "FFHQ-U-1024." 7 | echo 8 | echo "Note: All settings are already preset for training with 8 GPUs." \ 9 | "Please pass addition options, which will overwrite the original" \ 10 | "settings, if needed." 11 | echo 12 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 13 | echo 14 | echo "Example: $0 8 /data/ffhqu1024.zip [--help]" 15 | echo 16 | exit 0 17 | fi 18 | 19 | GPUS=$1 20 | DATASET=$2 21 | 22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \ 23 | --job_name='stylegan3r_ffhqu1024' \ 24 | --seed=0 \ 25 | --resolution=1024 \ 26 | --image_channels=3 \ 27 | --train_dataset=${DATASET} \ 28 | --val_dataset=${DATASET} \ 29 | --val_max_samples=-1 \ 30 | --total_img=25_000_000 \ 31 | --batch_size=4 \ 32 | --val_batch_size=4 \ 33 | --train_data_mirror=true \ 34 | --data_loader_type='iter' \ 35 | --data_repeat=200 \ 36 | --data_workers=3 \ 37 | --data_prefetch_factor=2 \ 38 | --data_pin_memory=true \ 39 | --g_kernel_size=1 \ 40 | --latent_dim=512 \ 41 | --d_fmaps_factor=1.0 \ 42 | --g_fmaps_factor=1.0 \ 43 | --d_mbstd_groups=4 \ 44 | --g_num_mappings=2 \ 45 | --d_lr=0.002 \ 46 | --g_lr=0.0025 \ 47 | --w_moving_decay=0.998 \ 48 | --sync_w_avg=false \ 49 | --style_mixing_prob=0.0 \ 50 | --r1_interval=16 \ 51 | --r1_gamma=32.8 \ 52 | --blur_init_sigma=10.0 \ 53 | --blur_fade_img=200_000 \ 54 | --pl_interval=0 \ 55 | --pl_weight=0.0 \ 56 | --pl_decay=0.01 \ 57 | --pl_batch_shrink=2 \ 58 | --g_ema_img=10_000 \ 59 | --g_ema_rampup=0.05 \ 60 | --eval_at_start=true \ 61 | --eval_interval=6400 \ 62 | --ckpt_interval=6400 \ 63 | --log_interval=128 \ 64 | --enable_amp=false \ 65 | --use_ada=true \ 66 | --num_fp16_res=4 \ 67 | ${@:3} 68 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan3r_ffhqu256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN3 (config R) on" \ 6 | "FFHQ-U-256." 7 | echo 8 | echo "Note: All settings are already preset for training with 8 GPUs." \ 9 | "Please pass addition options, which will overwrite the original" \ 10 | "settings, if needed." 11 | echo 12 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 13 | echo 14 | echo "Example: $0 8 /data/ffhqu256.zip [--help]" 15 | echo 16 | exit 0 17 | fi 18 | 19 | GPUS=$1 20 | DATASET=$2 21 | 22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \ 23 | --job_name='stylegan3r_ffhqu256' \ 24 | --seed=0 \ 25 | --resolution=256 \ 26 | --image_channels=3 \ 27 | --train_dataset=${DATASET} \ 28 | --val_dataset=${DATASET} \ 29 | --val_max_samples=-1 \ 30 | --total_img=25_000_000 \ 31 | --batch_size=8 \ 32 | --val_batch_size=16 \ 33 | --train_data_mirror=true \ 34 | --data_loader_type='iter' \ 35 | --data_repeat=200 \ 36 | --data_workers=3 \ 37 | --data_prefetch_factor=2 \ 38 | --data_pin_memory=true \ 39 | --g_kernel_size=1 \ 40 | --latent_dim=512 \ 41 | --d_fmaps_factor=0.5 \ 42 | --g_fmaps_factor=0.5 \ 43 | --d_mbstd_groups=4 \ 44 | --g_num_mappings=2 \ 45 | --d_lr=0.0025 \ 46 | --g_lr=0.0025 \ 47 | --w_moving_decay=0.998 \ 48 | --sync_w_avg=false \ 49 | --style_mixing_prob=0.0 \ 50 | --r1_interval=16 \ 51 | --r1_gamma=1.0 \ 52 | --blur_init_sigma=10.0 \ 53 | --blur_fade_img=200_000 \ 54 | --pl_interval=0 \ 55 | --pl_weight=0.0 \ 56 | --pl_decay=0.01 \ 57 | --pl_batch_shrink=2 \ 58 | --g_ema_img=10_000 \ 59 | --g_ema_rampup=0.05 \ 60 | --eval_at_start=true \ 61 | --eval_interval=3200 \ 62 | --ckpt_interval=3200 \ 63 | --log_interval=64 \ 64 | --enable_amp=false \ 65 | --use_ada=true \ 66 | --num_fp16_res=4 \ 67 | ${@:3} 68 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan3t_afhq512.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN3 (config T) on" \ 6 | "AFHQ-512." 7 | echo 8 | echo "Note: All settings are already preset for training with 8 GPUs." \ 9 | "Please pass addition options, which will overwrite the original" \ 10 | "settings, if needed." 11 | echo 12 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 13 | echo 14 | echo "Example: $0 8 /data/afhq512.zip [--help]" 15 | echo 16 | exit 0 17 | fi 18 | 19 | GPUS=$1 20 | DATASET=$2 21 | 22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \ 23 | --job_name='stylegan3t_afhq512' \ 24 | --seed=0 \ 25 | --resolution=512 \ 26 | --image_channels=3 \ 27 | --train_dataset=${DATASET} \ 28 | --val_dataset=${DATASET} \ 29 | --val_max_samples=-1 \ 30 | --total_img=25_000_000 \ 31 | --batch_size=4 \ 32 | --val_batch_size=8 \ 33 | --train_data_mirror=true \ 34 | --data_loader_type='iter' \ 35 | --data_repeat=2000 \ 36 | --data_workers=3 \ 37 | --data_prefetch_factor=2 \ 38 | --data_pin_memory=true \ 39 | --g_kernel_size=3 \ 40 | --latent_dim=512 \ 41 | --d_fmaps_factor=1.0 \ 42 | --g_fmaps_factor=1.0 \ 43 | --d_mbstd_groups=4 \ 44 | --g_num_mappings=2 \ 45 | --d_lr=0.002 \ 46 | --g_lr=0.0025 \ 47 | --w_moving_decay=0.998 \ 48 | --sync_w_avg=false \ 49 | --style_mixing_prob=0.0 \ 50 | --r1_interval=16 \ 51 | --r1_gamma=8.2 \ 52 | --blur_init_sigma=0.0 \ 53 | --blur_fade_img=0 \ 54 | --pl_interval=0 \ 55 | --pl_weight=0.0 \ 56 | --pl_decay=0.01 \ 57 | --pl_batch_shrink=2 \ 58 | --g_ema_img=10_000 \ 59 | --g_ema_rampup=0.05 \ 60 | --eval_at_start=true \ 61 | --eval_interval=6400 \ 62 | --ckpt_interval=6400 \ 63 | --log_interval=128 \ 64 | --enable_amp=false \ 65 | --use_ada=true \ 66 | --num_fp16_res=4 \ 67 | ${@:3} 68 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan3t_ffhqu1024.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN3 (config T) on" \ 6 | "FFHQ-U-1024." 7 | echo 8 | echo "Note: All settings are already preset for training with 8 GPUs." \ 9 | "Please pass addition options, which will overwrite the original" \ 10 | "settings, if needed." 11 | echo 12 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 13 | echo 14 | echo "Example: $0 8 /data/ffhqu1024.zip [--help]" 15 | echo 16 | exit 0 17 | fi 18 | 19 | GPUS=$1 20 | DATASET=$2 21 | 22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \ 23 | --job_name='stylegan3t_ffhqu1024' \ 24 | --seed=0 \ 25 | --resolution=1024 \ 26 | --image_channels=3 \ 27 | --train_dataset=${DATASET} \ 28 | --val_dataset=${DATASET} \ 29 | --val_max_samples=-1 \ 30 | --total_img=25_000_000 \ 31 | --batch_size=4 \ 32 | --val_batch_size=4 \ 33 | --train_data_mirror=true \ 34 | --data_loader_type='iter' \ 35 | --data_repeat=200 \ 36 | --data_workers=3 \ 37 | --data_prefetch_factor=2 \ 38 | --data_pin_memory=true \ 39 | --g_kernel_size=3 \ 40 | --latent_dim=512 \ 41 | --d_fmaps_factor=1.0 \ 42 | --g_fmaps_factor=1.0 \ 43 | --d_mbstd_groups=4 \ 44 | --g_num_mappings=2 \ 45 | --d_lr=0.002 \ 46 | --g_lr=0.0025 \ 47 | --w_moving_decay=0.998 \ 48 | --sync_w_avg=false \ 49 | --style_mixing_prob=0.0 \ 50 | --r1_interval=16 \ 51 | --r1_gamma=32.8 \ 52 | --blur_init_sigma=0.0 \ 53 | --blur_fade_img=0 \ 54 | --pl_interval=0 \ 55 | --pl_weight=0.0 \ 56 | --pl_decay=0.01 \ 57 | --pl_batch_shrink=2 \ 58 | --g_ema_img=10_000 \ 59 | --g_ema_rampup=0.05 \ 60 | --eval_at_start=true \ 61 | --eval_interval=6400 \ 62 | --ckpt_interval=6400 \ 63 | --log_interval=128 \ 64 | --enable_amp=false \ 65 | --use_ada=true \ 66 | --num_fp16_res=4 \ 67 | ${@:3} 68 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan3t_ffhqu256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN3 (config T) on" \ 6 | "FFHQ-U-256." 7 | echo 8 | echo "Note: All settings are already preset for training with 8 GPUs." \ 9 | "Please pass addition options, which will overwrite the original" \ 10 | "settings, if needed." 11 | echo 12 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 13 | echo 14 | echo "Example: $0 8 /data/ffhqu256.zip [--help]" 15 | echo 16 | exit 0 17 | fi 18 | 19 | GPUS=$1 20 | DATASET=$2 21 | 22 | ./scripts/dist_train.sh ${GPUS} stylegan3 \ 23 | --job_name='stylegan3t_ffhqu256' \ 24 | --seed=0 \ 25 | --resolution=256 \ 26 | --image_channels=3 \ 27 | --train_dataset=${DATASET} \ 28 | --val_dataset=${DATASET} \ 29 | --val_max_samples=-1 \ 30 | --total_img=25_000_000 \ 31 | --batch_size=8 \ 32 | --val_batch_size=16 \ 33 | --train_data_mirror=true \ 34 | --data_loader_type='iter' \ 35 | --data_repeat=200 \ 36 | --data_workers=3 \ 37 | --data_prefetch_factor=2 \ 38 | --data_pin_memory=true \ 39 | --g_kernel_size=3 \ 40 | --latent_dim=512 \ 41 | --d_fmaps_factor=0.5 \ 42 | --g_fmaps_factor=0.5 \ 43 | --d_mbstd_groups=4 \ 44 | --g_num_mappings=2 \ 45 | --d_lr=0.0025 \ 46 | --g_lr=0.0025 \ 47 | --w_moving_decay=0.998 \ 48 | --sync_w_avg=false \ 49 | --style_mixing_prob=0.0 \ 50 | --r1_interval=16 \ 51 | --r1_gamma=1.0 \ 52 | --blur_init_sigma=0.0 \ 53 | --blur_fade_img=0 \ 54 | --pl_interval=0 \ 55 | --pl_weight=0.0 \ 56 | --pl_decay=0.01 \ 57 | --pl_batch_shrink=2 \ 58 | --g_ema_img=10_000 \ 59 | --g_ema_rampup=0.05 \ 60 | --eval_at_start=true \ 61 | --eval_interval=3200 \ 62 | --ckpt_interval=3200 \ 63 | --log_interval=64 \ 64 | --enable_amp=false \ 65 | --use_ada=true \ 66 | --num_fp16_res=4 \ 67 | ${@:3} 68 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan_ffhq1024.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN on FFHQ-1024." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/ffhq1024.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan \ 22 | --job_name='stylegan_ffhq1024' \ 23 | --seed=0 \ 24 | --resolution=1024 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=4 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=200 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.001 \ 45 | --g_lr=0.001 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_gamma=10.0 \ 50 | --g_ema_img=10_000 \ 51 | --eval_at_start=true \ 52 | --eval_interval=6400 \ 53 | --ckpt_interval=6400 \ 54 | --log_interval=128 \ 55 | --use_ada=false \ 56 | --enable_amp=true \ 57 | -o controllers.ProgressScheduler.init_res=8 \ 58 | -o controllers.ProgressScheduler.batch_size_schedule='{"res4":64,"res8":32,"res16":16,"res32":8}' \ 59 | -o controllers.ProgressScheduler.lr_schedule='{"res128":1.5,"res256":2,"res512":3,"res1024":3}' \ 60 | ${@:3} 61 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan_ffhq256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN on FFHQ-256." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/ffhq256.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan \ 22 | --job_name='stylegan_ffhq256' \ 23 | --seed=0 \ 24 | --resolution=256 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=16 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=200 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.001 \ 45 | --g_lr=0.001 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_gamma=10.0 \ 50 | --g_ema_img=10_000 \ 51 | --eval_at_start=true \ 52 | --eval_interval=6400 \ 53 | --ckpt_interval=6400 \ 54 | --log_interval=128 \ 55 | --use_ada=false \ 56 | --enable_amp=true \ 57 | -o controllers.ProgressScheduler.init_res=8 \ 58 | -o controllers.ProgressScheduler.batch_size_schedule='{"res4":64,"res8":32,"res16":16,"res32":8}' \ 59 | -o controllers.ProgressScheduler.lr_schedule='{"res128":1.5,"res256":2,"res512":3,"res1024":3}' \ 60 | ${@:3} 61 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan_ffhq512.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN on FFHQ-512." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/ffhq512.zip [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan \ 22 | --job_name='stylegan_ffhq512' \ 23 | --seed=0 \ 24 | --resolution=512 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=8 \ 32 | --train_data_mirror=true \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=200 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.001 \ 45 | --g_lr=0.001 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_gamma=10.0 \ 50 | --g_ema_img=10_000 \ 51 | --eval_at_start=true \ 52 | --eval_interval=6400 \ 53 | --ckpt_interval=6400 \ 54 | --log_interval=128 \ 55 | --use_ada=false \ 56 | --enable_amp=true \ 57 | -o controllers.ProgressScheduler.init_res=8 \ 58 | -o controllers.ProgressScheduler.batch_size_schedule='{"res4":64,"res8":32,"res16":16,"res32":8}' \ 59 | -o controllers.ProgressScheduler.lr_schedule='{"res128":1.5,"res256":2,"res512":3,"res1024":3}' \ 60 | ${@:3} 61 | -------------------------------------------------------------------------------- /scripts/training_demos/stylegan_lsun_bedroom256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help message. 4 | if [[ $# -lt 2 ]]; then 5 | echo "This script launches a job of training StyleGAN on LSUN-Bedroom-256." 6 | echo 7 | echo "Note: All settings are already preset for training with 8 GPUs." \ 8 | "Please pass addition options, which will overwrite the original" \ 9 | "settings, if needed." 10 | echo 11 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 12 | echo 13 | echo "Example: $0 8 /data/LSUN/bedroom_train_lmdb [--help]" 14 | echo 15 | exit 0 16 | fi 17 | 18 | GPUS=$1 19 | DATASET=$2 20 | 21 | ./scripts/dist_train.sh ${GPUS} stylegan \ 22 | --job_name='stylegan_lsun_bedroom256' \ 23 | --seed=0 \ 24 | --resolution=256 \ 25 | --image_channels=3 \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=100_000_000 \ 30 | --batch_size=4 \ 31 | --val_batch_size=16 \ 32 | --train_data_mirror=false \ 33 | --data_loader_type='iter' \ 34 | --data_repeat=30 \ 35 | --data_workers=3 \ 36 | --data_prefetch_factor=2 \ 37 | --data_pin_memory=true \ 38 | --g_init_res=4 \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.001 \ 45 | --g_lr=0.001 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_gamma=10.0 \ 50 | --g_ema_img=10_000 \ 51 | --eval_at_start=true \ 52 | --eval_interval=6400 \ 53 | --ckpt_interval=6400 \ 54 | --log_interval=128 \ 55 | --use_ada=false \ 56 | --enable_amp=true \ 57 | -o controllers.ProgressScheduler.init_res=8 \ 58 | -o controllers.ProgressScheduler.batch_size_schedule='{"res4":64,"res8":32,"res16":16,"res32":8}' \ 59 | -o controllers.ProgressScheduler.lr_schedule='{"res128":1.5,"res256":2,"res512":3,"res1024":3}' \ 60 | ${@:3} 61 | -------------------------------------------------------------------------------- /scripts/training_demos/volumegan_ffhq256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | # Help message. 5 | if [[ $# -lt 2 ]]; then 6 | echo "This script launches a job of training VolumeGAN on FFHQ-256." 7 | echo 8 | echo "Note: All settings are already preset for training with 8 GPUs." \ 9 | "Please pass addition options, which will overwrite the original" \ 10 | "settings, if needed." 11 | echo 12 | echo "Usage: $0 GPUS DATASET [OPTIONS]" 13 | echo 14 | echo "Example: $0 8 /data/ffhq256.zip [--help]" 15 | echo 16 | exit 0 17 | fi 18 | 19 | GPUS=$1 20 | DATASET=$2 21 | NeRFRES=32 22 | RES=256 23 | ./scripts/dist_train.sh ${GPUS} volumegan-ffhq \ 24 | --seed=0 \ 25 | --resolution=${RES} \ 26 | --train_dataset=${DATASET} \ 27 | --val_dataset=${DATASET} \ 28 | --val_max_samples=-1 \ 29 | --total_img=25_000_000 \ 30 | --batch_size=8 \ 31 | --val_batch_size=16 \ 32 | --train_data_mirror=true \ 33 | --data_workers=3 \ 34 | --data_pin_memory=true \ 35 | --data_prefetch_factor=2 \ 36 | --data_repeat=500 \ 37 | --train_data_mirror=true \ 38 | --g_init_res=${NeRFRES} \ 39 | --latent_dim=512 \ 40 | --d_fmaps_factor=1.0 \ 41 | --g_fmaps_factor=1.0 \ 42 | --d_mbstd_groups=4 \ 43 | --g_num_mappings=8 \ 44 | --d_lr=0.002 \ 45 | --g_lr=0.002 \ 46 | --w_moving_decay=0.995 \ 47 | --sync_w_avg=false \ 48 | --style_mixing_prob=0.9 \ 49 | --r1_interval=16 \ 50 | --r1_gamma=10.0 \ 51 | --pl_interval=4 \ 52 | --pl_weight=2.0 \ 53 | --pl_decay=0.01 \ 54 | --pl_batch_shrink=2 \ 55 | --g_ema_img=20000 \ 56 | --eval_at_start=false \ 57 | --eval_interval=6400 \ 58 | --ckpt_interval=6400 \ 59 | --log_interval=128 \ 60 | --enable_amp=false \ 61 | --use_ada=false \ 62 | --num_fp16_res=128\ 63 | --logger_type=normal \ 64 | --use_pg=true \ 65 | ${@:4} 66 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/README.md: -------------------------------------------------------------------------------- 1 | # Operators for StyleGAN2 2 | 3 | All files in this directory are borrowed from repository [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including 4 | 5 | - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator. 6 | - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering. 7 | - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel. 8 | - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map. 9 | - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map. 10 | - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map. 11 | - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. 12 | - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. 13 | - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`) 14 | 15 | We make following slight modifications beyond disabling some lint warnings: 16 | 17 | - Line 25 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). 18 | - Line 35 of file `custom_ops.py`: Disable log message when setting up customized operators. 19 | - Line 53/89 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*) 20 | - Line 24 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). 21 | - Line 32 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default. 22 | - Line 36 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator. 23 | - Line 33 of file `conv2d_gradfix.py`: Enable customized convolution operators by default. 24 | - Line 46/51 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators. 25 | - Line 36/66 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators. 26 | - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator. 27 | 28 | Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default. 29 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/third_party/stylegan2_official_ops/__init__.py -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/fma.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. 12 | 13 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch 14 | """ 15 | 16 | # pylint: disable=line-too-long 17 | # pylint: disable=missing-function-docstring 18 | 19 | import torch 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def fma(a, b, c, impl='cuda'): # => a * b + c 24 | if impl == 'cuda': 25 | return _FusedMultiplyAdd.apply(a, b, c) 26 | return torch.addcmul(c, a, b) 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 31 | @staticmethod 32 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 33 | out = torch.addcmul(c, a, b) 34 | ctx.save_for_backward(a, b) 35 | ctx.c_shape = c.shape 36 | return out 37 | 38 | @staticmethod 39 | def backward(ctx, dout): # pylint: disable=arguments-differ 40 | a, b = ctx.saved_tensors 41 | c_shape = ctx.c_shape 42 | da = None 43 | db = None 44 | dc = None 45 | 46 | if ctx.needs_input_grad[0]: 47 | da = _unbroadcast(dout * b, a.shape) 48 | 49 | if ctx.needs_input_grad[1]: 50 | db = _unbroadcast(dout * a, b.shape) 51 | 52 | if ctx.needs_input_grad[2]: 53 | dc = _unbroadcast(dout, c_shape) 54 | 55 | return da, db, dc 56 | 57 | #---------------------------------------------------------------------------- 58 | 59 | def _unbroadcast(x, shape): 60 | extra_dims = x.ndim - len(shape) 61 | assert extra_dims >= 0 62 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 63 | if len(dim): 64 | x = x.sum(dim=dim, keepdim=True) 65 | if extra_dims: 66 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 67 | assert x.shape == shape 68 | return x 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | # pylint: enable=line-too-long 73 | # pylint: enable=missing-function-docstring 74 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Custom replacement for `torch.nn.functional.grid_sample`. 12 | 13 | This is useful for differentiable augmentation. This customized operator 14 | supports arbitrarily high order gradients between the input and output. Only 15 | works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and 16 | `align_corners=False`. 17 | 18 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch 19 | """ 20 | 21 | # pylint: disable=redefined-builtin 22 | # pylint: disable=arguments-differ 23 | # pylint: disable=protected-access 24 | # pylint: disable=line-too-long 25 | # pylint: disable=missing-function-docstring 26 | 27 | import warnings 28 | import torch 29 | from distutils.version import LooseVersion 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | enabled = True # Enable the custom op by setting this to true. 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def grid_sample(input, grid, impl='cuda'): 38 | if impl == 'cuda' and _should_use_custom_op(): 39 | return _GridSample2dForward.apply(input, grid) 40 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _should_use_custom_op(): 45 | if not enabled: 46 | return False 47 | if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): 48 | return True 49 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 50 | return False 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | class _GridSample2dForward(torch.autograd.Function): 55 | @staticmethod 56 | def forward(ctx, input, grid): 57 | assert input.ndim == 4 58 | assert grid.ndim == 4 59 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 60 | ctx.save_for_backward(input, grid) 61 | return output 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | input, grid = ctx.saved_tensors 66 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 67 | return grad_input, grad_grid 68 | 69 | #---------------------------------------------------------------------------- 70 | 71 | class _GridSample2dBackward(torch.autograd.Function): 72 | @staticmethod 73 | def forward(ctx, grad_output, input, grid): 74 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 75 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 76 | ctx.save_for_backward(grid) 77 | return grad_input, grad_grid 78 | 79 | @staticmethod 80 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 81 | _ = grad2_grad_grid # unused 82 | grid, = ctx.saved_tensors 83 | grad2_grad_output = None 84 | grad2_input = None 85 | grad2_grid = None 86 | 87 | if ctx.needs_input_grad[0]: 88 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 89 | 90 | assert not ctx.needs_input_grad[2] 91 | return grad2_grad_output, grad2_input, grad2_grid 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | # pylint: enable=redefined-builtin 96 | # pylint: enable=arguments-differ 97 | # pylint: enable=protected-access 98 | # pylint: enable=line-too-long 99 | # pylint: enable=missing-function-docstring 100 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/README.md: -------------------------------------------------------------------------------- 1 | # Operators for StyleGAN2 2 | 3 | All files in this directory are borrowed from repository [stylegan3](https://github.com/NVlabs/stylegan3). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including 4 | 5 | - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator. 6 | - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering. 7 | - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel. 8 | - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map. 9 | - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map. 10 | - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map. 11 | - `filtered_lrelu.filtered_lrelu()`: Leaky ReLU layer, wrapped with upsampling and downsampling for anti-aliasing. 12 | - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. 13 | - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. 14 | - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`) 15 | 16 | We make following slight modifications beyond disabling some lint warnings: 17 | 18 | - Line 24 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3). 19 | - Line 36 of file `custom_ops.py`: Disable log message when setting up customized operators. 20 | - Line 54/109 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*) 21 | - Line 21 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3). 22 | - Line 162-165 of file `filtered_lrelu.py`: Change some implementations in `_filtered_lrelu_ref()` to `ref`. 23 | - Line 31 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default. 24 | - Line 35 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator. 25 | - Line 34 of file `conv2d_gradfix.py`: Enable customized convolution operators by default. 26 | - Line 48/53 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators. 27 | - Line 36/53 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators. 28 | - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator. 29 | 30 | Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default. 31 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/third_party/stylegan3_official_ops/__init__.py -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/fma.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. 12 | 13 | Please refer to https://github.com/NVlabs/stylegan3 14 | """ 15 | 16 | # pylint: disable=line-too-long 17 | # pylint: disable=missing-function-docstring 18 | 19 | import torch 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def fma(a, b, c, impl='cuda'): # => a * b + c 24 | if impl == 'cuda': 25 | return _FusedMultiplyAdd.apply(a, b, c) 26 | return torch.addcmul(c, a, b) 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 31 | @staticmethod 32 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 33 | out = torch.addcmul(c, a, b) 34 | ctx.save_for_backward(a, b) 35 | ctx.c_shape = c.shape 36 | return out 37 | 38 | @staticmethod 39 | def backward(ctx, dout): # pylint: disable=arguments-differ 40 | a, b = ctx.saved_tensors 41 | c_shape = ctx.c_shape 42 | da = None 43 | db = None 44 | dc = None 45 | 46 | if ctx.needs_input_grad[0]: 47 | da = _unbroadcast(dout * b, a.shape) 48 | 49 | if ctx.needs_input_grad[1]: 50 | db = _unbroadcast(dout * a, b.shape) 51 | 52 | if ctx.needs_input_grad[2]: 53 | dc = _unbroadcast(dout, c_shape) 54 | 55 | return da, db, dc 56 | 57 | #---------------------------------------------------------------------------- 58 | 59 | def _unbroadcast(x, shape): 60 | extra_dims = x.ndim - len(shape) 61 | assert extra_dims >= 0 62 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 63 | if len(dim): 64 | x = x.sum(dim=dim, keepdim=True) 65 | if extra_dims: 66 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 67 | assert x.shape == shape 68 | return x 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | # pylint: enable=line-too-long 73 | # pylint: enable=missing-function-docstring 74 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Custom replacement for `torch.nn.functional.grid_sample`. 12 | 13 | This is useful for differentiable augmentation. This customized operator 14 | supports arbitrarily high order gradients between the input and output. Only 15 | works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and 16 | `align_corners=False`. 17 | 18 | Please refer to https://github.com/NVlabs/stylegan3 19 | """ 20 | 21 | # pylint: disable=redefined-builtin 22 | # pylint: disable=arguments-differ 23 | # pylint: disable=protected-access 24 | # pylint: disable=line-too-long 25 | # pylint: disable=missing-function-docstring 26 | 27 | import torch 28 | 29 | #---------------------------------------------------------------------------- 30 | 31 | enabled = True # Enable the custom op by setting this to true. 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def grid_sample(input, grid, impl='cuda'): 36 | if impl == 'cuda' and _should_use_custom_op(): 37 | return _GridSample2dForward.apply(input, grid) 38 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 39 | 40 | #---------------------------------------------------------------------------- 41 | 42 | def _should_use_custom_op(): 43 | return enabled 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | class _GridSample2dForward(torch.autograd.Function): 48 | @staticmethod 49 | def forward(ctx, input, grid): 50 | assert input.ndim == 4 51 | assert grid.ndim == 4 52 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 53 | ctx.save_for_backward(input, grid) 54 | return output 55 | 56 | @staticmethod 57 | def backward(ctx, grad_output): 58 | input, grid = ctx.saved_tensors 59 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 60 | return grad_input, grad_grid 61 | 62 | #---------------------------------------------------------------------------- 63 | 64 | class _GridSample2dBackward(torch.autograd.Function): 65 | @staticmethod 66 | def forward(ctx, grad_output, input, grid): 67 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 68 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 69 | ctx.save_for_backward(grid) 70 | return grad_input, grad_grid 71 | 72 | @staticmethod 73 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 74 | _ = grad2_grad_grid # unused 75 | grid, = ctx.saved_tensors 76 | grad2_grad_output = None 77 | grad2_input = None 78 | grad2_grid = None 79 | 80 | if ctx.needs_input_grad[0]: 81 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 82 | 83 | assert not ctx.needs_input_grad[2] 84 | return grad2_grad_output, grad2_input, grad2_grid 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | # pylint: enable=redefined-builtin 89 | # pylint: enable=arguments-differ 90 | # pylint: enable=protected-access 91 | # pylint: enable=line-too-long 92 | # pylint: enable=missing-function-docstring 93 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Main function for model training.""" 3 | 4 | import click 5 | 6 | from configs import CONFIG_POOL 7 | from configs import build_config 8 | from runners import build_runner 9 | from utils.dist_utils import init_dist 10 | from utils.dist_utils import exit_dist 11 | 12 | 13 | @click.group(name='Distributed Training', 14 | help='Train a deep model by choosing a command (configuration).', 15 | context_settings={'show_default': True, 'max_content_width': 180}) 16 | @click.option('--launcher', default='pytorch', 17 | type=click.Choice(['pytorch', 'slurm']), 18 | help='Distributed launcher.') 19 | @click.option('--backend', default='nccl', 20 | type=click.Choice(['nccl', 'gloo', 'mpi']), 21 | help='Distributed backend.') 22 | @click.option('--local_rank', type=int, default=0, hidden=True, 23 | help='Replica rank on the current node. This field is required ' 24 | 'by `torch.distributed.launch`.') 25 | def command_group(launcher, backend, local_rank): # pylint: disable=unused-argument 26 | """Defines a command group for launching distributed jobs. 27 | 28 | This function is mainly for interaction with the command line. The real 29 | launching is executed by `main()` function, through `result_callback()` 30 | decorator. In other words, the arguments obtained from the command line will 31 | be passed to `main()` function. As for how the arguments are passed, it is 32 | the responsibility of each command of this command group. Please refer to 33 | `BaseConfig.get_command()` in `configs/base_config.py` for more details. 34 | """ 35 | 36 | 37 | @command_group.result_callback() 38 | @click.pass_context 39 | def main(ctx, kwargs, launcher, backend, local_rank): 40 | """Main function for distributed training. 41 | 42 | Basically, this function first initializes a distributed environment, then 43 | parses configuration from the command line, and finally sets up the runner 44 | with the parsed configuration for training. 45 | """ 46 | _ = local_rank # unused variable 47 | 48 | # Initialize distributed environment. 49 | init_dist(launcher=launcher, backend=backend) 50 | 51 | # Build configurations and runner. 52 | config = build_config(ctx.invoked_subcommand, kwargs).get_config() 53 | runner = build_runner(config) 54 | 55 | # Start training. 56 | runner.train() 57 | runner.close() 58 | 59 | # Exit distributed environment. 60 | exit_dist() 61 | 62 | 63 | if __name__ == '__main__': 64 | # Append all available commands (from `configs/`) into the command group. 65 | for cfg in CONFIG_POOL: 66 | command_group.add_command(cfg.get_command()) 67 | # Run by interacting with command line. 68 | command_group() # pylint: disable=no-value-for-parameter 69 | -------------------------------------------------------------------------------- /unit_tests.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all unit tests.""" 3 | 4 | import argparse 5 | 6 | from models.test import test_model 7 | from utils.loggers.test import test_logger 8 | from utils.visualizers.test import test_visualizer 9 | from utils.parsing_utils import parse_bool 10 | 11 | 12 | def parse_args(): 13 | """Parses arguments.""" 14 | parser = argparse.ArgumentParser(description='Run unit tests.') 15 | parser.add_argument('--result_dir', type=str, 16 | default='work_dirs/unit_tests', 17 | help='Path to save the test results. (default: ' 18 | '%(default)s)') 19 | parser.add_argument('--test_all', type=parse_bool, default=False, 20 | help='Whether to run all unit tests. (default: ' 21 | '%(default)s)') 22 | parser.add_argument('--test_model', type=parse_bool, default=False, 23 | help='Whether to run unit test on models. (default: ' 24 | '%(default)s)') 25 | parser.add_argument('--test_logger', type=parse_bool, default=False, 26 | help='Whether to run unit test on loggers. (default: ' 27 | '%(default)s)') 28 | parser.add_argument('--test_visualizer', type=parse_bool, default=False, 29 | help='Whether to do unit test on visualizers. ' 30 | '(default: %(default)s)') 31 | return parser.parse_args() 32 | 33 | 34 | def main(): 35 | """Main function.""" 36 | args = parse_args() 37 | 38 | if args.test_all or args.test_model: 39 | test_model() 40 | 41 | if args.test_all or args.test_logger: 42 | test_logger(args.result_dir) 43 | 44 | if args.test_all or args.test_visualizer: 45 | test_visualizer(args.result_dir) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/volumegan/34bca216f7d16c600982422f0268bd896a68b759/utils/__init__.py -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains utility functions used for distribution.""" 3 | 4 | import contextlib 5 | import os 6 | import subprocess 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | 12 | __all__ = ['init_dist', 'exit_dist', 'ddp_sync', 'get_ddp_module'] 13 | 14 | 15 | def init_dist(launcher, backend='nccl', **kwargs): 16 | """Initializes distributed environment.""" 17 | if mp.get_start_method(allow_none=True) is None: 18 | mp.set_start_method('spawn') 19 | if launcher == 'pytorch': 20 | rank = int(os.environ['RANK']) 21 | num_gpus = torch.cuda.device_count() 22 | torch.cuda.set_device(rank % num_gpus) 23 | dist.init_process_group(backend=backend, **kwargs) 24 | elif launcher == 'slurm': 25 | proc_id = int(os.environ['SLURM_PROCID']) 26 | ntasks = int(os.environ['SLURM_NTASKS']) 27 | node_list = os.environ['SLURM_NODELIST'] 28 | num_gpus = torch.cuda.device_count() 29 | torch.cuda.set_device(proc_id % num_gpus) 30 | addr = subprocess.getoutput( 31 | f'scontrol show hostname {node_list} | head -n1') 32 | port = os.environ.get('PORT', 29500) 33 | os.environ['MASTER_PORT'] = str(port) 34 | os.environ['MASTER_ADDR'] = addr 35 | os.environ['WORLD_SIZE'] = str(ntasks) 36 | os.environ['RANK'] = str(proc_id) 37 | dist.init_process_group(backend=backend) 38 | else: 39 | raise NotImplementedError(f'Not implemented launcher type: ' 40 | f'`{launcher}`!') 41 | 42 | 43 | def exit_dist(): 44 | """Exits the distributed environment.""" 45 | if dist.is_initialized(): 46 | dist.destroy_process_group() 47 | 48 | 49 | @contextlib.contextmanager 50 | def ddp_sync(model, sync): 51 | """Controls whether the `DistributedDataParallel` model should be synced.""" 52 | assert isinstance(model, torch.nn.Module) 53 | is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel) 54 | if sync or not is_ddp: 55 | yield 56 | else: 57 | with model.no_sync(): 58 | yield 59 | 60 | 61 | def get_ddp_module(model): 62 | """Gets the module from `DistributedDataParallel`.""" 63 | assert isinstance(model, torch.nn.Module) 64 | is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel) 65 | if is_ddp: 66 | return model.module 67 | return model 68 | -------------------------------------------------------------------------------- /utils/file_transmitters/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all file transmitters.""" 3 | 4 | from .local_file_transmitter import LocalFileTransmitter 5 | from .dummy_file_transmitter import DummyFileTransmitter 6 | 7 | __all__ = ['build_file_transmitter'] 8 | 9 | _TRANSMITTERS = { 10 | 'local': LocalFileTransmitter, 11 | 'dummy': DummyFileTransmitter, 12 | } 13 | 14 | 15 | def build_file_transmitter(transmitter_type='local', **kwargs): 16 | """Builds a file transmitter. 17 | 18 | Args: 19 | transmitter_type: Type of the file transmitter_type, which is case 20 | insensitive. (default: `normal`) 21 | **kwargs: Additional arguments to build the file transmitter. 22 | 23 | Raises: 24 | ValueError: If the `transmitter_type` is not supported. 25 | """ 26 | transmitter_type = transmitter_type.lower() 27 | if transmitter_type not in _TRANSMITTERS: 28 | raise ValueError(f'Invalid transmitter type: `{transmitter_type}`!\n' 29 | f'Types allowed: {list(_TRANSMITTERS)}.') 30 | return _TRANSMITTERS[transmitter_type](**kwargs) 31 | -------------------------------------------------------------------------------- /utils/file_transmitters/base_file_transmitter.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the base class to transmit files across file systems. 3 | 4 | Basically, a file transmitter connects the local file system, on which the 5 | programme runs, to a remote file system. This is particularly used for 6 | (1) pulling files that are required by the programme from remote, and 7 | (2) pushing results that are produced by the programme to remote. In this way, 8 | the programme can focus on local file system only. 9 | 10 | NOTE: The remote file system can be the same as the local file system, since 11 | users may want to transmit files across directories. 12 | """ 13 | 14 | import warnings 15 | 16 | __all__ = ['BaseFileTransmitter'] 17 | 18 | 19 | class BaseFileTransmitter(object): 20 | """Defines the base file transmitter. 21 | 22 | A transmitter should have the following functions: 23 | 24 | (1) pull(): The function to pull a file/directory from remote to local. 25 | (2) push(): The function to push a file/directory from local to remote. 26 | (3) remove(): The function to remove a file/directory. 27 | (4) make_remote_dir(): Make directory remotely. 28 | 29 | 30 | To simplify, each derived class just need to implement the following helper 31 | functions: 32 | 33 | (1) download_hard(): Hard download a file/directory from remote to local. 34 | (2) download_soft(): Soft download a file/directory from remote to local. 35 | This is especially used to save space (e.g., soft link). 36 | (3) upload(): Upload a file/directory from local to remote. 37 | (4) delete(): Delete a file/directory according to given path. 38 | """ 39 | 40 | def __init__(self): 41 | pass 42 | 43 | @property 44 | def name(self): 45 | """Returns the class name of the file transmitter.""" 46 | return self.__class__.__name__ 47 | 48 | @staticmethod 49 | def download_hard(src, dst): 50 | """Downloads (in hard mode) a file/directory from remote to local.""" 51 | raise NotImplementedError('Should be implemented in derived class!') 52 | 53 | @staticmethod 54 | def download_soft(src, dst): 55 | """Downloads (in soft mode) a file/directory from local to remote.""" 56 | raise NotImplementedError('Should be implemented in derived class!') 57 | 58 | @staticmethod 59 | def upload(src, dst): 60 | """Uploads a file/directory from local to remote.""" 61 | raise NotImplementedError('Should be implemented in derived class!') 62 | 63 | @staticmethod 64 | def delete(path): 65 | """Deletes the given path.""" 66 | # TODO: should we secure the path to avoid mis-removing / attacks? 67 | raise NotImplementedError('Should be implemented in derived class!') 68 | 69 | def pull(self, src, dst, hard=False): 70 | """Pulls a file/directory from remote to local. 71 | 72 | The argument `hard` is to control the download mode (hard or soft). 73 | For example, the hard mode may hardly copy the file while the soft mode 74 | may softly link the file. 75 | """ 76 | if hard: 77 | self.download_hard(src, dst) 78 | else: 79 | self.download_soft(src, dst) 80 | 81 | def push(self, src, dst): 82 | """Pushes a file/directory from local to remote.""" 83 | self.upload(src, dst) 84 | 85 | def remove(self, path): 86 | """Removes the given path.""" 87 | warnings.warn(f'`{path}` will be removed!') 88 | self.delete(path) 89 | 90 | def make_remote_dir(self, directory): 91 | """Makes a directory on the remote system.""" 92 | raise NotImplementedError('Should be implemented in derived class!') 93 | -------------------------------------------------------------------------------- /utils/file_transmitters/dummy_file_transmitter.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of dummy file transmitter. 3 | 4 | This file transmitter has all expected data transmission functions but behaves 5 | silently, which is very useful in multi-processing mode. Only the chief process 6 | can have the file transmitter with normal behavior. 7 | """ 8 | 9 | from .base_file_transmitter import BaseFileTransmitter 10 | 11 | __all__ = ['DummyFileTransmitter'] 12 | 13 | 14 | class DummyFileTransmitter(BaseFileTransmitter): 15 | """Implements a dummy transmitter which transmits nothing.""" 16 | 17 | @staticmethod 18 | def download_hard(src, dst): 19 | return 20 | 21 | @staticmethod 22 | def download_soft(src, dst): 23 | return 24 | 25 | @staticmethod 26 | def upload(src, dst): 27 | return 28 | 29 | @staticmethod 30 | def delete(path): 31 | return 32 | 33 | def make_remote_dir(self, directory): 34 | return 35 | -------------------------------------------------------------------------------- /utils/file_transmitters/local_file_transmitter.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of local file transmitter. 3 | 4 | The transmitter builds the connection between the local file system and itself. 5 | This can be used to transmit files from one directory to another. Consequently, 6 | `remote` in this file also means `local`. 7 | """ 8 | 9 | from utils.misc import print_and_execute 10 | from .base_file_transmitter import BaseFileTransmitter 11 | 12 | __all__ = ['LocalFileTransmitter'] 13 | 14 | 15 | class LocalFileTransmitter(BaseFileTransmitter): 16 | """Implements the transmitter connecting local file system to itself.""" 17 | 18 | @staticmethod 19 | def download_hard(src, dst): 20 | print_and_execute(f'cp {src} {dst}') 21 | 22 | @staticmethod 23 | def download_soft(src, dst): 24 | print_and_execute(f'ln -s {src} {dst}') 25 | 26 | @staticmethod 27 | def upload(src, dst): 28 | print_and_execute(f'cp {src} {dst}') 29 | 30 | @staticmethod 31 | def delete(path): 32 | print_and_execute(f'rm -r {path}') 33 | 34 | def make_remote_dir(self, directory): 35 | print_and_execute(f'mkdir -p {directory}') 36 | -------------------------------------------------------------------------------- /utils/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all loggers.""" 3 | 4 | from .normal_logger import NormalLogger 5 | from .rich_logger import RichLogger 6 | from .dummy_logger import DummyLogger 7 | 8 | __all__ = ['build_logger'] 9 | 10 | _LOGGERS = { 11 | 'normal': NormalLogger, 12 | 'rich': RichLogger, 13 | 'dummy': DummyLogger 14 | } 15 | 16 | 17 | def build_logger(logger_type='normal', **kwargs): 18 | """Builds a logger. 19 | 20 | Args: 21 | logger_type: Type of logger, which is case insensitive. 22 | (default: `normal`) 23 | **kwargs: Additional arguments to build the logger. 24 | 25 | Raises: 26 | ValueError: If the `logger_type` is not supported. 27 | """ 28 | logger_type = logger_type.lower() 29 | if logger_type not in _LOGGERS: 30 | raise ValueError(f'Invalid logger type: `{logger_type}`!\n' 31 | f'Types allowed: {list(_LOGGERS)}.') 32 | return _LOGGERS[logger_type](**kwargs) 33 | -------------------------------------------------------------------------------- /utils/loggers/dummy_logger.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of dummy logger. 3 | 4 | This logger has all expected logging functions but behaves silently, which is 5 | very useful in multi-processing mode. Only the chief process can have the logger 6 | with normal behavior. 7 | """ 8 | 9 | from .base_logger import BaseLogger 10 | 11 | __all__ = ['DummyLogger'] 12 | 13 | 14 | class DummyLogger(BaseLogger): 15 | """Implements a dummy logger which logs nothing.""" 16 | 17 | def __init__(self, 18 | logger_name='logger', 19 | logfile=None, 20 | screen_level=None, 21 | file_level=None, 22 | indent_space=4, 23 | verbose_log=False): 24 | super().__init__(logger_name=logger_name, 25 | logfile=logfile, 26 | screen_level=screen_level, 27 | file_level=file_level, 28 | indent_space=indent_space, 29 | verbose_log=verbose_log) 30 | 31 | def _log(self, message, **kwargs): 32 | return 33 | 34 | def _debug(self, message, **kwargs): 35 | return 36 | 37 | def _info(self, message, **kwargs): 38 | return 39 | 40 | def _warning(self, message, **kwargs): 41 | return 42 | 43 | def _error(self, message, **kwargs): 44 | return 45 | 46 | def _exception(self, message, **kwargs): 47 | return 48 | 49 | def _critical(self, message, **kwargs): 50 | return 51 | 52 | def _print(self, *messages, **kwargs): 53 | return 54 | 55 | def init_pbar(self, leave=False): 56 | return 57 | 58 | def add_pbar_task(self, name, total, **kwargs): 59 | return -1 60 | 61 | def update_pbar(self, task_id, advance=1): 62 | return 63 | 64 | def close_pbar(self): 65 | return 66 | -------------------------------------------------------------------------------- /utils/loggers/test.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Unit test for logger.""" 3 | 4 | import os 5 | import time 6 | 7 | from . import build_logger 8 | 9 | __all__ = ['test_logger'] 10 | 11 | _TEST_DIR = 'logger_test' 12 | 13 | 14 | def test_logger(test_dir=_TEST_DIR): 15 | """Tests loggers.""" 16 | print('========== Start Logger Test ==========') 17 | 18 | os.makedirs(test_dir, exist_ok=True) 19 | 20 | for logger_type in ['normal', 'rich', 'dummy']: 21 | for indent_space in [2, 4]: 22 | for verbose_log in [False, True]: 23 | if logger_type == 'normal': 24 | class_name = 'Logger' 25 | elif logger_type == 'rich': 26 | class_name = 'RichLogger' 27 | elif logger_type == 'dummy': 28 | class_name = 'DummyLogger' 29 | 30 | print(f'===== ' 31 | f'Testing `utils.logger.{class_name}` ' 32 | f' (indent: {indent_space}, verbose: {verbose_log}) ' 33 | f'=====') 34 | logger_name = (f'{logger_type}_logger_' 35 | f'indent_{indent_space}_' 36 | f'verbose_{verbose_log}') 37 | logger = build_logger( 38 | logger_type, 39 | logger_name=logger_name, 40 | logfile=os.path.join(test_dir, f'test_{logger_name}.log'), 41 | verbose_log=verbose_log, 42 | indent_space=indent_space) 43 | logger.print('print log') 44 | logger.print('print log,', 'log 2') 45 | logger.print('print log (indent level 0)', indent_level=0) 46 | logger.print('print log (indent level 1)', indent_level=1) 47 | logger.print('print log (indent level 2)', indent_level=2) 48 | logger.print('print log (verbose `False`)', is_verbose=False) 49 | logger.print('print log (verbose `True`)', is_verbose=True) 50 | logger.debug('debug log') 51 | logger.info('info log') 52 | logger.warning('warning log') 53 | logger.init_pbar() 54 | task_1 = logger.add_pbar_task('Task 1', 500) 55 | task_2 = logger.add_pbar_task('Task 2', 1000) 56 | for _ in range(1000): 57 | logger.update_pbar(task_1, 1) 58 | logger.update_pbar(task_2, 1) 59 | time.sleep(0.002) 60 | logger.close_pbar() 61 | print('Success!') 62 | 63 | print('========== Finish Logger Test ==========') 64 | -------------------------------------------------------------------------------- /utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the utility functions to handle import TensorFlow modules. 3 | 4 | Basically, TensorFlow may not be supported in the current environment, or may 5 | cause some warnings. This file provides functions to help ease TensorFlow 6 | related imports, such as TensorBoard. 7 | """ 8 | 9 | import warnings 10 | 11 | __all__ = ['import_tf', 'import_tb_writer'] 12 | 13 | 14 | def import_tf(): 15 | """Imports TensorFlow module if possible. 16 | 17 | If `ImportError` is raised, `None` will be returned. Otherwise, the module 18 | `tensorflow` will be returned. 19 | """ 20 | warnings.filterwarnings('ignore', category=FutureWarning) 21 | try: 22 | import tensorflow as tf # pylint: disable=import-outside-toplevel 23 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 24 | module = tf 25 | except ImportError: 26 | module = None 27 | warnings.filterwarnings('default', category=FutureWarning) 28 | return module 29 | 30 | 31 | def import_tb_writer(): 32 | """Imports the SummaryWriter of TensorBoard. 33 | 34 | If `ImportError` is raised, `None` will be returned. Otherwise, the class 35 | `SummaryWriter` will be returned. 36 | 37 | NOTE: This function attempts to import `SummaryWriter` from 38 | `torch.utils.tensorboard`. But it does not necessarily mean the import 39 | always succeeds because installing TensorBoard is not a duty of `PyTorch`. 40 | """ 41 | warnings.filterwarnings('ignore', category=FutureWarning) 42 | try: 43 | from torch.utils.tensorboard import SummaryWriter # pylint: disable=import-outside-toplevel 44 | except ImportError: # In case TensorBoard is not supported. 45 | SummaryWriter = None 46 | warnings.filterwarnings('default', category=FutureWarning) 47 | return SummaryWriter 48 | -------------------------------------------------------------------------------- /utils/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all visualizers.""" 3 | 4 | from .grid_visualizer import GridVisualizer 5 | from .gif_visualizer import GifVisualizer 6 | from .html_visualizer import HtmlVisualizer 7 | from .html_visualizer import HtmlReader 8 | from .video_visualizer import VideoVisualizer 9 | from .video_visualizer import VideoReader 10 | 11 | __all__ = [ 12 | 'GridVisualizer', 'GifVisualizer', 'HtmlVisualizer', 'HtmlReader', 13 | 'VideoVisualizer', 'VideoReader' 14 | ] 15 | -------------------------------------------------------------------------------- /utils/visualizers/gif_visualizer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the visualizer to visualize images as a GIF.""" 3 | 4 | from PIL import Image 5 | 6 | from ..image_utils import parse_image_size 7 | from ..image_utils import load_image 8 | from ..image_utils import resize_image 9 | from ..image_utils import list_images_from_dir 10 | 11 | __all__ = ['GifVisualizer'] 12 | 13 | 14 | class GifVisualizer(object): 15 | """Defines the visualizer that visualizes an image collection as GIF.""" 16 | 17 | def __init__(self, image_size=None, duration=100, loop=0): 18 | """Initializes the GIF visualizer. 19 | 20 | Args: 21 | image_size: Size for image visualization. (default: None) 22 | duration: Duration between two frames, in milliseconds. 23 | (default: 100) 24 | loop: How many times to loop the GIF. `0` means infinite. 25 | (default: 0) 26 | """ 27 | self.set_image_size(image_size) 28 | self.set_duration(duration) 29 | self.set_loop(loop) 30 | 31 | def set_image_size(self, image_size=None): 32 | """Sets the image size of the GIF.""" 33 | height, width = parse_image_size(image_size) 34 | self.image_height = height 35 | self.image_width = width 36 | 37 | def set_duration(self, duration=100): 38 | """Sets the GIF duration.""" 39 | self.duration = duration 40 | 41 | def set_loop(self, loop=0): 42 | """Sets how many times the GIF will be looped. `0` means infinite.""" 43 | self.loop = loop 44 | 45 | def visualize_collection(self, images, save_path): 46 | """Visualizes a collection of images one by one.""" 47 | height, width = images[0].shape[0:2] 48 | height = self.image_height or height 49 | width = self.image_width or width 50 | pil_images = [] 51 | for image in images: 52 | if image.shape[0:2] != (height, width): 53 | image = resize_image(image, (width, height)) 54 | pil_images.append(Image.fromarray(image)) 55 | pil_images[0].save(save_path, format='GIF', save_all=True, 56 | append_images=pil_images[1:], 57 | duration=self.duration, 58 | loop=self.loop) 59 | 60 | def visualize_list(self, image_list, save_path): 61 | """Visualizes a list of image files.""" 62 | height, width = load_image(image_list[0]).shape[0:2] 63 | height = self.image_height or height 64 | width = self.image_width or width 65 | pil_images = [] 66 | for filename in image_list: 67 | image = load_image(filename) 68 | if image.shape[0:2] != (height, width): 69 | image = resize_image(image, (width, height)) 70 | pil_images.append(Image.fromarray(image)) 71 | pil_images[0].save(save_path, format='GIF', save_all=True, 72 | append_images=pil_images[1:], 73 | duration=self.duration, 74 | loop=self.loop) 75 | 76 | def visualize_directory(self, directory, save_path): 77 | """Visualizes all images under a directory.""" 78 | image_list = list_images_from_dir(directory) 79 | self.visualize_list(image_list, save_path) 80 | -------------------------------------------------------------------------------- /utils/visualizers/test.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Unit test for visualizer.""" 3 | 4 | import os 5 | import skvideo.datasets 6 | 7 | from ..image_utils import save_image 8 | from . import GridVisualizer 9 | from . import HtmlVisualizer 10 | from . import HtmlReader 11 | from . import GifVisualizer 12 | from . import VideoVisualizer 13 | from . import VideoReader 14 | 15 | __all__ = ['test_visualizer'] 16 | 17 | _TEST_DIR = 'visualizer_test' 18 | 19 | 20 | def test_visualizer(test_dir=_TEST_DIR): 21 | """Tests visualizers.""" 22 | print('========== Start Visualizer Test ==========') 23 | 24 | frame_dir = os.path.join(test_dir, 'test_frames') 25 | os.makedirs(frame_dir, exist_ok=True) 26 | 27 | print('===== Testing `VideoReader` =====') 28 | # Total 132 frames, with size (720, 1080). 29 | video_reader = VideoReader(skvideo.datasets.bigbuckbunny()) 30 | frame_height = video_reader.frame_height 31 | frame_width = video_reader.frame_width 32 | frame_size = (frame_height, frame_width) 33 | half_size = (frame_height // 2, frame_width // 2) 34 | # Save frames as the test set. 35 | for idx in range(80): 36 | frame = video_reader.read() 37 | save_image(os.path.join(frame_dir, f'{idx:02d}.png'), frame) 38 | 39 | print('===== Testing `GirdVisualizer` =====') 40 | grid_visualizer = GridVisualizer() 41 | grid_visualizer.set_row_spacing(30) 42 | grid_visualizer.set_col_spacing(30) 43 | grid_visualizer.set_background(use_black=True) 44 | path = os.path.join(test_dir, 'portrait_row_major_ori_space30_black.png') 45 | grid_visualizer.visualize_directory(frame_dir, path, 46 | is_portrait=True, is_row_major=True) 47 | path = os.path.join( 48 | test_dir, 'landscape_col_major_downsample_space15_white.png') 49 | grid_visualizer.set_image_size(half_size) 50 | grid_visualizer.set_row_spacing(15) 51 | grid_visualizer.set_col_spacing(15) 52 | grid_visualizer.set_background(use_black=False) 53 | grid_visualizer.visualize_directory(frame_dir, path, 54 | is_portrait=False, is_row_major=False) 55 | 56 | print('===== Testing `HtmlVisualizer` =====') 57 | html_visualizer = HtmlVisualizer() 58 | path = os.path.join(test_dir, 'portrait_col_major_ori.html') 59 | html_visualizer.visualize_directory(frame_dir, path, 60 | is_portrait=True, is_row_major=False) 61 | path = os.path.join(test_dir, 'landscape_row_major_downsample.html') 62 | html_visualizer.set_image_size(half_size) 63 | html_visualizer.visualize_directory(frame_dir, path, 64 | is_portrait=False, is_row_major=True) 65 | 66 | print('===== Testing `HtmlReader` =====') 67 | path = os.path.join(test_dir, 'landscape_row_major_downsample.html') 68 | html_reader = HtmlReader(path) 69 | for j in range(html_reader.num_cols): 70 | assert html_reader.get_header(j) == '' 71 | parsed_dir = os.path.join(test_dir, 'parsed_frames') 72 | os.makedirs(parsed_dir, exist_ok=True) 73 | for i in range(html_reader.num_rows): 74 | for j in range(html_reader.num_cols): 75 | idx = i * html_reader.num_cols + j 76 | assert html_reader.get_text(i, j).endswith(f'(index {idx:03d})') 77 | image = html_reader.get_image(i, j, image_size=frame_size) 78 | assert image.shape[0:2] == frame_size 79 | save_image(os.path.join(parsed_dir, f'{idx:02d}.png'), image) 80 | 81 | print('===== Testing `GifVisualizer` =====') 82 | gif_visualizer = GifVisualizer() 83 | path = os.path.join(test_dir, 'gif_ori.gif') 84 | gif_visualizer.visualize_directory(frame_dir, path) 85 | gif_visualizer.set_image_size(half_size) 86 | path = os.path.join(test_dir, 'gif_downsample.gif') 87 | gif_visualizer.visualize_directory(frame_dir, path) 88 | 89 | print('===== Testing `VideoVisualizer` =====') 90 | video_visualizer = VideoVisualizer() 91 | path = os.path.join(test_dir, 'video_ori.mp4') 92 | video_visualizer.visualize_directory(frame_dir, path) 93 | path = os.path.join(test_dir, 'video_downsample.mp4') 94 | video_visualizer.set_frame_size(half_size) 95 | video_visualizer.visualize_directory(frame_dir, path) 96 | 97 | print('========== Finish Visualizer Test ==========') 98 | --------------------------------------------------------------------------------