├── .gitignore ├── 3DDFA_V2_cropping ├── crop_samples │ └── img │ │ ├── dataset.json │ │ ├── man.jpg │ │ └── woman.jpg ├── cropping_guide.md ├── dlib_kps.py ├── recrop_images.py └── test │ └── original │ ├── man.jpg │ └── woman.jpg ├── LICENSES ├── LICENSE.md └── LICENSE_EG3D ├── README.md ├── calc_mbs.py ├── calc_metrics.py ├── camera_utils.py ├── dataset ├── testdata_img.zip ├── testdata_img │ ├── 000134.jpg │ ├── 000157.jpg │ └── dataset.json ├── testdata_seg.zip └── testdata_seg │ ├── 000134.png │ └── 000157.png ├── dataset_tool.py ├── dataset_tool_seg.py ├── dnnlib ├── __init__.py └── util.py ├── environment.yml ├── gen_interpolation.py ├── gen_pti_script.sh ├── gen_samples.py ├── gen_samples_forID.py ├── gen_videos.py ├── gen_videos_interp.py ├── gen_videos_proj.py ├── gen_videos_proj_withseg.py ├── get_metrics.sh ├── gui_utils ├── __init__.py ├── gl_utils.py ├── glfw_window.py ├── imgui_utils.py ├── imgui_window.py └── text_utils.py ├── legacy.py ├── metrics ├── __init__.py ├── equivariance.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── misc ├── segmentation_example.py └── teaser.png ├── projector.py ├── projector_withseg.py ├── resave_model.py ├── shape_utils.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py └── training ├── __init__.py ├── augment.py ├── crosssection_utils.py ├── dataset.py ├── dual_discriminator.py ├── loss.py ├── networks_stylegan2.py ├── networks_stylegan3.py ├── superresolution.py ├── training_loop.py ├── triplane.py ├── utils.py └── volumetric_rendering ├── __init__.py ├── math_utils.py ├── ray_marcher.py ├── ray_sampler.py └── renderer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # paper results 163 | *.mp4 164 | out/ 165 | pti_out/ 166 | interpolation_out/ 167 | models/ 168 | 169 | # viz 170 | viz/ 171 | visualizer.py -------------------------------------------------------------------------------- /3DDFA_V2_cropping/crop_samples/img/dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "labels": [ 3 | [ 4 | "man.jpg", 5 | [ 6 | "0.997649", 7 | "-0.015475", 8 | "0.064940", 9 | "-0.175337", 10 | "-0.015415", 11 | "-0.992914", 12 | "-0.118841", 13 | "0.320872", 14 | "0.067236", 15 | "0.118079", 16 | "-0.990818", 17 | "2.675209", 18 | "0.000000", 19 | "0.000000", 20 | "0.000000", 21 | "1.000000", 22 | "4.264700", 23 | "0.000000", 24 | "0.500000", 25 | "0.000000", 26 | "4.264700", 27 | "0.500000", 28 | "0.000000", 29 | "0.000000", 30 | "1.000000" 31 | ] 32 | ], 33 | [ 34 | "woman.jpg", 35 | [ 36 | "0.999654", 37 | "0.015362", 38 | "-0.000774", 39 | "0.002091", 40 | "0.029684", 41 | "-0.987381", 42 | "-0.158344", 43 | "0.427528", 44 | "-0.005544", 45 | "0.158313", 46 | "-0.987495", 47 | "2.666237", 48 | "0.000000", 49 | "0.000000", 50 | "0.000000", 51 | "1.000000", 52 | "4.264700", 53 | "0.000000", 54 | "0.500000", 55 | "0.000000", 56 | "4.264700", 57 | "0.500000", 58 | "0.000000", 59 | "0.000000", 60 | "1.000000" 61 | ] 62 | ] 63 | ] 64 | } -------------------------------------------------------------------------------- /3DDFA_V2_cropping/crop_samples/img/man.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/3DDFA_V2_cropping/crop_samples/img/man.jpg -------------------------------------------------------------------------------- /3DDFA_V2_cropping/crop_samples/img/woman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/3DDFA_V2_cropping/crop_samples/img/woman.jpg -------------------------------------------------------------------------------- /3DDFA_V2_cropping/cropping_guide.md: -------------------------------------------------------------------------------- 1 | Cropping and obtaining camera poses guide using [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2). 2 | 3 | Our cropping file [recrop_images.py](./recrop_images.py) is based on 3DDFA_V2 and dlib. First of all, please clone the 3DDFA_V2 repo and follow all the installation instruction and make sure their demos can be run sucessfully. After building cython of 3DDFA, you can use the follow command for other necessary packages: 4 | 5 | ``` 6 | pip install opencv-python dlib pyyaml onnxruntime onnx 7 | ``` 8 | 9 | Test images used here are from [test/origin/man.jpg](https://www.freepik.com/free-photo/portrait-white-man-isolated_3199590.htm) and [test/origin/woman.jpg](https://www.freepik.com/free-photo/pretty-smiling-joyfully-female-with-fair-hair-dressed-casually-looking-with-satisfaction_9117255.htm) 10 | 11 | --- 12 | 13 | # Steps 14 | 15 | ## 1. Move folder `test`, `recrop_images.py`, and `dlib_kps.py` under this directory to the 3DDFA_V2 root dir. Also, remember to download `shape_predictor_68_face_landmarks.dat` and put it under 3DDFA_V2 root dir. 16 | 17 | ## 2. cd to 3DDFA_V2. The cropping script has to run under 3DDFA_V2 dir. 18 | ```.bash 19 | cd 3DDFA_V2 20 | ``` 21 | 22 | ## 3. Extract face keypoints using dlib. After this, you should have data.pkl under your 3DDFA root dir saving the keypoints. 23 | ```.bash 24 | python dlib_kps.py 25 | ``` 26 | 27 | ## 4. Obtaining camera poses and cropping the images using recrop_images.py 28 | 29 | ```.bash 30 | python recrop_images.py -i data.pkl -j dataset.json 31 | ``` 32 | 33 | ## After this, you should have a folder called crop_samples in your 3DDFA_V2 root dir, which is the same as the one under this directory. Then, you can move the folder back to our panohead root dir, change the --target_img flag to the corresponding folder in `gen_pti_script.sh`, and do the PTI inversion. -------------------------------------------------------------------------------- /3DDFA_V2_cropping/dlib_kps.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import dlib 3 | import pickle 4 | import numpy as np 5 | import os 6 | 7 | # load face detector 8 | detector = dlib.get_frontal_face_detector() 9 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 10 | 11 | # load images 12 | img_dir = "test/original" 13 | list_dir = os.listdir(img_dir) 14 | 15 | # new dict for keypoints 16 | landmarks = {} 17 | 18 | for img_name in list_dir: 19 | _, extension = os.path.splitext(img_name) 20 | 21 | # only do it for images, not .json file 22 | if extension == '.json': 23 | continue 24 | 25 | img_path = os.path.join(img_dir, img_name) 26 | image = cv2.imread(img_path) 27 | 28 | # gray scale 29 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 30 | 31 | # detect face 32 | rects = detector(gray, 1) 33 | 34 | 35 | 36 | for (i, rect) in enumerate(rects): 37 | # get keypoints 38 | shape = predictor(gray, rect) 39 | 40 | # save kps to the dict 41 | landmarks[img_path] = [np.array([p.x, p.y]) for p in shape.parts()] 42 | 43 | # save the data.pkl pickle 44 | with open('data.pkl', 'wb') as f: 45 | pickle.dump(landmarks, f) -------------------------------------------------------------------------------- /3DDFA_V2_cropping/test/original/man.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/3DDFA_V2_cropping/test/original/man.jpg -------------------------------------------------------------------------------- /3DDFA_V2_cropping/test/original/woman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/3DDFA_V2_cropping/test/original/woman.jpg -------------------------------------------------------------------------------- /LICENSES/LICENSE_EG3D: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021-2022, NVIDIA Corporation & affiliates. All rights 2 | reserved. 3 | 4 | 5 | NVIDIA Source Code License for EG3D 6 | 7 | 8 | ======================================================================= 9 | 10 | 1. Definitions 11 | 12 | "Licensor" means any person or entity that distributes its Work. 13 | 14 | "Software" means the original work of authorship made available under 15 | this License. 16 | 17 | "Work" means the Software and any additions to or derivative works of 18 | the Software that are made available under this License. 19 | 20 | The terms "reproduce," "reproduction," "derivative works," and 21 | "distribution" have the meaning as provided under U.S. copyright law; 22 | provided, however, that for the purposes of this License, derivative 23 | works shall not include works that remain separable from, or merely 24 | link (or bind by name) to the interfaces of, the Work. 25 | 26 | Works, including the Software, are "made available" under this License 27 | by including in or with the Work either (a) a copyright notice 28 | referencing the applicability of this License to the Work, or (b) a 29 | copy of this License. 30 | 31 | 2. License Grants 32 | 33 | 2.1 Copyright Grant. Subject to the terms and conditions of this 34 | License, each Licensor grants to you a perpetual, worldwide, 35 | non-exclusive, royalty-free, copyright license to reproduce, 36 | prepare derivative works of, publicly display, publicly perform, 37 | sublicense and distribute its Work and any resulting derivative 38 | works in any form. 39 | 40 | 3. Limitations 41 | 42 | 3.1 Redistribution. You may reproduce or distribute the Work only 43 | if (a) you do so under this License, (b) you include a complete 44 | copy of this License with your distribution, and (c) you retain 45 | without modification any copyright, patent, trademark, or 46 | attribution notices that are present in the Work. 47 | 48 | 3.2 Derivative Works. You may specify that additional or different 49 | terms apply to the use, reproduction, and distribution of your 50 | derivative works of the Work ("Your Terms") only if (a) Your Terms 51 | provide that the use limitation in Section 3.3 applies to your 52 | derivative works, and (b) you identify the specific derivative 53 | works that are subject to Your Terms. Notwithstanding Your Terms, 54 | this License (including the redistribution requirements in Section 55 | 3.1) will continue to apply to the Work itself. 56 | 57 | 3.3 Use Limitation. The Work and any derivative works thereof only 58 | may be used or intended for use non-commercially. The Work or 59 | derivative works thereof may be used or intended for use by NVIDIA 60 | or it’s affiliates commercially or non-commercially. As used 61 | herein, "non-commercially" means for research or evaluation 62 | purposes only and not for any direct or indirect monetary gain. 63 | 64 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 65 | against any Licensor (including any claim, cross-claim or 66 | counterclaim in a lawsuit) to enforce any patents that you allege 67 | are infringed by any Work, then your rights under this License from 68 | such Licensor (including the grants in Sections 2.1) will terminate 69 | immediately. 70 | 71 | 3.5 Trademarks. This License does not grant any rights to use any 72 | Licensor’s or its affiliates’ names, logos, or trademarks, except 73 | as necessary to reproduce the notices described in this License. 74 | 75 | 3.6 Termination. If you violate any term of this License, then your 76 | rights under this License (including the grants in Sections 2.1) 77 | will terminate immediately. 78 | 79 | 4. Disclaimer of Warranty. 80 | 81 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 82 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 83 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 84 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 85 | THIS LICENSE. 86 | 87 | 5. Limitation of Liability. 88 | 89 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 90 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 91 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 92 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 93 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 94 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 95 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 96 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 97 | THE POSSIBILITY OF SUCH DAMAGES. 98 | 99 | ======================================================================= 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PanoHead: Geometry-Aware 3D Full-Head Synthesis in 360°
2 | 3 | 4 | 5 | 6 | 7 | 8 | ![Teaser image](./misc/teaser.png) 9 | 10 | **PanoHead: Geometry-Aware 3D Full-Head Synthesis in 360°**
11 | Sizhe An, Hongyi Xu, Yichun Shi, Guoxian Song, Umit Y. Ogras, Linjie Luo 12 |
https://sizhean.github.io/panohead
13 | 14 | Abstract: *Synthesis and reconstruction of 3D human head has gained increasing interests in computer vision and computer graphics recently. Existing state-of-the-art 3D generative adversarial networks (GANs) for 3D human head synthesis are either limited to near-frontal views or hard to preserve 3D consistency in large view angles. We propose PanoHead, the first 3D-aware generative model that enables high-quality view-consistent image synthesis of full heads in 360° with diverse appearance and detailed geometry using only in-the-wild unstructured images for training. At its core, we lift up the representation power of recent 3D GANs and bridge the data alignment gap when training from in-the-wild images with widely distributed views. Specifically, we propose a novel two-stage self-adaptive image alignment for robust 3D GAN training. We further introduce a tri-grid neural volume representation that effectively addresses front-face and back-head feature entanglement rooted in the widely-adopted tri-plane formulation. Our method instills prior knowledge of 2D image segmentation in adversarial learning of 3D neural scene structures, enabling compositable head synthesis in diverse backgrounds. Benefiting from these designs, our method significantly outperforms previous 3D GANs, generating high-quality 3D heads with accurate geometry and diverse appearances, even with long wavy and afro hairstyles, renderable from arbitrary poses. Furthermore, we show that our system can reconstruct full 3D heads from single input images for personalized realistic 3D avatars.* 15 | 16 | 17 | ## Requirements 18 | 19 | * We recommend Linux for performance and compatibility reasons. 20 | * 1–8 high-end NVIDIA GPUs. We have done all testing and development using V100, RTX3090, and A100 GPUs. 21 | * 64-bit Python 3.8 and PyTorch 1.11.0 (or later). See https://pytorch.org for PyTorch install instructions. 22 | * CUDA toolkit 11.3 or later. (Why is a separate CUDA toolkit installation required? We use the custom CUDA extensions from the StyleGAN3 repo. Please see [Troubleshooting](https://github.com/NVlabs/stylegan3/blob/main/docs/troubleshooting.md#why-is-cuda-toolkit-installation-necessary)). 23 | * Python libraries: see [environment.yml](./environment.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment: 24 | - `cd PanoHead` 25 | - `conda env create -f environment.yml` 26 | - `conda activate panohead` 27 | 28 | 29 | ## Getting started 30 | 31 | Download the whole `models` folder from [link](https://drive.google.com/drive/folders/1m517-F1NCTGA159dePs5R5qj02svtX1_?usp=sharing) and put it under the root dir. 32 | 33 | Pre-trained networks are stored as `*.pkl` files that can be referenced using local filenames. 34 | 35 | 36 | ## Generating results 37 | 38 | ```.bash 39 | # Generate videos using pre-trained model 40 | 41 | python gen_videos.py --network models/easy-khair-180-gpc0.8-trans10-025000.pkl \ 42 | --seeds 0-3 --grid 2x2 --outdir=out --cfg Head --trunc 0.7 43 | 44 | ``` 45 | 46 | ```.bash 47 | # Generate images and shapes (as .mrc files) using pre-trained model 48 | 49 | python gen_samples.py --outdir=out --trunc=0.7 --shapes=true --seeds=0-3 \ 50 | --network models/easy-khair-180-gpc0.8-trans10-025000.pkl 51 | ``` 52 | 53 | ## Applications 54 | ```.bash 55 | # Generate full head reconstruction from a single RGB image. 56 | # Please refer to ./gen_pti_script.sh 57 | # For this application we need to specify dataset folder instead of zip files. 58 | # Segmentation files are not necessary for PTI inversion. 59 | 60 | ./gen_pti_script.sh 61 | ``` 62 | 63 | ```.bash 64 | # Generate full head interpolation from two seeds. 65 | # Please refer to ./gen_interpolation.py for the implementation 66 | 67 | python gen_interpolation.py --network models/easy-khair-180-gpc0.8-trans10-025000.pkl\ 68 | --trunc 0.7 --outdir interpolation_out 69 | ``` 70 | 71 | 72 | 73 | ## Using networks from Python 74 | 75 | You can use pre-trained networks in your own Python code as follows: 76 | 77 | ```.python 78 | with open('*.pkl', 'rb') as f: 79 | G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module 80 | z = torch.randn([1, G.z_dim]).cuda() # latent codes 81 | c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) # camera parameters 82 | img = G(z, c)['image'] # NCHW, float32, dynamic range [-1, +1], no truncation 83 | mask = G(z, c)['image_mask'] # NHW, int8, [0,255] 84 | ``` 85 | 86 | The above code requires `torch_utils` and `dnnlib` to be accessible via `PYTHONPATH`. It does not need source code for the networks themselves — their class definitions are loaded from the pickle via `torch_utils.persistence`. 87 | 88 | The pickle contains three networks. `'G'` and `'D'` are instantaneous snapshots taken during training, and `'G_ema'` represents a moving average of the generator weights over several training steps. The networks are regular instances of `torch.nn.Module`, with all of their parameters and buffers placed on the CPU at import and gradient computation disabled by default. 89 | 90 | 91 | 92 | ## Datasets 93 | 94 | FFHQ-F(ullhead) consists of [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset), [K-Hairstyle dataset](https://psh01087.github.io/K-Hairstyle/), and an in-house human head dataset. For head pose estimation, we use [WHENet](https://arxiv.org/abs/2005.10353). 95 | 96 | Due to the license issue, we are not able to release FFHQ-F dataset that we used to train the model. [test_data_img](./dataset/testdata_img/) and [test_data_seg](./dataset/testdata_seg/) are just an example for showing the dataset struture. For the camera pose convention, please refer to [EG3D](https://github.com/NVlabs/eg3d). 97 | 98 | 99 | ## Datasets format 100 | For training purpose, we can use either zip files or normal folder for image dataset and segmentation dataset. For PTI, we need to use folder. 101 | 102 | To compress dataset folder to zip file, we can use [dataset_tool_seg](./dataset_tool_seg.py). 103 | 104 | For example: 105 | ```.bash 106 | python dataset_tool_seg.py --img_source dataset/testdata_img --seg_source dataset/testdata_seg --img_dest dataset/testdata_img.zip --seg_dest dataset/testdata_seg.zip --resolution 512x512 107 | ``` 108 | 109 | ## Obtaining camera pose and cropping the images 110 | Please follow the [guide](3DDFA_V2_cropping/cropping_guide.md) 111 | 112 | ## Obtaining segmentation masks 113 | You can try using deeplabv3 or other off-the-shelf tool to generate the masks. For example, using deeplabv3: [misc/segmentation_example.py](misc/segmentation_example.py) 114 | 115 | 116 | 117 | 118 | ## Training 119 | 120 | Examples of training using `train.py`: 121 | 122 | ``` 123 | # Train with StyleGAN2 backbone from scratch with raw neural rendering resolution=64, using 8 GPUs. 124 | # with segmentation mask, trigrid_depth@3, self-adaptive camera pose loss regularizer@10 125 | 126 | python train.py --outdir training-runs --img_data dataset/testdata_img.zip --seg_data dataset/testdata_seg.zip --cfg=ffhq --batch=32 --gpus 8\\ 127 | --gamma=1 --gamma_seg=1 --gen_pose_cond=True --mirror=1 --use_torgb_raw=1 --decoder_activation="none" --disc_module MaskDualDiscriminatorV2\\ 128 | --bcg_reg_prob 0.2 --triplane_depth 3 --density_noise_fade_kimg 200 --density_reg 0 --min_yaw 0 --max_yaw 180 --back_repeat 4 --trans_reg 10 --gpc_reg_prob 0.7 129 | 130 | 131 | # Second stage finetuning to 128 neural rendering resolution (optional). 132 | 133 | python train.py --outdir results --img_data dataset/testdata_img.zip --seg_data dataset/testdata_seg.zip --cfg=ffhq --batch=32 --gpus 8\\ 134 | --resume=~/training-runs/experiment_dir/network-snapshot-025000.pkl\\ 135 | --gamma=1 --gamma_seg=1 --gen_pose_cond=True --mirror=1 --use_torgb_raw=1 --decoder_activation="none" --disc_module MaskDualDiscriminatorV2\\ 136 | --bcg_reg_prob 0.2 --triplane_depth 3 --density_noise_fade_kimg 200 --density_reg 0 --min_yaw 0 --max_yaw 180 --back_repeat 4 --trans_reg 10 --gpc_reg_prob 0.7\\ 137 | --neural_rendering_resolution_final=128 --resume_kimg 1000 138 | ``` 139 | 140 | ## Metrics 141 | 142 | 143 | 144 | ```.bash 145 | ./get_metrics.sh 146 | ``` 147 | There are three evaluation modes: all, front, and back as we mentioned in the paper. Please refer to [cal_metrics.py](./calc_metrics.py) for the implementation. 148 | 149 | 150 | ## Citation 151 | 152 | If you find our repo helpful, please cite our paper using the following bib: 153 | 154 | ``` 155 | @InProceedings{An_2023_CVPR, 156 | author = {An, Sizhe and Xu, Hongyi and Shi, Yichun and Song, Guoxian and Ogras, Umit Y. and Luo, Linjie}, 157 | title = {PanoHead: Geometry-Aware 3D Full-Head Synthesis in 360deg}, 158 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 159 | month = {June}, 160 | year = {2023}, 161 | pages = {20950-20959} 162 | } 163 | ``` 164 | 165 | ## Development 166 | 167 | This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests. 168 | 169 | ## Acknowledgements 170 | 171 | We thank Shuhong Chen for the discussion during Sizhe's internship. 172 | 173 | This repo is heavily based off the [NVlabs/eg3d](https://github.com/NVlabs/eg3d) repo; Huge thanks to the EG3D authors for releasing their code! -------------------------------------------------------------------------------- /calc_mbs.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Generate images and shapes using pretrained network pickle.""" 12 | 13 | import os 14 | import re 15 | from typing import List, Optional, Tuple, Union 16 | 17 | import click 18 | import dnnlib 19 | import numpy as np 20 | import PIL.Image 21 | import torch 22 | from tqdm import tqdm 23 | 24 | 25 | import legacy 26 | from camera_utils import LookAtPoseSampler, FOV_to_intrinsics 27 | from torch_utils import misc 28 | from training.triplane import TriPlaneGenerator 29 | from torchvision.models.segmentation import deeplabv3_resnet101 30 | from torchvision import transforms, utils 31 | from torch.nn import functional as F 32 | 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def parse_range(s: Union[str, List]) -> List[int]: 37 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 38 | 39 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 40 | ''' 41 | if isinstance(s, list): return s 42 | ranges = [] 43 | range_re = re.compile(r'^(\d+)-(\d+)$') 44 | for p in s.split(','): 45 | m = range_re.match(p) 46 | if m: 47 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 48 | else: 49 | ranges.append(int(p)) 50 | return ranges 51 | 52 | 53 | #---------------------------------------------------------------------------- 54 | def get_mask(model, batch, cid): 55 | normalized_batch = transforms.functional.normalize( 56 | batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 57 | output = model(normalized_batch)['out'] 58 | # sem_classes = [ 59 | # '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 60 | # 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 61 | # 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' 62 | # ] 63 | # sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} 64 | # cid = sem_class_to_idx['car'] 65 | 66 | normalized_masks = torch.nn.functional.softmax(output, dim=1) 67 | 68 | boolean_car_masks = (normalized_masks.argmax(1) == cid) 69 | return boolean_car_masks.float() 70 | 71 | def norm_ip(img, low, high): 72 | img_ = img.clamp(min=low, max=high) 73 | img_.sub_(low).div_(max(high - low, 1e-5)) 74 | return img_ 75 | 76 | 77 | def norm_range(t, value_range=(-1, 1)): 78 | if value_range is not None: 79 | return norm_ip(t, value_range[0], value_range[1]) 80 | else: 81 | return norm_ip(t, float(t.min()), float(t.max())) 82 | 83 | #---------------------------------------------------------------------------- 84 | @click.command() 85 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 86 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.7, show_default=True) 87 | @click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) 88 | @click.option('--fov-deg', help='Field of View of camera in degrees', type=int, required=False, metavar='float', default=18.837, show_default=True) 89 | @click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True) 90 | @click.option('--pose_cond', type=int, help='pose_cond angle', default=90, show_default=True) 91 | 92 | def generate_images( 93 | network_pkl: str, 94 | truncation_psi: float, 95 | truncation_cutoff: int, 96 | fov_deg: float, 97 | reload_modules: bool, 98 | pose_cond: int, 99 | ): 100 | """Generate images using pretrained network pickle. 101 | 102 | Examples: 103 | 104 | \b 105 | # Generate an image using pre-trained FFHQ model. 106 | python gen_samples.py --outdir=output --trunc=0.7 --seeds=0-5 --shapes=True\\ 107 | --network=ffhq-rebalanced-128.pkl 108 | """ 109 | 110 | print('Loading networks from "%s"...' % network_pkl) 111 | device = torch.device('cuda:1') 112 | with dnnlib.util.open_url(network_pkl) as f: 113 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 114 | 115 | # Specify reload_modules=True if you want code modifications to take effect; otherwise uses pickled code 116 | if reload_modules: 117 | print("Reloading Modules!") 118 | G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) 119 | misc.copy_params_and_buffers(G, G_new, require_all=True) 120 | G_new.neural_rendering_resolution = G.neural_rendering_resolution 121 | G_new.rendering_kwargs = G.rendering_kwargs 122 | G = G_new 123 | 124 | 125 | pose_cond_rad = pose_cond/180*np.pi 126 | intrinsics = FOV_to_intrinsics(fov_deg, device=device) 127 | 128 | 129 | # load segmentation net 130 | seg_net = deeplabv3_resnet101(pretrained=True, progress=False).to(device) 131 | seg_net.requires_grad_(False) 132 | seg_net.eval() 133 | 134 | mse_total = 0 135 | 136 | n_sample = 64 137 | batch = 32 138 | 139 | n_sample = n_sample // batch * batch 140 | batch_li = n_sample // batch * [batch] 141 | pose_cond_rad = pose_cond/180*np.pi 142 | 143 | intrinsics = FOV_to_intrinsics(fov_deg, device=device) 144 | 145 | # Generate images. 146 | cam_pivot = torch.tensor([0, 0, 0], device=device) 147 | cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7) 148 | conditioning_cam2world_pose = LookAtPoseSampler.sample(pose_cond_rad, np.pi/2, cam_pivot, radius=cam_radius, device=device) 149 | conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) 150 | 151 | c = conditioning_params.repeat(batch, 1) 152 | 153 | for batch in tqdm(batch_li): 154 | # z and w 155 | z0 = torch.from_numpy(np.random.randn(batch, G.z_dim)).to(device) 156 | z1 = torch.from_numpy(np.random.randn(batch, G.z_dim)).to(device) 157 | 158 | 159 | ws0 = G.mapping(z0, c, truncation_psi=0.7, truncation_cutoff=None) 160 | ws1 = G.mapping(z1, c, truncation_psi=0.7, truncation_cutoff=None) 161 | 162 | c0 = c.clone() 163 | c1 = c.clone() 164 | 165 | img0 = G.synthesis(ws0, c0, ws_bcg = ws0.clone())['image'] 166 | img0 = norm_range(img0) 167 | img1 = G.synthesis(ws0, c1, ws_bcg = ws1.clone())['image'] 168 | img1 = norm_range(img1) 169 | # 15 means human mask 170 | mask0 = get_mask(seg_net, img0, 15).unsqueeze(1) 171 | mask1 = get_mask(seg_net, img1, 15).unsqueeze(1) 172 | 173 | diff = torch.abs(mask0-mask1) 174 | mse = F.mse_loss(mask0, mask1) 175 | 176 | # mutual_bg_mask = (1-mask0) * (1-mask1) 177 | 178 | # diff = F.l1_loss(mutual_bg_mask*img1, mutual_bg_mask*img0, reduction='none') 179 | # diff = torch.where(diff < 1/255, torch.zeros_like(diff), torch.ones_like(diff)) 180 | # diff = torch.sum(diff, dim=1) 181 | # diff = torch.where(diff < 1, torch.zeros_like(diff), torch.ones_like(diff)) 182 | utils.save_image( 183 | # (1-mask1)*img1, 184 | mask1, 185 | f'alphamse/changebg/mask1.png', 186 | nrow=8, 187 | normalize=True, 188 | range=(0, 1), 189 | padding=0, 190 | ) 191 | utils.save_image( 192 | mask0, 193 | f'alphamse/changebg/mask0.png', 194 | nrow=8, 195 | normalize=True, 196 | range=(0, 1), 197 | padding=0, 198 | ) 199 | utils.save_image( 200 | img0, 201 | f'alphamse/changebg/img0.png', 202 | nrow=8, 203 | normalize=True, 204 | range=(0, 1), 205 | padding=0, 206 | ) 207 | utils.save_image( 208 | img1, 209 | f'alphamse/changebg/img1.png', 210 | nrow=8, 211 | normalize=True, 212 | range=(0, 1), 213 | padding=0, 214 | ) 215 | utils.save_image( 216 | diff, 217 | f'alphamse/changebg/diff.png', 218 | nrow=8, 219 | normalize=True, 220 | range=(0, 1), 221 | padding=0, 222 | ) 223 | # sys.exit() 224 | 225 | # change_fg_score += torch.sum(torch.sum(diff, dim=(1,2)) / (torch.sum(mutual_bg_mask, dim=(1,2,3))+1e-8)) 226 | mse_total += mse.cpu().detach().numpy() 227 | 228 | print(f'mse_final: {mse_total/len(batch_li)}') 229 | 230 | 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | if __name__ == "__main__": 235 | generate_images() # pylint: disable=no-value-for-parameter 236 | 237 | #---------------------------------------------------------------------------- 238 | 239 | -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 12 | 13 | import os 14 | import click 15 | import json 16 | import tempfile 17 | import copy 18 | import torch 19 | 20 | import dnnlib 21 | import legacy 22 | from metrics import metric_main 23 | from metrics import metric_utils 24 | from torch_utils import training_stats 25 | from torch_utils import custom_ops 26 | from torch_utils import misc 27 | from torch_utils.ops import conv2d_gradfix 28 | 29 | #---------------------------------------------------------------------------- 30 | 31 | def subprocess_fn(rank, args, temp_dir): 32 | dnnlib.util.Logger(should_flush=True) 33 | 34 | # Init torch.distributed. 35 | if args.num_gpus > 1: 36 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 37 | if os.name == 'nt': 38 | init_method = 'file:///' + init_file.replace('\\', '/') 39 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 40 | else: 41 | init_method = f'file://{init_file}' 42 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 43 | 44 | # Init torch_utils. 45 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 46 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 47 | if rank != 0 or not args.verbose: 48 | custom_ops.verbosity = 'none' 49 | 50 | # Configure torch. 51 | device = torch.device('cuda', rank) 52 | torch.backends.cuda.matmul.allow_tf32 = False 53 | torch.backends.cudnn.allow_tf32 = False 54 | conv2d_gradfix.enabled = True 55 | 56 | # Print network summary. 57 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 58 | if rank == 0 and args.verbose: 59 | z = torch.empty([1, G.z_dim], device=device) 60 | c = torch.empty([1, G.c_dim], device=device) 61 | misc.print_module_summary(G, [z, c]) 62 | 63 | # Calculate each metric. 64 | for metric in args.metrics: 65 | if rank == 0 and args.verbose: 66 | print(f'Calculating {metric}...') 67 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 68 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, 69 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress, mode=args.mode) 70 | if rank == 0: 71 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 72 | if rank == 0 and args.verbose: 73 | print() 74 | 75 | # Done. 76 | if rank == 0 and args.verbose: 77 | print('Exiting...') 78 | 79 | #---------------------------------------------------------------------------- 80 | 81 | def parse_comma_separated_list(s): 82 | if isinstance(s, list): 83 | return s 84 | if s is None or s.lower() == 'none' or s == '': 85 | return [] 86 | return s.split(',') 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @click.command() 91 | @click.pass_context 92 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) 93 | @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) 94 | @click.option('--img_data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') 95 | @click.option('--seg_data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') 96 | @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') 97 | @click.option('--mode', help='Evaluation mode [back, front, all]', metavar='STR', type=click.Choice(['back','front','all']), default='all', required=False, show_default=False) 98 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) 99 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) 100 | 101 | def calc_metrics(ctx, network_pkl, metrics, img_data, seg_data, mirror, mode, gpus, verbose): 102 | """Calculate quality metrics for previous training run or pretrained network pickle. 103 | 104 | Examples: 105 | 106 | \b 107 | # Previous training run: look up options automatically, save result to JSONL file. 108 | python calc_metrics.py --metrics=eqt50k_int,eqr50k \\ 109 | --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl 110 | 111 | \b 112 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout. 113 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\ 114 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl 115 | 116 | \b 117 | Recommended metrics: 118 | fid50k_full Frechet inception distance against the full dataset. 119 | kid50k_full Kernel inception distance against the full dataset. 120 | pr50k3_full Precision and recall againt the full dataset. 121 | ppl2_wend Perceptual path length in W, endpoints, full image. 122 | eqt50k_int Equivariance w.r.t. integer translation (EQ-T). 123 | eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac). 124 | eqr50k Equivariance w.r.t. rotation (EQ-R). 125 | 126 | \b 127 | Legacy metrics: 128 | fid50k Frechet inception distance against 50k real images. 129 | kid50k Kernel inception distance against 50k real images. 130 | pr50k3 Precision and recall against 50k real images. 131 | is50k Inception score for CIFAR-10. 132 | """ 133 | dnnlib.util.Logger(should_flush=True) 134 | 135 | # decide evaluation mode 136 | if mode == 'all': 137 | min_yaw, max_yaw = 0, 180 138 | elif mode == 'front': 139 | min_yaw, max_yaw = 0, 90 140 | elif mode == 'back': 141 | min_yaw, max_yaw = 90, 180 142 | 143 | # Validate arguments. 144 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose, mode=mode) 145 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 146 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 147 | if not args.num_gpus >= 1: 148 | ctx.fail('--gpus must be at least 1') 149 | 150 | # Load network. 151 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 152 | ctx.fail('--network must point to a file or URL') 153 | if args.verbose: 154 | print(f'Loading network from "{network_pkl}"...') 155 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: 156 | network_dict = legacy.load_network_pkl(f) 157 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module 158 | 159 | # Initialize dataset options. 160 | if img_data is not None: 161 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.MaskLabeledDataset', img_path=img_data, seg_path=seg_data, min_yaw = min_yaw, max_yaw = max_yaw, back_repeat=1) 162 | elif network_dict['training_set_kwargs'] is not None: 163 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) 164 | else: 165 | ctx.fail('Could not look up dataset options; please specify --data') 166 | 167 | # Finalize dataset options. 168 | args.dataset_kwargs.resolution = args.G.img_resolution 169 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 170 | if mirror is not None: 171 | args.dataset_kwargs.xflip = mirror 172 | 173 | # Print dataset options. 174 | if args.verbose: 175 | print('Dataset options:') 176 | print(json.dumps(args.dataset_kwargs, indent=2)) 177 | 178 | # Locate run dir. 179 | args.run_dir = None 180 | if os.path.isfile(network_pkl): 181 | pkl_dir = os.path.dirname(network_pkl) 182 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 183 | args.run_dir = pkl_dir 184 | 185 | # Launch processes. 186 | if args.verbose: 187 | print('Launching processes...') 188 | torch.multiprocessing.set_start_method('spawn') 189 | with tempfile.TemporaryDirectory() as temp_dir: 190 | if args.num_gpus == 1: 191 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 192 | else: 193 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 194 | 195 | #---------------------------------------------------------------------------- 196 | 197 | if __name__ == "__main__": 198 | calc_metrics() # pylint: disable=no-value-for-parameter 199 | 200 | #---------------------------------------------------------------------------- 201 | -------------------------------------------------------------------------------- /camera_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts. 13 | """ 14 | 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from training.volumetric_rendering import math_utils 21 | 22 | class GaussianCameraPoseSampler: 23 | """ 24 | Samples pitch and yaw from a Gaussian distribution and returns a camera pose. 25 | Camera is specified as looking at the origin. 26 | If horizontal and vertical stddev (specified in radians) are zero, gives a 27 | deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean. 28 | The coordinate system is specified with y-up, z-forward, x-left. 29 | Horizontal mean is the azimuthal angle (rotation around y axis) in radians, 30 | vertical mean is the polar angle (angle from the y axis) in radians. 31 | A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2. 32 | 33 | Example: 34 | For a camera pose looking at the origin with the camera at position [0, 0, 1]: 35 | cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1) 36 | """ 37 | 38 | @staticmethod 39 | def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): 40 | h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean 41 | v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean 42 | v = torch.clamp(v, 1e-5, math.pi - 1e-5) 43 | 44 | theta = h 45 | v = v / math.pi 46 | phi = torch.arccos(1 - 2*v) 47 | 48 | camera_origins = torch.zeros((batch_size, 3), device=device) 49 | 50 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) 51 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) 52 | camera_origins[:, 1:2] = radius*torch.cos(phi) 53 | 54 | forward_vectors = math_utils.normalize_vecs(-camera_origins) 55 | return create_cam2world_matrix(forward_vectors, camera_origins) 56 | 57 | 58 | class LookAtPoseSampler: 59 | """ 60 | Same as GaussianCameraPoseSampler, except the 61 | camera is specified as looking at 'lookat_position', a 3-vector. 62 | 63 | Example: 64 | For a camera pose looking at the origin with the camera at position [0, 0, 1]: 65 | cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1) 66 | """ 67 | 68 | @staticmethod 69 | def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): 70 | h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean 71 | v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean 72 | v = torch.clamp(v, 1e-5, math.pi - 1e-5) 73 | 74 | theta = h 75 | v = v / math.pi 76 | phi = torch.arccos(1 - 2*v) 77 | 78 | camera_origins = torch.zeros((batch_size, 3), device=device) 79 | 80 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) 81 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) 82 | camera_origins[:, 1:2] = radius*torch.cos(phi) 83 | 84 | # forward_vectors = math_utils.normalize_vecs(-camera_origins) 85 | forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins) 86 | return create_cam2world_matrix(forward_vectors, camera_origins) 87 | 88 | class UniformCameraPoseSampler: 89 | """ 90 | Same as GaussianCameraPoseSampler, except the 91 | pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev. 92 | 93 | Example: 94 | For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians: 95 | 96 | cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16) 97 | """ 98 | 99 | @staticmethod 100 | def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): 101 | h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean 102 | v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean 103 | v = torch.clamp(v, 1e-5, math.pi - 1e-5) 104 | 105 | theta = h 106 | v = v / math.pi 107 | phi = torch.arccos(1 - 2*v) 108 | 109 | camera_origins = torch.zeros((batch_size, 3), device=device) 110 | 111 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) 112 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) 113 | camera_origins[:, 1:2] = radius*torch.cos(phi) 114 | 115 | forward_vectors = math_utils.normalize_vecs(-camera_origins) 116 | return create_cam2world_matrix(forward_vectors, camera_origins) 117 | 118 | def create_cam2world_matrix(forward_vector, origin): 119 | """ 120 | Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix. 121 | Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll. 122 | """ 123 | 124 | forward_vector = math_utils.normalize_vecs(forward_vector) 125 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector) 126 | 127 | right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1)) 128 | up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1)) 129 | 130 | rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 131 | rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1) 132 | 133 | translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 134 | translation_matrix[:, :3, 3] = origin 135 | cam2world = (translation_matrix @ rotation_matrix)[:, :, :] 136 | assert(cam2world.shape[1:] == (4, 4)) 137 | return cam2world 138 | 139 | 140 | def FOV_to_intrinsics(fov_degrees, device='cpu'): 141 | """ 142 | Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. 143 | Note the intrinsics are returned as normalized by image size, rather than in pixel units. 144 | Assumes principal point is at image center. 145 | """ 146 | 147 | focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414)) 148 | intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) 149 | return intrinsics -------------------------------------------------------------------------------- /dataset/testdata_img.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_img.zip -------------------------------------------------------------------------------- /dataset/testdata_img/000134.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_img/000134.jpg -------------------------------------------------------------------------------- /dataset/testdata_img/000157.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_img/000157.jpg -------------------------------------------------------------------------------- /dataset/testdata_img/dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "labels": [ 3 | [ 4 | "000134.jpg", 5 | [ 6 | "0.980169", 7 | "-0.004075", 8 | "-0.198125", 9 | "0.534939", 10 | "-0.003572", 11 | "-0.999991", 12 | "0.001775", 13 | "-0.004793", 14 | "-0.198131", 15 | "-0.000988", 16 | "-0.980175", 17 | "2.646472", 18 | "0.000000", 19 | "0.000000", 20 | "0.000000", 21 | "1.000000", 22 | "4.264700", 23 | "0.000000", 24 | "0.500000", 25 | "0.000000", 26 | "4.264700", 27 | "0.500000", 28 | "0.000000", 29 | "0.000000", 30 | "1.000000" 31 | ] 32 | ], 33 | [ 34 | "000157.jpg", 35 | [ 36 | "0.975032", 37 | "-0.009496", 38 | "0.221723", 39 | "-0.598652", 40 | "-0.008519", 41 | "-0.999792", 42 | "-0.020003", 43 | "0.054008", 44 | "0.221925", 45 | "0.018354", 46 | "-0.974910", 47 | "2.632258", 48 | "0.000000", 49 | "0.000000", 50 | "0.000000", 51 | "1.000000", 52 | "4.264700", 53 | "0.000000", 54 | "0.500000", 55 | "0.000000", 56 | "4.264700", 57 | "0.500000", 58 | "0.000000", 59 | "0.000000", 60 | "1.000000" 61 | ] 62 | ] 63 | ] 64 | } -------------------------------------------------------------------------------- /dataset/testdata_seg.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_seg.zip -------------------------------------------------------------------------------- /dataset/testdata_seg/000134.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_seg/000134.png -------------------------------------------------------------------------------- /dataset/testdata_seg/000157.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/dataset/testdata_seg/000157.png -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from .util import EasyDict, make_cache_dir_path 12 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | name: panohead 12 | channels: 13 | - pytorch 14 | - nvidia 15 | dependencies: 16 | - python >= 3.8 17 | - pip 18 | - numpy>=1.20 19 | - click>=8.0 20 | - pillow=8.3.1 21 | - scipy=1.7.1 22 | - pytorch=1.11.0 23 | - cudatoolkit=11.1 24 | - requests=2.26.0 25 | - tqdm=4.62.2 26 | - ninja=1.10.2 27 | - matplotlib=3.4.2 28 | - imageio=2.9.0 29 | - pip: 30 | - imgui==1.3.0 31 | - glfw==2.2.0 32 | - pyopengl==3.1.5 33 | - imageio-ffmpeg==0.4.3 34 | - pyspng 35 | - psutil 36 | - mrcfile 37 | - tensorboard 38 | - torchvision==0.12.0 -------------------------------------------------------------------------------- /gen_interpolation.py: -------------------------------------------------------------------------------- 1 | ''' Generate images and shapes using pretrained network pickle. 2 | Code adapted from following paper 3 | "Efficient Geometry-aware 3D Generative Adversarial Networks." 4 | See LICENSES/LICENSE_EG3D for original license. 5 | ''' 6 | 7 | import os 8 | import re 9 | from typing import List, Optional, Tuple, Union 10 | 11 | import click 12 | import dnnlib 13 | import numpy as np 14 | import PIL.Image 15 | import torch 16 | from tqdm import tqdm 17 | 18 | 19 | import legacy 20 | from camera_utils import LookAtPoseSampler, FOV_to_intrinsics 21 | from torch_utils import misc 22 | from training.triplane import TriPlaneGenerator 23 | from torchvision.models.segmentation import deeplabv3_resnet101 24 | from torchvision import transforms, utils 25 | from torch.nn import functional as F 26 | 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | def parse_range(s: Union[str, List]) -> List[int]: 31 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 32 | 33 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 34 | ''' 35 | if isinstance(s, list): return s 36 | ranges = [] 37 | range_re = re.compile(r'^(\d+)-(\d+)$') 38 | for p in s.split(','): 39 | m = range_re.match(p) 40 | if m: 41 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 42 | else: 43 | ranges.append(int(p)) 44 | return ranges 45 | 46 | 47 | #---------------------------------------------------------------------------- 48 | def get_mask(model, batch, cid): 49 | normalized_batch = transforms.functional.normalize( 50 | batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 51 | output = model(normalized_batch)['out'] 52 | # sem_classes = [ 53 | # '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 54 | # 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 55 | # 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' 56 | # ] 57 | # sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} 58 | # cid = sem_class_to_idx['car'] 59 | 60 | normalized_masks = torch.nn.functional.softmax(output, dim=1) 61 | 62 | boolean_car_masks = (normalized_masks.argmax(1) == cid) 63 | return boolean_car_masks.float() 64 | 65 | def norm_ip(img, low, high): 66 | img_ = img.clamp(min=low, max=high) 67 | img_.sub_(low).div_(max(high - low, 1e-5)) 68 | return img_ 69 | 70 | 71 | def norm_range(t, value_range=(-1, 1)): 72 | if value_range is not None: 73 | return norm_ip(t, value_range[0], value_range[1]) 74 | else: 75 | return norm_ip(t, float(t.min()), float(t.max())) 76 | 77 | #---------------------------------------------------------------------------- 78 | @click.command() 79 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 80 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.7, show_default=True) 81 | @click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) 82 | @click.option('--fov-deg', help='Field of View of camera in degrees', type=int, required=False, metavar='float', default=18.837, show_default=True) 83 | @click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True) 84 | @click.option('--pose_cond', type=int, help='pose_cond angle', default=90, show_default=True) 85 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 86 | 87 | def generate_images( 88 | network_pkl: str, 89 | outdir: str, 90 | truncation_psi: float, 91 | truncation_cutoff: int, 92 | fov_deg: float, 93 | reload_modules: bool, 94 | pose_cond: int, 95 | ): 96 | """Generate interpolation images using pretrained network pickle. 97 | 98 | Examples: 99 | 100 | \b 101 | python gen_interpolation.py --network models/easy-khair-180-gpc0.8-trans10-025000.pkl\ 102 | --trunc 0.7 --outdir interpolation_out 103 | """ 104 | 105 | print('Loading networks from "%s"...' % network_pkl) 106 | device = torch.device('cuda:1') 107 | with dnnlib.util.open_url(network_pkl) as f: 108 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 109 | 110 | # Specify reload_modules=True if you want code modifications to take effect; otherwise uses pickled code 111 | if reload_modules: 112 | print("Reloading Modules!") 113 | G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) 114 | misc.copy_params_and_buffers(G, G_new, require_all=True) 115 | G_new.neural_rendering_resolution = G.neural_rendering_resolution 116 | G_new.rendering_kwargs = G.rendering_kwargs 117 | G = G_new 118 | 119 | network_pkl = os.path.basename(network_pkl) 120 | outdir = os.path.join(outdir, os.path.splitext(network_pkl)[0] + '_' + str(pose_cond)) 121 | os.makedirs(outdir, exist_ok=True) 122 | 123 | pose_cond_rad = pose_cond/180*np.pi 124 | intrinsics = FOV_to_intrinsics(fov_deg, device=device) 125 | 126 | 127 | 128 | # n_sample = 64 129 | # batch = 32 130 | 131 | # n_sample = n_sample // batch * batch 132 | # batch_li = n_sample // batch * [batch] 133 | pose_cond_rad = pose_cond/180*np.pi 134 | 135 | intrinsics = FOV_to_intrinsics(fov_deg, device=device) 136 | 137 | # Generate images. 138 | cam_pivot = torch.tensor([0, 0, 0], device=device) 139 | cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7) 140 | conditioning_cam2world_pose = LookAtPoseSampler.sample(pose_cond_rad, np.pi/2, cam_pivot, radius=cam_radius, device=device) 141 | conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) 142 | 143 | conditioning_cam2world_pose_back = LookAtPoseSampler.sample(-pose_cond_rad, np.pi/2, cam_pivot, radius=cam_radius, device=device) 144 | conditioning_params_back = torch.cat([conditioning_cam2world_pose_back.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) 145 | 146 | 147 | conditioning_cam2world_pose_side = LookAtPoseSampler.sample(45/180*np.pi, np.pi/2, cam_pivot, radius=cam_radius, device=device) 148 | conditioning_params_side = torch.cat([conditioning_cam2world_pose_side.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) 149 | 150 | # set two random seeds for interpolation 151 | seed1, seed2 = 521, 329 152 | z0 = torch.from_numpy(np.random.RandomState(seed1).randn(G.z_dim).reshape(1,G.z_dim)).to(device) 153 | z1 = torch.from_numpy(np.random.RandomState(seed2).randn(G.z_dim).reshape(1,G.z_dim)).to(device) 154 | c = conditioning_params 155 | 156 | ws0 = G.mapping(z0, c, truncation_psi=0.7, truncation_cutoff=None) 157 | ws1 = G.mapping(z1, c, truncation_psi=0.7, truncation_cutoff=None) 158 | 159 | image_final = [] 160 | for c in [conditioning_params, conditioning_params_side, conditioning_params_back]: 161 | img0 = G.synthesis(ws0, c)['image'] 162 | img0 = norm_range(img0) 163 | img1 = G.synthesis(ws1, c)['image'] 164 | img1 = norm_range(img1) 165 | 166 | 167 | 168 | 169 | 170 | img_list = [] 171 | for interpolation_idx in [0,2,3,4,6,8]: 172 | # for interpolation_idx in range(0,14,1): 173 | # interpolation_idx = 8 174 | ws_new = ws0.clone() 175 | ws_new[:, interpolation_idx:, :] = ws1[:, interpolation_idx:, :] 176 | img_new = G.synthesis(ws_new, c)['image'] 177 | img_new = norm_range(img_new) 178 | img_list.append(img_new) 179 | 180 | img_list.append(img0) 181 | 182 | img_new = torch.cat(img_list, dim=0) 183 | image_final.append(img_new) 184 | 185 | image_final = torch.cat(image_final, dim=2) 186 | 187 | utils.save_image( 188 | image_final, 189 | os.path.join(outdir, f'img_interpolation_seed{seed1}_{seed2}.png'), 190 | # nrow=8, 191 | normalize=True, 192 | range=(0, 1), 193 | padding=0, 194 | ) 195 | # utils.save_image( 196 | # diff, 197 | # f'alphamse/changebg/diff.png', 198 | # nrow=8, 199 | # normalize=True, 200 | # range=(0, 1), 201 | # padding=0, 202 | # ) 203 | # sys.exit() 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | #---------------------------------------------------------------------------- 213 | 214 | if __name__ == "__main__": 215 | generate_images() # pylint: disable=no-value-for-parameter 216 | 217 | #---------------------------------------------------------------------------- 218 | 219 | -------------------------------------------------------------------------------- /gen_pti_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | models=("easy-khair-180-gpc0.8-trans10-025000.pkl"\ 4 | "ablation-trigridD-1-025000.pkl") 5 | 6 | in="models" 7 | out="pti_out" 8 | 9 | for model in ${models[@]} 10 | 11 | do 12 | 13 | for i in 0 1 14 | 15 | do 16 | # perform the pti and save w 17 | python projector_withseg.py --outdir=${out} --target_img=dataset/testdata_img --network ${in}/${model} --idx ${i} 18 | # generate .mp4 before finetune 19 | python gen_videos_proj_withseg.py --output=${out}/${model}/${i}/PTI_render/pre.mp4 --latent=${out}/${model}/${i}/projected_w.npz --trunc 0.7 --network ${in}/${model} --cfg Head 20 | # generate .mp4 after finetune 21 | python gen_videos_proj_withseg.py --output=${out}/${model}/${i}/PTI_render/post.mp4 --latent=${out}/${model}/${i}/projected_w.npz --trunc 0.7 --network ${out}/${model}/${i}/fintuned_generator.pkl --cfg Head 22 | 23 | 24 | done 25 | 26 | done 27 | -------------------------------------------------------------------------------- /get_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | models=("easy-khair-180-gpc0.8-trans10-025000.pkl"\ 4 | "ablation-trigridD-1-025000.pkl"\ 5 | ) 6 | 7 | 8 | for model in ${models[@]} 9 | 10 | do 11 | 12 | 13 | python calc_metrics.py --network models/${model} \ 14 | --img_data=dataset/ffhq-3DDFA-exp-augx2-lpx2-easyx2-khairfiltered-img.zip\ 15 | --seg_data=dataset/ffhq-3DDFA-exp-augx2-lpx2-easyx2-khairfiltered-seg.zip\ 16 | --gpus 8 --mirror True --metrics fid50k_full,is50k | tee -a paper_metrics/${model}.log 17 | 18 | 19 | done 20 | -------------------------------------------------------------------------------- /gui_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /gui_utils/glfw_window.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import time 12 | import glfw 13 | import OpenGL.GL as gl 14 | from . import gl_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | class GlfwWindow: # pylint: disable=too-many-public-methods 19 | def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True): 20 | self._glfw_window = None 21 | self._drawing_frame = False 22 | self._frame_start_time = None 23 | self._frame_delta = 0 24 | self._fps_limit = None 25 | self._vsync = None 26 | self._skip_frames = 0 27 | self._deferred_show = deferred_show 28 | self._close_on_esc = close_on_esc 29 | self._esc_pressed = False 30 | self._drag_and_drop_paths = None 31 | self._capture_next_frame = False 32 | self._captured_frame = None 33 | 34 | # Create window. 35 | glfw.init() 36 | glfw.window_hint(glfw.VISIBLE, False) 37 | self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None) 38 | self._attach_glfw_callbacks() 39 | self.make_context_current() 40 | 41 | # Adjust window. 42 | self.set_vsync(False) 43 | self.set_window_size(window_width, window_height) 44 | if not self._deferred_show: 45 | glfw.show_window(self._glfw_window) 46 | 47 | def close(self): 48 | if self._drawing_frame: 49 | self.end_frame() 50 | if self._glfw_window is not None: 51 | glfw.destroy_window(self._glfw_window) 52 | self._glfw_window = None 53 | #glfw.terminate() # Commented out to play it nice with other glfw clients. 54 | 55 | def __del__(self): 56 | try: 57 | self.close() 58 | except: 59 | pass 60 | 61 | @property 62 | def window_width(self): 63 | return self.content_width 64 | 65 | @property 66 | def window_height(self): 67 | return self.content_height + self.title_bar_height 68 | 69 | @property 70 | def content_width(self): 71 | width, _height = glfw.get_window_size(self._glfw_window) 72 | return width 73 | 74 | @property 75 | def content_height(self): 76 | _width, height = glfw.get_window_size(self._glfw_window) 77 | return height 78 | 79 | @property 80 | def title_bar_height(self): 81 | _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window) 82 | return top 83 | 84 | @property 85 | def monitor_width(self): 86 | _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) 87 | return width 88 | 89 | @property 90 | def monitor_height(self): 91 | _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) 92 | return height 93 | 94 | @property 95 | def frame_delta(self): 96 | return self._frame_delta 97 | 98 | def set_title(self, title): 99 | glfw.set_window_title(self._glfw_window, title) 100 | 101 | def set_window_size(self, width, height): 102 | width = min(width, self.monitor_width) 103 | height = min(height, self.monitor_height) 104 | glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0)) 105 | if width == self.monitor_width and height == self.monitor_height: 106 | self.maximize() 107 | 108 | def set_content_size(self, width, height): 109 | self.set_window_size(width, height + self.title_bar_height) 110 | 111 | def maximize(self): 112 | glfw.maximize_window(self._glfw_window) 113 | 114 | def set_position(self, x, y): 115 | glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height) 116 | 117 | def center(self): 118 | self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2) 119 | 120 | def set_vsync(self, vsync): 121 | vsync = bool(vsync) 122 | if vsync != self._vsync: 123 | glfw.swap_interval(1 if vsync else 0) 124 | self._vsync = vsync 125 | 126 | def set_fps_limit(self, fps_limit): 127 | self._fps_limit = int(fps_limit) 128 | 129 | def should_close(self): 130 | return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed) 131 | 132 | def skip_frame(self): 133 | self.skip_frames(1) 134 | 135 | def skip_frames(self, num): # Do not update window for the next N frames. 136 | self._skip_frames = max(self._skip_frames, int(num)) 137 | 138 | def is_skipping_frames(self): 139 | return self._skip_frames > 0 140 | 141 | def capture_next_frame(self): 142 | self._capture_next_frame = True 143 | 144 | def pop_captured_frame(self): 145 | frame = self._captured_frame 146 | self._captured_frame = None 147 | return frame 148 | 149 | def pop_drag_and_drop_paths(self): 150 | paths = self._drag_and_drop_paths 151 | self._drag_and_drop_paths = None 152 | return paths 153 | 154 | def draw_frame(self): # To be overridden by subclass. 155 | self.begin_frame() 156 | # Rendering code goes here. 157 | self.end_frame() 158 | 159 | def make_context_current(self): 160 | if self._glfw_window is not None: 161 | glfw.make_context_current(self._glfw_window) 162 | 163 | def begin_frame(self): 164 | # End previous frame. 165 | if self._drawing_frame: 166 | self.end_frame() 167 | 168 | # Apply FPS limit. 169 | if self._frame_start_time is not None and self._fps_limit is not None: 170 | delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit 171 | if delay > 0: 172 | time.sleep(delay) 173 | cur_time = time.perf_counter() 174 | if self._frame_start_time is not None: 175 | self._frame_delta = cur_time - self._frame_start_time 176 | self._frame_start_time = cur_time 177 | 178 | # Process events. 179 | glfw.poll_events() 180 | 181 | # Begin frame. 182 | self._drawing_frame = True 183 | self.make_context_current() 184 | 185 | # Initialize GL state. 186 | gl.glViewport(0, 0, self.content_width, self.content_height) 187 | gl.glMatrixMode(gl.GL_PROJECTION) 188 | gl.glLoadIdentity() 189 | gl.glTranslate(-1, 1, 0) 190 | gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1) 191 | gl.glMatrixMode(gl.GL_MODELVIEW) 192 | gl.glLoadIdentity() 193 | gl.glEnable(gl.GL_BLEND) 194 | gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha. 195 | 196 | # Clear. 197 | gl.glClearColor(0, 0, 0, 1) 198 | gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) 199 | 200 | def end_frame(self): 201 | assert self._drawing_frame 202 | self._drawing_frame = False 203 | 204 | # Skip frames if requested. 205 | if self._skip_frames > 0: 206 | self._skip_frames -= 1 207 | return 208 | 209 | # Capture frame if requested. 210 | if self._capture_next_frame: 211 | self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height) 212 | self._capture_next_frame = False 213 | 214 | # Update window. 215 | if self._deferred_show: 216 | glfw.show_window(self._glfw_window) 217 | self._deferred_show = False 218 | glfw.swap_buffers(self._glfw_window) 219 | 220 | def _attach_glfw_callbacks(self): 221 | glfw.set_key_callback(self._glfw_window, self._glfw_key_callback) 222 | glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback) 223 | 224 | def _glfw_key_callback(self, _window, key, _scancode, action, _mods): 225 | if action == glfw.PRESS and key == glfw.KEY_ESCAPE: 226 | self._esc_pressed = True 227 | 228 | def _glfw_drop_callback(self, _window, paths): 229 | self._drag_and_drop_paths = paths 230 | 231 | #---------------------------------------------------------------------------- 232 | -------------------------------------------------------------------------------- /gui_utils/imgui_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import contextlib 12 | import imgui 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): 17 | s = imgui.get_style() 18 | s.window_padding = [spacing, spacing] 19 | s.item_spacing = [spacing, spacing] 20 | s.item_inner_spacing = [spacing, spacing] 21 | s.columns_min_spacing = spacing 22 | s.indent_spacing = indent 23 | s.scrollbar_size = scrollbar 24 | s.frame_padding = [4, 3] 25 | s.window_border_size = 1 26 | s.child_border_size = 1 27 | s.popup_border_size = 1 28 | s.frame_border_size = 1 29 | s.window_rounding = 0 30 | s.child_rounding = 0 31 | s.popup_rounding = 3 32 | s.frame_rounding = 3 33 | s.scrollbar_rounding = 3 34 | s.grab_rounding = 3 35 | 36 | getattr(imgui, f'style_colors_{color_scheme}')(s) 37 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 38 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] 39 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | @contextlib.contextmanager 44 | def grayed_out(cond=True): 45 | if cond: 46 | s = imgui.get_style() 47 | text = s.colors[imgui.COLOR_TEXT_DISABLED] 48 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] 49 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 50 | imgui.push_style_color(imgui.COLOR_TEXT, *text) 51 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) 52 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) 53 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) 54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) 55 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) 56 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) 57 | imgui.push_style_color(imgui.COLOR_BUTTON, *back) 58 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) 59 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) 60 | imgui.push_style_color(imgui.COLOR_HEADER, *back) 61 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) 62 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) 63 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) 64 | yield 65 | imgui.pop_style_color(14) 66 | else: 67 | yield 68 | 69 | #---------------------------------------------------------------------------- 70 | 71 | @contextlib.contextmanager 72 | def item_width(width=None): 73 | if width is not None: 74 | imgui.push_item_width(width) 75 | yield 76 | imgui.pop_item_width() 77 | else: 78 | yield 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def scoped_by_object_id(method): 83 | def decorator(self, *args, **kwargs): 84 | imgui.push_id(str(id(self))) 85 | res = method(self, *args, **kwargs) 86 | imgui.pop_id() 87 | return res 88 | return decorator 89 | 90 | #---------------------------------------------------------------------------- 91 | 92 | def button(label, width=0, enabled=True): 93 | with grayed_out(not enabled): 94 | clicked = imgui.button(label, width=width) 95 | clicked = clicked and enabled 96 | return clicked 97 | 98 | #---------------------------------------------------------------------------- 99 | 100 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): 101 | expanded = False 102 | if show: 103 | if default: 104 | flags |= imgui.TREE_NODE_DEFAULT_OPEN 105 | if not enabled: 106 | flags |= imgui.TREE_NODE_LEAF 107 | with grayed_out(not enabled): 108 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) 109 | expanded = expanded and enabled 110 | return expanded, visible 111 | 112 | #---------------------------------------------------------------------------- 113 | 114 | def popup_button(label, width=0, enabled=True): 115 | if button(label, width, enabled): 116 | imgui.open_popup(label) 117 | opened = imgui.begin_popup(label) 118 | return opened 119 | 120 | #---------------------------------------------------------------------------- 121 | 122 | def input_text(label, value, buffer_length, flags, width=None, help_text=''): 123 | old_value = value 124 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) 125 | if value == '': 126 | color[-1] *= 0.5 127 | with item_width(width): 128 | imgui.push_style_color(imgui.COLOR_TEXT, *color) 129 | value = value if value != '' else help_text 130 | changed, value = imgui.input_text(label, value, buffer_length, flags) 131 | value = value if value != help_text else '' 132 | imgui.pop_style_color(1) 133 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: 134 | changed = (value != old_value) 135 | return changed, value 136 | 137 | #---------------------------------------------------------------------------- 138 | 139 | def drag_previous_control(enabled=True): 140 | dragging = False 141 | dx = 0 142 | dy = 0 143 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): 144 | if enabled: 145 | dragging = True 146 | dx, dy = imgui.get_mouse_drag_delta() 147 | imgui.reset_mouse_drag_delta() 148 | imgui.end_drag_drop_source() 149 | return dragging, dx, dy 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def drag_button(label, width=0, enabled=True): 154 | clicked = button(label, width=width, enabled=enabled) 155 | dragging, dx, dy = drag_previous_control(enabled=enabled) 156 | return clicked, dragging, dx, dy 157 | 158 | #---------------------------------------------------------------------------- 159 | 160 | def drag_hidden_window(label, x, y, width, height, enabled=True): 161 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) 162 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) 163 | imgui.set_next_window_position(x, y) 164 | imgui.set_next_window_size(width, height) 165 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) 166 | dragging, dx, dy = drag_previous_control(enabled=enabled) 167 | imgui.end() 168 | imgui.pop_style_color(2) 169 | return dragging, dx, dy 170 | 171 | #---------------------------------------------------------------------------- 172 | -------------------------------------------------------------------------------- /gui_utils/imgui_window.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import imgui 13 | import imgui.integrations.glfw 14 | 15 | from . import glfw_window 16 | from . import imgui_utils 17 | from . import text_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class ImguiWindow(glfw_window.GlfwWindow): 22 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): 23 | if font is None: 24 | font = text_utils.get_default_font() 25 | font_sizes = {int(size) for size in font_sizes} 26 | super().__init__(title=title, **glfw_kwargs) 27 | 28 | # Init fields. 29 | self._imgui_context = None 30 | self._imgui_renderer = None 31 | self._imgui_fonts = None 32 | self._cur_font_size = max(font_sizes) 33 | 34 | # Delete leftover imgui.ini to avoid unexpected behavior. 35 | if os.path.isfile('imgui.ini'): 36 | os.remove('imgui.ini') 37 | 38 | # Init ImGui. 39 | self._imgui_context = imgui.create_context() 40 | self._imgui_renderer = _GlfwRenderer(self._glfw_window) 41 | self._attach_glfw_callbacks() 42 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. 43 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). 44 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} 45 | self._imgui_renderer.refresh_font_texture() 46 | 47 | def close(self): 48 | self.make_context_current() 49 | self._imgui_fonts = None 50 | if self._imgui_renderer is not None: 51 | self._imgui_renderer.shutdown() 52 | self._imgui_renderer = None 53 | if self._imgui_context is not None: 54 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. 55 | self._imgui_context = None 56 | super().close() 57 | 58 | def _glfw_key_callback(self, *args): 59 | super()._glfw_key_callback(*args) 60 | self._imgui_renderer.keyboard_callback(*args) 61 | 62 | @property 63 | def font_size(self): 64 | return self._cur_font_size 65 | 66 | @property 67 | def spacing(self): 68 | return round(self._cur_font_size * 0.4) 69 | 70 | def set_font_size(self, target): # Applied on next frame. 71 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] 72 | 73 | def begin_frame(self): 74 | # Begin glfw frame. 75 | super().begin_frame() 76 | 77 | # Process imgui events. 78 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 79 | if self.content_width > 0 and self.content_height > 0: 80 | self._imgui_renderer.process_inputs() 81 | 82 | # Begin imgui frame. 83 | imgui.new_frame() 84 | imgui.push_font(self._imgui_fonts[self._cur_font_size]) 85 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) 86 | 87 | def end_frame(self): 88 | imgui.pop_font() 89 | imgui.render() 90 | imgui.end_frame() 91 | self._imgui_renderer.render(imgui.get_draw_data()) 92 | super().end_frame() 93 | 94 | #---------------------------------------------------------------------------- 95 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. 96 | 97 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): 98 | def __init__(self, *args, **kwargs): 99 | super().__init__(*args, **kwargs) 100 | self.mouse_wheel_multiplier = 1 101 | 102 | def scroll_callback(self, window, x_offset, y_offset): 103 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier 104 | 105 | #---------------------------------------------------------------------------- 106 | -------------------------------------------------------------------------------- /gui_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import functools 12 | from typing import Optional 13 | 14 | import dnnlib 15 | import numpy as np 16 | import PIL.Image 17 | import PIL.ImageFont 18 | import scipy.ndimage 19 | 20 | from . import gl_utils 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | def get_default_font(): 25 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular 26 | return dnnlib.util.open_url(url, return_filename=True) 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | @functools.lru_cache(maxsize=None) 31 | def get_pil_font(font=None, size=32): 32 | if font is None: 33 | font = get_default_font() 34 | return PIL.ImageFont.truetype(font=font, size=size) 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | def get_array(string, *, dropshadow_radius: int=None, **kwargs): 39 | if dropshadow_radius is not None: 40 | offset_x = int(np.ceil(dropshadow_radius*2/3)) 41 | offset_y = int(np.ceil(dropshadow_radius*2/3)) 42 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 43 | else: 44 | return _get_array_priv(string, **kwargs) 45 | 46 | @functools.lru_cache(maxsize=10000) 47 | def _get_array_priv( 48 | string: str, *, 49 | size: int = 32, 50 | max_width: Optional[int]=None, 51 | max_height: Optional[int]=None, 52 | min_size=10, 53 | shrink_coef=0.8, 54 | dropshadow_radius: int=None, 55 | offset_x: int=None, 56 | offset_y: int=None, 57 | **kwargs 58 | ): 59 | cur_size = size 60 | array = None 61 | while True: 62 | if dropshadow_radius is not None: 63 | # separate implementation for dropshadow text rendering 64 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 65 | else: 66 | array = _get_array_impl(string, size=cur_size, **kwargs) 67 | height, width, _ = array.shape 68 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): 69 | break 70 | cur_size = max(int(cur_size * shrink_coef), min_size) 71 | return array 72 | 73 | #---------------------------------------------------------------------------- 74 | 75 | @functools.lru_cache(maxsize=10000) 76 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): 77 | pil_font = get_pil_font(font=font, size=size) 78 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 79 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 80 | width = max(line.shape[1] for line in lines) 81 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 82 | line_spacing = line_pad if line_pad is not None else size // 2 83 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 84 | mask = np.concatenate(lines, axis=0) 85 | alpha = mask 86 | if outline > 0: 87 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) 88 | alpha = mask.astype(np.float32) / 255 89 | alpha = scipy.ndimage.gaussian_filter(alpha, outline) 90 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp 91 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 92 | alpha = np.maximum(alpha, mask) 93 | return np.stack([mask, alpha], axis=-1) 94 | 95 | #---------------------------------------------------------------------------- 96 | 97 | @functools.lru_cache(maxsize=10000) 98 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): 99 | assert (offset_x > 0) and (offset_y > 0) 100 | pil_font = get_pil_font(font=font, size=size) 101 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 102 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 103 | width = max(line.shape[1] for line in lines) 104 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 105 | line_spacing = line_pad if line_pad is not None else size // 2 106 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 107 | mask = np.concatenate(lines, axis=0) 108 | alpha = mask 109 | 110 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) 111 | alpha = mask.astype(np.float32) / 255 112 | alpha = scipy.ndimage.gaussian_filter(alpha, radius) 113 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 114 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 115 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] 116 | alpha = np.maximum(alpha, mask) 117 | return np.stack([mask, alpha], axis=-1) 118 | 119 | #---------------------------------------------------------------------------- 120 | 121 | @functools.lru_cache(maxsize=10000) 122 | def get_texture(string, bilinear=True, mipmap=True, **kwargs): 123 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Frechet Inception Distance (FID) from the paper 12 | "GANs trained by a two time-scale update rule converge to a local Nash 13 | equilibrium". Matches the original implementation by Heusel et al. at 14 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 15 | 16 | import numpy as np 17 | import scipy.linalg 18 | from . import metric_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def compute_fid(opts, max_real, num_gen): 23 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 24 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 25 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 26 | 27 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 30 | 31 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 32 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 33 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 34 | 35 | if opts.rank != 0: 36 | return float('nan') 37 | 38 | m = np.square(mu_gen - mu_real).sum() 39 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 40 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 41 | return float(fid) 42 | 43 | #---------------------------------------------------------------------------- 44 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Inception Score (IS) from the paper "Improved techniques for training 12 | GANs". Matches the original implementation by Salimans et al. at 13 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 14 | 15 | import numpy as np 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_is(opts, num_gen, num_splits): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 24 | 25 | gen_probs = metric_utils.compute_feature_stats_for_generator( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | capture_all=True, max_items=num_gen).get_all() 28 | 29 | if opts.rank != 0: 30 | return float('nan'), float('nan') 31 | 32 | scores = [] 33 | for i in range(num_splits): 34 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 35 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 36 | kl = np.mean(np.sum(kl, axis=1)) 37 | scores.append(np.exp(kl)) 38 | return float(np.mean(scores)), float(np.std(scores)) 39 | 40 | #---------------------------------------------------------------------------- 41 | -------------------------------------------------------------------------------- /metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 12 | GANs". Matches the original implementation by Binkowski et al. at 13 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 14 | 15 | import numpy as np 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | real_features = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 28 | 29 | gen_features = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | n = real_features.shape[1] 37 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 38 | t = 0 39 | for _subset_idx in range(num_subsets): 40 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 41 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 42 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 43 | b = (x @ y.T / n + 1) ** 3 44 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 45 | kid = t / num_subsets / m 46 | return float(kid) 47 | 48 | #---------------------------------------------------------------------------- 49 | -------------------------------------------------------------------------------- /metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Main API for computing and reporting quality metrics.""" 12 | 13 | import os 14 | import time 15 | import json 16 | import torch 17 | import dnnlib 18 | 19 | from . import metric_utils 20 | from . import frechet_inception_distance 21 | from . import kernel_inception_distance 22 | from . import precision_recall 23 | from . import perceptual_path_length 24 | from . import inception_score 25 | from . import equivariance 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | _metric_dict = dict() # name => fn 30 | 31 | def register_metric(fn): 32 | assert callable(fn) 33 | _metric_dict[fn.__name__] = fn 34 | return fn 35 | 36 | def is_valid_metric(metric): 37 | return metric in _metric_dict 38 | 39 | def list_valid_metrics(): 40 | return list(_metric_dict.keys()) 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 45 | assert is_valid_metric(metric) 46 | opts = metric_utils.MetricOptions(**kwargs) 47 | 48 | # Calculate. 49 | start_time = time.time() 50 | results = _metric_dict[metric](opts) 51 | total_time = time.time() - start_time 52 | 53 | # Broadcast results. 54 | for key, value in list(results.items()): 55 | if opts.num_gpus > 1: 56 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 57 | torch.distributed.broadcast(tensor=value, src=0) 58 | value = float(value.cpu()) 59 | results[key] = value 60 | 61 | # Decorate with metadata. 62 | return dnnlib.EasyDict( 63 | results = dnnlib.EasyDict(results), 64 | metric = metric, 65 | total_time = total_time, 66 | total_time_str = dnnlib.util.format_time(total_time), 67 | num_gpus = opts.num_gpus, 68 | mode = opts.mode 69 | ) 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 74 | metric = result_dict['metric'] 75 | assert is_valid_metric(metric) 76 | if run_dir is not None and snapshot_pkl is not None: 77 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 78 | 79 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 80 | print(jsonl_line) 81 | if run_dir is not None and os.path.isdir(run_dir): 82 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 83 | f.write(jsonl_line + '\n') 84 | 85 | #---------------------------------------------------------------------------- 86 | # Recommended metrics. 87 | 88 | @register_metric 89 | def fid50k_full(opts): 90 | opts.dataset_kwargs.update(max_size=None, xflip=False) 91 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 92 | return dict(fid50k_full=fid) 93 | 94 | @register_metric 95 | def kid50k_full(opts): 96 | opts.dataset_kwargs.update(max_size=None, xflip=False) 97 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 98 | return dict(kid50k_full=kid) 99 | 100 | @register_metric 101 | def pr50k3_full(opts): 102 | opts.dataset_kwargs.update(max_size=None, xflip=False) 103 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 104 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 105 | 106 | @register_metric 107 | def ppl2_wend(opts): 108 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 109 | return dict(ppl2_wend=ppl) 110 | 111 | @register_metric 112 | def eqt50k_int(opts): 113 | opts.G_kwargs.update(force_fp32=True) 114 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) 115 | return dict(eqt50k_int=psnr) 116 | 117 | @register_metric 118 | def eqt50k_frac(opts): 119 | opts.G_kwargs.update(force_fp32=True) 120 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) 121 | return dict(eqt50k_frac=psnr) 122 | 123 | @register_metric 124 | def eqr50k(opts): 125 | opts.G_kwargs.update(force_fp32=True) 126 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) 127 | return dict(eqr50k=psnr) 128 | 129 | #---------------------------------------------------------------------------- 130 | # Legacy metrics. 131 | 132 | @register_metric 133 | def fid50k(opts): 134 | opts.dataset_kwargs.update(max_size=None) 135 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 136 | return dict(fid50k=fid) 137 | 138 | @register_metric 139 | def kid50k(opts): 140 | opts.dataset_kwargs.update(max_size=None) 141 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 142 | return dict(kid50k=kid) 143 | 144 | @register_metric 145 | def pr50k3(opts): 146 | opts.dataset_kwargs.update(max_size=None) 147 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 148 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 149 | 150 | @register_metric 151 | def is50k(opts): 152 | opts.dataset_kwargs.update(max_size=None, xflip=False) 153 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 154 | return dict(is50k_mean=mean, is50k_std=std) 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 12 | Architecture for Generative Adversarial Networks". Matches the original 13 | implementation by Karras et al. at 14 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 15 | 16 | import copy 17 | import numpy as np 18 | import torch 19 | from . import metric_utils 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | # Spherical interpolation of a batch of vectors. 24 | def slerp(a, b, t): 25 | a = a / a.norm(dim=-1, keepdim=True) 26 | b = b / b.norm(dim=-1, keepdim=True) 27 | d = (a * b).sum(dim=-1, keepdim=True) 28 | p = t * torch.acos(d) 29 | c = b - d * a 30 | c = c / c.norm(dim=-1, keepdim=True) 31 | d = a * torch.cos(p) + c * torch.sin(p) 32 | d = d / d.norm(dim=-1, keepdim=True) 33 | return d 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | class PPLSampler(torch.nn.Module): 38 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 39 | assert space in ['z', 'w'] 40 | assert sampling in ['full', 'end'] 41 | super().__init__() 42 | self.G = copy.deepcopy(G) 43 | self.G_kwargs = G_kwargs 44 | self.epsilon = epsilon 45 | self.space = space 46 | self.sampling = sampling 47 | self.crop = crop 48 | self.vgg16 = copy.deepcopy(vgg16) 49 | 50 | def forward(self, c): 51 | # Generate random latents and interpolation t-values. 52 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 53 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 54 | 55 | # Interpolate in W or Z. 56 | if self.space == 'w': 57 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 58 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 59 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 60 | else: # space == 'z' 61 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 62 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 63 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 64 | 65 | # Randomize noise buffers. 66 | for name, buf in self.G.named_buffers(): 67 | if name.endswith('.noise_const'): 68 | buf.copy_(torch.randn_like(buf)) 69 | 70 | # Generate images. 71 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 72 | 73 | # Center crop. 74 | if self.crop: 75 | assert img.shape[2] == img.shape[3] 76 | c = img.shape[2] // 8 77 | img = img[:, :, c*3 : c*7, c*2 : c*6] 78 | 79 | # Downsample to 256x256. 80 | factor = self.G.img_resolution // 256 81 | if factor > 1: 82 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 83 | 84 | # Scale dynamic range from [-1,1] to [0,255]. 85 | img = (img + 1) * (255 / 2) 86 | if self.G.img_channels == 1: 87 | img = img.repeat([1, 3, 1, 1]) 88 | 89 | # Evaluate differential LPIPS. 90 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 91 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 92 | return dist 93 | 94 | #---------------------------------------------------------------------------- 95 | 96 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 97 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 99 | 100 | # Setup sampler and labels. 101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 102 | sampler.eval().requires_grad_(False).to(opts.device) 103 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 104 | 105 | # Sampling loop. 106 | dist = [] 107 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 108 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 109 | progress.update(batch_start) 110 | x = sampler(next(c_iter)) 111 | for src in range(opts.num_gpus): 112 | y = x.clone() 113 | if opts.num_gpus > 1: 114 | torch.distributed.broadcast(y, src=src) 115 | dist.append(y) 116 | progress.update(num_samples) 117 | 118 | # Compute PPL. 119 | if opts.rank != 0: 120 | return float('nan') 121 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 122 | lo = np.percentile(dist, 1, interpolation='lower') 123 | hi = np.percentile(dist, 99, interpolation='higher') 124 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 125 | return float(ppl) 126 | 127 | #---------------------------------------------------------------------------- 128 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 12 | Metric for Assessing Generative Models". Matches the original implementation 13 | by Kynkaanniemi et al. at 14 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 15 | 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 22 | assert 0 <= rank < num_gpus 23 | num_cols = col_features.shape[0] 24 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 25 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 26 | dist_batches = [] 27 | for col_batch in col_batches[rank :: num_gpus]: 28 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 29 | for src in range(num_gpus): 30 | dist_broadcast = dist_batch.clone() 31 | if num_gpus > 1: 32 | torch.distributed.broadcast(dist_broadcast, src=src) 33 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 34 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 39 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 40 | detector_kwargs = dict(return_features=True) 41 | 42 | real_features = metric_utils.compute_feature_stats_for_dataset( 43 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 44 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 45 | 46 | gen_features = metric_utils.compute_feature_stats_for_generator( 47 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 48 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 49 | 50 | results = dict() 51 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 52 | kth = [] 53 | for manifold_batch in manifold.split(row_batch_size): 54 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 55 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 56 | kth = torch.cat(kth) if opts.rank == 0 else None 57 | pred = [] 58 | for probes_batch in probes.split(row_batch_size): 59 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 60 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 61 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 62 | return results['precision'], results['recall'] 63 | 64 | #---------------------------------------------------------------------------- 65 | -------------------------------------------------------------------------------- /misc/segmentation_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torchvision.transforms import ToPILImage 4 | 5 | from torchvision.models.segmentation import deeplabv3_resnet101 6 | from torchvision import transforms, utils 7 | 8 | 9 | def get_mask(model, batch, cid): 10 | normalized_batch = transforms.functional.normalize( 11 | batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 12 | output = model(normalized_batch)['out'] 13 | # sem_classes = [ 14 | # '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 15 | # 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 16 | # 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' 17 | # ] 18 | # sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} 19 | # cid = sem_class_to_idx['car'] 20 | 21 | normalized_masks = torch.nn.functional.softmax(output, dim=1) 22 | 23 | boolean_car_masks = (normalized_masks.argmax(1) == cid) 24 | return boolean_car_masks.float() 25 | 26 | image = Image.open('image.jpg') 27 | # Define the preprocessing transformation 28 | preprocess = transforms.Compose([ 29 | transforms.Resize((512, 512)), 30 | transforms.ToTensor() 31 | ]) 32 | 33 | # Apply the transformation to the image 34 | input_tensor = preprocess(image) 35 | input_batch = input_tensor.unsqueeze(0).to('cuda:0') 36 | 37 | # load segmentation net 38 | seg_net = deeplabv3_resnet101(pretrained=True, progress=False).to('cuda:0') 39 | seg_net.requires_grad_(False) 40 | seg_net.eval() 41 | 42 | # 15 means human mask 43 | mask0 = get_mask(seg_net, input_batch, 15).unsqueeze(1) 44 | 45 | # Squeeze the tensor to remove unnecessary dimensions and convert to PIL Image 46 | mask_squeezed = torch.squeeze(mask0) 47 | mask_image = ToPILImage()(mask_squeezed) 48 | 49 | # Save as PNG 50 | mask_image.save("mask.png") 51 | -------------------------------------------------------------------------------- /misc/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SizheAn/PanoHead/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/misc/teaser.png -------------------------------------------------------------------------------- /resave_model.py: -------------------------------------------------------------------------------- 1 | """ Reload a model and save it. """ 2 | import copy 3 | import os 4 | import pickle 5 | 6 | import click 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import dnnlib 11 | import legacy 12 | 13 | @click.command() 14 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 15 | @click.option('--output', 'output_pkl', help='Network pickle filename', required=True) 16 | def main( 17 | network_pkl: str, 18 | output_pkl: str, 19 | ): 20 | # Load networks. 21 | print('Loading networks from "%s"...' % network_pkl) 22 | with dnnlib.util.open_url(network_pkl) as fp: 23 | data = legacy.load_network_pkl(fp) 24 | data_new = {} 25 | for name in data.keys(): 26 | module = data.get(name, None) 27 | if module is not None and isinstance(module, torch.nn.Module): 28 | module = copy.deepcopy(module).eval().requires_grad_(False).cpu() 29 | data_new[name] = module 30 | del module # conserve memory 31 | with open(output_pkl, 'wb') as f: 32 | pickle.dump(data_new, f) 33 | #---------------------------------------------------------------------------- 34 | 35 | if __name__ == "__main__": 36 | main() # pylint: disable=no-value-for-parameter 37 | 38 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /shape_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | """ 13 | Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.) 14 | 15 | Takes as input an .mrc file and extracts a mesh. 16 | 17 | Ex. 18 | python shape_utils.py my_shape.mrc 19 | Ex. 20 | python shape_utils.py myshapes_directory --level=12 21 | """ 22 | 23 | 24 | import time 25 | import plyfile 26 | import glob 27 | import logging 28 | import numpy as np 29 | import os 30 | import random 31 | import torch 32 | import torch.utils.data 33 | import trimesh 34 | import skimage.measure 35 | import argparse 36 | import mrcfile 37 | from tqdm import tqdm 38 | 39 | 40 | def convert_sdf_samples_to_ply( 41 | numpy_3d_sdf_tensor, 42 | voxel_grid_origin, 43 | voxel_size, 44 | ply_filename_out, 45 | offset=None, 46 | scale=None, 47 | level=0.0 48 | ): 49 | """ 50 | Convert sdf samples to .ply 51 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 52 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 53 | :voxel_size: float, the size of the voxels 54 | :ply_filename_out: string, path of the filename to save to 55 | This function adapted from: https://github.com/RobotLocomotion/spartan 56 | """ 57 | start_time = time.time() 58 | 59 | verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0) 60 | # try: 61 | verts, faces, normals, values = skimage.measure.marching_cubes( 62 | numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3 63 | ) 64 | # except: 65 | # pass 66 | 67 | # transform from voxel coordinates to camera coordinates 68 | # note x and y are flipped in the output of marching_cubes 69 | mesh_points = np.zeros_like(verts) 70 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0] 71 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1] 72 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2] 73 | 74 | # apply additional offset and scale 75 | if scale is not None: 76 | mesh_points = mesh_points / scale 77 | if offset is not None: 78 | mesh_points = mesh_points - offset 79 | 80 | # try writing to the ply file 81 | 82 | num_verts = verts.shape[0] 83 | num_faces = faces.shape[0] 84 | 85 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 86 | 87 | for i in range(0, num_verts): 88 | verts_tuple[i] = tuple(mesh_points[i, :]) 89 | 90 | faces_building = [] 91 | for i in range(0, num_faces): 92 | faces_building.append(((faces[i, :].tolist(),))) 93 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 94 | 95 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 96 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 97 | 98 | ply_data = plyfile.PlyData([el_verts, el_faces]) 99 | ply_data.write(ply_filename_out) 100 | print(f"wrote to {ply_filename_out}") 101 | 102 | 103 | def convert_mrc(input_filename, output_filename, isosurface_level=1): 104 | with mrcfile.open(input_filename) as mrc: 105 | convert_sdf_samples_to_ply(np.transpose(mrc.data, (2, 1, 0)), [0, 0, 0], 1, output_filename, level=isosurface_level) 106 | 107 | if __name__ == '__main__': 108 | start_time = time.time() 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('input_mrc_path') 111 | parser.add_argument('--level', type=float, default=10, help="The isosurface level for marching cubes") 112 | args = parser.parse_args() 113 | 114 | if os.path.isfile(args.input_mrc_path) and args.input_mrc_path.split('.')[-1] == 'ply': 115 | output_obj_path = args.input_mrc_path.split('.mrc')[0] + '.ply' 116 | convert_mrc(args.input_mrc_path, output_obj_path, isosurface_level=1) 117 | 118 | print(f"{time.time() - start_time:02f} s") 119 | else: 120 | assert os.path.isdir(args.input_mrc_path) 121 | 122 | for mrc_path in tqdm(glob.glob(os.path.join(args.input_mrc_path, '*.mrc'))): 123 | output_obj_path = mrc_path.split('.mrc')[0] + '.ply' 124 | convert_mrc(mrc_path, output_obj_path, isosurface_level=args.level) -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import glob 12 | import hashlib 13 | import importlib 14 | import os 15 | import re 16 | import shutil 17 | import uuid 18 | 19 | import torch 20 | import torch.utils.cpp_extension 21 | from torch.utils.file_baton import FileBaton 22 | 23 | #---------------------------------------------------------------------------- 24 | # Global options. 25 | 26 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 27 | 28 | #---------------------------------------------------------------------------- 29 | # Internal helper funcs. 30 | 31 | def _find_compiler_bindir(): 32 | patterns = [ 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 35 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 36 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 37 | ] 38 | for pattern in patterns: 39 | matches = sorted(glob.glob(pattern)) 40 | if len(matches): 41 | return matches[-1] 42 | return None 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def _get_mangled_gpu_name(): 47 | name = torch.cuda.get_device_name().lower() 48 | out = [] 49 | for c in name: 50 | if re.match('[a-z0-9_-]+', c): 51 | out.append(c) 52 | else: 53 | out.append('-') 54 | return ''.join(out) 55 | 56 | #---------------------------------------------------------------------------- 57 | # Main entry point for compiling and loading C++/CUDA plugins. 58 | 59 | _cached_plugins = dict() 60 | 61 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 62 | assert verbosity in ['none', 'brief', 'full'] 63 | if headers is None: 64 | headers = [] 65 | if source_dir is not None: 66 | sources = [os.path.join(source_dir, fname) for fname in sources] 67 | headers = [os.path.join(source_dir, fname) for fname in headers] 68 | 69 | # Already cached? 70 | if module_name in _cached_plugins: 71 | return _cached_plugins[module_name] 72 | 73 | # Print status. 74 | if verbosity == 'full': 75 | print(f'Setting up PyTorch plugin "{module_name}"...') 76 | elif verbosity == 'brief': 77 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 78 | verbose_build = (verbosity == 'full') 79 | 80 | # Compile and load. 81 | try: # pylint: disable=too-many-nested-blocks 82 | # Make sure we can find the necessary compiler binaries. 83 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 84 | compiler_bindir = _find_compiler_bindir() 85 | if compiler_bindir is None: 86 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 87 | os.environ['PATH'] += ';' + compiler_bindir 88 | 89 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 90 | # break the build or unnecessarily restrict what's available to nvcc. 91 | # Unset it to let nvcc decide based on what's available on the 92 | # machine. 93 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 94 | 95 | # Incremental build md5sum trickery. Copies all the input source files 96 | # into a cached build directory under a combined md5 digest of the input 97 | # source files. Copying is done only if the combined digest has changed. 98 | # This keeps input file timestamps and filenames the same as in previous 99 | # extension builds, allowing for fast incremental rebuilds. 100 | # 101 | # This optimization is done only in case all the source files reside in 102 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 103 | # environment variable is set (we take this as a signal that the user 104 | # actually cares about this.) 105 | # 106 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 107 | # around the *.cu dependency bug in ninja config. 108 | # 109 | all_source_files = sorted(sources + headers) 110 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 111 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 112 | 113 | # Compute combined hash digest for all source files. 114 | hash_md5 = hashlib.md5() 115 | for src in all_source_files: 116 | with open(src, 'rb') as f: 117 | hash_md5.update(f.read()) 118 | 119 | # Select cached build directory name. 120 | source_digest = hash_md5.hexdigest() 121 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 122 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 123 | 124 | if not os.path.isdir(cached_build_dir): 125 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 126 | os.makedirs(tmpdir) 127 | for src in all_source_files: 128 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 129 | try: 130 | os.replace(tmpdir, cached_build_dir) # atomic 131 | except OSError: 132 | # source directory already exists, delete tmpdir and its contents. 133 | shutil.rmtree(tmpdir) 134 | if not os.path.isdir(cached_build_dir): raise 135 | 136 | # Compile. 137 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 138 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 139 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 140 | else: 141 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 142 | 143 | # Load. 144 | module = importlib.import_module(module_name) 145 | 146 | except: 147 | if verbosity == 'brief': 148 | print('Failed!') 149 | raise 150 | 151 | # Print status and add to cache dict. 152 | if verbosity == 'full': 153 | print(f'Done setting up PyTorch plugin "{module_name}".') 154 | elif verbosity == 'brief': 155 | print('Done.') 156 | _cached_plugins[module_name] = module 157 | return module 158 | 159 | #---------------------------------------------------------------------------- 160 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include "bias_act.h" 17 | 18 | //------------------------------------------------------------------------ 19 | 20 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 21 | { 22 | if (x.dim() != y.dim()) 23 | return false; 24 | for (int64_t i = 0; i < x.dim(); i++) 25 | { 26 | if (x.size(i) != y.size(i)) 27 | return false; 28 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 29 | return false; 30 | } 31 | return true; 32 | } 33 | 34 | //------------------------------------------------------------------------ 35 | 36 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 37 | { 38 | // Validate arguments. 39 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 40 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 41 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 42 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 43 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 44 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 45 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 46 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 47 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 48 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 49 | 50 | // Validate layout. 51 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 52 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 53 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 54 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 55 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 56 | 57 | // Create output tensor. 58 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 59 | torch::Tensor y = torch::empty_like(x); 60 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 61 | 62 | // Initialize CUDA kernel parameters. 63 | bias_act_kernel_params p; 64 | p.x = x.data_ptr(); 65 | p.b = (b.numel()) ? b.data_ptr() : NULL; 66 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 67 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 68 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 69 | p.y = y.data_ptr(); 70 | p.grad = grad; 71 | p.act = act; 72 | p.alpha = alpha; 73 | p.gain = gain; 74 | p.clamp = clamp; 75 | p.sizeX = (int)x.numel(); 76 | p.sizeB = (int)b.numel(); 77 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 78 | 79 | // Choose CUDA kernel. 80 | void* kernel; 81 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 82 | { 83 | kernel = choose_bias_act_kernel(p); 84 | }); 85 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 86 | 87 | // Launch CUDA kernel. 88 | p.loopX = 4; 89 | int blockSize = 4 * 32; 90 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("bias_act", &bias_act); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | #include "bias_act.h" 15 | 16 | //------------------------------------------------------------------------ 17 | // Helpers. 18 | 19 | template struct InternalType; 20 | template <> struct InternalType { typedef double scalar_t; }; 21 | template <> struct InternalType { typedef float scalar_t; }; 22 | template <> struct InternalType { typedef float scalar_t; }; 23 | 24 | //------------------------------------------------------------------------ 25 | // CUDA kernel. 26 | 27 | template 28 | __global__ void bias_act_kernel(bias_act_kernel_params p) 29 | { 30 | typedef typename InternalType::scalar_t scalar_t; 31 | int G = p.grad; 32 | scalar_t alpha = (scalar_t)p.alpha; 33 | scalar_t gain = (scalar_t)p.gain; 34 | scalar_t clamp = (scalar_t)p.clamp; 35 | scalar_t one = (scalar_t)1; 36 | scalar_t two = (scalar_t)2; 37 | scalar_t expRange = (scalar_t)80; 38 | scalar_t halfExpRange = (scalar_t)40; 39 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 40 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 41 | 42 | // Loop over elements. 43 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 44 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 45 | { 46 | // Load. 47 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 48 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 49 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 50 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 51 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 52 | scalar_t yy = (gain != 0) ? yref / gain : 0; 53 | scalar_t y = 0; 54 | 55 | // Apply bias. 56 | ((G == 0) ? x : xref) += b; 57 | 58 | // linear 59 | if (A == 1) 60 | { 61 | if (G == 0) y = x; 62 | if (G == 1) y = x; 63 | } 64 | 65 | // relu 66 | if (A == 2) 67 | { 68 | if (G == 0) y = (x > 0) ? x : 0; 69 | if (G == 1) y = (yy > 0) ? x : 0; 70 | } 71 | 72 | // lrelu 73 | if (A == 3) 74 | { 75 | if (G == 0) y = (x > 0) ? x : x * alpha; 76 | if (G == 1) y = (yy > 0) ? x : x * alpha; 77 | } 78 | 79 | // tanh 80 | if (A == 4) 81 | { 82 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 83 | if (G == 1) y = x * (one - yy * yy); 84 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 85 | } 86 | 87 | // sigmoid 88 | if (A == 5) 89 | { 90 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 91 | if (G == 1) y = x * yy * (one - yy); 92 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 93 | } 94 | 95 | // elu 96 | if (A == 6) 97 | { 98 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 99 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 100 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 101 | } 102 | 103 | // selu 104 | if (A == 7) 105 | { 106 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 107 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 108 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 109 | } 110 | 111 | // softplus 112 | if (A == 8) 113 | { 114 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 115 | if (G == 1) y = x * (one - exp(-yy)); 116 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 117 | } 118 | 119 | // swish 120 | if (A == 9) 121 | { 122 | if (G == 0) 123 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 124 | else 125 | { 126 | scalar_t c = exp(xref); 127 | scalar_t d = c + one; 128 | if (G == 1) 129 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 130 | else 131 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 132 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 133 | } 134 | } 135 | 136 | // Apply gain. 137 | y *= gain * dy; 138 | 139 | // Clamp. 140 | if (clamp >= 0) 141 | { 142 | if (G == 0) 143 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 144 | else 145 | y = (yref > -clamp & yref < clamp) ? y : 0; 146 | } 147 | 148 | // Store. 149 | ((T*)p.y)[xi] = (T)y; 150 | } 151 | } 152 | 153 | //------------------------------------------------------------------------ 154 | // CUDA kernel selection. 155 | 156 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 157 | { 158 | if (p.act == 1) return (void*)bias_act_kernel; 159 | if (p.act == 2) return (void*)bias_act_kernel; 160 | if (p.act == 3) return (void*)bias_act_kernel; 161 | if (p.act == 4) return (void*)bias_act_kernel; 162 | if (p.act == 5) return (void*)bias_act_kernel; 163 | if (p.act == 6) return (void*)bias_act_kernel; 164 | if (p.act == 7) return (void*)bias_act_kernel; 165 | if (p.act == 8) return (void*)bias_act_kernel; 166 | if (p.act == 9) return (void*)bias_act_kernel; 167 | return NULL; 168 | } 169 | 170 | //------------------------------------------------------------------------ 171 | // Template specializations. 172 | 173 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 174 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 175 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 176 | 177 | //------------------------------------------------------------------------ 178 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | //------------------------------------------------------------------------ 14 | // CUDA kernel parameters. 15 | 16 | struct bias_act_kernel_params 17 | { 18 | const void* x; // [sizeX] 19 | const void* b; // [sizeB] or NULL 20 | const void* xref; // [sizeX] or NULL 21 | const void* yref; // [sizeX] or NULL 22 | const void* dy; // [sizeX] or NULL 23 | void* y; // [sizeX] 24 | 25 | int grad; 26 | int act; 27 | float alpha; 28 | float gain; 29 | float clamp; 30 | 31 | int sizeX; 32 | int sizeB; 33 | int stepB; 34 | int loopX; 35 | }; 36 | 37 | //------------------------------------------------------------------------ 38 | // CUDA kernel selection. 39 | 40 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 41 | 42 | //------------------------------------------------------------------------ 43 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Custom PyTorch ops for efficient bias and activation.""" 12 | 13 | import os 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _plugin = None 38 | _null_tensor = torch.empty([0]) 39 | 40 | def _init(): 41 | global _plugin 42 | if _plugin is None: 43 | _plugin = custom_ops.get_plugin( 44 | module_name='bias_act_plugin', 45 | sources=['bias_act.cpp', 'bias_act.cu'], 46 | headers=['bias_act.h'], 47 | source_dir=os.path.dirname(__file__), 48 | extra_cuda_cflags=['--use_fast_math'], 49 | ) 50 | return True 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 55 | r"""Fused bias and activation function. 56 | 57 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 58 | and scales the result by `gain`. Each of the steps is optional. In most cases, 59 | the fused op is considerably more efficient than performing the same calculation 60 | using standard PyTorch ops. It supports first and second order gradients, 61 | but not third order gradients. 62 | 63 | Args: 64 | x: Input activation tensor. Can be of any shape. 65 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 66 | as `x`. The shape must be known, and it must match the dimension of `x` 67 | corresponding to `dim`. 68 | dim: The dimension in `x` corresponding to the elements of `b`. 69 | The value of `dim` is ignored if `b` is not specified. 70 | act: Name of the activation function to evaluate, or `"linear"` to disable. 71 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 72 | See `activation_funcs` for a full list. `None` is not allowed. 73 | alpha: Shape parameter for the activation function, or `None` to use the default. 74 | gain: Scaling factor for the output tensor, or `None` to use default. 75 | See `activation_funcs` for the default scaling of each activation function. 76 | If unsure, consider specifying 1. 77 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 78 | the clamping (default). 79 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 80 | 81 | Returns: 82 | Tensor of the same shape and datatype as `x`. 83 | """ 84 | assert isinstance(x, torch.Tensor) 85 | assert impl in ['ref', 'cuda'] 86 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 87 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 88 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 89 | 90 | #---------------------------------------------------------------------------- 91 | 92 | @misc.profiled_function 93 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 94 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 95 | """ 96 | assert isinstance(x, torch.Tensor) 97 | assert clamp is None or clamp >= 0 98 | spec = activation_funcs[act] 99 | alpha = float(alpha if alpha is not None else spec.def_alpha) 100 | gain = float(gain if gain is not None else spec.def_gain) 101 | clamp = float(clamp if clamp is not None else -1) 102 | 103 | # Add bias. 104 | if b is not None: 105 | assert isinstance(b, torch.Tensor) and b.ndim == 1 106 | assert 0 <= dim < x.ndim 107 | assert b.shape[0] == x.shape[dim] 108 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 109 | 110 | # Evaluate activation function. 111 | alpha = float(alpha) 112 | x = spec.func(x, alpha=alpha) 113 | 114 | # Scale by gain. 115 | gain = float(gain) 116 | if gain != 1: 117 | x = x * gain 118 | 119 | # Clamp. 120 | if clamp >= 0: 121 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 122 | return x 123 | 124 | #---------------------------------------------------------------------------- 125 | 126 | _bias_act_cuda_cache = dict() 127 | 128 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 129 | """Fast CUDA implementation of `bias_act()` using custom ops. 130 | """ 131 | # Parse arguments. 132 | assert clamp is None or clamp >= 0 133 | spec = activation_funcs[act] 134 | alpha = float(alpha if alpha is not None else spec.def_alpha) 135 | gain = float(gain if gain is not None else spec.def_gain) 136 | clamp = float(clamp if clamp is not None else -1) 137 | 138 | # Lookup from cache. 139 | key = (dim, act, alpha, gain, clamp) 140 | if key in _bias_act_cuda_cache: 141 | return _bias_act_cuda_cache[key] 142 | 143 | # Forward op. 144 | class BiasActCuda(torch.autograd.Function): 145 | @staticmethod 146 | def forward(ctx, x, b): # pylint: disable=arguments-differ 147 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 148 | x = x.contiguous(memory_format=ctx.memory_format) 149 | b = b.contiguous() if b is not None else _null_tensor 150 | y = x 151 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 152 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 153 | ctx.save_for_backward( 154 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 155 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | y if 'y' in spec.ref else _null_tensor) 157 | return y 158 | 159 | @staticmethod 160 | def backward(ctx, dy): # pylint: disable=arguments-differ 161 | dy = dy.contiguous(memory_format=ctx.memory_format) 162 | x, b, y = ctx.saved_tensors 163 | dx = None 164 | db = None 165 | 166 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 167 | dx = dy 168 | if act != 'linear' or gain != 1 or clamp >= 0: 169 | dx = BiasActCudaGrad.apply(dy, x, b, y) 170 | 171 | if ctx.needs_input_grad[1]: 172 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 173 | 174 | return dx, db 175 | 176 | # Backward op. 177 | class BiasActCudaGrad(torch.autograd.Function): 178 | @staticmethod 179 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 180 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 181 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 182 | ctx.save_for_backward( 183 | dy if spec.has_2nd_grad else _null_tensor, 184 | x, b, y) 185 | return dx 186 | 187 | @staticmethod 188 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 189 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 190 | dy, x, b, y = ctx.saved_tensors 191 | d_dy = None 192 | d_x = None 193 | d_b = None 194 | d_y = None 195 | 196 | if ctx.needs_input_grad[0]: 197 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 198 | 199 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 200 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 201 | 202 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 203 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 204 | 205 | return d_dy, d_x, d_b, d_y 206 | 207 | # Add to cache. 208 | _bias_act_cuda_cache[key] = BiasActCuda 209 | return BiasActCuda 210 | 211 | #---------------------------------------------------------------------------- 212 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Custom replacement for `torch.nn.functional.conv2d` that supports 12 | arbitrarily high order gradients with zero performance penalty.""" 13 | 14 | import contextlib 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 25 | 26 | @contextlib.contextmanager 27 | def no_weight_gradients(disable=True): 28 | global weight_gradients_disabled 29 | old = weight_gradients_disabled 30 | if disable: 31 | weight_gradients_disabled = True 32 | yield 33 | weight_gradients_disabled = old 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 38 | if _should_use_custom_op(input): 39 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 40 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 41 | 42 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 43 | if _should_use_custom_op(input): 44 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 45 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _should_use_custom_op(input): 50 | assert isinstance(input, torch.Tensor) 51 | if (not enabled) or (not torch.backends.cudnn.enabled): 52 | return False 53 | if input.device.type != 'cuda': 54 | return False 55 | return True 56 | 57 | def _tuple_of_ints(xs, ndim): 58 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 59 | assert len(xs) == ndim 60 | assert all(isinstance(x, int) for x in xs) 61 | return xs 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | _conv2d_gradfix_cache = dict() 66 | _null_tensor = torch.empty([0]) 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | ctx.save_for_backward( 112 | input if weight.requires_grad else _null_tensor, 113 | weight if input.requires_grad else _null_tensor, 114 | ) 115 | ctx.input_shape = input.shape 116 | 117 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 118 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 119 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 120 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 121 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 122 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 123 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 124 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 125 | 126 | # General case => cuDNN. 127 | if transpose: 128 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 129 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 130 | 131 | @staticmethod 132 | def backward(ctx, grad_output): 133 | input, weight = ctx.saved_tensors 134 | input_shape = ctx.input_shape 135 | grad_input = None 136 | grad_weight = None 137 | grad_bias = None 138 | 139 | if ctx.needs_input_grad[0]: 140 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 141 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 142 | grad_input = op.apply(grad_output, weight, None) 143 | assert grad_input.shape == input_shape 144 | 145 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 146 | grad_weight = Conv2dGradWeight.apply(grad_output, input, weight) 147 | assert grad_weight.shape == weight_shape 148 | 149 | if ctx.needs_input_grad[2]: 150 | grad_bias = grad_output.sum([0, 2, 3]) 151 | 152 | return grad_input, grad_weight, grad_bias 153 | 154 | # Gradient with respect to the weights. 155 | class Conv2dGradWeight(torch.autograd.Function): 156 | @staticmethod 157 | def forward(ctx, grad_output, input, weight): 158 | ctx.save_for_backward( 159 | grad_output if input.requires_grad else _null_tensor, 160 | input if grad_output.requires_grad else _null_tensor, 161 | ) 162 | ctx.grad_output_shape = grad_output.shape 163 | ctx.input_shape = input.shape 164 | 165 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 166 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 167 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 168 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 169 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 170 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 171 | 172 | # General case => cuDNN. 173 | return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1] 174 | 175 | 176 | @staticmethod 177 | def backward(ctx, grad2_grad_weight): 178 | grad_output, input = ctx.saved_tensors 179 | grad_output_shape = ctx.grad_output_shape 180 | input_shape = ctx.input_shape 181 | grad2_grad_output = None 182 | grad2_input = None 183 | 184 | if ctx.needs_input_grad[0]: 185 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 186 | assert grad2_grad_output.shape == grad_output_shape 187 | 188 | if ctx.needs_input_grad[1]: 189 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 190 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 191 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 192 | assert grad2_input.shape == input_shape 193 | 194 | return grad2_grad_output, grad2_input 195 | 196 | _conv2d_gradfix_cache[key] = Conv2d 197 | return Conv2d 198 | 199 | #---------------------------------------------------------------------------- 200 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """2D convolution with optional up/downsampling.""" 12 | 13 | import torch 14 | 15 | from .. import misc 16 | from . import conv2d_gradfix 17 | from . import upfirdn2d 18 | from .upfirdn2d import _parse_padding 19 | from .upfirdn2d import _get_filter_size 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def _get_weight_shape(w): 24 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 25 | shape = [int(sz) for sz in w.shape] 26 | misc.assert_shape(w, shape) 27 | return shape 28 | 29 | #---------------------------------------------------------------------------- 30 | 31 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 32 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 33 | """ 34 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 35 | 36 | # Flip weight if requested. 37 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 38 | if not flip_weight and (kw > 1 or kh > 1): 39 | w = w.flip([2, 3]) 40 | 41 | # Execute using conv2d_gradfix. 42 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 43 | return op(x, w, stride=stride, padding=padding, groups=groups) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | @misc.profiled_function 48 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 49 | r"""2D convolution with optional up/downsampling. 50 | 51 | Padding is performed only once at the beginning, not between the operations. 52 | 53 | Args: 54 | x: Input tensor of shape 55 | `[batch_size, in_channels, in_height, in_width]`. 56 | w: Weight tensor of shape 57 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 58 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 59 | calling upfirdn2d.setup_filter(). None = identity (default). 60 | up: Integer upsampling factor (default: 1). 61 | down: Integer downsampling factor (default: 1). 62 | padding: Padding with respect to the upsampled image. Can be a single number 63 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 64 | (default: 0). 65 | groups: Split input channels into N groups (default: 1). 66 | flip_weight: False = convolution, True = correlation (default: True). 67 | flip_filter: False = convolution, True = correlation (default: False). 68 | 69 | Returns: 70 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 71 | """ 72 | # Validate arguments. 73 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 74 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 75 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 76 | assert isinstance(up, int) and (up >= 1) 77 | assert isinstance(down, int) and (down >= 1) 78 | assert isinstance(groups, int) and (groups >= 1) 79 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 80 | fw, fh = _get_filter_size(f) 81 | px0, px1, py0, py1 = _parse_padding(padding) 82 | 83 | # Adjust padding to account for up/downsampling. 84 | if up > 1: 85 | px0 += (fw + up - 1) // 2 86 | px1 += (fw - up) // 2 87 | py0 += (fh + up - 1) // 2 88 | py1 += (fh - up) // 2 89 | if down > 1: 90 | px0 += (fw - down + 1) // 2 91 | px1 += (fw - down) // 2 92 | py0 += (fh - down + 1) // 2 93 | py1 += (fh - down) // 2 94 | 95 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 96 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 97 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 98 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 99 | return x 100 | 101 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 102 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 103 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 104 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 105 | return x 106 | 107 | # Fast path: downsampling only => use strided convolution. 108 | if down > 1 and up == 1: 109 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 110 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 111 | return x 112 | 113 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 114 | if up > 1: 115 | if groups == 1: 116 | w = w.transpose(0, 1) 117 | else: 118 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 119 | w = w.transpose(1, 2) 120 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 121 | px0 -= kw - 1 122 | px1 -= kw - up 123 | py0 -= kh - 1 124 | py1 -= kh - up 125 | pxt = max(min(-px0, -px1), 0) 126 | pyt = max(min(-py0, -py1), 0) 127 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 129 | if down > 1: 130 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 131 | return x 132 | 133 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 134 | if up == 1 and down == 1: 135 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 136 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 137 | 138 | # Fallback: Generic reference implementation. 139 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 140 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 141 | if down > 1: 142 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 143 | return x 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | 15 | //------------------------------------------------------------------------ 16 | // CUDA kernel parameters. 17 | 18 | struct filtered_lrelu_kernel_params 19 | { 20 | // These parameters decide which kernel to use. 21 | int up; // upsampling ratio (1, 2, 4) 22 | int down; // downsampling ratio (1, 2, 4) 23 | int2 fuShape; // [size, 1] | [size, size] 24 | int2 fdShape; // [size, 1] | [size, size] 25 | 26 | int _dummy; // Alignment. 27 | 28 | // Rest of the parameters. 29 | const void* x; // Input tensor. 30 | void* y; // Output tensor. 31 | const void* b; // Bias tensor. 32 | unsigned char* s; // Sign tensor in/out. NULL if unused. 33 | const float* fu; // Upsampling filter. 34 | const float* fd; // Downsampling filter. 35 | 36 | int2 pad0; // Left/top padding. 37 | float gain; // Additional gain factor. 38 | float slope; // Leaky ReLU slope on negative side. 39 | float clamp; // Clamp after nonlinearity. 40 | int flip; // Filter kernel flip for gradient computation. 41 | 42 | int tilesXdim; // Original number of horizontal output tiles. 43 | int tilesXrep; // Number of horizontal tiles per CTA. 44 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 45 | 46 | int4 xShape; // [width, height, channel, batch] 47 | int4 yShape; // [width, height, channel, batch] 48 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 49 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 50 | int swLimit; // Active width of sign tensor in bytes. 51 | 52 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 53 | longlong4 yStride; // 54 | int64_t bStride; // 55 | longlong3 fuStride; // 56 | longlong3 fdStride; // 57 | }; 58 | 59 | struct filtered_lrelu_act_kernel_params 60 | { 61 | void* x; // Input/output, modified in-place. 62 | unsigned char* s; // Sign tensor in/out. NULL if unused. 63 | 64 | float gain; // Additional gain factor. 65 | float slope; // Leaky ReLU slope on negative side. 66 | float clamp; // Clamp after nonlinearity. 67 | 68 | int4 xShape; // [width, height, channel, batch] 69 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 70 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 71 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 72 | }; 73 | 74 | //------------------------------------------------------------------------ 75 | // CUDA kernel specialization. 76 | 77 | struct filtered_lrelu_kernel_spec 78 | { 79 | void* setup; // Function for filter kernel setup. 80 | void* exec; // Function for main operation. 81 | int2 tileOut; // Width/height of launch tile. 82 | int numWarps; // Number of warps per thread block, determines launch block size. 83 | int xrep; // For processing multiple horizontal tiles per thread block. 84 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 85 | }; 86 | 87 | //------------------------------------------------------------------------ 88 | // CUDA kernel selection. 89 | 90 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 91 | template void* choose_filtered_lrelu_act_kernel(void); 92 | template cudaError_t copy_filters(cudaStream_t stream); 93 | 94 | //------------------------------------------------------------------------ 95 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for no signs mode (no gradients required). 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for sign read mode. 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for sign write mode. 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 12 | 13 | import torch 14 | 15 | #---------------------------------------------------------------------------- 16 | 17 | def fma(a, b, c): # => a * b + c 18 | return _FusedMultiplyAdd.apply(a, b, c) 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 23 | @staticmethod 24 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 25 | out = torch.addcmul(c, a, b) 26 | ctx.save_for_backward(a, b) 27 | ctx.c_shape = c.shape 28 | return out 29 | 30 | @staticmethod 31 | def backward(ctx, dout): # pylint: disable=arguments-differ 32 | a, b = ctx.saved_tensors 33 | c_shape = ctx.c_shape 34 | da = None 35 | db = None 36 | dc = None 37 | 38 | if ctx.needs_input_grad[0]: 39 | da = _unbroadcast(dout * b, a.shape) 40 | 41 | if ctx.needs_input_grad[1]: 42 | db = _unbroadcast(dout * a, b.shape) 43 | 44 | if ctx.needs_input_grad[2]: 45 | dc = _unbroadcast(dout, c_shape) 46 | 47 | return da, db, dc 48 | 49 | #---------------------------------------------------------------------------- 50 | 51 | def _unbroadcast(x, shape): 52 | extra_dims = x.ndim - len(shape) 53 | assert extra_dims >= 0 54 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 55 | if len(dim): 56 | x = x.sum(dim=dim, keepdim=True) 57 | if extra_dims: 58 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 59 | assert x.shape == shape 60 | return x 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Custom replacement for `torch.nn.functional.grid_sample` that 12 | supports arbitrarily high order gradients between the input and output. 13 | Only works on 2D images and assumes 14 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 15 | 16 | import torch 17 | 18 | # pylint: disable=redefined-builtin 19 | # pylint: disable=arguments-differ 20 | # pylint: disable=protected-access 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | enabled = False # Enable the custom op by setting this to true. 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 62 | ctx.save_for_backward(grid) 63 | return grad_input, grad_grid 64 | 65 | @staticmethod 66 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 67 | _ = grad2_grad_grid # unused 68 | grid, = ctx.saved_tensors 69 | grad2_grad_output = None 70 | grad2_input = None 71 | grad2_grid = None 72 | 73 | if ctx.needs_input_grad[0]: 74 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 75 | 76 | assert not ctx.needs_input_grad[2] 77 | return grad2_grad_output, grad2_input, grad2_grid 78 | 79 | #---------------------------------------------------------------------------- 80 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include "upfirdn2d.h" 17 | 18 | //------------------------------------------------------------------------ 19 | 20 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 21 | { 22 | // Validate arguments. 23 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 24 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 25 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 26 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 27 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 28 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 29 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 30 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 31 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 32 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 33 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 34 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 35 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 36 | 37 | // Create output tensor. 38 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 39 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 40 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 41 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 42 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 43 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 44 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 45 | 46 | // Initialize CUDA kernel parameters. 47 | upfirdn2d_kernel_params p; 48 | p.x = x.data_ptr(); 49 | p.f = f.data_ptr(); 50 | p.y = y.data_ptr(); 51 | p.up = make_int2(upx, upy); 52 | p.down = make_int2(downx, downy); 53 | p.pad0 = make_int2(padx0, pady0); 54 | p.flip = (flip) ? 1 : 0; 55 | p.gain = gain; 56 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 57 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 58 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 59 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 60 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 61 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 62 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 63 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 64 | 65 | // Choose CUDA kernel. 66 | upfirdn2d_kernel_spec spec; 67 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 68 | { 69 | spec = choose_upfirdn2d_kernel(p); 70 | }); 71 | 72 | // Set looping options. 73 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 74 | p.loopMinor = spec.loopMinor; 75 | p.loopX = spec.loopX; 76 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 77 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 78 | 79 | // Compute grid size. 80 | dim3 blockSize, gridSize; 81 | if (spec.tileOutW < 0) // large 82 | { 83 | blockSize = dim3(4, 32, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | else // small 90 | { 91 | blockSize = dim3(256, 1, 1); 92 | gridSize = dim3( 93 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 94 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 95 | p.launchMajor); 96 | } 97 | 98 | // Launch CUDA kernel. 99 | void* args[] = {&p}; 100 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 101 | return y; 102 | } 103 | 104 | //------------------------------------------------------------------------ 105 | 106 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 107 | { 108 | m.def("upfirdn2d", &upfirdn2d); 109 | } 110 | 111 | //------------------------------------------------------------------------ 112 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | 15 | //------------------------------------------------------------------------ 16 | // CUDA kernel parameters. 17 | 18 | struct upfirdn2d_kernel_params 19 | { 20 | const void* x; 21 | const float* f; 22 | void* y; 23 | 24 | int2 up; 25 | int2 down; 26 | int2 pad0; 27 | int flip; 28 | float gain; 29 | 30 | int4 inSize; // [width, height, channel, batch] 31 | int4 inStride; 32 | int2 filterSize; // [width, height] 33 | int2 filterStride; 34 | int4 outSize; // [width, height, channel, batch] 35 | int4 outStride; 36 | int sizeMinor; 37 | int sizeMajor; 38 | 39 | int loopMinor; 40 | int loopMajor; 41 | int loopX; 42 | int launchMinor; 43 | int launchMajor; 44 | }; 45 | 46 | //------------------------------------------------------------------------ 47 | // CUDA kernel specialization. 48 | 49 | struct upfirdn2d_kernel_spec 50 | { 51 | void* kernel; 52 | int tileOutW; 53 | int tileOutH; 54 | int loopMinor; 55 | int loopX; 56 | }; 57 | 58 | //------------------------------------------------------------------------ 59 | // CUDA kernel selection. 60 | 61 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 62 | 63 | //------------------------------------------------------------------------ 64 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Facilities for pickling Python code alongside other data. 12 | 13 | The pickled code is automatically imported into a separate Python module 14 | during unpickling. This way, any previously exported pickles will remain 15 | usable even if the original code is no longer available, or if the current 16 | version of the code is not consistent with what was originally pickled.""" 17 | 18 | import sys 19 | import pickle 20 | import io 21 | import inspect 22 | import copy 23 | import uuid 24 | import types 25 | import dnnlib 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | _version = 6 # internal version number 30 | _decorators = set() # {decorator_class, ...} 31 | _import_hooks = [] # [hook_function, ...] 32 | _module_to_src_dict = dict() # {module: src, ...} 33 | _src_to_module_dict = dict() # {src: module, ...} 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def persistent_class(orig_class): 38 | r"""Class decorator that extends a given class to save its source code 39 | when pickled. 40 | 41 | Example: 42 | 43 | from torch_utils import persistence 44 | 45 | @persistence.persistent_class 46 | class MyNetwork(torch.nn.Module): 47 | def __init__(self, num_inputs, num_outputs): 48 | super().__init__() 49 | self.fc = MyLayer(num_inputs, num_outputs) 50 | ... 51 | 52 | @persistence.persistent_class 53 | class MyLayer(torch.nn.Module): 54 | ... 55 | 56 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 57 | source code alongside other internal state (e.g., parameters, buffers, 58 | and submodules). This way, any previously exported pickle will remain 59 | usable even if the class definitions have been modified or are no 60 | longer available. 61 | 62 | The decorator saves the source code of the entire Python module 63 | containing the decorated class. It does *not* save the source code of 64 | any imported modules. Thus, the imported modules must be available 65 | during unpickling, also including `torch_utils.persistence` itself. 66 | 67 | It is ok to call functions defined in the same module from the 68 | decorated class. However, if the decorated class depends on other 69 | classes defined in the same module, they must be decorated as well. 70 | This is illustrated in the above example in the case of `MyLayer`. 71 | 72 | It is also possible to employ the decorator just-in-time before 73 | calling the constructor. For example: 74 | 75 | cls = MyLayer 76 | if want_to_make_it_persistent: 77 | cls = persistence.persistent_class(cls) 78 | layer = cls(num_inputs, num_outputs) 79 | 80 | As an additional feature, the decorator also keeps track of the 81 | arguments that were used to construct each instance of the decorated 82 | class. The arguments can be queried via `obj.init_args` and 83 | `obj.init_kwargs`, and they are automatically pickled alongside other 84 | object state. A typical use case is to first unpickle a previous 85 | instance of a persistent class, and then upgrade it to use the latest 86 | version of the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | self._init_args = copy.deepcopy(args) 108 | self._init_kwargs = copy.deepcopy(kwargs) 109 | assert orig_class.__name__ in orig_module.__dict__ 110 | _check_pickleable(self.__reduce__()) 111 | 112 | @property 113 | def init_args(self): 114 | return copy.deepcopy(self._init_args) 115 | 116 | @property 117 | def init_kwargs(self): 118 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 119 | 120 | def __reduce__(self): 121 | fields = list(super().__reduce__()) 122 | fields += [None] * max(3 - len(fields), 0) 123 | if fields[0] is not _reconstruct_persistent_obj: 124 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 125 | fields[0] = _reconstruct_persistent_obj # reconstruct func 126 | fields[1] = (meta,) # reconstruct args 127 | fields[2] = None # state dict 128 | return tuple(fields) 129 | 130 | Decorator.__name__ = orig_class.__name__ 131 | _decorators.add(Decorator) 132 | return Decorator 133 | 134 | #---------------------------------------------------------------------------- 135 | 136 | def is_persistent(obj): 137 | r"""Test whether the given object or class is persistent, i.e., 138 | whether it will save its source code when pickled. 139 | """ 140 | try: 141 | if obj in _decorators: 142 | return True 143 | except TypeError: 144 | pass 145 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 146 | 147 | #---------------------------------------------------------------------------- 148 | 149 | def import_hook(hook): 150 | r"""Register an import hook that is called whenever a persistent object 151 | is being unpickled. A typical use case is to patch the pickled source 152 | code to avoid errors and inconsistencies when the API of some imported 153 | module has changed. 154 | 155 | The hook should have the following signature: 156 | 157 | hook(meta) -> modified meta 158 | 159 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 160 | 161 | type: Type of the persistent object, e.g. `'class'`. 162 | version: Internal version number of `torch_utils.persistence`. 163 | module_src Original source code of the Python module. 164 | class_name: Class name in the original Python module. 165 | state: Internal state of the object. 166 | 167 | Example: 168 | 169 | @persistence.import_hook 170 | def wreck_my_network(meta): 171 | if meta.class_name == 'MyNetwork': 172 | print('MyNetwork is being imported. I will wreck it!') 173 | meta.module_src = meta.module_src.replace("True", "False") 174 | return meta 175 | """ 176 | assert callable(hook) 177 | _import_hooks.append(hook) 178 | 179 | #---------------------------------------------------------------------------- 180 | 181 | def _reconstruct_persistent_obj(meta): 182 | r"""Hook that is called internally by the `pickle` module to unpickle 183 | a persistent object. 184 | """ 185 | meta = dnnlib.EasyDict(meta) 186 | meta.state = dnnlib.EasyDict(meta.state) 187 | for hook in _import_hooks: 188 | meta = hook(meta) 189 | assert meta is not None 190 | 191 | assert meta.version == _version 192 | module = _src_to_module(meta.module_src) 193 | 194 | assert meta.type == 'class' 195 | orig_class = module.__dict__[meta.class_name] 196 | decorator_class = persistent_class(orig_class) 197 | obj = decorator_class.__new__(decorator_class) 198 | 199 | setstate = getattr(obj, '__setstate__', None) 200 | if callable(setstate): 201 | setstate(meta.state) # pylint: disable=not-callable 202 | else: 203 | obj.__dict__.update(meta.state) 204 | return obj 205 | 206 | #---------------------------------------------------------------------------- 207 | 208 | def _module_to_src(module): 209 | r"""Query the source code of a given Python module. 210 | """ 211 | src = _module_to_src_dict.get(module, None) 212 | if src is None: 213 | src = inspect.getsource(module) 214 | _module_to_src_dict[module] = src 215 | _src_to_module_dict[src] = module 216 | return src 217 | 218 | def _src_to_module(src): 219 | r"""Get or create a Python module for the given source code. 220 | """ 221 | module = _src_to_module_dict.get(src, None) 222 | if module is None: 223 | module_name = "_imported_module_" + uuid.uuid4().hex 224 | module = types.ModuleType(module_name) 225 | sys.modules[module_name] = module 226 | _module_to_src_dict[module] = src 227 | _src_to_module_dict[src] = module 228 | exec(src, module.__dict__) # pylint: disable=exec-used 229 | return module 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def _check_pickleable(obj): 234 | r"""Check that the given object is pickleable, raising an exception if 235 | it is not. This function is expected to be considerably more efficient 236 | than actually pickling the object. 237 | """ 238 | def recurse(obj): 239 | if isinstance(obj, (list, tuple, set)): 240 | return [recurse(x) for x in obj] 241 | if isinstance(obj, dict): 242 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 243 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 244 | return None # Python primitive types are pickleable. 245 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 246 | return None # NumPy arrays and PyTorch tensors are pickleable. 247 | if is_persistent(obj): 248 | return None # Persistent objects are pickleable, by virtue of the constructor check. 249 | return obj 250 | with io.BytesIO() as f: 251 | pickle.dump(recurse(obj), f) 252 | 253 | #---------------------------------------------------------------------------- 254 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /training/crosssection_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import torch 12 | 13 | def sample_cross_section(G, ws, resolution=256, w=1.2): 14 | axis=0 15 | A, B = torch.meshgrid(torch.linspace(w/2, -w/2, resolution, device=ws.device), torch.linspace(-w/2, w/2, resolution, device=ws.device), indexing='ij') 16 | A, B = A.reshape(-1, 1), B.reshape(-1, 1) 17 | C = torch.zeros_like(A) 18 | coordinates = [A, B] 19 | coordinates.insert(axis, C) 20 | coordinates = torch.cat(coordinates, dim=-1).expand(ws.shape[0], -1, -1) 21 | 22 | sigma = G.sample_mixed(coordinates, torch.randn_like(coordinates), ws)['sigma'] 23 | return sigma.reshape(-1, 1, resolution, resolution) 24 | 25 | # if __name__ == '__main__': 26 | # sample_crossection(None) -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import atan2, asin, cos 3 | 4 | def P2sRt(P): 5 | """ decompositing camera matrix P. 6 | Args: 7 | P: (3, 4). Affine Camera Matrix. 8 | Returns: 9 | s: scale factor. 10 | R: (3, 3). rotation matrix. 11 | t2d: (2,). 2d translation. 12 | """ 13 | t3d = P[:, 3] 14 | R1 = P[0:1, :3] 15 | R2 = P[1:2, :3] 16 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 17 | r1 = R1 / np.linalg.norm(R1) 18 | r2 = R2 / np.linalg.norm(R2) 19 | r3 = np.cross(r1, r2) 20 | 21 | R = np.concatenate((r1, r2, r3), 0) 22 | return s, R, t3d 23 | 24 | def matrix2angle(R): 25 | """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf 26 | refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv 27 | todo: check and debug 28 | Args: 29 | R: (3,3). rotation matrix 30 | Returns: 31 | x: yaw 32 | y: pitch 33 | z: roll 34 | """ 35 | if R[2, 0] > 0.998: 36 | z = 0 37 | x = np.pi / 2 38 | y = z + atan2(-R[0, 1], -R[0, 2]) 39 | elif R[2, 0] < -0.998: 40 | z = 0 41 | x = -np.pi / 2 42 | y = -z + atan2(R[0, 1], R[0, 2]) 43 | else: 44 | x = asin(R[2, 0]) 45 | y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) 46 | z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) 47 | 48 | if abs(y) > np.pi/2: 49 | if x > 0: 50 | x = np.pi - x 51 | else: 52 | x = -np.pi - x 53 | y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) 54 | z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) 55 | return x, y, z 56 | 57 | def calc_pose(param): 58 | P = param[:12].reshape(3, -1) # camera matrix 59 | s, R, t3d = P2sRt(P) 60 | P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale 61 | pose = matrix2angle(R) 62 | pose = [p * 180 / np.pi for p in pose] 63 | 64 | return P, pose 65 | 66 | def get_poseangle(eg3dparams): 67 | 68 | convert = np.array([ 69 | [1, 0, 0, 0], 70 | [0, -1, 0, 0], 71 | [0, 0, -1, 0], 72 | [0, 0, 0, 1], 73 | ]).astype(np.float32) 74 | 75 | entry_cam = np.array([float(p) for p in eg3dparams][:16]).reshape((4,4)) 76 | 77 | world2cam = np.linalg.inv(entry_cam@convert) 78 | pose = matrix2angle(world2cam[:3,:3]) 79 | angle = [p * 180 / np.pi for p in pose] 80 | 81 | return angle 82 | -------------------------------------------------------------------------------- /training/volumetric_rendering/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty -------------------------------------------------------------------------------- /training/volumetric_rendering/math_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Petr Kellnhofer 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | 25 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: 26 | """ 27 | Left-multiplies MxM @ NxM. Returns NxM. 28 | """ 29 | res = torch.matmul(vectors4, matrix.T) 30 | return res 31 | 32 | 33 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Normalize vector lengths. 36 | """ 37 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) 38 | 39 | def torch_dot(x: torch.Tensor, y: torch.Tensor): 40 | """ 41 | Dot product of two tensors. 42 | """ 43 | return (x * y).sum(-1) 44 | 45 | 46 | def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): 47 | """ 48 | Author: Petr Kellnhofer 49 | Intersects rays with the [-1, 1] NDC volume. 50 | Returns min and max distance of entry. 51 | Returns -1 for no intersection. 52 | https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection 53 | """ 54 | o_shape = rays_o.shape 55 | rays_o = rays_o.detach().reshape(-1, 3) 56 | rays_d = rays_d.detach().reshape(-1, 3) 57 | 58 | 59 | bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] 60 | bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] 61 | bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) 62 | is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) 63 | 64 | # Precompute inverse for stability. 65 | invdir = 1 / rays_d 66 | sign = (invdir < 0).long() 67 | 68 | # Intersect with YZ plane. 69 | tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] 70 | tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] 71 | 72 | # Intersect with XZ plane. 73 | tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] 74 | tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] 75 | 76 | # Resolve parallel rays. 77 | is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False 78 | 79 | # Use the shortest intersection. 80 | tmin = torch.max(tmin, tymin) 81 | tmax = torch.min(tmax, tymax) 82 | 83 | # Intersect with XY plane. 84 | tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] 85 | tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] 86 | 87 | # Resolve parallel rays. 88 | is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False 89 | 90 | # Use the shortest intersection. 91 | tmin = torch.max(tmin, tzmin) 92 | tmax = torch.min(tmax, tzmax) 93 | 94 | # Mark invalid. 95 | tmin[torch.logical_not(is_valid)] = -1 96 | tmax[torch.logical_not(is_valid)] = -2 97 | 98 | return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) 99 | 100 | 101 | def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): 102 | """ 103 | Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. 104 | Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. 105 | """ 106 | # create a tensor of 'num' steps from 0 to 1 107 | steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) 108 | 109 | # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings 110 | # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript 111 | # "cannot statically infer the expected size of a list in this contex", hence the code below 112 | for i in range(start.ndim): 113 | steps = steps.unsqueeze(-1) 114 | 115 | # the output starts at 'start' and increments until 'stop' in each dimension 116 | out = start[None] + steps * (stop - start)[None] 117 | 118 | return out 119 | -------------------------------------------------------------------------------- /training/volumetric_rendering/ray_marcher.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. 13 | Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) 14 | """ 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch_utils import persistence 20 | 21 | @persistence.persistent_class 22 | class MipRayMarcher2(nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | 26 | 27 | def run_forward(self, colors, densities, depths, rendering_options): 28 | deltas = depths[:, :, 1:] - depths[:, :, :-1] 29 | colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 30 | densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 31 | depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 32 | 33 | 34 | if rendering_options['clamp_mode'] == 'softplus': 35 | densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better 36 | else: 37 | assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!" 38 | 39 | density_delta = densities_mid * deltas 40 | 41 | alpha = 1 - torch.exp(-density_delta) 42 | 43 | alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) 44 | weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] 45 | 46 | composite_rgb = torch.sum(weights * colors_mid, -2) 47 | weight_total = weights.sum(2) 48 | composite_depth = torch.sum(weights * depths_mid, -2) / weight_total 49 | 50 | # clip the composite to min/max range of depths 51 | composite_depth = torch.nan_to_num(composite_depth, float('inf')) 52 | composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) 53 | 54 | if rendering_options.get('white_back', False): 55 | composite_rgb = composite_rgb + 1 - weight_total 56 | 57 | return composite_rgb, composite_depth, weights 58 | 59 | 60 | def forward(self, colors, densities, depths, rendering_options): 61 | composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) 62 | 63 | return composite_rgb, composite_depth, weights -------------------------------------------------------------------------------- /training/volumetric_rendering/ray_sampler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | The ray sampler is a module that takes in camera matrices and resolution and batches of rays. 13 | Expects cam2world matrices that use the OpenCV camera coordinate system conventions. 14 | """ 15 | 16 | import torch 17 | 18 | class RaySampler(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None 22 | 23 | 24 | def forward(self, cam2world_matrix, intrinsics, resolution): 25 | """ 26 | Create batches of rays and return origins and directions. 27 | 28 | cam2world_matrix: (N, 4, 4) 29 | intrinsics: (N, 3, 3) 30 | resolution: int 31 | 32 | ray_origins: (N, M, 3) 33 | ray_dirs: (N, M, 2) 34 | """ 35 | N, M = cam2world_matrix.shape[0], resolution**2 36 | cam_locs_world = cam2world_matrix[:, :3, 3] 37 | fx = intrinsics[:, 0, 0] 38 | fy = intrinsics[:, 1, 1] 39 | cx = intrinsics[:, 0, 2] 40 | cy = intrinsics[:, 1, 2] 41 | sk = intrinsics[:, 0, 1] 42 | 43 | uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1./resolution) + (0.5/resolution) 44 | uv = uv.flip(0).reshape(2, -1).transpose(1, 0) 45 | uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) 46 | 47 | x_cam = uv[:, :, 0].view(N, -1) 48 | y_cam = uv[:, :, 1].view(N, -1) 49 | z_cam = torch.ones((N, M), device=cam2world_matrix.device) 50 | 51 | x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam 52 | y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam 53 | 54 | cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) 55 | 56 | world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] 57 | 58 | ray_dirs = world_rel_points - cam_locs_world[:, None, :] 59 | ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) 60 | 61 | ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) 62 | 63 | return ray_origins, ray_dirs --------------------------------------------------------------------------------