├── .gitignore
├── LICENSE
├── README.md
├── assets
└── watermark.jpg
├── conf
├── README.md
├── attn_maps.yaml
├── config.yaml
├── default_conf.yaml
├── inference.yaml
├── meta_conf.yaml
├── runs
│ └── ffhq_256.yaml
└── tags
├── dataset
└── dataset.py
├── images
├── architecture.png
├── ffhq256.png
├── fidparams.png
└── header.png
├── main.py
├── models
├── basic_layers.py
├── discriminator.py
├── generator.py
└── stylenat.py
├── op
├── __init__.py
├── fused_act.py
├── fused_bias_act.cpp
├── fused_bias_act_kernel.cu
├── upfirdn2d.cpp
├── upfirdn2d.py
└── upfirdn2d_kernel.cu
├── requirements.txt
├── src
├── analysis.py
├── evaluate.py
├── inference.py
├── logging.py
├── throughput.py
└── train.py
├── torch_utils
├── __init__.py
├── custom_ops.py
├── gen_utils.py
├── misc.py
├── ops
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── filtered_lrelu.cpp
│ ├── filtered_lrelu.cu
│ ├── filtered_lrelu.h
│ ├── filtered_lrelu.py
│ ├── filtered_lrelu_ns.cu
│ ├── filtered_lrelu_rd.cu
│ ├── filtered_lrelu_wr.cu
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
├── persistence.py
├── training_stats.py
└── utils_spectrum.py
└── utils
├── CRDiffAug.py
├── distributed.py
├── fid_score.py
├── helpers.py
├── improved_precision_recall.py
└── inception.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __*__
2 | *.pyc
3 | *.sw?
4 | wandb/
5 | *.pth
6 | *.tar
7 | *.gz
8 | checkpoints/
9 | samples/
10 | tags
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2022] [SHI-Labs]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## StyleNAT: Giving Each Head a New Perspective
2 |
3 |
4 | [](https://paperswithcode.com/sota/image-generation-on-ffhq-256-x-256?p=stylenat-giving-each-head-a-new-perspective)
5 | [](https://paperswithcode.com/sota/image-generation-on-ffhq-1024-x-1024?p=stylenat-giving-each-head-a-new-perspective)
6 | [](https://paperswithcode.com/sota/image-generation-on-lsun-churches-256-x-256?p=stylenat-giving-each-head-a-new-perspective)
7 |
8 | ##### Authors: [Steven Walton](https://github.com/stevenwalton), [Ali Hassani](https://github.com/alihassanijr), [Xingqian Xu](https://github.com/xingqian2018), Zhangyang Wang, [Humphrey Shi](https://github.com/honghuis)
9 |
10 | 
11 | StyleNAT is a Style-based GAN that exploits [Neighborhood
12 | Attention](https://github.com/SHI-Labs/Neighborhood-Attention-Transformer) to
13 | extend the power of localized attention heads to capture long range features and
14 | maximize information gain within the generative process.
15 | The flexibility of the the system allows it to be adapted to various
16 | environments and datasets.
17 |
18 | ## Abstract:
19 | Image generation has been a long sought-after but challenging task, and performing the generation task in an efficient manner is similarly difficult.
20 | Often researchers attempt to create a "one size fits all" generator, where there are few differences in the parameter space for drastically different datasets.
21 | Herein, we present a new transformer-based framework, dubbed StyleNAT, targeting high-quality image generation with superior efficiency and flexibility.
22 | At the core of our model, is a carefully designed framework that partitions attention heads to capture local and global information, which is achieved through using Neighborhood Attention (NA).
23 | With different heads able to pay attention to varying receptive fields, the model is able to better combine this information, and adapt, in a highly flexible manner, to the data at hand.
24 | StyleNAT attains a new SOTA FID score on FFHQ-256 with 2.046, beating prior arts with convolutional models such as StyleGAN-XL and transformers such as HIT and StyleSwin, and a new transformer SOTA on FFHQ-1024 with an FID score of 4.174.
25 | These results show a 6.4% improvement on FFHQ-256 scores when compared to StyleGAN-XL with a 28% reduction in the number of parameters and 56% improvement in sampling throughput.
26 |
27 | ## Architecture
28 | 
29 |
30 | ## Performance
31 | 
32 |
33 | Dataset | FID | Throughput (imgs/s) | Number of Parameters (M) |
34 | |:---:|:---:|:---:|:---:|
35 | FFHQ 256 | [2.046](https://shi-labs.com/projects/stylenat/checkpoints/FFHQ256_940k_flip.pt) | 32.56 | 48.92 |
36 | FFHQ 1024 | [4.174](https://shi-labs.com/projects/stylenat/checkpoints/FFHQ1024_700k.pt) | - | 49.45 |
37 | Church 256 | 3.400 | - | - |
38 |
39 | ## Building and Using StyleNAT
40 | We recommend building an environment with conda to get the best performance. We
41 | recommend the following build instructions but your millage may vary.
42 | ```bash
43 | conda create --name stylenat python=3.10
44 | conda activate stylenat
45 | conda install pytorch torchvision cudatoolkit=11.6 -c pytorch -c nvidia
46 | # Use xargs to install lines one at a time since natten requires torch to be installed first
47 | cat requirements.txt | xargs -L1 pip install
48 | ```
49 | Note: some version issues can create poor FIDs. Always check your build
50 | environment first with the `evaluate` method. With the best FFHQ score you
51 | should always get under an FID < 2.10 (hopefully closer to 2.05).
52 |
53 | Note: You may need to install torch and torchvision first due to dependence. Pip
54 | does not build sequentially and NATTEN may fail to build.
55 |
56 | Notes:
57 | - [NATTEN can be sped up by using pre-built wheels directly.](https://shi-labs.com/natten/)
58 |
59 | - Some arguments and configurations have changed slightly. Everything should be
60 | backwards compatible but if they aren't please open an issue.
61 |
62 | - This is research code, not production. There are plenty of optimizations that
63 | can be implemented easily. We also are obsessive about logging information and
64 | storing into checkpoints. Official checkpoints may not have all information as
65 | current code tracks due to research and development. Most important things
66 | should exist but if you're missing something important open an issue. Sorry,
67 | seeds and rng states are only available if they exist in the checkpoints.
68 |
69 | ## Inference
70 | Using META's hydra-core we can easily run. We simply have to run
71 | ```bash
72 | python main.py type=inference
73 | ```
74 | Note that the first time you run this it will take some time, upfirdn2d is compiling.
75 |
76 | By default this will create 10 random inference images with a checkpoint and the
77 | names will be saved as the name of the random seed.
78 |
79 | You can specify seeds by using
80 | ```bash
81 | python main.py type=inference inference.seeds=[1,2,3,4]
82 | ```
83 | If you would like to specify a set of seeds in a range use the following command
84 | `python main 'inference.seeds="range(start, stop, step)"'`
85 |
86 |
87 | ## Evaluation
88 | If you would like to check the performance of a model we provide the evaluation
89 | mode type. Simply run
90 | ```bash
91 | python main.py type=evaluation
92 | ```
93 | See the config file to set the proper dataset, checkpoint, etc.
94 |
95 | # Training
96 | If you would like to train a model from scratch we provide the following mode
97 | ```bash
98 | python main.py type=train restart.ckpt=null
99 | ```
100 | We suggest explicitly setting the checkpoint to null so that you don't
101 | accidentally load a checkpoint.
102 | It is also advised to create a new run file and call
103 | ```bash
104 | python main.py type=train restart.ckpt=null runs=my_new_run
105 | ```
106 | We also support distributed training. Simply use torchrun
107 | ```bash
108 | torchrun --nnodes=$NUM_NODES --nproc_per_node=$NUM_GPUS --node_rank=$NODE_RANK
109 | main.py type=train
110 | ```
111 |
112 | ## Modifying Hydra-Configs
113 | The `confs` directory holds yaml configs for different types of runs. If you would
114 | like to adjust parameters (such as changing checkpoints, inference, number of
115 | images, specifying seeds, and so on) you should edit this file. The `confs/runs` folder holds
116 | parameters for the model and training options. It is not advised to modify these
117 | files. It is better to copy them to a new file and use those if you wish to
118 | train a new model.
119 |
120 | ## "Secret" Hydra args
121 | There's a few unspecified hydra configs around wandb. We're just providing a
122 | simple version. But we also support `tags` and `description` under this
123 | argument.
124 |
125 |
126 |
127 | ## Citation:
128 | ```bibtex
129 | @article{walton2022stylenat,
130 | title = {StyleNAT: Giving Each Head a New Perspective},
131 | author = {Steven Walton and Ali Hassani and Xingqian Xu and Zhangyang Wang and Humphrey Shi},
132 | year = 2022,
133 | url = {https://arxiv.org/abs/2211.05770},
134 | eprint = {2211.05770},
135 | archiveprefix = {arXiv},
136 | primaryclass = {cs.CV}
137 | }
138 | ```
139 |
140 | ## Acknowledgements
141 | This code heavily relies upon
142 | [StyleSwin](https://github.com/microsoft/StyleSwin) which also relies upon
143 | [rosinality's StyleGAN2-pytorch](https://github.com/rosinality/stylegan2-pytorch) library.
144 | We also utilize [mseitzer's pytorch-fid](https://github.com/mseitzer/pytorch-fid).
145 | Finally, we utilize SHI-Lab's [NATTEN](https://github.com/SHI-Labs/NATTEN/).
146 |
147 | We'd also like to thank Intelligence Advanced Research Projects Activity
148 | (IARPA), University of Oregon, University of Illinois at Urbana-Champaign, and
149 | Picsart AI Research (PAIR) for their generous support.
150 |
--------------------------------------------------------------------------------
/assets/watermark.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/assets/watermark.jpg
--------------------------------------------------------------------------------
/conf/README.md:
--------------------------------------------------------------------------------
1 | # Hydra-core configurations
2 |
3 | We use [hydra](https://hydra.cc/docs/intro/) to track our hyper-parameters.
4 | `runs` directory contains settings for the runs that were used for each dataset
5 |
6 | `meta_conf.yaml` is an explanation of all the possible arguments, the type, and
7 | what they are used in.
8 |
9 | `inference.yaml` is a simple inference example configuration.
10 |
11 | `conf.yaml` is the bare minimum configuration for training.
12 |
13 | Here's a quick reference on hydra's override syntax
14 |
15 | Overwriting an existing value
16 | ```
17 | arg_key=new_vale
18 | nested/key=new_value
19 | ```
20 |
21 | Add a new argument
22 | ```
23 | +new_arg=new_value
24 | +new/nested/arg=new_value
25 | ```
26 |
27 | Remove an argument
28 | ```
29 | ~arg_key
30 | ~arg_key=value
31 | ~nested/arg
32 | ~nested/arg=value
33 | ```
34 |
35 | Use a different base config file
36 | ```
37 | python main.py --config-name attn_map
38 | python main.py -cn inference
39 | ```
40 |
--------------------------------------------------------------------------------
/conf/attn_maps.yaml:
--------------------------------------------------------------------------------
1 | type: attention_map
2 | device: cuda
3 | distributed: False
4 | save_root: my/storage/path/
5 |
6 | defaults:
7 | - _self_
8 | - runs: ffhq_256
9 |
10 | restart:
11 | ckpt: FFHQ256_940k_flip.pt
12 |
13 | evaluation:
14 | attn_map: True
15 | save_attn_map: True
16 | attn_map_path: remove/if/want/save_root/as/dir
17 | const_attn_seed: True
18 |
19 |
--------------------------------------------------------------------------------
/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default_conf
3 | - runs: ffhq_256
4 | - _self_
5 |
6 | type: train
7 | device: cuda
8 | latent: 4096
9 | world_size: 1
10 | rank: 0
11 | local_rank: 0
12 | distributed: False
13 | workers: 0
14 | save_root: results
15 |
16 | dataset:
17 | name: ffhq
18 | path: /data/datasets/ffhq_256
19 |
20 | inference:
21 | num_images: 10
22 | save_path: sample_images
23 | batch: 1
24 |
25 | logging:
26 | wandb: False
27 | log_img_batch: False
28 | print_freq: 1000
29 | eval_freq: 50000
30 | save_freq: 25000
31 | checkpoint_path: checkpoints
32 | sample_path: eval_samples
33 | reuse_samplepath: True
34 |
35 | evaluation:
36 | gt_path: /data/datasets/ffhq_256/images/
37 | num_batches: 12500
38 | total_size: 50000
39 | batch: 4
40 | attn_map: True
41 | save_attn_map: False
42 | attn_map_path: attn_maps
43 | const_attn_seed: True
44 |
45 | wandb:
46 | project_name: my_stylenat
47 | entity: my_wand_team
48 | run_name: ffhq_256_reproduce
49 |
--------------------------------------------------------------------------------
/conf/default_conf.yaml:
--------------------------------------------------------------------------------
1 | type: train
2 | device: cuda
3 | latent: 4096
4 | world_size: 1
5 | rank: 0
6 | local_rank: 0
7 | distributed: False
8 | workers: 0
9 | save_root: "/tmp/"
10 |
11 | inference:
12 | num_images: 1
13 | batch: 1
14 | save_path: "inference_samples"
15 |
16 | logging:
17 | print_freq: 1000
18 | eval_freq: 50000
19 | save_freq: 25000
20 | checkpoint_path: "checkpoints"
21 | sample_path: "samples"
22 | reuse_samplepath: True
23 | log_img_batch: False
24 | wandb: False
25 |
26 | evaluation:
27 | total_size: 50000
28 | num_batches: 12500
29 | batch: 4
30 | attn_map: True
31 |
32 | analysis:
33 | save_path: "attn_maps"
34 | const_attn_seed: True
35 | attn_seed: 0
36 |
37 | restart:
38 | wandb_resume: False
39 | reuse_rng: False
40 |
41 | throughput:
42 | rounds: 500
43 | warmup: 10
44 | batch_size: 1
45 |
46 | misc:
47 | seed: null
48 | rng_state: null
49 | py_rng_state: null
50 |
--------------------------------------------------------------------------------
/conf/inference.yaml:
--------------------------------------------------------------------------------
1 | type: inference
2 | device: cuda
3 | distributed: False
4 | save_root: storage/
5 |
6 | defaults:
7 | - _self_
8 | - runs: ffhq_256
9 |
10 | restart:
11 | ckpt: FFHQ256_940k_flip.pt
12 |
13 | inference:
14 | seeds: range(0,20)
15 | save_path: tmp
16 |
--------------------------------------------------------------------------------
/conf/meta_conf.yaml:
--------------------------------------------------------------------------------
1 | explanation_of_conf: |-
2 | This file is an explanation of the configuration YAML file
3 | and every variable that is potentially used within the code.
4 | Since hydra-core was made by META we'll be meta and write this meta conf.
5 | Each section will denote if it is required (req) or optional (opt).
6 | Each item in the section will similarly be noted and then given a type and
7 | description of what it does.
8 |
9 | type: (req:str) Type of job that to perform supports {train,inference,evaluate}
10 | logging_level: (opt:str) python's logging level. Defaults to warning to have better readouts
11 | device: (opt:str) Which acceleration device we will use supports {cuda,cpu}
12 | world_size: (opt:int) The size of your world. In distributed. This value will automatically be set in src/train.py Default 1
13 | rank: (opt:int) This is the rank of your GPU This value will automatically be set in src/train.py Default 1
14 | local_rank: (opt:int) Automatically set Default 0
15 | distributed: (opt:bool) Used in training Bool to determine if using distributed training default based if multiple GPUs are detected
16 | truncation: (opt:float) Unused optional truncation used when generating images Edit eval to enable this
17 | workers: (opt:int) Number of workers to spawn for dataloader default 0 (use this)
18 | save_root: (opt:str) root path for saving. If other paths don't start with / then we assume relative from here
19 |
20 | defaults:
21 | - runs: (req:str) Which yaml file to use from the run dir
22 | - _self_
23 |
24 | dataset:
25 | (req for all)
26 | name: (req:str) name of the dataset
27 | path: (req:str) /path/to/dataset/
28 | lmdb: (opt:bool) use the lmdb format overrides other options
29 |
30 | logging:
31 | print_freq: (opt:int) frequency to print to std default 1000
32 | eval_freq: (opt:int) frequency to evaluate, default -1
33 | save_freq: (opt:int) frequency to save model checkpoint
34 | checkpoint_path: (opt:str) where to save checkpoint, default /tmp
35 | sample_path: (opt:str) where to save samples, default /tmp
36 | reuse_samplepath: (opt:bool) write fid samples to same directory (save space)
37 | wandb: (opt:bool) enable wandb logging
38 | log_img_batch: (opt:bool) bool to log the first image batch
39 |
40 | evaluation:
41 | gt_path: (str) path to ground truth images/data
42 | total_size: (int) number of fid images, typically 50000
43 | batch: (opt:int) batch size for generator during fid calculation
44 | attn_map: (opt:bool) generate attention maps (if wandb enabled, we log)
45 | save_attn_map: (opt:bool) do you want to save the attention maps?
46 | attn_map_path: (opt:str) path to save attention maps if save_attn_map
47 | const_attn_seed: (opt:bool or int) Specify true or a integer seed
48 |
49 | inference: need either num_images or seeds
50 | num_images: (opt:int) number of images to sample
51 | seeds: (opt:list,range) list of seeds to sample. "range(a,b)" also accepted
52 | save_path: (str) where to save images relative to cwd
53 | batch: (opt:int) batch size of images to generate. Default 1
54 |
55 | restart: options to restart a run
56 | ckpt: (str) path to checkpoint
57 | wandb_resume: (opt:str) see wandb.init resume
58 | wandb_id: (opt:str) wandb run id
59 | start_iter: (opt: int) you can manually specify the iteration but we'll try to figure it out
60 | reuse_rng: (opt:bool) attempt to use the checkpoint's rng information
61 |
62 | misc:
63 | seed: (opt:int) random seed for run
64 | rng_state: (opt:tensor) (don't set!) the rng_state of the run
65 | py_rng_state: (opt:list) (don't set!) the python random rng state
66 |
67 | wandb:
68 | entity: (str) your username
69 | project_name: (str): name of wandb project
70 | run_name: (opt:str) name of your run
71 | tags: (opt:list) list of tags for the run
72 | description: (opt:str) description information for your run
73 |
74 |
75 |
--------------------------------------------------------------------------------
/conf/runs/ffhq_256.yaml:
--------------------------------------------------------------------------------
1 | size: 256
2 |
3 | generator:
4 | n_mlp: 8
5 | block_type: nat
6 | style_dim: 512
7 | lr: 0.00004
8 | channel_multiplier: 2
9 | mlp_ratio: 4
10 | use_checkpoint: False
11 | lr_mlp: 0.01
12 | enable_full_resolution: 8
13 | min_heads: 4
14 | qkv_bias: True
15 | qk_scale: null
16 | proj_drop: 0.
17 | attn_drop: 0.
18 | kernels: [[3],[7],[7],[7],[7],[7],[7],[7],[7]]
19 | dilations: [[1],[1],[1,2],[1,4],[1,8],[1,16],[1,32],[1,64],[1,128]]
20 | reg_every: null
21 | params: 0
22 |
23 | discriminator:
24 | lr: 0.0002
25 | channel_multiplier: 2
26 | blur_kernel: [1, 3, 3, 1]
27 | sn: True
28 | ssd: False
29 | reg_every: 16
30 | params: 0
31 |
32 | training:
33 | iter: 1000000
34 | batch: 8
35 | use_flip: True
36 | ttur: True
37 | r1: 10
38 | bcr: True
39 | bcr_fake_lambda: 10
40 | bcr_real_lambda: 10
41 | beta1: 0.0
42 | beta2: 0.99
43 | start_dim: 512
44 | workers: 8
45 | lr_decay: True
46 | lr_decay_start_steps: 775000
47 | gan_weight: 1
48 |
--------------------------------------------------------------------------------
/conf/tags:
--------------------------------------------------------------------------------
1 | !_TAG_FILE_SORTED 2 /0=unsorted, 1=sorted, 2=foldcase/
2 |
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 | from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
6 | from torch.utils.data.distributed import DistributedSampler
7 |
8 | from torchvision import transforms as T
9 | from torchvision import datasets
10 |
11 | from utils.distributed import get_rank
12 |
13 |
14 | class MultiResolutionDataset(Dataset):
15 | def __init__(self, path, transform, resolution=256):
16 | self.env = lmdb.open(
17 | path,
18 | max_readers=32,
19 | readonly=True,
20 | lock=False,
21 | readahead=False,
22 | meminit=False,
23 | )
24 |
25 | if not self.env:
26 | raise IOError('Cannot open lmdb dataset', path)
27 |
28 | with self.env.begin(write=False) as txn:
29 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
30 |
31 | self.resolution = resolution
32 | self.transform = transform
33 |
34 | def __len__(self):
35 | return self.length
36 |
37 | def __getitem__(self, index):
38 | with self.env.begin(write=False) as txn:
39 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
40 | img_bytes = txn.get(key)
41 |
42 | buffer = BytesIO(img_bytes)
43 | img = Image.open(buffer)
44 | img = self.transform(img)
45 |
46 | return img
47 |
48 | def unnormalize(image):
49 | if image.dim() == 4:
50 | image[:, 0, :, :] = image[:, 0, :, :] * 0.229 + 0.485
51 | image[:, 1, :, :] = image[:, 1, :, :] * 0.224 + 0.456
52 | image[:, 2, :, :] = image[:, 2, :, :] * 0.225 + 0.406
53 | elif image.dim() == 3:
54 | image[0, :, :] = image[0, :, :] * 0.229 + 0.485
55 | image[1, :, :] = image[1, :, :] * 0.224 + 0.456
56 | image[2, :, :] = image[2, :, :] * 0.225 + 0.406
57 | else:
58 | raise NotImplemented(f"Can't handle image of dimension {image.dim()}, please use a 3 or 4 dimensional image")
59 | return image
60 |
61 | def get_dataset(args, evaluation=True):
62 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
63 | std=[0.229, 0.224, 0.225])
64 | # Resize then center crop. Make sure data integrity is good
65 | transforms = [T.Resize(args.runs.size),
66 | T.CenterCrop(args.runs.size),
67 | ]
68 | if args.runs.training.use_flip and not evaluation:
69 | transforms.append(T.RandomHorizontalFlip())
70 | transforms.append(T.ToTensor())
71 | transforms.append(normalize)
72 | transforms = T.Compose(transforms)
73 |
74 | if "lmdb" in args.dataset and args.dataset.lmdb:
75 | if get_rank() == 0:
76 | print(f"Using LMDB with {args.dataset.path}")
77 | dataset = MultiResolutionDataset(path=args.dataset.path,
78 | transform=transforms,
79 | resolution=args.runs.size)
80 | elif args.dataset.name in ["cifar10"]:
81 | if get_rank() == 0:
82 | print(f"Loading CIFAR-10")
83 | dataset = datasets.CIFAR10(root=args.dataset.path,
84 | transform=transforms)
85 | else:
86 | if get_rank() == 0:
87 | print(f"Loading ImageFolder dataset from {args.dataset.path}")
88 | dataset = datasets.ImageFolder(root=args.dataset.path,
89 | transform=transforms)
90 | return dataset
91 |
92 | def data_sampler(dataset, shuffle, distributed):
93 | if distributed:
94 | return DistributedSampler(dataset, shuffle=shuffle)
95 | if shuffle:
96 | return RandomSampler(dataset)
97 | else:
98 | return SequentialSampler(dataset)
99 |
100 | def get_loader(args, dataset, batch_size=1):
101 | loader = DataLoader(dataset,
102 | batch_size=batch_size,
103 | num_workers=args.workers,
104 | sampler=data_sampler(dataset,
105 | shuffle=True,
106 | distributed=args.distributed),
107 | drop_last=True,
108 | )
109 | return loader
110 |
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/images/architecture.png
--------------------------------------------------------------------------------
/images/ffhq256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/images/ffhq256.png
--------------------------------------------------------------------------------
/images/fidparams.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/images/fidparams.png
--------------------------------------------------------------------------------
/images/header.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/images/header.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from rich import print
2 | import hydra
3 | import os
4 | import random
5 | import warnings
6 | from datetime import timedelta
7 | import logging
8 | import torch
9 | from omegaconf import OmegaConf, open_dict
10 | import natten
11 |
12 | from models.generator import Generator
13 | from utils.distributed import get_rank, synchronize, get_world_size
14 | from utils import helpers
15 |
16 | from src.train import train
17 | from src.inference import inference
18 | from src.evaluate import evaluate
19 | from src.analysis import visualize_attention
20 | from src.throughput import throughput
21 |
22 | torch.backends.cudnn.benchmark = True
23 | torch.backends.cuda.matmul.allow_tf32 = True
24 | torch.backends.cudnn.allow_tf32 = True
25 |
26 | @hydra.main(version_base=None, config_path="conf", config_name="config")
27 | def main(args):
28 | if "logging_level" in args:
29 | if type(args.logging_level) == str:
30 | _logging_level = {"DEBUG": logging.DEBUG, "INFO":logging.INFO,
31 | "WARNING": logging.WARNING, "ERROR": logging.ERROR,
32 | "CRITICAL": logging.CRITICAL}[args.logging_level.upper()]
33 | else:
34 | _logging_level = int(args.logging_level)
35 | logging.getLogger().setLevel(_logging_level)
36 | else:
37 | logging.getLogger().setLevel(logging.WARNING)
38 | helpers.validate_args(args)
39 | ckpt = None
40 | if "restart" in args and "ckpt" in args.restart and args.restart.ckpt:
41 | assert(os.path.exists(args.restart.ckpt)),f"Can't find a checkpoint "\
42 | f"at {args.restart.ckpt}"
43 | ckpt = torch.load(args.restart.ckpt, map_location=lambda storage, loc: storage)
44 | if "start_iter" not in args.restart:
45 | with open_dict(args):
46 | try:
47 | args.restart.start_iter = \
48 | int(os.path.basename(args.restart.ckpt)\
49 | .split(".pt")[0])
50 | except:
51 | args.restart.start_iter = 0
52 |
53 | helpers.rng_reproducibility(args, ckpt)
54 | #if "WORLD_SIZE" in os.environ:
55 | # # Single node multi GPU
56 | # n_gpu = int(os.environ["WORLD_SIZE"])
57 | #else:
58 | # n_gpu = torch.cuda.device_count()
59 | #args.distributed = n_gpu > 1
60 |
61 | if args.distributed:
62 | try:
63 | args.local_rank = int(os.environ["LOCAL_RANK"])
64 | except:
65 | args.local_rank = 0
66 | torch.cuda.set_device(args.local_rank)
67 | torch.distributed.init_process_group(backend="nccl",
68 | init_method="env://",
69 | timeout=timedelta(0, 180000))
70 | args.rank = get_rank()
71 | args.world_size = get_world_size()
72 | args.device = f"cuda:{args.local_rank}"
73 | torch.cuda.set_device(args.local_rank)
74 | synchronize()
75 |
76 | if get_rank() == 0:
77 | if args.save_root[-1] != "/": args.save_root += "/"
78 | if not os.path.exists(args.save_root):
79 | print(f"[bold yellow]WARNING:[/] Save root {args.save_root} path "\
80 | f"does not exist. Creating...")
81 | os.mkdir(args.save_root)
82 | if get_rank() == 0 and args.type in ['train']:
83 | # Make sample path
84 | if "sample_path" not in args.logging:
85 | samp_path = args.save_root + "samples"
86 | else:
87 | samp_path = args.logging.sample_path
88 | if args.logging.sample_path[0] != "/":
89 | samp_path = args.save_root + samp_path
90 | if not os.path.exists(samp_path):
91 | print(f"====> MAKING SAMPLE DIRECTORY: {samp_path}")
92 | os.mkdir(samp_path)
93 | # make checkpoint path
94 | if "checkpoint_path" not in args.logging:
95 | ckpt_path = args.save_root + "checkpoints"
96 | else:
97 | ckpt_path = args.logging.checkpoint_path
98 | if args.logging.checkpoint_path[0] != "/":
99 | ckpt_path = args.save_root + ckpt_path
100 | if not os.path.exists(ckpt_path):
101 | print(f"====> MAKING CHECKPOINT DIRECTORY: {ckpt_path}")
102 | os.mkdir(ckpt_path)
103 |
104 |
105 | # Only load gen if training, to save space
106 | if args.type == "train":
107 | generator = Generator(args=args.runs.generator, size=args.runs.size).to(args.device)
108 | g_ema = Generator(args=args.runs.generator, size=args.runs.size).to(args.device)
109 |
110 | if hasattr(g_ema, "num_params"):
111 | args.runs.generator.params = g_ema.num_params() / 1e6
112 | else:
113 | num_params = sum([m.numel() for m in g_ema.parameters()])
114 | if hasattr(args.runs.generator, "params"):
115 | args.runs.generator.params = num_params / 1e6
116 | else:
117 | with open_dict(args):
118 | args.runs.generator.params = num_params / 1e6
119 |
120 | # Load generator checkpoint
121 | if ckpt is not None:
122 | # Load generator checkpoints. But only load g if training
123 | if get_rank() == 0:
124 | print(f"Loading Generative Model")
125 | if 'state_dicts' in ckpt.keys():
126 | if args.type == "train":
127 | generator.load_state_dict(ckpt["state_dicts"]["g"])
128 | g_ema.load_state_dict(ckpt["state_dicts"]["g_ema"])
129 | elif set(['g', 'g_ema']).issubset(ckpt.keys()): # Old
130 | if args.type == "train":
131 | generator = Generator(args=args.runs.generator,
132 | size=args.runs.size, legacy=True).to(args.device)
133 | try:
134 | generator.load_state_dict(ckpt['g'])
135 | except Exception as e:
136 | print(e)
137 | print(f"[bold red]ERROR:[/] Failed to load checkpoint. " \
138 | f"Likely a mismatch between kernel size or dilation " \
139 | f"in the config and the checkpoint. "\
140 | f"Checkpoint has kernels {args.runs.generator.kernels}, " \
141 | f"and dilations {args.runs.generator.dilations}")
142 | exit(1)
143 | g_ema = Generator(args=args.runs.generator,
144 | size=args.runs.size, legacy=True).to(args.device)
145 | try:
146 | g_ema.load_state_dict(ckpt["g_ema"])
147 | except Exception as e:
148 | print(e)
149 | print(f"[bold red]ERROR:[/] Failed to load checkpoint. " \
150 | f"Likely a mismatch between kernel size or dilation " \
151 | f"in the config and the checkpoint. "\
152 | f"\nCheckpoint has " \
153 | f"\n\tkernels {args.runs.generator.kernels}, " \
154 | f"\n\tdilations {args.runs.generator.dilations}")
155 | exit(1)
156 | else:
157 | raise ValueError(f"Checkpoint dict broken:\n"\
158 | f"Checkpoint name: {args.restart.ckpt}\n"
159 | f"Keys: {ckpt.keys()}")
160 | g_ema.eval()
161 |
162 | # Print mode in a nice format
163 | if get_rank() == 0:
164 | print("\n" + ("=" * 50))
165 | print(f" Mode: {args.type} ".center(49, "="))
166 | print("=" * 50, "\n")
167 |
168 | if args.type == "train":
169 | train(args=args,
170 | generator=generator,
171 | g_ema=g_ema,
172 | ckpt=ckpt,
173 | )
174 | elif args.type == "inference":
175 | inference(args=args, generator=g_ema)
176 | elif args.type == "evaluate":
177 | evaluate(args=args, generator=g_ema)
178 | elif args.type == "attention_map":
179 | visualize_attention(args, g_ema,
180 | save_maps=args.analysis.save_path,
181 | )
182 | elif args.type == "throughput":
183 | throughput(generator=g_ema,
184 | style_dim=args.runs.generator.style_dim,
185 | batch_size=args.throughput.batch_size,
186 | rounds=args.throughput.rounds,
187 | warmup=args.throughput.warmup,
188 | )
189 |
190 |
191 | if __name__ == '__main__':
192 | main()
193 |
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | '''
2 | Most of the Modules from here come from rosinality's StyleGAN2-pytorch
3 | https://github.com/rosinality/stylegan2-pytorch
4 |
5 | Slight changes to these have been made to include Spectral Norm.
6 | StyleSwin changed the code to include this option.
7 | '''
8 | import math
9 | import torch
10 | from op import FusedLeakyReLU, upfirdn2d
11 | from torch import nn
12 | from torch.nn import functional as F
13 | from torch.nn.utils import spectral_norm
14 |
15 | from models.basic_layers import (Blur, Downsample, EqualConv2d, EqualLinear,
16 | ScaledLeakyReLU)
17 |
18 |
19 | class ConvLayer(nn.Sequential):
20 | def __init__(
21 | self,
22 | in_channel,
23 | out_channel,
24 | kernel_size,
25 | downsample=False,
26 | blur_kernel=[1, 3, 3, 1],
27 | bias=True,
28 | activate=True,
29 | sn=False
30 | ):
31 | layers = []
32 |
33 | if downsample:
34 | factor = 2
35 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
36 | pad0 = (p + 1) // 2
37 | pad1 = p // 2
38 |
39 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
40 |
41 | stride = 2
42 | self.padding = 0
43 |
44 | else:
45 | stride = 1
46 | self.padding = kernel_size // 2
47 |
48 | if sn:
49 | # Not use equal conv2d when apply SN
50 | layers.append(
51 | spectral_norm(nn.Conv2d(
52 | in_channel,
53 | out_channel,
54 | kernel_size,
55 | padding=self.padding,
56 | stride=stride,
57 | bias=bias and not activate,
58 | ))
59 | )
60 | else:
61 | layers.append(
62 | EqualConv2d(
63 | in_channel,
64 | out_channel,
65 | kernel_size,
66 | padding=self.padding,
67 | stride=stride,
68 | bias=bias and not activate,
69 | )
70 | )
71 |
72 | if activate:
73 | if bias:
74 | layers.append(FusedLeakyReLU(out_channel))
75 | else:
76 | layers.append(ScaledLeakyReLU(0.2))
77 |
78 | super().__init__(*layers)
79 |
80 |
81 | class ConvBlock(nn.Module):
82 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], sn=False):
83 | super().__init__()
84 |
85 | self.conv1 = ConvLayer(in_channel, in_channel, 3, sn=sn)
86 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, sn=sn)
87 |
88 | def forward(self, input):
89 | out = self.conv1(input)
90 | out = self.conv2(out)
91 |
92 | return out
93 |
94 |
95 | def get_haar_wavelet(in_channels):
96 | haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2)
97 | haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2)
98 | haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0]
99 |
100 | haar_wav_ll = haar_wav_l.T * haar_wav_l
101 | haar_wav_lh = haar_wav_h.T * haar_wav_l
102 | haar_wav_hl = haar_wav_l.T * haar_wav_h
103 | haar_wav_hh = haar_wav_h.T * haar_wav_h
104 |
105 | return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh
106 |
107 |
108 | class HaarTransform(nn.Module):
109 | def __init__(self, in_channels):
110 | super().__init__()
111 |
112 | ll, lh, hl, hh = get_haar_wavelet(in_channels)
113 |
114 | self.register_buffer('ll', ll)
115 | self.register_buffer('lh', lh)
116 | self.register_buffer('hl', hl)
117 | self.register_buffer('hh', hh)
118 |
119 | def forward(self, input):
120 | ll = upfirdn2d(input, self.ll, down=2)
121 | lh = upfirdn2d(input, self.lh, down=2)
122 | hl = upfirdn2d(input, self.hl, down=2)
123 | hh = upfirdn2d(input, self.hh, down=2)
124 |
125 | return torch.cat((ll, lh, hl, hh), 1)
126 |
127 |
128 | class InverseHaarTransform(nn.Module):
129 | def __init__(self, in_channels):
130 | super().__init__()
131 |
132 | ll, lh, hl, hh = get_haar_wavelet(in_channels)
133 |
134 | self.register_buffer('ll', ll)
135 | self.register_buffer('lh', -lh)
136 | self.register_buffer('hl', -hl)
137 | self.register_buffer('hh', hh)
138 |
139 | def forward(self, input):
140 | ll, lh, hl, hh = input.chunk(4, 1)
141 | ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0))
142 | lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0))
143 | hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0))
144 | hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0))
145 |
146 | return ll + lh + hl + hh
147 |
148 |
149 | class FromRGB(nn.Module):
150 | def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1], sn=False):
151 | super().__init__()
152 |
153 | self.downsample = downsample
154 |
155 | if downsample:
156 | self.iwt = InverseHaarTransform(3)
157 | self.downsample = Downsample(blur_kernel)
158 | self.dwt = HaarTransform(3)
159 |
160 | self.conv = ConvLayer(3 * 4, out_channel, 1, sn=sn)
161 |
162 | def forward(self, input, skip=None):
163 | if self.downsample:
164 | input = self.iwt(input)
165 | input = self.downsample(input)
166 | input = self.dwt(input)
167 |
168 | out = self.conv(input)
169 |
170 | if skip is not None:
171 | out = out + skip
172 |
173 | return input, out
174 |
175 |
176 | class Discriminator(nn.Module):
177 | def __init__(self,
178 | args,
179 | size,
180 | ):
181 | super().__init__()
182 |
183 | channels = {
184 | 4: 512,
185 | 8: 512,
186 | 16: 512,
187 | 32: 512,
188 | 64: 256 * args.channel_multiplier,
189 | 128: 128 * args.channel_multiplier,
190 | 256: 64 * args.channel_multiplier,
191 | 512: 32 * args.channel_multiplier,
192 | 1024: 16 * args.channel_multiplier,
193 | }
194 | self.size = size
195 | self.args = args
196 |
197 | self.dwt = HaarTransform(3)
198 |
199 | self.from_rgbs = nn.ModuleList()
200 | self.convs = nn.ModuleList()
201 |
202 | log_size = int(math.log(self.size, 2)) - 1
203 |
204 | in_channel = channels[self.size]
205 |
206 | for i in range(log_size, 2, -1):
207 | out_channel = channels[2 ** (i - 1)]
208 |
209 | self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size, sn=args.sn))
210 | self.convs.append(ConvBlock(in_channel, out_channel, args.blur_kernel, sn=args.sn))
211 |
212 | in_channel = out_channel
213 |
214 | self.from_rgbs.append(FromRGB(channels[4], sn=args.sn))
215 |
216 | self.stddev_group = 4
217 | self.stddev_feat = 1
218 |
219 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3, sn=args.sn)
220 | if args.sn:
221 | self.final_linear = nn.Sequential(
222 | spectral_norm(nn.Linear(channels[4] * 4 * 4, channels[4])),
223 | FusedLeakyReLU(channels[4]),
224 | spectral_norm(nn.Linear(channels[4], 1)),
225 | )
226 | else:
227 | self.final_linear = nn.Sequential(
228 | EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
229 | EqualLinear(channels[4], 1),
230 | )
231 |
232 | def forward(self, input):
233 | input = self.dwt(input)
234 | out = None
235 |
236 | for from_rgb, conv in zip(self.from_rgbs, self.convs):
237 | input, out = from_rgb(input, out)
238 | out = conv(out)
239 |
240 | _, out = self.from_rgbs[-1](input, out)
241 |
242 | batch, channel, height, width = out.shape
243 | group = min(batch, self.stddev_group)
244 | stddev = out.view(
245 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
246 | )
247 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
248 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
249 | stddev = stddev.repeat(group, 1, height, width)
250 | out = torch.cat([out, stddev], 1)
251 |
252 | out = self.final_conv(out)
253 |
254 | out = out.view(batch, -1)
255 | out = self.final_linear(out)
256 |
257 | return out
258 |
259 |
--------------------------------------------------------------------------------
/op/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
5 | from .upfirdn2d import upfirdn2d
6 |
7 |
--------------------------------------------------------------------------------
/op/fused_act.py:
--------------------------------------------------------------------------------
1 | # From StyleSwin
2 |
3 | import os
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from torch.autograd import Function
9 | from torch.utils.cpp_extension import load
10 |
11 | from torch.cuda.amp import custom_fwd, custom_bwd
12 |
13 | module_path = os.path.dirname(__file__)
14 | fused = load(
15 | "fused",
16 | sources=[
17 | os.path.join(module_path, "fused_bias_act.cpp"),
18 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
19 | ],
20 | )
21 |
22 |
23 | class FusedLeakyReLUFunctionBackward(Function):
24 | @staticmethod
25 | def forward(ctx, grad_output, out, negative_slope, scale):
26 | ctx.save_for_backward(out)
27 | ctx.negative_slope = negative_slope
28 | ctx.scale = scale
29 |
30 | empty = grad_output.new_empty(0)
31 |
32 | grad_input = fused.fused_bias_act(
33 | grad_output, empty, out, 3, 1, negative_slope, scale
34 | )
35 |
36 | dim = [0]
37 |
38 | if grad_input.ndim > 2:
39 | dim += list(range(2, grad_input.ndim))
40 |
41 | grad_bias = grad_input.sum(dim).detach()
42 |
43 | return grad_input, grad_bias
44 |
45 | @staticmethod
46 | def backward(ctx, gradgrad_input, gradgrad_bias):
47 | out, = ctx.saved_tensors
48 | gradgrad_out = fused.fused_bias_act(
49 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
50 | )
51 |
52 | return gradgrad_out, None, None, None
53 |
54 |
55 | class FusedLeakyReLUFunction(Function):
56 | @staticmethod
57 | @custom_fwd(cast_inputs=torch.float32)
58 | def forward(ctx, input, bias, negative_slope, scale):
59 | empty = input.new_empty(0)
60 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
61 | ctx.save_for_backward(out)
62 | ctx.negative_slope = negative_slope
63 | ctx.scale = scale
64 |
65 | return out
66 |
67 | @staticmethod
68 | @custom_bwd
69 | def backward(ctx, grad_output):
70 | out, = ctx.saved_tensors
71 |
72 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
73 | grad_output, out, ctx.negative_slope, ctx.scale
74 | )
75 |
76 | return grad_input, grad_bias, None, None
77 |
78 |
79 | class FusedLeakyReLU(nn.Module):
80 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
81 | super().__init__()
82 |
83 | self.bias = nn.Parameter(torch.zeros(channel))
84 | self.negative_slope = negative_slope
85 | self.scale = scale
86 |
87 | def forward(self, input):
88 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
89 |
90 |
91 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
92 | if input.device.type == "cpu":
93 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
94 | return (
95 | F.leaky_relu(
96 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
97 | )
98 | * scale
99 | )
100 |
101 | else:
102 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
103 |
--------------------------------------------------------------------------------
/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation.
2 | // Licensed under the MIT License.
3 |
4 | #include
5 |
6 |
7 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
8 | int act, int grad, float alpha, float scale);
9 |
10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13 |
14 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
15 | int act, int grad, float alpha, float scale) {
16 | CHECK_CUDA(input);
17 | CHECK_CUDA(bias);
18 |
19 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
20 | }
21 |
22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
23 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
24 | }
--------------------------------------------------------------------------------
/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation.
2 | // Licensed under the MIT License.
3 |
4 | #include
5 |
6 |
7 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
8 | int up_x, int up_y, int down_x, int down_y,
9 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
10 |
11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
14 |
15 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
16 | int up_x, int up_y, int down_x, int down_y,
17 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
18 | CHECK_CUDA(input);
19 | CHECK_CUDA(kernel);
20 |
21 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
22 | }
23 |
24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
25 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
26 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 |
6 | import torch
7 | from torch.nn import functional as F
8 | from torch.autograd import Function
9 | from torch.utils.cpp_extension import load
10 |
11 |
12 | module_path = os.path.dirname(__file__)
13 | upfirdn2d_op = load(
14 | "upfirdn2d",
15 | sources=[
16 | os.path.join(module_path, "upfirdn2d.cpp"),
17 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
18 | ],
19 | )
20 |
21 |
22 | class UpFirDn2dBackward(Function):
23 | @staticmethod
24 | def forward(
25 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
26 | ):
27 |
28 | up_x, up_y = up
29 | down_x, down_y = down
30 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
31 |
32 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
33 |
34 | grad_input = upfirdn2d_op.upfirdn2d(
35 | grad_output,
36 | grad_kernel,
37 | down_x,
38 | down_y,
39 | up_x,
40 | up_y,
41 | g_pad_x0,
42 | g_pad_x1,
43 | g_pad_y0,
44 | g_pad_y1,
45 | )
46 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
47 |
48 | ctx.save_for_backward(kernel)
49 |
50 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
51 |
52 | ctx.up_x = up_x
53 | ctx.up_y = up_y
54 | ctx.down_x = down_x
55 | ctx.down_y = down_y
56 | ctx.pad_x0 = pad_x0
57 | ctx.pad_x1 = pad_x1
58 | ctx.pad_y0 = pad_y0
59 | ctx.pad_y1 = pad_y1
60 | ctx.in_size = in_size
61 | ctx.out_size = out_size
62 |
63 | return grad_input
64 |
65 | @staticmethod
66 | def backward(ctx, gradgrad_input):
67 | kernel, = ctx.saved_tensors
68 |
69 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
70 |
71 | gradgrad_out = upfirdn2d_op.upfirdn2d(
72 | gradgrad_input,
73 | kernel,
74 | ctx.up_x,
75 | ctx.up_y,
76 | ctx.down_x,
77 | ctx.down_y,
78 | ctx.pad_x0,
79 | ctx.pad_x1,
80 | ctx.pad_y0,
81 | ctx.pad_y1,
82 | )
83 | gradgrad_out = gradgrad_out.view(
84 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
85 | )
86 |
87 | return gradgrad_out, None, None, None, None, None, None, None, None
88 |
89 |
90 | class UpFirDn2d(Function):
91 | @staticmethod
92 | def forward(ctx, input, kernel, up, down, pad):
93 | up_x, up_y = up
94 | down_x, down_y = down
95 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
96 |
97 | kernel_h, kernel_w = kernel.shape
98 | batch, channel, in_h, in_w = input.shape
99 | ctx.in_size = input.shape
100 |
101 | input = input.reshape(-1, in_h, in_w, 1)
102 |
103 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
104 |
105 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
106 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
107 | ctx.out_size = (out_h, out_w)
108 |
109 | ctx.up = (up_x, up_y)
110 | ctx.down = (down_x, down_y)
111 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
112 |
113 | g_pad_x0 = kernel_w - pad_x0 - 1
114 | g_pad_y0 = kernel_h - pad_y0 - 1
115 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
116 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
117 |
118 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
119 |
120 | out = upfirdn2d_op.upfirdn2d(
121 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
122 | )
123 | out = out.view(-1, channel, out_h, out_w)
124 |
125 | return out
126 |
127 | @staticmethod
128 | def backward(ctx, grad_output):
129 | kernel, grad_kernel = ctx.saved_tensors
130 |
131 | grad_input = UpFirDn2dBackward.apply(
132 | grad_output,
133 | kernel,
134 | grad_kernel,
135 | ctx.up,
136 | ctx.down,
137 | ctx.pad,
138 | ctx.g_pad,
139 | ctx.in_size,
140 | ctx.out_size,
141 | )
142 |
143 | return grad_input, None, None, None, None
144 |
145 |
146 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
147 | if input.device.type == "cpu":
148 | out = upfirdn2d_native(
149 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
150 | )
151 |
152 | else:
153 | out = UpFirDn2d.apply(
154 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
155 | )
156 |
157 | return out
158 |
159 |
160 | def upfirdn2d_native(
161 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
162 | ):
163 | _, channel, in_h, in_w = input.shape
164 | input = input.reshape(-1, in_h, in_w, 1)
165 |
166 | _, in_h, in_w, minor = input.shape
167 | kernel_h, kernel_w = kernel.shape
168 |
169 | out = input.view(-1, in_h, 1, in_w, 1, minor)
170 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
171 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
172 |
173 | out = F.pad(
174 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
175 | )
176 | out = out[
177 | :,
178 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
179 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
180 | :,
181 | ]
182 |
183 | out = out.permute(0, 3, 1, 2)
184 | out = out.reshape(
185 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
186 | )
187 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
188 | out = F.conv2d(out, w)
189 | out = out.reshape(
190 | -1,
191 | minor,
192 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
193 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
194 | )
195 | out = out.permute(0, 2, 3, 1)
196 | out = out[:, ::down_y, ::down_x, :]
197 |
198 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
199 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
200 |
201 | return out.view(-1, channel, out_h, out_w)
202 |
--------------------------------------------------------------------------------
/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | lmdb
4 | timm
5 | scipy
6 | scikit-learn
7 | einops
8 | tqdm
9 | wandb
10 | ninja
11 | natten>=0.14.6
12 | hydra-core
13 | hydra_colorlog
14 | joblib
15 | dill
16 | imageio-ffmpeg
17 | ftfy
18 | matplotlib
19 | rich
20 |
--------------------------------------------------------------------------------
/src/evaluate.py:
--------------------------------------------------------------------------------
1 | from rich import print
2 | import time
3 | import os
4 | from tqdm import tqdm
5 | from joblib import Parallel, delayed
6 | import wandb
7 | import torch
8 | from torchvision import transforms as T
9 | from torchvision.utils import make_grid, save_image
10 |
11 | from dataset.dataset import unnormalize
12 | from utils import fid_score
13 | from utils.improved_precision_recall import IPR
14 | from utils.distributed import get_rank
15 |
16 | @torch.inference_mode()
17 | def save_images_batched(args, generator, steps=None, log_first_batch=True):
18 | if args.logging.sample_path[0] != "/":
19 | path = args.save_root + args.logging.sample_path
20 | else:
21 | path = args.logging.sample_path
22 | if steps is not None:
23 | path += f"eval_{str(steps)}"
24 | # We're only going to make the paths for training
25 | if not os.path.exists(path):
26 | os.mkdir(path)
27 | assert(args.evaluation.total_size % args.evaluation.batch == 0),\
28 | f"Evaluation total size % batch should be zero. Got "\
29 | f"{args.evaluation.total_size % args.evaluation.batch}"
30 | print(f"Saving {args.evaluation.total_size} "\
31 | f"Images to {path}")
32 | cnt = 0
33 | nbatches = args.evaluation.total_size // args.evaluation.batch
34 | for _ in tqdm(range(nbatches), desc="Saving Images"):
35 | noise = torch.randn((args.evaluation.batch, args.runs.generator.style_dim)).to(args.device)
36 | sample, _ = generator(noise)
37 | sample = unnormalize(sample)
38 | if args.logging.wandb and get_rank() == 0 and log_first_batch and cnt == 0:
39 | grid = make_grid(sample, nrow=args.evaluation.batch)
40 | wandb.log({"samples": wandb.Image(grid)}, commit=False)
41 | Parallel(n_jobs=args.evaluation.batch)(delayed(save_image)
42 | (img, f"{path}/{str(cnt+j).zfill(6)}.png",
43 | nrow=1, padding=0, normalize=True, value_range=(0,1),)
44 | for j, img in enumerate(sample))
45 | cnt += args.evaluation.batch
46 |
47 | @torch.inference_mode()
48 | def clear_directory(args):
49 | if args.logging.sample_path[0] != "/":
50 | path = args.save_root + args.logging.sample_path
51 | else:
52 | path = args.logging.sample_path
53 | if not os.path.exists(path):
54 | print(f"[bold yellow]WARNING:[/] (Evaluate) {path} does not exist. Creating...")
55 | os.mkdir(path)
56 | _files = os.listdir(path)
57 | if _files == []:
58 | print(f"Directory {path} is already empty. Worry if not first time")
59 | return
60 | assert(".png" in _files[0])
61 | Parallel(n_jobs=32)(delayed(os.remove)
62 | (path + img) for img in _files)
63 |
64 |
65 | @torch.inference_mode()
66 | def evaluate(args,
67 | generator, # Should be g_ema
68 | steps=None, # Save to specific eval dir or common
69 | log_first_batch=True, # Log first batch of images to wandb?
70 | ):
71 | #print(f" Parameters ".center(40, "-"))
72 | print(f"> Generator has:".ljust(19),f"{args.runs.generator.params:.4f} M Parameters")
73 | if not hasattr(args.logging, "sample_path"):
74 | path = args.save_root
75 | else:
76 | if args.logging.sample_path[0] != "/":
77 | path = args.save_root+ args.logging.sample_path
78 | else:
79 | path = args.logging.sample_path
80 | assert(type(path) == str),f"Path needs to be a string not {type(path)}"
81 | if path[-1] != "/": path = path + "/"
82 | if steps is not None:
83 | path += f"eval_{str(steps)}"
84 | # We're only going to make the paths for training
85 | if not os.path.exists(path):
86 | os.mkdir(path)
87 | # Save ALL the images
88 | if args.logging.reuse_samplepath:
89 | clear_directory(args)
90 | torch.cuda.synchronize()
91 | save_images_batched(args=args,
92 | generator=generator,
93 | steps=steps,
94 | log_first_batch=log_first_batch)
95 | if args.evaluation.gt_path[0] != "/":
96 | gt_path = args.dataset.path + args.evaluation.gt_path
97 | else:
98 | gt_path = args.evaluation.gt_path
99 | fid = fid_score.calculate_fid_given_paths([path, gt_path],
100 | batch_size=args.evaluation.batch,
101 | device=args.device,
102 | dims=2048,
103 | num_workers=args.workers,
104 | N=args.evaluation.total_size)
105 | print(f"{args.dataset.name} FID-{args.evaluation.total_size//1000}k : "\
106 | f"{fid:.3f} for {steps} steps")
107 | return fid
108 |
109 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | from rich import print
2 | import os
3 | import time
4 | from PIL import Image
5 | import torch
6 | from torchvision import transforms as T
7 | from torchvision.utils import make_grid, save_image
8 |
9 | from dataset.dataset import unnormalize
10 |
11 | toPIL = T.ToPILImage()
12 | toTensor = T.ToTensor()
13 |
14 | @torch.inference_mode()
15 | def add_watermark(image, im_size, watermark_path="assets/watermark.jpg",
16 | wmsize=16, bbuf=5, opacity=0.9):
17 | '''
18 | Creates a watermark on the saved inference image.
19 | We request that you do not remove this to properly assign credit to
20 | Shi-Lab's work.
21 | '''
22 | image = image.cpu()
23 | watermark = Image.open(watermark_path).resize((wmsize, wmsize))
24 | watermark = toTensor(watermark)
25 | loc = im_size - wmsize - bbuf
26 | image[:,:,loc:-bbuf, loc:-bbuf] = watermark
27 | return image
28 |
29 | def extract_range(args):
30 | '''
31 | Helper function to allow for a range of inputs from hydra-core
32 | This only works if you specify a range from the config file.
33 | seeds: range(start, end, step)
34 |
35 | If you are using the command line pass in a comma delimited sequence with
36 | `seq`
37 | python main type=inference inference.seeds=[`seq -s, start step_size end`]
38 | The back ticks are important because they run a bash command
39 | '''
40 | start = 0
41 | step_size=1
42 | start_end = args.inference.seeds.split("range(")[1].split(")")[0]
43 | if "," in start_end:
44 | nums = start_end.split(',')
45 | match len(nums):
46 | case 1:
47 | end = nums[0]
48 | case 2:
49 | start = nums[0]
50 | end = nums[1]
51 | case 3:
52 | start = nums[0]
53 | end = nums[1]
54 | step_size = nums[2]
55 | case _:
56 | raise ValueError
57 | else:
58 | end = start_end
59 | args.inference.seeds = list(range(int(start), int(end), int(step_size)))
60 | print(f"Using Range of Seeds from {start} to {end} with a step size of {step_size}")
61 |
62 | @torch.inference_mode()
63 | def inference(args, generator):
64 | save_path = args.inference.save_path
65 | if save_path[0] != "/":
66 | save_path = args.save_root + save_path
67 | if not os.path.exists(save_path):
68 | print(f"[bold yellow]WARNING:[/] (Inference) {save_path} does not exist. Creating...")
69 | os.mkdir(save_path)
70 | assert('num_images' in args.inference or 'seeds' in args.inference),\
71 | f"Inference must either specify a number of images "\
72 | f"(random seed generation) or a set of seeds to use to generate."
73 | if 'num_images' in args.inference:
74 | num_imgs = args.inference.num_images
75 | print(f"Got {num_imgs=}, {args.inference.num_images=}")
76 | if "seeds" in args.inference and args.inference.seeds is not None:
77 | # Handles "range(start, end)" input from hydra file
78 | if "range" in args.inference.seeds:
79 | extract_range(args)
80 | num_imgs = len(args.inference.seeds)
81 | else:
82 | num_imgs = len(args.inference.seeds)
83 | if "num_images" in args.inference and args.inference.num_images != num_imgs:
84 | print(f"[bold red]WARNING:[/] You asked for "\
85 | f"{args.inference.num_images} image in the config but specified " \
86 | f"seeds. Seeds overrides and you will get {num_imgs} images.")
87 |
88 | for i in range(num_imgs):
89 | if "seeds" in args.inference and args.inference.seeds is not None and \
90 | i < len(args.inference.seeds):
91 | seed = args.inference.seeds[i]
92 | torch.random.manual_seed(seed)
93 | else:
94 | seed = torch.random.seed()
95 | print(f"Using Seed: {seed}")
96 |
97 | noise = torch.randn((args.inference.batch,
98 | args.runs.generator.style_dim)).to(args.device)
99 | sample, latent = generator(noise)
100 | sample = unnormalize(sample)
101 | sample = add_watermark(sample, im_size=args.runs.size)
102 | save_image(sample, f"{save_path}/{seed}.png",
103 | nrow=1, padding=0, normalize=True, value_range=(0,1))
104 | print(f"Saved {save_path}/{seed}.png")
105 |
--------------------------------------------------------------------------------
/src/logging.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/src/logging.py
--------------------------------------------------------------------------------
/src/throughput.py:
--------------------------------------------------------------------------------
1 | from rich import print
2 | import torch
3 |
4 | @torch.inference_mode()
5 | def throughput(generator,
6 | rounds=100,
7 | warmup=10,
8 | batch_size=1,
9 | style_dim=512,
10 | ):
11 | #torch.backends.cuda.matmul.allow_tf32 = True
12 | times = torch.empty(rounds)
13 | noise = torch.randn((batch_size, style_dim), device="cuda")
14 | for _ in range(warmup):
15 | imgs, _ = generator(noise)
16 | for i in range(rounds):
17 | starter = torch.cuda.Event(enable_timing=True)
18 | ender = torch.cuda.Event(enable_timing=True)
19 | starter.record()
20 | imgs, _ = generator(noise)
21 | ender.record()
22 | torch.cuda.synchronize()
23 | times[i] = starter.elapsed_time(ender)/1000
24 | #print(f"{imgs.shape=}")
25 | total_time = times.sum()
26 | total_images = rounds * batch_size
27 | imgs_per_second = total_images / total_time
28 | print(f"{torch.std_mean(times)=}")
29 | print(f"{total_images=}, {total_time=}")
30 | print(f"{imgs_per_second=}")
31 | #print(f"{torch.std_mean(imgs_per_second)=}")
32 |
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import glob
10 | import hashlib
11 | import importlib
12 | import os
13 | import re
14 | import shutil
15 | import uuid
16 |
17 | import torch
18 | import torch.utils.cpp_extension
19 | from torch.utils.file_baton import FileBaton
20 |
21 | #----------------------------------------------------------------------------
22 | # Global options.
23 |
24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25 |
26 | #----------------------------------------------------------------------------
27 | # Internal helper funcs.
28 |
29 | def _find_compiler_bindir():
30 | patterns = [
31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
35 | ]
36 | for pattern in patterns:
37 | matches = sorted(glob.glob(pattern))
38 | if len(matches):
39 | return matches[-1]
40 | return None
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | def _get_mangled_gpu_name():
45 | name = torch.cuda.get_device_name().lower()
46 | out = []
47 | for c in name:
48 | if re.match('[a-z0-9_-]+', c):
49 | out.append(c)
50 | else:
51 | out.append('-')
52 | return ''.join(out)
53 |
54 | #----------------------------------------------------------------------------
55 | # Main entry point for compiling and loading C++/CUDA plugins.
56 |
57 | _cached_plugins = dict()
58 |
59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60 | assert verbosity in ['none', 'brief', 'full']
61 | if headers is None:
62 | headers = []
63 | if source_dir is not None:
64 | sources = [os.path.join(source_dir, fname) for fname in sources]
65 | headers = [os.path.join(source_dir, fname) for fname in headers]
66 |
67 | # Already cached?
68 | if module_name in _cached_plugins:
69 | return _cached_plugins[module_name]
70 |
71 | # Print status.
72 | if verbosity == 'full':
73 | print(f'Setting up PyTorch plugin "{module_name}"...')
74 | elif verbosity == 'brief':
75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76 | verbose_build = (verbosity == 'full')
77 |
78 | # Compile and load.
79 | try: # pylint: disable=too-many-nested-blocks
80 | # Make sure we can find the necessary compiler binaries.
81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82 | compiler_bindir = _find_compiler_bindir()
83 | if compiler_bindir is None:
84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85 | os.environ['PATH'] += ';' + compiler_bindir
86 |
87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88 | # break the build or unnecessarily restrict what's available to nvcc.
89 | # Unset it to let nvcc decide based on what's available on the
90 | # machine.
91 | os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92 |
93 | # Incremental build md5sum trickery. Copies all the input source files
94 | # into a cached build directory under a combined md5 digest of the input
95 | # source files. Copying is done only if the combined digest has changed.
96 | # This keeps input file timestamps and filenames the same as in previous
97 | # extension builds, allowing for fast incremental rebuilds.
98 | #
99 | # This optimization is done only in case all the source files reside in
100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101 | # environment variable is set (we take this as a signal that the user
102 | # actually cares about this.)
103 | #
104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105 | # around the *.cu dependency bug in ninja config.
106 | #
107 | all_source_files = sorted(sources + headers)
108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110 |
111 | # Compute combined hash digest for all source files.
112 | hash_md5 = hashlib.md5()
113 | for src in all_source_files:
114 | with open(src, 'rb') as f:
115 | hash_md5.update(f.read())
116 |
117 | # Select cached build directory name.
118 | source_digest = hash_md5.hexdigest()
119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121 |
122 | if not os.path.isdir(cached_build_dir):
123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124 | os.makedirs(tmpdir)
125 | for src in all_source_files:
126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127 | try:
128 | os.replace(tmpdir, cached_build_dir) # atomic
129 | except OSError:
130 | # source directory already exists, delete tmpdir and its contents.
131 | shutil.rmtree(tmpdir)
132 | if not os.path.isdir(cached_build_dir): raise
133 |
134 | # Compile.
135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137 | verbose=verbose_build, sources=cached_sources, **build_kwargs)
138 | else:
139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140 |
141 | # Load.
142 | module = importlib.import_module(module_name)
143 |
144 | except:
145 | if verbosity == 'brief':
146 | print('Failed!')
147 | raise
148 |
149 | # Print status and add to cache dict.
150 | if verbosity == 'full':
151 | print(f'Done setting up PyTorch plugin "{module_name}".')
152 | elif verbosity == 'brief':
153 | print('Done.')
154 | _cached_plugins[module_name] = module
155 | return module
156 |
157 | #----------------------------------------------------------------------------
158 |
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import re
10 | import contextlib
11 | import numpy as np
12 | import torch
13 | import warnings
14 | import dnnlib
15 |
16 | #----------------------------------------------------------------------------
17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18 | # same constant is used multiple times.
19 |
20 | _constant_cache = dict()
21 |
22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23 | value = np.asarray(value)
24 | if shape is not None:
25 | shape = tuple(shape)
26 | if dtype is None:
27 | dtype = torch.get_default_dtype()
28 | if device is None:
29 | device = torch.device('cpu')
30 | if memory_format is None:
31 | memory_format = torch.contiguous_format
32 |
33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34 | tensor = _constant_cache.get(key, None)
35 | if tensor is None:
36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37 | if shape is not None:
38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39 | tensor = tensor.contiguous(memory_format=memory_format)
40 | _constant_cache[key] = tensor
41 | return tensor
42 |
43 | #----------------------------------------------------------------------------
44 | # Replace NaN/Inf with specified numerical values.
45 |
46 | try:
47 | nan_to_num = torch.nan_to_num # 1.8.0a0
48 | except AttributeError:
49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50 | assert isinstance(input, torch.Tensor)
51 | if posinf is None:
52 | posinf = torch.finfo(input.dtype).max
53 | if neginf is None:
54 | neginf = torch.finfo(input.dtype).min
55 | assert nan == 0
56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57 |
58 | #----------------------------------------------------------------------------
59 | # Symbolic assert.
60 |
61 | try:
62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63 | except AttributeError:
64 | symbolic_assert = torch.Assert # 1.7.0
65 |
66 | #----------------------------------------------------------------------------
67 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
68 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
69 |
70 | @contextlib.contextmanager
71 | def suppress_tracer_warnings():
72 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
73 | warnings.filters.insert(0, flt)
74 | yield
75 | warnings.filters.remove(flt)
76 |
77 | #----------------------------------------------------------------------------
78 | # Assert that the shape of a tensor matches the given list of integers.
79 | # None indicates that the size of a dimension is allowed to vary.
80 | # Performs symbolic assertion when used in torch.jit.trace().
81 |
82 | def assert_shape(tensor, ref_shape):
83 | if tensor.ndim != len(ref_shape):
84 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86 | if ref_size is None:
87 | pass
88 | elif isinstance(ref_size, torch.Tensor):
89 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
90 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91 | elif isinstance(size, torch.Tensor):
92 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
93 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94 | elif size != ref_size:
95 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96 |
97 | #----------------------------------------------------------------------------
98 | # Function decorator that calls torch.autograd.profiler.record_function().
99 |
100 | def profiled_function(fn):
101 | def decorator(*args, **kwargs):
102 | with torch.autograd.profiler.record_function(fn.__name__):
103 | return fn(*args, **kwargs)
104 | decorator.__name__ = fn.__name__
105 | return decorator
106 |
107 | #----------------------------------------------------------------------------
108 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
109 | # indefinitely, shuffling items as it goes.
110 |
111 | class InfiniteSampler(torch.utils.data.Sampler):
112 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113 | assert len(dataset) > 0
114 | assert num_replicas > 0
115 | assert 0 <= rank < num_replicas
116 | assert 0 <= window_size <= 1
117 | super().__init__(dataset)
118 | self.dataset = dataset
119 | self.rank = rank
120 | self.num_replicas = num_replicas
121 | self.shuffle = shuffle
122 | self.seed = seed
123 | self.window_size = window_size
124 |
125 | def __iter__(self):
126 | order = np.arange(len(self.dataset))
127 | rnd = None
128 | window = 0
129 | if self.shuffle:
130 | rnd = np.random.RandomState(self.seed)
131 | rnd.shuffle(order)
132 | window = int(np.rint(order.size * self.window_size))
133 |
134 | idx = 0
135 | while True:
136 | i = idx % order.size
137 | if idx % self.num_replicas == self.rank:
138 | yield order[i]
139 | if window >= 2:
140 | j = (i - rnd.randint(window)) % order.size
141 | order[i], order[j] = order[j], order[i]
142 | idx += 1
143 |
144 | #----------------------------------------------------------------------------
145 | # Utilities for operating with torch.nn.Module parameters and buffers.
146 | def spectral_to_cpu(model: torch.nn.Module):
147 | def wrapped_in_spectral(m): return hasattr(m, 'weight_v')
148 | children = get_children(model)
149 | for child in children:
150 | if wrapped_in_spectral(child):
151 | child.weight = child.weight.cpu()
152 | return model
153 |
154 | def get_children(model: torch.nn.Module):
155 | children = list(model.children())
156 | flatt_children = []
157 | if children == []:
158 | return model
159 | else:
160 | for child in children:
161 | try:
162 | flatt_children.extend(get_children(child))
163 | except TypeError:
164 | flatt_children.append(get_children(child))
165 | return flatt_children
166 |
167 | def params_and_buffers(module):
168 | assert isinstance(module, torch.nn.Module)
169 | return list(module.parameters()) + list(module.buffers())
170 |
171 | def named_params_and_buffers(module):
172 | assert isinstance(module, torch.nn.Module)
173 | return list(module.named_parameters()) + list(module.named_buffers())
174 |
175 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
176 | assert isinstance(src_module, torch.nn.Module)
177 | assert isinstance(dst_module, torch.nn.Module)
178 | src_tensors = dict(named_params_and_buffers(src_module))
179 | for name, tensor in named_params_and_buffers(dst_module):
180 | assert (name in src_tensors) or (not require_all)
181 | if name in src_tensors:
182 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
183 |
184 | #----------------------------------------------------------------------------
185 | # Context manager for easily enabling/disabling DistributedDataParallel
186 | # synchronization.
187 |
188 | @contextlib.contextmanager
189 | def ddp_sync(module, sync):
190 | assert isinstance(module, torch.nn.Module)
191 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
192 | yield
193 | else:
194 | with module.no_sync():
195 | yield
196 |
197 | #----------------------------------------------------------------------------
198 | # Check DistributedDataParallel consistency across processes.
199 |
200 | def check_ddp_consistency(module, ignore_regex=None):
201 | assert isinstance(module, torch.nn.Module)
202 | for name, tensor in named_params_and_buffers(module):
203 | fullname = type(module).__name__ + '.' + name
204 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
205 | continue
206 | tensor = tensor.detach()
207 | if tensor.is_floating_point():
208 | tensor = nan_to_num(tensor)
209 | other = tensor.clone()
210 | torch.distributed.broadcast(tensor=other, src=0)
211 | assert (tensor == other).all(), fullname
212 |
213 | #----------------------------------------------------------------------------
214 | # Print summary table of module hierarchy.
215 |
216 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
217 | assert isinstance(module, torch.nn.Module)
218 | assert not isinstance(module, torch.jit.ScriptModule)
219 | assert isinstance(inputs, (tuple, list))
220 |
221 | # Register hooks.
222 | entries = []
223 | nesting = [0]
224 | def pre_hook(_mod, _inputs):
225 | nesting[0] += 1
226 | def post_hook(mod, _inputs, outputs):
227 | nesting[0] -= 1
228 | if nesting[0] <= max_nesting:
229 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
230 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
231 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
232 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
233 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
234 |
235 | # Run module.
236 | outputs = module(*inputs)
237 | for hook in hooks:
238 | hook.remove()
239 |
240 | # Identify unique outputs, parameters, and buffers.
241 | tensors_seen = set()
242 | for e in entries:
243 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
244 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
245 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
246 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
247 |
248 | # Filter out redundant entries.
249 | if skip_redundant:
250 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
251 |
252 | # Construct table.
253 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
254 | rows += [['---'] * len(rows[0])]
255 | param_total = 0
256 | buffer_total = 0
257 | submodule_names = {mod: name for name, mod in module.named_modules()}
258 | for e in entries:
259 | name = '' if e.mod is module else submodule_names[e.mod]
260 | param_size = sum(t.numel() for t in e.unique_params)
261 | buffer_size = sum(t.numel() for t in e.unique_buffers)
262 | output_shapes = [str(list(t.shape)) for t in e.outputs]
263 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
264 | rows += [[
265 | name + (':0' if len(e.outputs) >= 2 else ''),
266 | str(param_size) if param_size else '-',
267 | str(buffer_size) if buffer_size else '-',
268 | (output_shapes + ['-'])[0],
269 | (output_dtypes + ['-'])[0],
270 | ]]
271 | for idx in range(1, len(e.outputs)):
272 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
273 | param_total += param_size
274 | buffer_total += buffer_size
275 | rows += [['---'] * len(rows[0])]
276 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
277 |
278 | # Print table.
279 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
280 | print()
281 | for row in rows:
282 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
283 | print()
284 | return outputs
285 |
286 | #----------------------------------------------------------------------------
287 |
288 | # Added by Katja
289 | import os
290 |
291 | def get_ckpt_path(run_dir):
292 | return os.path.join(run_dir, f'network-snapshot.pkl')
293 |
--------------------------------------------------------------------------------
/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient bias and activation."""
10 |
11 | import os
12 | import numpy as np
13 | import torch
14 | import dnnlib
15 |
16 | from .. import custom_ops
17 | from .. import misc
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | activation_funcs = {
22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
31 | }
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | _plugin = None
36 | _null_tensor = torch.empty([0])
37 |
38 | def _init():
39 | global _plugin
40 | if _plugin is None:
41 | _plugin = custom_ops.get_plugin(
42 | module_name='bias_act_plugin',
43 | sources=['bias_act.cpp', 'bias_act.cu'],
44 | headers=['bias_act.h'],
45 | source_dir=os.path.dirname(__file__),
46 | extra_cuda_cflags=['--use_fast_math'],
47 | )
48 | return True
49 |
50 | #----------------------------------------------------------------------------
51 |
52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
53 | r"""Fused bias and activation function.
54 |
55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
56 | and scales the result by `gain`. Each of the steps is optional. In most cases,
57 | the fused op is considerably more efficient than performing the same calculation
58 | using standard PyTorch ops. It supports first and second order gradients,
59 | but not third order gradients.
60 |
61 | Args:
62 | x: Input activation tensor. Can be of any shape.
63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
64 | as `x`. The shape must be known, and it must match the dimension of `x`
65 | corresponding to `dim`.
66 | dim: The dimension in `x` corresponding to the elements of `b`.
67 | The value of `dim` is ignored if `b` is not specified.
68 | act: Name of the activation function to evaluate, or `"linear"` to disable.
69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
70 | See `activation_funcs` for a full list. `None` is not allowed.
71 | alpha: Shape parameter for the activation function, or `None` to use the default.
72 | gain: Scaling factor for the output tensor, or `None` to use default.
73 | See `activation_funcs` for the default scaling of each activation function.
74 | If unsure, consider specifying 1.
75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
76 | the clamping (default).
77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
78 |
79 | Returns:
80 | Tensor of the same shape and datatype as `x`.
81 | """
82 | assert isinstance(x, torch.Tensor)
83 | assert impl in ['ref', 'cuda']
84 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
87 |
88 | #----------------------------------------------------------------------------
89 |
90 | @misc.profiled_function
91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
93 | """
94 | assert isinstance(x, torch.Tensor)
95 | assert clamp is None or clamp >= 0
96 | spec = activation_funcs[act]
97 | alpha = float(alpha if alpha is not None else spec.def_alpha)
98 | gain = float(gain if gain is not None else spec.def_gain)
99 | clamp = float(clamp if clamp is not None else -1)
100 |
101 | # Add bias.
102 | if b is not None:
103 | assert isinstance(b, torch.Tensor) and b.ndim == 1
104 | assert 0 <= dim < x.ndim
105 | assert b.shape[0] == x.shape[dim]
106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
107 |
108 | # Evaluate activation function.
109 | alpha = float(alpha)
110 | x = spec.func(x, alpha=alpha)
111 |
112 | # Scale by gain.
113 | gain = float(gain)
114 | if gain != 1:
115 | x = x * gain
116 |
117 | # Clamp.
118 | if clamp >= 0:
119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
120 | return x
121 |
122 | #----------------------------------------------------------------------------
123 |
124 | _bias_act_cuda_cache = dict()
125 |
126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
127 | """Fast CUDA implementation of `bias_act()` using custom ops.
128 | """
129 | # Parse arguments.
130 | assert clamp is None or clamp >= 0
131 | spec = activation_funcs[act]
132 | alpha = float(alpha if alpha is not None else spec.def_alpha)
133 | gain = float(gain if gain is not None else spec.def_gain)
134 | clamp = float(clamp if clamp is not None else -1)
135 |
136 | # Lookup from cache.
137 | key = (dim, act, alpha, gain, clamp)
138 | if key in _bias_act_cuda_cache:
139 | return _bias_act_cuda_cache[key]
140 |
141 | # Forward op.
142 | class BiasActCuda(torch.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, x, b): # pylint: disable=arguments-differ
145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
146 | x = x.contiguous(memory_format=ctx.memory_format)
147 | b = b.contiguous() if b is not None else _null_tensor
148 | y = x
149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
151 | ctx.save_for_backward(
152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
154 | y if 'y' in spec.ref else _null_tensor)
155 | return y
156 |
157 | @staticmethod
158 | def backward(ctx, dy): # pylint: disable=arguments-differ
159 | dy = dy.contiguous(memory_format=ctx.memory_format)
160 | x, b, y = ctx.saved_tensors
161 | dx = None
162 | db = None
163 |
164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
165 | dx = dy
166 | if act != 'linear' or gain != 1 or clamp >= 0:
167 | dx = BiasActCudaGrad.apply(dy, x, b, y)
168 |
169 | if ctx.needs_input_grad[1]:
170 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
171 |
172 | return dx, db
173 |
174 | # Backward op.
175 | class BiasActCudaGrad(torch.autograd.Function):
176 | @staticmethod
177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
180 | ctx.save_for_backward(
181 | dy if spec.has_2nd_grad else _null_tensor,
182 | x, b, y)
183 | return dx
184 |
185 | @staticmethod
186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
188 | dy, x, b, y = ctx.saved_tensors
189 | d_dy = None
190 | d_x = None
191 | d_b = None
192 | d_y = None
193 |
194 | if ctx.needs_input_grad[0]:
195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
196 |
197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
199 |
200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
202 |
203 | return d_dy, d_x, d_b, d_y
204 |
205 | # Add to cache.
206 | _bias_act_cuda_cache[key] = BiasActCuda
207 | return BiasActCuda
208 |
209 | #----------------------------------------------------------------------------
210 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 |
12 | import contextlib
13 | import torch
14 |
15 | # pylint: disable=redefined-builtin
16 | # pylint: disable=arguments-differ
17 | # pylint: disable=protected-access
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | enabled = False # Enable the custom op by setting this to true.
22 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
23 |
24 | @contextlib.contextmanager
25 | def no_weight_gradients(disable=True):
26 | global weight_gradients_disabled
27 | old = weight_gradients_disabled
28 | if disable:
29 | weight_gradients_disabled = True
30 | yield
31 | weight_gradients_disabled = old
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36 | if _should_use_custom_op(input):
37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39 |
40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41 | if _should_use_custom_op(input):
42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | def _should_use_custom_op(input):
48 | assert isinstance(input, torch.Tensor)
49 | if (not enabled) or (not torch.backends.cudnn.enabled):
50 | return False
51 | if input.device.type != 'cuda':
52 | return False
53 | return True
54 |
55 | def _tuple_of_ints(xs, ndim):
56 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
57 | assert len(xs) == ndim
58 | assert all(isinstance(x, int) for x in xs)
59 | return xs
60 |
61 | #----------------------------------------------------------------------------
62 |
63 | _conv2d_gradfix_cache = dict()
64 | _null_tensor = torch.empty([0])
65 |
66 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
67 | # Parse arguments.
68 | ndim = 2
69 | weight_shape = tuple(weight_shape)
70 | stride = _tuple_of_ints(stride, ndim)
71 | padding = _tuple_of_ints(padding, ndim)
72 | output_padding = _tuple_of_ints(output_padding, ndim)
73 | dilation = _tuple_of_ints(dilation, ndim)
74 |
75 | # Lookup from cache.
76 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
77 | if key in _conv2d_gradfix_cache:
78 | return _conv2d_gradfix_cache[key]
79 |
80 | # Validate arguments.
81 | assert groups >= 1
82 | assert len(weight_shape) == ndim + 2
83 | assert all(stride[i] >= 1 for i in range(ndim))
84 | assert all(padding[i] >= 0 for i in range(ndim))
85 | assert all(dilation[i] >= 0 for i in range(ndim))
86 | if not transpose:
87 | assert all(output_padding[i] == 0 for i in range(ndim))
88 | else: # transpose
89 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
90 |
91 | # Helpers.
92 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
93 | def calc_output_padding(input_shape, output_shape):
94 | if transpose:
95 | return [0, 0]
96 | return [
97 | input_shape[i + 2]
98 | - (output_shape[i + 2] - 1) * stride[i]
99 | - (1 - 2 * padding[i])
100 | - dilation[i] * (weight_shape[i + 2] - 1)
101 | for i in range(ndim)
102 | ]
103 |
104 | # Forward & backward.
105 | class Conv2d(torch.autograd.Function):
106 | @staticmethod
107 | def forward(ctx, input, weight, bias):
108 | assert weight.shape == weight_shape
109 | ctx.save_for_backward(
110 | input if weight.requires_grad else _null_tensor,
111 | weight if input.requires_grad else _null_tensor,
112 | )
113 | ctx.input_shape = input.shape
114 |
115 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
116 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
117 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
118 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
119 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
120 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
121 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
122 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
123 |
124 | # General case => cuDNN.
125 | if transpose:
126 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
127 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
128 |
129 | @staticmethod
130 | def backward(ctx, grad_output):
131 | input, weight = ctx.saved_tensors
132 | input_shape = ctx.input_shape
133 | grad_input = None
134 | grad_weight = None
135 | grad_bias = None
136 |
137 | if ctx.needs_input_grad[0]:
138 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
139 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
140 | grad_input = op.apply(grad_output, weight, None)
141 | assert grad_input.shape == input_shape
142 |
143 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
144 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
145 | assert grad_weight.shape == weight_shape
146 |
147 | if ctx.needs_input_grad[2]:
148 | grad_bias = grad_output.sum([0, 2, 3])
149 |
150 | return grad_input, grad_weight, grad_bias
151 |
152 | # Gradient with respect to the weights.
153 | class Conv2dGradWeight(torch.autograd.Function):
154 | @staticmethod
155 | def forward(ctx, grad_output, input):
156 | ctx.save_for_backward(
157 | grad_output if input.requires_grad else _null_tensor,
158 | input if grad_output.requires_grad else _null_tensor,
159 | )
160 | ctx.grad_output_shape = grad_output.shape
161 | ctx.input_shape = input.shape
162 |
163 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
164 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
165 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
166 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
167 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
168 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
169 |
170 | # General case => cuDNN.
171 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
172 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
173 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
174 |
175 | @staticmethod
176 | def backward(ctx, grad2_grad_weight):
177 | grad_output, input = ctx.saved_tensors
178 | grad_output_shape = ctx.grad_output_shape
179 | input_shape = ctx.input_shape
180 | grad2_grad_output = None
181 | grad2_input = None
182 |
183 | if ctx.needs_input_grad[0]:
184 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
185 | assert grad2_grad_output.shape == grad_output_shape
186 |
187 | if ctx.needs_input_grad[1]:
188 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
189 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
190 | grad2_input = op.apply(grad_output, grad2_grad_weight, None)
191 | assert grad2_input.shape == input_shape
192 |
193 | return grad2_grad_output, grad2_input
194 |
195 | _conv2d_gradfix_cache[key] = Conv2d
196 | return Conv2d
197 |
198 | #----------------------------------------------------------------------------
199 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | if not flip_weight and (kw > 1 or kh > 1):
37 | w = w.flip([2, 3])
38 |
39 | # Execute using conv2d_gradfix.
40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41 | return op(x, w, stride=stride, padding=padding, groups=groups)
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | @misc.profiled_function
46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47 | r"""2D convolution with optional up/downsampling.
48 |
49 | Padding is performed only once at the beginning, not between the operations.
50 |
51 | Args:
52 | x: Input tensor of shape
53 | `[batch_size, in_channels, in_height, in_width]`.
54 | w: Weight tensor of shape
55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57 | calling upfirdn2d.setup_filter(). None = identity (default).
58 | up: Integer upsampling factor (default: 1).
59 | down: Integer downsampling factor (default: 1).
60 | padding: Padding with respect to the upsampled image. Can be a single number
61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62 | (default: 0).
63 | groups: Split input channels into N groups (default: 1).
64 | flip_weight: False = convolution, True = correlation (default: True).
65 | flip_filter: False = convolution, True = correlation (default: False).
66 |
67 | Returns:
68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69 | """
70 | # Validate arguments.
71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74 | assert isinstance(up, int) and (up >= 1)
75 | assert isinstance(down, int) and (down >= 1)
76 | assert isinstance(groups, int) and (groups >= 1)
77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78 | fw, fh = _get_filter_size(f)
79 | px0, px1, py0, py1 = _parse_padding(padding)
80 |
81 | # Adjust padding to account for up/downsampling.
82 | if up > 1:
83 | px0 += (fw + up - 1) // 2
84 | px1 += (fw - up) // 2
85 | py0 += (fh + up - 1) // 2
86 | py1 += (fh - up) // 2
87 | if down > 1:
88 | px0 += (fw - down + 1) // 2
89 | px1 += (fw - down) // 2
90 | py0 += (fh - down + 1) // 2
91 | py1 += (fh - down) // 2
92 |
93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97 | return x
98 |
99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103 | return x
104 |
105 | # Fast path: downsampling only => use strided convolution.
106 | if down > 1 and up == 1:
107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109 | return x
110 |
111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112 | if up > 1:
113 | if groups == 1:
114 | w = w.transpose(0, 1)
115 | else:
116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117 | w = w.transpose(1, 2)
118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119 | px0 -= kw - 1
120 | px1 -= kw - up
121 | py0 -= kh - 1
122 | py1 -= kh - up
123 | pxt = max(min(-px0, -px1), 0)
124 | pyt = max(min(-py0, -py1), 0)
125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127 | if down > 1:
128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129 | return x
130 |
131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132 | if up == 1 and down == 1:
133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135 |
136 | # Fallback: Generic reference implementation.
137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139 | if down > 1:
140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141 | return x
142 |
143 | #----------------------------------------------------------------------------
144 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct filtered_lrelu_kernel_params
15 | {
16 | // These parameters decide which kernel to use.
17 | int up; // upsampling ratio (1, 2, 4)
18 | int down; // downsampling ratio (1, 2, 4)
19 | int2 fuShape; // [size, 1] | [size, size]
20 | int2 fdShape; // [size, 1] | [size, size]
21 |
22 | int _dummy; // Alignment.
23 |
24 | // Rest of the parameters.
25 | const void* x; // Input tensor.
26 | void* y; // Output tensor.
27 | const void* b; // Bias tensor.
28 | unsigned char* s; // Sign tensor in/out. NULL if unused.
29 | const float* fu; // Upsampling filter.
30 | const float* fd; // Downsampling filter.
31 |
32 | int2 pad0; // Left/top padding.
33 | float gain; // Additional gain factor.
34 | float slope; // Leaky ReLU slope on negative side.
35 | float clamp; // Clamp after nonlinearity.
36 | int flip; // Filter kernel flip for gradient computation.
37 |
38 | int tilesXdim; // Original number of horizontal output tiles.
39 | int tilesXrep; // Number of horizontal tiles per CTA.
40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41 |
42 | int4 xShape; // [width, height, channel, batch]
43 | int4 yShape; // [width, height, channel, batch]
44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46 | int swLimit; // Active width of sign tensor in bytes.
47 |
48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49 | longlong4 yStride; //
50 | int64_t bStride; //
51 | longlong3 fuStride; //
52 | longlong3 fdStride; //
53 | };
54 |
55 | struct filtered_lrelu_act_kernel_params
56 | {
57 | void* x; // Input/output, modified in-place.
58 | unsigned char* s; // Sign tensor in/out. NULL if unused.
59 |
60 | float gain; // Additional gain factor.
61 | float slope; // Leaky ReLU slope on negative side.
62 | float clamp; // Clamp after nonlinearity.
63 |
64 | int4 xShape; // [width, height, channel, batch]
65 | longlong4 xStride; // Input/output tensor strides, same order as in shape.
66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68 | };
69 |
70 | //------------------------------------------------------------------------
71 | // CUDA kernel specialization.
72 |
73 | struct filtered_lrelu_kernel_spec
74 | {
75 | void* setup; // Function for filter kernel setup.
76 | void* exec; // Function for main operation.
77 | int2 tileOut; // Width/height of launch tile.
78 | int numWarps; // Number of warps per thread block, determines launch block size.
79 | int xrep; // For processing multiple horizontal tiles per thread block.
80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81 | };
82 |
83 | //------------------------------------------------------------------------
84 | // CUDA kernel selection.
85 |
86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87 | template void* choose_filtered_lrelu_act_kernel(void);
88 | template cudaError_t copy_filters(cudaStream_t stream);
89 |
90 | //------------------------------------------------------------------------
91 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for no signs mode (no gradients required).
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign read mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign write mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import torch
15 |
16 | # pylint: disable=redefined-builtin
17 | # pylint: disable=arguments-differ
18 | # pylint: disable=protected-access
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | enabled = False # Enable the custom op by setting this to true.
23 |
24 | #----------------------------------------------------------------------------
25 |
26 | def grid_sample(input, grid):
27 | if _should_use_custom_op():
28 | return _GridSample2dForward.apply(input, grid)
29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
30 |
31 | #----------------------------------------------------------------------------
32 |
33 | def _should_use_custom_op():
34 | return enabled
35 |
36 | #----------------------------------------------------------------------------
37 |
38 | class _GridSample2dForward(torch.autograd.Function):
39 | @staticmethod
40 | def forward(ctx, input, grid):
41 | assert input.ndim == 4
42 | assert grid.ndim == 4
43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
44 | ctx.save_for_backward(input, grid)
45 | return output
46 |
47 | @staticmethod
48 | def backward(ctx, grad_output):
49 | input, grid = ctx.saved_tensors
50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
51 | return grad_input, grad_grid
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | class _GridSample2dBackward(torch.autograd.Function):
56 | @staticmethod
57 | def forward(ctx, grad_output, input, grid):
58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
60 | ctx.save_for_backward(grid)
61 | return grad_input, grad_grid
62 |
63 | @staticmethod
64 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
65 | _ = grad2_grad_grid # unused
66 | grid, = ctx.saved_tensors
67 | grad2_grad_output = None
68 | grad2_input = None
69 | grad2_grid = None
70 |
71 | if ctx.needs_input_grad[0]:
72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
73 |
74 | assert not ctx.needs_input_grad[2]
75 | return grad2_grad_output, grad2_input, grad2_grid
76 |
77 | #----------------------------------------------------------------------------
78 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.numel() > 0, "x has zero size");
25 | TORCH_CHECK(f.numel() > 0, "f has zero size");
26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32 |
33 | // Create output tensor.
34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41 |
42 | // Initialize CUDA kernel parameters.
43 | upfirdn2d_kernel_params p;
44 | p.x = x.data_ptr();
45 | p.f = f.data_ptr();
46 | p.y = y.data_ptr();
47 | p.up = make_int2(upx, upy);
48 | p.down = make_int2(downx, downy);
49 | p.pad0 = make_int2(padx0, pady0);
50 | p.flip = (flip) ? 1 : 0;
51 | p.gain = gain;
52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60 |
61 | // Choose CUDA kernel.
62 | upfirdn2d_kernel_spec spec;
63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64 | {
65 | spec = choose_upfirdn2d_kernel(p);
66 | });
67 |
68 | // Set looping options.
69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70 | p.loopMinor = spec.loopMinor;
71 | p.loopX = spec.loopX;
72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74 |
75 | // Compute grid size.
76 | dim3 blockSize, gridSize;
77 | if (spec.tileOutW < 0) // large
78 | {
79 | blockSize = dim3(4, 32, 1);
80 | gridSize = dim3(
81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83 | p.launchMajor);
84 | }
85 | else // small
86 | {
87 | blockSize = dim3(256, 1, 1);
88 | gridSize = dim3(
89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91 | p.launchMajor);
92 | }
93 |
94 | // Launch CUDA kernel.
95 | void* args[] = {&p};
96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97 | return y;
98 | }
99 |
100 | //------------------------------------------------------------------------
101 |
102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103 | {
104 | m.def("upfirdn2d", &upfirdn2d);
105 | }
106 |
107 | //------------------------------------------------------------------------
108 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for pickling Python code alongside other data.
10 |
11 | The pickled code is automatically imported into a separate Python module
12 | during unpickling. This way, any previously exported pickles will remain
13 | usable even if the original code is no longer available, or if the current
14 | version of the code is not consistent with what was originally pickled."""
15 |
16 | import sys
17 | import pickle
18 | import io
19 | import inspect
20 | import copy
21 | import uuid
22 | import types
23 | import dnnlib
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | _version = 6 # internal version number
28 | _decorators = set() # {decorator_class, ...}
29 | _import_hooks = [] # [hook_function, ...]
30 | _module_to_src_dict = dict() # {module: src, ...}
31 | _src_to_module_dict = dict() # {src: module, ...}
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def persistent_class(orig_class):
36 | r"""Class decorator that extends a given class to save its source code
37 | when pickled.
38 |
39 | Example:
40 |
41 | from torch_utils import persistence
42 |
43 | @persistence.persistent_class
44 | class MyNetwork(torch.nn.Module):
45 | def __init__(self, num_inputs, num_outputs):
46 | super().__init__()
47 | self.fc = MyLayer(num_inputs, num_outputs)
48 | ...
49 |
50 | @persistence.persistent_class
51 | class MyLayer(torch.nn.Module):
52 | ...
53 |
54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55 | source code alongside other internal state (e.g., parameters, buffers,
56 | and submodules). This way, any previously exported pickle will remain
57 | usable even if the class definitions have been modified or are no
58 | longer available.
59 |
60 | The decorator saves the source code of the entire Python module
61 | containing the decorated class. It does *not* save the source code of
62 | any imported modules. Thus, the imported modules must be available
63 | during unpickling, also including `torch_utils.persistence` itself.
64 |
65 | It is ok to call functions defined in the same module from the
66 | decorated class. However, if the decorated class depends on other
67 | classes defined in the same module, they must be decorated as well.
68 | This is illustrated in the above example in the case of `MyLayer`.
69 |
70 | It is also possible to employ the decorator just-in-time before
71 | calling the constructor. For example:
72 |
73 | cls = MyLayer
74 | if want_to_make_it_persistent:
75 | cls = persistence.persistent_class(cls)
76 | layer = cls(num_inputs, num_outputs)
77 |
78 | As an additional feature, the decorator also keeps track of the
79 | arguments that were used to construct each instance of the decorated
80 | class. The arguments can be queried via `obj.init_args` and
81 | `obj.init_kwargs`, and they are automatically pickled alongside other
82 | object state. A typical use case is to first unpickle a previous
83 | instance of a persistent class, and then upgrade it to use the latest
84 | version of the source code:
85 |
86 | with open('old_pickle.pkl', 'rb') as f:
87 | old_net = pickle.load(f)
88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90 | """
91 | assert isinstance(orig_class, type)
92 | if is_persistent(orig_class):
93 | return orig_class
94 |
95 | assert orig_class.__module__ in sys.modules
96 | orig_module = sys.modules[orig_class.__module__]
97 | orig_module_src = _module_to_src(orig_module)
98 |
99 | class Decorator(orig_class):
100 | _orig_module_src = orig_module_src
101 | _orig_class_name = orig_class.__name__
102 |
103 | def __init__(self, *args, **kwargs):
104 | super().__init__(*args, **kwargs)
105 | self._init_args = copy.deepcopy(args)
106 | self._init_kwargs = copy.deepcopy(kwargs)
107 | assert orig_class.__name__ in orig_module.__dict__
108 | _check_pickleable(self.__reduce__())
109 |
110 | @property
111 | def init_args(self):
112 | return copy.deepcopy(self._init_args)
113 |
114 | @property
115 | def init_kwargs(self):
116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117 |
118 | def __reduce__(self):
119 | fields = list(super().__reduce__())
120 | fields += [None] * max(3 - len(fields), 0)
121 | if fields[0] is not _reconstruct_persistent_obj:
122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123 | fields[0] = _reconstruct_persistent_obj # reconstruct func
124 | fields[1] = (meta,) # reconstruct args
125 | fields[2] = None # state dict
126 | return tuple(fields)
127 |
128 | Decorator.__name__ = orig_class.__name__
129 | _decorators.add(Decorator)
130 | return Decorator
131 |
132 | #----------------------------------------------------------------------------
133 |
134 | def is_persistent(obj):
135 | r"""Test whether the given object or class is persistent, i.e.,
136 | whether it will save its source code when pickled.
137 | """
138 | try:
139 | if obj in _decorators:
140 | return True
141 | except TypeError:
142 | pass
143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144 |
145 | #----------------------------------------------------------------------------
146 |
147 | def import_hook(hook):
148 | r"""Register an import hook that is called whenever a persistent object
149 | is being unpickled. A typical use case is to patch the pickled source
150 | code to avoid errors and inconsistencies when the API of some imported
151 | module has changed.
152 |
153 | The hook should have the following signature:
154 |
155 | hook(meta) -> modified meta
156 |
157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158 |
159 | type: Type of the persistent object, e.g. `'class'`.
160 | version: Internal version number of `torch_utils.persistence`.
161 | module_src Original source code of the Python module.
162 | class_name: Class name in the original Python module.
163 | state: Internal state of the object.
164 |
165 | Example:
166 |
167 | @persistence.import_hook
168 | def wreck_my_network(meta):
169 | if meta.class_name == 'MyNetwork':
170 | print('MyNetwork is being imported. I will wreck it!')
171 | meta.module_src = meta.module_src.replace("True", "False")
172 | return meta
173 | """
174 | assert callable(hook)
175 | _import_hooks.append(hook)
176 |
177 | #----------------------------------------------------------------------------
178 |
179 | def _reconstruct_persistent_obj(meta):
180 | r"""Hook that is called internally by the `pickle` module to unpickle
181 | a persistent object.
182 | """
183 | meta = dnnlib.EasyDict(meta)
184 | meta.state = dnnlib.EasyDict(meta.state)
185 | for hook in _import_hooks:
186 | meta = hook(meta)
187 | assert meta is not None
188 |
189 | assert meta.version == _version
190 | module = _src_to_module(meta.module_src)
191 |
192 | assert meta.type == 'class'
193 | orig_class = module.__dict__[meta.class_name]
194 | decorator_class = persistent_class(orig_class)
195 | obj = decorator_class.__new__(decorator_class)
196 |
197 | setstate = getattr(obj, '__setstate__', None)
198 | if callable(setstate):
199 | setstate(meta.state) # pylint: disable=not-callable
200 | else:
201 | obj.__dict__.update(meta.state)
202 | return obj
203 |
204 | #----------------------------------------------------------------------------
205 |
206 | def _module_to_src(module):
207 | r"""Query the source code of a given Python module.
208 | """
209 | src = _module_to_src_dict.get(module, None)
210 | if src is None:
211 | src = inspect.getsource(module)
212 | _module_to_src_dict[module] = src
213 | _src_to_module_dict[src] = module
214 | return src
215 |
216 | def _src_to_module(src):
217 | r"""Get or create a Python module for the given source code.
218 | """
219 | module = _src_to_module_dict.get(src, None)
220 | if module is None:
221 | module_name = "_imported_module_" + uuid.uuid4().hex
222 | module = types.ModuleType(module_name)
223 | sys.modules[module_name] = module
224 | _module_to_src_dict[module] = src
225 | _src_to_module_dict[src] = module
226 | exec(src, module.__dict__) # pylint: disable=exec-used
227 | return module
228 |
229 | #----------------------------------------------------------------------------
230 |
231 | def _check_pickleable(obj):
232 | r"""Check that the given object is pickleable, raising an exception if
233 | it is not. This function is expected to be considerably more efficient
234 | than actually pickling the object.
235 | """
236 | def recurse(obj):
237 | if isinstance(obj, (list, tuple, set)):
238 | return [recurse(x) for x in obj]
239 | if isinstance(obj, dict):
240 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242 | return None # Python primitive types are pickleable.
243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
244 | return None # NumPy arrays and PyTorch tensors are pickleable.
245 | if is_persistent(obj):
246 | return None # Persistent objects are pickleable, by virtue of the constructor check.
247 | return obj
248 | with io.BytesIO() as f:
249 | pickle.dump(recurse(obj), f)
250 |
251 | #----------------------------------------------------------------------------
252 |
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for reporting and collecting training statistics across
10 | multiple processes and devices. The interface is designed to minimize
11 | synchronization overhead as well as the amount of boilerplate in user
12 | code."""
13 |
14 | import re
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 |
19 | from . import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
26 | _rank = 0 # Rank of the current process.
27 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
28 | _sync_called = False # Has _sync() been called yet?
29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def init_multiprocessing(rank, sync_device):
35 | r"""Initializes `torch_utils.training_stats` for collecting statistics
36 | across multiple processes.
37 |
38 | This function must be called after
39 | `torch.distributed.init_process_group()` and before `Collector.update()`.
40 | The call is not necessary if multi-process collection is not needed.
41 |
42 | Args:
43 | rank: Rank of the current process.
44 | sync_device: PyTorch device to use for inter-process
45 | communication, or None to disable multi-process
46 | collection. Typically `torch.device('cuda', rank)`.
47 | """
48 | global _rank, _sync_device
49 | assert not _sync_called
50 | _rank = rank
51 | _sync_device = sync_device
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | @misc.profiled_function
56 | def report(name, value):
57 | r"""Broadcasts the given set of scalars to all interested instances of
58 | `Collector`, across device and process boundaries.
59 |
60 | This function is expected to be extremely cheap and can be safely
61 | called from anywhere in the training loop, loss function, or inside a
62 | `torch.nn.Module`.
63 |
64 | Warning: The current implementation expects the set of unique names to
65 | be consistent across processes. Please make sure that `report()` is
66 | called at least once for each unique name by each process, and in the
67 | same order. If a given process has no scalars to broadcast, it can do
68 | `report(name, [])` (empty list).
69 |
70 | Args:
71 | name: Arbitrary string specifying the name of the statistic.
72 | Averages are accumulated separately for each unique name.
73 | value: Arbitrary set of scalars. Can be a list, tuple,
74 | NumPy array, PyTorch tensor, or Python scalar.
75 |
76 | Returns:
77 | The same `value` that was passed in.
78 | """
79 | if name not in _counters:
80 | _counters[name] = dict()
81 |
82 | elems = torch.as_tensor(value)
83 | if elems.numel() == 0:
84 | return value
85 |
86 | elems = elems.detach().flatten().to(_reduce_dtype)
87 | moments = torch.stack([
88 | torch.ones_like(elems).sum(),
89 | elems.sum(),
90 | elems.square().sum(),
91 | ])
92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
93 | moments = moments.to(_counter_dtype)
94 |
95 | device = moments.device
96 | if device not in _counters[name]:
97 | _counters[name][device] = torch.zeros_like(moments)
98 | _counters[name][device].add_(moments)
99 | return value
100 |
101 | #----------------------------------------------------------------------------
102 |
103 | def report0(name, value):
104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105 | but ignores any scalars provided by the other processes.
106 | See `report()` for further details.
107 | """
108 | report(name, value if _rank == 0 else [])
109 | return value
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | class Collector:
114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
115 | computes their long-term averages (mean and standard deviation) over
116 | user-defined periods of time.
117 |
118 | The averages are first collected into internal counters that are not
119 | directly visible to the user. They are then copied to the user-visible
120 | state as a result of calling `update()` and can then be queried using
121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122 | internal counters for the next round, so that the user-visible state
123 | effectively reflects averages collected between the last two calls to
124 | `update()`.
125 |
126 | Args:
127 | regex: Regular expression defining which statistics to
128 | collect. The default is to collect everything.
129 | keep_previous: Whether to retain the previous averages if no
130 | scalars were collected on a given round
131 | (default: True).
132 | """
133 | def __init__(self, regex='.*', keep_previous=True):
134 | self._regex = re.compile(regex)
135 | self._keep_previous = keep_previous
136 | self._cumulative = dict()
137 | self._moments = dict()
138 | self.update()
139 | self._moments.clear()
140 |
141 | def names(self):
142 | r"""Returns the names of all statistics broadcasted so far that
143 | match the regular expression specified at construction time.
144 | """
145 | return [name for name in _counters if self._regex.fullmatch(name)]
146 |
147 | def update(self):
148 | r"""Copies current values of the internal counters to the
149 | user-visible state and resets them for the next round.
150 |
151 | If `keep_previous=True` was specified at construction time, the
152 | operation is skipped for statistics that have received no scalars
153 | since the last update, retaining their previous averages.
154 |
155 | This method performs a number of GPU-to-CPU transfers and one
156 | `torch.distributed.all_reduce()`. It is intended to be called
157 | periodically in the main training loop, typically once every
158 | N training steps.
159 | """
160 | if not self._keep_previous:
161 | self._moments.clear()
162 | for name, cumulative in _sync(self.names()):
163 | if name not in self._cumulative:
164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165 | delta = cumulative - self._cumulative[name]
166 | self._cumulative[name].copy_(cumulative)
167 | if float(delta[0]) != 0:
168 | self._moments[name] = delta
169 |
170 | def _get_delta(self, name):
171 | r"""Returns the raw moments that were accumulated for the given
172 | statistic between the last two calls to `update()`, or zero if
173 | no scalars were collected.
174 | """
175 | assert self._regex.fullmatch(name)
176 | if name not in self._moments:
177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178 | return self._moments[name]
179 |
180 | def num(self, name):
181 | r"""Returns the number of scalars that were accumulated for the given
182 | statistic between the last two calls to `update()`, or zero if
183 | no scalars were collected.
184 | """
185 | delta = self._get_delta(name)
186 | return int(delta[0])
187 |
188 | def mean(self, name):
189 | r"""Returns the mean of the scalars that were accumulated for the
190 | given statistic between the last two calls to `update()`, or NaN if
191 | no scalars were collected.
192 | """
193 | delta = self._get_delta(name)
194 | if int(delta[0]) == 0:
195 | return float('nan')
196 | return float(delta[1] / delta[0])
197 |
198 | def std(self, name):
199 | r"""Returns the standard deviation of the scalars that were
200 | accumulated for the given statistic between the last two calls to
201 | `update()`, or NaN if no scalars were collected.
202 | """
203 | delta = self._get_delta(name)
204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205 | return float('nan')
206 | if int(delta[0]) == 1:
207 | return float(0)
208 | mean = float(delta[1] / delta[0])
209 | raw_var = float(delta[2] / delta[0])
210 | return np.sqrt(max(raw_var - np.square(mean), 0))
211 |
212 | def as_dict(self):
213 | r"""Returns the averages accumulated between the last two calls to
214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215 |
216 | dnnlib.EasyDict(
217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218 | ...
219 | )
220 | """
221 | stats = dnnlib.EasyDict()
222 | for name in self.names():
223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224 | return stats
225 |
226 | def __getitem__(self, name):
227 | r"""Convenience getter.
228 | `collector[name]` is a synonym for `collector.mean(name)`.
229 | """
230 | return self.mean(name)
231 |
232 | #----------------------------------------------------------------------------
233 |
234 | def _sync(names):
235 | r"""Synchronize the global cumulative counters across devices and
236 | processes. Called internally by `Collector.update()`.
237 | """
238 | if len(names) == 0:
239 | return []
240 | global _sync_called
241 | _sync_called = True
242 |
243 | # Collect deltas within current rank.
244 | deltas = []
245 | device = _sync_device if _sync_device is not None else torch.device('cpu')
246 | for name in names:
247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248 | for counter in _counters[name].values():
249 | delta.add_(counter.to(device))
250 | counter.copy_(torch.zeros_like(counter))
251 | deltas.append(delta)
252 | deltas = torch.stack(deltas)
253 |
254 | # Sum deltas across ranks.
255 | if _sync_device is not None:
256 | torch.distributed.all_reduce(deltas)
257 |
258 | # Update cumulative values.
259 | deltas = deltas.cpu()
260 | for idx, name in enumerate(names):
261 | if name not in _cumulative:
262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263 | _cumulative[name].add_(deltas[idx])
264 |
265 | # Return name-value pairs.
266 | return [(name, _cumulative[name]) for name in names]
267 |
268 | #----------------------------------------------------------------------------
269 |
--------------------------------------------------------------------------------
/torch_utils/utils_spectrum.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.fft import fftn
3 |
4 |
5 | def roll_quadrants(data, backwards=False):
6 | """
7 | Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1]
8 | Args:
9 | data: fourier transform, (NxHxW)
10 | backwards: bool, if True shift high frequencies back to center
11 |
12 | Returns:
13 | Shifted fourier transform.
14 | """
15 | dim = data.ndim - 1
16 |
17 | if dim != 2:
18 | raise AttributeError(f'Data must be 2d but it is {dim}d.')
19 | if any(s % 2 == 0 for s in data.shape[1:]):
20 | raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.')
21 |
22 | # for each dimension swap left and right half
23 | dims = tuple(range(1, dim+1)) # add one for batch dimension
24 | shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd
25 | if backwards:
26 | shifts *= -1
27 | return data.roll(shifts.tolist(), dims=dims)
28 |
29 |
30 | def batch_fft(data, normalize=False):
31 | """
32 | Compute fourier transform of batch.
33 | Args:
34 | data: input tensor, (NxHxW)
35 |
36 | Returns:
37 | Batch fourier transform of input data.
38 | """
39 |
40 | dim = data.ndim - 1 # subtract one for batch dimension
41 | if dim != 2:
42 | raise AttributeError(f'Data must be 2d but it is {dim}d.')
43 |
44 | dims = tuple(range(1, dim + 1)) # add one for batch dimension
45 | if normalize:
46 | norm = 'ortho'
47 | else:
48 | norm = 'backward'
49 |
50 | if not torch.is_complex(data):
51 | data = torch.complex(data, torch.zeros_like(data))
52 | freq = fftn(data, dim=dims, norm=norm)
53 |
54 | return freq
55 |
56 |
57 | def azimuthal_average(image, center=None):
58 | # modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
59 | """
60 | Calculate the azimuthally averaged radial profile.
61 | Requires low frequencies to be at the center of the image.
62 | Args:
63 | image: Batch of 2D images, NxHxW
64 | center: The [x,y] pixel coordinates used as the center. The default is
65 | None, which then uses the center of the image (including
66 | fracitonal pixels).
67 |
68 | Returns:
69 | Azimuthal average over the image around the center
70 | """
71 | # Check input shapes
72 | assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \
73 | f'(but it is len(center)={len(center)}.'
74 | # Calculate the indices from the image
75 | H, W = image.shape[-2:]
76 | h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
77 |
78 | if center is None:
79 | center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0])
80 |
81 | # Compute radius for each pixel wrt center
82 | r = torch.stack([w-center[0], h-center[1]]).norm(2, 0)
83 |
84 | # Get sorted radii
85 | r_sorted, ind = r.flatten().sort()
86 | i_sorted = image.flatten(-2, -1)[..., ind]
87 |
88 | # Get the integer part of the radii (bin size = 1)
89 | r_int = r_sorted.long() # attribute to the smaller integer
90 |
91 | # Find all pixels that fall within each radial bin.
92 | deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii
93 | rind = torch.where(deltar)[0] # location of changed radius
94 |
95 | # compute number of elements in each bin
96 | nind = rind + 1 # number of elements = idx + 1
97 | nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders
98 | nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius
99 |
100 | # Cumulative sum to figure out sums for each radius bin
101 | if H % 2 == 0:
102 | raise NotImplementedError('Not sure if implementation correct, please check')
103 | rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders
104 | else:
105 | rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders
106 | csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius
107 | tbin = csim[..., rind[1:]] - csim[..., rind[:-1]]
108 | # add mean
109 | tbin = torch.cat([csim[:, 0:1], tbin], 1)
110 |
111 | radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins
112 |
113 | return radial_prof
114 |
115 |
116 | def get_spectrum(data, normalize=False):
117 | dim = data.ndim - 1 # subtract one for batch dimension
118 | if dim != 2:
119 | raise AttributeError(f'Data must be 2d but it is {dim}d.')
120 |
121 | freq = batch_fft(data, normalize=normalize)
122 | power_spec = freq.real ** 2 + freq.imag ** 2
123 | N = data.shape[1]
124 | if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum
125 | # and is not averaged with the mean value
126 | N_2 = N//2
127 | power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1)
128 | power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2)
129 |
130 | power_spec = roll_quadrants(power_spec)
131 | power_spec = azimuthal_average(power_spec)
132 | return power_spec
133 |
134 |
135 | def plot_std(mean, std, x=None, ax=None, **kwargs):
136 | import matplotlib.pyplot as plt
137 | if ax is None:
138 | fig, ax = plt.subplots(1)
139 |
140 | # plot error margins in same color as line
141 | err_kwargs = {
142 | 'alpha': 0.3
143 | }
144 |
145 | if 'c' in kwargs.keys():
146 | err_kwargs['color'] = kwargs['c']
147 | elif 'color' in kwargs.keys():
148 | err_kwargs['color'] = kwargs['color']
149 |
150 | if x is None:
151 | x = torch.linspace(0, 1, len(mean)) # use normalized x axis
152 | ax.plot(x, mean, **kwargs)
153 | ax.fill_between(x, mean-std, mean+std, **err_kwargs)
154 |
155 | return ax
156 |
--------------------------------------------------------------------------------
/utils/CRDiffAug.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def CR_DiffAug(x, flip=True, translation=True, color=True, cutout=True):
9 | if flip:
10 | x = random_flip(x, 0.5)
11 | if translation:
12 | x = rand_translation(x, 1/8)
13 | if color:
14 | aug_list = [rand_brightness, rand_saturation, rand_contrast]
15 | for func in aug_list:
16 | x = func(x)
17 | if cutout:
18 | x = rand_cutout(x)
19 | if flip or translation:
20 | x = x.contiguous()
21 | return x
22 |
23 |
24 | def random_flip(x, p):
25 | x_out = x.clone()
26 | n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
27 | flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0)
28 | flip_mask = flip_prob < p
29 | flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w).to(x.device)
30 | x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1)
31 | return x_out
32 |
33 |
34 | def rand_brightness(x):
35 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
36 | return x
37 |
38 |
39 | def rand_saturation(x):
40 | x_mean = x.mean(dim=1, keepdim=True)
41 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
42 | return x
43 |
44 |
45 | def rand_contrast(x):
46 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
47 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
48 | return x
49 |
50 |
51 | def rand_translation(x, ratio=0.125):
52 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
53 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
54 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
55 | grid_batch, grid_x, grid_y = torch.meshgrid(
56 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
57 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
58 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
59 | indexing='ij'
60 | )
61 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
62 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
63 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
64 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
65 | return x
66 |
67 |
68 | def rand_cutout(x, ratio=0.5):
69 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
70 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
71 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
72 | grid_batch, grid_x, grid_y = torch.meshgrid(
73 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
74 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
75 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
76 | indexing='ij'
77 | )
78 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
79 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
80 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
81 | mask[grid_batch, grid_x, grid_y] = 0
82 | x = x * mask.unsqueeze(1)
83 | return x
84 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | '''
2 | From StyleSwin
3 | '''
4 |
5 | import pickle
6 |
7 | import torch
8 | from torch import distributed as dist
9 |
10 |
11 | def get_rank():
12 | if not dist.is_available():
13 | return 0
14 |
15 | if not dist.is_initialized():
16 | return 0
17 |
18 | return dist.get_rank()
19 |
20 |
21 | def synchronize():
22 | if not dist.is_available():
23 | return
24 |
25 | if not dist.is_initialized():
26 | return
27 |
28 | world_size = dist.get_world_size()
29 |
30 | if world_size == 1:
31 | return
32 |
33 | dist.barrier()
34 |
35 |
36 | def get_world_size():
37 | if not dist.is_available():
38 | return 1
39 |
40 | if not dist.is_initialized():
41 | return 1
42 |
43 | return dist.get_world_size()
44 |
45 |
46 | def reduce_sum(tensor):
47 | if not dist.is_available():
48 | return tensor
49 |
50 | if not dist.is_initialized():
51 | return tensor
52 |
53 | tensor = tensor.clone()
54 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
55 |
56 | return tensor
57 |
58 |
59 | def gather_grad(params):
60 | world_size = get_world_size()
61 |
62 | if world_size == 1:
63 | return
64 |
65 | for param in params:
66 | if param.grad is not None:
67 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
68 | param.grad.data.div_(world_size)
69 |
70 |
71 | def all_gather(data):
72 | world_size = get_world_size()
73 |
74 | if world_size == 1:
75 | return [data]
76 |
77 | buffer = pickle.dumps(data)
78 | storage = torch.ByteStorage.from_buffer(buffer)
79 | tensor = torch.ByteTensor(storage).to('cuda')
80 |
81 | local_size = torch.IntTensor([tensor.numel()]).to('cuda')
82 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
83 | dist.all_gather(size_list, local_size)
84 | size_list = [int(size.item()) for size in size_list]
85 | max_size = max(size_list)
86 |
87 | tensor_list = []
88 | for _ in size_list:
89 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
90 |
91 | if local_size != max_size:
92 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
93 | tensor = torch.cat((tensor, padding), 0)
94 |
95 | dist.all_gather(tensor_list, tensor)
96 |
97 | data_list = []
98 |
99 | for size, tensor in zip(size_list, tensor_list):
100 | buffer = tensor.cpu().numpy().tobytes()[:size]
101 | data_list.append(pickle.loads(buffer))
102 |
103 | return data_list
104 |
105 |
106 | def reduce_loss_dict(loss_dict):
107 | world_size = get_world_size()
108 |
109 | if world_size < 2:
110 | return loss_dict
111 |
112 | with torch.no_grad():
113 | keys = []
114 | losses = []
115 |
116 | for k in sorted(loss_dict.keys()):
117 | keys.append(k)
118 | losses.append(loss_dict[k])
119 |
120 | losses = torch.stack(losses, 0)
121 | dist.reduce(losses, dst=0)
122 |
123 | if dist.get_rank() == 0:
124 | losses /= world_size
125 |
126 | reduced_losses = {k: v for k, v in zip(keys, losses)}
127 |
128 | return reduced_losses
129 |
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import os
3 | import random
4 | import warnings
5 | import logging
6 | from omegaconf import OmegaConf, open_dict
7 | import torch
8 | from rich import print
9 | # For legacy
10 | import argparse
11 |
12 | _allowed_run_types=['train', 'inference', 'evaluate', 'attention_map', 'throughput']
13 |
14 | def check_and_set_hydra(args, key : str, value : Any) -> None:
15 | if hasattr(args, key):
16 | args['key'] = value
17 | else:
18 | with open_dict(args):
19 | args['key'] = value
20 | logging.info(f"{args}.{key} = {value}")
21 |
22 | def validate_args(args):
23 | '''
24 | Check some of the args and do some sanity checking
25 | We'll define default values here so that users don't need to
26 | set them themselves. Reduce user burden, reduce user error.
27 | '''
28 | assert(args.type in _allowed_run_types),f"Type must be from {_allowed_run_types} but got {args.type}"
29 | arg_keys = args.keys()
30 | with open_dict(args):
31 | if "rank" not in arg_keys: check_and_set_hydra(args,"rank",0)
32 | if "device" not in arg_keys: check_and_set_hydra(args,"device","cpu")
33 | if "world_size" not in arg_keys: check_and_set_hydra(args,"world_size",1)
34 | if "rank" not in arg_keys: check_and_set_hydra(args,"rank",0)
35 | if "local_rank" not in arg_keys: check_and_set_hydra(args,"local_rank",0)
36 | if "distributed" not in arg_keys: #args.distributed = False
37 | if "WORLD_SIZE" in os.environ:
38 | # Single node multi GPU
39 | n_gpu = int(os.environ["WORLD_SIZE"])
40 | else:
41 | n_gpu = torch.cuda.device_count()
42 | check_and_set_hydra(args,"distributed",n_gpu > 1)
43 | if "workers" not in arg_keys: args.workers = 0
44 | # Validate training args
45 | if args.type == "train":
46 | # logging
47 | assert(hasattr(args, "logging"))
48 | if "print_freq" not in args.logging:
49 | check_and_set_hydra(args.logging,"print_freq",10000)
50 | if "eval_freq" not in args.logging:
51 | check_and_set_hydra(args.logging,"eval_freq",-1)
52 | if "save_freq" not in args.logging:
53 | check_and_set_hydra(args.logging,"save_freq",-1)
54 | if "checkpoint_path" in args.logging:
55 | if args.logging.checkpoint_path[-1] != "/":
56 | args.logging.checkpoint_path += "/"
57 | else:
58 | check_and_set_hydra(args.logging,"checkpoint_path","./")
59 | if "sample_path" in args.logging:
60 | if args.logging.sample_path[-1] != "/":
61 | args.logging.sample_path += "/"
62 | else:
63 | check_and_set_hydra(args.logging, "sample_path", "./")
64 | if "reuse_samplepath" not in args.logging:
65 | check_and_set_hydra(args.logging,"reuse_samplepath",False)
66 | if args.type == "evaluation" or args.type == "train":
67 | assert(hasattr(args.evaluation, "gt_path")),f"You must specify "\
68 | f"the ground truth data path"
69 | assert(hasattr(args.evaluation, "total_size")),f"You must specify "\
70 | f"the number of images for FID"
71 | if "batch" not in args.evaluation:
72 | check_and_set_hydra(args.evaluation,"batch", 1)
73 | if "save_root" not in args:
74 | check_and_set_hydra(args,"save_root","/tmp/")
75 | if args.type == "training":
76 | logging.warning("Save root path not set, using /tmp")
77 | if args.save_root[-1] != "/":
78 | args.save_root += "/"
79 | if args.type == "inference":
80 | if "batch" not in args.inference:
81 | check_and_set_hydra(args.inference,"batch",1)
82 | # assert(hasattr(args.evaluation, "gt_path"))
83 | if "misc" not in args.keys():
84 | check_and_set_hydra(args,"misc",{})
85 | if "seed" not in args.misc:
86 | check_and_set_hydra(args.misc,"seed",None)
87 | if "rng_state" not in args.misc:
88 | check_and_set_hydra(args.misc,"rng_state",None)
89 | if "py_rng_state" not in args.misc:
90 | check_and_set_hydra(args.misc,"rng_state",None)
91 |
92 | def rng_reproducibility(args, ckpt=None):
93 | # Store RNG info
94 | # Cumbersome but reproducibility is not super easy
95 | with open_dict(args):
96 | if args.misc.seed is None:
97 | args.misc.seed = torch.initial_seed()
98 | else:
99 | torch.manual_seed(args.misc.seed)
100 | if args.misc.rng_state is None:
101 | args.misc.rng_state = torch.get_rng_state().tolist()
102 | else:
103 | torch.set_rng_state(args.misc.rng_state)
104 | if args.misc.py_rng_state is None:
105 | args.misc.py_rng_state = random.getstate()
106 | else:
107 | random.setstate(args.misc.py_rng_state)
108 |
109 | if ckpt is not None \
110 | and "reuse_rng" in args.restart \
111 | and args.restart.reuse_rng:
112 | with open_dict(args):
113 | #if "misc" in ckpt['args'].keys() and "seed" in ckpt['args']['misc'].keys():
114 | try:
115 | if hasattr(ckpt['args'], "misc"):
116 | if hasattr(ckpt['args']['misc'], "seed"):
117 | try:
118 | args.misc.seed = ckpt['args']['misc']['seed']
119 | print(f"[bold green]RNG Seed successfully loaded")
120 | except:
121 | print("[bold yellow]Seed couldn't be loaded (new style ckpt)")
122 | else:
123 | print("[bold yellow]Couldn't find seed (new style ckpt)")
124 | if hasattr(ckpt['args']['misc'], 'rng_state'):
125 | try:
126 | args.misc.rng_state = ckpt['args']['misc']['rng_state']
127 | print(f"[bold green]RNG State successfully loaded")
128 | except:
129 | print("[bold yellow]RNG State couldn't be loaded (new style ckpt)")
130 | else:
131 | print("[bold yellow]Couldn't find RNG State (new style ckpt)")
132 | if hasattr(ckpt['args']['misc'], 'py_rng_state'):
133 | try:
134 | args.misc.py_rng_state = ckpt['args']['misc']['py_rng_state']
135 | print(f"[bold green] Py-RNG State successfully loaded")
136 | except:
137 | print("[bold yellow]Py-RNG State couldn't be loaded (new style ckpt)")
138 | else:
139 | print("[bold yellow]Couldn't find Py-RNG State (new style ckpt)")
140 | elif type(ckpt['args']) == argparse.Namespace:
141 | try:
142 | args.misc.seed = ckpt['args'].seed
143 | print(f"[bold green]RNG Seed successfully loaded")
144 | except:
145 | print("[bold yellow]Seed couldn't be loaded (old style ckpt)")
146 | try:
147 | args.misc.rng_state = ckpt['args'].rng_state.tolist()
148 | print(f"[bold green]RNG State successfully loaded")
149 | except:
150 | print("[bold yellow]RNG State couldn't be loaded (old style ckpt)")
151 | try:
152 | args.misc.py_rng_state = ckpt['args'].rng_state.tolist()
153 | print(f"[bold green] Py-RNG State successfully loaded")
154 | except:
155 | print("[bold yellow]Py-RNG State couldn't be loaded (old style ckpt)")
156 | else:
157 | print("[bold yellow]No rng loading. {type(ckpt['args'])=}")
158 | except:
159 | print("[bold yellow]Seeds couldn't be loaded and don't know why")
160 | print(f"[bold yellow]\t {type(ckpt['args'])}")
161 | print(f"{type(ckpt['args']) == argparse.Namespace}")
162 | try:
163 | torch.manual_seed(args.misc.seed)
164 | _seed = "[bold green]True[/]"
165 | except:
166 | print("[bold yellow]Unable to set manual_seed")
167 | _seed = "[bold red]False[/]"
168 | try:
169 | torch.set_rng_state(torch.as_tensor(
170 | args.misc.rng_state, dtype=torch.uint8),
171 | )
172 | _pt_rng = "[bold green]True[/]"
173 | except:
174 | print("[bold yellow]Unable to set ptyroch's rng state")
175 | _pt_rng = "[bold red]False[/]"
176 | try:
177 | l = tuple(args.misc.py_rng_state)
178 | random.setstate((l[0], tuple(l[1]), l[2]))
179 | _py_rng = "[bold green]True[/]"
180 | except:
181 | print("[bold yellow]Unable to set python's rng state")
182 | _py_rng = "[bold red]False[/]"
183 | print(f"[bold green]RNG Loading Success States:[/]\n"\
184 | f"\tSeed: {_seed}, PyTorch RNG: {_pt_rng}, Python RNG {_py_rng}")
185 |
--------------------------------------------------------------------------------
/utils/inception.py:
--------------------------------------------------------------------------------
1 | '''
2 | From mseitzer's Pytorch-FID
3 | https://github.com/mseitzer/pytorch-fid
4 | License included in fid_score
5 | '''
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torchvision
11 |
12 | try:
13 | from torchvision.models.utils import load_state_dict_from_url
14 | except ImportError:
15 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
16 |
17 | # Inception weights ported to Pytorch from
18 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
19 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
20 |
21 |
22 | class InceptionV3(nn.Module):
23 | """Pretrained InceptionV3 network returning feature maps"""
24 |
25 | # Index of default block of inception to return,
26 | # corresponds to output of final average pooling
27 | DEFAULT_BLOCK_INDEX = 3
28 |
29 | # Maps feature dimensionality to their output blocks indices
30 | BLOCK_INDEX_BY_DIM = {
31 | 64: 0, # First max pooling features
32 | 192: 1, # Second max pooling featurs
33 | 768: 2, # Pre-aux classifier features
34 | 2048: 3 # Final average pooling features
35 | }
36 |
37 | def __init__(self,
38 | output_blocks=(DEFAULT_BLOCK_INDEX,),
39 | resize_input=True,
40 | normalize_input=True,
41 | requires_grad=False,
42 | use_fid_inception=True):
43 | """Build pretrained InceptionV3
44 |
45 | Parameters
46 | ----------
47 | output_blocks : list of int
48 | Indices of blocks to return features of. Possible values are:
49 | - 0: corresponds to output of first max pooling
50 | - 1: corresponds to output of second max pooling
51 | - 2: corresponds to output which is fed to aux classifier
52 | - 3: corresponds to output of final average pooling
53 | resize_input : bool
54 | If true, bilinearly resizes input to width and height 299 before
55 | feeding input to model. As the network without fully connected
56 | layers is fully convolutional, it should be able to handle inputs
57 | of arbitrary size, so resizing might not be strictly needed
58 | normalize_input : bool
59 | If true, scales the input from range (0, 1) to the range the
60 | pretrained Inception network expects, namely (-1, 1)
61 | requires_grad : bool
62 | If true, parameters of the model require gradients. Possibly useful
63 | for finetuning the network
64 | use_fid_inception : bool
65 | If true, uses the pretrained Inception model used in Tensorflow's
66 | FID implementation. If false, uses the pretrained Inception model
67 | available in torchvision. The FID Inception model has different
68 | weights and a slightly different structure from torchvision's
69 | Inception model. If you want to compute FID scores, you are
70 | strongly advised to set this parameter to true to get comparable
71 | results.
72 | """
73 | super(InceptionV3, self).__init__()
74 |
75 | self.resize_input = resize_input
76 | self.normalize_input = normalize_input
77 | self.output_blocks = sorted(output_blocks)
78 | self.last_needed_block = max(output_blocks)
79 |
80 | assert self.last_needed_block <= 3, \
81 | 'Last possible output block index is 3'
82 |
83 | self.blocks = nn.ModuleList()
84 |
85 | if use_fid_inception:
86 | inception = fid_inception_v3()
87 | else:
88 | inception = _inception_v3(pretrained=True)
89 |
90 | # Block 0: input to maxpool1
91 | block0 = [
92 | inception.Conv2d_1a_3x3,
93 | inception.Conv2d_2a_3x3,
94 | inception.Conv2d_2b_3x3,
95 | nn.MaxPool2d(kernel_size=3, stride=2)
96 | ]
97 | self.blocks.append(nn.Sequential(*block0))
98 |
99 | # Block 1: maxpool1 to maxpool2
100 | if self.last_needed_block >= 1:
101 | block1 = [
102 | inception.Conv2d_3b_1x1,
103 | inception.Conv2d_4a_3x3,
104 | nn.MaxPool2d(kernel_size=3, stride=2)
105 | ]
106 | self.blocks.append(nn.Sequential(*block1))
107 |
108 | # Block 2: maxpool2 to aux classifier
109 | if self.last_needed_block >= 2:
110 | block2 = [
111 | inception.Mixed_5b,
112 | inception.Mixed_5c,
113 | inception.Mixed_5d,
114 | inception.Mixed_6a,
115 | inception.Mixed_6b,
116 | inception.Mixed_6c,
117 | inception.Mixed_6d,
118 | inception.Mixed_6e,
119 | ]
120 | self.blocks.append(nn.Sequential(*block2))
121 |
122 | # Block 3: aux classifier to final avgpool
123 | if self.last_needed_block >= 3:
124 | block3 = [
125 | inception.Mixed_7a,
126 | inception.Mixed_7b,
127 | inception.Mixed_7c,
128 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
129 | ]
130 | self.blocks.append(nn.Sequential(*block3))
131 |
132 | for param in self.parameters():
133 | param.requires_grad = requires_grad
134 |
135 | def forward(self, inp):
136 | """Get Inception feature maps
137 |
138 | Parameters
139 | ----------
140 | inp : torch.autograd.Variable
141 | Input tensor of shape Bx3xHxW. Values are expected to be in
142 | range (0, 1)
143 |
144 | Returns
145 | -------
146 | List of torch.autograd.Variable, corresponding to the selected output
147 | block, sorted ascending by index
148 | """
149 | outp = []
150 | x = inp
151 |
152 | if self.resize_input:
153 | x = F.interpolate(x,
154 | size=(299, 299),
155 | mode='bilinear',
156 | align_corners=False)
157 |
158 | if self.normalize_input:
159 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
160 |
161 | for idx, block in enumerate(self.blocks):
162 | x = block(x)
163 | if idx in self.output_blocks:
164 | outp.append(x)
165 |
166 | if idx == self.last_needed_block:
167 | break
168 |
169 | return outp
170 |
171 |
172 | def _inception_v3(*args, **kwargs):
173 | """Wraps `torchvision.models.inception_v3`
174 |
175 | Skips default weight inititialization if supported by torchvision version.
176 | See https://github.com/mseitzer/pytorch-fid/issues/28.
177 | """
178 | try:
179 | version = tuple(map(int, torchvision.__version__.split('.')[:2]))
180 | except ValueError:
181 | # Just a caution against weird version strings
182 | version = (0,)
183 |
184 | if version >= (0, 6):
185 | kwargs['init_weights'] = False
186 |
187 | return torchvision.models.inception_v3(*args, **kwargs)
188 |
189 |
190 | def fid_inception_v3():
191 | """Build pretrained Inception model for FID computation
192 |
193 | The Inception model for FID computation uses a different set of weights
194 | and has a slightly different structure than torchvision's Inception.
195 |
196 | This method first constructs torchvision's Inception and then patches the
197 | necessary parts that are different in the FID Inception model.
198 | """
199 | inception = _inception_v3(num_classes=1008,
200 | aux_logits=False,
201 | pretrained=False)
202 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
203 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
204 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
205 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
206 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
207 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
208 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
209 | inception.Mixed_7b = FIDInceptionE_1(1280)
210 | inception.Mixed_7c = FIDInceptionE_2(2048)
211 |
212 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
213 | inception.load_state_dict(state_dict)
214 | return inception
215 |
216 |
217 | class FIDInceptionA(torchvision.models.inception.InceptionA):
218 | """InceptionA block patched for FID computation"""
219 | def __init__(self, in_channels, pool_features):
220 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
221 |
222 | def forward(self, x):
223 | branch1x1 = self.branch1x1(x)
224 |
225 | branch5x5 = self.branch5x5_1(x)
226 | branch5x5 = self.branch5x5_2(branch5x5)
227 |
228 | branch3x3dbl = self.branch3x3dbl_1(x)
229 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
230 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
231 |
232 | # Patch: Tensorflow's average pool does not use the padded zero's in
233 | # its average calculation
234 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
235 | count_include_pad=False)
236 | branch_pool = self.branch_pool(branch_pool)
237 |
238 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
239 | return torch.cat(outputs, 1)
240 |
241 |
242 | class FIDInceptionC(torchvision.models.inception.InceptionC):
243 | """InceptionC block patched for FID computation"""
244 | def __init__(self, in_channels, channels_7x7):
245 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
246 |
247 | def forward(self, x):
248 | branch1x1 = self.branch1x1(x)
249 |
250 | branch7x7 = self.branch7x7_1(x)
251 | branch7x7 = self.branch7x7_2(branch7x7)
252 | branch7x7 = self.branch7x7_3(branch7x7)
253 |
254 | branch7x7dbl = self.branch7x7dbl_1(x)
255 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
256 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
257 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
258 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
259 |
260 | # Patch: Tensorflow's average pool does not use the padded zero's in
261 | # its average calculation
262 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
263 | count_include_pad=False)
264 | branch_pool = self.branch_pool(branch_pool)
265 |
266 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
267 | return torch.cat(outputs, 1)
268 |
269 |
270 | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
271 | """First InceptionE block patched for FID computation"""
272 | def __init__(self, in_channels):
273 | super(FIDInceptionE_1, self).__init__(in_channels)
274 |
275 | def forward(self, x):
276 | branch1x1 = self.branch1x1(x)
277 |
278 | branch3x3 = self.branch3x3_1(x)
279 | branch3x3 = [
280 | self.branch3x3_2a(branch3x3),
281 | self.branch3x3_2b(branch3x3),
282 | ]
283 | branch3x3 = torch.cat(branch3x3, 1)
284 |
285 | branch3x3dbl = self.branch3x3dbl_1(x)
286 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
287 | branch3x3dbl = [
288 | self.branch3x3dbl_3a(branch3x3dbl),
289 | self.branch3x3dbl_3b(branch3x3dbl),
290 | ]
291 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
292 |
293 | # Patch: Tensorflow's average pool does not use the padded zero's in
294 | # its average calculation
295 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
296 | count_include_pad=False)
297 | branch_pool = self.branch_pool(branch_pool)
298 |
299 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
300 | return torch.cat(outputs, 1)
301 |
302 |
303 | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
304 | """Second InceptionE block patched for FID computation"""
305 | def __init__(self, in_channels):
306 | super(FIDInceptionE_2, self).__init__(in_channels)
307 |
308 | def forward(self, x):
309 | branch1x1 = self.branch1x1(x)
310 |
311 | branch3x3 = self.branch3x3_1(x)
312 | branch3x3 = [
313 | self.branch3x3_2a(branch3x3),
314 | self.branch3x3_2b(branch3x3),
315 | ]
316 | branch3x3 = torch.cat(branch3x3, 1)
317 |
318 | branch3x3dbl = self.branch3x3dbl_1(x)
319 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
320 | branch3x3dbl = [
321 | self.branch3x3dbl_3a(branch3x3dbl),
322 | self.branch3x3dbl_3b(branch3x3dbl),
323 | ]
324 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
325 |
326 | # Patch: The FID Inception model uses max pooling instead of average
327 | # pooling. This is likely an error in this specific Inception
328 | # implementation, as other Inception models use average pooling here
329 | # (which matches the description in the paper).
330 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
331 | branch_pool = self.branch_pool(branch_pool)
332 |
333 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
334 | return torch.cat(outputs, 1)
335 |
--------------------------------------------------------------------------------