├── .gitignore ├── README.md ├── build_container.sh ├── build_img.sh ├── dnnlib ├── __init__.py └── util.py ├── docker └── Dockerfile ├── environment.yml ├── gen_images.py ├── imgs ├── encoded.png ├── encoded_batch_1.png ├── encoded_batch_2.png ├── encoded_batch_3.png ├── encoded_transform_x-0.2_y-0.1.png ├── encoded_transform_x-0.2_y0.1.png ├── encoded_transform_x0.2_y0.0.png ├── encoded_transform_x0.2_y0.1.png ├── real_batch_1.png ├── real_batch_2.png ├── real_batch_3.png ├── target.png ├── train_id.png ├── train_id_improve.png ├── train_l2.png ├── train_lpips.png ├── tree_data.png └── tree_pretrained.png ├── legacy.py ├── samples ├── test01.jpeg ├── test02.jpeg └── test03.jpeg ├── setup.py ├── test.py ├── test_architecture.py ├── test_base.sh ├── test_config_b.sh ├── test_dataset.py ├── test_eval.sh ├── test_gen_images.sh ├── test_inference.py ├── test_loss.py ├── test_resume.sh ├── test_resume_config_a.sh ├── test_resume_config_b.sh ├── test_train.sh ├── test_train_config_a.sh ├── test_train_config_b.sh ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py ├── train_base.sh ├── train_config_a.sh ├── train_config_b.sh └── training ├── __init__.py ├── augment.py ├── dataset.py ├── dataset_encoder.py ├── loss.py ├── loss_encoder.py ├── networks_arcface.py ├── networks_encoder.py ├── networks_irse.py ├── networks_stylegan2.py ├── networks_stylegan3.py ├── ranger.py ├── testing_loop_encoder.py ├── training_loop.py └── training_loop_encoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pth 3 | *.pkl 4 | *.dat 5 | *.pt 6 | tmp/ 7 | exp/ 8 | old/ 9 | data/ 10 | stylegan3_encoder.egg-info/ 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stylegan3-encoder 2 | 3 | ## Introduction 4 | Encoder implementation for image inversion task of stylegan3 generator ([Alias Free GAN](https://github.com/NVlabs/stylegan3)). 5 | The neural network architecture and hyper-parameter settings of the base configuration is almost the same as that of [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel), and various settings of improved encoder architecture will be added in the future. 6 | For fast training, pytorch DistibutedDataParallel is used. 7 | 8 | Please see this repo for further research ([Stylegan3-edit](https://github.com/yuval-alaluf/stylegan3-editing)). 9 | 10 | ## Installation 11 | 12 | ### GPU and NVIDIA driver info 13 | * GeForce RTX 3090 x 8 14 | * NVIDIA driver version: 460.91.03 15 | 16 | ### Docker build 17 | ``` 18 | $ sh build_img.sh 19 | $ sh build_container.sh [container-name] 20 | ``` 21 | 22 | ### Install package 23 | ``` 24 | $ docker start [container-name] 25 | $ docker attach [container-name] 26 | $ pip install -v -e . 27 | ``` 28 | 29 | ### Pretrained weights 30 | ![tree](./imgs/tree_pretrained.png) 31 | - [encoder pretrained, base configuration](https://drive.google.com/file/d/1dog6vajt_1zUwh_hopxSvQ2ZSWALz71T/view?usp=sharing) 32 | - [stylegan3, vgg, inception](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan3) 33 | - [dlib landmarks detector](https://drive.google.com/file/d/1HKmjg6iXsWr4aFPuU0gBXPGR83wqMzq7/view?usp=sharing) 34 | - [IR-SE50](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing) 35 | 36 | ### Prepare dataset 37 | ![tree2](./imgs/tree_data.png) 38 | - [ffhq](https://github.com/NVlabs/ffhq-dataset) 39 | - [ffhqs - 1000 images sampled from FFHQ, for test](https://drive.google.com/drive/folders/1taHKxS66YKJNhdhiGcEdM6nnE5W9zBb1?usp=sharing) 40 | - [celeba-hq](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 41 | - [celeba-hq-samples](https://drive.google.com/file/d/1IRIQTaTDn3NGuTauyultlQdYHlIntsBD/view?usp=sharing) 42 | 43 | ### Train 44 | ``` 45 | python train.py \ 46 | --outdir exp/[exp_name] \ 47 | --encoder [encoder_type] \ 48 | --data data/[dataset_name] \ 49 | --gpus [num_gpus] \ 50 | --batch [total_batch_size] \ 51 | --generator [generator_pkl] 52 | ``` 53 | 54 | ### Test 55 | ``` 56 | python test.py \ 57 | --testdir exp/[train_exp]/[train_exp_subdir] \ 58 | --data data/[dataset_name] \ 59 | --gpus [num_gpus] \ 60 | --batch [total_batch_size] 61 | ``` 62 | 63 | ## Experiments 64 | ### Base configuration 65 | **Train options** 66 | ``` 67 | { 68 | "model_architecture": "base", 69 | "dataset_dir": "data/ffhq", 70 | "num_gpus": 8, 71 | "batch_size": 32, 72 | "batch_gpu": 4, 73 | "generator_pkl": "pretrained/stylegan3-t-ffhq-1024x1024.pkl", 74 | "val_dataset_dir": null, 75 | "training_steps": 100001, 76 | "val_steps": 10000, 77 | "print_steps": 50, 78 | "tensorboard_steps": 50, 79 | "image_snapshot_steps": 100, 80 | "network_snapshot_steps": 5000, 81 | "learning_rate": 0.001, 82 | "l2_lambda": 1.0, 83 | "lpips_lambda": 0.8, 84 | "id_lambda": 0.1, 85 | "reg_lambda": 0.0, 86 | "gan_lambda": 0.0, 87 | "edit_lambda": 0.0, 88 | "random_seed": 0, 89 | "num_workers": 3, 90 | "resume_pkl": null, 91 | "run_dir": "exp/base/00000-base-ffhq-gpus8-batch32" 92 | } 93 | ``` 94 | **Learning Curve** 95 | ![l2loss](./imgs/train_l2.png) 96 | ![lpipsloss](./imgs/train_lpips.png) 97 | ![idloss](./imgs/train_id.png) 98 | ![idimprove](./imgs/train_id_improve.png) 99 | 100 | **Trainset examples** 101 | Real image batch X 102 | ![real1](./imgs/real_batch_1.png) 103 | ![real2](./imgs/real_batch_2.png) 104 | ![real3](./imgs/real_batch_3.png) 105 | Encoded image batch G.synthesis(E(X)) 106 | ![encoded1](./imgs/encoded_batch_1.png) 107 | ![encoded2](./imgs/encoded_batch_2.png) 108 | ![encoded3](./imgs/encoded_batch_3.png) 109 | 110 | **Testset examples(celeba-hq)** 111 | Target image 112 | ![target](./imgs/target.png) 113 | Encoded image 114 | ![encoded](./imgs/encoded.png) 115 | Encoded image, transform x=0.2, y=0 116 | ![x02y00](./imgs/encoded_transform_x0.2_y0.0.png) 117 | Encoded image, transform x=0.2, y=0.1 118 | ![x02y01](./imgs/encoded_transform_x0.2_y0.1.png) 119 | Encoded image, transform x=-0.2, y=0.1 120 | ![x-02y01](./imgs/encoded_transform_x-0.2_y0.1.png) 121 | Encoded image, transform x=-0.2, y=-0.1 122 | ![x-02y-01](./imgs/encoded_transform_x-0.2_y-0.1.png) 123 | 124 | ## References 125 | 1. [stylegan3](https://github.com/NVlabs/stylegan3) 126 | 2. [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) 127 | -------------------------------------------------------------------------------- /build_container.sh: -------------------------------------------------------------------------------- 1 | nvidia-docker run -ti \ 2 | -v $(pwd):/workspace/stylegan3-encoder/ \ 3 | --ipc=host \ 4 | --net=host \ 5 | --name=$1 \ 6 | stylegan3 \ 7 | /bin/bash 8 | -------------------------------------------------------------------------------- /build_img.sh: -------------------------------------------------------------------------------- 1 | docker build -t stylegan3 docker/ 2 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | FROM nvcr.io/nvidia/pytorch:21.08-py3 10 | 11 | ENV PYTHONDONTWRITEBYTECODE 1 12 | ENV PYTHONUNBUFFERED 1 13 | 14 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 lpips 15 | 16 | WORKDIR /workspace 17 | 18 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 19 | ENTRYPOINT ["/entry.sh"] 20 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: stylegan3 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow=8.3.1 11 | - scipy=1.7.1 12 | - pytorch=1.9.1 13 | - cudatoolkit=11.1 14 | - requests=2.26.0 15 | - tqdm=4.62.2 16 | - ninja=1.10.2 17 | - matplotlib=3.4.2 18 | - imageio=2.9.0 19 | - pip: 20 | - imgui==1.3.0 21 | - glfw==2.2.0 22 | - pyopengl==3.1.5 23 | - imageio-ffmpeg==0.4.3 24 | - pyspng 25 | -------------------------------------------------------------------------------- /gen_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def parse_range(s: Union[str, List]) -> List[int]: 26 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 27 | 28 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 29 | ''' 30 | if isinstance(s, list): return s 31 | ranges = [] 32 | range_re = re.compile(r'^(\d+)-(\d+)$') 33 | for p in s.split(','): 34 | m = range_re.match(p) 35 | if m: 36 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 37 | else: 38 | ranges.append(int(p)) 39 | return ranges 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: 44 | '''Parse a floating point 2-vector of syntax 'a,b'. 45 | 46 | Example: 47 | '0,1' returns (0,1) 48 | ''' 49 | if isinstance(s, tuple): return s 50 | parts = s.split(',') 51 | if len(parts) == 2: 52 | return (float(parts[0]), float(parts[1])) 53 | raise ValueError(f'cannot parse 2-vector {s}') 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def make_transform(translate: Tuple[float,float], angle: float): 58 | m = np.eye(3) 59 | s = np.sin(angle/360.0*np.pi*2) 60 | c = np.cos(angle/360.0*np.pi*2) 61 | m[0][0] = c 62 | m[0][1] = s 63 | m[0][2] = translate[0] 64 | m[1][0] = -s 65 | m[1][1] = c 66 | m[1][2] = translate[1] 67 | return m 68 | 69 | #---------------------------------------------------------------------------- 70 | 71 | @click.command() 72 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 73 | @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) 74 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 75 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 76 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 77 | @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') 78 | @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') 79 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 80 | def generate_images( 81 | network_pkl: str, 82 | seeds: List[int], 83 | truncation_psi: float, 84 | noise_mode: str, 85 | outdir: str, 86 | translate: Tuple[float,float], 87 | rotate: float, 88 | class_idx: Optional[int] 89 | ): 90 | """Generate images using pretrained network pickle. 91 | 92 | Examples: 93 | 94 | \b 95 | # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). 96 | python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ 97 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 98 | 99 | \b 100 | # Generate uncurated images with truncation using the MetFaces-U dataset 101 | python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 102 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl 103 | """ 104 | 105 | print('Loading networks from "%s"...' % network_pkl) 106 | device = torch.device('cuda') 107 | with dnnlib.util.open_url(network_pkl) as f: 108 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 109 | 110 | os.makedirs(outdir, exist_ok=True) 111 | 112 | # Labels. 113 | label = torch.zeros([1, G.c_dim], device=device) 114 | if G.c_dim != 0: 115 | if class_idx is None: 116 | raise click.ClickException('Must specify class label with --class when using a conditional network') 117 | label[:, class_idx] = 1 118 | else: 119 | if class_idx is not None: 120 | print ('warn: --class=lbl ignored when running on an unconditional network') 121 | 122 | # Generate images. 123 | for seed_idx, seed in enumerate(seeds): 124 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 125 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 126 | 127 | # Construct an inverse rotation/translation matrix and pass to the generator. The 128 | # generator expects this matrix as an inverse to avoid potentially failing numerical 129 | # operations in the network. 130 | if hasattr(G.synthesis, 'input'): 131 | m = make_transform(translate, rotate) 132 | m = np.linalg.inv(m) 133 | G.synthesis.input.transform.copy_(torch.from_numpy(m)) 134 | 135 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) 136 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 137 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') 138 | 139 | 140 | #---------------------------------------------------------------------------- 141 | 142 | if __name__ == "__main__": 143 | generate_images() # pylint: disable=no-value-for-parameter 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /imgs/encoded.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/encoded.png -------------------------------------------------------------------------------- /imgs/encoded_batch_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/encoded_batch_1.png -------------------------------------------------------------------------------- /imgs/encoded_batch_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/encoded_batch_2.png -------------------------------------------------------------------------------- /imgs/encoded_batch_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/encoded_batch_3.png -------------------------------------------------------------------------------- /imgs/encoded_transform_x-0.2_y-0.1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/encoded_transform_x-0.2_y-0.1.png -------------------------------------------------------------------------------- /imgs/encoded_transform_x-0.2_y0.1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/encoded_transform_x-0.2_y0.1.png -------------------------------------------------------------------------------- /imgs/encoded_transform_x0.2_y0.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/encoded_transform_x0.2_y0.0.png -------------------------------------------------------------------------------- /imgs/encoded_transform_x0.2_y0.1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/encoded_transform_x0.2_y0.1.png -------------------------------------------------------------------------------- /imgs/real_batch_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/real_batch_1.png -------------------------------------------------------------------------------- /imgs/real_batch_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/real_batch_2.png -------------------------------------------------------------------------------- /imgs/real_batch_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/real_batch_3.png -------------------------------------------------------------------------------- /imgs/target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/target.png -------------------------------------------------------------------------------- /imgs/train_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/train_id.png -------------------------------------------------------------------------------- /imgs/train_id_improve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/train_id_improve.png -------------------------------------------------------------------------------- /imgs/train_l2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/train_l2.png -------------------------------------------------------------------------------- /imgs/train_lpips.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/train_lpips.png -------------------------------------------------------------------------------- /imgs/tree_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/tree_data.png -------------------------------------------------------------------------------- /imgs/tree_pretrained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/imgs/tree_pretrained.png -------------------------------------------------------------------------------- /samples/test01.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/samples/test01.jpeg -------------------------------------------------------------------------------- /samples/test02.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/samples/test02.jpeg -------------------------------------------------------------------------------- /samples/test03.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nampyohong/stylegan3-encoder/783b5861e45fdf51474c57d6789e614c8ae35522/samples/test03.jpeg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from setuptools import setup 4 | 5 | 6 | if __name__ == '__main__': 7 | setup( 8 | name='stylegan3-encoder', 9 | version=1.0, 10 | description='stylegan3 encoder for image inversion', 11 | author='soushirou', 12 | author_email='nampyo24@kaist.ac.kr', 13 | ) 14 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """Test stylegan3 encoder""" 2 | 3 | import json 4 | import os 5 | import random 6 | import re 7 | import tempfile 8 | 9 | import click 10 | import numpy as np 11 | import torch 12 | 13 | import dnnlib 14 | from training import testing_loop_encoder 15 | from torch_utils import training_stats 16 | from torch_utils import custom_ops 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | @click.command() 21 | 22 | # Required. 23 | @click.option('--testdir', help='Training exp directory path', metavar='DIR', required=True) 24 | @click.option('--data', help='Testing data', metavar='[DIR]', type=str, required=True) 25 | @click.option('--gpus', help='Number of GPUs to use', metavar='INT', type=click.IntRange(min=1), required=True) 26 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True) 27 | 28 | # Reproducibility 29 | @click.option('--seed', help='Random seed', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) 30 | 31 | # Dataloader workers 32 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=3, show_default=True) 33 | 34 | 35 | def main(**kwargs): 36 | """Main training script 37 | """ 38 | # Initialize config. 39 | opts = dnnlib.EasyDict(kwargs) # Command line arguments. 40 | c = dnnlib.EasyDict() # Main config dict. 41 | 42 | c.test_dir = opts.testdir 43 | c.dataset_dir = opts.data 44 | c.num_gpus = opts.gpus 45 | c.batch_size = opts.batch 46 | c.batch_gpu = opts.batch // opts.gpus 47 | 48 | c.random_seed = opts.seed 49 | c.num_workers = opts.workers 50 | 51 | with open(os.path.join(c.test_dir, 'training_options.json'),'r') as f: 52 | training_options = dnnlib.EasyDict(json.load(f)) 53 | 54 | c.model_architecture = training_options.model_architecture 55 | if 'w_avg' in training_options: 56 | c.w_avg = training_options.w_avg 57 | if 'num_encoder_layers' in training_options: 58 | c.num_encoder_layers = training_options.num_encoder_layers 59 | c.generator_pkl = training_options.generator_pkl 60 | c.l2_lambda = training_options.l2_lambda 61 | c.lpips_lambda = training_options.lpips_lambda 62 | c.id_lambda = training_options.id_lambda 63 | c.reg_lambda = training_options.reg_lambda 64 | c.gan_lambda = training_options.gan_lambda 65 | c.edit_lambda = training_options.edit_lambda 66 | 67 | # Print options. 68 | print() 69 | print('Testing options:') 70 | print(json.dumps(c, indent=2)) 71 | print() 72 | 73 | # Create output directory. 74 | print('Creating output directory...') 75 | os.makedirs(f'{c.test_dir}/test', exist_ok=True) 76 | with open(os.path.join(c.test_dir, 'test', 'testing_options.json'), 'wt') as f: 77 | json.dump(c, f, indent=2) 78 | 79 | # Launch processes. 80 | print('Launching processes...') 81 | torch.multiprocessing.set_start_method('spawn') 82 | with tempfile.TemporaryDirectory() as temp_dir: 83 | if c.num_gpus == 1: 84 | subprocess_fn(rank=0, c=c, temp_dir=temp_dir) 85 | else: 86 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus) 87 | 88 | 89 | def subprocess_fn(rank, c, temp_dir): 90 | # Init torch.distributed. 91 | # if c.num_gpus > 1: 92 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 93 | init_method = f'file://{init_file}' 94 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus) 95 | 96 | # Init torch_utils 97 | torch.cuda.set_device(rank) 98 | sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None 99 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 100 | if rank != 0: 101 | custom_ops.verbosity = 'none' 102 | 103 | # Execute training loop. 104 | testing_loop_encoder.testing_loop(rank=rank, **c) 105 | 106 | #---------------------------------------------------------------------------- 107 | 108 | if __name__ == "__main__": 109 | main() # pylint: disable=no-value-for-parameter 110 | 111 | #---------------------------------------------------------------------------- 112 | -------------------------------------------------------------------------------- /test_architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import dnnlib 4 | import legacy 5 | from training.networks_encoder import Encoder 6 | 7 | 8 | def n_param(module): 9 | return sum(p.numel() for p in module.parameters() if p.requires_grad) 10 | 11 | 12 | if __name__ == '__main__': 13 | device = torch.device('cuda:0') 14 | network_pkl = 'pretrained/stylegan3-t-ffhq-1024x1024.pkl' 15 | with dnnlib.util.open_url(network_pkl) as f: 16 | G = legacy.load_network_pkl(f)['G_ema'].to(device) 17 | 18 | # Base 19 | # encoder = Encoder(pretrained=None).to(device) 20 | # 21 | # x = torch.randn((1,3,256,256)).to(device) 22 | # print('\nBase configuration') 23 | # print(f'\nencoder # params: {n_param(encoder)}') 24 | # print(f'encoder input_layer # params: {n_param(encoder.encoder.input_layer)}') 25 | # print(f'encoder body # params: {n_param(encoder.encoder.body)}') 26 | # print(f'encoder styles # params: {n_param(encoder.encoder.styles)}') 27 | # print(f'encoder latlayer # params: {n_param(encoder.encoder.latlayer1)+n_param(encoder.encoder.latlayer2)}') 28 | # print('\ninput: [b,3,256,256]') 29 | # print(x.shape) 30 | # print('latent: [b,16,512]') 31 | # latent = encoder(x) 32 | # print(latent.shape) 33 | # synth = G.synthesis(latent) 34 | # print('synth: [b,3,1024,1024]') 35 | # print(synth.shape) 36 | 37 | # Config-a 38 | # x = torch.randn((2,3,256,256)).to(device) 39 | # styleblock = dict(arch='transformer', num_encoder_layers=1) # 1 40 | # encoder = Encoder(pretrained=None, **styleblock).to(device) 41 | # print('\nConfig-a: over parametrization') 42 | # print(f'\nencoder # params: {n_param(encoder)}') 43 | # print(f'encoder input_layer # params: {n_param(encoder.encoder.input_layer)}') 44 | # print(f'encoder body # params: {n_param(encoder.encoder.body)}') 45 | # print(f'encoder styles # params: {n_param(encoder.encoder.styles)}') 46 | # print(f'encoder latlayer # params: {n_param(encoder.encoder.latlayer1)+n_param(encoder.encoder.latlayer2)}') 47 | # print('\ninput: [b,3,256,256]') 48 | # print(x.shape) 49 | # print('latent: [b,16,512]') 50 | # latent = encoder(x) 51 | # print(latent.shape) 52 | # synth = G.synthesis(latent) 53 | # print('synth: [b,3,1024,1024]') 54 | # print(synth.shape) 55 | 56 | # Config-b train from w_avg 57 | w_avg = G.mapping.w_avg 58 | encoder = Encoder(pretrained=None,w_avg=w_avg).to(device) 59 | x = torch.randn((1,3,256,256)).to(device) 60 | print('\nBase configuration') 61 | print(f'\nencoder # params: {n_param(encoder)}') 62 | print(f'encoder input_layer # params: {n_param(encoder.encoder.input_layer)}') 63 | print(f'encoder body # params: {n_param(encoder.encoder.body)}') 64 | print(f'encoder styles # params: {n_param(encoder.encoder.styles)}') 65 | print(f'encoder latlayer # params: {n_param(encoder.encoder.latlayer1)+n_param(encoder.encoder.latlayer2)}') 66 | print('\ninput: [b,3,256,256]') 67 | print(x.shape) 68 | print('latent: [b,16,512]') 69 | latent = encoder(x) 70 | print(latent.shape) 71 | synth = G.synthesis(latent) 72 | print('synth: [b,3,1024,1024]') 73 | print(synth.shape) 74 | 75 | print("\nDone.") 76 | -------------------------------------------------------------------------------- /test_base.sh: -------------------------------------------------------------------------------- 1 | python test.py \ 2 | --testdir exp/base/00000-base-ffhq-gpus8-batch32/ \ 3 | --data data/celeba-hq/ \ 4 | --gpus 8 \ 5 | --batch 32 \ 6 | -------------------------------------------------------------------------------- /test_config_b.sh: -------------------------------------------------------------------------------- 1 | python test.py \ 2 | --testdir exp/config-b/00000-base-ffhq-gpus8-batch32/ \ 3 | --data data/celeba-hq/ \ 4 | --gpus 8 \ 5 | --batch 32 \ 6 | -------------------------------------------------------------------------------- /test_dataset.py: -------------------------------------------------------------------------------- 1 | from training.dataset_encoder import ImagesDataset 2 | 3 | 4 | if __name__ == '__main__': 5 | # TODO : get this path from config 6 | ffhqs_dataset_dir = 'data/ffhqs' 7 | ffhqs_dataset = ImagesDataset(ffhqs_dataset_dir, mode='train') 8 | print(f'dataset length: {len(ffhqs_dataset)}') 9 | print('transforms') 10 | print(ffhqs_dataset.transforms) 11 | print(f'input image shape: {ffhqs_dataset.__getitem__(0)[0].shape}') 12 | print("Done.") 13 | -------------------------------------------------------------------------------- /test_eval.sh: -------------------------------------------------------------------------------- 1 | python test.py \ 2 | --testdir exp/test/00000-base-ffhqs-gpus8-batch32/ \ 3 | --data data/ffhqs/ \ 4 | --gpus 8 \ 5 | --batch 32 \ 6 | -------------------------------------------------------------------------------- /test_gen_images.sh: -------------------------------------------------------------------------------- 1 | python gen_images.py --network pretrained/stylegan3-t-ffhq-1024x1024.pkl --seeds=100-109 --trunc=0.5 --outdir=tmp 2 | -------------------------------------------------------------------------------- /test_inference.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | import torch 5 | 6 | import dnnlib 7 | import legacy 8 | from training.dataset_encoder import ImagesDataset 9 | from training.networks_encoder import Encoder 10 | from training.training_loop_encoder import save_image 11 | from gen_images import make_transform 12 | 13 | 14 | if __name__ == '__main__': 15 | device = torch.device('cuda:0') 16 | with dnnlib.util.open_url('pretrained/stylegan3-t-ffhq-1024x1024.pkl') as f: 17 | G = legacy.load_network_pkl(f)['G_ema'].to(device) 18 | pretrained = 'pretrained/encoder-base-100000.pkl' 19 | E = Encoder(pretrained=pretrained).to(device) 20 | 21 | test_set = ImagesDataset('data/celeba-hq-samples', mode='inference') 22 | test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=len(test_set)) 23 | print(f'\ninference dataset length: {len(test_set)}') 24 | 25 | print(f'\ninput: [batch,3,256,256]') 26 | X,_ = next(iter(test_loader)) 27 | X = X.to(device) 28 | print(X.shape) 29 | 30 | print('\nlatent: [batch,16,512]') 31 | w = E(X) 32 | print(w.shape) 33 | 34 | print('\nsynth: [batch,3,1024,1024]') 35 | synth = G.synthesis(w) 36 | print(synth.shape) 37 | 38 | save_image(X, 'tmp/target.png', 1, len(test_set), 256, 256) 39 | save_image(synth, 'tmp/encoded.png', 1, len(test_set), 256, 256) 40 | 41 | # transform 42 | x_lst, y_lst = [-0.2, 0.0, 0.2], [-0.1, 0.0, 0.1] 43 | for x in x_lst: 44 | for y in y_lst: 45 | m = make_transform([x,y], 0) 46 | m = np.linalg.inv(m) 47 | G.synthesis.input.transform.copy_(torch.from_numpy(m)) 48 | synth = G.synthesis(w) 49 | save_image(synth, f'tmp/encoded_transform_x{x}_y{y}.png', 1, len(test_set), 256, 256) 50 | -------------------------------------------------------------------------------- /test_loss.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | import PIL.Image 4 | import numpy as np 5 | import torch 6 | 7 | import dnnlib 8 | import legacy 9 | from training.loss_encoder import l2_loss, IDLoss 10 | from lpips import LPIPS 11 | 12 | 13 | if __name__ == '__main__': 14 | device = torch.device('cuda:0') 15 | 16 | # generate 2 images 17 | network_pkl = 'pretrained/stylegan3-t-ffhq-1024x1024.pkl' 18 | truncation_psi = 0.5 19 | noise_mode = 'const' 20 | with dnnlib.util.open_url(network_pkl) as f: 21 | G = legacy.load_network_pkl(f)['G_ema'].to(device) 22 | label = torch.zeros([1, G.c_dim], device=device) 23 | seeds = [0,1] 24 | imgs = [] 25 | for seed_idx, seed in enumerate(seeds): 26 | z = torch.from_numpy(np.random.RandomState(seed).randn(1,G.z_dim)).to(device) 27 | img = G(z,label,truncation_psi=truncation_psi,noise_mode=noise_mode) 28 | imgs.append(img) 29 | generated_images, real_images = imgs[0], imgs[1] 30 | 31 | print(f'\ngenerated image shape : {imgs[0].shape}') # imgs[0] 32 | print(f'train image shape : {imgs[1].shape}') # imgs[1] 33 | print() 34 | 35 | # get loss 36 | id_loss = IDLoss().to(device) 37 | lpips_loss = LPIPS(net='alex').to(device).eval() 38 | loss_dict = dict() 39 | loss = 0.0 40 | # lambda1 : pixelwise l2 loss 41 | # lambda2 : lpips loss 42 | # lambda3 : id_loss 43 | # TODO : get lambda values from config 44 | lambda1, lambda2, lambda3 = 1,0.8,0.1 45 | loss_l2 = l2_loss(generated_images, real_images) 46 | loss_dict['l2'] = loss_l2.item() 47 | loss += loss_l2 * lambda1 48 | 49 | loss_lpips = lpips_loss(generated_images, real_images).squeeze() 50 | loss_dict['lpips'] = loss_lpips.item() 51 | loss += loss_lpips * lambda2 52 | 53 | loss_id, sim_improvement = id_loss(generated_images, real_images, real_images) 54 | loss_dict['id'] = loss_id.item() 55 | loss_dict['id_improve'] = sim_improvement 56 | loss += loss_id * lambda3 57 | 58 | print(f'\nloss: {loss}') 59 | print('\nloss dictionary') 60 | pprint(loss_dict) 61 | print('Done.') 62 | -------------------------------------------------------------------------------- /test_resume.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --outdir exp/test \ 3 | --encoder base \ 4 | --data data/ffhqs \ 5 | --gpus 8 \ 6 | --batch 32 \ 7 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 8 | --valdata data/ffhqs \ 9 | --training_steps 41 \ 10 | --val_steps 99 \ 11 | --print_steps 5 \ 12 | --tb_steps 5 \ 13 | --img_snshot_steps 10 \ 14 | --net_snshot_steps 10 \ 15 | --lr 0.001 \ 16 | --l2_lambda 1.0 \ 17 | --lpips_lambda 0.8 \ 18 | --id_lambda 0.1 \ 19 | --reg_lambda 0.0 \ 20 | --gan_lambda 0.0 \ 21 | --edit_lambda 0.0 \ 22 | --seed 0 \ 23 | --workers 3 \ 24 | --resume_pkl exp/test/00000-base-ffhqs-gpus8-batch32/network_snapshots/network-snapshot-000020.pkl 25 | -------------------------------------------------------------------------------- /test_resume_config_a.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --outdir exp/test \ 3 | --encoder transformer \ 4 | --data data/ffhqs \ 5 | --gpus 8 \ 6 | --batch 32 \ 7 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 8 | --enc_layers 1 \ 9 | --valdata data/ffhqs \ 10 | --training_steps 41 \ 11 | --val_steps 99 \ 12 | --print_steps 5 \ 13 | --tb_steps 5 \ 14 | --img_snshot_steps 10 \ 15 | --net_snshot_steps 10 \ 16 | --lr 0.001 \ 17 | --l2_lambda 1.0 \ 18 | --lpips_lambda 0.8 \ 19 | --id_lambda 0.1 \ 20 | --reg_lambda 0.0 \ 21 | --gan_lambda 0.0 \ 22 | --edit_lambda 0.0 \ 23 | --seed 0 \ 24 | --workers 3 \ 25 | --resume_pkl exp/test/00000-transformer-ffhqs-gpus8-batch32/network_snapshots/network-snapshot-000020.pkl 26 | -------------------------------------------------------------------------------- /test_resume_config_b.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --outdir exp/test \ 3 | --encoder base \ 4 | --data data/ffhqs \ 5 | --gpus 8 \ 6 | --batch 32 \ 7 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 8 | --w_avg \ 9 | --valdata data/ffhqs \ 10 | --training_steps 41 \ 11 | --val_steps 99 \ 12 | --print_steps 5 \ 13 | --tb_steps 5 \ 14 | --img_snshot_steps 10 \ 15 | --net_snshot_steps 10 \ 16 | --lr 0.001 \ 17 | --l2_lambda 1.0 \ 18 | --lpips_lambda 0.8 \ 19 | --id_lambda 0.1 \ 20 | --reg_lambda 0.0 \ 21 | --gan_lambda 0.0 \ 22 | --edit_lambda 0.0 \ 23 | --seed 0 \ 24 | --workers 3 \ 25 | --resume_pkl exp/test/00000-base-ffhqs-gpus8-batch32/network_snapshots/network-snapshot-000020.pkl 26 | -------------------------------------------------------------------------------- /test_train.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --outdir exp/test \ 3 | --encoder base \ 4 | --data data/ffhqs \ 5 | --gpus 8 \ 6 | --batch 32 \ 7 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 8 | --valdata data/ffhqs \ 9 | --training_steps 21 \ 10 | --val_steps 10 \ 11 | --print_steps 5 \ 12 | --tb_steps 5 \ 13 | --img_snshot_steps 10 \ 14 | --net_snshot_steps 10 \ 15 | --lr 0.001 \ 16 | --l2_lambda 1.0 \ 17 | --lpips_lambda 0.8 \ 18 | --id_lambda 0.1 \ 19 | --reg_lambda 0.0 \ 20 | --gan_lambda 0.0 \ 21 | --edit_lambda 0.0 \ 22 | --seed 0 \ 23 | --workers 3 24 | -------------------------------------------------------------------------------- /test_train_config_a.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --outdir exp/test \ 3 | --encoder transformer \ 4 | --data data/ffhqs \ 5 | --gpus 8 \ 6 | --batch 32 \ 7 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 8 | --enc_layers 1 \ 9 | --valdata data/ffhqs \ 10 | --training_steps 21 \ 11 | --val_steps 10 \ 12 | --print_steps 5 \ 13 | --tb_steps 5 \ 14 | --img_snshot_steps 10 \ 15 | --net_snshot_steps 10 \ 16 | --lr 0.001 \ 17 | --l2_lambda 1.0 \ 18 | --lpips_lambda 0.8 \ 19 | --id_lambda 0.1 \ 20 | --reg_lambda 0.0 \ 21 | --gan_lambda 0.0 \ 22 | --edit_lambda 0.0 \ 23 | --seed 0 \ 24 | --workers 3 25 | -------------------------------------------------------------------------------- /test_train_config_b.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --outdir exp/test \ 3 | --encoder base \ 4 | --data data/ffhqs \ 5 | --gpus 8 \ 6 | --batch 32 \ 7 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 8 | --w_avg \ 9 | --valdata data/ffhqs \ 10 | --training_steps 21 \ 11 | --val_steps 10 \ 12 | --print_steps 5 \ 13 | --tb_steps 5 \ 14 | --img_snshot_steps 10 \ 15 | --net_snshot_steps 10 \ 16 | --lr 0.001 \ 17 | --l2_lambda 1.0 \ 18 | --lpips_lambda 0.8 \ 19 | --id_lambda 0.1 \ 20 | --reg_lambda 0.0 \ 21 | --gan_lambda 0.0 \ 22 | --edit_lambda 0.0 \ 23 | --seed 0 \ 24 | --workers 3 25 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 68 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 69 | 70 | @contextlib.contextmanager 71 | def suppress_tracer_warnings(): 72 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 73 | warnings.filters.insert(0, flt) 74 | yield 75 | warnings.filters.remove(flt) 76 | 77 | #---------------------------------------------------------------------------- 78 | # Assert that the shape of a tensor matches the given list of integers. 79 | # None indicates that the size of a dimension is allowed to vary. 80 | # Performs symbolic assertion when used in torch.jit.trace(). 81 | 82 | def assert_shape(tensor, ref_shape): 83 | if tensor.ndim != len(ref_shape): 84 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 85 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 86 | if ref_size is None: 87 | pass 88 | elif isinstance(ref_size, torch.Tensor): 89 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 90 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 91 | elif isinstance(size, torch.Tensor): 92 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 93 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 94 | elif size != ref_size: 95 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 96 | 97 | #---------------------------------------------------------------------------- 98 | # Function decorator that calls torch.autograd.profiler.record_function(). 99 | 100 | def profiled_function(fn): 101 | def decorator(*args, **kwargs): 102 | with torch.autograd.profiler.record_function(fn.__name__): 103 | return fn(*args, **kwargs) 104 | decorator.__name__ = fn.__name__ 105 | return decorator 106 | 107 | #---------------------------------------------------------------------------- 108 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 109 | # indefinitely, shuffling items as it goes. 110 | 111 | class InfiniteSampler(torch.utils.data.Sampler): 112 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 113 | assert len(dataset) > 0 114 | assert num_replicas > 0 115 | assert 0 <= rank < num_replicas 116 | assert 0 <= window_size <= 1 117 | super().__init__(dataset) 118 | self.dataset = dataset 119 | self.rank = rank 120 | self.num_replicas = num_replicas 121 | self.shuffle = shuffle 122 | self.seed = seed 123 | self.window_size = window_size 124 | 125 | def __iter__(self): 126 | order = np.arange(len(self.dataset)) 127 | rnd = None 128 | window = 0 129 | if self.shuffle: 130 | rnd = np.random.RandomState(self.seed) 131 | rnd.shuffle(order) 132 | window = int(np.rint(order.size * self.window_size)) 133 | 134 | idx = 0 135 | while True: 136 | i = idx % order.size 137 | if idx % self.num_replicas == self.rank: 138 | yield order[i] 139 | if window >= 2: 140 | j = (i - rnd.randint(window)) % order.size 141 | order[i], order[j] = order[j], order[i] 142 | idx += 1 143 | 144 | #---------------------------------------------------------------------------- 145 | # Utilities for operating with torch.nn.Module parameters and buffers. 146 | 147 | def params_and_buffers(module): 148 | assert isinstance(module, torch.nn.Module) 149 | return list(module.parameters()) + list(module.buffers()) 150 | 151 | def named_params_and_buffers(module): 152 | assert isinstance(module, torch.nn.Module) 153 | return list(module.named_parameters()) + list(module.named_buffers()) 154 | 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | import dnnlib 15 | 16 | from .. import custom_ops 17 | from .. import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | activation_funcs = { 22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 31 | } 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | _plugin = None 36 | _null_tensor = torch.empty([0]) 37 | 38 | def _init(): 39 | global _plugin 40 | if _plugin is None: 41 | _plugin = custom_ops.get_plugin( 42 | module_name='bias_act_plugin', 43 | sources=['bias_act.cpp', 'bias_act.cu'], 44 | headers=['bias_act.h'], 45 | source_dir=os.path.dirname(__file__), 46 | extra_cuda_cflags=['--use_fast_math'], 47 | ) 48 | return True 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 53 | r"""Fused bias and activation function. 54 | 55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 56 | and scales the result by `gain`. Each of the steps is optional. In most cases, 57 | the fused op is considerably more efficient than performing the same calculation 58 | using standard PyTorch ops. It supports first and second order gradients, 59 | but not third order gradients. 60 | 61 | Args: 62 | x: Input activation tensor. Can be of any shape. 63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 64 | as `x`. The shape must be known, and it must match the dimension of `x` 65 | corresponding to `dim`. 66 | dim: The dimension in `x` corresponding to the elements of `b`. 67 | The value of `dim` is ignored if `b` is not specified. 68 | act: Name of the activation function to evaluate, or `"linear"` to disable. 69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 70 | See `activation_funcs` for a full list. `None` is not allowed. 71 | alpha: Shape parameter for the activation function, or `None` to use the default. 72 | gain: Scaling factor for the output tensor, or `None` to use default. 73 | See `activation_funcs` for the default scaling of each activation function. 74 | If unsure, consider specifying 1. 75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 76 | the clamping (default). 77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 78 | 79 | Returns: 80 | Tensor of the same shape and datatype as `x`. 81 | """ 82 | assert isinstance(x, torch.Tensor) 83 | assert impl in ['ref', 'cuda'] 84 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @misc.profiled_function 91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 93 | """ 94 | assert isinstance(x, torch.Tensor) 95 | assert clamp is None or clamp >= 0 96 | spec = activation_funcs[act] 97 | alpha = float(alpha if alpha is not None else spec.def_alpha) 98 | gain = float(gain if gain is not None else spec.def_gain) 99 | clamp = float(clamp if clamp is not None else -1) 100 | 101 | # Add bias. 102 | if b is not None: 103 | assert isinstance(b, torch.Tensor) and b.ndim == 1 104 | assert 0 <= dim < x.ndim 105 | assert b.shape[0] == x.shape[dim] 106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 107 | 108 | # Evaluate activation function. 109 | alpha = float(alpha) 110 | x = spec.func(x, alpha=alpha) 111 | 112 | # Scale by gain. 113 | gain = float(gain) 114 | if gain != 1: 115 | x = x * gain 116 | 117 | # Clamp. 118 | if clamp >= 0: 119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 120 | return x 121 | 122 | #---------------------------------------------------------------------------- 123 | 124 | _bias_act_cuda_cache = dict() 125 | 126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 127 | """Fast CUDA implementation of `bias_act()` using custom ops. 128 | """ 129 | # Parse arguments. 130 | assert clamp is None or clamp >= 0 131 | spec = activation_funcs[act] 132 | alpha = float(alpha if alpha is not None else spec.def_alpha) 133 | gain = float(gain if gain is not None else spec.def_gain) 134 | clamp = float(clamp if clamp is not None else -1) 135 | 136 | # Lookup from cache. 137 | key = (dim, act, alpha, gain, clamp) 138 | if key in _bias_act_cuda_cache: 139 | return _bias_act_cuda_cache[key] 140 | 141 | # Forward op. 142 | class BiasActCuda(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, x, b): # pylint: disable=arguments-differ 145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 146 | x = x.contiguous(memory_format=ctx.memory_format) 147 | b = b.contiguous() if b is not None else _null_tensor 148 | y = x 149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 151 | ctx.save_for_backward( 152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 154 | y if 'y' in spec.ref else _null_tensor) 155 | return y 156 | 157 | @staticmethod 158 | def backward(ctx, dy): # pylint: disable=arguments-differ 159 | dy = dy.contiguous(memory_format=ctx.memory_format) 160 | x, b, y = ctx.saved_tensors 161 | dx = None 162 | db = None 163 | 164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 165 | dx = dy 166 | if act != 'linear' or gain != 1 or clamp >= 0: 167 | dx = BiasActCudaGrad.apply(dy, x, b, y) 168 | 169 | if ctx.needs_input_grad[1]: 170 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 171 | 172 | return dx, db 173 | 174 | # Backward op. 175 | class BiasActCudaGrad(torch.autograd.Function): 176 | @staticmethod 177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 180 | ctx.save_for_backward( 181 | dy if spec.has_2nd_grad else _null_tensor, 182 | x, b, y) 183 | return dx 184 | 185 | @staticmethod 186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 188 | dy, x, b, y = ctx.saved_tensors 189 | d_dy = None 190 | d_x = None 191 | d_b = None 192 | d_y = None 193 | 194 | if ctx.needs_input_grad[0]: 195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 196 | 197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 199 | 200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 202 | 203 | return d_dy, d_x, d_b, d_y 204 | 205 | # Add to cache. 206 | _bias_act_cuda_cache[key] = BiasActCuda 207 | return BiasActCuda 208 | 209 | #---------------------------------------------------------------------------- 210 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import torch 14 | 15 | # pylint: disable=redefined-builtin 16 | # pylint: disable=arguments-differ 17 | # pylint: disable=protected-access 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | enabled = False # Enable the custom op by setting this to true. 22 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 23 | 24 | @contextlib.contextmanager 25 | def no_weight_gradients(disable=True): 26 | global weight_gradients_disabled 27 | old = weight_gradients_disabled 28 | if disable: 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | return True 54 | 55 | def _tuple_of_ints(xs, ndim): 56 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 57 | assert len(xs) == ndim 58 | assert all(isinstance(x, int) for x in xs) 59 | return xs 60 | 61 | #---------------------------------------------------------------------------- 62 | 63 | _conv2d_gradfix_cache = dict() 64 | _null_tensor = torch.empty([0]) 65 | 66 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 67 | # Parse arguments. 68 | ndim = 2 69 | weight_shape = tuple(weight_shape) 70 | stride = _tuple_of_ints(stride, ndim) 71 | padding = _tuple_of_ints(padding, ndim) 72 | output_padding = _tuple_of_ints(output_padding, ndim) 73 | dilation = _tuple_of_ints(dilation, ndim) 74 | 75 | # Lookup from cache. 76 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 77 | if key in _conv2d_gradfix_cache: 78 | return _conv2d_gradfix_cache[key] 79 | 80 | # Validate arguments. 81 | assert groups >= 1 82 | assert len(weight_shape) == ndim + 2 83 | assert all(stride[i] >= 1 for i in range(ndim)) 84 | assert all(padding[i] >= 0 for i in range(ndim)) 85 | assert all(dilation[i] >= 0 for i in range(ndim)) 86 | if not transpose: 87 | assert all(output_padding[i] == 0 for i in range(ndim)) 88 | else: # transpose 89 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 90 | 91 | # Helpers. 92 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 93 | def calc_output_padding(input_shape, output_shape): 94 | if transpose: 95 | return [0, 0] 96 | return [ 97 | input_shape[i + 2] 98 | - (output_shape[i + 2] - 1) * stride[i] 99 | - (1 - 2 * padding[i]) 100 | - dilation[i] * (weight_shape[i + 2] - 1) 101 | for i in range(ndim) 102 | ] 103 | 104 | # Forward & backward. 105 | class Conv2d(torch.autograd.Function): 106 | @staticmethod 107 | def forward(ctx, input, weight, bias): 108 | assert weight.shape == weight_shape 109 | ctx.save_for_backward( 110 | input if weight.requires_grad else _null_tensor, 111 | weight if input.requires_grad else _null_tensor, 112 | ) 113 | ctx.input_shape = input.shape 114 | 115 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 116 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 117 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 118 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 119 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 120 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 121 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 122 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 123 | 124 | # General case => cuDNN. 125 | if transpose: 126 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 127 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | input, weight = ctx.saved_tensors 132 | input_shape = ctx.input_shape 133 | grad_input = None 134 | grad_weight = None 135 | grad_bias = None 136 | 137 | if ctx.needs_input_grad[0]: 138 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 139 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 140 | grad_input = op.apply(grad_output, weight, None) 141 | assert grad_input.shape == input_shape 142 | 143 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 144 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 145 | assert grad_weight.shape == weight_shape 146 | 147 | if ctx.needs_input_grad[2]: 148 | grad_bias = grad_output.sum([0, 2, 3]) 149 | 150 | return grad_input, grad_weight, grad_bias 151 | 152 | # Gradient with respect to the weights. 153 | class Conv2dGradWeight(torch.autograd.Function): 154 | @staticmethod 155 | def forward(ctx, grad_output, input): 156 | ctx.save_for_backward( 157 | grad_output if input.requires_grad else _null_tensor, 158 | input if grad_output.requires_grad else _null_tensor, 159 | ) 160 | ctx.grad_output_shape = grad_output.shape 161 | ctx.input_shape = input.shape 162 | 163 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 164 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 165 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 166 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 167 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 168 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 169 | 170 | # General case => cuDNN. 171 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' 172 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 173 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 174 | 175 | @staticmethod 176 | def backward(ctx, grad2_grad_weight): 177 | grad_output, input = ctx.saved_tensors 178 | grad_output_shape = ctx.grad_output_shape 179 | input_shape = ctx.input_shape 180 | grad2_grad_output = None 181 | grad2_input = None 182 | 183 | if ctx.needs_input_grad[0]: 184 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 185 | assert grad2_grad_output.shape == grad_output_shape 186 | 187 | if ctx.needs_input_grad[1]: 188 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 189 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 190 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 191 | assert grad2_input.shape == input_shape 192 | 193 | return grad2_grad_output, grad2_input 194 | 195 | _conv2d_gradfix_cache[key] = Conv2d 196 | return Conv2d 197 | 198 | #---------------------------------------------------------------------------- 199 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Train stylegan3 encoder""" 2 | 3 | import json 4 | import os 5 | import random 6 | import re 7 | import tempfile 8 | 9 | import click 10 | import numpy as np 11 | import torch 12 | 13 | import dnnlib 14 | from training import training_loop_encoder 15 | from torch_utils import training_stats 16 | from torch_utils import custom_ops 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | @click.command() 21 | 22 | # Required. 23 | @click.option('--outdir', help='Where to save the results', metavar='DIR', required=True) 24 | @click.option('--encoder', help='Encoder architecture type', type=click.Choice(['base','transformer']), required=True) 25 | @click.option('--data', help='Training data', metavar='[DIR]', type=str, required=True) 26 | @click.option('--gpus', help='Number of GPUs to use', metavar='INT', type=click.IntRange(min=1), required=True) 27 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True) 28 | @click.option('--generator', help='Generator pickle to encode', required=True) 29 | 30 | # Encoder settings 31 | @click.option('--w_avg', help='Train delta w from w_avg', is_flag=True) 32 | @click.option('--enc_layers', help='Transformer encoder layers', metavar='INT', type=click.IntRange(min=1), default=1) 33 | 34 | # Validate 35 | @click.option('--valdata', help='Validation data', metavar='[DIR]', type=str) 36 | 37 | # Training, logging batch steps 38 | @click.option('--training_steps', help='Total training steps', type=click.IntRange(min=1), default=100001) 39 | @click.option('--val_steps', help='Validation batch steps', type=click.IntRange(min=1), default=10000) 40 | @click.option('--print_steps', help='How often to print logs', type=click.IntRange(min=1), default=50) 41 | @click.option('--tb_steps', help='How often to log to tensorboard?', type=click.IntRange(min=1), default=50) 42 | @click.option('--img_snshot_steps', help='How often to save image snapshots?', type=click.IntRange(min=1), default=100) 43 | @click.option('--net_snshot_steps', help='How often to save network snapshots?', type=click.IntRange(min=1), default=5000) 44 | 45 | # Define Loss 46 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=0.001, show_default=True) 47 | @click.option('--l2_lambda', help='L2 loss multiplier factor', metavar='FLOAT', type=click.FloatRange(min=0), default=1.0, show_default=True) 48 | @click.option('--lpips_lambda', help='LPIPS loss multiplier factor', metavar='FLOAT', type=click.FloatRange(min=0), default=0.8, show_default=True) 49 | @click.option('--id_lambda', help='ID loss multiplier factor', metavar='FLOAT', type=click.FloatRange(min=0), default=0.1, show_default=True) 50 | @click.option('--reg_lambda', help='e4e reg loss multiplier factor', metavar='FLOAT', type=click.FloatRange(min=0), default=0.0, show_default=True) 51 | @click.option('--gan_lambda', help='e4e gan loss multiplier factor', metavar='FLOAT', type=click.FloatRange(min=0), default=0.0, show_default=True) 52 | @click.option('--edit_lambda', help='e4e editability lambda', metavar='FLOAT', type=click.FloatRange(min=0), default=0.0, show_default=True) 53 | 54 | # Reproducibility 55 | @click.option('--seed', help='Random seed', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) 56 | 57 | # Dataloader workers 58 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=3, show_default=True) 59 | 60 | # Resume 61 | @click.option('--resume_pkl', help='Network pickle to resume training', default=None, show_default=True) 62 | 63 | 64 | def main(**kwargs): 65 | """Main training script 66 | """ 67 | # Initialize config. 68 | opts = dnnlib.EasyDict(kwargs) # Command line arguments. 69 | c = dnnlib.EasyDict() # Main config dict. 70 | 71 | c.model_architecture = opts.encoder 72 | c.dataset_dir = opts.data 73 | c.num_gpus = opts.gpus 74 | c.batch_size = opts.batch 75 | c.batch_gpu = opts.batch // opts.gpus 76 | c.generator_pkl = opts.generator 77 | 78 | c.w_avg = opts.w_avg 79 | c.num_encoder_layers = opts.enc_layers 80 | 81 | c.val_dataset_dir = opts.valdata 82 | 83 | c.training_steps = opts.training_steps 84 | c.val_steps = opts.val_steps 85 | c.print_steps = opts.print_steps 86 | c.tensorboard_steps = opts.tb_steps 87 | c.image_snapshot_steps = opts.img_snshot_steps 88 | c.network_snapshot_steps = opts.net_snshot_steps 89 | 90 | c.learning_rate = opts.lr 91 | c.l2_lambda = opts.l2_lambda 92 | c.lpips_lambda = opts.lpips_lambda 93 | c.id_lambda = opts.id_lambda 94 | c.reg_lambda = opts.reg_lambda 95 | c.gan_lambda = opts.gan_lambda 96 | c.edit_lambda = opts.edit_lambda 97 | 98 | c.random_seed = opts.seed 99 | c.num_workers = opts.workers 100 | c.resume_pkl = opts.resume_pkl 101 | 102 | # Description string. 103 | dataset_name = c.dataset_dir.split('/')[-1] 104 | desc = f'{c.model_architecture:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}' 105 | # TODO: add resume related description 106 | 107 | # Pick output directory. 108 | prev_run_dirs = [] 109 | if os.path.isdir(opts.outdir): 110 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))] 111 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] 112 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] 113 | cur_run_id = max(prev_run_ids, default=-1) + 1 114 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05}-{desc}') 115 | assert not os.path.exists(c.run_dir) 116 | 117 | # Print options. 118 | print() 119 | print('Training options:') 120 | print(json.dumps(c, indent=2)) 121 | print() 122 | 123 | # Create output directory. 124 | print('Creating output directory...') 125 | os.makedirs(c.run_dir) 126 | os.makedirs(f'{c.run_dir}/image_snapshots/') 127 | os.makedirs(f'{c.run_dir}/network_snapshots/') 128 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: 129 | json.dump(c, f, indent=2) 130 | 131 | # Launch processes. 132 | print('Launching processes...') 133 | torch.multiprocessing.set_start_method('spawn') 134 | with tempfile.TemporaryDirectory() as temp_dir: 135 | if c.num_gpus == 1: 136 | subprocess_fn(rank=0, c=c, temp_dir=temp_dir) 137 | else: 138 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus) 139 | 140 | 141 | def subprocess_fn(rank, c, temp_dir): 142 | # Init torch.distributed. 143 | # if c.num_gpus > 1: 144 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 145 | init_method = f'file://{init_file}' 146 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus) 147 | 148 | # Init torch_utils 149 | torch.cuda.set_device(rank) 150 | sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None 151 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 152 | if rank != 0: 153 | custom_ops.verbosity = 'none' 154 | 155 | # Execute training loop. 156 | training_loop_encoder.training_loop(rank=rank, **c) 157 | 158 | #---------------------------------------------------------------------------- 159 | 160 | if __name__ == "__main__": 161 | main() # pylint: disable=no-value-for-parameter 162 | 163 | #---------------------------------------------------------------------------- 164 | -------------------------------------------------------------------------------- /train_base.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --outdir exp/base \ 3 | --encoder base \ 4 | --data data/ffhq \ 5 | --gpus 8 \ 6 | --batch 32 \ 7 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 8 | -------------------------------------------------------------------------------- /train_config_a.sh: -------------------------------------------------------------------------------- 1 | # use transformer encoder instead of gradual style block CNN architecture -> discarded 2 | python train.py \ 3 | --outdir exp/config_a \ 4 | --encoder transformer \ 5 | --data data/ffhq \ 6 | --gpus 8 \ 7 | --batch 32 \ 8 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 9 | --enc_layers 1 \ 10 | --img_snshot_steps 1000 \ 11 | -------------------------------------------------------------------------------- /train_config_b.sh: -------------------------------------------------------------------------------- 1 | # train delta_w instead of w, not apply reg loss 2 | python train.py \ 3 | --outdir exp/config-b \ 4 | --encoder base \ 5 | --data data/ffhq \ 6 | --gpus 8 \ 7 | --batch 32 \ 8 | --generator pretrained/stylegan3-t-ffhq-1024x1024.pkl \ 9 | --w_avg \ 10 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Streaming images and labels from datasets created with dataset_tool.py.""" 10 | 11 | import os 12 | import numpy as np 13 | import zipfile 14 | import PIL.Image 15 | import json 16 | import torch 17 | import dnnlib 18 | 19 | try: 20 | import pyspng 21 | except ImportError: 22 | pyspng = None 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, 28 | name, # Name of the dataset. 29 | raw_shape, # Shape of the raw image data (NCHW). 30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 33 | random_seed = 0, # Random seed to use when applying max_size. 34 | ): 35 | self._name = name 36 | self._raw_shape = list(raw_shape) 37 | self._use_labels = use_labels 38 | self._raw_labels = None 39 | self._label_shape = None 40 | 41 | # Apply max_size. 42 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 43 | if (max_size is not None) and (self._raw_idx.size > max_size): 44 | np.random.RandomState(random_seed).shuffle(self._raw_idx) 45 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 46 | 47 | # Apply xflip. 48 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 49 | if xflip: 50 | self._raw_idx = np.tile(self._raw_idx, 2) 51 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 52 | 53 | def _get_raw_labels(self): 54 | if self._raw_labels is None: 55 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 56 | if self._raw_labels is None: 57 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 58 | assert isinstance(self._raw_labels, np.ndarray) 59 | assert self._raw_labels.shape[0] == self._raw_shape[0] 60 | assert self._raw_labels.dtype in [np.float32, np.int64] 61 | if self._raw_labels.dtype == np.int64: 62 | assert self._raw_labels.ndim == 1 63 | assert np.all(self._raw_labels >= 0) 64 | return self._raw_labels 65 | 66 | def close(self): # to be overridden by subclass 67 | pass 68 | 69 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 70 | raise NotImplementedError 71 | 72 | def _load_raw_labels(self): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def __getstate__(self): 76 | return dict(self.__dict__, _raw_labels=None) 77 | 78 | def __del__(self): 79 | try: 80 | self.close() 81 | except: 82 | pass 83 | 84 | def __len__(self): 85 | return self._raw_idx.size 86 | 87 | def __getitem__(self, idx): 88 | image = self._load_raw_image(self._raw_idx[idx]) 89 | assert isinstance(image, np.ndarray) 90 | assert list(image.shape) == self.image_shape 91 | assert image.dtype == np.uint8 92 | if self._xflip[idx]: 93 | assert image.ndim == 3 # CHW 94 | image = image[:, :, ::-1] 95 | return image.copy(), self.get_label(idx) 96 | 97 | def get_label(self, idx): 98 | label = self._get_raw_labels()[self._raw_idx[idx]] 99 | if label.dtype == np.int64: 100 | onehot = np.zeros(self.label_shape, dtype=np.float32) 101 | onehot[label] = 1 102 | label = onehot 103 | return label.copy() 104 | 105 | def get_details(self, idx): 106 | d = dnnlib.EasyDict() 107 | d.raw_idx = int(self._raw_idx[idx]) 108 | d.xflip = (int(self._xflip[idx]) != 0) 109 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 110 | return d 111 | 112 | @property 113 | def name(self): 114 | return self._name 115 | 116 | @property 117 | def image_shape(self): 118 | return list(self._raw_shape[1:]) 119 | 120 | @property 121 | def num_channels(self): 122 | assert len(self.image_shape) == 3 # CHW 123 | return self.image_shape[0] 124 | 125 | @property 126 | def resolution(self): 127 | assert len(self.image_shape) == 3 # CHW 128 | assert self.image_shape[1] == self.image_shape[2] 129 | return self.image_shape[1] 130 | 131 | @property 132 | def label_shape(self): 133 | if self._label_shape is None: 134 | raw_labels = self._get_raw_labels() 135 | if raw_labels.dtype == np.int64: 136 | self._label_shape = [int(np.max(raw_labels)) + 1] 137 | else: 138 | self._label_shape = raw_labels.shape[1:] 139 | return list(self._label_shape) 140 | 141 | @property 142 | def label_dim(self): 143 | assert len(self.label_shape) == 1 144 | return self.label_shape[0] 145 | 146 | @property 147 | def has_labels(self): 148 | return any(x != 0 for x in self.label_shape) 149 | 150 | @property 151 | def has_onehot_labels(self): 152 | return self._get_raw_labels().dtype == np.int64 153 | 154 | #---------------------------------------------------------------------------- 155 | 156 | class ImageFolderDataset(Dataset): 157 | def __init__(self, 158 | path, # Path to directory or zip. 159 | resolution = None, # Ensure specific resolution, None = highest available. 160 | **super_kwargs, # Additional arguments for the Dataset base class. 161 | ): 162 | self._path = path 163 | self._zipfile = None 164 | 165 | if os.path.isdir(self._path): 166 | self._type = 'dir' 167 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 168 | elif self._file_ext(self._path) == '.zip': 169 | self._type = 'zip' 170 | self._all_fnames = set(self._get_zipfile().namelist()) 171 | else: 172 | raise IOError('Path must point to a directory or zip') 173 | 174 | PIL.Image.init() 175 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 176 | if len(self._image_fnames) == 0: 177 | raise IOError('No image files found in the specified path') 178 | 179 | name = os.path.splitext(os.path.basename(self._path))[0] 180 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 181 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 182 | raise IOError('Image files do not match the specified resolution') 183 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 184 | 185 | @staticmethod 186 | def _file_ext(fname): 187 | return os.path.splitext(fname)[1].lower() 188 | 189 | def _get_zipfile(self): 190 | assert self._type == 'zip' 191 | if self._zipfile is None: 192 | self._zipfile = zipfile.ZipFile(self._path) 193 | return self._zipfile 194 | 195 | def _open_file(self, fname): 196 | if self._type == 'dir': 197 | return open(os.path.join(self._path, fname), 'rb') 198 | if self._type == 'zip': 199 | return self._get_zipfile().open(fname, 'r') 200 | return None 201 | 202 | def close(self): 203 | try: 204 | if self._zipfile is not None: 205 | self._zipfile.close() 206 | finally: 207 | self._zipfile = None 208 | 209 | def __getstate__(self): 210 | return dict(super().__getstate__(), _zipfile=None) 211 | 212 | def _load_raw_image(self, raw_idx): 213 | fname = self._image_fnames[raw_idx] 214 | with self._open_file(fname) as f: 215 | if pyspng is not None and self._file_ext(fname) == '.png': 216 | image = pyspng.load(f.read()) 217 | else: 218 | image = np.array(PIL.Image.open(f)) 219 | if image.ndim == 2: 220 | image = image[:, :, np.newaxis] # HW => HWC 221 | image = image.transpose(2, 0, 1) # HWC => CHW 222 | return image 223 | 224 | def _load_raw_labels(self): 225 | fname = 'dataset.json' 226 | if fname not in self._all_fnames: 227 | return None 228 | with self._open_file(fname) as f: 229 | labels = json.load(f)['labels'] 230 | if labels is None: 231 | return None 232 | labels = dict(labels) 233 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 234 | labels = np.array(labels) 235 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 236 | return labels 237 | 238 | #---------------------------------------------------------------------------- 239 | -------------------------------------------------------------------------------- /training/dataset_encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import PIL.Image 5 | import torch 6 | from torchvision.transforms import (Compose, Resize, RandomHorizontalFlip, 7 | ToTensor, Normalize) 8 | 9 | # data utils 10 | """ 11 | Code adopted from pix2pixHD: 12 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 13 | """ 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir_): 25 | images = [] 26 | assert os.path.isdir(dir_), '%s is not a valid directory' % dir_ 27 | for root, _, fnames in sorted(os.walk(dir_)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images 33 | 34 | 35 | # dataset 36 | class ImagesDataset(torch.utils.data.Dataset): 37 | def __init__(self, dataset_dir, mode='train'): 38 | assert mode in ['train', 'test', 'inference'] 39 | self.paths = sorted(make_dataset(dataset_dir)) 40 | transforms_dict = { 41 | 'train': Compose([ 42 | Resize((256, 256)), 43 | RandomHorizontalFlip(0.5), 44 | ToTensor(), 45 | Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 46 | 'test': Compose([ 47 | Resize((256, 256)), 48 | ToTensor(), 49 | Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 50 | 'inference': Compose([ 51 | Resize((256, 256)), 52 | ToTensor(), 53 | Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 54 | } 55 | self.transforms = transforms_dict.get(mode) 56 | 57 | def __len__(self): 58 | return len(self.paths) 59 | 60 | def __getitem__(self, i): 61 | x = self.transforms(PIL.Image.open(self.paths[i]).convert('RGB')) 62 | y = copy.deepcopy(x) 63 | return x,y 64 | 65 | 66 | # TODO : implement distributed sampler for ddp 67 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Loss functions.""" 10 | 11 | import numpy as np 12 | import torch 13 | from torch_utils import training_stats 14 | from torch_utils.ops import conv2d_gradfix 15 | from torch_utils.ops import upfirdn2d 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class Loss: 20 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass 21 | raise NotImplementedError() 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | class StyleGAN2Loss(Loss): 26 | def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0): 27 | super().__init__() 28 | self.device = device 29 | self.G = G 30 | self.D = D 31 | self.augment_pipe = augment_pipe 32 | self.r1_gamma = r1_gamma 33 | self.style_mixing_prob = style_mixing_prob 34 | self.pl_weight = pl_weight 35 | self.pl_batch_shrink = pl_batch_shrink 36 | self.pl_decay = pl_decay 37 | self.pl_no_weight_grad = pl_no_weight_grad 38 | self.pl_mean = torch.zeros([], device=device) 39 | self.blur_init_sigma = blur_init_sigma 40 | self.blur_fade_kimg = blur_fade_kimg 41 | 42 | def run_G(self, z, c, update_emas=False): 43 | ws = self.G.mapping(z, c, update_emas=update_emas) 44 | if self.style_mixing_prob > 0: 45 | with torch.autograd.profiler.record_function('style_mixing'): 46 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) 47 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) 48 | ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] 49 | img = self.G.synthesis(ws, update_emas=update_emas) 50 | return img, ws 51 | 52 | def run_D(self, img, c, blur_sigma=0, update_emas=False): 53 | blur_size = np.floor(blur_sigma * 3) 54 | if blur_size > 0: 55 | with torch.autograd.profiler.record_function('blur'): 56 | f = torch.arange(-blur_size, blur_size + 1, device=img.device).div(blur_sigma).square().neg().exp2() 57 | img = upfirdn2d.filter2d(img, f / f.sum()) 58 | if self.augment_pipe is not None: 59 | img = self.augment_pipe(img) 60 | logits = self.D(img, c, update_emas=update_emas) 61 | return logits 62 | 63 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): 64 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] 65 | if self.pl_weight == 0: 66 | phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) 67 | if self.r1_gamma == 0: 68 | phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) 69 | blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 70 | 71 | # Gmain: Maximize logits for generated images. 72 | if phase in ['Gmain', 'Gboth']: 73 | with torch.autograd.profiler.record_function('Gmain_forward'): 74 | gen_img, _gen_ws = self.run_G(gen_z, gen_c) 75 | gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) 76 | training_stats.report('Loss/scores/fake', gen_logits) 77 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 78 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) 79 | training_stats.report('Loss/G/loss', loss_Gmain) 80 | with torch.autograd.profiler.record_function('Gmain_backward'): 81 | loss_Gmain.mean().mul(gain).backward() 82 | 83 | # Gpl: Apply path length regularization. 84 | if phase in ['Greg', 'Gboth']: 85 | with torch.autograd.profiler.record_function('Gpl_forward'): 86 | batch_size = gen_z.shape[0] // self.pl_batch_shrink 87 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size]) 88 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) 89 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad): 90 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] 91 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() 92 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) 93 | self.pl_mean.copy_(pl_mean.detach()) 94 | pl_penalty = (pl_lengths - pl_mean).square() 95 | training_stats.report('Loss/pl_penalty', pl_penalty) 96 | loss_Gpl = pl_penalty * self.pl_weight 97 | training_stats.report('Loss/G/reg', loss_Gpl) 98 | with torch.autograd.profiler.record_function('Gpl_backward'): 99 | loss_Gpl.mean().mul(gain).backward() 100 | 101 | # Dmain: Minimize logits for generated images. 102 | loss_Dgen = 0 103 | if phase in ['Dmain', 'Dboth']: 104 | with torch.autograd.profiler.record_function('Dgen_forward'): 105 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True) 106 | gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) 107 | training_stats.report('Loss/scores/fake', gen_logits) 108 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 109 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) 110 | with torch.autograd.profiler.record_function('Dgen_backward'): 111 | loss_Dgen.mean().mul(gain).backward() 112 | 113 | # Dmain: Maximize logits for real images. 114 | # Dr1: Apply R1 regularization. 115 | if phase in ['Dmain', 'Dreg', 'Dboth']: 116 | name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' 117 | with torch.autograd.profiler.record_function(name + '_forward'): 118 | real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']) 119 | real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) 120 | training_stats.report('Loss/scores/real', real_logits) 121 | training_stats.report('Loss/signs/real', real_logits.sign()) 122 | 123 | loss_Dreal = 0 124 | if phase in ['Dmain', 'Dboth']: 125 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) 126 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) 127 | 128 | loss_Dr1 = 0 129 | if phase in ['Dreg', 'Dboth']: 130 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): 131 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] 132 | r1_penalty = r1_grads.square().sum([1,2,3]) 133 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2) 134 | training_stats.report('Loss/r1_penalty', r1_penalty) 135 | training_stats.report('Loss/D/reg', loss_Dr1) 136 | 137 | with torch.autograd.profiler.record_function(name + '_backward'): 138 | (loss_Dreal + loss_Dr1).mean().mul(gain).backward() 139 | 140 | #---------------------------------------------------------------------------- 141 | -------------------------------------------------------------------------------- /training/loss_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from training.networks_irse import Backbone 3 | 4 | 5 | l2_criterion = torch.nn.MSELoss(reduction='mean') 6 | 7 | 8 | def l2_loss(generated_images, real_images): 9 | return l2_criterion(generated_images, real_images) 10 | 11 | 12 | class IDLoss(torch.nn.Module): 13 | def __init__(self): 14 | super(IDLoss, self).__init__() 15 | self.facenet = Backbone( 16 | input_size=112, 17 | num_layers=50, 18 | drop_ratio=0.6, 19 | mode='ir_se' 20 | ) 21 | # TODO : get pretrained weight path from config 22 | self.facenet.load_state_dict(torch.load('pretrained/model_ir_se50.pth')) 23 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112,112)) 24 | self.facenet.eval() 25 | 26 | def extract_feats(self, x): 27 | x = x[:, :, 35:223, 32:220] # Crop interesting region 28 | x = self.face_pool(x) 29 | x_feats = self.facenet(x) 30 | return x_feats 31 | 32 | def forward(self, generated_images, y, x): # y == x for image inversion 33 | n_samples = x.shape[0] 34 | x_feats = self.extract_feats(x) 35 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 36 | generated_feats = self.extract_feats(generated_images) 37 | y_feats = y_feats.detach() 38 | loss = 0 39 | sim_improvement = 0 40 | count = 0 41 | for i in range(n_samples): 42 | diff_target = generated_feats[i].dot(y_feats[i]) 43 | diff_input = generated_feats[i].dot(x_feats[i]) 44 | diff_views = y_feats[i].dot(x_feats[i]) 45 | loss += 1 - diff_target 46 | id_diff = float(diff_target) - float(diff_views) 47 | sim_improvement += id_diff 48 | count += 1 49 | 50 | return loss / count, sim_improvement / count 51 | -------------------------------------------------------------------------------- /training/networks_arcface.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | class Flatten(Module): 10 | def forward(self, input): 11 | return input.view(input.size(0), -1) 12 | 13 | 14 | def l2_norm(input, axis=1): 15 | norm = torch.norm(input, 2, axis, True) 16 | output = torch.div(input, norm) 17 | return output 18 | 19 | 20 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 21 | """ A named tuple describing a ResNet block. """ 22 | 23 | 24 | def get_block(in_channel, depth, num_units, stride=2): 25 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 26 | 27 | 28 | def get_blocks(num_layers): 29 | if num_layers == 50: 30 | blocks = [ 31 | get_block(in_channel=64, depth=64, num_units=3), 32 | get_block(in_channel=64, depth=128, num_units=4), 33 | get_block(in_channel=128, depth=256, num_units=14), 34 | get_block(in_channel=256, depth=512, num_units=3) 35 | ] 36 | elif num_layers == 100: 37 | blocks = [ 38 | get_block(in_channel=64, depth=64, num_units=3), 39 | get_block(in_channel=64, depth=128, num_units=13), 40 | get_block(in_channel=128, depth=256, num_units=30), 41 | get_block(in_channel=256, depth=512, num_units=3) 42 | ] 43 | elif num_layers == 152: 44 | blocks = [ 45 | get_block(in_channel=64, depth=64, num_units=3), 46 | get_block(in_channel=64, depth=128, num_units=8), 47 | get_block(in_channel=128, depth=256, num_units=36), 48 | get_block(in_channel=256, depth=512, num_units=3) 49 | ] 50 | else: 51 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 52 | return blocks 53 | 54 | 55 | class SEModule(Module): 56 | def __init__(self, channels, reduction): 57 | super(SEModule, self).__init__() 58 | self.avg_pool = AdaptiveAvgPool2d(1) 59 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 60 | self.relu = ReLU(inplace=True) 61 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 62 | self.sigmoid = Sigmoid() 63 | 64 | def forward(self, x): 65 | module_input = x 66 | x = self.avg_pool(x) 67 | x = self.fc1(x) 68 | x = self.relu(x) 69 | x = self.fc2(x) 70 | x = self.sigmoid(x) 71 | return module_input * x 72 | 73 | 74 | class bottleneck_IR(Module): 75 | def __init__(self, in_channel, depth, stride): 76 | super(bottleneck_IR, self).__init__() 77 | if in_channel == depth: 78 | self.shortcut_layer = MaxPool2d(1, stride) 79 | else: 80 | self.shortcut_layer = Sequential( 81 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 82 | BatchNorm2d(depth) 83 | ) 84 | self.res_layer = Sequential( 85 | BatchNorm2d(in_channel), 86 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 87 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 88 | ) 89 | 90 | def forward(self, x): 91 | shortcut = self.shortcut_layer(x) 92 | res = self.res_layer(x) 93 | return res + shortcut 94 | 95 | 96 | class bottleneck_IR_SE(Module): 97 | def __init__(self, in_channel, depth, stride): 98 | super(bottleneck_IR_SE, self).__init__() 99 | if in_channel == depth: 100 | self.shortcut_layer = MaxPool2d(1, stride) 101 | else: 102 | self.shortcut_layer = Sequential( 103 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 104 | BatchNorm2d(depth) 105 | ) 106 | self.res_layer = Sequential( 107 | BatchNorm2d(in_channel), 108 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 109 | PReLU(depth), 110 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 111 | BatchNorm2d(depth), 112 | SEModule(depth, 16) 113 | ) 114 | 115 | def forward(self, x): 116 | shortcut = self.shortcut_layer(x) 117 | res = self.res_layer(x) 118 | return res + shortcut 119 | -------------------------------------------------------------------------------- /training/networks_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from training.networks_arcface import get_blocks, bottleneck_IR_SE 9 | 10 | 11 | class GradualStyleBlock(torch.nn.Module): 12 | def __init__(self, in_c, out_c, spatial): 13 | super(GradualStyleBlock, self).__init__() 14 | self.out_c = out_c 15 | self.spatial = spatial 16 | num_pools = int(np.log2(spatial)) 17 | modules = [] 18 | modules += [nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 19 | nn.LeakyReLU()] 20 | for i in range(num_pools - 1): 21 | modules += [ 22 | nn.Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 23 | nn.LeakyReLU() 24 | ] 25 | self.convs = nn.Sequential(*modules) 26 | self.linear = nn.Linear(out_c, out_c) 27 | 28 | def forward(self, x): 29 | x = self.convs(x) # [b,512,H,W]->[b,512,1,1] 30 | # (H,W) in [(16,16),(32,32),(64,64)] 31 | x = x.view(-1, self.out_c) # [b,512,1,1]-> [b,512] 32 | x = self.linear(x) 33 | x = nn.LeakyReLU()(x) 34 | return x 35 | 36 | 37 | class PositionalEncoding(nn.Module): 38 | 39 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 40 | super().__init__() 41 | self.dropout = nn.Dropout(p=dropout) 42 | 43 | position = torch.arange(max_len).unsqueeze(1) 44 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 45 | pe = torch.zeros(max_len, 1, d_model) 46 | pe[:, 0, 0::2] = torch.sin(position * div_term) 47 | pe[:, 0, 1::2] = torch.cos(position * div_term) 48 | self.register_buffer('pe', pe) 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | """ 52 | Args: 53 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 54 | """ 55 | x = x + self.pe[:x.size(0)] 56 | return self.dropout(x) 57 | 58 | 59 | class TransformerBlock(torch.nn.Module): 60 | def __init__(self, in_c, out_c, spatial, **styleblock): 61 | super(TransformerBlock, self).__init__() 62 | self.out_c = out_c 63 | self.spatial = spatial 64 | self.num_encoder_layers = styleblock['num_encoder_layers'] 65 | num_pools = int(np.log2(spatial))-4 66 | modules = [] 67 | # modules += [nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 68 | # nn.LeakyReLU()] 69 | for i in range(num_pools-1): 70 | modules += [ 71 | nn.Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 72 | nn.LeakyReLU() 73 | ] 74 | self.convs = nn.Sequential(*modules) 75 | self.positional_encoding = PositionalEncoding(out_c) 76 | self.transformer_encoder = nn.Transformer(num_encoder_layers=self.num_encoder_layers).encoder 77 | #(out_c, out_c) 78 | 79 | def forward(self, x): 80 | x = self.convs(x) # [b,512,H,W]->[b,512,16,16] 81 | # (H,W) in [(16,16),(32,32),(64,64)] 82 | x = x.view(x.shape[0], -1, self.out_c) # [b,256,512] 83 | x = self.positional_encoding(x) 84 | x = self.transformer_encoder(x)[:,0,:] 85 | return x 86 | 87 | 88 | class GradualStyleEncoder(torch.nn.Module): 89 | def __init__(self, **styleblock): 90 | super(GradualStyleEncoder, self).__init__() 91 | blocks = get_blocks(50) # num_layers=50 92 | unit_module = bottleneck_IR_SE # 'ir_se' bottleneck 93 | 94 | self.input_layer = nn.Sequential( 95 | nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False), 96 | nn.BatchNorm2d(64), 97 | nn.PReLU(64) 98 | ) # [b,3,256,256]->[b,3,64,64] 99 | 100 | modules = [] 101 | for block in blocks: 102 | for bottleneck in block: 103 | modules.append(unit_module(bottleneck.in_channel, 104 | bottleneck.depth, 105 | bottleneck.stride)) 106 | self.body = nn.Sequential(*modules) 107 | 108 | self.styles = nn.ModuleList() # feat->latent 109 | 110 | # TODO: 111 | # need some other method for handling w[0] 112 | # train w[0] separately ? 113 | # coarse_ind, middle_ind tuning 114 | self.style_count = 16 115 | self.coarse_ind = 3 116 | self.middle_ind = 7 117 | 118 | if 'arch' in styleblock: 119 | for i in range(self.style_count): 120 | if i < self.coarse_ind: 121 | style = TransformerBlock(512, 512, 16, **styleblock) 122 | elif i < self.middle_ind: 123 | style = TransformerBlock(512, 512, 32, **styleblock) 124 | else: 125 | style = TransformerBlock(512, 512, 64, **styleblock) 126 | self.styles.append(style) 127 | else: 128 | for i in range(self.style_count): 129 | if i < self.coarse_ind: 130 | style = GradualStyleBlock(512, 512, 16) 131 | elif i < self.middle_ind: 132 | style = GradualStyleBlock(512, 512, 32) 133 | else: 134 | style = GradualStyleBlock(512, 512, 64) 135 | self.styles.append(style) 136 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 137 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 138 | 139 | def _upsample_add(self, x, y): 140 | '''Upsample and add two feature maps. 141 | Args: 142 | x: (Variable) top feature map to be upsampled. 143 | y: (Variable) lateral feature map. 144 | Returns: 145 | (Variable) added feature map. 146 | Note in PyTorch, when input size is odd, the upsampled feature map 147 | with `F.upsample(..., scale_factor=2, mode='nearest')` 148 | maybe not equal to the lateral feature map size. 149 | e.g. 150 | original input size: [N,_,15,15] -> 151 | conv2d feature map size: [N,_,8,8] -> 152 | upsampled feature map size: [N,_,16,16] 153 | So we choose bilinear upsample which supports arbitrary output sizes. 154 | ''' 155 | _, _, H, W = y.size() 156 | return nn.functional.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 157 | 158 | def forward(self, x): 159 | x = self.input_layer(x) # [b,3,256,256]->[b,64,256,256] 160 | 161 | latents = [] 162 | modulelist = list(self.body._modules.values()) 163 | 164 | for i, l in enumerate(modulelist): 165 | x = l(x) 166 | if i == 6: 167 | c1 = x # [b,128,64,64] 168 | elif i == 20: 169 | c2 = x # [b,256,32,32] 170 | elif i == 23: 171 | c3 = x # [b,512,16,16] 172 | 173 | for j in range(self.coarse_ind): 174 | latents.append(self.styles[j](c3)) 175 | 176 | p2 = self._upsample_add(c3, self.latlayer1(c2)) # [b,512,32,32] 177 | for j in range(self.coarse_ind, self.middle_ind): 178 | latents.append(self.styles[j](p2)) 179 | 180 | p1 = self._upsample_add(p2, self.latlayer2(c1)) # [b,512,64,64] 181 | for j in range(self.middle_ind, self.style_count): 182 | latents.append(self.styles[j](p1)) 183 | 184 | out = torch.stack(latents, dim=1) 185 | # Adding Transformer structure in this results in gradient exploding 186 | # if self.neck is not None: 187 | # out = self.neck(out) 188 | return out 189 | 190 | 191 | class Encoder(torch.nn.Module): 192 | """stylegan3 encoder implementation 193 | based on pixel2sylte2pixel GradualStyleEncoder 194 | 195 | (b, 3, 256, 256) -> (b, 16, 512) 196 | 197 | stylegan3 generator synthesis 198 | (b, 16, 512) -> (b, 3, 1024, 1024) 199 | """ 200 | def __init__( 201 | self, 202 | pretrained=None, 203 | w_avg=None, 204 | **kwargs, 205 | ): 206 | super(Encoder, self).__init__() 207 | self.encoder = GradualStyleEncoder(**kwargs) # 50, irse 208 | self.resume_step = 0 209 | self.w_avg = w_avg 210 | 211 | # load weight 212 | if pretrained is not None: 213 | with open(pretrained, 'rb') as f: 214 | dic = pickle.load(f) 215 | weights = dic['E'] 216 | weights_ = dict() 217 | for layer in weights: 218 | if 'module.encoder' in layer: 219 | weights_['.'.join(layer.split('.')[2:])] = weights[layer] 220 | self.resume_step = dic['step'] 221 | self.encoder.load_state_dict(weights_, strict=True) 222 | del weights 223 | else: 224 | irse50 = torch.load("pretrained/model_ir_se50.pth", map_location='cpu') 225 | weights = {k:v for k,v in irse50.items() if "input_layer" not in k} 226 | self.encoder.load_state_dict(weights, strict=False) 227 | 228 | def forward(self, img): 229 | if self.w_avg is None: 230 | return self.encoder(img) 231 | else: # train delta_w, from w_avg 232 | delta_w = self.encoder(img) 233 | w = delta_w + self.w_avg.repeat(delta_w.shape[0],1,1) 234 | return w 235 | -------------------------------------------------------------------------------- /training/networks_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from training.networks_arcface import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /training/ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | use_gc=True, gc_conv_only=False 35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, 55 | eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | def __setstate__(self, state): 76 | super(Ranger, self).__setstate__(state) 77 | 78 | def step(self, closure=None): 79 | loss = None 80 | 81 | # Evaluate averages and grad, update param tensors 82 | for group in self.param_groups: 83 | 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | grad = p.grad.data.float() 88 | 89 | if grad.is_sparse: 90 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 91 | 92 | p_data_fp32 = p.data.float() 93 | 94 | state = self.state[p] # get state dict for this param 95 | 96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 97 | # if self.first_run_check==0: 98 | # self.first_run_check=1 99 | # print("Initializing slow buffer...should not see this at load from saved model!") 100 | state['step'] = 0 101 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 103 | 104 | # look ahead weight storage now in state dict 105 | state['slow_buffer'] = torch.empty_like(p.data) 106 | state['slow_buffer'].copy_(p.data) 107 | 108 | else: 109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 111 | 112 | # begin computations 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | # GC operation for Conv layers and FC layers 117 | if grad.dim() > self.gc_gradient_threshold: 118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 119 | 120 | state['step'] += 1 121 | 122 | # compute variance mov avg 123 | #exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 124 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) 125 | # compute mean moving avg 126 | #exp_avg.mul_(beta1).add_(1 - beta1, grad) 127 | exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) 128 | 129 | buffered = self.radam_buffer[int(state['step'] % 10)] 130 | 131 | if state['step'] == buffered[0]: 132 | N_sma, step_size = buffered[1], buffered[2] 133 | else: 134 | buffered[0] = state['step'] 135 | beta2_t = beta2 ** state['step'] 136 | N_sma_max = 2 / (1 - beta2) - 1 137 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 138 | buffered[1] = N_sma 139 | if N_sma > self.N_sma_threshhold: 140 | step_size = math.sqrt( 141 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 142 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 143 | else: 144 | step_size = 1.0 / (1 - beta1 ** state['step']) 145 | buffered[2] = step_size 146 | 147 | if group['weight_decay'] != 0: 148 | #p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 149 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) 150 | 151 | # apply lr 152 | if N_sma > self.N_sma_threshhold: 153 | denom = exp_avg_sq.sqrt().add_(group['eps']) 154 | #p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 155 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) 156 | else: 157 | #p_data_fp32.add_(-step_size * group['lr'], exp_avg) 158 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) 159 | 160 | p.data.copy_(p_data_fp32) 161 | 162 | # integrated look ahead... 163 | # we do it at the param level instead of group level 164 | if state['step'] % group['k'] == 0: 165 | slow_p = state['slow_buffer'] # get access to slow param tensor 166 | #slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha 167 | slow_p.add_(p.data - slow_p,alpha=self.alpha) # (fast weights - slow weights) * alpha 168 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor 169 | 170 | return loss 171 | -------------------------------------------------------------------------------- /training/testing_loop_encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import random 5 | import pickle 6 | import time 7 | from pprint import pprint 8 | from tqdm import tqdm 9 | 10 | import PIL.Image 11 | import numpy as np 12 | import torch 13 | import torch.utils.tensorboard as tensorboard 14 | from lpips import LPIPS 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | 17 | import dnnlib 18 | import legacy 19 | from torch_utils import misc 20 | from training.dataset_encoder import ImagesDataset 21 | from training.loss_encoder import l2_loss, IDLoss 22 | from training.networks_encoder import Encoder 23 | from training.training_loop_encoder import save_image 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | @torch.no_grad() 28 | def testing_loop( 29 | test_dir = '.', # Output directory. 30 | rank = 0, # Rank of the current process in [0, num_gpus]. 31 | model_architecture = 'base', # Model architecture type, ['base', 'transformer'] 32 | w_avg = False, # Train delta w from w_avg 33 | num_encoder_layers = 1, # Encoder layers if model_architecture is transformer 34 | dataset_dir = 'celeba-hq', # Train dataset directory 35 | num_gpus = 8, # Number of GPUs participating in the training. 36 | batch_size = 32, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. 37 | batch_gpu = 4, # Number of samples processed at a time by one GPU. 38 | generator_pkl = None, # Generator pickle to encode. 39 | l2_lambda = 1.0, # L2 loss multiplier factor 40 | lpips_lambda = 0.8, # LPIPS loss multiplier factor 41 | id_lambda = 0.1, # ID loss multiplier factor 42 | reg_lambda = 0.0, # e4e reg loss multiplier factor 43 | gan_lambda = 0.0, # e4e latent gan loss multiplier factor 44 | edit_lambda = 0.0, # e4e editability loss multiplier factor 45 | random_seed = 0, # Global random seed. 46 | num_workers = 3, # Dataloader workers. 47 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? 48 | ): 49 | 50 | # initialize 51 | device = torch.device('cuda', rank) 52 | 53 | # Reproducability 54 | random.seed(random_seed * num_gpus + rank) 55 | np.random.seed(random_seed * num_gpus + rank) 56 | torch.manual_seed(random_seed * num_gpus + rank) 57 | torch.cuda.manual_seed(random_seed * num_gpus + rank) 58 | torch.cuda.manual_seed_all(random_seed * num_gpus + rank) 59 | torch.backends.cudnn.deterministic = True 60 | torch.backends.cudnn.benchmark = cudnn_benchmark 61 | 62 | # Load testing set. 63 | if rank == 0: 64 | print('Loading testing set...') 65 | testing_set = ImagesDataset(dataset_dir, mode='test') 66 | testing_set_sampler = torch.utils.data.distributed.DistributedSampler(testing_set, num_replicas=num_gpus, rank=rank, shuffle=False, seed=random_seed, drop_last=False) 67 | testing_loader = torch.utils.data.DataLoader(dataset=testing_set, sampler=testing_set_sampler, batch_size=batch_size//num_gpus, num_workers=num_workers) 68 | if rank == 0: 69 | print() 70 | print('Num images: ', len(testing_set)) 71 | print('Image shape:', testing_set.__getitem__(0)[0].shape) 72 | print() 73 | 74 | # Construct generator. 75 | if rank == 0: 76 | print('Constructing generator...') 77 | with dnnlib.util.open_url(generator_pkl) as f: 78 | G = legacy.load_network_pkl(f)['G_ema'].to(device) 79 | 80 | # Initizlize loss 81 | if rank == 0: 82 | print('Initialize loss...') 83 | id_loss = IDLoss().to(device) 84 | lpips_loss = LPIPS(net='alex', verbose=False).to(device).eval() 85 | 86 | # Initialize logs. 87 | if rank == 0: 88 | print('Initialize tensorboard logs...') 89 | logger = tensorboard.SummaryWriter(test_dir) 90 | 91 | # Test. 92 | G.eval() 93 | 94 | latent_avg = None 95 | if w_avg: 96 | latent_avg = G.mapping.w_avg 97 | 98 | test_pkl_lst = [os.path.join(test_dir, 'network_snapshots', x) for x in sorted(os.listdir(os.path.join(test_dir, 'network_snapshots')))][-9:] 99 | 100 | for test_pkl in test_pkl_lst: 101 | if rank == 0: 102 | print(f'\nConstructing encoder from: {test_pkl}') 103 | # Construct encoder. 104 | if model_architecture == 'base': 105 | E = DDP(Encoder(pretrained=test_pkl,w_avg=latent_avg).to(device), device_ids=[rank]) 106 | elif model_architecture == 'transformer': 107 | styleblock = dict(arch='transformer', num_encoder_layers=num_encoder_layers) 108 | E = DDP(Encoder(pretrained=test_pkl, w_avg=latent_avg, **styleblock).to(device), device_ids=[rank]) 109 | cur_step = E.module.resume_step 110 | assert cur_step.__repr__() in test_pkl 111 | 112 | E.eval() 113 | 114 | epoch_loss = 0.0 115 | epoch_loss_dict = {k:0.0 for k in ['l2', 'lpips', 'id', 'id_improve', 'loss']} 116 | 117 | for batch_idx, batch in tqdm(enumerate(testing_loader),total=len(testing_loader)): 118 | # x:source image = y:real image 119 | # E(x): w, encoded latent 120 | # G.synthesis(E(x)):encoded_images 121 | x,y = batch 122 | x,real_images = x.to(device),y.to(device) 123 | face_pool=torch.nn.AdaptiveAvgPool2d((256,256)) 124 | encoded_images = face_pool(G.synthesis(E(x))) 125 | 126 | # get loss 127 | loss = 0.0 128 | loss_dict = {} 129 | loss_l2 = l2_loss(encoded_images, real_images) 130 | loss_dict['l2'] = loss_l2.item() 131 | loss += loss_l2 * l2_lambda 132 | loss_lpips = lpips_loss(encoded_images, real_images).squeeze().mean() 133 | loss_dict['lpips'] = loss_lpips.item() 134 | loss += loss_lpips * lpips_lambda 135 | loss_id, sim_improvement = id_loss(encoded_images, real_images, x) 136 | loss_dict['id'] = loss_id.item() 137 | loss_dict['id_improve'] = sim_improvement 138 | loss += loss_id * id_lambda 139 | loss_dict['loss'] = loss.item() 140 | 141 | epoch_loss += loss.item() 142 | for k in epoch_loss_dict: 143 | epoch_loss_dict[k] += loss_dict[k] 144 | 145 | # barrier 146 | torch.distributed.barrier() 147 | torch.cuda.empty_cache() 148 | 149 | # Save image snapshot. 150 | if rank == 0 and batch_idx == 0: 151 | print(f"Saving image samples...") 152 | gh, gw = 1, batch_gpu 153 | H,W = real_images.shape[2], real_images.shape[3] 154 | real_path = f'test-image-snapshot-real-{cur_step:06d}.png' 155 | encoded_path = f'test-image-snapshot-encoded-{cur_step:06d}.png' 156 | save_image(real_images, os.path.join(test_dir, 'image_snapshots', real_path), gh, gw, H, W) 157 | save_image(encoded_images, os.path.join(test_dir, 'image_snapshots', encoded_path), gh, gw, H, W) 158 | 159 | # barrier 160 | torch.distributed.barrier() 161 | 162 | # Tensorboard logs. 163 | # TODO: need to get other devices' loss 164 | for k in epoch_loss_dict: 165 | epoch_loss_dict[k] /= len(testing_loader) 166 | if rank == 0: 167 | pprint(epoch_loss_dict) 168 | for key in epoch_loss_dict: 169 | logger.add_scalar(f'test/{key}', epoch_loss_dict[key], cur_step) 170 | # barrier 171 | torch.distributed.barrier() 172 | 173 | del E 174 | 175 | # Done. 176 | torch.distributed.destroy_process_group() 177 | 178 | if rank == 0: 179 | print() 180 | print('Exiting...') 181 | -------------------------------------------------------------------------------- /training/training_loop_encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import random 5 | import pickle 6 | import time 7 | from pprint import pprint 8 | 9 | import PIL.Image 10 | import numpy as np 11 | import torch 12 | import torch.utils.tensorboard as tensorboard 13 | from lpips import LPIPS 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | 16 | import dnnlib 17 | import legacy 18 | from torch_utils import misc 19 | from training.dataset_encoder import ImagesDataset 20 | from training.loss_encoder import l2_loss, IDLoss 21 | from training.networks_encoder import Encoder 22 | from training.ranger import Ranger 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def save_image(images, save_path, gh, gw, H, W): 27 | np_imgs = [] 28 | for i, image in enumerate(images): 29 | image = images[i][None,:,:] 30 | image = (image.permute(0,2,3,1)*127.5+128).clamp(0,255).to(torch.uint8).cpu().numpy() 31 | np_imgs.append(np.asarray(PIL.Image.fromarray(image[0], 'RGB').resize((H,W),PIL.Image.LANCZOS))) 32 | np_imgs = np.stack(np_imgs) 33 | np_imgs = np_imgs.reshape(gh,gw,H,W,3) 34 | np_imgs = np_imgs.transpose(0,2,1,3,4) 35 | np_imgs = np_imgs.reshape(gh*H, gw*W, 3) 36 | PIL.Image.fromarray(np_imgs, 'RGB').save(save_path) 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def training_loop( 41 | run_dir = '.', # Output directory. 42 | rank = 0, # Rank of the current process in [0, num_gpus]. 43 | model_architecture = 'base', # Model architecture type, ['base', 'transformer'] 44 | w_avg = False, # Train delta w from w_avg 45 | num_encoder_layers = 1, # Encoder layers if model_architecture is transformer 46 | dataset_dir = 'ffhq', # Train dataset directory 47 | num_gpus = 1, # Number of GPUs participating in the training. 48 | batch_size = 32, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. 49 | batch_gpu = 4, # Number of samples processed at a time by one GPU. 50 | generator_pkl = None, # Generator pickle to encode. 51 | val_dataset_dir = 'celeba-hq', # Validation dataset directory 52 | training_steps = 100001, # Total training batch steps 53 | val_steps = 10000, # Validation batch steps 54 | print_steps = 50, # How often to print logs 55 | tensorboard_steps = 50, # How often to log to tensorboard? 56 | image_snapshot_steps = 100, # How often to save image snapshots? None=disable. 57 | network_snapshot_steps = 5000, # How often to save network snapshots? 58 | learning_rate = 0.001, # Learning rate 59 | l2_lambda = 1.0, # L2 loss multiplier factor 60 | lpips_lambda = 0.8, # LPIPS loss multiplier factor 61 | id_lambda = 0.1, # ID loss multiplier factor 62 | reg_lambda = 0.0, # e4e reg loss multiplier factor 63 | gan_lambda = 0.0, # e4e latent gan loss multiplier factor 64 | edit_lambda = 0.0, # e4e editability loss multiplier factor 65 | random_seed = 0, # Global random seed. 66 | num_workers = 3, # Dataloader workers. 67 | resume_pkl = None, # Network pickle to resume training from. 68 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? 69 | ): 70 | 71 | # initialize 72 | device = torch.device('cuda', rank) 73 | 74 | # Reproducability 75 | random.seed(random_seed * num_gpus + rank) 76 | np.random.seed(random_seed * num_gpus + rank) 77 | torch.manual_seed(random_seed * num_gpus + rank) 78 | torch.cuda.manual_seed(random_seed * num_gpus + rank) 79 | torch.cuda.manual_seed_all(random_seed * num_gpus + rank) 80 | torch.backends.cudnn.deterministic = True 81 | torch.backends.cudnn.benchmark = cudnn_benchmark 82 | 83 | # Load training set. 84 | if rank == 0: 85 | print('Loading training set...') 86 | training_set = ImagesDataset(dataset_dir, mode='train') 87 | training_set_sampler = torch.utils.data.distributed.DistributedSampler(training_set, num_replicas=num_gpus, rank=rank, shuffle=True, seed=random_seed, drop_last=False) 88 | training_loader = torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, num_workers=num_workers) 89 | if rank == 0: 90 | print() 91 | print('Num images: ', len(training_set)) 92 | print('Image shape:', training_set.__getitem__(0)[0].shape) 93 | print() 94 | 95 | # Load validation set. 96 | 97 | # Construct networks. 98 | if rank == 0: 99 | print('Constructing networks...') 100 | with dnnlib.util.open_url(generator_pkl) as f: 101 | G = legacy.load_network_pkl(f)['G_ema'].to(device) 102 | 103 | latent_avg = None 104 | if w_avg: 105 | latent_avg = G.mapping.w_avg 106 | if model_architecture == 'base': 107 | if resume_pkl is None: 108 | E = DDP(Encoder(pretrained=None,w_avg=latent_avg).to(device), device_ids=[rank]) 109 | else: 110 | E = DDP(Encoder(pretrained=resume_pkl,w_avg=latent_avg).to(device), device_ids=[rank]) 111 | elif model_architecture == 'transformer': 112 | styleblock = dict(arch='transformer', num_encoder_layers=num_encoder_layers) 113 | if resume_pkl is None: 114 | E = DDP(Encoder(pretrained=None, w_avg=latent_avg, **styleblock).to(device), device_ids=[rank]) 115 | else: 116 | E = DDP(Encoder(pretrained=resume_pkl, w_avg=latent_avg, **styleblock).to(device), device_ids=[rank]) 117 | cur_step = E.module.resume_step 118 | 119 | # Initizlize loss 120 | if rank == 0: 121 | print('Initialize loss...') 122 | id_loss = IDLoss().to(device) 123 | lpips_loss = LPIPS(net='alex', verbose=False).to(device).eval() 124 | 125 | # Initialize optimizer 126 | if rank == 0: 127 | print('Initialize optimizer...') 128 | params = list(E.parameters()) 129 | optimizer = Ranger(params, lr=learning_rate) 130 | 131 | # Initialize logs. 132 | if rank == 0: 133 | print('Initialize tensorboard logs...') 134 | logger = tensorboard.SummaryWriter(run_dir) 135 | 136 | # Train. 137 | E.train() 138 | G.eval() 139 | 140 | #TODO : implement validation 141 | while cur_step < training_steps: 142 | for batch_idx, batch in enumerate(training_loader): 143 | optimizer.zero_grad() 144 | # x:source image = y:real image 145 | # E(x): w, encoded latent 146 | # G.synthesis(E(x)):encoded_images 147 | x,y = batch 148 | x,real_images = x.to(device),y.to(device) 149 | face_pool=torch.nn.AdaptiveAvgPool2d((256,256)) 150 | encoded_images = face_pool(G.synthesis(E(x))) 151 | 152 | # get loss 153 | loss = 0.0 154 | loss_dict = {} # for tb logs 155 | loss_l2 = l2_loss(encoded_images, real_images) 156 | loss_dict['l2'] = loss_l2.item() 157 | loss += loss_l2 * l2_lambda 158 | loss_lpips = lpips_loss(encoded_images, real_images).squeeze().mean() 159 | loss_dict['lpips'] = loss_lpips.item() 160 | loss += loss_lpips * lpips_lambda 161 | loss_id, sim_improvement = id_loss(encoded_images, real_images, x) 162 | loss_dict['id'] = loss_id.item() 163 | loss_dict['id_improve'] = sim_improvement 164 | loss += loss_id * id_lambda 165 | loss_dict['loss'] = loss.item() 166 | 167 | if rank == 0 and cur_step % print_steps == 0: 168 | print(f'\nCurrent batch step: {cur_step}') 169 | pprint(loss_dict) 170 | 171 | # back propagation 172 | loss.backward() 173 | 174 | # optimizer step 175 | optimizer.step() 176 | 177 | # barrier 178 | torch.distributed.barrier() 179 | 180 | # Save image snapshot. 181 | if rank == 0 and cur_step % image_snapshot_steps == 0: 182 | print(f"Saving image snapshot at step {cur_step}...") 183 | gh, gw = 1, batch_gpu 184 | H,W = real_images.shape[2], real_images.shape[3] 185 | real_path = f'image-snapshot-real-{cur_step:06d}.png' 186 | encoded_path = f'image-snapshot-encoded-{cur_step:06d}.png' 187 | save_image(real_images, os.path.join(run_dir, 'image_snapshots', real_path), gh, gw, H, W) 188 | save_image(encoded_images, os.path.join(run_dir, 'image_snapshots', encoded_path), gh, gw, H, W) 189 | 190 | # Save network snapshot. 191 | snapshot_pkl = None 192 | snapshot_data = None 193 | if rank == 0 and cur_step % network_snapshot_steps == 0: 194 | print(f"Saving netowrk snapshot at step {cur_step}...") 195 | # TODO: save lr scheduler, optimizer status, etc... 196 | snapshot_data = dict( 197 | E=E.state_dict(), 198 | step=cur_step, 199 | ) 200 | snapshot_pkl = os.path.join(run_dir, 'network_snapshots',f'network-snapshot-{cur_step:06d}.pkl') 201 | with open(snapshot_pkl, 'wb') as f: 202 | pickle.dump(snapshot_data, f) 203 | del snapshot_data # conserve memory 204 | 205 | # Tensorboard logs. 206 | if rank == 0 and cur_step % tensorboard_steps == 0: 207 | for key in loss_dict: 208 | logger.add_scalar(f'train/{key}', loss_dict[key], cur_step) 209 | # barrier 210 | torch.distributed.barrier() 211 | 212 | # update cur_step 213 | cur_step += 1 214 | if cur_step == training_steps: 215 | break 216 | 217 | # Done. 218 | torch.distributed.destroy_process_group() 219 | 220 | if rank == 0: 221 | print() 222 | print('Exiting...') 223 | --------------------------------------------------------------------------------