├── .gitignore ├── README.md ├── configs ├── config.yaml ├── dataset │ ├── base.yaml │ ├── cats_aligned.yaml │ ├── custom.yaml │ ├── ffhq.yaml │ ├── ffhq_posed.yaml │ ├── megascans_food.yaml │ ├── megascans_plants.yaml │ └── nerf_synth.yaml ├── env │ ├── base.yaml │ └── local.yaml ├── infra.yaml ├── model │ ├── base.yaml │ ├── eg3d.yaml │ └── stylegan2.yaml ├── scripts │ ├── calc_metrics.yaml │ ├── camera │ │ ├── front_circle.yaml │ │ ├── line.yaml │ │ ├── point.yaml │ │ ├── points.yaml │ │ ├── wiggle.yaml │ │ └── zoom_in_out.yaml │ ├── extract_geometry.yaml │ ├── inference.yaml │ └── vis │ │ ├── bg_nobg.yaml │ │ ├── density.yaml │ │ ├── front_grid.yaml │ │ ├── interp.yaml │ │ ├── interp_density.yaml │ │ ├── interp_video.yaml │ │ ├── interp_video_grid.yaml │ │ ├── minigrid.yaml │ │ ├── rotation_video.yaml │ │ └── video.yaml └── training │ ├── base.yaml │ ├── default.yaml │ ├── patch_beta.yaml │ ├── patch_categ.yaml │ ├── patch_discrete_uniform.yaml │ └── patch_uniform.yaml ├── environment.yml ├── render_dataset.py ├── setup.py ├── src ├── __init__.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── infra │ ├── __init__.py │ ├── experiments.yaml │ ├── launch.py │ ├── slurm_job.py │ ├── slurm_job_proxy.sh │ └── utils.py ├── legacy.py ├── metrics │ ├── __init__.py │ ├── equivariance.py │ ├── frechet_inception_distance.py │ ├── inception_score.py │ ├── kernel_inception_distance.py │ ├── metric_main.py │ ├── metric_utils.py │ ├── perceptual_path_length.py │ └── precision_recall.py ├── scripts │ ├── calc_metrics.py │ ├── extract_geometry.py │ ├── inference.py │ └── utils.py ├── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── filtered_lrelu.cpp │ │ ├── filtered_lrelu.cu │ │ ├── filtered_lrelu.h │ │ ├── filtered_lrelu.py │ │ ├── filtered_lrelu_ns.cu │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py ├── train.py └── training │ ├── __init__.py │ ├── augment.py │ ├── dataset.py │ ├── layers.py │ ├── loss.py │ ├── networks_discriminator.py │ ├── networks_eg3d.py │ ├── networks_inr_gan.py │ ├── networks_stylegan2.py │ ├── networks_stylegan3.py │ ├── rendering.py │ ├── training_loop.py │ └── utils.py └── teaser.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .cache/ 3 | lib/ 4 | /env/ 5 | /env-ampere/ 6 | outputs/ 7 | .hydra/ 8 | experiments/ 9 | old-experiments/ 10 | failed-experiments/ 11 | data 12 | *.log 13 | notebooks/zoo.ipynb 14 | .ipynb_checkpoints/ 15 | checkpoints/ 16 | *.egg-info/ 17 | /notebooks/ 18 | /scripts/ 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rethinking training of 3D GANs 2 | 3 | [[Website]](https://rethinking-3d-gans.github.io) 4 | [[Paper]](https://rethinking-3d-gans.github.io/rethinking-3d-gans.pdf) 5 | 6 | ![teaser](https://user-images.githubusercontent.com/105873229/170562989-bd409c04-bc49-4439-9b2e-7f6986eaa9e5.jpg) 7 | 8 | 9 | ## Installation 10 | To install and activate the environment, run the following command: 11 | ``` 12 | conda env create -f environment.yml -p env 13 | conda activate ./env 14 | ``` 15 | This repo is built on top of [StyleGAN2-ADA](https://github.com/NVlabs/stylegan2-ada-pytorch), so make sure that it runs on your system. 16 | 17 | ## Data format 18 | Data should be stored in zip archives, the exact structure is not important, the script will be used all the found images. 19 | For FFHQ and Cats, we use the same data processing as [GRAM](https://yudeng.github.io/GRAM/). 20 | Put your datasets into `data/` directory. 21 | 22 | If you want to train with camera angles enabled, then create a `dataset.json` with `camera_angles` dict of `"": [yaw, pitch, roll]` key/values. 23 | Also, use `model.discriminator.camera_cond=true model.discriminator.camera_cond_drop_p=0.5` command line arguments (or simply override them in the config). 24 | If you want to train on a custom dataset, then create its config under `configs/dataset` folder. 25 | 26 | Data links: 27 | - [Megascans Plants 256x256](https://www.dropbox.com/s/078gy1govyyoye9/plants_256.zip?dl=0) 28 | - [Megascans Food 256x256](https://www.dropbox.com/s/lekkx0agd4fjaaa/food_256.zip?dl=0) 29 | 30 | ## Training 31 | 32 | To launch training, run: 33 | ``` 34 | python src/infra/launch.py hydra.run.dir=. exp_suffix=min0.125-anneal10k-gamma0.05-dblocks3-cameracond-drop0.5-cin dataset=ffhq_posed dataset.resolution=512 training=patch_beta training.patch.min_scale_trg=0.25 training.patch.anneal_kimg=10000 model=eg3d training.metrics=fid2k_full env=local training.patch.resolution=64 model.discriminator.num_additional_start_blocks=3 training.kimg=100000 training.gamma=0.05 model.generator.tri_plane.res=512 35 | ``` 36 | 37 | To continue training, launch: 38 | ``` 39 | python src/infra/launch.py hydra.run.dir=. experiment_dir= training.resume=latest 40 | ``` 41 | 42 | ## Evalution 43 | At train time, we compute FID only on 2,048 fake images (vs all real images), since generating 50,000 images takes too long. 44 | To compute FID for 50k fake images, run: 45 | ``` 46 | python src/scripts/calc_metrics.py hydra.run.dir=. ckpt.networks_dir= script.data= script.mirror=true script.gpus=4 script.metrics=fid50k_full img_resolution=256 ckpt.selection_metric=fid2k_full ckpt.reload_code=true script=calc_metrics 47 | ``` 48 | 49 | ## Visualization 50 | 51 | We provide a lot of visualizations types, with the entry point being `configs/scripts/inference.yaml'. 52 | For example, to create the main front grid, run: 53 | ``` 54 | python src/scripts/inference.py hydra.run.dir=. ckpt.networks_dir= vis=front_grid camera=points output_dir= num_seeds=16 truncation_psi=0.7 55 | ``` 56 | 57 | To create visualization videos, run: 58 | ``` 59 | python src/scripts/inference.py hydra.run.dir=. ckpt.networks_dir= output_dir= vis=video camera=front_circle camera.num_frames=64 vis.fps=30 num_seeds=16 60 | ``` 61 | 62 | ## Rendering Megascans 63 | 64 | To render the megascans, obtain the necessary models from the [website](https://quixel.com/megascans/home), convert them into GLTF, create a `enb.blend` Blender environment, and then run: 65 | ``` 66 | blender --python render_dataset.py env.blend --background 67 | ``` 68 | The rendering config is located in `render_dataset.py`. 69 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model/base.yaml 3 | - model: eg3d 4 | 5 | # Import dataset common and overwrite it with the custom one 6 | - dataset/base.yaml 7 | - dataset: ffhq 8 | 9 | - training/base.yaml 10 | - training: patch_uniform 11 | 12 | - env/base.yaml 13 | - env: local 14 | - infra.yaml 15 | -------------------------------------------------------------------------------- /configs/dataset/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | path: data/${dataset.name}.zip 4 | c_dim: 0 5 | sampling: ~ # You must define it explicitly 6 | 7 | # Default parameters 8 | resolution: 256 9 | 10 | # When used with slurm, this will take the dataset from `path_for_slurm_job` 11 | # and will copy it into the `path` location 12 | path_for_slurm_job: ${env.datasets_dir}/${dataset.name}.zip 13 | -------------------------------------------------------------------------------- /configs/dataset/cats_aligned.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: cats_${dataset.resolution}_aligned 4 | sampling: 5 | fov: 12.0 6 | ray_start: 0.88 7 | ray_end: 1.12 8 | radius: 1.0 9 | dist: custom 10 | cube_scale: 0.2 11 | white_back: false 12 | last_back: false 13 | -------------------------------------------------------------------------------- /configs/dataset/custom.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: ~ 4 | resolution: ~ 5 | white_back: true 6 | -------------------------------------------------------------------------------- /configs/dataset/ffhq.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: ffhq_${dataset.resolution} 4 | sampling: 5 | fov: 12.0 6 | ray_start: 0.88 7 | ray_end: 1.12 8 | radius: 1.0 9 | dist: gaussian 10 | horizontal_stddev: 0.3 11 | horizontal_mean: 0.0 12 | vertical_mean: 1.57079632679 13 | vertical_stddev: 0.155 14 | cube_scale: 0.2 15 | white_back: false 16 | last_back: false 17 | -------------------------------------------------------------------------------- /configs/dataset/ffhq_posed.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: ffhq_${dataset.resolution}_posed 4 | sampling: 5 | fov: 12.0 6 | ray_start: 0.88 7 | ray_end: 1.12 8 | radius: 1.0 9 | dist: custom 10 | cube_scale: 0.2 11 | white_back: false 12 | last_back: false 13 | -------------------------------------------------------------------------------- /configs/dataset/megascans_food.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: food_${dataset.resolution} 4 | sampling: 5 | fov: 30.0 6 | ray_start: 0.75 7 | ray_end: 1.25 8 | radius: 1.0 9 | # dist: custom 10 | dist: spherical_uniform 11 | horizontal_stddev: 3.141592653589793 12 | vertical_stddev: 1.5707963267948966 13 | horizontal_mean: 0.0 14 | vertical_mean: 1.5707963267948966 15 | white_back: true 16 | last_back: false 17 | cube_scale: 0.475 18 | c_dim: 6 19 | -------------------------------------------------------------------------------- /configs/dataset/megascans_plants.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: plants_${dataset.resolution} 4 | sampling: 5 | fov: 30.0 6 | ray_start: 0.75 7 | ray_end: 1.25 8 | radius: 1.0 9 | # dist: custom 10 | dist: spherical_uniform 11 | horizontal_stddev: 3.141592653589793 12 | vertical_stddev: 1.5707963267948966 13 | horizontal_mean: 0.0 14 | vertical_mean: 1.5707963267948966 15 | white_back: true 16 | last_back: false 17 | cube_scale: 0.475 18 | c_dim: 191 19 | -------------------------------------------------------------------------------- /configs/dataset/nerf_synth.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | scene: ficus 4 | name: ${dataset.scene}_${dataset.resolution} 5 | sampling: 6 | # fov: 20.0 7 | # ray_start: 3.5 8 | # ray_end: 4.5 9 | # radius: 4.0 10 | fov: 30.0 11 | ray_start: 0.75 12 | ray_end: 1.25 13 | radius: 1.0 14 | dist: custom 15 | # dist: spherical_uniform 16 | # horizontal_stddev: 3.141592653589793 17 | # vertical_stddev: 0.7417649320975901 18 | # horizontal_mean: 1.5707963267948966 19 | # vertical_mean: 0.7417649320975901 20 | white_back: true 21 | last_back: false 22 | cube_scale: 0.475 23 | -------------------------------------------------------------------------------- /configs/env/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | python_bin: ${env.project_path}/env/bin/python 4 | before_train_commands: [] 5 | torch_extensions_dir: "/tmp/torch_extensions" 6 | objects_to_copy: 7 | - ${env.project_path}/src 8 | - ${env.project_path}/configs 9 | # A list of objects that are static and too big 10 | # to be copy-pasted for each experiment 11 | symlinks_to_create: 12 | - ${env.project_path}/data 13 | tmp_dir: "/tmp" 14 | datasets_dir: ~ 15 | slurm_constraint: v100 16 | -------------------------------------------------------------------------------- /configs/env/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | project_path: ${hydra:runtime.cwd} 4 | -------------------------------------------------------------------------------- /configs/infra.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # Arguments that we want to pass via env into slurm job launcher 4 | env_args: 5 | project_dir: ${experiment_dir} 6 | python_bin: ${env.python_bin} 7 | python_script: ${experiment_dir}/src/infra/slurm_job.py 8 | 9 | num_gpus: 4 10 | print_only: false 11 | slurm: false 12 | use_qos: false 13 | 14 | git_hash: {_target_: src.infra.utils.get_git_hash} 15 | exp_suffix: no_spec 16 | experiment_name: ${dataset.name}_${model.name}_${training.name}_${exp_suffix} 17 | experiment_name_with_hash: ${experiment_name}-${git_hash} 18 | experiments_root_dir: experiments 19 | experiment_dir: ${env.project_path}/${experiments_root_dir}/${experiment_name_with_hash} 20 | job_sequence_length: 1 21 | run_profiling: false 22 | 23 | sbatch_args: 24 | constraint: ${env.slurm_constraint} 25 | time: "1-0" 26 | gres: gpu:${num_gpus} 27 | cpus-per-task: 5 28 | mem: 29 | _target_: src.infra.utils.num_gpus_to_mem 30 | num_gpus: ${num_gpus} 31 | mem_per_gpu: 64 32 | # mem-per-gpu: 64G 33 | cpus-per-gpu: 5 34 | comment: ${experiment_name} 35 | 36 | sbatch_args_str: 37 | _target_: src.infra.utils.cfg_to_args_str 38 | cfg: ${sbatch_args} 39 | 40 | env_args_str: 41 | _target_: src.infra.utils.cfg_to_args_str 42 | cfg: ${env_args} 43 | use_dashes: true 44 | -------------------------------------------------------------------------------- /configs/model/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | generator: 4 | fmaps: 1.0 # # Capacity multiplier --- the default one is 1.0 5 | cmax: 512 6 | cbase: 32768 7 | optim: 8 | betas: [0.0, 0.99] 9 | patch: ${training.patch} 10 | dataset: ${dataset} 11 | w_dim: 512 12 | camera_cond: false 13 | camera_cond_drop_p: 0.0 14 | camera_cond_spoof_p: 0.5 15 | discriminator: 16 | fmaps: 0.5 # # Capacity multiplier --- the default one is 1.0 17 | cmax: 512 18 | cbase: 32768 19 | patch: ${training.patch} 20 | num_additional_start_blocks: 0 21 | mbstd_group_size: 4 # Minibatch std group size 22 | camera_cond: false 23 | camera_cond_drop_p: 0.0 24 | predict_pose_loss_weight: ${training.predict_pose_loss_weight} 25 | renorm: false 26 | 27 | # Hyper-modulation parameters 28 | hyper_type: no_hyper # One of ["no_hyper", "hyper", "dummy_hyper"]. Disable by default. 29 | optim: 30 | lr: 0.002 31 | betas: [0.0, 0.99] 32 | loss_kwargs: 33 | pl_weight: 0.0 34 | -------------------------------------------------------------------------------- /configs/model/eg3d.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: eg3d 4 | generator: 5 | backbone: stylegan2 6 | num_ray_steps: 48 7 | clamp_mode: softplus 8 | nerf_noise_std_init: 1.0 9 | nerf_noise_kimg_growth: 5000 10 | use_noise: true # Should we use spatial noise in StyleGAN2? 11 | 12 | nerf_sp_beta_init: 1.0 13 | nerf_sp_beta_target: 1.0 14 | nerf_sp_beta_kimg_growth: 5000 15 | 16 | tri_plane: 17 | res: 512 18 | feat_dim: 32 19 | fp32: true 20 | last_block_fp32: false 21 | view_hid_dim: 0 22 | posenc_period_len: 0 23 | mlp: 24 | n_layers: 2 25 | hid_dim: 64 26 | 27 | bg_model: 28 | type: ~ # One of [null, "plane", "sphere"] 29 | output_channels: 4 30 | coord_dim: 4 31 | num_blocks: 2 32 | cbase: 32768 33 | cmax: 128 34 | num_fp16_blocks: 0 35 | fmm: {enabled: false, rank: 3, activation: demod} 36 | posenc_period_len: 64.0 # Fourier features period length 37 | 38 | # Sampling parameters 39 | num_steps: 16 # Number of steps per ray 40 | start: 1.0 # Start plane for the background (in terms of disparity) 41 | 42 | discriminator: 43 | hyper_type: no_hyper 44 | renorm: true 45 | 46 | loss_kwargs: 47 | blur_init_sigma: 10 48 | blur_fade_kimg: 200 49 | -------------------------------------------------------------------------------- /configs/model/stylegan2.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: stylegan2 4 | generator: {} 5 | discriminator: {} 6 | loss_kwargs: 7 | pl_weight: 2.0 8 | -------------------------------------------------------------------------------- /configs/scripts/calc_metrics.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # Checkpoint loading options 4 | ckpt: 5 | network_pkl: ~ # Network pickle filename 6 | networks_dir: ~ # Network pickles directory 7 | selection_metric: fid2k_full # Which metric to use when selecting the best ckpt? 8 | reload_code: true 9 | 10 | metrics: "fid50k_full" # Comma-separated list of metric names 11 | data: ~ # Path to the dataset to evaluate metrics against (directory or zip) 12 | mirror: ~ # Whether the dataset was augmented with x-flips during training. Default: look up 13 | gpus: 1 # Number of GPUs to use 14 | img_resolution: 256 # Image resolution of the generator? 15 | verbose: false 16 | -------------------------------------------------------------------------------- /configs/scripts/camera/front_circle.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: front_circle 4 | num_frames: 64 5 | yaw_diff: 0.3 6 | pitch_diff: 0.2 7 | use_zoom: true 8 | -------------------------------------------------------------------------------- /configs/scripts/camera/line.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: line 4 | num_frames: 16 5 | yaw_left: -1.57 6 | yaw_right: 1.57 7 | pitch_left: 0.835 8 | pitch_right: 0.835 9 | -------------------------------------------------------------------------------- /configs/scripts/camera/point.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: point 4 | yaws: 0.0 5 | pitch: 1.57 6 | -------------------------------------------------------------------------------- /configs/scripts/camera/points.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: points 4 | yaws: [0.0, -0.3, 0.3] 5 | pitch: 1.57 6 | -------------------------------------------------------------------------------- /configs/scripts/camera/wiggle.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: wiggle 4 | num_frames: 16 5 | yaw_left: -1.57 6 | yaw_right: 1.57 7 | pitch_diff: 0.5 8 | -------------------------------------------------------------------------------- /configs/scripts/camera/zoom_in_out.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rethinking-3d-gans/code/9bfc3ab32bd2b0992a229501e50bafcf232c5c11/configs/scripts/camera/zoom_in_out.yaml -------------------------------------------------------------------------------- /configs/scripts/extract_geometry.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # Main options 4 | seeds: [1] 5 | cube_size: 0.3 6 | voxel_res: 256 7 | max_batch_res: 32 8 | voxel_origin: [0.0, 0.0, 0.0] 9 | output_dir: shapes 10 | thresh_percentile: 97.5 11 | -------------------------------------------------------------------------------- /configs/scripts/inference.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - vis: front_grid 3 | - camera: points 4 | 5 | # Checkpoint loading options 6 | ckpt: 7 | network_pkl: ~ # Network pickle filename 8 | networks_dir: ~ # Network pickles directory 9 | selection_metric: fid2k_full # Which metric to use when selecting the best ckpt? 10 | reload_code: true 11 | 12 | # Randomness options 13 | seed: 1 # Random seed to fix non-generation randomness 14 | seeds: ~ 15 | num_seeds: ~ 16 | 17 | # Inference options 18 | batch_size: 16 19 | truncation_psi: 0.7 # Truncation psi. 20 | max_batch_res: 64 # Split image generation into chunks of the `max_batch_res`^2 resolution 21 | img_resolution: 256 # Image resolution of the generator? 22 | ray_step_multiplier: 2 # Inrease in the number of steps per ray 23 | synthesis_kwargs: {} # Empty by default 24 | force_whiteback: false 25 | 26 | # Logging options 27 | verbose: true 28 | output_dir: ~ 29 | -------------------------------------------------------------------------------- /configs/scripts/vis/bg_nobg.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: bg_nobg 4 | -------------------------------------------------------------------------------- /configs/scripts/vis/density.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: density 4 | -------------------------------------------------------------------------------- /configs/scripts/vis/front_grid.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: front_grid 4 | -------------------------------------------------------------------------------- /configs/scripts/vis/interp.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: interp -------------------------------------------------------------------------------- /configs/scripts/vis/interp_density.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: interp_density 4 | -------------------------------------------------------------------------------- /configs/scripts/vis/interp_video.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: interp_video 4 | fps: 30 5 | -------------------------------------------------------------------------------- /configs/scripts/vis/interp_video_grid.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: interp_video_grid 4 | fps: 30 5 | -------------------------------------------------------------------------------- /configs/scripts/vis/minigrid.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: minigrid 4 | -------------------------------------------------------------------------------- /configs/scripts/vis/rotation_video.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: rotation_video 4 | fps: 30 5 | -------------------------------------------------------------------------------- /configs/scripts/vis/video.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: video 4 | fps: 30 5 | -------------------------------------------------------------------------------- /configs/training/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 64 # Total batch size 4 | test_batch_gpu: 4 # Batch size at test time 5 | gamma: auto # R1 regularization weight 6 | 7 | # Optional features. 8 | use_labels: false # Train conditional model 9 | mirror: true # Enable dataset x-flips 10 | resume: latest # Resume from given network pickle 11 | freezed: 0 # Freeze first layers of D 12 | 13 | # Misc hyperparameters. 14 | p: 0.2 # Probability for aug=fixed 15 | target: 0.6 # Target value for aug=ada 16 | batch_gpu: null # Limit batch size per GPU 17 | map_depth: null # Mapping network depth [default: varies] 18 | 19 | # Misc settings. 20 | desc: null # String to include in result dir name 21 | metrics: fid50k_full # Quality metrics 22 | main_metric: __pick_first__ # Takes the first metric among `metrics` as the main one to compute the best checkpoint 23 | kimg: 100000 # Total training duration 24 | tick: 4 # How often to print progress 25 | val_freq: 100 # How often to compute metrics 26 | snap: 250 # How often to save snapshots 27 | image_snap: 100 # How often to save samples? 28 | seed: 0 # Random seed 29 | fp32: false # Disable mixed-precision 30 | nobench: false # Disable cuDNN benchmarking 31 | workers: 3 # DataLoader worker processes 32 | dry_run: false # Print training options and exit 33 | 34 | # Hyperparams for the reconstruction loss 35 | Grec_coef: 0.0 # For reconstruction losses 36 | perceptual_embedder: mse # one of ["laplace", "vgg16", "mse"] 37 | Grec_interval: ~ # Apply the reconstruction loss once per `Grec_interval` steps 38 | Grec_only: false # The only phase we do is Grec only. Used for debugging. 39 | use_instances: false # Should we use instances 40 | 41 | # Predicting camera pose loss weight 42 | predict_pose_loss_weight: 0.0 43 | 44 | # Old hyperparams 45 | mvc: 46 | adversarial: false 47 | perceptual_weight: 0.0 48 | 49 | # Default parameters for patch-wise training (in case it is enabled) 50 | patch: 51 | strategy: normal # Enabled by default 52 | patch_params_cond: true # Patch parameters pos-enc embeddings dimensionality 53 | discr_concat_coords_strategy: ~ 54 | discr_scales_hyper_cond: false 55 | min_scale_trg: 0.5 56 | max_scale: 1.0 57 | anneal_kimg: 10000 58 | resolution: 128 59 | mbstd_group_size: ${model.discriminator.mbstd_group_size} 60 | 61 | augment: 62 | mode: ada # Augmentation mode. One of ["noaug", "ada", "fixed"] 63 | # Augment probabilities for different transformation types 64 | probs: 65 | xflip: 0.0 66 | rotate90: 1.0 67 | xint: 1.0 68 | scale: 1.0 69 | rotate: 1.0 70 | xfrac: 1.0 71 | aniso: 1.0 72 | brightness: 1.0 73 | contrast: 1.0 74 | lumaflip: 1.0 75 | hue: 1.0 76 | saturation: 1.0 77 | -------------------------------------------------------------------------------- /configs/training/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: default 4 | patch: 5 | strategy: ~ 6 | patch_params_cond: false 7 | -------------------------------------------------------------------------------- /configs/training/patch_beta.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: patch_beta_p${training.patch.resolution} 4 | patch: 5 | distribution: beta 6 | alpha: 1.0 7 | beta_val_start: 0.001 8 | beta_val_end: 0.8 9 | -------------------------------------------------------------------------------- /configs/training/patch_categ.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: patch_categ_p${training.patch.resolution} 4 | patch: 5 | distribution: categorical 6 | support: [0.25, 0.5, 1.0] 7 | probs: [0.125, 0.175, 0.7] 8 | anneal_kimg: 0 9 | -------------------------------------------------------------------------------- /configs/training/patch_discrete_uniform.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: patch_discrete_uniform_p${training.patch.resolution} 4 | patch: 5 | distribution: discrete_uniform 6 | discrete_support: 7 | _target_: src.infra.utils.linspace 8 | val_from: 0.25 9 | val_to: 1.0 10 | num_steps: 2 11 | -------------------------------------------------------------------------------- /configs/training/patch_uniform.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: patch_uniform_p${training.patch.resolution} 4 | patch: 5 | distribution: uniform 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: rnf 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow=8.3.1 11 | - scipy=1.7.1 12 | - pytorch=1.10.2 13 | - torchvision 14 | - cudatoolkit=11.1 15 | - requests=2.26.0 16 | - tqdm=4.62.2 17 | - ninja=1.10.2 18 | - matplotlib=3.4.2 19 | - imageio=2.9.0 20 | - pip: 21 | - imgui==1.3.0 22 | - glfw==2.2.0 23 | - pyopengl==3.1.5 24 | - imageio-ffmpeg==0.4.3 25 | - pyspng 26 | - tensorboard==2.4.1 27 | - tqdm 28 | - gitpython 29 | - gpustat 30 | - hydra-core==1.0.7 31 | - av==8.1.0 32 | - torch_tb_profiler==0.3.1 33 | - joblib==1.1.0 34 | - -e . 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='rethinking_3d_gans', 5 | version='0.0.1', 6 | description='Rethinking training of 3D GANs', 7 | packages=find_packages(), 8 | ) 9 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rethinking-3d-gans/code/9bfc3ab32bd2b0992a229501e50bafcf232c5c11/src/__init__.py -------------------------------------------------------------------------------- /src/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /src/infra/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rethinking-3d-gans/code/9bfc3ab32bd2b0992a229501e50bafcf232c5c11/src/infra/__init__.py -------------------------------------------------------------------------------- /src/infra/experiments.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | common_args: {} 3 | experiments: 4 | default: {} 5 | 6 | patchwise_stylegan2: 7 | common_args: 8 | training: patchwise 9 | training.patch.max_scale: 1.0 10 | training.patch.anneal_kimg: 10000 11 | training.metrics: fid50k_full 12 | model: stylegan2 13 | model.loss_kwargs.pl_weight: 0 14 | experiments: 15 | p64-min0.999_anneal10k: 16 | training.patch.min_scale_trg: 0.999 17 | training.patch.resolution: 64 18 | p64-min0.5_anneal10k: 19 | training.patch.min_scale_trg: 0.5 20 | training.patch.resolution: 64 21 | p64-min0.25_anneal10k: 22 | training.patch.min_scale_trg: 0.5 23 | training.patch.resolution: 64 24 | p128-min0.9_anneal10k: 25 | training.patch.min_scale_trg: 0.9 26 | training.patch.resolution: 128 27 | p128-min0.5_anneal10k: 28 | training.patch.min_scale_trg: 0.5 29 | training.patch.resolution: 128 30 | p64-min0.9_anneal10k: 31 | training.patch.min_scale_trg: 0.9 32 | training.patch.resolution: 64 33 | p64-min0.75_anneal10k: 34 | training.patch.min_scale_trg: 0.75 35 | training.patch.resolution: 64 36 | p128-min0.75_anneal10k: 37 | training.patch.min_scale_trg: 0.75 38 | training.patch.resolution: 128 39 | -------------------------------------------------------------------------------- /src/infra/launch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run a __reproducible__ experiment on __allocated__ resources 3 | It submits a slurm job(s) with the given hyperparams which will then execute `slurm_job.py` 4 | This is the main entry-point 5 | """ 6 | 7 | import os 8 | import subprocess 9 | import re 10 | 11 | import hydra 12 | from omegaconf import DictConfig, OmegaConf 13 | from pathlib import Path 14 | 15 | from src.infra.utils import create_project_dir, recursive_instantiate 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | HYDRA_ARGS = "hydra.output_subdir=null hydra/job_logging=disabled hydra/hydra_logging=disabled" 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | @hydra.main(config_path="../../configs", config_name="config.yaml") 24 | def main(cfg: DictConfig): 25 | recursive_instantiate(cfg) 26 | OmegaConf.set_struct(cfg, True) 27 | cfg.env.project_path = str(cfg.env.project_path) # This is needed to evaluate ${hydra:runtime.cwd} 28 | 29 | before_train_cmd = '\n'.join(cfg.env.before_train_commands) 30 | before_train_cmd = before_train_cmd + '\n' if len(before_train_cmd) > 0 else '' 31 | torch_extensions_dir = os.environ.get('TORCH_EXTENSIONS_DIR', cfg.env.torch_extensions_dir) 32 | training_cmd = f'{before_train_cmd}TORCH_EXTENSIONS_DIR={torch_extensions_dir} cd {cfg.experiment_dir} && PYTHONPATH=. {cfg.env.python_bin} src/train.py hydra.run.dir={cfg.experiment_dir} {HYDRA_ARGS}' 33 | quiet = cfg.get('quiet', False) 34 | training_cmd_save_path = os.path.join(cfg.experiment_dir, 'training_cmd.sh') 35 | cfg_save_path = os.path.join(cfg.experiment_dir, 'experiment_config.yaml') 36 | 37 | if not quiet: 38 | print('<=== TRAINING COMMAND START ===>') 39 | print(training_cmd) 40 | print('<=== TRAINING COMMAND END ===>') 41 | 42 | is_running_from_scratch = True 43 | 44 | if cfg.training.resume == "latest" and os.path.isdir(cfg.experiment_dir) and os.path.isfile(training_cmd_save_path) and os.path.isfile(cfg_save_path): 45 | is_running_from_scratch = False 46 | if not quiet: 47 | print("We are going to resume the training and the experiment already exists. " \ 48 | "That's why the provided config/training_cmd are discarded and the project dir is not created.") 49 | 50 | # Adding training.resume=latest to the command. 51 | # TODO: this looks very dirty... 52 | with open(training_cmd_save_path) as f: 53 | training_cmd = f.read().splitlines() 54 | training_cmd[-1] = f'{training_cmd[-1]} training.resume=latest' 55 | training_cmd = '\n'.join(training_cmd) 56 | 57 | print('<=== NEW TRAINING COMMAND START ===>') 58 | print(training_cmd) 59 | print('<=== NEW TRAINING COMMAND END ===>') 60 | else: 61 | print('Running from scratch...') 62 | 63 | if is_running_from_scratch and not cfg.print_only: 64 | create_project_dir( 65 | cfg.experiment_dir, 66 | cfg.env.objects_to_copy, 67 | cfg.env.symlinks_to_create, 68 | quiet=quiet, 69 | ignore_uncommited_changes=cfg.get('ignore_uncommited_changes', False), 70 | overwrite=cfg.get('overwrite', False)) 71 | 72 | with open(training_cmd_save_path, 'w') as f: 73 | f.write(training_cmd + '\n') 74 | if not quiet: 75 | print(f'Saved training command in {training_cmd_save_path}') 76 | 77 | with open(cfg_save_path, 'w') as f: 78 | OmegaConf.save(config=cfg, f=f) 79 | if not quiet: 80 | print(f'Saved config in {cfg_save_path}') 81 | 82 | if not cfg.print_only: 83 | os.chdir(cfg.experiment_dir) 84 | 85 | if cfg.get('slurm', False): 86 | assert Path(cfg.dataset.path_for_slurm_job).exists(), f"Dataset {cfg.dataset.path_for_slurm_job} does not exist." 87 | 88 | curr_job_id = None 89 | 90 | for i in range(cfg.job_sequence_length): 91 | if i == 0: 92 | deps_args_str = '' 93 | else: 94 | deps_args_str = f'--dependency=afterany:{curr_job_id}' 95 | 96 | # Submitting the slurm job 97 | env_args_str = ','.join([f'{k}={v}' for k, v in cfg.env_args.items()]) 98 | output_file_arg_str = f'--output {cfg.experiment_dir}/slurm_{i}.log' 99 | submit_job_cmd = f'sbatch {cfg.sbatch_args_str} {output_file_arg_str} --export=ALL,{env_args_str} {deps_args_str} src/infra/slurm_job_proxy.sh' 100 | 101 | if cfg.print_only: 102 | print(submit_job_cmd) 103 | curr_job_id = "DUMMY_JOB_ID" 104 | else: 105 | result = subprocess.run(submit_job_cmd, stdout=subprocess.PIPE, shell=True) 106 | output_str = result.stdout.decode("utf-8").strip("\n") # It has a format of "Submitted batch job 17033559" 107 | if not quiet or i == 0: 108 | print(output_str) 109 | curr_job_id = re.findall(r"^Submitted\ batch\ job\ \d{5,8}$", output_str) 110 | assert len(curr_job_id) == 1, f"Bad output: `{output_str}`" 111 | curr_job_id = int(curr_job_id[0][len('Submitted batch job '):]) 112 | else: 113 | assert cfg.job_sequence_length == 1, "You can use a job sequence only when running via slurm." 114 | if cfg.print_only: 115 | print(training_cmd) 116 | else: 117 | os.system(training_cmd) 118 | 119 | #---------------------------------------------------------------------------- 120 | 121 | if __name__ == "__main__": 122 | main() 123 | 124 | #---------------------------------------------------------------------------- 125 | -------------------------------------------------------------------------------- /src/infra/slurm_job.py: -------------------------------------------------------------------------------- 1 | """ 2 | Must be launched from the released project dir 3 | """ 4 | 5 | import os 6 | import time 7 | import random 8 | import subprocess 9 | from shutil import copyfile 10 | 11 | import hydra 12 | from omegaconf import DictConfig, OmegaConf 13 | 14 | # Unfortunately, (AFAIK) we cannot pass arguments normally (to parse them with argparse) 15 | # that's why we are reading them from env 16 | SLURM_JOB_ID = os.getenv('SLURM_JOB_ID') 17 | project_dir = os.getenv('project_dir') 18 | python_bin = os.getenv('python_bin') 19 | 20 | # Printing the environment 21 | print('PROJECT DIR:', project_dir) 22 | print(f'SLURM_JOB_ID: {SLURM_JOB_ID}') 23 | print('HOSTNAME:', subprocess.run(['hostname'], stdout=subprocess.PIPE).stdout.decode('utf-8')) 24 | print(subprocess.run([os.path.join(os.path.dirname(python_bin), 'gpustat')], stdout=subprocess.PIPE).stdout.decode('utf-8')) 25 | 26 | @hydra.main(config_name=os.path.join(project_dir, 'experiment_config.yaml')) 27 | def main(cfg: DictConfig): 28 | os.chdir(project_dir) 29 | 30 | target_data_dir_base = os.path.dirname(cfg.dataset.path) 31 | if os.path.islink(target_data_dir_base): 32 | os.makedirs(os.readlink(target_data_dir_base), exist_ok=True) 33 | else: 34 | os.makedirs(target_data_dir_base, exist_ok=True) 35 | 36 | copyfile(cfg.dataset.path_for_slurm_job, cfg.dataset.path) 37 | print(f'Copied the data: {cfg.dataset.path_for_slurm_job} => {cfg.dataset.path}. Starting the training...') 38 | 39 | training_cmd = open('training_cmd.sh').read() 40 | print('<=== TRAINING COMMAND ===>') 41 | print(training_cmd) 42 | os.system(training_cmd) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /src/infra/slurm_job_proxy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # We need this proxy so not to put the shebang into `slurm_job.py` 3 | # We cannot put a shebang there since we use different python executors for it 4 | $python_bin $python_script 5 | -------------------------------------------------------------------------------- /src/infra/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | from distutils.dir_util import copy_tree 5 | from shutil import copyfile 6 | from typing import List, Optional 7 | 8 | from hydra.utils import instantiate 9 | import click 10 | import git 11 | from omegaconf import DictConfig 12 | 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | def copy_objects(target_dir: os.PathLike, objects_to_copy: List[os.PathLike]): 17 | for src_path in objects_to_copy: 18 | trg_path = os.path.join(target_dir, os.path.basename(src_path)) 19 | 20 | if os.path.islink(src_path): 21 | os.symlink(os.readlink(src_path), trg_path) 22 | elif os.path.isfile(src_path): 23 | copyfile(src_path, trg_path) 24 | elif os.path.isdir(src_path): 25 | copy_tree(src_path, trg_path) 26 | else: 27 | raise NotImplementedError(f"Unknown object type: {src_path}") 28 | 29 | #---------------------------------------------------------------------------- 30 | 31 | def create_symlinks(target_dir: os.PathLike, symlinks_to_create: List[os.PathLike]): 32 | """ 33 | Creates symlinks to the given paths 34 | """ 35 | for src_path in symlinks_to_create: 36 | trg_path = os.path.join(target_dir, os.path.basename(src_path)) 37 | 38 | if os.path.islink(src_path): 39 | # Let's not create symlinks to symlinks 40 | # Since dropping the current symlink will break the experiment 41 | os.symlink(os.readlink(src_path), trg_path) 42 | else: 43 | print(f'Creating a symlink to {src_path}, so try not to delete it occasionally!') 44 | os.symlink(src_path, trg_path) 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | def is_git_repo(path: os.PathLike): 49 | try: 50 | _ = git.Repo(path).git_dir 51 | return True 52 | except git.exc.InvalidGitRepositoryError: 53 | return False 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def create_project_dir( 58 | project_dir: os.PathLike, 59 | objects_to_copy: List[os.PathLike], 60 | symlinks_to_create: List[os.PathLike], 61 | quiet: bool=False, 62 | ignore_uncommited_changes: bool=False, 63 | overwrite: bool=False): 64 | 65 | if is_git_repo(os.getcwd()) and are_there_uncommitted_changes(): 66 | if ignore_uncommited_changes or click.confirm("There are uncommited changes. Continue?", default=False): 67 | pass 68 | else: 69 | raise PermissionError("Cannot created a dir when there are uncommited changes") 70 | 71 | if os.path.exists(project_dir): 72 | if overwrite or click.confirm(f'Dir {project_dir} already exists. Overwrite it?', default=False): 73 | shutil.rmtree(project_dir) 74 | else: 75 | print('User refused to delete an existing project dir.') 76 | raise PermissionError("There is an existing dir and I cannot delete it.") 77 | 78 | os.makedirs(project_dir) 79 | copy_objects(project_dir, objects_to_copy) 80 | create_symlinks(project_dir, symlinks_to_create) 81 | 82 | if not quiet: 83 | print(f'Created a project dir: {project_dir}') 84 | 85 | #---------------------------------------------------------------------------- 86 | 87 | def get_git_hash() -> Optional[str]: 88 | if not is_git_repo(os.getcwd()): 89 | return None 90 | 91 | try: 92 | return subprocess \ 93 | .check_output(['git', 'rev-parse', '--short', 'HEAD']) \ 94 | .decode("utf-8") \ 95 | .strip() 96 | except: 97 | return None 98 | 99 | #---------------------------------------------------------------------------- 100 | 101 | # def get_experiment_path(master_dir: os.PathLike, experiment_name: str) -> os.PathLike: 102 | # return os.path.join(master_dir, f"{experiment_name}-{get_git_hash()}") 103 | 104 | #---------------------------------------------------------------------------- 105 | 106 | def get_git_hash_suffix() -> str: 107 | git_hash: Optional[str] = get_git_hash() 108 | git_hash_suffix = "-nogit" if git_hash is None else f"-{git_hash}" 109 | 110 | return git_hash_suffix 111 | 112 | #---------------------------------------------------------------------------- 113 | 114 | def are_there_uncommitted_changes() -> bool: 115 | return len(subprocess.check_output('git status -s'.split()).decode("utf-8")) > 0 116 | 117 | #---------------------------------------------------------------------------- 118 | 119 | def cfg_to_args_str(cfg: DictConfig, use_dashes=True) -> str: 120 | dashes = '--' if use_dashes else '' 121 | 122 | return ' '.join([f'{dashes}{p}={cfg[p]}' for p in cfg]) 123 | 124 | #---------------------------------------------------------------------------- 125 | 126 | def recursive_instantiate(cfg: DictConfig): 127 | for key in cfg: 128 | # print(type(cfg[key])) 129 | if isinstance(cfg[key], DictConfig): 130 | if '_target_' in cfg[key]: 131 | cfg[key] = instantiate(cfg[key]) 132 | else: 133 | recursive_instantiate(cfg[key]) 134 | 135 | #---------------------------------------------------------------------------- 136 | 137 | def num_gpus_to_mem(num_gpus: int, mem_per_gpu: 64) -> str: 138 | # Doing it here since hydra config cannot do formatting for ${...} 139 | return f"{num_gpus * mem_per_gpu}G" 140 | 141 | #---------------------------------------------------------------------------- 142 | 143 | def product(values): 144 | import numpy as np 145 | return np.prod([x for x in values]).item() 146 | 147 | #---------------------------------------------------------------------------- 148 | 149 | def linspace(val_from: float, val_to: float, num_steps: int) -> List[float]: 150 | # import numpy as np 151 | # return np.linspace(val_from, val_to, num_steps).tolist() 152 | assert num_steps > 1, f"Too small num_steps: {num_steps}" 153 | return [val_from + (val_to - val_from) * i / (num_steps - 1) for i in range(num_steps)] 154 | 155 | #---------------------------------------------------------------------------- 156 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /src/metrics/equivariance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper 10 | "Alias-Free Generative Adversarial Networks".""" 11 | 12 | import copy 13 | import numpy as np 14 | import torch 15 | import torch.fft 16 | from src.torch_utils.ops import upfirdn2d 17 | from src.metrics import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | # Utilities. 21 | 22 | def sinc(x): 23 | y = (x * np.pi).abs() 24 | z = torch.sin(y) / y.clamp(1e-30, float('inf')) 25 | return torch.where(y < 1e-30, torch.ones_like(x), z) 26 | 27 | def lanczos_window(x, a): 28 | x = x.abs() / a 29 | return torch.where(x < 1, sinc(x), torch.zeros_like(x)) 30 | 31 | def rotation_matrix(angle): 32 | angle = torch.as_tensor(angle).to(torch.float32) 33 | mat = torch.eye(3, device=angle.device) 34 | mat[0, 0] = angle.cos() 35 | mat[0, 1] = angle.sin() 36 | mat[1, 0] = -angle.sin() 37 | mat[1, 1] = angle.cos() 38 | return mat 39 | 40 | #---------------------------------------------------------------------------- 41 | # Apply integer translation to a batch of 2D images. Corresponds to the 42 | # operator T_x in Appendix E.1. 43 | 44 | def apply_integer_translation(x, tx, ty): 45 | _N, _C, H, W = x.shape 46 | tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) 47 | ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) 48 | ix = tx.round().to(torch.int64) 49 | iy = ty.round().to(torch.int64) 50 | 51 | z = torch.zeros_like(x) 52 | m = torch.zeros_like(x) 53 | if abs(ix) < W and abs(iy) < H: 54 | y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)] 55 | z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y 56 | m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1 57 | return z, m 58 | 59 | #---------------------------------------------------------------------------- 60 | # Apply integer translation to a batch of 2D images. Corresponds to the 61 | # operator T_x in Appendix E.2. 62 | 63 | def apply_fractional_translation(x, tx, ty, a=3): 64 | _N, _C, H, W = x.shape 65 | tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) 66 | ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) 67 | ix = tx.floor().to(torch.int64) 68 | iy = ty.floor().to(torch.int64) 69 | fx = tx - ix 70 | fy = ty - iy 71 | b = a - 1 72 | 73 | z = torch.zeros_like(x) 74 | zx0 = max(ix - b, 0) 75 | zy0 = max(iy - b, 0) 76 | zx1 = min(ix + a, 0) + W 77 | zy1 = min(iy + a, 0) + H 78 | if zx0 < zx1 and zy0 < zy1: 79 | taps = torch.arange(a * 2, device=x.device) - b 80 | filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0) 81 | filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1) 82 | y = x 83 | y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0]) 84 | y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a]) 85 | y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)] 86 | z[:, :, zy0:zy1, zx0:zx1] = y 87 | 88 | m = torch.zeros_like(x) 89 | mx0 = max(ix + a, 0) 90 | my0 = max(iy + a, 0) 91 | mx1 = min(ix - b, 0) + W 92 | my1 = min(iy - b, 0) + H 93 | if mx0 < mx1 and my0 < my1: 94 | m[:, :, my0:my1, mx0:mx1] = 1 95 | return z, m 96 | 97 | #---------------------------------------------------------------------------- 98 | # Construct an oriented low-pass filter that applies the appropriate 99 | # bandlimit with respect to the input and output of the given affine 2D 100 | # image transformation. 101 | 102 | def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): 103 | assert a <= amax < aflt 104 | mat = torch.as_tensor(mat).to(torch.float32) 105 | 106 | # Construct 2D filter taps in input & output coordinate spaces. 107 | taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) 108 | yi, xi = torch.meshgrid(taps, taps) 109 | xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) 110 | 111 | # Convolution of two oriented 2D sinc filters. 112 | fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in) 113 | fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out) 114 | f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real 115 | 116 | # Convolution of two oriented 2D Lanczos windows. 117 | wi = lanczos_window(xi, a) * lanczos_window(yi, a) 118 | wo = lanczos_window(xo, a) * lanczos_window(yo, a) 119 | w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real 120 | 121 | # Construct windowed FIR filter. 122 | f = f * w 123 | 124 | # Finalize. 125 | c = (aflt - amax) * up 126 | f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] 127 | f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) 128 | f = f / f.sum([0,2], keepdim=True) / (up ** 2) 129 | f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] 130 | return f 131 | 132 | #---------------------------------------------------------------------------- 133 | # Apply the given affine transformation to a batch of 2D images. 134 | 135 | def apply_affine_transformation(x, mat, up=4, **filter_kwargs): 136 | _N, _C, H, W = x.shape 137 | mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) 138 | 139 | # Construct filter. 140 | f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) 141 | assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 142 | p = f.shape[0] // 2 143 | 144 | # Construct sampling grid. 145 | theta = mat.inverse() 146 | theta[:2, 2] *= 2 147 | theta[0, 2] += 1 / up / W 148 | theta[1, 2] += 1 / up / H 149 | theta[0, :] *= W / (W + p / up * 2) 150 | theta[1, :] *= H / (H + p / up * 2) 151 | theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) 152 | g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) 153 | 154 | # Resample image. 155 | y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) 156 | z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) 157 | 158 | # Form mask. 159 | m = torch.zeros_like(y) 160 | c = p * 2 + 1 161 | m[:, :, c:-c, c:-c] = 1 162 | m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) 163 | return z, m 164 | 165 | #---------------------------------------------------------------------------- 166 | # Apply fractional rotation to a batch of 2D images. Corresponds to the 167 | # operator R_\alpha in Appendix E.3. 168 | 169 | def apply_fractional_rotation(x, angle, a=3, **filter_kwargs): 170 | angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) 171 | mat = rotation_matrix(angle) 172 | return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs) 173 | 174 | #---------------------------------------------------------------------------- 175 | # Modify the frequency content of a batch of 2D images as if they had undergo 176 | # fractional rotation -- but without actually rotating them. Corresponds to 177 | # the operator R^*_\alpha in Appendix E.3. 178 | 179 | def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs): 180 | angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) 181 | mat = rotation_matrix(-angle) 182 | f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs) 183 | y = upfirdn2d.filter2d(x=x, f=f) 184 | m = torch.zeros_like(y) 185 | c = f.shape[0] // 2 186 | m[:, :, c:-c, c:-c] = 1 187 | return y, m 188 | 189 | #---------------------------------------------------------------------------- 190 | # Compute the selected equivariance metrics for the given generator. 191 | 192 | def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False): 193 | assert compute_eqt_int or compute_eqt_frac or compute_eqr 194 | 195 | # Setup generator and labels. 196 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) 197 | I = torch.eye(3, device=opts.device) 198 | M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None) 199 | if M is None: 200 | raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations') 201 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 202 | 203 | # Sampling loop. 204 | sums = None 205 | progress = opts.progress.sub(tag='eq sampling', num_items=num_samples) 206 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 207 | progress.update(batch_start) 208 | s = [] 209 | 210 | # Randomize noise buffers, if any. 211 | for name, buf in G.named_buffers(): 212 | if name.endswith('.noise_const'): 213 | buf.copy_(torch.randn_like(buf)) 214 | 215 | # Run mapping network. 216 | z = torch.randn([batch_size, G.z_dim], device=opts.device) 217 | c = next(c_iter) 218 | ws = G.mapping(z=z, c=c) 219 | 220 | # Generate reference image. 221 | M[:] = I 222 | orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) 223 | 224 | # Integer translation (EQ-T). 225 | if compute_eqt_int: 226 | t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max 227 | t = (t * G.img_resolution).round() / G.img_resolution 228 | M[:] = I 229 | M[:2, 2] = -t 230 | img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) 231 | ref, mask = apply_integer_translation(orig, t[0], t[1]) 232 | s += [(ref - img).square() * mask, mask] 233 | 234 | # Fractional translation (EQ-T_frac). 235 | if compute_eqt_frac: 236 | t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max 237 | M[:] = I 238 | M[:2, 2] = -t 239 | img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) 240 | ref, mask = apply_fractional_translation(orig, t[0], t[1]) 241 | s += [(ref - img).square() * mask, mask] 242 | 243 | # Rotation (EQ-R). 244 | if compute_eqr: 245 | angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi) 246 | M[:] = rotation_matrix(-angle) 247 | img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) 248 | ref, ref_mask = apply_fractional_rotation(orig, angle) 249 | pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle) 250 | mask = ref_mask * pseudo_mask 251 | s += [(ref - pseudo).square() * mask, mask] 252 | 253 | # Accumulate results. 254 | s = torch.stack([x.to(torch.float64).sum() for x in s]) 255 | sums = sums + s if sums is not None else s 256 | progress.update(num_samples) 257 | 258 | # Compute PSNRs. 259 | if opts.num_gpus > 1: 260 | torch.distributed.all_reduce(sums) 261 | sums = sums.cpu() 262 | mses = sums[0::2] / sums[1::2] 263 | psnrs = np.log10(2) * 20 - mses.log10() * 10 264 | psnrs = tuple(psnrs.numpy()) 265 | return psnrs[0] if len(psnrs) == 1 else psnrs 266 | 267 | #---------------------------------------------------------------------------- 268 | -------------------------------------------------------------------------------- /src/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from src.metrics import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /src/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from src.metrics import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /src/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from src.metrics import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /src/metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Main API for computing and reporting quality metrics.""" 10 | 11 | import os 12 | import time 13 | import json 14 | import torch 15 | from src import dnnlib 16 | 17 | from src.metrics import metric_utils 18 | from src.metrics import frechet_inception_distance 19 | from src.metrics import kernel_inception_distance 20 | from src.metrics import precision_recall 21 | from src.metrics import perceptual_path_length 22 | from src.metrics import inception_score 23 | from src.metrics import equivariance 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _metric_dict = dict() # name => fn 28 | 29 | def register_metric(fn): 30 | assert callable(fn) 31 | _metric_dict[fn.__name__] = fn 32 | return fn 33 | 34 | def is_valid_metric(metric): 35 | return metric in _metric_dict 36 | 37 | def list_valid_metrics(): 38 | return list(_metric_dict.keys()) 39 | 40 | #---------------------------------------------------------------------------- 41 | 42 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 43 | assert is_valid_metric(metric) 44 | opts = metric_utils.MetricOptions(**kwargs) 45 | 46 | # Calculate. 47 | start_time = time.time() 48 | results = _metric_dict[metric](opts) 49 | total_time = time.time() - start_time 50 | 51 | # Broadcast results. 52 | for key, value in list(results.items()): 53 | if opts.num_gpus > 1: 54 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 55 | torch.distributed.broadcast(tensor=value, src=0) 56 | value = float(value.cpu()) 57 | results[key] = value 58 | 59 | # Decorate with metadata. 60 | return dnnlib.EasyDict( 61 | results = dnnlib.EasyDict(results), 62 | metric = metric, 63 | total_time = total_time, 64 | total_time_str = dnnlib.util.format_time(total_time), 65 | num_gpus = opts.num_gpus, 66 | ) 67 | 68 | #---------------------------------------------------------------------------- 69 | 70 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 71 | metric = result_dict['metric'] 72 | assert is_valid_metric(metric) 73 | if run_dir is not None and snapshot_pkl is not None: 74 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 75 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 76 | print(jsonl_line) 77 | if run_dir is not None and os.path.isdir(run_dir): 78 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 79 | f.write(jsonl_line + '\n') 80 | 81 | #---------------------------------------------------------------------------- 82 | # Recommended metrics. 83 | 84 | @register_metric 85 | def fid50k_full(opts): 86 | opts.dataset_kwargs.update(max_size=None, xflip=False) 87 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 88 | return dict(fid50k_full=fid) 89 | 90 | @register_metric 91 | def kid50k_full(opts): 92 | opts.dataset_kwargs.update(max_size=None, xflip=False) 93 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 94 | return dict(kid50k_full=kid) 95 | 96 | @register_metric 97 | def pr50k3_full(opts): 98 | opts.dataset_kwargs.update(max_size=None, xflip=False) 99 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 100 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 101 | 102 | @register_metric 103 | def ppl2_wend(opts): 104 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 105 | return dict(ppl2_wend=ppl) 106 | 107 | @register_metric 108 | def eqt50k_int(opts): 109 | opts.G_kwargs.update(force_fp32=True) 110 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) 111 | return dict(eqt50k_int=psnr) 112 | 113 | @register_metric 114 | def eqt50k_frac(opts): 115 | opts.G_kwargs.update(force_fp32=True) 116 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) 117 | return dict(eqt50k_frac=psnr) 118 | 119 | @register_metric 120 | def eqr50k(opts): 121 | opts.G_kwargs.update(force_fp32=True) 122 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) 123 | return dict(eqr50k=psnr) 124 | 125 | @register_metric 126 | def fid5k_5k(opts): 127 | """The metric used by GRAM""" 128 | opts.dataset_kwargs.update(max_size=None) 129 | fid = frechet_inception_distance.compute_fid(opts, max_real=5000, num_gen=5000) 130 | return dict(fid5k_5k=fid) 131 | 132 | #---------------------------------------------------------------------------- 133 | # Legacy metrics. 134 | 135 | @register_metric 136 | def fid50k(opts): 137 | opts.dataset_kwargs.update(max_size=None) 138 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 139 | return dict(fid50k=fid) 140 | 141 | @register_metric 142 | def kid50k(opts): 143 | opts.dataset_kwargs.update(max_size=None) 144 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 145 | return dict(kid50k=kid) 146 | 147 | @register_metric 148 | def pr50k3(opts): 149 | opts.dataset_kwargs.update(max_size=None) 150 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 151 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 152 | 153 | @register_metric 154 | def is50k(opts): 155 | opts.dataset_kwargs.update(max_size=None, xflip=False) 156 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 157 | return dict(is50k_mean=mean, is50k_std=std) 158 | 159 | #---------------------------------------------------------------------------- 160 | # Debugging metrics (fast to compute) 161 | 162 | @register_metric 163 | def fid2k_full(opts): 164 | opts.dataset_kwargs.update(max_size=None, xflip=False) 165 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=2048) 166 | return dict(fid2k_full=fid) 167 | 168 | #---------------------------------------------------------------------------- 169 | -------------------------------------------------------------------------------- /src/metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | from src.metrics import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | # Spherical interpolation of a batch of vectors. 22 | def slerp(a, b, t): 23 | a = a / a.norm(dim=-1, keepdim=True) 24 | b = b / b.norm(dim=-1, keepdim=True) 25 | d = (a * b).sum(dim=-1, keepdim=True) 26 | p = t * torch.acos(d) 27 | c = b - d * a 28 | c = c / c.norm(dim=-1, keepdim=True) 29 | d = a * torch.cos(p) + c * torch.sin(p) 30 | d = d / d.norm(dim=-1, keepdim=True) 31 | return d 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPLSampler(torch.nn.Module): 36 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__() 40 | self.G = copy.deepcopy(G) 41 | self.G_kwargs = G_kwargs 42 | self.epsilon = epsilon 43 | self.space = space 44 | self.sampling = sampling 45 | self.crop = crop 46 | self.vgg16 = copy.deepcopy(vgg16) 47 | 48 | def forward(self, c): 49 | # Generate random latents and interpolation t-values. 50 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 51 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 52 | 53 | # Interpolate in W or Z. 54 | if self.space == 'w': 55 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 56 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 57 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 58 | else: # space == 'z' 59 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 60 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 61 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 62 | 63 | # Randomize noise buffers. 64 | for name, buf in self.G.named_buffers(): 65 | if name.endswith('.noise_const'): 66 | buf.copy_(torch.randn_like(buf)) 67 | 68 | # Generate images. 69 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 70 | 71 | # Center crop. 72 | if self.crop: 73 | assert img.shape[2] == img.shape[3] 74 | c = img.shape[2] // 8 75 | img = img[:, :, c*3 : c*7, c*2 : c*6] 76 | 77 | # Downsample to 256x256. 78 | factor = self.G.img_resolution // 256 79 | if factor > 1: 80 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 81 | 82 | # Scale dynamic range from [-1,1] to [0,255]. 83 | img = (img + 1) * (255 / 2) 84 | if self.G.img_channels == 1: 85 | img = img.repeat([1, 3, 1, 1]) 86 | 87 | # Evaluate differential LPIPS. 88 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 89 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 90 | return dist 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 95 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 96 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 97 | 98 | # Setup sampler and labels. 99 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 100 | sampler.eval().requires_grad_(False).to(opts.device) 101 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 102 | 103 | # Sampling loop. 104 | dist = [] 105 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 106 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 107 | progress.update(batch_start) 108 | x = sampler(next(c_iter)) 109 | for src in range(opts.num_gpus): 110 | y = x.clone() 111 | if opts.num_gpus > 1: 112 | torch.distributed.broadcast(y, src=src) 113 | dist.append(y) 114 | progress.update(num_samples) 115 | 116 | # Compute PPL. 117 | if opts.rank != 0: 118 | return float('nan') 119 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 120 | lo = np.percentile(dist, 1, interpolation='lower') 121 | hi = np.percentile(dist, 99, interpolation='higher') 122 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 123 | return float(ppl) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /src/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from src.metrics import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /src/scripts/calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | # Add `src` to sys.path. Otherwise, we get ModuleNotFound for torch_utils :( 12 | # (when loading the inception model). TODO: wtf? 13 | import sys; sys.path.append('src') 14 | 15 | import os 16 | import copy 17 | import tempfile 18 | 19 | import torch 20 | import hydra 21 | from omegaconf import DictConfig 22 | 23 | from src import dnnlib 24 | from src.metrics import metric_main 25 | from src.metrics import metric_utils 26 | from src.torch_utils import training_stats 27 | from src.torch_utils import custom_ops 28 | from src.torch_utils import misc 29 | from src.torch_utils.ops import conv2d_gradfix 30 | from src.scripts.utils import load_generator 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def subprocess_fn(rank, args, temp_dir): 35 | dnnlib.util.Logger(should_flush=True) 36 | 37 | # Init torch.distributed. 38 | if args.num_gpus > 1: 39 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 40 | if os.name == 'nt': 41 | init_method = 'file:///' + init_file.replace('\\', '/') 42 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 43 | else: 44 | init_method = f'file://{init_file}' 45 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 46 | 47 | # Init torch_utils. 48 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 49 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 50 | if rank != 0 or not args.verbose: 51 | custom_ops.verbosity = 'none' 52 | 53 | # Configure torch 54 | device = torch.device('cuda', rank) 55 | torch.backends.cudnn.benchmark = True 56 | torch.backends.cuda.matmul.allow_tf32 = False 57 | torch.backends.cudnn.allow_tf32 = False 58 | conv2d_gradfix.enabled = True 59 | 60 | # Print network summary. 61 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 62 | if rank == 0 and args.verbose: 63 | z = torch.empty([2, G.z_dim], device=device) 64 | c = torch.empty([2, G.c_dim], device=device) 65 | camera_angles = torch.empty([2, 3], device=device) 66 | misc.print_module_summary(G, [z, c], module_kwargs={'camera_angles': camera_angles}) 67 | 68 | # Calculate each metric. 69 | for metric in args.metrics: 70 | if rank == 0 and args.verbose: 71 | print(f'Calculating {metric}...') 72 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 73 | result_dict = metric_main.calc_metric( 74 | metric=metric, 75 | G=G, 76 | dataset_kwargs=args.dataset_kwargs, 77 | num_gpus=args.num_gpus, 78 | rank=rank, 79 | device=device, 80 | progress=progress, 81 | ) 82 | if rank == 0: 83 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 84 | if rank == 0 and args.verbose: 85 | print() 86 | 87 | # Done. 88 | if rank == 0 and args.verbose: 89 | print('Exiting...') 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | def parse_comma_separated_list(s): 94 | if isinstance(s, list): 95 | return s 96 | if s is None or s.lower() == 'none' or s == '': 97 | return [] 98 | return s.split(',') 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | @hydra.main(config_path="../../configs", config_name="calc_metrics.yaml") 103 | def calc_metrics(cfg: DictConfig): 104 | dnnlib.util.Logger(should_flush=True) 105 | 106 | device = torch.device('cuda') 107 | G, snapshot, network_pkl = load_generator(cfg.ckpt, verbose=cfg.verbose) 108 | G = G.to(device).eval() 109 | 110 | # Validate arguments. 111 | args = dnnlib.EasyDict(metrics=cfg.metrics.split(','), num_gpus=cfg.gpus, network_pkl=network_pkl, verbose=cfg.verbose, G=G) 112 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 113 | raise ValueError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 114 | if not args.num_gpus >= 1: 115 | raise ValueError('--gpus must be at least 1') 116 | 117 | # Initialize dataset options. 118 | if cfg.data is not None: 119 | args.dataset_kwargs = dnnlib.EasyDict(class_name='src.training.dataset.ImageFolderDataset', path=cfg.data) 120 | elif snapshot['training_set_kwargs'] is not None: 121 | args.dataset_kwargs = dnnlib.EasyDict(snapshot['training_set_kwargs']) 122 | else: 123 | raise ValueError('Could not look up dataset options; please specify --data') 124 | 125 | # Finalize dataset options. 126 | args.G.img_resolution = args.G.img_resolution if cfg.img_resolution is None else cfg.img_resolution 127 | args.dataset_kwargs.resolution = args.G.img_resolution 128 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 129 | if not cfg.mirror is None: 130 | args.dataset_kwargs.xflip = cfg.mirror 131 | 132 | # Locate run dir. 133 | args.run_dir = None 134 | if os.path.isfile(network_pkl): 135 | pkl_dir = os.path.dirname(network_pkl) 136 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 137 | args.run_dir = pkl_dir 138 | 139 | # Launch processes. 140 | if args.verbose: 141 | print('Launching processes...') 142 | torch.multiprocessing.set_start_method('spawn') 143 | with tempfile.TemporaryDirectory() as temp_dir: 144 | if args.num_gpus == 1: 145 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 146 | else: 147 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 148 | 149 | #---------------------------------------------------------------------------- 150 | 151 | if __name__ == "__main__": 152 | calc_metrics() # pylint: disable=no-value-for-parameter 153 | 154 | #---------------------------------------------------------------------------- 155 | -------------------------------------------------------------------------------- /src/scripts/extract_geometry.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | # from ast import DictComp 3 | # import plyfile 4 | # import argparse 5 | import torch 6 | import numpy as np 7 | # import skimage.measure 8 | # import scipy 9 | import mcubes 10 | import trimesh 11 | import mrcfile 12 | import os 13 | import hydra 14 | from tqdm import tqdm 15 | 16 | from src.scripts.utils import load_generator, set_seed, create_voxel_coords 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | @hydra.main(config_path="../../configs/scripts", config_name="extract_geometry.yaml") 21 | def extract_geometry(cfg: DictConfig): 22 | device = torch.device('cuda') 23 | G = load_generator(cfg.ckpt, verbose=cfg.verbose)[0].to(device).eval() 24 | 25 | for seed in tqdm(cfg.seeds, desc='Extracting geometry...'): 26 | set_seed(seed) 27 | batch_size = 1 28 | z = torch.randn(batch_size, G.z_dim, device=device) # [batch_size, z_dim] 29 | c = torch.zeros(batch_size, G.c_dim, device=device) # [batch_size, c_dim] 30 | assert G.c_dim == 0 31 | coords = create_voxel_coords(cfg.voxel_res, cfg.voxel_origin, cfg.cube_size, batch_size) # [batch_size, voxel_res ** 3, 3] 32 | coords = coords.to(z.device) # [batch_size, voxel_res ** 3, 3] 33 | ws = G.mapping(z, c, truncation_psi=cfg.truncation_psi, noise_mode='const') # [batch_size, num_ws, w_dim] 34 | sigma = G.synthesis.compute_densities(ws, coords, max_batch_res=cfg.max_batch_res) # [batch_size, voxel_res ** 3, 1] 35 | assert batch_size == 1 36 | sigma = sigma.reshape(cfg.voxel_res, cfg.voxel_res, cfg.voxel_res).cpu().numpy() # [voxel_res ** 3] 37 | 38 | print('sigma percentiles:', {q: np.percentile(sigma.reshape(-1), q) for q in [50.0, 90.0, 95.0, 97.5, 99.0, 99.5]}) 39 | vertices, triangles = mcubes.marching_cubes(sigma, np.percentile(sigma, cfg.thresh_percentile)) 40 | mesh = trimesh.Trimesh(vertices, triangles) 41 | os.makedirs('shapes', exist_ok=True) 42 | mesh.export(f'shapes/shape-{seed}.obj') 43 | 44 | # os.makedirs(cfg.output_dir, exist_ok=True) 45 | # with mrcfile.new_mmap(os.path.join(cfg.output_dir, f'{seed}.mrc'), overwrite=True, shape=sigma.shape, mrc_mode=2) as mrc: 46 | # mrc.data[:] = sigma 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | if __name__ == "__main__": 51 | extract_geometry() # pylint: disable=no-value-for-parameter 52 | 53 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /src/scripts/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import shutil 5 | import random 6 | import itertools 7 | import contextlib 8 | import zipfile 9 | from typing import List, Dict, Tuple 10 | 11 | import click 12 | import joblib 13 | from omegaconf import DictConfig 14 | import numpy as np 15 | from PIL import Image 16 | import torch 17 | import torchvision.transforms.functional as TVF 18 | from torchvision.utils import make_grid 19 | from tqdm import tqdm 20 | from src import dnnlib, legacy 21 | 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | @contextlib.contextmanager 26 | def tqdm_joblib(tqdm_object): 27 | """Context manager to patch joblib to report into tqdm progress bar given as argument""" 28 | class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): 29 | def __call__(self, *args, **kwargs): 30 | tqdm_object.update(n=self.batch_size) 31 | return super().__call__(*args, **kwargs) 32 | 33 | old_batch_callback = joblib.parallel.BatchCompletionCallBack 34 | joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback 35 | try: 36 | yield tqdm_object 37 | finally: 38 | joblib.parallel.BatchCompletionCallBack = old_batch_callback 39 | tqdm_object.close() 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def display_dir(dir_path: os.PathLike, num_imgs: int=25, selection_strategy: str="order", n_skip_imgs: int=0, **kwargs) -> "Image": 44 | if selection_strategy in ('order', 'random'): 45 | img_fnames = [os.path.relpath(os.path.join(root, fname), start=dir_path) for root, _dirs, files in os.walk(dir_path) for fname in files] 46 | img_paths = [os.path.join(dir_path, f) for f in sorted(img_fnames)] 47 | img_paths = img_paths[n_skip_imgs:] 48 | 49 | if selection_strategy == 'order': 50 | img_paths = img_paths[:num_imgs] 51 | elif selection_strategy == 'random': 52 | img_paths = random.sample(img_paths, k=num_imgs) 53 | elif selection_strategy == 'random_imgs_from_subdirs': 54 | img_paths = [p for d in [d for d in listdir_full_paths(dir_path) if os.path.isdir(d)] for p in random.sample(listdir_full_paths(d), k=num_imgs)] 55 | else: 56 | raise NotImplementedError(f'Unknown selection strategy: {selection_strategy}') 57 | 58 | return display_imgs(img_paths, **kwargs) 59 | 60 | #---------------------------------------------------------------------------- 61 | 62 | def display_imgs(img_paths: List[os.PathLike], nrow: bool=None, resize: int=None, crop: Tuple=None, padding: int=2) -> "Image": 63 | imgs = [Image.open(p) for p in img_paths] 64 | if not crop is None: 65 | imgs = [img.crop(crop) for img in imgs] 66 | if not resize is None: 67 | imgs = [TVF.resize(x, size=resize, interpolation=TVF.InterpolationMode.LANCZOS) for x in imgs] 68 | imgs = torch.stack([TVF.to_tensor(TVF.center_crop(x, output_size=min(x.size))) for x in imgs]) # [num_imgs, c, h, w] 69 | grid = make_grid(imgs, nrow=(int(np.sqrt(imgs.shape[0])) if nrow is None else nrow), padding=padding) # [c, grid_h, grid_w] 70 | grid = TVF.to_pil_image(grid) 71 | 72 | return grid 73 | 74 | #---------------------------------------------------------------------------- 75 | 76 | def resize_and_save_image(src_path: str, trg_path: str, size: int): 77 | img = Image.open(src_path) 78 | img.load() # required for png.split() 79 | img = center_resize_crop(img, size) 80 | jpg_kwargs = {'quality': 95} if file_ext(trg_path) == '.jpg' else {} 81 | 82 | if file_ext(src_path) == '.png' and file_ext(trg_path) == '.jpg' and len(img.split()) == 4: 83 | jpg = Image.new("RGB", img.size, (255, 255, 255)) 84 | jpg.paste(img, mask=img.split()[3]) # 3 is the alpha channel 85 | jpg.save(trg_path, **jpg_kwargs) 86 | else: 87 | img.save(trg_path, **jpg_kwargs) 88 | 89 | #---------------------------------------------------------------------------- 90 | 91 | def center_resize_crop(img: Image, size: int) -> Image: 92 | img = TVF.center_crop(img, min(img.size)) # First, make it square 93 | img = TVF.resize(img, size, interpolation=TVF.InterpolationMode.LANCZOS) # Now, resize it 94 | 95 | return img 96 | 97 | #---------------------------------------------------------------------------- 98 | 99 | def file_ext(path: os.PathLike) -> str: 100 | return os.path.splitext(path)[1].lower() 101 | 102 | #---------------------------------------------------------------------------- 103 | 104 | # Extract the zip file for simplicity... 105 | def extract_zip(zip_path: os.PathLike, overwrite: bool=False): 106 | assert file_ext(zip_path) == '.zip', f'Not a zip archive: {zip_path}' 107 | 108 | if os.path.exists(zip_path[:-4]): 109 | if overwrite or click.confirm(f'Dir {zip_path[:-4]} already exists. Delete it?', default=False): 110 | shutil.rmtree(zip_path[:-4]) 111 | 112 | with zipfile.ZipFile(zip_path, 'r') as zip_ref: 113 | zip_ref.extractall(os.path.dirname(zip_path[:-4])) 114 | 115 | #---------------------------------------------------------------------------- 116 | 117 | def compress_to_zip(dir_to_compress: os.PathLike, delete: bool=False): 118 | shutil.make_archive(dir_to_compress, 'zip', root_dir=os.path.dirname(dir_to_compress), base_dir=os.path.basename(dir_to_compress)) 119 | 120 | if delete: 121 | shutil.rmtree(dir_to_compress) 122 | 123 | #---------------------------------------------------------------------------- 124 | 125 | def list_full_paths(dir_path: os.PathLike) -> List[os.PathLike]: 126 | """ 127 | Returns a list of full paths to all objects in the given directory. 128 | """ 129 | return [os.path.join(dir_path, f) for f in sorted(os.listdir(dir_path))] 130 | 131 | #---------------------------------------------------------------------------- 132 | 133 | def load_generator(cfg: DictConfig, verbose: bool=True) -> Tuple[torch.nn.Module, Dict, str]: 134 | if cfg.network_pkl is None: 135 | if not cfg.selection_metric is None: 136 | metrics_file = os.path.join(cfg.networks_dir, f'metric-{cfg.selection_metric}.jsonl') 137 | with open(metrics_file, 'r') as f: 138 | snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()] 139 | snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results'][cfg.selection_metric])[0] 140 | network_pkl = os.path.join(cfg.networks_dir, snapshot['snapshot_pkl']) 141 | if verbose: 142 | print(f'Using checkpoint: {network_pkl} with {cfg.selection_metric} of', snapshot['results'][cfg.selection_metric]) 143 | else: 144 | output_regex = "^network-snapshot-\d{6}.pkl$" 145 | ckpt_regex = re.compile(output_regex) 146 | ckpts = sorted([f for f in os.listdir(cfg.networks_dir) if ckpt_regex.match(f)]) 147 | network_pkl = os.path.join(cfg.networks_dir, ckpts[-1]) 148 | if verbose: 149 | print(f"Using the latest found checkpoint: {network_pkl}") 150 | else: 151 | assert cfg.networks_dir is None, "Cant have both parameters: network_pkl and cfg.networks_dir" 152 | network_pkl = cfg.network_pkl 153 | 154 | # Load network. 155 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 156 | raise ValueError('--network must point to a file or URL') 157 | if verbose: 158 | print(f'Loading networks from {network_pkl}') 159 | with dnnlib.util.open_url(network_pkl) as f: 160 | snapshot = legacy.load_network_pkl(f) 161 | G = snapshot['G_ema'] # type: ignore 162 | 163 | # G.cfg.set('bg_model', dnnlib.EasyDict(type=None)) 164 | G.cfg = dnnlib.EasyDict(**G.cfg) 165 | G.cfg.bg_model = G.cfg.bg_model if 'bg_model' in G.cfg else dnnlib.EasyDict(type=None, num_steps=16) 166 | G.cfg.use_noise = True 167 | 168 | if cfg.reload_code: 169 | from src.training.networks_eg3d import Generator 170 | G_new = Generator( 171 | G.cfg, 172 | z_dim=G.z_dim, 173 | w_dim=G.w_dim, 174 | mapping_kwargs=dnnlib.EasyDict( 175 | num_layers=2, 176 | mean_camera_pose=(G.mapping.mean_camera_pose if hasattr(G.mapping, 'mean_camera_pose') else None), 177 | camera_cond=G.cfg.camera_cond if 'camera_cond' in G.cfg else False, 178 | ), 179 | # mapping_kwargs=dnnlib.EasyDict(num_layers=2, mean_camera_pose=torch.zeros(3)), 180 | # mapping_kwargs=dnnlib.EasyDict(num_layers=2), 181 | img_resolution=G.img_resolution, 182 | img_channels=G.img_channels, 183 | c_dim=G.c_dim, 184 | channel_base=int(G.cfg.get('fmaps', 0.5) * 32768), 185 | channel_max=G.cfg.get('channel_max', 512), 186 | ) 187 | G_new.load_state_dict(G.state_dict()) 188 | G = G_new 189 | 190 | return G, snapshot, network_pkl 191 | 192 | #---------------------------------------------------------------------------- 193 | 194 | def set_seed(seed: int): 195 | random.seed(seed) 196 | np.random.seed(seed) 197 | torch.manual_seed(seed) 198 | 199 | #---------------------------------------------------------------------------- 200 | 201 | def listdir_full_paths(d: os.PathLike) -> List[os.PathLike]: 202 | return [os.path.join(d, o) for o in sorted(os.listdir(d))] 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def lanczos_resize_tensors(x: torch.Tensor, size): 207 | x = [TVF.to_pil_image(img) for img in x] 208 | x = [TVF.resize(img, size=size, interpolation=TVF.InterpolationMode.LANCZOS) for img in x] 209 | x = [TVF.to_tensor(img) for img in x] 210 | 211 | return torch.stack(x) 212 | 213 | #---------------------------------------------------------------------------- 214 | 215 | def maybe_makedirs(d: os.PathLike): 216 | # TODO: what the hell is this function name? 217 | if d != '': 218 | os.makedirs(d, exist_ok=True) 219 | 220 | #---------------------------------------------------------------------------- 221 | -------------------------------------------------------------------------------- /src/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /src/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /src/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | from src import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 68 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 69 | 70 | @contextlib.contextmanager 71 | def suppress_tracer_warnings(): 72 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 73 | warnings.filters.insert(0, flt) 74 | yield 75 | warnings.filters.remove(flt) 76 | 77 | #---------------------------------------------------------------------------- 78 | # Assert that the shape of a tensor matches the given list of integers. 79 | # None indicates that the size of a dimension is allowed to vary. 80 | # Performs symbolic assertion when used in torch.jit.trace(). 81 | 82 | def assert_shape(tensor, ref_shape): 83 | if tensor.ndim != len(ref_shape): 84 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 85 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 86 | if ref_size is None: 87 | pass 88 | elif isinstance(ref_size, torch.Tensor): 89 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 90 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 91 | elif isinstance(size, torch.Tensor): 92 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 93 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 94 | elif size != ref_size: 95 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 96 | 97 | #---------------------------------------------------------------------------- 98 | # Function decorator that calls torch.autograd.profiler.record_function(). 99 | 100 | def profiled_function(fn): 101 | def decorator(*args, **kwargs): 102 | with torch.autograd.profiler.record_function(fn.__name__): 103 | return fn(*args, **kwargs) 104 | decorator.__name__ = fn.__name__ 105 | return decorator 106 | 107 | #---------------------------------------------------------------------------- 108 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 109 | # indefinitely, shuffling items as it goes. 110 | 111 | class InfiniteSampler(torch.utils.data.Sampler): 112 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 113 | assert len(dataset) > 0 114 | assert num_replicas > 0 115 | assert 0 <= rank < num_replicas 116 | assert 0 <= window_size <= 1 117 | super().__init__(dataset) 118 | self.dataset = dataset 119 | self.rank = rank 120 | self.num_replicas = num_replicas 121 | self.shuffle = shuffle 122 | self.seed = seed 123 | self.window_size = window_size 124 | 125 | def __iter__(self): 126 | order = np.arange(len(self.dataset)) 127 | rnd = None 128 | window = 0 129 | if self.shuffle: 130 | rnd = np.random.RandomState(self.seed) 131 | rnd.shuffle(order) 132 | window = int(np.rint(order.size * self.window_size)) 133 | 134 | idx = 0 135 | while True: 136 | i = idx % order.size 137 | if idx % self.num_replicas == self.rank: 138 | yield order[i] 139 | if window >= 2: 140 | j = (i - rnd.randint(window)) % order.size 141 | order[i], order[j] = order[j], order[i] 142 | idx += 1 143 | 144 | #---------------------------------------------------------------------------- 145 | # Utilities for operating with torch.nn.Module parameters and buffers. 146 | 147 | def params_and_buffers(module): 148 | assert isinstance(module, torch.nn.Module) 149 | return list(module.parameters()) + list(module.buffers()) 150 | 151 | def named_params_and_buffers(module): 152 | assert isinstance(module, torch.nn.Module) 153 | return list(module.named_parameters()) + list(module.named_buffers()) 154 | 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True, module_kwargs={}): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs, **module_kwargs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /src/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /src/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /src/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /src/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /src/torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | from src import dnnlib 15 | 16 | from src.torch_utils import custom_ops 17 | from src.torch_utils import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | activation_funcs = { 22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 31 | } 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | _plugin = None 36 | _null_tensor = torch.empty([0]) 37 | 38 | def _init(): 39 | global _plugin 40 | if _plugin is None: 41 | _plugin = custom_ops.get_plugin( 42 | module_name='bias_act_plugin', 43 | sources=['bias_act.cpp', 'bias_act.cu'], 44 | headers=['bias_act.h'], 45 | source_dir=os.path.dirname(__file__), 46 | extra_cuda_cflags=['--use_fast_math'], 47 | ) 48 | return True 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 53 | r"""Fused bias and activation function. 54 | 55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 56 | and scales the result by `gain`. Each of the steps is optional. In most cases, 57 | the fused op is considerably more efficient than performing the same calculation 58 | using standard PyTorch ops. It supports first and second order gradients, 59 | but not third order gradients. 60 | 61 | Args: 62 | x: Input activation tensor. Can be of any shape. 63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 64 | as `x`. The shape must be known, and it must match the dimension of `x` 65 | corresponding to `dim`. 66 | dim: The dimension in `x` corresponding to the elements of `b`. 67 | The value of `dim` is ignored if `b` is not specified. 68 | act: Name of the activation function to evaluate, or `"linear"` to disable. 69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 70 | See `activation_funcs` for a full list. `None` is not allowed. 71 | alpha: Shape parameter for the activation function, or `None` to use the default. 72 | gain: Scaling factor for the output tensor, or `None` to use default. 73 | See `activation_funcs` for the default scaling of each activation function. 74 | If unsure, consider specifying 1. 75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 76 | the clamping (default). 77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 78 | 79 | Returns: 80 | Tensor of the same shape and datatype as `x`. 81 | """ 82 | assert isinstance(x, torch.Tensor) 83 | assert impl in ['ref', 'cuda'] 84 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @misc.profiled_function 91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 93 | """ 94 | assert isinstance(x, torch.Tensor) 95 | assert clamp is None or clamp >= 0 96 | spec = activation_funcs[act] 97 | alpha = float(alpha if alpha is not None else spec.def_alpha) 98 | gain = float(gain if gain is not None else spec.def_gain) 99 | clamp = float(clamp if clamp is not None else -1) 100 | 101 | # Add bias. 102 | if b is not None: 103 | assert isinstance(b, torch.Tensor) and b.ndim == 1 104 | assert 0 <= dim < x.ndim 105 | assert b.shape[0] == x.shape[dim] 106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 107 | 108 | # Evaluate activation function. 109 | alpha = float(alpha) 110 | x = spec.func(x, alpha=alpha) 111 | 112 | # Scale by gain. 113 | gain = float(gain) 114 | if gain != 1: 115 | x = x * gain 116 | 117 | # Clamp. 118 | if clamp >= 0: 119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 120 | return x 121 | 122 | #---------------------------------------------------------------------------- 123 | 124 | _bias_act_cuda_cache = dict() 125 | 126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 127 | """Fast CUDA implementation of `bias_act()` using custom ops. 128 | """ 129 | # Parse arguments. 130 | assert clamp is None or clamp >= 0 131 | spec = activation_funcs[act] 132 | alpha = float(alpha if alpha is not None else spec.def_alpha) 133 | gain = float(gain if gain is not None else spec.def_gain) 134 | clamp = float(clamp if clamp is not None else -1) 135 | 136 | # Lookup from cache. 137 | key = (dim, act, alpha, gain, clamp) 138 | if key in _bias_act_cuda_cache: 139 | return _bias_act_cuda_cache[key] 140 | 141 | # Forward op. 142 | class BiasActCuda(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, x, b): # pylint: disable=arguments-differ 145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 146 | x = x.contiguous(memory_format=ctx.memory_format) 147 | b = b.contiguous() if b is not None else _null_tensor 148 | y = x 149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 151 | ctx.save_for_backward( 152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 154 | y if 'y' in spec.ref else _null_tensor) 155 | return y 156 | 157 | @staticmethod 158 | def backward(ctx, dy): # pylint: disable=arguments-differ 159 | dy = dy.contiguous(memory_format=ctx.memory_format) 160 | x, b, y = ctx.saved_tensors 161 | dx = None 162 | db = None 163 | 164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 165 | dx = dy 166 | if act != 'linear' or gain != 1 or clamp >= 0: 167 | dx = BiasActCudaGrad.apply(dy, x, b, y) 168 | 169 | if ctx.needs_input_grad[1]: 170 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 171 | 172 | return dx, db 173 | 174 | # Backward op. 175 | class BiasActCudaGrad(torch.autograd.Function): 176 | @staticmethod 177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 180 | ctx.save_for_backward( 181 | dy if spec.has_2nd_grad else _null_tensor, 182 | x, b, y) 183 | return dx 184 | 185 | @staticmethod 186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 188 | dy, x, b, y = ctx.saved_tensors 189 | d_dy = None 190 | d_x = None 191 | d_b = None 192 | d_y = None 193 | 194 | if ctx.needs_input_grad[0]: 195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 196 | 197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 199 | 200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 202 | 203 | return d_dy, d_x, d_b, d_y 204 | 205 | # Add to cache. 206 | _bias_act_cuda_cache[key] = BiasActCuda 207 | return BiasActCuda 208 | 209 | #---------------------------------------------------------------------------- 210 | -------------------------------------------------------------------------------- /src/torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.10']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /src/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /src/torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /src/torch_utils/ops/filtered_lrelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | import torch 12 | import warnings 13 | 14 | from .. import custom_ops 15 | from .. import misc 16 | from . import upfirdn2d 17 | from . import bias_act 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | _plugin = None 22 | 23 | def _init(): 24 | global _plugin 25 | if _plugin is None: 26 | _plugin = custom_ops.get_plugin( 27 | module_name='filtered_lrelu_plugin', 28 | sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], 29 | headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], 30 | source_dir=os.path.dirname(__file__), 31 | extra_cuda_cflags=['--use_fast_math'], 32 | ) 33 | return True 34 | 35 | def _get_filter_size(f): 36 | if f is None: 37 | return 1, 1 38 | assert isinstance(f, torch.Tensor) 39 | assert 1 <= f.ndim <= 2 40 | return f.shape[-1], f.shape[0] # width, height 41 | 42 | def _parse_padding(padding): 43 | if isinstance(padding, int): 44 | padding = [padding, padding] 45 | assert isinstance(padding, (list, tuple)) 46 | assert all(isinstance(x, (int, np.integer)) for x in padding) 47 | padding = [int(x) for x in padding] 48 | if len(padding) == 2: 49 | px, py = padding 50 | padding = [px, px, py, py] 51 | px0, px1, py0, py1 = padding 52 | return px0, px1, py0, py1 53 | 54 | #---------------------------------------------------------------------------- 55 | 56 | def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): 57 | r"""Filtered leaky ReLU for a batch of 2D images. 58 | 59 | Performs the following sequence of operations for each channel: 60 | 61 | 1. Add channel-specific bias if provided (`b`). 62 | 63 | 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). 64 | 65 | 3. Pad the image with the specified number of zeros on each side (`padding`). 66 | Negative padding corresponds to cropping the image. 67 | 68 | 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it 69 | so that the footprint of all output pixels lies within the input image. 70 | 71 | 5. Multiply each value by the provided gain factor (`gain`). 72 | 73 | 6. Apply leaky ReLU activation function to each value. 74 | 75 | 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. 76 | 77 | 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking 78 | it so that the footprint of all output pixels lies within the input image. 79 | 80 | 9. Downsample the image by keeping every Nth pixel (`down`). 81 | 82 | The fused op is considerably more efficient than performing the same calculation 83 | using standard PyTorch ops. It supports gradients of arbitrary order. 84 | 85 | Args: 86 | x: Float32/float16/float64 input tensor of the shape 87 | `[batch_size, num_channels, in_height, in_width]`. 88 | fu: Float32 upsampling FIR filter of the shape 89 | `[filter_height, filter_width]` (non-separable), 90 | `[filter_taps]` (separable), or 91 | `None` (identity). 92 | fd: Float32 downsampling FIR filter of the shape 93 | `[filter_height, filter_width]` (non-separable), 94 | `[filter_taps]` (separable), or 95 | `None` (identity). 96 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 97 | as `x`. The length of vector must must match the channel dimension of `x`. 98 | up: Integer upsampling factor (default: 1). 99 | down: Integer downsampling factor. (default: 1). 100 | padding: Padding with respect to the upsampled image. Can be a single number 101 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 102 | (default: 0). 103 | gain: Overall scaling factor for signal magnitude (default: sqrt(2)). 104 | slope: Slope on the negative side of leaky ReLU (default: 0.2). 105 | clamp: Maximum magnitude for leaky ReLU output (default: None). 106 | flip_filter: False = convolution, True = correlation (default: False). 107 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 108 | 109 | Returns: 110 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 111 | """ 112 | assert isinstance(x, torch.Tensor) 113 | assert impl in ['ref', 'cuda'] 114 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 115 | return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) 116 | return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) 117 | 118 | #---------------------------------------------------------------------------- 119 | 120 | @misc.profiled_function 121 | def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): 122 | """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using 123 | existing `upfirdn2n()` and `bias_act()` ops. 124 | """ 125 | assert isinstance(x, torch.Tensor) and x.ndim == 4 126 | fu_w, fu_h = _get_filter_size(fu) 127 | fd_w, fd_h = _get_filter_size(fd) 128 | if b is not None: 129 | assert isinstance(b, torch.Tensor) and b.dtype == x.dtype 130 | misc.assert_shape(b, [x.shape[1]]) 131 | assert isinstance(up, int) and up >= 1 132 | assert isinstance(down, int) and down >= 1 133 | px0, px1, py0, py1 = _parse_padding(padding) 134 | assert gain == float(gain) and gain > 0 135 | assert slope == float(slope) and slope >= 0 136 | assert clamp is None or (clamp == float(clamp) and clamp >= 0) 137 | 138 | # Calculate output size. 139 | batch_size, channels, in_h, in_w = x.shape 140 | in_dtype = x.dtype 141 | out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down 142 | out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down 143 | 144 | # Compute using existing ops. 145 | x = bias_act.bias_act(x=x, b=b) # Apply bias. 146 | x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. 147 | x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. 148 | x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. 149 | 150 | # Check output shape & dtype. 151 | misc.assert_shape(x, [batch_size, channels, out_h, out_w]) 152 | assert x.dtype == in_dtype 153 | return x 154 | 155 | #---------------------------------------------------------------------------- 156 | 157 | _filtered_lrelu_cuda_cache = dict() 158 | 159 | def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): 160 | """Fast CUDA implementation of `filtered_lrelu()` using custom ops. 161 | """ 162 | assert isinstance(up, int) and up >= 1 163 | assert isinstance(down, int) and down >= 1 164 | px0, px1, py0, py1 = _parse_padding(padding) 165 | assert gain == float(gain) and gain > 0 166 | gain = float(gain) 167 | assert slope == float(slope) and slope >= 0 168 | slope = float(slope) 169 | assert clamp is None or (clamp == float(clamp) and clamp >= 0) 170 | clamp = float(clamp if clamp is not None else 'inf') 171 | 172 | # Lookup from cache. 173 | key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) 174 | if key in _filtered_lrelu_cuda_cache: 175 | return _filtered_lrelu_cuda_cache[key] 176 | 177 | # Forward op. 178 | class FilteredLReluCuda(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ 181 | assert isinstance(x, torch.Tensor) and x.ndim == 4 182 | 183 | # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). 184 | if fu is None: 185 | fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) 186 | if fd is None: 187 | fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) 188 | assert 1 <= fu.ndim <= 2 189 | assert 1 <= fd.ndim <= 2 190 | 191 | # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. 192 | if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: 193 | fu = fu.square()[None] 194 | if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: 195 | fd = fd.square()[None] 196 | 197 | # Missing sign input tensor. 198 | if si is None: 199 | si = torch.empty([0]) 200 | 201 | # Missing bias tensor. 202 | if b is None: 203 | b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) 204 | 205 | # Construct internal sign tensor only if gradients are needed. 206 | write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) 207 | 208 | # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. 209 | strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] 210 | if any(a < b for a, b in zip(strides[:-1], strides[1:])): 211 | warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) 212 | 213 | # Call C++/Cuda plugin if datatype is supported. 214 | if x.dtype in [torch.float16, torch.float32]: 215 | if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): 216 | warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) 217 | y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) 218 | else: 219 | return_code = -1 220 | 221 | # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because 222 | # only the bit-packed sign tensor is retained for gradient computation. 223 | if return_code < 0: 224 | warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) 225 | 226 | y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. 227 | y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. 228 | so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. 229 | y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. 230 | 231 | # Prepare for gradient computation. 232 | ctx.save_for_backward(fu, fd, (si if si.numel() else so)) 233 | ctx.x_shape = x.shape 234 | ctx.y_shape = y.shape 235 | ctx.s_ofs = sx, sy 236 | return y 237 | 238 | @staticmethod 239 | def backward(ctx, dy): # pylint: disable=arguments-differ 240 | fu, fd, si = ctx.saved_tensors 241 | _, _, xh, xw = ctx.x_shape 242 | _, _, yh, yw = ctx.y_shape 243 | sx, sy = ctx.s_ofs 244 | dx = None # 0 245 | dfu = None; assert not ctx.needs_input_grad[1] 246 | dfd = None; assert not ctx.needs_input_grad[2] 247 | db = None # 3 248 | dsi = None; assert not ctx.needs_input_grad[4] 249 | dsx = None; assert not ctx.needs_input_grad[5] 250 | dsy = None; assert not ctx.needs_input_grad[6] 251 | 252 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: 253 | pp = [ 254 | (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, 255 | xw * up - yw * down + px0 - (up - 1), 256 | (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, 257 | xh * up - yh * down + py0 - (up - 1), 258 | ] 259 | gg = gain * (up ** 2) / (down ** 2) 260 | ff = (not flip_filter) 261 | sx = sx - (fu.shape[-1] - 1) + px0 262 | sy = sy - (fu.shape[0] - 1) + py0 263 | dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) 264 | 265 | if ctx.needs_input_grad[3]: 266 | db = dx.sum([0, 2, 3]) 267 | 268 | return dx, dfu, dfd, db, dsi, dsx, dsy 269 | 270 | # Add to cache. 271 | _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda 272 | return FilteredLReluCuda 273 | 274 | #---------------------------------------------------------------------------- 275 | -------------------------------------------------------------------------------- /src/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /src/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /src/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /src/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /src/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /src/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /src/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /src/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | from src import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from src.torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /src/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | from src import dnnlib 18 | 19 | from src.torch_utils import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rethinking-3d-gans/code/9bfc3ab32bd2b0992a229501e50bafcf232c5c11/teaser.jpg --------------------------------------------------------------------------------