├── .gitignore
├── 3DDFA_V2_cropping
├── crop_samples
│ └── img
│ │ ├── dataset.json
│ │ ├── man.jpg
│ │ └── woman.jpg
├── cropping_guide.md
├── dlib_kps.py
├── recrop_images.py
└── test
│ └── original
│ ├── man.jpg
│ └── woman.jpg
├── LICENSES
├── LICENSE.md
└── LICENSE_EG3D
├── README.md
├── calc_mbs.py
├── calc_metrics.py
├── camera_utils.py
├── dataset
├── testdata_img.zip
├── testdata_img
│ ├── 000134.jpg
│ ├── 000157.jpg
│ └── dataset.json
├── testdata_seg.zip
└── testdata_seg
│ ├── 000134.png
│ └── 000157.png
├── dataset_tool.py
├── dataset_tool_seg.py
├── dnnlib
├── __init__.py
└── util.py
├── environment.yml
├── gen_interpolation.py
├── gen_pti_script.sh
├── gen_samples.py
├── gen_samples_forID.py
├── gen_videos.py
├── gen_videos_interp.py
├── gen_videos_proj.py
├── gen_videos_proj_withseg.py
├── get_metrics.sh
├── gui_utils
├── __init__.py
├── gl_utils.py
├── glfw_window.py
├── imgui_utils.py
├── imgui_window.py
└── text_utils.py
├── legacy.py
├── metrics
├── __init__.py
├── equivariance.py
├── frechet_inception_distance.py
├── inception_score.py
├── kernel_inception_distance.py
├── metric_main.py
├── metric_utils.py
├── perceptual_path_length.py
└── precision_recall.py
├── misc
├── segmentation_example.py
└── teaser.png
├── projector.py
├── projector_withseg.py
├── resave_model.py
├── shape_utils.py
├── torch_utils
├── __init__.py
├── custom_ops.py
├── misc.py
├── ops
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── filtered_lrelu.cpp
│ ├── filtered_lrelu.cu
│ ├── filtered_lrelu.h
│ ├── filtered_lrelu.py
│ ├── filtered_lrelu_ns.cu
│ ├── filtered_lrelu_rd.cu
│ ├── filtered_lrelu_wr.cu
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
├── persistence.py
└── training_stats.py
├── train.py
└── training
├── __init__.py
├── augment.py
├── crosssection_utils.py
├── dataset.py
├── dual_discriminator.py
├── loss.py
├── networks_stylegan2.py
├── networks_stylegan3.py
├── superresolution.py
├── training_loop.py
├── triplane.py
├── utils.py
└── volumetric_rendering
├── __init__.py
├── math_utils.py
├── ray_marcher.py
├── ray_sampler.py
└── renderer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # paper results
163 | *.mp4
164 | out/
165 | pti_out/
166 | interpolation_out/
167 | models/
168 |
169 | # viz
170 | viz/
171 | visualizer.py
--------------------------------------------------------------------------------
/3DDFA_V2_cropping/crop_samples/img/dataset.json:
--------------------------------------------------------------------------------
1 | {
2 | "labels": [
3 | [
4 | "man.jpg",
5 | [
6 | "0.997649",
7 | "-0.015475",
8 | "0.064940",
9 | "-0.175337",
10 | "-0.015415",
11 | "-0.992914",
12 | "-0.118841",
13 | "0.320872",
14 | "0.067236",
15 | "0.118079",
16 | "-0.990818",
17 | "2.675209",
18 | "0.000000",
19 | "0.000000",
20 | "0.000000",
21 | "1.000000",
22 | "4.264700",
23 | "0.000000",
24 | "0.500000",
25 | "0.000000",
26 | "4.264700",
27 | "0.500000",
28 | "0.000000",
29 | "0.000000",
30 | "1.000000"
31 | ]
32 | ],
33 | [
34 | "woman.jpg",
35 | [
36 | "0.999654",
37 | "0.015362",
38 | "-0.000774",
39 | "0.002091",
40 | "0.029684",
41 | "-0.987381",
42 | "-0.158344",
43 | "0.427528",
44 | "-0.005544",
45 | "0.158313",
46 | "-0.987495",
47 | "2.666237",
48 | "0.000000",
49 | "0.000000",
50 | "0.000000",
51 | "1.000000",
52 | "4.264700",
53 | "0.000000",
54 | "0.500000",
55 | "0.000000",
56 | "4.264700",
57 | "0.500000",
58 | "0.000000",
59 | "0.000000",
60 | "1.000000"
61 | ]
62 | ]
63 | ]
64 | }
--------------------------------------------------------------------------------
/3DDFA_V2_cropping/crop_samples/img/man.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/3DDFA_V2_cropping/crop_samples/img/man.jpg
--------------------------------------------------------------------------------
/3DDFA_V2_cropping/crop_samples/img/woman.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/3DDFA_V2_cropping/crop_samples/img/woman.jpg
--------------------------------------------------------------------------------
/3DDFA_V2_cropping/cropping_guide.md:
--------------------------------------------------------------------------------
1 | Cropping and obtaining camera poses guide using [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2).
2 |
3 | Our cropping file [recrop_images.py](./recrop_images.py) is based on 3DDFA_V2 and dlib. First of all, please clone the 3DDFA_V2 repo and follow all the installation instruction and make sure their demos can be run sucessfully. After building cython of 3DDFA, you can use the follow command for other necessary packages:
4 |
5 | ```
6 | pip install opencv-python dlib pyyaml onnxruntime onnx
7 | ```
8 |
9 | Test images used here are from [test/origin/man.jpg](https://www.freepik.com/free-photo/portrait-white-man-isolated_3199590.htm) and [test/origin/woman.jpg](https://www.freepik.com/free-photo/pretty-smiling-joyfully-female-with-fair-hair-dressed-casually-looking-with-satisfaction_9117255.htm)
10 |
11 | ---
12 |
13 | # Steps
14 |
15 | ## 1. Move folder `test`, `recrop_images.py`, and `dlib_kps.py` under this directory to the 3DDFA_V2 root dir. Also, remember to download `shape_predictor_68_face_landmarks.dat` and put it under 3DDFA_V2 root dir.
16 |
17 | ## 2. cd to 3DDFA_V2. The cropping script has to run under 3DDFA_V2 dir.
18 | ```.bash
19 | cd 3DDFA_V2
20 | ```
21 |
22 | ## 3. Extract face keypoints using dlib. After this, you should have data.pkl under your 3DDFA root dir saving the keypoints.
23 | ```.bash
24 | python dlib_kps.py
25 | ```
26 |
27 | ## 4. Obtaining camera poses and cropping the images using recrop_images.py
28 |
29 | ```.bash
30 | python recrop_images.py -i data.pkl -j dataset.json
31 | ```
32 |
33 | ## After this, you should have a folder called crop_samples in your 3DDFA_V2 root dir, which is the same as the one under this directory. Then, you can move the folder back to our panohead root dir, change the --target_img flag to the corresponding folder in `gen_pti_script.sh`, and do the PTI inversion.
--------------------------------------------------------------------------------
/3DDFA_V2_cropping/dlib_kps.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import dlib
3 | import pickle
4 | import numpy as np
5 | import os
6 |
7 | # load face detector
8 | detector = dlib.get_frontal_face_detector()
9 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
10 |
11 | # load images
12 | img_dir = "test/original"
13 | list_dir = os.listdir(img_dir)
14 |
15 | # new dict for keypoints
16 | landmarks = {}
17 |
18 | for img_name in list_dir:
19 | _, extension = os.path.splitext(img_name)
20 |
21 | # only do it for images, not .json file
22 | if extension == '.json':
23 | continue
24 |
25 | img_path = os.path.join(img_dir, img_name)
26 | image = cv2.imread(img_path)
27 |
28 | # gray scale
29 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
30 |
31 | # detect face
32 | rects = detector(gray, 1)
33 |
34 |
35 |
36 | for (i, rect) in enumerate(rects):
37 | # get keypoints
38 | shape = predictor(gray, rect)
39 |
40 | # save kps to the dict
41 | landmarks[img_path] = [np.array([p.x, p.y]) for p in shape.parts()]
42 |
43 | # save the data.pkl pickle
44 | with open('data.pkl', 'wb') as f:
45 | pickle.dump(landmarks, f)
--------------------------------------------------------------------------------
/3DDFA_V2_cropping/test/original/man.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/3DDFA_V2_cropping/test/original/man.jpg
--------------------------------------------------------------------------------
/3DDFA_V2_cropping/test/original/woman.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/3DDFA_V2_cropping/test/original/woman.jpg
--------------------------------------------------------------------------------
/LICENSES/LICENSE_EG3D:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021-2022, NVIDIA Corporation & affiliates. All rights
2 | reserved.
3 |
4 |
5 | NVIDIA Source Code License for EG3D
6 |
7 |
8 | =======================================================================
9 |
10 | 1. Definitions
11 |
12 | "Licensor" means any person or entity that distributes its Work.
13 |
14 | "Software" means the original work of authorship made available under
15 | this License.
16 |
17 | "Work" means the Software and any additions to or derivative works of
18 | the Software that are made available under this License.
19 |
20 | The terms "reproduce," "reproduction," "derivative works," and
21 | "distribution" have the meaning as provided under U.S. copyright law;
22 | provided, however, that for the purposes of this License, derivative
23 | works shall not include works that remain separable from, or merely
24 | link (or bind by name) to the interfaces of, the Work.
25 |
26 | Works, including the Software, are "made available" under this License
27 | by including in or with the Work either (a) a copyright notice
28 | referencing the applicability of this License to the Work, or (b) a
29 | copy of this License.
30 |
31 | 2. License Grants
32 |
33 | 2.1 Copyright Grant. Subject to the terms and conditions of this
34 | License, each Licensor grants to you a perpetual, worldwide,
35 | non-exclusive, royalty-free, copyright license to reproduce,
36 | prepare derivative works of, publicly display, publicly perform,
37 | sublicense and distribute its Work and any resulting derivative
38 | works in any form.
39 |
40 | 3. Limitations
41 |
42 | 3.1 Redistribution. You may reproduce or distribute the Work only
43 | if (a) you do so under this License, (b) you include a complete
44 | copy of this License with your distribution, and (c) you retain
45 | without modification any copyright, patent, trademark, or
46 | attribution notices that are present in the Work.
47 |
48 | 3.2 Derivative Works. You may specify that additional or different
49 | terms apply to the use, reproduction, and distribution of your
50 | derivative works of the Work ("Your Terms") only if (a) Your Terms
51 | provide that the use limitation in Section 3.3 applies to your
52 | derivative works, and (b) you identify the specific derivative
53 | works that are subject to Your Terms. Notwithstanding Your Terms,
54 | this License (including the redistribution requirements in Section
55 | 3.1) will continue to apply to the Work itself.
56 |
57 | 3.3 Use Limitation. The Work and any derivative works thereof only
58 | may be used or intended for use non-commercially. The Work or
59 | derivative works thereof may be used or intended for use by NVIDIA
60 | or it’s affiliates commercially or non-commercially. As used
61 | herein, "non-commercially" means for research or evaluation
62 | purposes only and not for any direct or indirect monetary gain.
63 |
64 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
65 | against any Licensor (including any claim, cross-claim or
66 | counterclaim in a lawsuit) to enforce any patents that you allege
67 | are infringed by any Work, then your rights under this License from
68 | such Licensor (including the grants in Sections 2.1) will terminate
69 | immediately.
70 |
71 | 3.5 Trademarks. This License does not grant any rights to use any
72 | Licensor’s or its affiliates’ names, logos, or trademarks, except
73 | as necessary to reproduce the notices described in this License.
74 |
75 | 3.6 Termination. If you violate any term of this License, then your
76 | rights under this License (including the grants in Sections 2.1)
77 | will terminate immediately.
78 |
79 | 4. Disclaimer of Warranty.
80 |
81 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
82 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
83 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
84 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
85 | THIS LICENSE.
86 |
87 | 5. Limitation of Liability.
88 |
89 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
90 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
91 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
92 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
93 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
94 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
95 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
96 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
97 | THE POSSIBILITY OF SUCH DAMAGES.
98 |
99 | =======================================================================
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## PanoHead: Geometry-Aware 3D Full-Head Synthesis in 360°
2 |
3 |
4 |
5 |
6 |
7 |
8 | 
9 |
10 | **PanoHead: Geometry-Aware 3D Full-Head Synthesis in 360°**
11 | Sizhe An, Hongyi Xu, Yichun Shi, Guoxian Song, Umit Y. Ogras, Linjie Luo
12 |
https://sizhean.github.io/panohead
13 |
14 | Abstract: *Synthesis and reconstruction of 3D human head has gained increasing interests in computer vision and computer graphics recently. Existing state-of-the-art 3D generative adversarial networks (GANs) for 3D human head synthesis are either limited to near-frontal views or hard to preserve 3D consistency in large view angles. We propose PanoHead, the first 3D-aware generative model that enables high-quality view-consistent image synthesis of full heads in 360° with diverse appearance and detailed geometry using only in-the-wild unstructured images for training. At its core, we lift up the representation power of recent 3D GANs and bridge the data alignment gap when training from in-the-wild images with widely distributed views. Specifically, we propose a novel two-stage self-adaptive image alignment for robust 3D GAN training. We further introduce a tri-grid neural volume representation that effectively addresses front-face and back-head feature entanglement rooted in the widely-adopted tri-plane formulation. Our method instills prior knowledge of 2D image segmentation in adversarial learning of 3D neural scene structures, enabling compositable head synthesis in diverse backgrounds. Benefiting from these designs, our method significantly outperforms previous 3D GANs, generating high-quality 3D heads with accurate geometry and diverse appearances, even with long wavy and afro hairstyles, renderable from arbitrary poses. Furthermore, we show that our system can reconstruct full 3D heads from single input images for personalized realistic 3D avatars.*
15 |
16 |
17 | ## Requirements
18 |
19 | * We recommend Linux for performance and compatibility reasons.
20 | * 1–8 high-end NVIDIA GPUs. We have done all testing and development using V100, RTX3090, and A100 GPUs.
21 | * 64-bit Python 3.8 and PyTorch 1.11.0 (or later). See https://pytorch.org for PyTorch install instructions.
22 | * CUDA toolkit 11.3 or later. (Why is a separate CUDA toolkit installation required? We use the custom CUDA extensions from the StyleGAN3 repo. Please see [Troubleshooting](https://github.com/NVlabs/stylegan3/blob/main/docs/troubleshooting.md#why-is-cuda-toolkit-installation-necessary)).
23 | * Python libraries: see [environment.yml](./environment.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:
24 | - `cd PanoHead`
25 | - `conda env create -f environment.yml`
26 | - `conda activate panohead`
27 |
28 |
29 | ## Getting started
30 |
31 | Download the whole `models` folder from [link](https://drive.google.com/drive/folders/1m517-F1NCTGA159dePs5R5qj02svtX1_?usp=sharing) and put it under the root dir.
32 |
33 | Pre-trained networks are stored as `*.pkl` files that can be referenced using local filenames.
34 |
35 |
36 | ## Generating results
37 |
38 | ```.bash
39 | # Generate videos using pre-trained model
40 |
41 | python gen_videos.py --network models/easy-khair-180-gpc0.8-trans10-025000.pkl \
42 | --seeds 0-3 --grid 2x2 --outdir=out --cfg Head --trunc 0.7
43 |
44 | ```
45 |
46 | ```.bash
47 | # Generate images and shapes (as .mrc files) using pre-trained model
48 |
49 | python gen_samples.py --outdir=out --trunc=0.7 --shapes=true --seeds=0-3 \
50 | --network models/easy-khair-180-gpc0.8-trans10-025000.pkl
51 | ```
52 |
53 | ## Applications
54 | ```.bash
55 | # Generate full head reconstruction from a single RGB image.
56 | # Please refer to ./gen_pti_script.sh
57 | # For this application we need to specify dataset folder instead of zip files.
58 | # Segmentation files are not necessary for PTI inversion.
59 |
60 | ./gen_pti_script.sh
61 | ```
62 |
63 | ```.bash
64 | # Generate full head interpolation from two seeds.
65 | # Please refer to ./gen_interpolation.py for the implementation
66 |
67 | python gen_interpolation.py --network models/easy-khair-180-gpc0.8-trans10-025000.pkl\
68 | --trunc 0.7 --outdir interpolation_out
69 | ```
70 |
71 |
72 |
73 | ## Using networks from Python
74 |
75 | You can use pre-trained networks in your own Python code as follows:
76 |
77 | ```.python
78 | with open('*.pkl', 'rb') as f:
79 | G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
80 | z = torch.randn([1, G.z_dim]).cuda() # latent codes
81 | c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters
82 | img = G(z, c)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation
83 | mask = G(z, c)['image_mask'] # NHW, int8, [0,255]
84 | ```
85 |
86 | The above code requires `torch_utils` and `dnnlib` to be accessible via `PYTHONPATH`. It does not need source code for the networks themselves — their class definitions are loaded from the pickle via `torch_utils.persistence`.
87 |
88 | The pickle contains three networks. `'G'` and `'D'` are instantaneous snapshots taken during training, and `'G_ema'` represents a moving average of the generator weights over several training steps. The networks are regular instances of `torch.nn.Module`, with all of their parameters and buffers placed on the CPU at import and gradient computation disabled by default.
89 |
90 |
91 |
92 | ## Datasets
93 |
94 | FFHQ-F(ullhead) consists of [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset), [K-Hairstyle dataset](https://psh01087.github.io/K-Hairstyle/), and an in-house human head dataset. For head pose estimation, we use [WHENet](https://arxiv.org/abs/2005.10353).
95 |
96 | Due to the license issue, we are not able to release FFHQ-F dataset that we used to train the model. [test_data_img](./dataset/testdata_img/) and [test_data_seg](./dataset/testdata_seg/) are just an example for showing the dataset struture. For the camera pose convention, please refer to [EG3D](https://github.com/NVlabs/eg3d).
97 |
98 |
99 | ## Datasets format
100 | For training purpose, we can use either zip files or normal folder for image dataset and segmentation dataset. For PTI, we need to use folder.
101 |
102 | To compress dataset folder to zip file, we can use [dataset_tool_seg](./dataset_tool_seg.py).
103 |
104 | For example:
105 | ```.bash
106 | python dataset_tool_seg.py --img_source dataset/testdata_img --seg_source dataset/testdata_seg --img_dest dataset/testdata_img.zip --seg_dest dataset/testdata_seg.zip --resolution 512x512
107 | ```
108 |
109 | ## Obtaining camera pose and cropping the images
110 | Please follow the [guide](3DDFA_V2_cropping/cropping_guide.md)
111 |
112 | ## Obtaining segmentation masks
113 | You can try using deeplabv3 or other off-the-shelf tool to generate the masks. For example, using deeplabv3: [misc/segmentation_example.py](misc/segmentation_example.py)
114 |
115 |
116 |
117 |
118 | ## Training
119 |
120 | Examples of training using `train.py`:
121 |
122 | ```
123 | # Train with StyleGAN2 backbone from scratch with raw neural rendering resolution=64, using 8 GPUs.
124 | # with segmentation mask, trigrid_depth@3, self-adaptive camera pose loss regularizer@10
125 |
126 | python train.py --outdir training-runs --img_data dataset/testdata_img.zip --seg_data dataset/testdata_seg.zip --cfg=ffhq --batch=32 --gpus 8\\
127 | --gamma=1 --gamma_seg=1 --gen_pose_cond=True --mirror=1 --use_torgb_raw=1 --decoder_activation="none" --disc_module MaskDualDiscriminatorV2\\
128 | --bcg_reg_prob 0.2 --triplane_depth 3 --density_noise_fade_kimg 200 --density_reg 0 --min_yaw 0 --max_yaw 180 --back_repeat 4 --trans_reg 10 --gpc_reg_prob 0.7
129 |
130 |
131 | # Second stage finetuning to 128 neural rendering resolution (optional).
132 |
133 | python train.py --outdir results --img_data dataset/testdata_img.zip --seg_data dataset/testdata_seg.zip --cfg=ffhq --batch=32 --gpus 8\\
134 | --resume=~/training-runs/experiment_dir/network-snapshot-025000.pkl\\
135 | --gamma=1 --gamma_seg=1 --gen_pose_cond=True --mirror=1 --use_torgb_raw=1 --decoder_activation="none" --disc_module MaskDualDiscriminatorV2\\
136 | --bcg_reg_prob 0.2 --triplane_depth 3 --density_noise_fade_kimg 200 --density_reg 0 --min_yaw 0 --max_yaw 180 --back_repeat 4 --trans_reg 10 --gpc_reg_prob 0.7\\
137 | --neural_rendering_resolution_final=128 --resume_kimg 1000
138 | ```
139 |
140 | ## Metrics
141 |
142 |
143 |
144 | ```.bash
145 | ./get_metrics.sh
146 | ```
147 | There are three evaluation modes: all, front, and back as we mentioned in the paper. Please refer to [cal_metrics.py](./calc_metrics.py) for the implementation.
148 |
149 |
150 | ## Citation
151 |
152 | If you find our repo helpful, please cite our paper using the following bib:
153 |
154 | ```
155 | @InProceedings{An_2023_CVPR,
156 | author = {An, Sizhe and Xu, Hongyi and Shi, Yichun and Song, Guoxian and Ogras, Umit Y. and Luo, Linjie},
157 | title = {PanoHead: Geometry-Aware 3D Full-Head Synthesis in 360deg},
158 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
159 | month = {June},
160 | year = {2023},
161 | pages = {20950-20959}
162 | }
163 | ```
164 |
165 | ## Development
166 |
167 | This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests.
168 |
169 | ## Acknowledgements
170 |
171 | We thank Shuhong Chen for the discussion during Sizhe's internship.
172 |
173 | This repo is heavily based off the [NVlabs/eg3d](https://github.com/NVlabs/eg3d) repo; Huge thanks to the EG3D authors for releasing their code!
--------------------------------------------------------------------------------
/calc_mbs.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Generate images and shapes using pretrained network pickle."""
12 |
13 | import os
14 | import re
15 | from typing import List, Optional, Tuple, Union
16 |
17 | import click
18 | import dnnlib
19 | import numpy as np
20 | import PIL.Image
21 | import torch
22 | from tqdm import tqdm
23 |
24 |
25 | import legacy
26 | from camera_utils import LookAtPoseSampler, FOV_to_intrinsics
27 | from torch_utils import misc
28 | from training.triplane import TriPlaneGenerator
29 | from torchvision.models.segmentation import deeplabv3_resnet101
30 | from torchvision import transforms, utils
31 | from torch.nn import functional as F
32 |
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def parse_range(s: Union[str, List]) -> List[int]:
37 | '''Parse a comma separated list of numbers or ranges and return a list of ints.
38 |
39 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
40 | '''
41 | if isinstance(s, list): return s
42 | ranges = []
43 | range_re = re.compile(r'^(\d+)-(\d+)$')
44 | for p in s.split(','):
45 | m = range_re.match(p)
46 | if m:
47 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
48 | else:
49 | ranges.append(int(p))
50 | return ranges
51 |
52 |
53 | #----------------------------------------------------------------------------
54 | def get_mask(model, batch, cid):
55 | normalized_batch = transforms.functional.normalize(
56 | batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
57 | output = model(normalized_batch)['out']
58 | # sem_classes = [
59 | # '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
60 | # 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
61 | # 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
62 | # ]
63 | # sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
64 | # cid = sem_class_to_idx['car']
65 |
66 | normalized_masks = torch.nn.functional.softmax(output, dim=1)
67 |
68 | boolean_car_masks = (normalized_masks.argmax(1) == cid)
69 | return boolean_car_masks.float()
70 |
71 | def norm_ip(img, low, high):
72 | img_ = img.clamp(min=low, max=high)
73 | img_.sub_(low).div_(max(high - low, 1e-5))
74 | return img_
75 |
76 |
77 | def norm_range(t, value_range=(-1, 1)):
78 | if value_range is not None:
79 | return norm_ip(t, value_range[0], value_range[1])
80 | else:
81 | return norm_ip(t, float(t.min()), float(t.max()))
82 |
83 | #----------------------------------------------------------------------------
84 | @click.command()
85 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
86 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.7, show_default=True)
87 | @click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True)
88 | @click.option('--fov-deg', help='Field of View of camera in degrees', type=int, required=False, metavar='float', default=18.837, show_default=True)
89 | @click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True)
90 | @click.option('--pose_cond', type=int, help='pose_cond angle', default=90, show_default=True)
91 |
92 | def generate_images(
93 | network_pkl: str,
94 | truncation_psi: float,
95 | truncation_cutoff: int,
96 | fov_deg: float,
97 | reload_modules: bool,
98 | pose_cond: int,
99 | ):
100 | """Generate images using pretrained network pickle.
101 |
102 | Examples:
103 |
104 | \b
105 | # Generate an image using pre-trained FFHQ model.
106 | python gen_samples.py --outdir=output --trunc=0.7 --seeds=0-5 --shapes=True\\
107 | --network=ffhq-rebalanced-128.pkl
108 | """
109 |
110 | print('Loading networks from "%s"...' % network_pkl)
111 | device = torch.device('cuda:1')
112 | with dnnlib.util.open_url(network_pkl) as f:
113 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
114 |
115 | # Specify reload_modules=True if you want code modifications to take effect; otherwise uses pickled code
116 | if reload_modules:
117 | print("Reloading Modules!")
118 | G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device)
119 | misc.copy_params_and_buffers(G, G_new, require_all=True)
120 | G_new.neural_rendering_resolution = G.neural_rendering_resolution
121 | G_new.rendering_kwargs = G.rendering_kwargs
122 | G = G_new
123 |
124 |
125 | pose_cond_rad = pose_cond/180*np.pi
126 | intrinsics = FOV_to_intrinsics(fov_deg, device=device)
127 |
128 |
129 | # load segmentation net
130 | seg_net = deeplabv3_resnet101(pretrained=True, progress=False).to(device)
131 | seg_net.requires_grad_(False)
132 | seg_net.eval()
133 |
134 | mse_total = 0
135 |
136 | n_sample = 64
137 | batch = 32
138 |
139 | n_sample = n_sample // batch * batch
140 | batch_li = n_sample // batch * [batch]
141 | pose_cond_rad = pose_cond/180*np.pi
142 |
143 | intrinsics = FOV_to_intrinsics(fov_deg, device=device)
144 |
145 | # Generate images.
146 | cam_pivot = torch.tensor([0, 0, 0], device=device)
147 | cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7)
148 | conditioning_cam2world_pose = LookAtPoseSampler.sample(pose_cond_rad, np.pi/2, cam_pivot, radius=cam_radius, device=device)
149 | conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
150 |
151 | c = conditioning_params.repeat(batch, 1)
152 |
153 | for batch in tqdm(batch_li):
154 | # z and w
155 | z0 = torch.from_numpy(np.random.randn(batch, G.z_dim)).to(device)
156 | z1 = torch.from_numpy(np.random.randn(batch, G.z_dim)).to(device)
157 |
158 |
159 | ws0 = G.mapping(z0, c, truncation_psi=0.7, truncation_cutoff=None)
160 | ws1 = G.mapping(z1, c, truncation_psi=0.7, truncation_cutoff=None)
161 |
162 | c0 = c.clone()
163 | c1 = c.clone()
164 |
165 | img0 = G.synthesis(ws0, c0, ws_bcg = ws0.clone())['image']
166 | img0 = norm_range(img0)
167 | img1 = G.synthesis(ws0, c1, ws_bcg = ws1.clone())['image']
168 | img1 = norm_range(img1)
169 | # 15 means human mask
170 | mask0 = get_mask(seg_net, img0, 15).unsqueeze(1)
171 | mask1 = get_mask(seg_net, img1, 15).unsqueeze(1)
172 |
173 | diff = torch.abs(mask0-mask1)
174 | mse = F.mse_loss(mask0, mask1)
175 |
176 | # mutual_bg_mask = (1-mask0) * (1-mask1)
177 |
178 | # diff = F.l1_loss(mutual_bg_mask*img1, mutual_bg_mask*img0, reduction='none')
179 | # diff = torch.where(diff < 1/255, torch.zeros_like(diff), torch.ones_like(diff))
180 | # diff = torch.sum(diff, dim=1)
181 | # diff = torch.where(diff < 1, torch.zeros_like(diff), torch.ones_like(diff))
182 | utils.save_image(
183 | # (1-mask1)*img1,
184 | mask1,
185 | f'alphamse/changebg/mask1.png',
186 | nrow=8,
187 | normalize=True,
188 | range=(0, 1),
189 | padding=0,
190 | )
191 | utils.save_image(
192 | mask0,
193 | f'alphamse/changebg/mask0.png',
194 | nrow=8,
195 | normalize=True,
196 | range=(0, 1),
197 | padding=0,
198 | )
199 | utils.save_image(
200 | img0,
201 | f'alphamse/changebg/img0.png',
202 | nrow=8,
203 | normalize=True,
204 | range=(0, 1),
205 | padding=0,
206 | )
207 | utils.save_image(
208 | img1,
209 | f'alphamse/changebg/img1.png',
210 | nrow=8,
211 | normalize=True,
212 | range=(0, 1),
213 | padding=0,
214 | )
215 | utils.save_image(
216 | diff,
217 | f'alphamse/changebg/diff.png',
218 | nrow=8,
219 | normalize=True,
220 | range=(0, 1),
221 | padding=0,
222 | )
223 | # sys.exit()
224 |
225 | # change_fg_score += torch.sum(torch.sum(diff, dim=(1,2)) / (torch.sum(mutual_bg_mask, dim=(1,2,3))+1e-8))
226 | mse_total += mse.cpu().detach().numpy()
227 |
228 | print(f'mse_final: {mse_total/len(batch_li)}')
229 |
230 |
231 |
232 | #----------------------------------------------------------------------------
233 |
234 | if __name__ == "__main__":
235 | generate_images() # pylint: disable=no-value-for-parameter
236 |
237 | #----------------------------------------------------------------------------
238 |
239 |
--------------------------------------------------------------------------------
/calc_metrics.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Calculate quality metrics for previous training run or pretrained network pickle."""
12 |
13 | import os
14 | import click
15 | import json
16 | import tempfile
17 | import copy
18 | import torch
19 |
20 | import dnnlib
21 | import legacy
22 | from metrics import metric_main
23 | from metrics import metric_utils
24 | from torch_utils import training_stats
25 | from torch_utils import custom_ops
26 | from torch_utils import misc
27 | from torch_utils.ops import conv2d_gradfix
28 |
29 | #----------------------------------------------------------------------------
30 |
31 | def subprocess_fn(rank, args, temp_dir):
32 | dnnlib.util.Logger(should_flush=True)
33 |
34 | # Init torch.distributed.
35 | if args.num_gpus > 1:
36 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
37 | if os.name == 'nt':
38 | init_method = 'file:///' + init_file.replace('\\', '/')
39 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
40 | else:
41 | init_method = f'file://{init_file}'
42 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
43 |
44 | # Init torch_utils.
45 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
46 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
47 | if rank != 0 or not args.verbose:
48 | custom_ops.verbosity = 'none'
49 |
50 | # Configure torch.
51 | device = torch.device('cuda', rank)
52 | torch.backends.cuda.matmul.allow_tf32 = False
53 | torch.backends.cudnn.allow_tf32 = False
54 | conv2d_gradfix.enabled = True
55 |
56 | # Print network summary.
57 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
58 | if rank == 0 and args.verbose:
59 | z = torch.empty([1, G.z_dim], device=device)
60 | c = torch.empty([1, G.c_dim], device=device)
61 | misc.print_module_summary(G, [z, c])
62 |
63 | # Calculate each metric.
64 | for metric in args.metrics:
65 | if rank == 0 and args.verbose:
66 | print(f'Calculating {metric}...')
67 | progress = metric_utils.ProgressMonitor(verbose=args.verbose)
68 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
69 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress, mode=args.mode)
70 | if rank == 0:
71 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
72 | if rank == 0 and args.verbose:
73 | print()
74 |
75 | # Done.
76 | if rank == 0 and args.verbose:
77 | print('Exiting...')
78 |
79 | #----------------------------------------------------------------------------
80 |
81 | def parse_comma_separated_list(s):
82 | if isinstance(s, list):
83 | return s
84 | if s is None or s.lower() == 'none' or s == '':
85 | return []
86 | return s.split(',')
87 |
88 | #----------------------------------------------------------------------------
89 |
90 | @click.command()
91 | @click.pass_context
92 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
93 | @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
94 | @click.option('--img_data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
95 | @click.option('--seg_data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
96 | @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL')
97 | @click.option('--mode', help='Evaluation mode [back, front, all]', metavar='STR', type=click.Choice(['back','front','all']), default='all', required=False, show_default=False)
98 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
99 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
100 |
101 | def calc_metrics(ctx, network_pkl, metrics, img_data, seg_data, mirror, mode, gpus, verbose):
102 | """Calculate quality metrics for previous training run or pretrained network pickle.
103 |
104 | Examples:
105 |
106 | \b
107 | # Previous training run: look up options automatically, save result to JSONL file.
108 | python calc_metrics.py --metrics=eqt50k_int,eqr50k \\
109 | --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
110 |
111 | \b
112 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
113 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\
114 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
115 |
116 | \b
117 | Recommended metrics:
118 | fid50k_full Frechet inception distance against the full dataset.
119 | kid50k_full Kernel inception distance against the full dataset.
120 | pr50k3_full Precision and recall againt the full dataset.
121 | ppl2_wend Perceptual path length in W, endpoints, full image.
122 | eqt50k_int Equivariance w.r.t. integer translation (EQ-T).
123 | eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac).
124 | eqr50k Equivariance w.r.t. rotation (EQ-R).
125 |
126 | \b
127 | Legacy metrics:
128 | fid50k Frechet inception distance against 50k real images.
129 | kid50k Kernel inception distance against 50k real images.
130 | pr50k3 Precision and recall against 50k real images.
131 | is50k Inception score for CIFAR-10.
132 | """
133 | dnnlib.util.Logger(should_flush=True)
134 |
135 | # decide evaluation mode
136 | if mode == 'all':
137 | min_yaw, max_yaw = 0, 180
138 | elif mode == 'front':
139 | min_yaw, max_yaw = 0, 90
140 | elif mode == 'back':
141 | min_yaw, max_yaw = 90, 180
142 |
143 | # Validate arguments.
144 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose, mode=mode)
145 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
146 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
147 | if not args.num_gpus >= 1:
148 | ctx.fail('--gpus must be at least 1')
149 |
150 | # Load network.
151 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
152 | ctx.fail('--network must point to a file or URL')
153 | if args.verbose:
154 | print(f'Loading network from "{network_pkl}"...')
155 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
156 | network_dict = legacy.load_network_pkl(f)
157 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module
158 |
159 | # Initialize dataset options.
160 | if img_data is not None:
161 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.MaskLabeledDataset', img_path=img_data, seg_path=seg_data, min_yaw = min_yaw, max_yaw = max_yaw, back_repeat=1)
162 | elif network_dict['training_set_kwargs'] is not None:
163 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
164 | else:
165 | ctx.fail('Could not look up dataset options; please specify --data')
166 |
167 | # Finalize dataset options.
168 | args.dataset_kwargs.resolution = args.G.img_resolution
169 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
170 | if mirror is not None:
171 | args.dataset_kwargs.xflip = mirror
172 |
173 | # Print dataset options.
174 | if args.verbose:
175 | print('Dataset options:')
176 | print(json.dumps(args.dataset_kwargs, indent=2))
177 |
178 | # Locate run dir.
179 | args.run_dir = None
180 | if os.path.isfile(network_pkl):
181 | pkl_dir = os.path.dirname(network_pkl)
182 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
183 | args.run_dir = pkl_dir
184 |
185 | # Launch processes.
186 | if args.verbose:
187 | print('Launching processes...')
188 | torch.multiprocessing.set_start_method('spawn')
189 | with tempfile.TemporaryDirectory() as temp_dir:
190 | if args.num_gpus == 1:
191 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
192 | else:
193 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
194 |
195 | #----------------------------------------------------------------------------
196 |
197 | if __name__ == "__main__":
198 | calc_metrics() # pylint: disable=no-value-for-parameter
199 |
200 | #----------------------------------------------------------------------------
201 |
--------------------------------------------------------------------------------
/camera_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """
12 | Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
13 | """
14 |
15 | import math
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 | from training.volumetric_rendering import math_utils
21 |
22 | class GaussianCameraPoseSampler:
23 | """
24 | Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
25 | Camera is specified as looking at the origin.
26 | If horizontal and vertical stddev (specified in radians) are zero, gives a
27 | deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
28 | The coordinate system is specified with y-up, z-forward, x-left.
29 | Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
30 | vertical mean is the polar angle (angle from the y axis) in radians.
31 | A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
32 |
33 | Example:
34 | For a camera pose looking at the origin with the camera at position [0, 0, 1]:
35 | cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
36 | """
37 |
38 | @staticmethod
39 | def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
40 | h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
41 | v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
42 | v = torch.clamp(v, 1e-5, math.pi - 1e-5)
43 |
44 | theta = h
45 | v = v / math.pi
46 | phi = torch.arccos(1 - 2*v)
47 |
48 | camera_origins = torch.zeros((batch_size, 3), device=device)
49 |
50 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
51 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
52 | camera_origins[:, 1:2] = radius*torch.cos(phi)
53 |
54 | forward_vectors = math_utils.normalize_vecs(-camera_origins)
55 | return create_cam2world_matrix(forward_vectors, camera_origins)
56 |
57 |
58 | class LookAtPoseSampler:
59 | """
60 | Same as GaussianCameraPoseSampler, except the
61 | camera is specified as looking at 'lookat_position', a 3-vector.
62 |
63 | Example:
64 | For a camera pose looking at the origin with the camera at position [0, 0, 1]:
65 | cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
66 | """
67 |
68 | @staticmethod
69 | def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
70 | h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
71 | v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
72 | v = torch.clamp(v, 1e-5, math.pi - 1e-5)
73 |
74 | theta = h
75 | v = v / math.pi
76 | phi = torch.arccos(1 - 2*v)
77 |
78 | camera_origins = torch.zeros((batch_size, 3), device=device)
79 |
80 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
81 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
82 | camera_origins[:, 1:2] = radius*torch.cos(phi)
83 |
84 | # forward_vectors = math_utils.normalize_vecs(-camera_origins)
85 | forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins)
86 | return create_cam2world_matrix(forward_vectors, camera_origins)
87 |
88 | class UniformCameraPoseSampler:
89 | """
90 | Same as GaussianCameraPoseSampler, except the
91 | pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev.
92 |
93 | Example:
94 | For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
95 |
96 | cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
97 | """
98 |
99 | @staticmethod
100 | def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
101 | h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean
102 | v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean
103 | v = torch.clamp(v, 1e-5, math.pi - 1e-5)
104 |
105 | theta = h
106 | v = v / math.pi
107 | phi = torch.arccos(1 - 2*v)
108 |
109 | camera_origins = torch.zeros((batch_size, 3), device=device)
110 |
111 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
112 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
113 | camera_origins[:, 1:2] = radius*torch.cos(phi)
114 |
115 | forward_vectors = math_utils.normalize_vecs(-camera_origins)
116 | return create_cam2world_matrix(forward_vectors, camera_origins)
117 |
118 | def create_cam2world_matrix(forward_vector, origin):
119 | """
120 | Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
121 | Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
122 | """
123 |
124 | forward_vector = math_utils.normalize_vecs(forward_vector)
125 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector)
126 |
127 | right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
128 | up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1))
129 |
130 | rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
131 | rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1)
132 |
133 | translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
134 | translation_matrix[:, :3, 3] = origin
135 | cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
136 | assert(cam2world.shape[1:] == (4, 4))
137 | return cam2world
138 |
139 |
140 | def FOV_to_intrinsics(fov_degrees, device='cpu'):
141 | """
142 | Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
143 | Note the intrinsics are returned as normalized by image size, rather than in pixel units.
144 | Assumes principal point is at image center.
145 | """
146 |
147 | focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
148 | intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
149 | return intrinsics
--------------------------------------------------------------------------------
/dataset/testdata_img.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_img.zip
--------------------------------------------------------------------------------
/dataset/testdata_img/000134.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_img/000134.jpg
--------------------------------------------------------------------------------
/dataset/testdata_img/000157.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_img/000157.jpg
--------------------------------------------------------------------------------
/dataset/testdata_img/dataset.json:
--------------------------------------------------------------------------------
1 | {
2 | "labels": [
3 | [
4 | "000134.jpg",
5 | [
6 | "0.980169",
7 | "-0.004075",
8 | "-0.198125",
9 | "0.534939",
10 | "-0.003572",
11 | "-0.999991",
12 | "0.001775",
13 | "-0.004793",
14 | "-0.198131",
15 | "-0.000988",
16 | "-0.980175",
17 | "2.646472",
18 | "0.000000",
19 | "0.000000",
20 | "0.000000",
21 | "1.000000",
22 | "4.264700",
23 | "0.000000",
24 | "0.500000",
25 | "0.000000",
26 | "4.264700",
27 | "0.500000",
28 | "0.000000",
29 | "0.000000",
30 | "1.000000"
31 | ]
32 | ],
33 | [
34 | "000157.jpg",
35 | [
36 | "0.975032",
37 | "-0.009496",
38 | "0.221723",
39 | "-0.598652",
40 | "-0.008519",
41 | "-0.999792",
42 | "-0.020003",
43 | "0.054008",
44 | "0.221925",
45 | "0.018354",
46 | "-0.974910",
47 | "2.632258",
48 | "0.000000",
49 | "0.000000",
50 | "0.000000",
51 | "1.000000",
52 | "4.264700",
53 | "0.000000",
54 | "0.500000",
55 | "0.000000",
56 | "4.264700",
57 | "0.500000",
58 | "0.000000",
59 | "0.000000",
60 | "1.000000"
61 | ]
62 | ]
63 | ]
64 | }
--------------------------------------------------------------------------------
/dataset/testdata_seg.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_seg.zip
--------------------------------------------------------------------------------
/dataset/testdata_seg/000134.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_seg/000134.png
--------------------------------------------------------------------------------
/dataset/testdata_seg/000157.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_seg/000157.png
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | from .util import EasyDict, make_cache_dir_path
12 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | name: panohead
12 | channels:
13 | - pytorch
14 | - nvidia
15 | dependencies:
16 | - python >= 3.8
17 | - pip
18 | - numpy>=1.20
19 | - click>=8.0
20 | - pillow=8.3.1
21 | - scipy=1.7.1
22 | - pytorch=1.11.0
23 | - cudatoolkit=11.1
24 | - requests=2.26.0
25 | - tqdm=4.62.2
26 | - ninja=1.10.2
27 | - matplotlib=3.4.2
28 | - imageio=2.9.0
29 | - pip:
30 | - imgui==1.3.0
31 | - glfw==2.2.0
32 | - pyopengl==3.1.5
33 | - imageio-ffmpeg==0.4.3
34 | - pyspng
35 | - psutil
36 | - mrcfile
37 | - tensorboard
38 | - torchvision==0.12.0
--------------------------------------------------------------------------------
/gen_interpolation.py:
--------------------------------------------------------------------------------
1 | ''' Generate images and shapes using pretrained network pickle.
2 | Code adapted from following paper
3 | "Efficient Geometry-aware 3D Generative Adversarial Networks."
4 | See LICENSES/LICENSE_EG3D for original license.
5 | '''
6 |
7 | import os
8 | import re
9 | from typing import List, Optional, Tuple, Union
10 |
11 | import click
12 | import dnnlib
13 | import numpy as np
14 | import PIL.Image
15 | import torch
16 | from tqdm import tqdm
17 |
18 |
19 | import legacy
20 | from camera_utils import LookAtPoseSampler, FOV_to_intrinsics
21 | from torch_utils import misc
22 | from training.triplane import TriPlaneGenerator
23 | from torchvision.models.segmentation import deeplabv3_resnet101
24 | from torchvision import transforms, utils
25 | from torch.nn import functional as F
26 |
27 |
28 | #----------------------------------------------------------------------------
29 |
30 | def parse_range(s: Union[str, List]) -> List[int]:
31 | '''Parse a comma separated list of numbers or ranges and return a list of ints.
32 |
33 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
34 | '''
35 | if isinstance(s, list): return s
36 | ranges = []
37 | range_re = re.compile(r'^(\d+)-(\d+)$')
38 | for p in s.split(','):
39 | m = range_re.match(p)
40 | if m:
41 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
42 | else:
43 | ranges.append(int(p))
44 | return ranges
45 |
46 |
47 | #----------------------------------------------------------------------------
48 | def get_mask(model, batch, cid):
49 | normalized_batch = transforms.functional.normalize(
50 | batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
51 | output = model(normalized_batch)['out']
52 | # sem_classes = [
53 | # '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
54 | # 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
55 | # 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
56 | # ]
57 | # sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
58 | # cid = sem_class_to_idx['car']
59 |
60 | normalized_masks = torch.nn.functional.softmax(output, dim=1)
61 |
62 | boolean_car_masks = (normalized_masks.argmax(1) == cid)
63 | return boolean_car_masks.float()
64 |
65 | def norm_ip(img, low, high):
66 | img_ = img.clamp(min=low, max=high)
67 | img_.sub_(low).div_(max(high - low, 1e-5))
68 | return img_
69 |
70 |
71 | def norm_range(t, value_range=(-1, 1)):
72 | if value_range is not None:
73 | return norm_ip(t, value_range[0], value_range[1])
74 | else:
75 | return norm_ip(t, float(t.min()), float(t.max()))
76 |
77 | #----------------------------------------------------------------------------
78 | @click.command()
79 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
80 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.7, show_default=True)
81 | @click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True)
82 | @click.option('--fov-deg', help='Field of View of camera in degrees', type=int, required=False, metavar='float', default=18.837, show_default=True)
83 | @click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True)
84 | @click.option('--pose_cond', type=int, help='pose_cond angle', default=90, show_default=True)
85 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
86 |
87 | def generate_images(
88 | network_pkl: str,
89 | outdir: str,
90 | truncation_psi: float,
91 | truncation_cutoff: int,
92 | fov_deg: float,
93 | reload_modules: bool,
94 | pose_cond: int,
95 | ):
96 | """Generate interpolation images using pretrained network pickle.
97 |
98 | Examples:
99 |
100 | \b
101 | python gen_interpolation.py --network models/easy-khair-180-gpc0.8-trans10-025000.pkl\
102 | --trunc 0.7 --outdir interpolation_out
103 | """
104 |
105 | print('Loading networks from "%s"...' % network_pkl)
106 | device = torch.device('cuda:1')
107 | with dnnlib.util.open_url(network_pkl) as f:
108 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
109 |
110 | # Specify reload_modules=True if you want code modifications to take effect; otherwise uses pickled code
111 | if reload_modules:
112 | print("Reloading Modules!")
113 | G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device)
114 | misc.copy_params_and_buffers(G, G_new, require_all=True)
115 | G_new.neural_rendering_resolution = G.neural_rendering_resolution
116 | G_new.rendering_kwargs = G.rendering_kwargs
117 | G = G_new
118 |
119 | network_pkl = os.path.basename(network_pkl)
120 | outdir = os.path.join(outdir, os.path.splitext(network_pkl)[0] + '_' + str(pose_cond))
121 | os.makedirs(outdir, exist_ok=True)
122 |
123 | pose_cond_rad = pose_cond/180*np.pi
124 | intrinsics = FOV_to_intrinsics(fov_deg, device=device)
125 |
126 |
127 |
128 | # n_sample = 64
129 | # batch = 32
130 |
131 | # n_sample = n_sample // batch * batch
132 | # batch_li = n_sample // batch * [batch]
133 | pose_cond_rad = pose_cond/180*np.pi
134 |
135 | intrinsics = FOV_to_intrinsics(fov_deg, device=device)
136 |
137 | # Generate images.
138 | cam_pivot = torch.tensor([0, 0, 0], device=device)
139 | cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7)
140 | conditioning_cam2world_pose = LookAtPoseSampler.sample(pose_cond_rad, np.pi/2, cam_pivot, radius=cam_radius, device=device)
141 | conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
142 |
143 | conditioning_cam2world_pose_back = LookAtPoseSampler.sample(-pose_cond_rad, np.pi/2, cam_pivot, radius=cam_radius, device=device)
144 | conditioning_params_back = torch.cat([conditioning_cam2world_pose_back.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
145 |
146 |
147 | conditioning_cam2world_pose_side = LookAtPoseSampler.sample(45/180*np.pi, np.pi/2, cam_pivot, radius=cam_radius, device=device)
148 | conditioning_params_side = torch.cat([conditioning_cam2world_pose_side.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
149 |
150 | # set two random seeds for interpolation
151 | seed1, seed2 = 521, 329
152 | z0 = torch.from_numpy(np.random.RandomState(seed1).randn(G.z_dim).reshape(1,G.z_dim)).to(device)
153 | z1 = torch.from_numpy(np.random.RandomState(seed2).randn(G.z_dim).reshape(1,G.z_dim)).to(device)
154 | c = conditioning_params
155 |
156 | ws0 = G.mapping(z0, c, truncation_psi=0.7, truncation_cutoff=None)
157 | ws1 = G.mapping(z1, c, truncation_psi=0.7, truncation_cutoff=None)
158 |
159 | image_final = []
160 | for c in [conditioning_params, conditioning_params_side, conditioning_params_back]:
161 | img0 = G.synthesis(ws0, c)['image']
162 | img0 = norm_range(img0)
163 | img1 = G.synthesis(ws1, c)['image']
164 | img1 = norm_range(img1)
165 |
166 |
167 |
168 |
169 |
170 | img_list = []
171 | for interpolation_idx in [0,2,3,4,6,8]:
172 | # for interpolation_idx in range(0,14,1):
173 | # interpolation_idx = 8
174 | ws_new = ws0.clone()
175 | ws_new[:, interpolation_idx:, :] = ws1[:, interpolation_idx:, :]
176 | img_new = G.synthesis(ws_new, c)['image']
177 | img_new = norm_range(img_new)
178 | img_list.append(img_new)
179 |
180 | img_list.append(img0)
181 |
182 | img_new = torch.cat(img_list, dim=0)
183 | image_final.append(img_new)
184 |
185 | image_final = torch.cat(image_final, dim=2)
186 |
187 | utils.save_image(
188 | image_final,
189 | os.path.join(outdir, f'img_interpolation_seed{seed1}_{seed2}.png'),
190 | # nrow=8,
191 | normalize=True,
192 | range=(0, 1),
193 | padding=0,
194 | )
195 | # utils.save_image(
196 | # diff,
197 | # f'alphamse/changebg/diff.png',
198 | # nrow=8,
199 | # normalize=True,
200 | # range=(0, 1),
201 | # padding=0,
202 | # )
203 | # sys.exit()
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 | #----------------------------------------------------------------------------
213 |
214 | if __name__ == "__main__":
215 | generate_images() # pylint: disable=no-value-for-parameter
216 |
217 | #----------------------------------------------------------------------------
218 |
219 |
--------------------------------------------------------------------------------
/gen_pti_script.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | models=("easy-khair-180-gpc0.8-trans10-025000.pkl"\
4 | "ablation-trigridD-1-025000.pkl")
5 |
6 | in="models"
7 | out="pti_out"
8 |
9 | for model in ${models[@]}
10 |
11 | do
12 |
13 | for i in 0 1
14 |
15 | do
16 | # perform the pti and save w
17 | python projector_withseg.py --outdir=${out} --target_img=dataset/testdata_img --network ${in}/${model} --idx ${i}
18 | # generate .mp4 before finetune
19 | python gen_videos_proj_withseg.py --output=${out}/${model}/${i}/PTI_render/pre.mp4 --latent=${out}/${model}/${i}/projected_w.npz --trunc 0.7 --network ${in}/${model} --cfg Head
20 | # generate .mp4 after finetune
21 | python gen_videos_proj_withseg.py --output=${out}/${model}/${i}/PTI_render/post.mp4 --latent=${out}/${model}/${i}/projected_w.npz --trunc 0.7 --network ${out}/${model}/${i}/fintuned_generator.pkl --cfg Head
22 |
23 |
24 | done
25 |
26 | done
27 |
--------------------------------------------------------------------------------
/get_metrics.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | models=("easy-khair-180-gpc0.8-trans10-025000.pkl"\
4 | "ablation-trigridD-1-025000.pkl"\
5 | )
6 |
7 |
8 | for model in ${models[@]}
9 |
10 | do
11 |
12 |
13 | python calc_metrics.py --network models/${model} \
14 | --img_data=dataset/ffhq-3DDFA-exp-augx2-lpx2-easyx2-khairfiltered-img.zip\
15 | --seg_data=dataset/ffhq-3DDFA-exp-augx2-lpx2-easyx2-khairfiltered-seg.zip\
16 | --gpus 8 --mirror True --metrics fid50k_full,is50k | tee -a paper_metrics/${model}.log
17 |
18 |
19 | done
20 |
--------------------------------------------------------------------------------
/gui_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
12 |
--------------------------------------------------------------------------------
/gui_utils/glfw_window.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import time
12 | import glfw
13 | import OpenGL.GL as gl
14 | from . import gl_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | class GlfwWindow: # pylint: disable=too-many-public-methods
19 | def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
20 | self._glfw_window = None
21 | self._drawing_frame = False
22 | self._frame_start_time = None
23 | self._frame_delta = 0
24 | self._fps_limit = None
25 | self._vsync = None
26 | self._skip_frames = 0
27 | self._deferred_show = deferred_show
28 | self._close_on_esc = close_on_esc
29 | self._esc_pressed = False
30 | self._drag_and_drop_paths = None
31 | self._capture_next_frame = False
32 | self._captured_frame = None
33 |
34 | # Create window.
35 | glfw.init()
36 | glfw.window_hint(glfw.VISIBLE, False)
37 | self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
38 | self._attach_glfw_callbacks()
39 | self.make_context_current()
40 |
41 | # Adjust window.
42 | self.set_vsync(False)
43 | self.set_window_size(window_width, window_height)
44 | if not self._deferred_show:
45 | glfw.show_window(self._glfw_window)
46 |
47 | def close(self):
48 | if self._drawing_frame:
49 | self.end_frame()
50 | if self._glfw_window is not None:
51 | glfw.destroy_window(self._glfw_window)
52 | self._glfw_window = None
53 | #glfw.terminate() # Commented out to play it nice with other glfw clients.
54 |
55 | def __del__(self):
56 | try:
57 | self.close()
58 | except:
59 | pass
60 |
61 | @property
62 | def window_width(self):
63 | return self.content_width
64 |
65 | @property
66 | def window_height(self):
67 | return self.content_height + self.title_bar_height
68 |
69 | @property
70 | def content_width(self):
71 | width, _height = glfw.get_window_size(self._glfw_window)
72 | return width
73 |
74 | @property
75 | def content_height(self):
76 | _width, height = glfw.get_window_size(self._glfw_window)
77 | return height
78 |
79 | @property
80 | def title_bar_height(self):
81 | _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
82 | return top
83 |
84 | @property
85 | def monitor_width(self):
86 | _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
87 | return width
88 |
89 | @property
90 | def monitor_height(self):
91 | _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
92 | return height
93 |
94 | @property
95 | def frame_delta(self):
96 | return self._frame_delta
97 |
98 | def set_title(self, title):
99 | glfw.set_window_title(self._glfw_window, title)
100 |
101 | def set_window_size(self, width, height):
102 | width = min(width, self.monitor_width)
103 | height = min(height, self.monitor_height)
104 | glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
105 | if width == self.monitor_width and height == self.monitor_height:
106 | self.maximize()
107 |
108 | def set_content_size(self, width, height):
109 | self.set_window_size(width, height + self.title_bar_height)
110 |
111 | def maximize(self):
112 | glfw.maximize_window(self._glfw_window)
113 |
114 | def set_position(self, x, y):
115 | glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
116 |
117 | def center(self):
118 | self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
119 |
120 | def set_vsync(self, vsync):
121 | vsync = bool(vsync)
122 | if vsync != self._vsync:
123 | glfw.swap_interval(1 if vsync else 0)
124 | self._vsync = vsync
125 |
126 | def set_fps_limit(self, fps_limit):
127 | self._fps_limit = int(fps_limit)
128 |
129 | def should_close(self):
130 | return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
131 |
132 | def skip_frame(self):
133 | self.skip_frames(1)
134 |
135 | def skip_frames(self, num): # Do not update window for the next N frames.
136 | self._skip_frames = max(self._skip_frames, int(num))
137 |
138 | def is_skipping_frames(self):
139 | return self._skip_frames > 0
140 |
141 | def capture_next_frame(self):
142 | self._capture_next_frame = True
143 |
144 | def pop_captured_frame(self):
145 | frame = self._captured_frame
146 | self._captured_frame = None
147 | return frame
148 |
149 | def pop_drag_and_drop_paths(self):
150 | paths = self._drag_and_drop_paths
151 | self._drag_and_drop_paths = None
152 | return paths
153 |
154 | def draw_frame(self): # To be overridden by subclass.
155 | self.begin_frame()
156 | # Rendering code goes here.
157 | self.end_frame()
158 |
159 | def make_context_current(self):
160 | if self._glfw_window is not None:
161 | glfw.make_context_current(self._glfw_window)
162 |
163 | def begin_frame(self):
164 | # End previous frame.
165 | if self._drawing_frame:
166 | self.end_frame()
167 |
168 | # Apply FPS limit.
169 | if self._frame_start_time is not None and self._fps_limit is not None:
170 | delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
171 | if delay > 0:
172 | time.sleep(delay)
173 | cur_time = time.perf_counter()
174 | if self._frame_start_time is not None:
175 | self._frame_delta = cur_time - self._frame_start_time
176 | self._frame_start_time = cur_time
177 |
178 | # Process events.
179 | glfw.poll_events()
180 |
181 | # Begin frame.
182 | self._drawing_frame = True
183 | self.make_context_current()
184 |
185 | # Initialize GL state.
186 | gl.glViewport(0, 0, self.content_width, self.content_height)
187 | gl.glMatrixMode(gl.GL_PROJECTION)
188 | gl.glLoadIdentity()
189 | gl.glTranslate(-1, 1, 0)
190 | gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
191 | gl.glMatrixMode(gl.GL_MODELVIEW)
192 | gl.glLoadIdentity()
193 | gl.glEnable(gl.GL_BLEND)
194 | gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
195 |
196 | # Clear.
197 | gl.glClearColor(0, 0, 0, 1)
198 | gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
199 |
200 | def end_frame(self):
201 | assert self._drawing_frame
202 | self._drawing_frame = False
203 |
204 | # Skip frames if requested.
205 | if self._skip_frames > 0:
206 | self._skip_frames -= 1
207 | return
208 |
209 | # Capture frame if requested.
210 | if self._capture_next_frame:
211 | self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
212 | self._capture_next_frame = False
213 |
214 | # Update window.
215 | if self._deferred_show:
216 | glfw.show_window(self._glfw_window)
217 | self._deferred_show = False
218 | glfw.swap_buffers(self._glfw_window)
219 |
220 | def _attach_glfw_callbacks(self):
221 | glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
222 | glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
223 |
224 | def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
225 | if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
226 | self._esc_pressed = True
227 |
228 | def _glfw_drop_callback(self, _window, paths):
229 | self._drag_and_drop_paths = paths
230 |
231 | #----------------------------------------------------------------------------
232 |
--------------------------------------------------------------------------------
/gui_utils/imgui_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import contextlib
12 | import imgui
13 |
14 | #----------------------------------------------------------------------------
15 |
16 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
17 | s = imgui.get_style()
18 | s.window_padding = [spacing, spacing]
19 | s.item_spacing = [spacing, spacing]
20 | s.item_inner_spacing = [spacing, spacing]
21 | s.columns_min_spacing = spacing
22 | s.indent_spacing = indent
23 | s.scrollbar_size = scrollbar
24 | s.frame_padding = [4, 3]
25 | s.window_border_size = 1
26 | s.child_border_size = 1
27 | s.popup_border_size = 1
28 | s.frame_border_size = 1
29 | s.window_rounding = 0
30 | s.child_rounding = 0
31 | s.popup_rounding = 3
32 | s.frame_rounding = 3
33 | s.scrollbar_rounding = 3
34 | s.grab_rounding = 3
35 |
36 | getattr(imgui, f'style_colors_{color_scheme}')(s)
37 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
38 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
39 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
40 |
41 | #----------------------------------------------------------------------------
42 |
43 | @contextlib.contextmanager
44 | def grayed_out(cond=True):
45 | if cond:
46 | s = imgui.get_style()
47 | text = s.colors[imgui.COLOR_TEXT_DISABLED]
48 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
49 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
50 | imgui.push_style_color(imgui.COLOR_TEXT, *text)
51 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
52 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
53 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
55 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
56 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
57 | imgui.push_style_color(imgui.COLOR_BUTTON, *back)
58 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
59 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
60 | imgui.push_style_color(imgui.COLOR_HEADER, *back)
61 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
62 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
63 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
64 | yield
65 | imgui.pop_style_color(14)
66 | else:
67 | yield
68 |
69 | #----------------------------------------------------------------------------
70 |
71 | @contextlib.contextmanager
72 | def item_width(width=None):
73 | if width is not None:
74 | imgui.push_item_width(width)
75 | yield
76 | imgui.pop_item_width()
77 | else:
78 | yield
79 |
80 | #----------------------------------------------------------------------------
81 |
82 | def scoped_by_object_id(method):
83 | def decorator(self, *args, **kwargs):
84 | imgui.push_id(str(id(self)))
85 | res = method(self, *args, **kwargs)
86 | imgui.pop_id()
87 | return res
88 | return decorator
89 |
90 | #----------------------------------------------------------------------------
91 |
92 | def button(label, width=0, enabled=True):
93 | with grayed_out(not enabled):
94 | clicked = imgui.button(label, width=width)
95 | clicked = clicked and enabled
96 | return clicked
97 |
98 | #----------------------------------------------------------------------------
99 |
100 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
101 | expanded = False
102 | if show:
103 | if default:
104 | flags |= imgui.TREE_NODE_DEFAULT_OPEN
105 | if not enabled:
106 | flags |= imgui.TREE_NODE_LEAF
107 | with grayed_out(not enabled):
108 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
109 | expanded = expanded and enabled
110 | return expanded, visible
111 |
112 | #----------------------------------------------------------------------------
113 |
114 | def popup_button(label, width=0, enabled=True):
115 | if button(label, width, enabled):
116 | imgui.open_popup(label)
117 | opened = imgui.begin_popup(label)
118 | return opened
119 |
120 | #----------------------------------------------------------------------------
121 |
122 | def input_text(label, value, buffer_length, flags, width=None, help_text=''):
123 | old_value = value
124 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
125 | if value == '':
126 | color[-1] *= 0.5
127 | with item_width(width):
128 | imgui.push_style_color(imgui.COLOR_TEXT, *color)
129 | value = value if value != '' else help_text
130 | changed, value = imgui.input_text(label, value, buffer_length, flags)
131 | value = value if value != help_text else ''
132 | imgui.pop_style_color(1)
133 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
134 | changed = (value != old_value)
135 | return changed, value
136 |
137 | #----------------------------------------------------------------------------
138 |
139 | def drag_previous_control(enabled=True):
140 | dragging = False
141 | dx = 0
142 | dy = 0
143 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
144 | if enabled:
145 | dragging = True
146 | dx, dy = imgui.get_mouse_drag_delta()
147 | imgui.reset_mouse_drag_delta()
148 | imgui.end_drag_drop_source()
149 | return dragging, dx, dy
150 |
151 | #----------------------------------------------------------------------------
152 |
153 | def drag_button(label, width=0, enabled=True):
154 | clicked = button(label, width=width, enabled=enabled)
155 | dragging, dx, dy = drag_previous_control(enabled=enabled)
156 | return clicked, dragging, dx, dy
157 |
158 | #----------------------------------------------------------------------------
159 |
160 | def drag_hidden_window(label, x, y, width, height, enabled=True):
161 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
162 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
163 | imgui.set_next_window_position(x, y)
164 | imgui.set_next_window_size(width, height)
165 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
166 | dragging, dx, dy = drag_previous_control(enabled=enabled)
167 | imgui.end()
168 | imgui.pop_style_color(2)
169 | return dragging, dx, dy
170 |
171 | #----------------------------------------------------------------------------
172 |
--------------------------------------------------------------------------------
/gui_utils/imgui_window.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import os
12 | import imgui
13 | import imgui.integrations.glfw
14 |
15 | from . import glfw_window
16 | from . import imgui_utils
17 | from . import text_utils
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | class ImguiWindow(glfw_window.GlfwWindow):
22 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
23 | if font is None:
24 | font = text_utils.get_default_font()
25 | font_sizes = {int(size) for size in font_sizes}
26 | super().__init__(title=title, **glfw_kwargs)
27 |
28 | # Init fields.
29 | self._imgui_context = None
30 | self._imgui_renderer = None
31 | self._imgui_fonts = None
32 | self._cur_font_size = max(font_sizes)
33 |
34 | # Delete leftover imgui.ini to avoid unexpected behavior.
35 | if os.path.isfile('imgui.ini'):
36 | os.remove('imgui.ini')
37 |
38 | # Init ImGui.
39 | self._imgui_context = imgui.create_context()
40 | self._imgui_renderer = _GlfwRenderer(self._glfw_window)
41 | self._attach_glfw_callbacks()
42 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
43 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
44 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
45 | self._imgui_renderer.refresh_font_texture()
46 |
47 | def close(self):
48 | self.make_context_current()
49 | self._imgui_fonts = None
50 | if self._imgui_renderer is not None:
51 | self._imgui_renderer.shutdown()
52 | self._imgui_renderer = None
53 | if self._imgui_context is not None:
54 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
55 | self._imgui_context = None
56 | super().close()
57 |
58 | def _glfw_key_callback(self, *args):
59 | super()._glfw_key_callback(*args)
60 | self._imgui_renderer.keyboard_callback(*args)
61 |
62 | @property
63 | def font_size(self):
64 | return self._cur_font_size
65 |
66 | @property
67 | def spacing(self):
68 | return round(self._cur_font_size * 0.4)
69 |
70 | def set_font_size(self, target): # Applied on next frame.
71 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
72 |
73 | def begin_frame(self):
74 | # Begin glfw frame.
75 | super().begin_frame()
76 |
77 | # Process imgui events.
78 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
79 | if self.content_width > 0 and self.content_height > 0:
80 | self._imgui_renderer.process_inputs()
81 |
82 | # Begin imgui frame.
83 | imgui.new_frame()
84 | imgui.push_font(self._imgui_fonts[self._cur_font_size])
85 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
86 |
87 | def end_frame(self):
88 | imgui.pop_font()
89 | imgui.render()
90 | imgui.end_frame()
91 | self._imgui_renderer.render(imgui.get_draw_data())
92 | super().end_frame()
93 |
94 | #----------------------------------------------------------------------------
95 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
96 |
97 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
98 | def __init__(self, *args, **kwargs):
99 | super().__init__(*args, **kwargs)
100 | self.mouse_wheel_multiplier = 1
101 |
102 | def scroll_callback(self, window, x_offset, y_offset):
103 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
104 |
105 | #----------------------------------------------------------------------------
106 |
--------------------------------------------------------------------------------
/gui_utils/text_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import functools
12 | from typing import Optional
13 |
14 | import dnnlib
15 | import numpy as np
16 | import PIL.Image
17 | import PIL.ImageFont
18 | import scipy.ndimage
19 |
20 | from . import gl_utils
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | def get_default_font():
25 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
26 | return dnnlib.util.open_url(url, return_filename=True)
27 |
28 | #----------------------------------------------------------------------------
29 |
30 | @functools.lru_cache(maxsize=None)
31 | def get_pil_font(font=None, size=32):
32 | if font is None:
33 | font = get_default_font()
34 | return PIL.ImageFont.truetype(font=font, size=size)
35 |
36 | #----------------------------------------------------------------------------
37 |
38 | def get_array(string, *, dropshadow_radius: int=None, **kwargs):
39 | if dropshadow_radius is not None:
40 | offset_x = int(np.ceil(dropshadow_radius*2/3))
41 | offset_y = int(np.ceil(dropshadow_radius*2/3))
42 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
43 | else:
44 | return _get_array_priv(string, **kwargs)
45 |
46 | @functools.lru_cache(maxsize=10000)
47 | def _get_array_priv(
48 | string: str, *,
49 | size: int = 32,
50 | max_width: Optional[int]=None,
51 | max_height: Optional[int]=None,
52 | min_size=10,
53 | shrink_coef=0.8,
54 | dropshadow_radius: int=None,
55 | offset_x: int=None,
56 | offset_y: int=None,
57 | **kwargs
58 | ):
59 | cur_size = size
60 | array = None
61 | while True:
62 | if dropshadow_radius is not None:
63 | # separate implementation for dropshadow text rendering
64 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
65 | else:
66 | array = _get_array_impl(string, size=cur_size, **kwargs)
67 | height, width, _ = array.shape
68 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
69 | break
70 | cur_size = max(int(cur_size * shrink_coef), min_size)
71 | return array
72 |
73 | #----------------------------------------------------------------------------
74 |
75 | @functools.lru_cache(maxsize=10000)
76 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
77 | pil_font = get_pil_font(font=font, size=size)
78 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
79 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
80 | width = max(line.shape[1] for line in lines)
81 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
82 | line_spacing = line_pad if line_pad is not None else size // 2
83 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
84 | mask = np.concatenate(lines, axis=0)
85 | alpha = mask
86 | if outline > 0:
87 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
88 | alpha = mask.astype(np.float32) / 255
89 | alpha = scipy.ndimage.gaussian_filter(alpha, outline)
90 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
91 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
92 | alpha = np.maximum(alpha, mask)
93 | return np.stack([mask, alpha], axis=-1)
94 |
95 | #----------------------------------------------------------------------------
96 |
97 | @functools.lru_cache(maxsize=10000)
98 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
99 | assert (offset_x > 0) and (offset_y > 0)
100 | pil_font = get_pil_font(font=font, size=size)
101 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
102 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
103 | width = max(line.shape[1] for line in lines)
104 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
105 | line_spacing = line_pad if line_pad is not None else size // 2
106 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
107 | mask = np.concatenate(lines, axis=0)
108 | alpha = mask
109 |
110 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
111 | alpha = mask.astype(np.float32) / 255
112 | alpha = scipy.ndimage.gaussian_filter(alpha, radius)
113 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
114 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
115 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
116 | alpha = np.maximum(alpha, mask)
117 | return np.stack([mask, alpha], axis=-1)
118 |
119 | #----------------------------------------------------------------------------
120 |
121 | @functools.lru_cache(maxsize=10000)
122 | def get_texture(string, bilinear=True, mipmap=True, **kwargs):
123 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
124 |
125 | #----------------------------------------------------------------------------
126 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
12 |
--------------------------------------------------------------------------------
/metrics/frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Frechet Inception Distance (FID) from the paper
12 | "GANs trained by a two time-scale update rule converge to a local Nash
13 | equilibrium". Matches the original implementation by Heusel et al. at
14 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
15 |
16 | import numpy as np
17 | import scipy.linalg
18 | from . import metric_utils
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | def compute_fid(opts, max_real, num_gen):
23 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
24 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
25 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
26 |
27 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
30 |
31 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
32 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
33 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
34 |
35 | if opts.rank != 0:
36 | return float('nan')
37 |
38 | m = np.square(mu_gen - mu_real).sum()
39 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
40 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
41 | return float(fid)
42 |
43 | #----------------------------------------------------------------------------
44 |
--------------------------------------------------------------------------------
/metrics/inception_score.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Inception Score (IS) from the paper "Improved techniques for training
12 | GANs". Matches the original implementation by Salimans et al. at
13 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
14 |
15 | import numpy as np
16 | from . import metric_utils
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def compute_is(opts, num_gen, num_splits):
21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
23 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
24 |
25 | gen_probs = metric_utils.compute_feature_stats_for_generator(
26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27 | capture_all=True, max_items=num_gen).get_all()
28 |
29 | if opts.rank != 0:
30 | return float('nan'), float('nan')
31 |
32 | scores = []
33 | for i in range(num_splits):
34 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
35 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
36 | kl = np.mean(np.sum(kl, axis=1))
37 | scores.append(np.exp(kl))
38 | return float(np.mean(scores)), float(np.std(scores))
39 |
40 | #----------------------------------------------------------------------------
41 |
--------------------------------------------------------------------------------
/metrics/kernel_inception_distance.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD
12 | GANs". Matches the original implementation by Binkowski et al. at
13 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
14 |
15 | import numpy as np
16 | from . import metric_utils
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24 |
25 | real_features = metric_utils.compute_feature_stats_for_dataset(
26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
28 |
29 | gen_features = metric_utils.compute_feature_stats_for_generator(
30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
32 |
33 | if opts.rank != 0:
34 | return float('nan')
35 |
36 | n = real_features.shape[1]
37 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
38 | t = 0
39 | for _subset_idx in range(num_subsets):
40 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
41 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
42 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
43 | b = (x @ y.T / n + 1) ** 3
44 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
45 | kid = t / num_subsets / m
46 | return float(kid)
47 |
48 | #----------------------------------------------------------------------------
49 |
--------------------------------------------------------------------------------
/metrics/metric_main.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Main API for computing and reporting quality metrics."""
12 |
13 | import os
14 | import time
15 | import json
16 | import torch
17 | import dnnlib
18 |
19 | from . import metric_utils
20 | from . import frechet_inception_distance
21 | from . import kernel_inception_distance
22 | from . import precision_recall
23 | from . import perceptual_path_length
24 | from . import inception_score
25 | from . import equivariance
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | _metric_dict = dict() # name => fn
30 |
31 | def register_metric(fn):
32 | assert callable(fn)
33 | _metric_dict[fn.__name__] = fn
34 | return fn
35 |
36 | def is_valid_metric(metric):
37 | return metric in _metric_dict
38 |
39 | def list_valid_metrics():
40 | return list(_metric_dict.keys())
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
45 | assert is_valid_metric(metric)
46 | opts = metric_utils.MetricOptions(**kwargs)
47 |
48 | # Calculate.
49 | start_time = time.time()
50 | results = _metric_dict[metric](opts)
51 | total_time = time.time() - start_time
52 |
53 | # Broadcast results.
54 | for key, value in list(results.items()):
55 | if opts.num_gpus > 1:
56 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
57 | torch.distributed.broadcast(tensor=value, src=0)
58 | value = float(value.cpu())
59 | results[key] = value
60 |
61 | # Decorate with metadata.
62 | return dnnlib.EasyDict(
63 | results = dnnlib.EasyDict(results),
64 | metric = metric,
65 | total_time = total_time,
66 | total_time_str = dnnlib.util.format_time(total_time),
67 | num_gpus = opts.num_gpus,
68 | mode = opts.mode
69 | )
70 |
71 | #----------------------------------------------------------------------------
72 |
73 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
74 | metric = result_dict['metric']
75 | assert is_valid_metric(metric)
76 | if run_dir is not None and snapshot_pkl is not None:
77 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
78 |
79 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
80 | print(jsonl_line)
81 | if run_dir is not None and os.path.isdir(run_dir):
82 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
83 | f.write(jsonl_line + '\n')
84 |
85 | #----------------------------------------------------------------------------
86 | # Recommended metrics.
87 |
88 | @register_metric
89 | def fid50k_full(opts):
90 | opts.dataset_kwargs.update(max_size=None, xflip=False)
91 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
92 | return dict(fid50k_full=fid)
93 |
94 | @register_metric
95 | def kid50k_full(opts):
96 | opts.dataset_kwargs.update(max_size=None, xflip=False)
97 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
98 | return dict(kid50k_full=kid)
99 |
100 | @register_metric
101 | def pr50k3_full(opts):
102 | opts.dataset_kwargs.update(max_size=None, xflip=False)
103 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
104 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
105 |
106 | @register_metric
107 | def ppl2_wend(opts):
108 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
109 | return dict(ppl2_wend=ppl)
110 |
111 | @register_metric
112 | def eqt50k_int(opts):
113 | opts.G_kwargs.update(force_fp32=True)
114 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
115 | return dict(eqt50k_int=psnr)
116 |
117 | @register_metric
118 | def eqt50k_frac(opts):
119 | opts.G_kwargs.update(force_fp32=True)
120 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
121 | return dict(eqt50k_frac=psnr)
122 |
123 | @register_metric
124 | def eqr50k(opts):
125 | opts.G_kwargs.update(force_fp32=True)
126 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
127 | return dict(eqr50k=psnr)
128 |
129 | #----------------------------------------------------------------------------
130 | # Legacy metrics.
131 |
132 | @register_metric
133 | def fid50k(opts):
134 | opts.dataset_kwargs.update(max_size=None)
135 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
136 | return dict(fid50k=fid)
137 |
138 | @register_metric
139 | def kid50k(opts):
140 | opts.dataset_kwargs.update(max_size=None)
141 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
142 | return dict(kid50k=kid)
143 |
144 | @register_metric
145 | def pr50k3(opts):
146 | opts.dataset_kwargs.update(max_size=None)
147 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
148 | return dict(pr50k3_precision=precision, pr50k3_recall=recall)
149 |
150 | @register_metric
151 | def is50k(opts):
152 | opts.dataset_kwargs.update(max_size=None, xflip=False)
153 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
154 | return dict(is50k_mean=mean, is50k_std=std)
155 |
156 | #----------------------------------------------------------------------------
157 |
--------------------------------------------------------------------------------
/metrics/perceptual_path_length.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
12 | Architecture for Generative Adversarial Networks". Matches the original
13 | implementation by Karras et al. at
14 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
15 |
16 | import copy
17 | import numpy as np
18 | import torch
19 | from . import metric_utils
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | # Spherical interpolation of a batch of vectors.
24 | def slerp(a, b, t):
25 | a = a / a.norm(dim=-1, keepdim=True)
26 | b = b / b.norm(dim=-1, keepdim=True)
27 | d = (a * b).sum(dim=-1, keepdim=True)
28 | p = t * torch.acos(d)
29 | c = b - d * a
30 | c = c / c.norm(dim=-1, keepdim=True)
31 | d = a * torch.cos(p) + c * torch.sin(p)
32 | d = d / d.norm(dim=-1, keepdim=True)
33 | return d
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | class PPLSampler(torch.nn.Module):
38 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
39 | assert space in ['z', 'w']
40 | assert sampling in ['full', 'end']
41 | super().__init__()
42 | self.G = copy.deepcopy(G)
43 | self.G_kwargs = G_kwargs
44 | self.epsilon = epsilon
45 | self.space = space
46 | self.sampling = sampling
47 | self.crop = crop
48 | self.vgg16 = copy.deepcopy(vgg16)
49 |
50 | def forward(self, c):
51 | # Generate random latents and interpolation t-values.
52 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
53 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
54 |
55 | # Interpolate in W or Z.
56 | if self.space == 'w':
57 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
58 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
59 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
60 | else: # space == 'z'
61 | zt0 = slerp(z0, z1, t.unsqueeze(1))
62 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
63 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
64 |
65 | # Randomize noise buffers.
66 | for name, buf in self.G.named_buffers():
67 | if name.endswith('.noise_const'):
68 | buf.copy_(torch.randn_like(buf))
69 |
70 | # Generate images.
71 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
72 |
73 | # Center crop.
74 | if self.crop:
75 | assert img.shape[2] == img.shape[3]
76 | c = img.shape[2] // 8
77 | img = img[:, :, c*3 : c*7, c*2 : c*6]
78 |
79 | # Downsample to 256x256.
80 | factor = self.G.img_resolution // 256
81 | if factor > 1:
82 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
83 |
84 | # Scale dynamic range from [-1,1] to [0,255].
85 | img = (img + 1) * (255 / 2)
86 | if self.G.img_channels == 1:
87 | img = img.repeat([1, 3, 1, 1])
88 |
89 | # Evaluate differential LPIPS.
90 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
91 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
92 | return dist
93 |
94 | #----------------------------------------------------------------------------
95 |
96 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
97 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99 |
100 | # Setup sampler and labels.
101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102 | sampler.eval().requires_grad_(False).to(opts.device)
103 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
104 |
105 | # Sampling loop.
106 | dist = []
107 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
108 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
109 | progress.update(batch_start)
110 | x = sampler(next(c_iter))
111 | for src in range(opts.num_gpus):
112 | y = x.clone()
113 | if opts.num_gpus > 1:
114 | torch.distributed.broadcast(y, src=src)
115 | dist.append(y)
116 | progress.update(num_samples)
117 |
118 | # Compute PPL.
119 | if opts.rank != 0:
120 | return float('nan')
121 | dist = torch.cat(dist)[:num_samples].cpu().numpy()
122 | lo = np.percentile(dist, 1, interpolation='lower')
123 | hi = np.percentile(dist, 99, interpolation='higher')
124 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
125 | return float(ppl)
126 |
127 | #----------------------------------------------------------------------------
128 |
--------------------------------------------------------------------------------
/metrics/precision_recall.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Precision/Recall (PR) from the paper "Improved Precision and Recall
12 | Metric for Assessing Generative Models". Matches the original implementation
13 | by Kynkaanniemi et al. at
14 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
15 |
16 | import torch
17 | from . import metric_utils
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
22 | assert 0 <= rank < num_gpus
23 | num_cols = col_features.shape[0]
24 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
25 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
26 | dist_batches = []
27 | for col_batch in col_batches[rank :: num_gpus]:
28 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
29 | for src in range(num_gpus):
30 | dist_broadcast = dist_batch.clone()
31 | if num_gpus > 1:
32 | torch.distributed.broadcast(dist_broadcast, src=src)
33 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
34 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
35 |
36 | #----------------------------------------------------------------------------
37 |
38 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
39 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
40 | detector_kwargs = dict(return_features=True)
41 |
42 | real_features = metric_utils.compute_feature_stats_for_dataset(
43 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
44 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
45 |
46 | gen_features = metric_utils.compute_feature_stats_for_generator(
47 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
48 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
49 |
50 | results = dict()
51 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
52 | kth = []
53 | for manifold_batch in manifold.split(row_batch_size):
54 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
55 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
56 | kth = torch.cat(kth) if opts.rank == 0 else None
57 | pred = []
58 | for probes_batch in probes.split(row_batch_size):
59 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
60 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
61 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
62 | return results['precision'], results['recall']
63 |
64 | #----------------------------------------------------------------------------
65 |
--------------------------------------------------------------------------------
/misc/segmentation_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from torchvision.transforms import ToPILImage
4 |
5 | from torchvision.models.segmentation import deeplabv3_resnet101
6 | from torchvision import transforms, utils
7 |
8 |
9 | def get_mask(model, batch, cid):
10 | normalized_batch = transforms.functional.normalize(
11 | batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
12 | output = model(normalized_batch)['out']
13 | # sem_classes = [
14 | # '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
15 | # 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
16 | # 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
17 | # ]
18 | # sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
19 | # cid = sem_class_to_idx['car']
20 |
21 | normalized_masks = torch.nn.functional.softmax(output, dim=1)
22 |
23 | boolean_car_masks = (normalized_masks.argmax(1) == cid)
24 | return boolean_car_masks.float()
25 |
26 | image = Image.open('image.jpg')
27 | # Define the preprocessing transformation
28 | preprocess = transforms.Compose([
29 | transforms.Resize((512, 512)),
30 | transforms.ToTensor()
31 | ])
32 |
33 | # Apply the transformation to the image
34 | input_tensor = preprocess(image)
35 | input_batch = input_tensor.unsqueeze(0).to('cuda:0')
36 |
37 | # load segmentation net
38 | seg_net = deeplabv3_resnet101(pretrained=True, progress=False).to('cuda:0')
39 | seg_net.requires_grad_(False)
40 | seg_net.eval()
41 |
42 | # 15 means human mask
43 | mask0 = get_mask(seg_net, input_batch, 15).unsqueeze(1)
44 |
45 | # Squeeze the tensor to remove unnecessary dimensions and convert to PIL Image
46 | mask_squeezed = torch.squeeze(mask0)
47 | mask_image = ToPILImage()(mask_squeezed)
48 |
49 | # Save as PNG
50 | mask_image.save("mask.png")
51 |
--------------------------------------------------------------------------------
/misc/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/misc/teaser.png
--------------------------------------------------------------------------------
/resave_model.py:
--------------------------------------------------------------------------------
1 | """ Reload a model and save it. """
2 | import copy
3 | import os
4 | import pickle
5 |
6 | import click
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | import dnnlib
11 | import legacy
12 |
13 | @click.command()
14 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
15 | @click.option('--output', 'output_pkl', help='Network pickle filename', required=True)
16 | def main(
17 | network_pkl: str,
18 | output_pkl: str,
19 | ):
20 | # Load networks.
21 | print('Loading networks from "%s"...' % network_pkl)
22 | with dnnlib.util.open_url(network_pkl) as fp:
23 | data = legacy.load_network_pkl(fp)
24 | data_new = {}
25 | for name in data.keys():
26 | module = data.get(name, None)
27 | if module is not None and isinstance(module, torch.nn.Module):
28 | module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
29 | data_new[name] = module
30 | del module # conserve memory
31 | with open(output_pkl, 'wb') as f:
32 | pickle.dump(data_new, f)
33 | #----------------------------------------------------------------------------
34 |
35 | if __name__ == "__main__":
36 | main() # pylint: disable=no-value-for-parameter
37 |
38 | #----------------------------------------------------------------------------
--------------------------------------------------------------------------------
/shape_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | """
13 | Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.)
14 |
15 | Takes as input an .mrc file and extracts a mesh.
16 |
17 | Ex.
18 | python shape_utils.py my_shape.mrc
19 | Ex.
20 | python shape_utils.py myshapes_directory --level=12
21 | """
22 |
23 |
24 | import time
25 | import plyfile
26 | import glob
27 | import logging
28 | import numpy as np
29 | import os
30 | import random
31 | import torch
32 | import torch.utils.data
33 | import trimesh
34 | import skimage.measure
35 | import argparse
36 | import mrcfile
37 | from tqdm import tqdm
38 |
39 |
40 | def convert_sdf_samples_to_ply(
41 | numpy_3d_sdf_tensor,
42 | voxel_grid_origin,
43 | voxel_size,
44 | ply_filename_out,
45 | offset=None,
46 | scale=None,
47 | level=0.0
48 | ):
49 | """
50 | Convert sdf samples to .ply
51 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
52 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
53 | :voxel_size: float, the size of the voxels
54 | :ply_filename_out: string, path of the filename to save to
55 | This function adapted from: https://github.com/RobotLocomotion/spartan
56 | """
57 | start_time = time.time()
58 |
59 | verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0)
60 | # try:
61 | verts, faces, normals, values = skimage.measure.marching_cubes(
62 | numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3
63 | )
64 | # except:
65 | # pass
66 |
67 | # transform from voxel coordinates to camera coordinates
68 | # note x and y are flipped in the output of marching_cubes
69 | mesh_points = np.zeros_like(verts)
70 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
71 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
72 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
73 |
74 | # apply additional offset and scale
75 | if scale is not None:
76 | mesh_points = mesh_points / scale
77 | if offset is not None:
78 | mesh_points = mesh_points - offset
79 |
80 | # try writing to the ply file
81 |
82 | num_verts = verts.shape[0]
83 | num_faces = faces.shape[0]
84 |
85 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
86 |
87 | for i in range(0, num_verts):
88 | verts_tuple[i] = tuple(mesh_points[i, :])
89 |
90 | faces_building = []
91 | for i in range(0, num_faces):
92 | faces_building.append(((faces[i, :].tolist(),)))
93 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
94 |
95 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
96 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
97 |
98 | ply_data = plyfile.PlyData([el_verts, el_faces])
99 | ply_data.write(ply_filename_out)
100 | print(f"wrote to {ply_filename_out}")
101 |
102 |
103 | def convert_mrc(input_filename, output_filename, isosurface_level=1):
104 | with mrcfile.open(input_filename) as mrc:
105 | convert_sdf_samples_to_ply(np.transpose(mrc.data, (2, 1, 0)), [0, 0, 0], 1, output_filename, level=isosurface_level)
106 |
107 | if __name__ == '__main__':
108 | start_time = time.time()
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument('input_mrc_path')
111 | parser.add_argument('--level', type=float, default=10, help="The isosurface level for marching cubes")
112 | args = parser.parse_args()
113 |
114 | if os.path.isfile(args.input_mrc_path) and args.input_mrc_path.split('.')[-1] == 'ply':
115 | output_obj_path = args.input_mrc_path.split('.mrc')[0] + '.ply'
116 | convert_mrc(args.input_mrc_path, output_obj_path, isosurface_level=1)
117 |
118 | print(f"{time.time() - start_time:02f} s")
119 | else:
120 | assert os.path.isdir(args.input_mrc_path)
121 |
122 | for mrc_path in tqdm(glob.glob(os.path.join(args.input_mrc_path, '*.mrc'))):
123 | output_obj_path = mrc_path.split('.mrc')[0] + '.ply'
124 | convert_mrc(mrc_path, output_obj_path, isosurface_level=args.level)
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
12 |
--------------------------------------------------------------------------------
/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import glob
12 | import hashlib
13 | import importlib
14 | import os
15 | import re
16 | import shutil
17 | import uuid
18 |
19 | import torch
20 | import torch.utils.cpp_extension
21 | from torch.utils.file_baton import FileBaton
22 |
23 | #----------------------------------------------------------------------------
24 | # Global options.
25 |
26 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
27 |
28 | #----------------------------------------------------------------------------
29 | # Internal helper funcs.
30 |
31 | def _find_compiler_bindir():
32 | patterns = [
33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
34 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
35 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
36 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
37 | ]
38 | for pattern in patterns:
39 | matches = sorted(glob.glob(pattern))
40 | if len(matches):
41 | return matches[-1]
42 | return None
43 |
44 | #----------------------------------------------------------------------------
45 |
46 | def _get_mangled_gpu_name():
47 | name = torch.cuda.get_device_name().lower()
48 | out = []
49 | for c in name:
50 | if re.match('[a-z0-9_-]+', c):
51 | out.append(c)
52 | else:
53 | out.append('-')
54 | return ''.join(out)
55 |
56 | #----------------------------------------------------------------------------
57 | # Main entry point for compiling and loading C++/CUDA plugins.
58 |
59 | _cached_plugins = dict()
60 |
61 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
62 | assert verbosity in ['none', 'brief', 'full']
63 | if headers is None:
64 | headers = []
65 | if source_dir is not None:
66 | sources = [os.path.join(source_dir, fname) for fname in sources]
67 | headers = [os.path.join(source_dir, fname) for fname in headers]
68 |
69 | # Already cached?
70 | if module_name in _cached_plugins:
71 | return _cached_plugins[module_name]
72 |
73 | # Print status.
74 | if verbosity == 'full':
75 | print(f'Setting up PyTorch plugin "{module_name}"...')
76 | elif verbosity == 'brief':
77 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
78 | verbose_build = (verbosity == 'full')
79 |
80 | # Compile and load.
81 | try: # pylint: disable=too-many-nested-blocks
82 | # Make sure we can find the necessary compiler binaries.
83 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
84 | compiler_bindir = _find_compiler_bindir()
85 | if compiler_bindir is None:
86 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
87 | os.environ['PATH'] += ';' + compiler_bindir
88 |
89 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
90 | # break the build or unnecessarily restrict what's available to nvcc.
91 | # Unset it to let nvcc decide based on what's available on the
92 | # machine.
93 | os.environ['TORCH_CUDA_ARCH_LIST'] = ''
94 |
95 | # Incremental build md5sum trickery. Copies all the input source files
96 | # into a cached build directory under a combined md5 digest of the input
97 | # source files. Copying is done only if the combined digest has changed.
98 | # This keeps input file timestamps and filenames the same as in previous
99 | # extension builds, allowing for fast incremental rebuilds.
100 | #
101 | # This optimization is done only in case all the source files reside in
102 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
103 | # environment variable is set (we take this as a signal that the user
104 | # actually cares about this.)
105 | #
106 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
107 | # around the *.cu dependency bug in ninja config.
108 | #
109 | all_source_files = sorted(sources + headers)
110 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
111 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
112 |
113 | # Compute combined hash digest for all source files.
114 | hash_md5 = hashlib.md5()
115 | for src in all_source_files:
116 | with open(src, 'rb') as f:
117 | hash_md5.update(f.read())
118 |
119 | # Select cached build directory name.
120 | source_digest = hash_md5.hexdigest()
121 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
122 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
123 |
124 | if not os.path.isdir(cached_build_dir):
125 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
126 | os.makedirs(tmpdir)
127 | for src in all_source_files:
128 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
129 | try:
130 | os.replace(tmpdir, cached_build_dir) # atomic
131 | except OSError:
132 | # source directory already exists, delete tmpdir and its contents.
133 | shutil.rmtree(tmpdir)
134 | if not os.path.isdir(cached_build_dir): raise
135 |
136 | # Compile.
137 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
138 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
139 | verbose=verbose_build, sources=cached_sources, **build_kwargs)
140 | else:
141 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
142 |
143 | # Load.
144 | module = importlib.import_module(module_name)
145 |
146 | except:
147 | if verbosity == 'brief':
148 | print('Failed!')
149 | raise
150 |
151 | # Print status and add to cache dict.
152 | if verbosity == 'full':
153 | print(f'Done setting up PyTorch plugin "{module_name}".')
154 | elif verbosity == 'brief':
155 | print('Done.')
156 | _cached_plugins[module_name] = module
157 | return module
158 |
159 | #----------------------------------------------------------------------------
160 |
--------------------------------------------------------------------------------
/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
12 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 | #include
15 | #include
16 | #include "bias_act.h"
17 |
18 | //------------------------------------------------------------------------
19 |
20 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
21 | {
22 | if (x.dim() != y.dim())
23 | return false;
24 | for (int64_t i = 0; i < x.dim(); i++)
25 | {
26 | if (x.size(i) != y.size(i))
27 | return false;
28 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
29 | return false;
30 | }
31 | return true;
32 | }
33 |
34 | //------------------------------------------------------------------------
35 |
36 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
37 | {
38 | // Validate arguments.
39 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
40 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
41 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
42 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
43 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
44 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
45 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
46 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
47 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
48 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
49 |
50 | // Validate layout.
51 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
52 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
53 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
54 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
55 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
56 |
57 | // Create output tensor.
58 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
59 | torch::Tensor y = torch::empty_like(x);
60 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
61 |
62 | // Initialize CUDA kernel parameters.
63 | bias_act_kernel_params p;
64 | p.x = x.data_ptr();
65 | p.b = (b.numel()) ? b.data_ptr() : NULL;
66 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
67 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
68 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
69 | p.y = y.data_ptr();
70 | p.grad = grad;
71 | p.act = act;
72 | p.alpha = alpha;
73 | p.gain = gain;
74 | p.clamp = clamp;
75 | p.sizeX = (int)x.numel();
76 | p.sizeB = (int)b.numel();
77 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
78 |
79 | // Choose CUDA kernel.
80 | void* kernel;
81 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
82 | {
83 | kernel = choose_bias_act_kernel(p);
84 | });
85 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
86 |
87 | // Launch CUDA kernel.
88 | p.loopX = 4;
89 | int blockSize = 4 * 32;
90 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
91 | void* args[] = {&p};
92 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93 | return y;
94 | }
95 |
96 | //------------------------------------------------------------------------
97 |
98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99 | {
100 | m.def("bias_act", &bias_act);
101 | }
102 |
103 | //------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 | #include "bias_act.h"
15 |
16 | //------------------------------------------------------------------------
17 | // Helpers.
18 |
19 | template struct InternalType;
20 | template <> struct InternalType { typedef double scalar_t; };
21 | template <> struct InternalType { typedef float scalar_t; };
22 | template <> struct InternalType { typedef float scalar_t; };
23 |
24 | //------------------------------------------------------------------------
25 | // CUDA kernel.
26 |
27 | template
28 | __global__ void bias_act_kernel(bias_act_kernel_params p)
29 | {
30 | typedef typename InternalType::scalar_t scalar_t;
31 | int G = p.grad;
32 | scalar_t alpha = (scalar_t)p.alpha;
33 | scalar_t gain = (scalar_t)p.gain;
34 | scalar_t clamp = (scalar_t)p.clamp;
35 | scalar_t one = (scalar_t)1;
36 | scalar_t two = (scalar_t)2;
37 | scalar_t expRange = (scalar_t)80;
38 | scalar_t halfExpRange = (scalar_t)40;
39 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
40 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
41 |
42 | // Loop over elements.
43 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
44 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
45 | {
46 | // Load.
47 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
48 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
49 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
50 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
51 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
52 | scalar_t yy = (gain != 0) ? yref / gain : 0;
53 | scalar_t y = 0;
54 |
55 | // Apply bias.
56 | ((G == 0) ? x : xref) += b;
57 |
58 | // linear
59 | if (A == 1)
60 | {
61 | if (G == 0) y = x;
62 | if (G == 1) y = x;
63 | }
64 |
65 | // relu
66 | if (A == 2)
67 | {
68 | if (G == 0) y = (x > 0) ? x : 0;
69 | if (G == 1) y = (yy > 0) ? x : 0;
70 | }
71 |
72 | // lrelu
73 | if (A == 3)
74 | {
75 | if (G == 0) y = (x > 0) ? x : x * alpha;
76 | if (G == 1) y = (yy > 0) ? x : x * alpha;
77 | }
78 |
79 | // tanh
80 | if (A == 4)
81 | {
82 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
83 | if (G == 1) y = x * (one - yy * yy);
84 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
85 | }
86 |
87 | // sigmoid
88 | if (A == 5)
89 | {
90 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
91 | if (G == 1) y = x * yy * (one - yy);
92 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
93 | }
94 |
95 | // elu
96 | if (A == 6)
97 | {
98 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
99 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
100 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
101 | }
102 |
103 | // selu
104 | if (A == 7)
105 | {
106 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
107 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
108 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
109 | }
110 |
111 | // softplus
112 | if (A == 8)
113 | {
114 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
115 | if (G == 1) y = x * (one - exp(-yy));
116 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
117 | }
118 |
119 | // swish
120 | if (A == 9)
121 | {
122 | if (G == 0)
123 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
124 | else
125 | {
126 | scalar_t c = exp(xref);
127 | scalar_t d = c + one;
128 | if (G == 1)
129 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
130 | else
131 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
132 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
133 | }
134 | }
135 |
136 | // Apply gain.
137 | y *= gain * dy;
138 |
139 | // Clamp.
140 | if (clamp >= 0)
141 | {
142 | if (G == 0)
143 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
144 | else
145 | y = (yref > -clamp & yref < clamp) ? y : 0;
146 | }
147 |
148 | // Store.
149 | ((T*)p.y)[xi] = (T)y;
150 | }
151 | }
152 |
153 | //------------------------------------------------------------------------
154 | // CUDA kernel selection.
155 |
156 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
157 | {
158 | if (p.act == 1) return (void*)bias_act_kernel;
159 | if (p.act == 2) return (void*)bias_act_kernel;
160 | if (p.act == 3) return (void*)bias_act_kernel;
161 | if (p.act == 4) return (void*)bias_act_kernel;
162 | if (p.act == 5) return (void*)bias_act_kernel;
163 | if (p.act == 6) return (void*)bias_act_kernel;
164 | if (p.act == 7) return (void*)bias_act_kernel;
165 | if (p.act == 8) return (void*)bias_act_kernel;
166 | if (p.act == 9) return (void*)bias_act_kernel;
167 | return NULL;
168 | }
169 |
170 | //------------------------------------------------------------------------
171 | // Template specializations.
172 |
173 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
174 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
175 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
176 |
177 | //------------------------------------------------------------------------
178 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | //------------------------------------------------------------------------
14 | // CUDA kernel parameters.
15 |
16 | struct bias_act_kernel_params
17 | {
18 | const void* x; // [sizeX]
19 | const void* b; // [sizeB] or NULL
20 | const void* xref; // [sizeX] or NULL
21 | const void* yref; // [sizeX] or NULL
22 | const void* dy; // [sizeX] or NULL
23 | void* y; // [sizeX]
24 |
25 | int grad;
26 | int act;
27 | float alpha;
28 | float gain;
29 | float clamp;
30 |
31 | int sizeX;
32 | int sizeB;
33 | int stepB;
34 | int loopX;
35 | };
36 |
37 | //------------------------------------------------------------------------
38 | // CUDA kernel selection.
39 |
40 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
41 |
42 | //------------------------------------------------------------------------
43 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Custom PyTorch ops for efficient bias and activation."""
12 |
13 | import os
14 | import numpy as np
15 | import torch
16 | import dnnlib
17 |
18 | from .. import custom_ops
19 | from .. import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | activation_funcs = {
24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33 | }
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | _plugin = None
38 | _null_tensor = torch.empty([0])
39 |
40 | def _init():
41 | global _plugin
42 | if _plugin is None:
43 | _plugin = custom_ops.get_plugin(
44 | module_name='bias_act_plugin',
45 | sources=['bias_act.cpp', 'bias_act.cu'],
46 | headers=['bias_act.h'],
47 | source_dir=os.path.dirname(__file__),
48 | extra_cuda_cflags=['--use_fast_math'],
49 | )
50 | return True
51 |
52 | #----------------------------------------------------------------------------
53 |
54 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
55 | r"""Fused bias and activation function.
56 |
57 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
58 | and scales the result by `gain`. Each of the steps is optional. In most cases,
59 | the fused op is considerably more efficient than performing the same calculation
60 | using standard PyTorch ops. It supports first and second order gradients,
61 | but not third order gradients.
62 |
63 | Args:
64 | x: Input activation tensor. Can be of any shape.
65 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
66 | as `x`. The shape must be known, and it must match the dimension of `x`
67 | corresponding to `dim`.
68 | dim: The dimension in `x` corresponding to the elements of `b`.
69 | The value of `dim` is ignored if `b` is not specified.
70 | act: Name of the activation function to evaluate, or `"linear"` to disable.
71 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
72 | See `activation_funcs` for a full list. `None` is not allowed.
73 | alpha: Shape parameter for the activation function, or `None` to use the default.
74 | gain: Scaling factor for the output tensor, or `None` to use default.
75 | See `activation_funcs` for the default scaling of each activation function.
76 | If unsure, consider specifying 1.
77 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
78 | the clamping (default).
79 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
80 |
81 | Returns:
82 | Tensor of the same shape and datatype as `x`.
83 | """
84 | assert isinstance(x, torch.Tensor)
85 | assert impl in ['ref', 'cuda']
86 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
87 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
88 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
89 |
90 | #----------------------------------------------------------------------------
91 |
92 | @misc.profiled_function
93 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
94 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
95 | """
96 | assert isinstance(x, torch.Tensor)
97 | assert clamp is None or clamp >= 0
98 | spec = activation_funcs[act]
99 | alpha = float(alpha if alpha is not None else spec.def_alpha)
100 | gain = float(gain if gain is not None else spec.def_gain)
101 | clamp = float(clamp if clamp is not None else -1)
102 |
103 | # Add bias.
104 | if b is not None:
105 | assert isinstance(b, torch.Tensor) and b.ndim == 1
106 | assert 0 <= dim < x.ndim
107 | assert b.shape[0] == x.shape[dim]
108 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
109 |
110 | # Evaluate activation function.
111 | alpha = float(alpha)
112 | x = spec.func(x, alpha=alpha)
113 |
114 | # Scale by gain.
115 | gain = float(gain)
116 | if gain != 1:
117 | x = x * gain
118 |
119 | # Clamp.
120 | if clamp >= 0:
121 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
122 | return x
123 |
124 | #----------------------------------------------------------------------------
125 |
126 | _bias_act_cuda_cache = dict()
127 |
128 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
129 | """Fast CUDA implementation of `bias_act()` using custom ops.
130 | """
131 | # Parse arguments.
132 | assert clamp is None or clamp >= 0
133 | spec = activation_funcs[act]
134 | alpha = float(alpha if alpha is not None else spec.def_alpha)
135 | gain = float(gain if gain is not None else spec.def_gain)
136 | clamp = float(clamp if clamp is not None else -1)
137 |
138 | # Lookup from cache.
139 | key = (dim, act, alpha, gain, clamp)
140 | if key in _bias_act_cuda_cache:
141 | return _bias_act_cuda_cache[key]
142 |
143 | # Forward op.
144 | class BiasActCuda(torch.autograd.Function):
145 | @staticmethod
146 | def forward(ctx, x, b): # pylint: disable=arguments-differ
147 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
148 | x = x.contiguous(memory_format=ctx.memory_format)
149 | b = b.contiguous() if b is not None else _null_tensor
150 | y = x
151 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
152 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
153 | ctx.save_for_backward(
154 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
155 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156 | y if 'y' in spec.ref else _null_tensor)
157 | return y
158 |
159 | @staticmethod
160 | def backward(ctx, dy): # pylint: disable=arguments-differ
161 | dy = dy.contiguous(memory_format=ctx.memory_format)
162 | x, b, y = ctx.saved_tensors
163 | dx = None
164 | db = None
165 |
166 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
167 | dx = dy
168 | if act != 'linear' or gain != 1 or clamp >= 0:
169 | dx = BiasActCudaGrad.apply(dy, x, b, y)
170 |
171 | if ctx.needs_input_grad[1]:
172 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
173 |
174 | return dx, db
175 |
176 | # Backward op.
177 | class BiasActCudaGrad(torch.autograd.Function):
178 | @staticmethod
179 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
180 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
181 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
182 | ctx.save_for_backward(
183 | dy if spec.has_2nd_grad else _null_tensor,
184 | x, b, y)
185 | return dx
186 |
187 | @staticmethod
188 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
189 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
190 | dy, x, b, y = ctx.saved_tensors
191 | d_dy = None
192 | d_x = None
193 | d_b = None
194 | d_y = None
195 |
196 | if ctx.needs_input_grad[0]:
197 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
198 |
199 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
200 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
201 |
202 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
203 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
204 |
205 | return d_dy, d_x, d_b, d_y
206 |
207 | # Add to cache.
208 | _bias_act_cuda_cache[key] = BiasActCuda
209 | return BiasActCuda
210 |
211 | #----------------------------------------------------------------------------
212 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Custom replacement for `torch.nn.functional.conv2d` that supports
12 | arbitrarily high order gradients with zero performance penalty."""
13 |
14 | import contextlib
15 | import torch
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
25 |
26 | @contextlib.contextmanager
27 | def no_weight_gradients(disable=True):
28 | global weight_gradients_disabled
29 | old = weight_gradients_disabled
30 | if disable:
31 | weight_gradients_disabled = True
32 | yield
33 | weight_gradients_disabled = old
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
38 | if _should_use_custom_op(input):
39 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
40 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
41 |
42 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
43 | if _should_use_custom_op(input):
44 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
45 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _should_use_custom_op(input):
50 | assert isinstance(input, torch.Tensor)
51 | if (not enabled) or (not torch.backends.cudnn.enabled):
52 | return False
53 | if input.device.type != 'cuda':
54 | return False
55 | return True
56 |
57 | def _tuple_of_ints(xs, ndim):
58 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
59 | assert len(xs) == ndim
60 | assert all(isinstance(x, int) for x in xs)
61 | return xs
62 |
63 | #----------------------------------------------------------------------------
64 |
65 | _conv2d_gradfix_cache = dict()
66 | _null_tensor = torch.empty([0])
67 |
68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69 | # Parse arguments.
70 | ndim = 2
71 | weight_shape = tuple(weight_shape)
72 | stride = _tuple_of_ints(stride, ndim)
73 | padding = _tuple_of_ints(padding, ndim)
74 | output_padding = _tuple_of_ints(output_padding, ndim)
75 | dilation = _tuple_of_ints(dilation, ndim)
76 |
77 | # Lookup from cache.
78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79 | if key in _conv2d_gradfix_cache:
80 | return _conv2d_gradfix_cache[key]
81 |
82 | # Validate arguments.
83 | assert groups >= 1
84 | assert len(weight_shape) == ndim + 2
85 | assert all(stride[i] >= 1 for i in range(ndim))
86 | assert all(padding[i] >= 0 for i in range(ndim))
87 | assert all(dilation[i] >= 0 for i in range(ndim))
88 | if not transpose:
89 | assert all(output_padding[i] == 0 for i in range(ndim))
90 | else: # transpose
91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92 |
93 | # Helpers.
94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95 | def calc_output_padding(input_shape, output_shape):
96 | if transpose:
97 | return [0, 0]
98 | return [
99 | input_shape[i + 2]
100 | - (output_shape[i + 2] - 1) * stride[i]
101 | - (1 - 2 * padding[i])
102 | - dilation[i] * (weight_shape[i + 2] - 1)
103 | for i in range(ndim)
104 | ]
105 |
106 | # Forward & backward.
107 | class Conv2d(torch.autograd.Function):
108 | @staticmethod
109 | def forward(ctx, input, weight, bias):
110 | assert weight.shape == weight_shape
111 | ctx.save_for_backward(
112 | input if weight.requires_grad else _null_tensor,
113 | weight if input.requires_grad else _null_tensor,
114 | )
115 | ctx.input_shape = input.shape
116 |
117 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
118 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
119 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
120 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
121 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
122 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
123 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
124 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
125 |
126 | # General case => cuDNN.
127 | if transpose:
128 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
129 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
130 |
131 | @staticmethod
132 | def backward(ctx, grad_output):
133 | input, weight = ctx.saved_tensors
134 | input_shape = ctx.input_shape
135 | grad_input = None
136 | grad_weight = None
137 | grad_bias = None
138 |
139 | if ctx.needs_input_grad[0]:
140 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
141 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
142 | grad_input = op.apply(grad_output, weight, None)
143 | assert grad_input.shape == input_shape
144 |
145 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
146 | grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
147 | assert grad_weight.shape == weight_shape
148 |
149 | if ctx.needs_input_grad[2]:
150 | grad_bias = grad_output.sum([0, 2, 3])
151 |
152 | return grad_input, grad_weight, grad_bias
153 |
154 | # Gradient with respect to the weights.
155 | class Conv2dGradWeight(torch.autograd.Function):
156 | @staticmethod
157 | def forward(ctx, grad_output, input, weight):
158 | ctx.save_for_backward(
159 | grad_output if input.requires_grad else _null_tensor,
160 | input if grad_output.requires_grad else _null_tensor,
161 | )
162 | ctx.grad_output_shape = grad_output.shape
163 | ctx.input_shape = input.shape
164 |
165 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
166 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
167 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
168 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
169 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
170 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
171 |
172 | # General case => cuDNN.
173 | return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
174 |
175 |
176 | @staticmethod
177 | def backward(ctx, grad2_grad_weight):
178 | grad_output, input = ctx.saved_tensors
179 | grad_output_shape = ctx.grad_output_shape
180 | input_shape = ctx.input_shape
181 | grad2_grad_output = None
182 | grad2_input = None
183 |
184 | if ctx.needs_input_grad[0]:
185 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
186 | assert grad2_grad_output.shape == grad_output_shape
187 |
188 | if ctx.needs_input_grad[1]:
189 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
190 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
191 | grad2_input = op.apply(grad_output, grad2_grad_weight, None)
192 | assert grad2_input.shape == input_shape
193 |
194 | return grad2_grad_output, grad2_input
195 |
196 | _conv2d_gradfix_cache[key] = Conv2d
197 | return Conv2d
198 |
199 | #----------------------------------------------------------------------------
200 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """2D convolution with optional up/downsampling."""
12 |
13 | import torch
14 |
15 | from .. import misc
16 | from . import conv2d_gradfix
17 | from . import upfirdn2d
18 | from .upfirdn2d import _parse_padding
19 | from .upfirdn2d import _get_filter_size
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def _get_weight_shape(w):
24 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
25 | shape = [int(sz) for sz in w.shape]
26 | misc.assert_shape(w, shape)
27 | return shape
28 |
29 | #----------------------------------------------------------------------------
30 |
31 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
32 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
33 | """
34 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
35 |
36 | # Flip weight if requested.
37 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
38 | if not flip_weight and (kw > 1 or kh > 1):
39 | w = w.flip([2, 3])
40 |
41 | # Execute using conv2d_gradfix.
42 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
43 | return op(x, w, stride=stride, padding=padding, groups=groups)
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | @misc.profiled_function
48 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
49 | r"""2D convolution with optional up/downsampling.
50 |
51 | Padding is performed only once at the beginning, not between the operations.
52 |
53 | Args:
54 | x: Input tensor of shape
55 | `[batch_size, in_channels, in_height, in_width]`.
56 | w: Weight tensor of shape
57 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
58 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
59 | calling upfirdn2d.setup_filter(). None = identity (default).
60 | up: Integer upsampling factor (default: 1).
61 | down: Integer downsampling factor (default: 1).
62 | padding: Padding with respect to the upsampled image. Can be a single number
63 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
64 | (default: 0).
65 | groups: Split input channels into N groups (default: 1).
66 | flip_weight: False = convolution, True = correlation (default: True).
67 | flip_filter: False = convolution, True = correlation (default: False).
68 |
69 | Returns:
70 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
71 | """
72 | # Validate arguments.
73 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
74 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
75 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
76 | assert isinstance(up, int) and (up >= 1)
77 | assert isinstance(down, int) and (down >= 1)
78 | assert isinstance(groups, int) and (groups >= 1)
79 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
80 | fw, fh = _get_filter_size(f)
81 | px0, px1, py0, py1 = _parse_padding(padding)
82 |
83 | # Adjust padding to account for up/downsampling.
84 | if up > 1:
85 | px0 += (fw + up - 1) // 2
86 | px1 += (fw - up) // 2
87 | py0 += (fh + up - 1) // 2
88 | py1 += (fh - up) // 2
89 | if down > 1:
90 | px0 += (fw - down + 1) // 2
91 | px1 += (fw - down) // 2
92 | py0 += (fh - down + 1) // 2
93 | py1 += (fh - down) // 2
94 |
95 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
96 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
97 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
98 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
99 | return x
100 |
101 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
102 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
103 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
104 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
105 | return x
106 |
107 | # Fast path: downsampling only => use strided convolution.
108 | if down > 1 and up == 1:
109 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
110 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
111 | return x
112 |
113 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
114 | if up > 1:
115 | if groups == 1:
116 | w = w.transpose(0, 1)
117 | else:
118 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
119 | w = w.transpose(1, 2)
120 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
121 | px0 -= kw - 1
122 | px1 -= kw - up
123 | py0 -= kh - 1
124 | py1 -= kh - up
125 | pxt = max(min(-px0, -px1), 0)
126 | pyt = max(min(-py0, -py1), 0)
127 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
128 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
129 | if down > 1:
130 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
131 | return x
132 |
133 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
134 | if up == 1 and down == 1:
135 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
136 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
137 |
138 | # Fallback: Generic reference implementation.
139 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
140 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
141 | if down > 1:
142 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
143 | return x
144 |
145 | #----------------------------------------------------------------------------
146 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 |
15 | //------------------------------------------------------------------------
16 | // CUDA kernel parameters.
17 |
18 | struct filtered_lrelu_kernel_params
19 | {
20 | // These parameters decide which kernel to use.
21 | int up; // upsampling ratio (1, 2, 4)
22 | int down; // downsampling ratio (1, 2, 4)
23 | int2 fuShape; // [size, 1] | [size, size]
24 | int2 fdShape; // [size, 1] | [size, size]
25 |
26 | int _dummy; // Alignment.
27 |
28 | // Rest of the parameters.
29 | const void* x; // Input tensor.
30 | void* y; // Output tensor.
31 | const void* b; // Bias tensor.
32 | unsigned char* s; // Sign tensor in/out. NULL if unused.
33 | const float* fu; // Upsampling filter.
34 | const float* fd; // Downsampling filter.
35 |
36 | int2 pad0; // Left/top padding.
37 | float gain; // Additional gain factor.
38 | float slope; // Leaky ReLU slope on negative side.
39 | float clamp; // Clamp after nonlinearity.
40 | int flip; // Filter kernel flip for gradient computation.
41 |
42 | int tilesXdim; // Original number of horizontal output tiles.
43 | int tilesXrep; // Number of horizontal tiles per CTA.
44 | int blockZofs; // Block z offset to support large minibatch, channel dimensions.
45 |
46 | int4 xShape; // [width, height, channel, batch]
47 | int4 yShape; // [width, height, channel, batch]
48 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
49 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
50 | int swLimit; // Active width of sign tensor in bytes.
51 |
52 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
53 | longlong4 yStride; //
54 | int64_t bStride; //
55 | longlong3 fuStride; //
56 | longlong3 fdStride; //
57 | };
58 |
59 | struct filtered_lrelu_act_kernel_params
60 | {
61 | void* x; // Input/output, modified in-place.
62 | unsigned char* s; // Sign tensor in/out. NULL if unused.
63 |
64 | float gain; // Additional gain factor.
65 | float slope; // Leaky ReLU slope on negative side.
66 | float clamp; // Clamp after nonlinearity.
67 |
68 | int4 xShape; // [width, height, channel, batch]
69 | longlong4 xStride; // Input/output tensor strides, same order as in shape.
70 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
71 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
72 | };
73 |
74 | //------------------------------------------------------------------------
75 | // CUDA kernel specialization.
76 |
77 | struct filtered_lrelu_kernel_spec
78 | {
79 | void* setup; // Function for filter kernel setup.
80 | void* exec; // Function for main operation.
81 | int2 tileOut; // Width/height of launch tile.
82 | int numWarps; // Number of warps per thread block, determines launch block size.
83 | int xrep; // For processing multiple horizontal tiles per thread block.
84 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
85 | };
86 |
87 | //------------------------------------------------------------------------
88 | // CUDA kernel selection.
89 |
90 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
91 | template void* choose_filtered_lrelu_act_kernel(void);
92 | template cudaError_t copy_filters(cudaStream_t stream);
93 |
94 | //------------------------------------------------------------------------
95 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include "filtered_lrelu.cu"
14 |
15 | // Template/kernel specializations for no signs mode (no gradients required).
16 |
17 | // Full op, 32-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Full op, 64-bit indexing.
22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
24 |
25 | // Activation/signs only for generic variant. 64-bit indexing.
26 | template void* choose_filtered_lrelu_act_kernel(void);
27 | template void* choose_filtered_lrelu_act_kernel(void);
28 | template void* choose_filtered_lrelu_act_kernel(void);
29 |
30 | // Copy filters to constant memory.
31 | template cudaError_t copy_filters(cudaStream_t stream);
32 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include "filtered_lrelu.cu"
14 |
15 | // Template/kernel specializations for sign read mode.
16 |
17 | // Full op, 32-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Full op, 64-bit indexing.
22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
24 |
25 | // Activation/signs only for generic variant. 64-bit indexing.
26 | template void* choose_filtered_lrelu_act_kernel(void);
27 | template void* choose_filtered_lrelu_act_kernel(void);
28 | template void* choose_filtered_lrelu_act_kernel(void);
29 |
30 | // Copy filters to constant memory.
31 | template cudaError_t copy_filters(cudaStream_t stream);
32 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include "filtered_lrelu.cu"
14 |
15 | // Template/kernel specializations for sign write mode.
16 |
17 | // Full op, 32-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Full op, 64-bit indexing.
22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
24 |
25 | // Activation/signs only for generic variant. 64-bit indexing.
26 | template void* choose_filtered_lrelu_act_kernel(void);
27 | template void* choose_filtered_lrelu_act_kernel(void);
28 | template void* choose_filtered_lrelu_act_kernel(void);
29 |
30 | // Copy filters to constant memory.
31 | template cudaError_t copy_filters(cudaStream_t stream);
32 |
--------------------------------------------------------------------------------
/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
12 |
13 | import torch
14 |
15 | #----------------------------------------------------------------------------
16 |
17 | def fma(a, b, c): # => a * b + c
18 | return _FusedMultiplyAdd.apply(a, b, c)
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
23 | @staticmethod
24 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
25 | out = torch.addcmul(c, a, b)
26 | ctx.save_for_backward(a, b)
27 | ctx.c_shape = c.shape
28 | return out
29 |
30 | @staticmethod
31 | def backward(ctx, dout): # pylint: disable=arguments-differ
32 | a, b = ctx.saved_tensors
33 | c_shape = ctx.c_shape
34 | da = None
35 | db = None
36 | dc = None
37 |
38 | if ctx.needs_input_grad[0]:
39 | da = _unbroadcast(dout * b, a.shape)
40 |
41 | if ctx.needs_input_grad[1]:
42 | db = _unbroadcast(dout * a, b.shape)
43 |
44 | if ctx.needs_input_grad[2]:
45 | dc = _unbroadcast(dout, c_shape)
46 |
47 | return da, db, dc
48 |
49 | #----------------------------------------------------------------------------
50 |
51 | def _unbroadcast(x, shape):
52 | extra_dims = x.ndim - len(shape)
53 | assert extra_dims >= 0
54 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
55 | if len(dim):
56 | x = x.sum(dim=dim, keepdim=True)
57 | if extra_dims:
58 | x = x.reshape(-1, *x.shape[extra_dims+1:])
59 | assert x.shape == shape
60 | return x
61 |
62 | #----------------------------------------------------------------------------
63 |
--------------------------------------------------------------------------------
/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Custom replacement for `torch.nn.functional.grid_sample` that
12 | supports arbitrarily high order gradients between the input and output.
13 | Only works on 2D images and assumes
14 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
15 |
16 | import torch
17 |
18 | # pylint: disable=redefined-builtin
19 | # pylint: disable=arguments-differ
20 | # pylint: disable=protected-access
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | enabled = False # Enable the custom op by setting this to true.
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def grid_sample(input, grid):
29 | if _should_use_custom_op():
30 | return _GridSample2dForward.apply(input, grid)
31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def _should_use_custom_op():
36 | return enabled
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | class _GridSample2dForward(torch.autograd.Function):
41 | @staticmethod
42 | def forward(ctx, input, grid):
43 | assert input.ndim == 4
44 | assert grid.ndim == 4
45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
46 | ctx.save_for_backward(input, grid)
47 | return output
48 |
49 | @staticmethod
50 | def backward(ctx, grad_output):
51 | input, grid = ctx.saved_tensors
52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
53 | return grad_input, grad_grid
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | class _GridSample2dBackward(torch.autograd.Function):
58 | @staticmethod
59 | def forward(ctx, grad_output, input, grid):
60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
61 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
62 | ctx.save_for_backward(grid)
63 | return grad_input, grad_grid
64 |
65 | @staticmethod
66 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
67 | _ = grad2_grad_grid # unused
68 | grid, = ctx.saved_tensors
69 | grad2_grad_output = None
70 | grad2_input = None
71 | grad2_grid = None
72 |
73 | if ctx.needs_input_grad[0]:
74 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
75 |
76 | assert not ctx.needs_input_grad[2]
77 | return grad2_grad_output, grad2_input, grad2_grid
78 |
79 | #----------------------------------------------------------------------------
80 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 | #include
15 | #include
16 | #include "upfirdn2d.h"
17 |
18 | //------------------------------------------------------------------------
19 |
20 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
21 | {
22 | // Validate arguments.
23 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
24 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
25 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
26 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
27 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
28 | TORCH_CHECK(x.numel() > 0, "x has zero size");
29 | TORCH_CHECK(f.numel() > 0, "f has zero size");
30 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
31 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
32 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
33 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
34 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
35 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
36 |
37 | // Create output tensor.
38 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
39 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
40 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
41 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
42 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
43 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
44 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
45 |
46 | // Initialize CUDA kernel parameters.
47 | upfirdn2d_kernel_params p;
48 | p.x = x.data_ptr();
49 | p.f = f.data_ptr();
50 | p.y = y.data_ptr();
51 | p.up = make_int2(upx, upy);
52 | p.down = make_int2(downx, downy);
53 | p.pad0 = make_int2(padx0, pady0);
54 | p.flip = (flip) ? 1 : 0;
55 | p.gain = gain;
56 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
57 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
58 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
59 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
60 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
61 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
62 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
63 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
64 |
65 | // Choose CUDA kernel.
66 | upfirdn2d_kernel_spec spec;
67 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
68 | {
69 | spec = choose_upfirdn2d_kernel(p);
70 | });
71 |
72 | // Set looping options.
73 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
74 | p.loopMinor = spec.loopMinor;
75 | p.loopX = spec.loopX;
76 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
77 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
78 |
79 | // Compute grid size.
80 | dim3 blockSize, gridSize;
81 | if (spec.tileOutW < 0) // large
82 | {
83 | blockSize = dim3(4, 32, 1);
84 | gridSize = dim3(
85 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
86 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
87 | p.launchMajor);
88 | }
89 | else // small
90 | {
91 | blockSize = dim3(256, 1, 1);
92 | gridSize = dim3(
93 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
94 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
95 | p.launchMajor);
96 | }
97 |
98 | // Launch CUDA kernel.
99 | void* args[] = {&p};
100 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
101 | return y;
102 | }
103 |
104 | //------------------------------------------------------------------------
105 |
106 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
107 | {
108 | m.def("upfirdn2d", &upfirdn2d);
109 | }
110 |
111 | //------------------------------------------------------------------------
112 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 |
15 | //------------------------------------------------------------------------
16 | // CUDA kernel parameters.
17 |
18 | struct upfirdn2d_kernel_params
19 | {
20 | const void* x;
21 | const float* f;
22 | void* y;
23 |
24 | int2 up;
25 | int2 down;
26 | int2 pad0;
27 | int flip;
28 | float gain;
29 |
30 | int4 inSize; // [width, height, channel, batch]
31 | int4 inStride;
32 | int2 filterSize; // [width, height]
33 | int2 filterStride;
34 | int4 outSize; // [width, height, channel, batch]
35 | int4 outStride;
36 | int sizeMinor;
37 | int sizeMajor;
38 |
39 | int loopMinor;
40 | int loopMajor;
41 | int loopX;
42 | int launchMinor;
43 | int launchMajor;
44 | };
45 |
46 | //------------------------------------------------------------------------
47 | // CUDA kernel specialization.
48 |
49 | struct upfirdn2d_kernel_spec
50 | {
51 | void* kernel;
52 | int tileOutW;
53 | int tileOutH;
54 | int loopMinor;
55 | int loopX;
56 | };
57 |
58 | //------------------------------------------------------------------------
59 | // CUDA kernel selection.
60 |
61 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
62 |
63 | //------------------------------------------------------------------------
64 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Facilities for pickling Python code alongside other data.
12 |
13 | The pickled code is automatically imported into a separate Python module
14 | during unpickling. This way, any previously exported pickles will remain
15 | usable even if the original code is no longer available, or if the current
16 | version of the code is not consistent with what was originally pickled."""
17 |
18 | import sys
19 | import pickle
20 | import io
21 | import inspect
22 | import copy
23 | import uuid
24 | import types
25 | import dnnlib
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | _version = 6 # internal version number
30 | _decorators = set() # {decorator_class, ...}
31 | _import_hooks = [] # [hook_function, ...]
32 | _module_to_src_dict = dict() # {module: src, ...}
33 | _src_to_module_dict = dict() # {src: module, ...}
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | def persistent_class(orig_class):
38 | r"""Class decorator that extends a given class to save its source code
39 | when pickled.
40 |
41 | Example:
42 |
43 | from torch_utils import persistence
44 |
45 | @persistence.persistent_class
46 | class MyNetwork(torch.nn.Module):
47 | def __init__(self, num_inputs, num_outputs):
48 | super().__init__()
49 | self.fc = MyLayer(num_inputs, num_outputs)
50 | ...
51 |
52 | @persistence.persistent_class
53 | class MyLayer(torch.nn.Module):
54 | ...
55 |
56 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
57 | source code alongside other internal state (e.g., parameters, buffers,
58 | and submodules). This way, any previously exported pickle will remain
59 | usable even if the class definitions have been modified or are no
60 | longer available.
61 |
62 | The decorator saves the source code of the entire Python module
63 | containing the decorated class. It does *not* save the source code of
64 | any imported modules. Thus, the imported modules must be available
65 | during unpickling, also including `torch_utils.persistence` itself.
66 |
67 | It is ok to call functions defined in the same module from the
68 | decorated class. However, if the decorated class depends on other
69 | classes defined in the same module, they must be decorated as well.
70 | This is illustrated in the above example in the case of `MyLayer`.
71 |
72 | It is also possible to employ the decorator just-in-time before
73 | calling the constructor. For example:
74 |
75 | cls = MyLayer
76 | if want_to_make_it_persistent:
77 | cls = persistence.persistent_class(cls)
78 | layer = cls(num_inputs, num_outputs)
79 |
80 | As an additional feature, the decorator also keeps track of the
81 | arguments that were used to construct each instance of the decorated
82 | class. The arguments can be queried via `obj.init_args` and
83 | `obj.init_kwargs`, and they are automatically pickled alongside other
84 | object state. A typical use case is to first unpickle a previous
85 | instance of a persistent class, and then upgrade it to use the latest
86 | version of the source code:
87 |
88 | with open('old_pickle.pkl', 'rb') as f:
89 | old_net = pickle.load(f)
90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92 | """
93 | assert isinstance(orig_class, type)
94 | if is_persistent(orig_class):
95 | return orig_class
96 |
97 | assert orig_class.__module__ in sys.modules
98 | orig_module = sys.modules[orig_class.__module__]
99 | orig_module_src = _module_to_src(orig_module)
100 |
101 | class Decorator(orig_class):
102 | _orig_module_src = orig_module_src
103 | _orig_class_name = orig_class.__name__
104 |
105 | def __init__(self, *args, **kwargs):
106 | super().__init__(*args, **kwargs)
107 | self._init_args = copy.deepcopy(args)
108 | self._init_kwargs = copy.deepcopy(kwargs)
109 | assert orig_class.__name__ in orig_module.__dict__
110 | _check_pickleable(self.__reduce__())
111 |
112 | @property
113 | def init_args(self):
114 | return copy.deepcopy(self._init_args)
115 |
116 | @property
117 | def init_kwargs(self):
118 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
119 |
120 | def __reduce__(self):
121 | fields = list(super().__reduce__())
122 | fields += [None] * max(3 - len(fields), 0)
123 | if fields[0] is not _reconstruct_persistent_obj:
124 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
125 | fields[0] = _reconstruct_persistent_obj # reconstruct func
126 | fields[1] = (meta,) # reconstruct args
127 | fields[2] = None # state dict
128 | return tuple(fields)
129 |
130 | Decorator.__name__ = orig_class.__name__
131 | _decorators.add(Decorator)
132 | return Decorator
133 |
134 | #----------------------------------------------------------------------------
135 |
136 | def is_persistent(obj):
137 | r"""Test whether the given object or class is persistent, i.e.,
138 | whether it will save its source code when pickled.
139 | """
140 | try:
141 | if obj in _decorators:
142 | return True
143 | except TypeError:
144 | pass
145 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
146 |
147 | #----------------------------------------------------------------------------
148 |
149 | def import_hook(hook):
150 | r"""Register an import hook that is called whenever a persistent object
151 | is being unpickled. A typical use case is to patch the pickled source
152 | code to avoid errors and inconsistencies when the API of some imported
153 | module has changed.
154 |
155 | The hook should have the following signature:
156 |
157 | hook(meta) -> modified meta
158 |
159 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
160 |
161 | type: Type of the persistent object, e.g. `'class'`.
162 | version: Internal version number of `torch_utils.persistence`.
163 | module_src Original source code of the Python module.
164 | class_name: Class name in the original Python module.
165 | state: Internal state of the object.
166 |
167 | Example:
168 |
169 | @persistence.import_hook
170 | def wreck_my_network(meta):
171 | if meta.class_name == 'MyNetwork':
172 | print('MyNetwork is being imported. I will wreck it!')
173 | meta.module_src = meta.module_src.replace("True", "False")
174 | return meta
175 | """
176 | assert callable(hook)
177 | _import_hooks.append(hook)
178 |
179 | #----------------------------------------------------------------------------
180 |
181 | def _reconstruct_persistent_obj(meta):
182 | r"""Hook that is called internally by the `pickle` module to unpickle
183 | a persistent object.
184 | """
185 | meta = dnnlib.EasyDict(meta)
186 | meta.state = dnnlib.EasyDict(meta.state)
187 | for hook in _import_hooks:
188 | meta = hook(meta)
189 | assert meta is not None
190 |
191 | assert meta.version == _version
192 | module = _src_to_module(meta.module_src)
193 |
194 | assert meta.type == 'class'
195 | orig_class = module.__dict__[meta.class_name]
196 | decorator_class = persistent_class(orig_class)
197 | obj = decorator_class.__new__(decorator_class)
198 |
199 | setstate = getattr(obj, '__setstate__', None)
200 | if callable(setstate):
201 | setstate(meta.state) # pylint: disable=not-callable
202 | else:
203 | obj.__dict__.update(meta.state)
204 | return obj
205 |
206 | #----------------------------------------------------------------------------
207 |
208 | def _module_to_src(module):
209 | r"""Query the source code of a given Python module.
210 | """
211 | src = _module_to_src_dict.get(module, None)
212 | if src is None:
213 | src = inspect.getsource(module)
214 | _module_to_src_dict[module] = src
215 | _src_to_module_dict[src] = module
216 | return src
217 |
218 | def _src_to_module(src):
219 | r"""Get or create a Python module for the given source code.
220 | """
221 | module = _src_to_module_dict.get(src, None)
222 | if module is None:
223 | module_name = "_imported_module_" + uuid.uuid4().hex
224 | module = types.ModuleType(module_name)
225 | sys.modules[module_name] = module
226 | _module_to_src_dict[module] = src
227 | _src_to_module_dict[src] = module
228 | exec(src, module.__dict__) # pylint: disable=exec-used
229 | return module
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def _check_pickleable(obj):
234 | r"""Check that the given object is pickleable, raising an exception if
235 | it is not. This function is expected to be considerably more efficient
236 | than actually pickling the object.
237 | """
238 | def recurse(obj):
239 | if isinstance(obj, (list, tuple, set)):
240 | return [recurse(x) for x in obj]
241 | if isinstance(obj, dict):
242 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
243 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
244 | return None # Python primitive types are pickleable.
245 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
246 | return None # NumPy arrays and PyTorch tensors are pickleable.
247 | if is_persistent(obj):
248 | return None # Persistent objects are pickleable, by virtue of the constructor check.
249 | return obj
250 | with io.BytesIO() as f:
251 | pickle.dump(recurse(obj), f)
252 |
253 | #----------------------------------------------------------------------------
254 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
12 |
--------------------------------------------------------------------------------
/training/crosssection_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import torch
12 |
13 | def sample_cross_section(G, ws, resolution=256, w=1.2):
14 | axis=0
15 | A, B = torch.meshgrid(torch.linspace(w/2, -w/2, resolution, device=ws.device), torch.linspace(-w/2, w/2, resolution, device=ws.device), indexing='ij')
16 | A, B = A.reshape(-1, 1), B.reshape(-1, 1)
17 | C = torch.zeros_like(A)
18 | coordinates = [A, B]
19 | coordinates.insert(axis, C)
20 | coordinates = torch.cat(coordinates, dim=-1).expand(ws.shape[0], -1, -1)
21 |
22 | sigma = G.sample_mixed(coordinates, torch.randn_like(coordinates), ws)['sigma']
23 | return sigma.reshape(-1, 1, resolution, resolution)
24 |
25 | # if __name__ == '__main__':
26 | # sample_crossection(None)
--------------------------------------------------------------------------------
/training/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import atan2, asin, cos
3 |
4 | def P2sRt(P):
5 | """ decompositing camera matrix P.
6 | Args:
7 | P: (3, 4). Affine Camera Matrix.
8 | Returns:
9 | s: scale factor.
10 | R: (3, 3). rotation matrix.
11 | t2d: (2,). 2d translation.
12 | """
13 | t3d = P[:, 3]
14 | R1 = P[0:1, :3]
15 | R2 = P[1:2, :3]
16 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0
17 | r1 = R1 / np.linalg.norm(R1)
18 | r2 = R2 / np.linalg.norm(R2)
19 | r3 = np.cross(r1, r2)
20 |
21 | R = np.concatenate((r1, r2, r3), 0)
22 | return s, R, t3d
23 |
24 | def matrix2angle(R):
25 | """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf
26 | refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv
27 | todo: check and debug
28 | Args:
29 | R: (3,3). rotation matrix
30 | Returns:
31 | x: yaw
32 | y: pitch
33 | z: roll
34 | """
35 | if R[2, 0] > 0.998:
36 | z = 0
37 | x = np.pi / 2
38 | y = z + atan2(-R[0, 1], -R[0, 2])
39 | elif R[2, 0] < -0.998:
40 | z = 0
41 | x = -np.pi / 2
42 | y = -z + atan2(R[0, 1], R[0, 2])
43 | else:
44 | x = asin(R[2, 0])
45 | y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x))
46 | z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x))
47 |
48 | if abs(y) > np.pi/2:
49 | if x > 0:
50 | x = np.pi - x
51 | else:
52 | x = -np.pi - x
53 | y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x))
54 | z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x))
55 | return x, y, z
56 |
57 | def calc_pose(param):
58 | P = param[:12].reshape(3, -1) # camera matrix
59 | s, R, t3d = P2sRt(P)
60 | P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale
61 | pose = matrix2angle(R)
62 | pose = [p * 180 / np.pi for p in pose]
63 |
64 | return P, pose
65 |
66 | def get_poseangle(eg3dparams):
67 |
68 | convert = np.array([
69 | [1, 0, 0, 0],
70 | [0, -1, 0, 0],
71 | [0, 0, -1, 0],
72 | [0, 0, 0, 1],
73 | ]).astype(np.float32)
74 |
75 | entry_cam = np.array([float(p) for p in eg3dparams][:16]).reshape((4,4))
76 |
77 | world2cam = np.linalg.inv(entry_cam@convert)
78 | pose = matrix2angle(world2cam[:3,:3])
79 | angle = [p * 180 / np.pi for p in pose]
80 |
81 | return angle
82 |
--------------------------------------------------------------------------------
/training/volumetric_rendering/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
--------------------------------------------------------------------------------
/training/volumetric_rendering/math_utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2022 Petr Kellnhofer
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 |
25 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
26 | """
27 | Left-multiplies MxM @ NxM. Returns NxM.
28 | """
29 | res = torch.matmul(vectors4, matrix.T)
30 | return res
31 |
32 |
33 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
34 | """
35 | Normalize vector lengths.
36 | """
37 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
38 |
39 | def torch_dot(x: torch.Tensor, y: torch.Tensor):
40 | """
41 | Dot product of two tensors.
42 | """
43 | return (x * y).sum(-1)
44 |
45 |
46 | def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
47 | """
48 | Author: Petr Kellnhofer
49 | Intersects rays with the [-1, 1] NDC volume.
50 | Returns min and max distance of entry.
51 | Returns -1 for no intersection.
52 | https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
53 | """
54 | o_shape = rays_o.shape
55 | rays_o = rays_o.detach().reshape(-1, 3)
56 | rays_d = rays_d.detach().reshape(-1, 3)
57 |
58 |
59 | bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
60 | bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
61 | bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
62 | is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
63 |
64 | # Precompute inverse for stability.
65 | invdir = 1 / rays_d
66 | sign = (invdir < 0).long()
67 |
68 | # Intersect with YZ plane.
69 | tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
70 | tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
71 |
72 | # Intersect with XZ plane.
73 | tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
74 | tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
75 |
76 | # Resolve parallel rays.
77 | is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
78 |
79 | # Use the shortest intersection.
80 | tmin = torch.max(tmin, tymin)
81 | tmax = torch.min(tmax, tymax)
82 |
83 | # Intersect with XY plane.
84 | tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
85 | tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
86 |
87 | # Resolve parallel rays.
88 | is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
89 |
90 | # Use the shortest intersection.
91 | tmin = torch.max(tmin, tzmin)
92 | tmax = torch.min(tmax, tzmax)
93 |
94 | # Mark invalid.
95 | tmin[torch.logical_not(is_valid)] = -1
96 | tmax[torch.logical_not(is_valid)] = -2
97 |
98 | return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
99 |
100 |
101 | def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
102 | """
103 | Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
104 | Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
105 | """
106 | # create a tensor of 'num' steps from 0 to 1
107 | steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
108 |
109 | # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
110 | # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
111 | # "cannot statically infer the expected size of a list in this contex", hence the code below
112 | for i in range(start.ndim):
113 | steps = steps.unsqueeze(-1)
114 |
115 | # the output starts at 'start' and increments until 'stop' in each dimension
116 | out = start[None] + steps * (stop - start)[None]
117 |
118 | return out
119 |
--------------------------------------------------------------------------------
/training/volumetric_rendering/ray_marcher.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """
12 | The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
13 | Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
14 | """
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from torch_utils import persistence
20 |
21 | @persistence.persistent_class
22 | class MipRayMarcher2(nn.Module):
23 | def __init__(self):
24 | super().__init__()
25 |
26 |
27 | def run_forward(self, colors, densities, depths, rendering_options):
28 | deltas = depths[:, :, 1:] - depths[:, :, :-1]
29 | colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
30 | densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
31 | depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
32 |
33 |
34 | if rendering_options['clamp_mode'] == 'softplus':
35 | densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better
36 | else:
37 | assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!"
38 |
39 | density_delta = densities_mid * deltas
40 |
41 | alpha = 1 - torch.exp(-density_delta)
42 |
43 | alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
44 | weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
45 |
46 | composite_rgb = torch.sum(weights * colors_mid, -2)
47 | weight_total = weights.sum(2)
48 | composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
49 |
50 | # clip the composite to min/max range of depths
51 | composite_depth = torch.nan_to_num(composite_depth, float('inf'))
52 | composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
53 |
54 | if rendering_options.get('white_back', False):
55 | composite_rgb = composite_rgb + 1 - weight_total
56 |
57 | return composite_rgb, composite_depth, weights
58 |
59 |
60 | def forward(self, colors, densities, depths, rendering_options):
61 | composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
62 |
63 | return composite_rgb, composite_depth, weights
--------------------------------------------------------------------------------
/training/volumetric_rendering/ray_sampler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """
12 | The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
13 | Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
14 | """
15 |
16 | import torch
17 |
18 | class RaySampler(torch.nn.Module):
19 | def __init__(self):
20 | super().__init__()
21 | self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
22 |
23 |
24 | def forward(self, cam2world_matrix, intrinsics, resolution):
25 | """
26 | Create batches of rays and return origins and directions.
27 |
28 | cam2world_matrix: (N, 4, 4)
29 | intrinsics: (N, 3, 3)
30 | resolution: int
31 |
32 | ray_origins: (N, M, 3)
33 | ray_dirs: (N, M, 2)
34 | """
35 | N, M = cam2world_matrix.shape[0], resolution**2
36 | cam_locs_world = cam2world_matrix[:, :3, 3]
37 | fx = intrinsics[:, 0, 0]
38 | fy = intrinsics[:, 1, 1]
39 | cx = intrinsics[:, 0, 2]
40 | cy = intrinsics[:, 1, 2]
41 | sk = intrinsics[:, 0, 1]
42 |
43 | uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1./resolution) + (0.5/resolution)
44 | uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
45 | uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
46 |
47 | x_cam = uv[:, :, 0].view(N, -1)
48 | y_cam = uv[:, :, 1].view(N, -1)
49 | z_cam = torch.ones((N, M), device=cam2world_matrix.device)
50 |
51 | x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
52 | y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
53 |
54 | cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
55 |
56 | world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
57 |
58 | ray_dirs = world_rel_points - cam_locs_world[:, None, :]
59 | ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
60 |
61 | ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
62 |
63 | return ray_origins, ray_dirs
--------------------------------------------------------------------------------