├── .gitignore ├── .gitmodules ├── LICENSE.txt ├── README.md ├── dataset_preprocessing ├── afhq │ ├── preprocess_afhq_cameras.py │ └── runme.py ├── ffhq │ ├── 3dface2idr_mat.py │ ├── align_multiprocess.py │ ├── batch_mtcnn.py │ ├── crop_images.py │ ├── crop_images_in_the_wild.py │ ├── download_ffhq.py │ ├── preprocess.py │ ├── preprocess_face_cameras.py │ ├── preprocess_in_the_wild.py │ ├── runme.py │ └── validate_ffhq.py ├── mirror_dataset.py ├── rebalance_ffhq │ ├── num_replicas.json │ └── rebalance_ffhq_dataset.py └── shapenet_cars │ ├── preprocess_shapenet_cameras.py │ └── run_me.py ├── docs ├── camera_conventions.md ├── camera_coordinate_conventions.jpg ├── models.md ├── teaser.jpeg ├── training_guide.md ├── visualizer.png └── visualizer_guide.md └── eg3d ├── calc_metrics.py ├── camera_utils.py ├── dataset_tool.py ├── dnnlib ├── __init__.py └── util.py ├── environment.yml ├── gen_samples.py ├── gen_videos.py ├── gui_utils ├── __init__.py ├── gl_utils.py ├── glfw_window.py ├── imgui_utils.py ├── imgui_window.py └── text_utils.py ├── legacy.py ├── metrics ├── __init__.py ├── equivariance.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── shape_utils.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py ├── training ├── __init__.py ├── augment.py ├── crosssection_utils.py ├── dataset.py ├── dual_discriminator.py ├── loss.py ├── networks_stylegan2.py ├── networks_stylegan3.py ├── superresolution.py ├── training_loop.py ├── triplane.py └── volumetric_rendering │ ├── __init__.py │ ├── math_utils.py │ ├── ray_marcher.py │ ├── ray_sampler.py │ └── renderer.py ├── visualizer.py └── viz ├── __init__.py ├── backbone_cache_widget.py ├── capture_widget.py ├── conditioning_pose_widget.py ├── latent_widget.py ├── layer_widget.py ├── performance_widget.py ├── pickle_widget.py ├── pose_widget.py ├── render_depth_sample_widget.py ├── render_type_widget.py ├── renderer.py ├── stylemix_widget.py ├── trunc_noise_widget.py └── zoom_widget.py /.gitignore: -------------------------------------------------------------------------------- 1 | images/ 2 | .DS_Store 3 | .ipynb_checkpoints 4 | data 5 | output 6 | debug 7 | *.pyc 8 | deep-head-pose 9 | EvalImages 10 | cache* 11 | *.pkl 12 | gif 13 | archive 14 | *.ply 15 | eval 16 | out 17 | 18 | # evaluation: 19 | temp/ 20 | shapes/ 21 | imgs/ 22 | vids/ 23 | *.mp4 24 | 25 | stylegan3/results 26 | eg3d/results 27 | eg3d_results 28 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch"] 2 | path = dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch 3 | url = https://github.com/sicxu/Deep3DFaceRecon_pytorch.git 4 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021-2022, NVIDIA Corporation & affiliates. All rights 2 | reserved. 3 | 4 | 5 | NVIDIA Source Code License for EG3D 6 | 7 | 8 | ======================================================================= 9 | 10 | 1. Definitions 11 | 12 | "Licensor" means any person or entity that distributes its Work. 13 | 14 | "Software" means the original work of authorship made available under 15 | this License. 16 | 17 | "Work" means the Software and any additions to or derivative works of 18 | the Software that are made available under this License. 19 | 20 | The terms "reproduce," "reproduction," "derivative works," and 21 | "distribution" have the meaning as provided under U.S. copyright law; 22 | provided, however, that for the purposes of this License, derivative 23 | works shall not include works that remain separable from, or merely 24 | link (or bind by name) to the interfaces of, the Work. 25 | 26 | Works, including the Software, are "made available" under this License 27 | by including in or with the Work either (a) a copyright notice 28 | referencing the applicability of this License to the Work, or (b) a 29 | copy of this License. 30 | 31 | 2. License Grants 32 | 33 | 2.1 Copyright Grant. Subject to the terms and conditions of this 34 | License, each Licensor grants to you a perpetual, worldwide, 35 | non-exclusive, royalty-free, copyright license to reproduce, 36 | prepare derivative works of, publicly display, publicly perform, 37 | sublicense and distribute its Work and any resulting derivative 38 | works in any form. 39 | 40 | 3. Limitations 41 | 42 | 3.1 Redistribution. You may reproduce or distribute the Work only 43 | if (a) you do so under this License, (b) you include a complete 44 | copy of this License with your distribution, and (c) you retain 45 | without modification any copyright, patent, trademark, or 46 | attribution notices that are present in the Work. 47 | 48 | 3.2 Derivative Works. You may specify that additional or different 49 | terms apply to the use, reproduction, and distribution of your 50 | derivative works of the Work ("Your Terms") only if (a) Your Terms 51 | provide that the use limitation in Section 3.3 applies to your 52 | derivative works, and (b) you identify the specific derivative 53 | works that are subject to Your Terms. Notwithstanding Your Terms, 54 | this License (including the redistribution requirements in Section 55 | 3.1) will continue to apply to the Work itself. 56 | 57 | 3.3 Use Limitation. The Work and any derivative works thereof only 58 | may be used or intended for use non-commercially. The Work or 59 | derivative works thereof may be used or intended for use by NVIDIA 60 | or it’s affiliates commercially or non-commercially. As used 61 | herein, "non-commercially" means for research or evaluation 62 | purposes only and not for any direct or indirect monetary gain. 63 | 64 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 65 | against any Licensor (including any claim, cross-claim or 66 | counterclaim in a lawsuit) to enforce any patents that you allege 67 | are infringed by any Work, then your rights under this License from 68 | such Licensor (including the grants in Sections 2.1) will terminate 69 | immediately. 70 | 71 | 3.5 Trademarks. This License does not grant any rights to use any 72 | Licensor’s or its affiliates’ names, logos, or trademarks, except 73 | as necessary to reproduce the notices described in this License. 74 | 75 | 3.6 Termination. If you violate any term of this License, then your 76 | rights under this License (including the grants in Sections 2.1) 77 | will terminate immediately. 78 | 79 | 4. Disclaimer of Warranty. 80 | 81 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 82 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 83 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 84 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 85 | THIS LICENSE. 86 | 87 | 5. Limitation of Liability. 88 | 89 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 90 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 91 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 92 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 93 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 94 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 95 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 96 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 97 | THE POSSIBILITY OF SUCH DAMAGES. 98 | 99 | ======================================================================= 100 | -------------------------------------------------------------------------------- /dataset_preprocessing/afhq/preprocess_afhq_cameras.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import json 12 | import numpy as np 13 | from PIL import Image, ImageOps 14 | import os 15 | from tqdm import tqdm 16 | import argparse 17 | 18 | 19 | def gen_pose(rot_mat): 20 | rot_mat = np.array(rot_mat).copy() 21 | forward = rot_mat[:, 2] 22 | translation = forward * -2.7 23 | pose = np.array([ 24 | [rot_mat[0, 0], rot_mat[0, 1], rot_mat[0, 2], translation[0]], 25 | [rot_mat[1, 0], rot_mat[1, 1], rot_mat[1, 2], translation[1]], 26 | [rot_mat[2, 0], rot_mat[2, 1], rot_mat[2, 2], translation[2]], 27 | [0, 0, 0, 1], 28 | ]) 29 | return pose 30 | 31 | def flip_yaw(pose_matrix): 32 | flipped = pose_matrix.copy() 33 | flipped[0, 1] *= -1 34 | flipped[0, 2] *= -1 35 | flipped[1, 0] *= -1 36 | flipped[2, 0] *= -1 37 | flipped[0, 3] *= -1 38 | return flipped 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--source", type=str) 42 | parser.add_argument("--dest", type=str, default=None) 43 | parser.add_argument("--max_images", type=int, default=None) 44 | args = parser.parse_args() 45 | 46 | camera_dataset_file = os.path.join(args.source, 'cameras.json') 47 | 48 | with open(camera_dataset_file, "r") as f: 49 | cameras = json.load(f) 50 | 51 | dataset = {'labels':[]} 52 | max_images = args.max_images if args.max_images is not None else len(cameras) 53 | for i, filename in tqdm(enumerate(cameras), total=max_images): 54 | if (max_images is not None and i >= max_images): break 55 | 56 | rot_mat = cameras[filename] 57 | pose = gen_pose(rot_mat) 58 | intrinsics = np.array([ 59 | [4.2647, 0.00000000e+00, 0.5], 60 | [0.00000000e+00, 4.2647, 0.5], 61 | [0.00000000e+00, 0.00000000e+00, 1.00000000e+00] 62 | ]) 63 | label = np.concatenate([pose.reshape(-1), intrinsics.reshape(-1)]).tolist() 64 | 65 | filename = filename + '.png' 66 | image_path = os.path.join(args.source, filename) 67 | img = Image.open(image_path) 68 | dataset["labels"].append([filename, label]) 69 | 70 | flipped_img = ImageOps.mirror(img) 71 | flipped_pose = flip_yaw(pose) 72 | label = np.concatenate([flipped_pose.reshape(-1), intrinsics.reshape(-1)]).tolist() 73 | base, ext = filename.split('.')[0], '.' + filename.split('.')[1] 74 | flipped_filename = base + '_mirror' + ext 75 | dataset["labels"].append([flipped_filename, label]) 76 | flipped_img.save(os.path.join(args.dest, flipped_filename)) 77 | 78 | with open(os.path.join(args.dest, 'dataset.json'), "w") as f: 79 | json.dump(dataset, f) 80 | -------------------------------------------------------------------------------- /dataset_preprocessing/afhq/runme.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import argparse 12 | import os 13 | import sys 14 | import shutil 15 | import tempfile 16 | import subprocess 17 | 18 | import gdown 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('inzip', type=str) # the AFHQ zip downloaded from starganV2 (https://github.com/clovaai/stargan-v2) 22 | parser.add_argument('outzip', type=str, required=False, default='processed_afhq.zip') # this is the output path to write the new zip 23 | args = parser.parse_args() 24 | 25 | 26 | eg3d_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..') 27 | 28 | input_dataset_path = os.path.realpath(args.inzip) 29 | output_dataset_path = os.path.realpath(args.outzip) 30 | 31 | dataset_tool_path = os.path.join(eg3d_root, 'eg3d', 'dataset_tool.py') 32 | mirror_tool_path = os.path.join(eg3d_root, 'dataset_preprocessing', 'mirror_dataset.py') 33 | 34 | # Attempt to import dataset_tool.py and mirror_dataset.py to fail-fast on errors (ie importing python modules) before any processing 35 | try: 36 | sys.path.append(os.path.dirname(dataset_tool_path)) 37 | import dataset_tool 38 | sys.path.append(os.path.dirname(mirror_tool_path)) 39 | import mirror_dataset 40 | except Exception as e: 41 | print(e) 42 | print("There was a problem while importing the dataset_tool. Are you in the correct virtual environment?") 43 | exit() 44 | 45 | 46 | with tempfile.TemporaryDirectory() as working_dir: 47 | cmd = f""" 48 | unzip {input_dataset_path} -d {working_dir}/extracted_images; 49 | mv {working_dir}/extracted_images/train/cat/ {working_dir}/cat_images/; 50 | """ 51 | subprocess.run([cmd], shell=True, check=True) 52 | 53 | 54 | """Download dataset.json file""" 55 | json_url = 'https://drive.google.com/file/d/1FQXQ26kAgRyN2iOH8CBl3P9CGPIQ5TAQ/view?usp=sharing' 56 | gdown.download(json_url, f'{working_dir}/cat_images/dataset.json', quiet=False, fuzzy=True) 57 | 58 | 59 | print("Mirroring dataset...") 60 | cmd = f""" 61 | python {mirror_tool_path} \ 62 | --source={working_dir}/cat_images \ 63 | --dest={working_dir}/mirrored_images 64 | """ 65 | subprocess.run([cmd], shell=True, check=True) 66 | 67 | 68 | print("Creating dataset zip...") 69 | cmd = f""" 70 | python {dataset_tool_path} \ 71 | --source {working_dir}/mirrored_images \ 72 | --dest {output_dataset_path} \ 73 | --resolution 512x512 74 | """ 75 | subprocess.run([cmd], shell=True, check=True) 76 | -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/3dface2idr_mat.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | import numpy as np 13 | import os 14 | import torch 15 | import json 16 | import argparse 17 | import scipy.io 18 | import sys 19 | sys.path.append('Deep3DFaceRecon_pytorch') 20 | from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--in_root', type=str, default="", help='process folder') 24 | parser.add_argument('--out_path', type=str, default="cameras.json", help='output filename') 25 | args = parser.parse_args() 26 | in_root = args.in_root 27 | 28 | npys = sorted([x for x in os.listdir(in_root) if x.endswith(".mat")]) 29 | 30 | mode = 1 31 | outAll={} 32 | 33 | face_model = ParametricFaceModel(bfm_folder='Deep3DFaceRecon_pytorch/BFM') 34 | 35 | for src_filename in npys: 36 | src = os.path.join(in_root, src_filename) 37 | 38 | dict_load = scipy.io.loadmat(src) 39 | angle = dict_load['angle'] 40 | trans = dict_load['trans'][0] 41 | R = face_model.compute_rotation(torch.from_numpy(angle))[0].numpy() 42 | trans[2] += -10 43 | c = -np.dot(R, trans) 44 | pose = np.eye(4) 45 | pose[:3, :3] = R 46 | 47 | c *= 0.27 # normalize camera radius 48 | c[1] += 0.006 # additional offset used in submission 49 | c[2] += 0.161 # additional offset used in submission 50 | pose[0,3] = c[0] 51 | pose[1,3] = c[1] 52 | pose[2,3] = c[2] 53 | 54 | focal = 2985.29 # = 1015*1024/224*(300/466.285)# 55 | pp = 512#112 56 | w = 1024#224 57 | h = 1024#224 58 | 59 | count = 0 60 | K = np.eye(3) 61 | K[0][0] = focal 62 | K[1][1] = focal 63 | K[0][2] = w/2.0 64 | K[1][2] = h/2.0 65 | K = K.tolist() 66 | 67 | Rot = np.eye(3) 68 | Rot[0, 0] = 1 69 | Rot[1, 1] = -1 70 | Rot[2, 2] = -1 71 | pose[:3, :3] = np.dot(pose[:3, :3], Rot) 72 | 73 | pose = pose.tolist() 74 | out = {} 75 | out["intrinsics"] = K 76 | out["pose"] = pose 77 | outAll[src_filename.replace(".mat", ".png")] = out 78 | 79 | 80 | with open(args.out_path, "w") as outfile: 81 | json.dump(outAll, outfile) 82 | -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/batch_mtcnn.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | import argparse 13 | import cv2 14 | import os 15 | from mtcnn import MTCNN 16 | import random 17 | detector = MTCNN() 18 | 19 | # see how to visualize the bounding box and the landmarks at : https://github.com/ipazc/mtcnn/blob/master/example.py 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--in_root', type=str, default="", help='process folder') 23 | args = parser.parse_args() 24 | in_root = args.in_root 25 | 26 | out_detection = os.path.join(in_root, "detections") 27 | 28 | if not os.path.exists(out_detection): 29 | os.makedirs(out_detection) 30 | 31 | imgs = sorted([x for x in os.listdir(in_root) if x.endswith(".jpg") or x.endswith(".png")]) 32 | random.shuffle(imgs) 33 | for img in imgs: 34 | src = os.path.join(in_root, img) 35 | print(src) 36 | if img.endswith(".jpg"): 37 | dst = os.path.join(out_detection, img.replace(".jpg", ".txt")) 38 | if img.endswith(".png"): 39 | dst = os.path.join(out_detection, img.replace(".png", ".txt")) 40 | 41 | if not os.path.exists(dst): 42 | image = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB) 43 | result = detector.detect_faces(image) 44 | 45 | if len(result)>0: 46 | index = 0 47 | if len(result)>1: # if multiple faces, take the biggest face 48 | size = -100000 49 | for r in range(len(result)): 50 | size_ = result[r]["box"][2] + result[r]["box"][3] 51 | if size < size_: 52 | size = size_ 53 | index = r 54 | 55 | bounding_box = result[index]['box'] 56 | keypoints = result[index]['keypoints'] 57 | if result[index]["confidence"] > 0.9: 58 | 59 | if img.endswith(".jpg"): 60 | dst = os.path.join(out_detection, img.replace(".jpg", ".txt")) 61 | if img.endswith(".png"): 62 | dst = os.path.join(out_detection, img.replace(".png", ".txt")) 63 | 64 | outLand = open(dst, "w") 65 | outLand.write(str(float(keypoints['left_eye'][0])) + " " + str(float(keypoints['left_eye'][1])) + "\n") 66 | outLand.write(str(float(keypoints['right_eye'][0])) + " " + str(float(keypoints['right_eye'][1])) + "\n") 67 | outLand.write(str(float(keypoints['nose'][0])) + " " + str(float(keypoints['nose'][1])) + "\n") 68 | outLand.write(str(float(keypoints['mouth_left'][0])) + " " + str(float(keypoints['mouth_left'][1])) + "\n") 69 | outLand.write(str(float(keypoints['mouth_right'][0])) + " " + str(float(keypoints['mouth_right'][1])) + "\n") 70 | outLand.close() 71 | print(result) 72 | -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/crop_images.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import argparse 12 | import os 13 | import json 14 | 15 | import numpy as np 16 | from PIL import Image 17 | from tqdm import tqdm 18 | from preprocess import align_img 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--indir', type=str, required=True) 24 | parser.add_argument('--outdir', type=str, required=True) 25 | parser.add_argument('--compress_level', type=int, default=0) 26 | args = parser.parse_args() 27 | 28 | with open(os.path.join(args.indir, 'cropping_params.json')) as f: 29 | cropping_params = json.load(f) 30 | 31 | os.makedirs(args.outdir, exist_ok=True) 32 | 33 | for im_path, cropping_dict in tqdm(cropping_params.items()): 34 | im = Image.open(os.path.join(args.indir, im_path)).convert('RGB') 35 | 36 | _, H = im.size 37 | lm = np.array(cropping_dict['lm']) 38 | lm = lm.reshape([-1, 2]) 39 | lm[:, -1] = H - 1 - lm[:, -1] 40 | 41 | _, im_high, _, _, = align_img(im, lm, np.array(cropping_dict['lm3d_std']), target_size=1024., rescale_factor=cropping_dict['rescale_factor']) 42 | 43 | left = int(im_high.size[0]/2 - cropping_dict['center_crop_size']/2) 44 | upper = int(im_high.size[1]/2 - cropping_dict['center_crop_size']/2) 45 | right = left + cropping_dict['center_crop_size'] 46 | lower = upper + cropping_dict['center_crop_size'] 47 | im_cropped = im_high.crop((left, upper, right,lower)) 48 | im_cropped = im_cropped.resize((cropping_dict['output_size'], cropping_dict['output_size']), resample=Image.LANCZOS) 49 | 50 | im_cropped.save(os.path.join(args.outdir, os.path.basename(im_path)), compress_level=args.compress_level) -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/crop_images_in_the_wild.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import argparse 12 | import os 13 | from preprocess import align_img 14 | from PIL import Image 15 | import numpy as np 16 | import sys 17 | sys.path.append('Deep3DFaceRecon_pytorch') 18 | from Deep3DFaceRecon_pytorch.util.load_mats import load_lm3d 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--indir', type=str, required=True) 24 | args = parser.parse_args() 25 | 26 | lm_dir = os.path.join(args.indir, "detections") 27 | img_files = sorted([x for x in os.listdir(args.indir) if x.lower().endswith(".png") or x.lower().endswith(".jpg")]) 28 | lm_files = sorted([x for x in os.listdir(lm_dir) if x.endswith(".txt")]) 29 | 30 | lm3d_std = load_lm3d("Deep3DFaceRecon_pytorch/BFM/") 31 | 32 | out_dir = os.path.join(args.indir, "crop") 33 | if not os.path.exists(out_dir): 34 | os.makedirs(out_dir, exist_ok=True) 35 | 36 | for img_file, lm_file in zip(img_files, lm_files): 37 | 38 | img_path = os.path.join(args.indir, img_file) 39 | lm_path = os.path.join(lm_dir, lm_file) 40 | im = Image.open(img_path).convert('RGB') 41 | _,H = im.size 42 | lm = np.loadtxt(lm_path).astype(np.float32) 43 | lm = lm.reshape([-1, 2]) 44 | lm[:, -1] = H - 1 - lm[:, -1] 45 | 46 | target_size = 1024. 47 | rescale_factor = 300 48 | center_crop_size = 700 49 | output_size = 512 50 | 51 | _, im_high, _, _, = align_img(im, lm, lm3d_std, target_size=target_size, rescale_factor=rescale_factor) 52 | 53 | left = int(im_high.size[0]/2 - center_crop_size/2) 54 | upper = int(im_high.size[1]/2 - center_crop_size/2) 55 | right = left + center_crop_size 56 | lower = upper + center_crop_size 57 | im_cropped = im_high.crop((left, upper, right,lower)) 58 | im_cropped = im_cropped.resize((output_size, output_size), resample=Image.LANCZOS) 59 | out_path = os.path.join(out_dir, img_file.split(".")[0] + ".png") 60 | im_cropped.save(out_path) -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/preprocess.py: -------------------------------------------------------------------------------- 1 | ./Deep3DFaceRecon_pytorch/util/preprocess.py -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/preprocess_face_cameras.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | ############################################################# 12 | 13 | # Usage: python dataset_preprocessing/ffhq/preprocess_ffhq_cameras.py --source /data/ffhq --dest /data/preprocessed_ffhq_images 14 | 15 | ############################################################# 16 | 17 | import json 18 | import numpy as np 19 | from PIL import Image, ImageOps 20 | import os 21 | from tqdm import tqdm 22 | import argparse 23 | import torch 24 | import sys 25 | sys.path.append('../../eg3d') 26 | from camera_utils import create_cam2world_matrix 27 | 28 | COMPRESS_LEVEL=0 29 | 30 | def fix_intrinsics(intrinsics): 31 | intrinsics = np.array(intrinsics).copy() 32 | assert intrinsics.shape == (3, 3), intrinsics 33 | intrinsics[0,0] = 2985.29/700 34 | intrinsics[1,1] = 2985.29/700 35 | intrinsics[0,2] = 1/2 36 | intrinsics[1,2] = 1/2 37 | assert intrinsics[0,1] == 0 38 | assert intrinsics[2,2] == 1 39 | assert intrinsics[1,0] == 0 40 | assert intrinsics[2,0] == 0 41 | assert intrinsics[2,1] == 0 42 | return intrinsics 43 | 44 | # For our recropped images, with correction 45 | def fix_pose(pose): 46 | COR = np.array([0, 0, 0.175]) 47 | pose = np.array(pose).copy() 48 | location = pose[:3, 3] 49 | direction = (location - COR) / np.linalg.norm(location - COR) 50 | pose[:3, 3] = direction * 2.7 + COR 51 | return pose 52 | 53 | # Used in original submission 54 | def fix_pose_orig(pose): 55 | pose = np.array(pose).copy() 56 | location = pose[:3, 3] 57 | radius = np.linalg.norm(location) 58 | pose[:3, 3] = pose[:3, 3]/radius * 2.7 59 | return pose 60 | 61 | # Used for original crop images 62 | def fix_pose_simplify(pose): 63 | cam_location = torch.tensor(pose).clone()[:3, 3] 64 | normalized_cam_location = torch.nn.functional.normalize(cam_location - torch.tensor([0, 0, 0.175]), dim=0) 65 | camera_view_dir = - normalized_cam_location 66 | camera_pos = 2.7 * normalized_cam_location + np.array([0, 0, 0.175]) 67 | simple_pose_matrix = create_cam2world_matrix(camera_view_dir.unsqueeze(0), camera_pos.unsqueeze(0))[0] 68 | return simple_pose_matrix.numpy() 69 | 70 | def flip_yaw(pose_matrix): 71 | flipped = pose_matrix.copy() 72 | flipped[0, 1] *= -1 73 | flipped[0, 2] *= -1 74 | flipped[1, 0] *= -1 75 | flipped[2, 0] *= -1 76 | flipped[0, 3] *= -1 77 | return flipped 78 | 79 | if __name__ == '__main__': 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--source", type=str) 82 | parser.add_argument("--dest", type=str, default=None) 83 | parser.add_argument("--max_images", type=int, default=None) 84 | parser.add_argument("--mode", type=str, default="orig", choices=["orig", "cor"]) 85 | args = parser.parse_args() 86 | 87 | camera_dataset_file = os.path.join(args.source, 'cameras.json') 88 | 89 | with open(camera_dataset_file, "r") as f: 90 | cameras = json.load(f) 91 | 92 | dataset = {'labels':[]} 93 | 94 | max_images = args.max_images if args.max_images is not None else len(cameras) 95 | for i, filename in tqdm(enumerate(cameras), total=max_images): 96 | if (max_images is not None and i >= max_images): break 97 | 98 | pose = cameras[filename]['pose'] 99 | intrinsics = cameras[filename]['intrinsics'] 100 | 101 | if args.mode == 'cor': 102 | pose = fix_pose(pose) 103 | elif args.mode == 'orig': 104 | pose = fix_pose_orig(pose) 105 | elif args.mode == 'simplify': 106 | pose = fix_pose_simplify(pose) 107 | else: 108 | assert False, "invalid mode" 109 | intrinsics = fix_intrinsics(intrinsics) 110 | label = np.concatenate([pose.reshape(-1), intrinsics.reshape(-1)]).tolist() 111 | 112 | image_path = os.path.join(args.source, filename) 113 | img = Image.open(image_path) 114 | 115 | dataset["labels"].append([filename, label]) 116 | os.makedirs(os.path.dirname(os.path.join(args.dest, filename)), exist_ok=True) 117 | img.save(os.path.join(args.dest, filename)) 118 | 119 | 120 | flipped_img = ImageOps.mirror(img) 121 | flipped_pose = flip_yaw(pose) 122 | label = np.concatenate([flipped_pose.reshape(-1), intrinsics.reshape(-1)]).tolist() 123 | base, ext = filename.split('.')[0], '.' + filename.split('.')[1] 124 | flipped_filename = base + '_mirror' + ext 125 | dataset["labels"].append([flipped_filename, label]) 126 | flipped_img.save(os.path.join(args.dest, flipped_filename)) 127 | 128 | with open(os.path.join(args.dest, 'dataset.json'), "w") as f: 129 | json.dump(dataset, f) -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/preprocess_in_the_wild.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import argparse 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--indir', type=str, required=True) 15 | args = parser.parse_args() 16 | 17 | # run mtcnn needed for Deep3DFaceRecon 18 | command = "python batch_mtcnn.py --in_root " + args.indir 19 | print(command) 20 | os.system(command) 21 | 22 | out_folder = args.indir.split("/")[-2] if args.indir.endswith("/") else args.indir.split("/")[-1] 23 | 24 | # run Deep3DFaceRecon 25 | os.chdir('Deep3DFaceRecon_pytorch') 26 | command = "python test.py --img_folder=" + args.indir + " --gpu_ids=0 --name=pretrained --epoch=20" 27 | print(command) 28 | os.system(command) 29 | os.chdir('..') 30 | 31 | # crop out the input image 32 | command = "python crop_images_in_the_wild.py --indir=" + args.indir 33 | print(command) 34 | os.system(command) 35 | 36 | # convert the pose to our format 37 | command = f"python 3dface2idr_mat.py --in_root Deep3DFaceRecon_pytorch/checkpoints/pretrained/results/{out_folder}/epoch_20_000000 --out_path {os.path.join(args.indir, 'crop', 'cameras.json')}" 38 | print(command) 39 | os.system(command) 40 | 41 | # additional correction to match the submission version 42 | command = f"python preprocess_face_cameras.py --source {os.path.join(args.indir, 'crop')} --dest {out_folder} --mode orig" 43 | print(command) 44 | os.system(command) -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/runme.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import gdown 13 | import shutil 14 | import subprocess 15 | 16 | dir_path = os.path.dirname(os.path.realpath(__file__)) 17 | 18 | #--------------------------------------------------------------------------------------------------------# 19 | 20 | # Download wilds 21 | cmd = "python download_ffhq.py --wilds" 22 | subprocess.run([cmd], shell=True, check=True) 23 | 24 | #--------------------------------------------------------------------------------------------------------# 25 | 26 | # Validate wilds 27 | cmd = "python validate_ffhq.py" 28 | subprocess.run([cmd], shell=True, check=True) 29 | 30 | #--------------------------------------------------------------------------------------------------------# 31 | 32 | # Align wilds 33 | cmd = "python align_multiprocess.py --source=. --dest=realign1500 --threads=16" 34 | subprocess.run([cmd], shell=True, check=True) 35 | 36 | # #--------------------------------------------------------------------------------------------------------# 37 | 38 | # Move out of subdirs into single directory 39 | realign1500_dir = 'realign1500' 40 | for subdir in os.listdir(realign1500_dir): 41 | if not os.path.isdir(os.path.join(realign1500_dir, subdir)): continue 42 | if not(len(subdir) == 5 and subdir.isnumeric()): continue 43 | for filename in os.listdir(os.path.join(realign1500_dir, subdir)): 44 | shutil.move(os.path.join(realign1500_dir, subdir, filename), os.path.join(realign1500_dir, filename)) 45 | 46 | # #--------------------------------------------------------------------------------------------------------# 47 | 48 | print("Downloading cropping params...") 49 | gdown.download('https://drive.google.com/uc?id=1KdVf2lIepGECRaANGhfuR7mDpJ5nfb9K', 'realign1500/cropping_params.json', quiet=False) 50 | 51 | #--------------------------------------------------------------------------------------------------------# 52 | 53 | # Perform final cropping of 512x512 images. 54 | print("Processing final crops...") 55 | cmd = "python crop_images.py" 56 | input_flag = " --indir " + 'realign1500' 57 | output_flag = " --outdir " + 'final_crops' 58 | cmd += input_flag + output_flag 59 | subprocess.run([cmd], shell=True, check=True) 60 | 61 | # #--------------------------------------------------------------------------------------------------------# 62 | 63 | print("Mirroring dataset...") 64 | cmd = f"python ../mirror_dataset.py --source=final_crops" 65 | subprocess.run([cmd], shell=True, check=True) 66 | 67 | # #--------------------------------------------------------------------------------------------------------# 68 | 69 | print("Downloading poses...") 70 | gdown.download('https://drive.google.com/uc?id=14mzYD1DxUjh7BGgeWKgXtLHWwvr-he1Z', 'final_crops/dataset.json', quiet=False) 71 | 72 | #--------------------------------------------------------------------------------------------------------# 73 | 74 | print("Creating dataset zip...") 75 | cmd = f"python {os.path.join(dir_path, '../../eg3d', 'dataset_tool.py')}" 76 | cmd += f" --source=final_crops --dest FFHQ_512.zip --resolution 512x512" 77 | subprocess.run([cmd], shell=True, check=True) 78 | -------------------------------------------------------------------------------- /dataset_preprocessing/ffhq/validate_ffhq.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | Usage: python validate_ffhq.py 13 | 14 | Checks in-the-wild images to verify images are complete and uncorrupted. Deletes files that 15 | failed check. After running this script, re-run download_ffhq.py to reacquire failed images. 16 | """ 17 | 18 | 19 | import json 20 | from PIL import Image 21 | import hashlib 22 | import numpy as np 23 | from tqdm import tqdm 24 | import os 25 | import argparse 26 | import sys 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--dataset_json', type=str, default='ffhq-dataset-v2.json') 31 | parser.add_argument('--mode', type=str, default='file', choices=['file', 'pixel']) 32 | args = parser.parse_args() 33 | clean = True 34 | 35 | with open(args.dataset_json) as f: 36 | datasetjson = json.load(f) 37 | 38 | for key, val in tqdm(datasetjson.items()): 39 | file_spec = val['in_the_wild'] 40 | try: 41 | if args.mode == 'file': 42 | with open(file_spec['file_path'], 'rb') as file_to_check: 43 | data = file_to_check.read() 44 | if 'file_md5' in file_spec and hashlib.md5(data).hexdigest() != file_spec['file_md5']: 45 | raise IOError('Incorrect file MD5', file_spec['file_path']) 46 | elif args.mode == 'pixel': 47 | with Image.open(file_spec['file_path']) as image: 48 | if 'pixel_size' in file_spec and list(image.size) != file_spec['pixel_size']: 49 | raise IOError('Incorrect pixel size', file_spec['file_path']) 50 | if 'pixel_md5' in file_spec and hashlib.md5(np.array(image)).hexdigest() != file_spec['pixel_md5']: 51 | raise IOError('Incorrect pixel MD5', file_spec['file_path']) 52 | except IOError: 53 | clean = False 54 | tqdm.write(f"Bad file {file_spec['file_path']}") 55 | if os.path.isfile(file_spec['file_path']): 56 | os.remove(file_spec['file_path']) 57 | 58 | if not clean: 59 | sys.exit(1) -------------------------------------------------------------------------------- /dataset_preprocessing/mirror_dataset.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | ############################################################# 12 | 13 | # Usage: python dataset_preprocessing/ffhq/preprocess_ffhq_cameras.py --source /data/ffhq --dest /data/preprocessed_ffhq_images 14 | 15 | ############################################################# 16 | 17 | import json 18 | import numpy as np 19 | from PIL import Image, ImageOps 20 | import os 21 | from tqdm import tqdm 22 | import argparse 23 | 24 | COMPRESS_LEVEL=0 25 | 26 | def flip_yaw(pose_matrix): 27 | flipped = pose_matrix.copy() 28 | flipped[0, 1] *= -1 29 | flipped[0, 2] *= -1 30 | flipped[1, 0] *= -1 31 | flipped[2, 0] *= -1 32 | flipped[0, 3] *= -1 33 | return flipped 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--source", type=str) 38 | parser.add_argument("--dest", type=str, default=None) 39 | parser.add_argument("--max_images", type=int, default=None) 40 | args = parser.parse_args() 41 | 42 | 43 | dest = args.source if args.dest is None else args.dest 44 | 45 | dataset_file = os.path.join(args.source, 'dataset.json') 46 | 47 | if os.path.isfile(dataset_file): # If dataset.json present, mirror images and mirror labels. 48 | with open(dataset_file, "r") as f: 49 | dataset = json.load(f) 50 | 51 | max_images = args.max_images if args.max_images is not None else len(dataset['labels']) 52 | for i, example in tqdm(enumerate(dataset['labels']), total=max_images): 53 | if (max_images is not None and i >= max_images): break 54 | filename, label = example 55 | if '_mirror' in filename: 56 | continue 57 | 58 | image_path = os.path.join(args.source, filename) 59 | img = Image.open(image_path) 60 | 61 | if args.dest is not None: # skip saving originals if dest==source 62 | os.makedirs(os.path.dirname(os.path.join(dest, filename)), exist_ok=True) 63 | img.save(os.path.join(dest, filename), compress_level=COMPRESS_LEVEL) 64 | 65 | flipped_img = ImageOps.mirror(img) 66 | pose, intrinsics = np.array(label[:16]).reshape(4,4), np.array(label[16:]).reshape(3, 3) 67 | flipped_pose = flip_yaw(pose) 68 | label = np.concatenate([flipped_pose.reshape(-1), intrinsics.reshape(-1)]).tolist() 69 | base, ext = filename.split('.')[0], '.' + filename.split('.')[1] 70 | flipped_filename = base + '_mirror' + ext 71 | dataset["labels"].append([flipped_filename, label]) 72 | flipped_img.save(os.path.join(dest, flipped_filename), compress_level=COMPRESS_LEVEL) 73 | 74 | with open(os.path.join(dest, 'dataset.json'), "w") as f: 75 | json.dump(dataset, f) 76 | 77 | else: # If dataset.json is not preset, just mirror images. 78 | for filename in tqdm(os.listdir(args.source)): 79 | if filename.lower().endswith(('.png', '.jpg', '.jpeg')): 80 | image_path = os.path.join(args.source, filename) 81 | img = Image.open(image_path) 82 | 83 | if args.dest is not None: # skip saving originals if dest==source 84 | os.makedirs(os.path.dirname(os.path.join(dest, filename)), exist_ok=True) 85 | img.save(os.path.join(dest, filename), compress_level=COMPRESS_LEVEL) 86 | 87 | flipped_img = ImageOps.mirror(img) 88 | base, ext = os.path.splitext(filename) 89 | flipped_filename = base + '_mirror' + ext 90 | flipped_img.save(os.path.join(dest, flipped_filename), compress_level=COMPRESS_LEVEL) -------------------------------------------------------------------------------- /dataset_preprocessing/rebalance_ffhq/rebalance_ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import argparse 12 | import json 13 | import os 14 | import zipfile 15 | 16 | from tqdm import tqdm 17 | 18 | 19 | #-------------------------------------------------------------------------------------------------- 20 | 21 | # create the new zipfile by duplicating files according to num_replicas.json 22 | # which is a file indicating how many times each file in the original ffhq 23 | # should be duplicated. 24 | 25 | # num_replicas was created with the following steps: 26 | # 1: get the min and max yaw over the dataset 27 | # 2: split the dataset into N=9 uniform size arcs across the range 28 | # (with possibly differing number of images in each arc) 29 | # 3: Mark images in edge bins to have a higher number of duplicates 30 | # 31 | # The new dataset is still biased towards frontal facing images 32 | # but much less so than before. 33 | 34 | 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('inzip', type=str) # the FFHQ dataset created by `dataset_preprocessing/ffhq/runme.py` 38 | parser.add_argument('outzip', type=str) # this is the output path to write the new zip 39 | args = parser.parse_args() 40 | 41 | print('Please verify that the following two md5 hashes are identical to ensure you are specifying the correct input dataset') 42 | print('Command: >> unzip -p [path/to/input_dataset.zip] dataset.json | md5sum') 43 | print('Expected:') 44 | print('a5893550587656894051685f1a5930ce -') 45 | print('Actual:') 46 | os.system(f'unzip -p {inzip} dataset.json | md5sum') 47 | 48 | num_replicas = os.path.join(os.path.dirname(__file__), 'num_replicas.json') 49 | with open(num_replicas) as f: 50 | duplicate_list = json.load(f) 51 | 52 | 53 | with zipfile.ZipFile(args.inzip, 'r') as zipread, zipfile.ZipFile(args.outzip, 'w') as zipwrite: 54 | dataset = json.loads(zipread.read('dataset.json')) 55 | 56 | new_dataset = [] 57 | for index, n_duplicates in tqdm(duplicate_list.items()): 58 | index = int(index) 59 | 60 | name, label = dataset['labels'][index] 61 | img = zipread.read(name) 62 | 63 | for replica in range(0, n_duplicates): 64 | newname = name.replace('.', f'_{replica:02}.') 65 | 66 | new_dataset.append([newname, label]) 67 | zipwrite.writestr(newname, img) 68 | 69 | new_dataset = {'labels': new_dataset} 70 | zipwrite.writestr('dataset.json', json.dumps(new_dataset)) 71 | 72 | print('Sanity check: to verify your dataset was created properly we recommend the follwing verification:') 73 | print('>> unzip -p [path/to/output_dataset.zip] dataset.json | md5sum') 74 | print('should give the value bae1b0b52267670f1735fef9092b5c11.') -------------------------------------------------------------------------------- /dataset_preprocessing/shapenet_cars/preprocess_shapenet_cameras.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | ############################################################# 12 | 13 | # Usage: python dataset_preprocessing/shapenet/preprocess_cars_cameras.py --source ~/downloads/cars_train --dest /data/cars_preprocessed 14 | 15 | ############################################################# 16 | 17 | 18 | import json 19 | import numpy as np 20 | import os 21 | from tqdm import tqdm 22 | import argparse 23 | 24 | def list_recursive(folderpath): 25 | return [os.path.join(folderpath, filename) for filename in os.listdir(folderpath)] 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--source", type=str) 30 | parser.add_argument("--max_images", type=int, default=None) 31 | args = parser.parse_args() 32 | 33 | # Parse cameras 34 | dataset_path = args.source 35 | cameras = {} 36 | for scene_folder_path in list_recursive(dataset_path): 37 | if not os.path.isdir(scene_folder_path): continue 38 | 39 | for rgb_path in list_recursive(os.path.join(scene_folder_path, 'rgb')): 40 | relative_path = os.path.relpath(rgb_path, dataset_path) 41 | intrinsics_path = os.path.join(scene_folder_path, 'intrinsics.txt') 42 | pose_path = rgb_path.replace('rgb', 'pose').replace('png', 'txt') 43 | assert os.path.isfile(rgb_path) 44 | assert os.path.isfile(intrinsics_path) 45 | assert os.path.isfile(pose_path) 46 | 47 | with open(pose_path, 'r') as f: 48 | pose = np.array([float(n) for n in f.read().split(' ')]).reshape(4, 4).tolist() 49 | 50 | with open(intrinsics_path, 'r') as f: 51 | first_line = f.read().split('\n')[0].split(' ') 52 | focal = float(first_line[0]) 53 | cx = float(first_line[1]) 54 | cy = float(first_line[2]) 55 | 56 | orig_img_size = 512 # cars_train has intrinsics corresponding to image size of 512 * 512 57 | intrinsics = np.array( 58 | [[focal / orig_img_size, 0.00000000e+00, cx / orig_img_size], 59 | [0.00000000e+00, focal / orig_img_size, cy / orig_img_size], 60 | [0.00000000e+00, 0.00000000e+00, 1.00000000e+00]] 61 | ).tolist() 62 | 63 | cameras[relative_path] = {'pose': pose, 'intrinsics': intrinsics, 'scene-name': os.path.basename(scene_folder_path)} 64 | 65 | with open(os.path.join(dataset_path, 'cameras.json'), 'w') as outfile: 66 | json.dump(cameras, outfile, indent=4) 67 | 68 | 69 | camera_dataset_file = os.path.join(args.source, 'cameras.json') 70 | 71 | with open(camera_dataset_file, "r") as f: 72 | cameras = json.load(f) 73 | 74 | dataset = {'labels':[]} 75 | max_images = args.max_images if args.max_images is not None else len(cameras) 76 | for i, filename in tqdm(enumerate(cameras), total=max_images): 77 | if (max_images is not None and i >= max_images): break 78 | 79 | pose = np.array(cameras[filename]['pose']) 80 | intrinsics = np.array(cameras[filename]['intrinsics']) 81 | label = np.concatenate([pose.reshape(-1), intrinsics.reshape(-1)]).tolist() 82 | 83 | image_path = os.path.join(args.source, filename) 84 | dataset["labels"].append([filename, label]) 85 | 86 | with open(os.path.join(args.source, 'dataset.json'), "w") as f: 87 | json.dump(dataset, f, indent=4) 88 | -------------------------------------------------------------------------------- /dataset_preprocessing/shapenet_cars/run_me.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import gdown 13 | import shutil 14 | import tempfile 15 | import subprocess 16 | 17 | 18 | if __name__ == '__main__': 19 | with tempfile.TemporaryDirectory() as working_dir: 20 | download_name = 'cars_train.zip' 21 | url = 'https://drive.google.com/uc?id=1bThUNtIHx4xEQyffVBSf82ABDDh2HlFn' 22 | output_dataset_name = 'cars_128.zip' 23 | 24 | dir_path = os.path.dirname(os.path.realpath(__file__)) 25 | extracted_data_path = os.path.join(working_dir, os.path.splitext(download_name)[0]) 26 | 27 | print("Downloading data...") 28 | zipped_dataset = os.path.join(working_dir, download_name) 29 | gdown.download(url, zipped_dataset, quiet=False) 30 | 31 | print("Unzipping downloaded data...") 32 | shutil.unpack_archive(zipped_dataset, working_dir) 33 | 34 | print("Converting camera parameters...") 35 | cmd = f"python {os.path.join(dir_path, 'preprocess_shapenet_cameras.py')} --source={extracted_data_path}" 36 | subprocess.run([cmd], shell=True) 37 | 38 | print("Creating dataset zip...") 39 | cmd = f"python {os.path.join(dir_path, '../../eg3d', 'dataset_tool.py')}" 40 | cmd += f" --source {extracted_data_path} --dest {output_dataset_name} --resolution 128x128" 41 | subprocess.run([cmd], shell=True) -------------------------------------------------------------------------------- /docs/camera_conventions.md: -------------------------------------------------------------------------------- 1 | Camera poses are in OpenCV Cam2World format. 2 | Intrinsics are normalized. -------------------------------------------------------------------------------- /docs/camera_coordinate_conventions.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/eg3d/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/docs/camera_coordinate_conventions.jpg -------------------------------------------------------------------------------- /docs/models.md: -------------------------------------------------------------------------------- 1 | Pre-trained checkpoints can be found on the [NGC Catalog](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/eg3d). 2 | 3 | Brief descriptions of models and the commands used to train them are found below. 4 | 5 | --- 6 | 7 | # FFHQ 8 | 9 | **ffhq512-64.pkl** 10 | 11 | FFHQ 512, trained with neural rendering resolution of 64x64. 12 | 13 | ```.bash 14 | # Train with FFHQ from scratch with raw neural rendering resolution=64, using 8 GPUs. 15 | python train.py --outdir=~/training-runs --cfg=ffhq --data=~/datasets/FFHQ_512.zip \ 16 | --gpus=8 --batch=32 --gamma=1 --gen_pose_cond=True 17 | ``` 18 | 19 | **ffhq512-128.pkl** 20 | 21 | Fine-tune FFHQ 512, with neural rendering resolution of 128x128. 22 | 23 | ```.bash 24 | # Second stage finetuning of FFHQ to 128 neural rendering resolution. 25 | python train.py --outdir=~/training-runs --cfg=ffhq --data=~/datasets/FFHQ_512.zip \ 26 | --resume=ffhq-64.pkl \ 27 | --gpus=8 --batch=32 --gamma=1 --gen_pose_cond=True --neural_rendering_resolution_final=128 --kimg=2000 28 | ``` 29 | 30 | ## FFHQ Rebalanced 31 | 32 | Same as the models above, but fine-tuned using a rebalanced version of FFHQ that has a more uniform pose distribution. Compared to models trained on standard FFHQ, these models should produce better 3D shapes and better renderings from steep angles. 33 | 34 | **ffhqrebalanced512-64.pkl** 35 | 36 | ```.bash 37 | # Finetune with rebalanced FFHQ at rendering resolution 64. 38 | python train.py --outdir=~/training-runs --cfg=ffhq --data=~/datasets/FFHQ_rebalanced_512.zip \ 39 | --resume=ffhq-64.pkl \ 40 | --gpus=8 --batch=32 --gamma=1 --gen_pose_cond=True --gpc_reg_prob=0.8 41 | ``` 42 | 43 | **ffhqrebalanced512-128.pkl** 44 | ```.bash 45 | # Finetune with rebalanced FFHQ at 128 neural rendering resolution. 46 | python train.py --outdir=~/training-runs --cfg=ffhq --data=~/datasets/FFHQ_rebalanced_512.zip \ 47 | --resume=ffhq-rebalanced-64.pkl \ 48 | --gpus=8 --batch=32 --gamma=1 --gen_pose_cond=True --gpc_reg_prob=0.8 --neural_rendering_resolution_final=128 49 | ``` 50 | 51 | # AFHQ Cats 52 | 53 | **afhqcats512-128.pkl** 54 | 55 | ```.bash 56 | # Train with AFHQ, finetuning from FFHQ with ADA, using 8 GPUs. 57 | python train.py --outdir=~/training-runs --cfg=afhq --data=~/datasets/afhq.zip \ 58 | --resume=ffhq-64.pkl \ 59 | --gpus=8 --batch=32 --gamma=5 --aug=ada --gen_pose_cond=True --gpc_reg_prob=0.8 --neural_rendering_resolution_final=128 60 | ``` 61 | 62 | 63 | # Shapenet 64 | 65 | **shapenetcars128-64.pkl** 66 | 67 | ```.bash 68 | # Train with Shapenet from scratch, using 8 GPUs. 69 | python train.py --outdir=~/training-runs --cfg=shapenet --data=~/datasets/cars_train.zip \ 70 | --gpus=8 --batch=32 --gamma=0.3 71 | ``` -------------------------------------------------------------------------------- /docs/teaser.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/eg3d/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/docs/teaser.jpeg -------------------------------------------------------------------------------- /docs/visualizer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/eg3d/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/docs/visualizer.png -------------------------------------------------------------------------------- /docs/visualizer_guide.md: -------------------------------------------------------------------------------- 1 | ## Guide to the Visualizer 2 | 3 | ![Visualizer](./visualizer.png) 4 | 5 | We include a 3D visualizer that is based on the amazing tool introduced in StyleGAN3. The following document describes important options and sliders of the visualizer UI. 6 | 7 | TLDR: 8 | 1. Press the "Pickle/Recent" button to select a pretrained EG3D model. 9 | 2. Click and drag the "Latent/Drag" button to sweep latent codes and change the scene identity. 10 | 3. Click and drag the rendering on the right to move the camera. 11 | 12 | --- 13 | 14 | ## Network & Latent 15 | 16 | ### Pickle 17 | Specify the path of the model checkpoint to visualize. You have a few options: 18 | 1. Drag and drop the .pkl file from your file browser into the visualizer window 19 | 1. Type the path (or url) of your .pkl file into the text field 20 | 1. Press the recent box to access a list of recently used checkpoints 21 | 22 | ### Pose 23 | Control the pitch and yaw of the camera by clicking and dragging the rendering on the right. By default, the camera rotates on a sphere with fixed radius, pointed at the origin. 24 | 25 | ### FOV 26 | Control the field of view of the camera with this slider to zoom the camera in and out. For FFHQ, 18 degrees is about right; for ShapeNet, use a FOV of 45 degrees. 27 | 28 | ### Cond Pose 29 | The pose with which we condition the generator (see Generator Pose Conditioning in Sec. 4.4). By default, we condition on the fixed frontal camera pose. For models trained without generator pose conditioning, this will have no effect. 30 | 31 | ### Render Type 32 | Toggle between the final super-resolved output (RGB image), a depth map (Depth image) or the raw neural rendering without super resolution (Neural rendering). 33 | 34 | ### Depth Sample Multiplier / Depth Sample Importance Multiplier 35 | Adjust the number of depth samples taken per ray. By increasing the number of depth samples, we reduce flickering artifacts caused by depth aliasing, which leads to more temporally-consistent videos. However, the tradeoff is slower rendering and slightly blurrier images. At 1X / 1X, render in the visualizer with the same number of depth samples as at training; at 2X / 2X, take double the uniformly spaced and double the importance samples per ray. As an example: we train FFHQ with 48 uniformly spaced depth samples and 48 importance samples per ray. Using 2X / 2X, we instead take 96 uniformly spaced depth samples and 96 importance samples (192 total). 36 | 37 | ### Latent 38 | The seed for the latent code, *z*, that is the input to the generator. Click and drag the "drag" button to sweep between scene identities. Press the "Anim" checkbox to play an animation sweeping through latent codes. 39 | 40 | ### Stylemix 41 | The seed for a second latent code for style mixing. Check the boxes on the right to select which layers should be conditioned by this second code. 42 | 43 | ### Truncate 44 | Apply the truncation trick in *w*-space to trade off fidelity for diversity. Psi=1 means no truncation. Psi=0 gives the "average" scene learned by the generator. A Psi between 0 and 1, e.g. 0.7 is a compromise that reduces diversity somewhat but improves the overall consistency in quality. (See the Truncation Trick in StyleGAN for more info.) 45 | 46 | --- 47 | 48 | ## Performance & capture 49 | 50 | ### Render 51 | 52 | Displays the framerate of rendering. On an RTX 3090, with neural rendering resolution of 128, and with 48 uniform and 48 importance depth samples, we get 25-30 FPS. 53 | 54 | ### Capture 55 | 56 | Save screenshots to the directory specified by the text field. Save image saves just the rendering; Save GUI saves the complete pane including the user interface. 57 | 58 | --- 59 | 60 | ## Layers & channels 61 | 62 | ### Cache backbone 63 | For rendering where the scene identity (the latent code *z* and conditioning pose) remain static, but rendering parameters (the camera pose, fov, render type, etc...) change, we can enable 'backbone caching' which will enable us to cache and reuse the existing triplanes computed by the convolutional backbone. Backbone caching slightly improves rendering speed. 64 | 65 | ### Layer viewer 66 | View and analyze the intermediate weights and layers of the generator. Scroll through the network and select a layer using the checkbox. Use the "Channel" slider on the right to view different activations. Do note that when 'cache backbone' is enabled, you will be unable to view the intermediate weights of the convolutional backbone/triplanes. 67 | -------------------------------------------------------------------------------- /eg3d/camera_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts. 13 | """ 14 | 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from training.volumetric_rendering import math_utils 21 | 22 | class GaussianCameraPoseSampler: 23 | """ 24 | Samples pitch and yaw from a Gaussian distribution and returns a camera pose. 25 | Camera is specified as looking at the origin. 26 | If horizontal and vertical stddev (specified in radians) are zero, gives a 27 | deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean. 28 | The coordinate system is specified with y-up, z-forward, x-left. 29 | Horizontal mean is the azimuthal angle (rotation around y axis) in radians, 30 | vertical mean is the polar angle (angle from the y axis) in radians. 31 | A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2. 32 | 33 | Example: 34 | For a camera pose looking at the origin with the camera at position [0, 0, 1]: 35 | cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1) 36 | """ 37 | 38 | @staticmethod 39 | def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): 40 | h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean 41 | v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean 42 | v = torch.clamp(v, 1e-5, math.pi - 1e-5) 43 | 44 | theta = h 45 | v = v / math.pi 46 | phi = torch.arccos(1 - 2*v) 47 | 48 | camera_origins = torch.zeros((batch_size, 3), device=device) 49 | 50 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) 51 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) 52 | camera_origins[:, 1:2] = radius*torch.cos(phi) 53 | 54 | forward_vectors = math_utils.normalize_vecs(-camera_origins) 55 | return create_cam2world_matrix(forward_vectors, camera_origins) 56 | 57 | 58 | class LookAtPoseSampler: 59 | """ 60 | Same as GaussianCameraPoseSampler, except the 61 | camera is specified as looking at 'lookat_position', a 3-vector. 62 | 63 | Example: 64 | For a camera pose looking at the origin with the camera at position [0, 0, 1]: 65 | cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1) 66 | """ 67 | 68 | @staticmethod 69 | def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): 70 | h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean 71 | v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean 72 | v = torch.clamp(v, 1e-5, math.pi - 1e-5) 73 | 74 | theta = h 75 | v = v / math.pi 76 | phi = torch.arccos(1 - 2*v) 77 | 78 | camera_origins = torch.zeros((batch_size, 3), device=device) 79 | 80 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) 81 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) 82 | camera_origins[:, 1:2] = radius*torch.cos(phi) 83 | 84 | # forward_vectors = math_utils.normalize_vecs(-camera_origins) 85 | forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins) 86 | return create_cam2world_matrix(forward_vectors, camera_origins) 87 | 88 | class UniformCameraPoseSampler: 89 | """ 90 | Same as GaussianCameraPoseSampler, except the 91 | pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev. 92 | 93 | Example: 94 | For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians: 95 | 96 | cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16) 97 | """ 98 | 99 | @staticmethod 100 | def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): 101 | h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean 102 | v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean 103 | v = torch.clamp(v, 1e-5, math.pi - 1e-5) 104 | 105 | theta = h 106 | v = v / math.pi 107 | phi = torch.arccos(1 - 2*v) 108 | 109 | camera_origins = torch.zeros((batch_size, 3), device=device) 110 | 111 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) 112 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) 113 | camera_origins[:, 1:2] = radius*torch.cos(phi) 114 | 115 | forward_vectors = math_utils.normalize_vecs(-camera_origins) 116 | return create_cam2world_matrix(forward_vectors, camera_origins) 117 | 118 | def create_cam2world_matrix(forward_vector, origin): 119 | """ 120 | Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix. 121 | Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll. 122 | """ 123 | 124 | forward_vector = math_utils.normalize_vecs(forward_vector) 125 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector) 126 | 127 | right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1)) 128 | up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1)) 129 | 130 | rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 131 | rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1) 132 | 133 | translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 134 | translation_matrix[:, :3, 3] = origin 135 | cam2world = (translation_matrix @ rotation_matrix)[:, :, :] 136 | assert(cam2world.shape[1:] == (4, 4)) 137 | return cam2world 138 | 139 | 140 | def FOV_to_intrinsics(fov_degrees, device='cpu'): 141 | """ 142 | Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. 143 | Note the intrinsics are returned as normalized by image size, rather than in pixel units. 144 | Assumes principal point is at image center. 145 | """ 146 | 147 | focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414)) 148 | intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) 149 | return intrinsics -------------------------------------------------------------------------------- /eg3d/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from .util import EasyDict, make_cache_dir_path 12 | -------------------------------------------------------------------------------- /eg3d/environment.yml: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | name: eg3d 12 | channels: 13 | - pytorch 14 | - nvidia 15 | dependencies: 16 | - python >= 3.8 17 | - pip 18 | - numpy>=1.20 19 | - click>=8.0 20 | - pillow=8.3.1 21 | - scipy=1.7.1 22 | - pytorch=1.11.0 23 | - cudatoolkit=11.1 24 | - requests=2.26.0 25 | - tqdm=4.62.2 26 | - ninja=1.10.2 27 | - matplotlib=3.4.2 28 | - imageio=2.9.0 29 | - pip: 30 | - imgui==1.3.0 31 | - glfw==2.2.0 32 | - pyopengl==3.1.5 33 | - imageio-ffmpeg==0.4.3 34 | - pyspng 35 | - psutil 36 | - mrcfile 37 | - tensorboard -------------------------------------------------------------------------------- /eg3d/gui_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/gui_utils/imgui_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import contextlib 12 | import imgui 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): 17 | s = imgui.get_style() 18 | s.window_padding = [spacing, spacing] 19 | s.item_spacing = [spacing, spacing] 20 | s.item_inner_spacing = [spacing, spacing] 21 | s.columns_min_spacing = spacing 22 | s.indent_spacing = indent 23 | s.scrollbar_size = scrollbar 24 | s.frame_padding = [4, 3] 25 | s.window_border_size = 1 26 | s.child_border_size = 1 27 | s.popup_border_size = 1 28 | s.frame_border_size = 1 29 | s.window_rounding = 0 30 | s.child_rounding = 0 31 | s.popup_rounding = 3 32 | s.frame_rounding = 3 33 | s.scrollbar_rounding = 3 34 | s.grab_rounding = 3 35 | 36 | getattr(imgui, f'style_colors_{color_scheme}')(s) 37 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 38 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] 39 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | @contextlib.contextmanager 44 | def grayed_out(cond=True): 45 | if cond: 46 | s = imgui.get_style() 47 | text = s.colors[imgui.COLOR_TEXT_DISABLED] 48 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] 49 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 50 | imgui.push_style_color(imgui.COLOR_TEXT, *text) 51 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) 52 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) 53 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) 54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) 55 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) 56 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) 57 | imgui.push_style_color(imgui.COLOR_BUTTON, *back) 58 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) 59 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) 60 | imgui.push_style_color(imgui.COLOR_HEADER, *back) 61 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) 62 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) 63 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) 64 | yield 65 | imgui.pop_style_color(14) 66 | else: 67 | yield 68 | 69 | #---------------------------------------------------------------------------- 70 | 71 | @contextlib.contextmanager 72 | def item_width(width=None): 73 | if width is not None: 74 | imgui.push_item_width(width) 75 | yield 76 | imgui.pop_item_width() 77 | else: 78 | yield 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def scoped_by_object_id(method): 83 | def decorator(self, *args, **kwargs): 84 | imgui.push_id(str(id(self))) 85 | res = method(self, *args, **kwargs) 86 | imgui.pop_id() 87 | return res 88 | return decorator 89 | 90 | #---------------------------------------------------------------------------- 91 | 92 | def button(label, width=0, enabled=True): 93 | with grayed_out(not enabled): 94 | clicked = imgui.button(label, width=width) 95 | clicked = clicked and enabled 96 | return clicked 97 | 98 | #---------------------------------------------------------------------------- 99 | 100 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): 101 | expanded = False 102 | if show: 103 | if default: 104 | flags |= imgui.TREE_NODE_DEFAULT_OPEN 105 | if not enabled: 106 | flags |= imgui.TREE_NODE_LEAF 107 | with grayed_out(not enabled): 108 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) 109 | expanded = expanded and enabled 110 | return expanded, visible 111 | 112 | #---------------------------------------------------------------------------- 113 | 114 | def popup_button(label, width=0, enabled=True): 115 | if button(label, width, enabled): 116 | imgui.open_popup(label) 117 | opened = imgui.begin_popup(label) 118 | return opened 119 | 120 | #---------------------------------------------------------------------------- 121 | 122 | def input_text(label, value, buffer_length, flags, width=None, help_text=''): 123 | old_value = value 124 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) 125 | if value == '': 126 | color[-1] *= 0.5 127 | with item_width(width): 128 | imgui.push_style_color(imgui.COLOR_TEXT, *color) 129 | value = value if value != '' else help_text 130 | changed, value = imgui.input_text(label, value, buffer_length, flags) 131 | value = value if value != help_text else '' 132 | imgui.pop_style_color(1) 133 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: 134 | changed = (value != old_value) 135 | return changed, value 136 | 137 | #---------------------------------------------------------------------------- 138 | 139 | def drag_previous_control(enabled=True): 140 | dragging = False 141 | dx = 0 142 | dy = 0 143 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): 144 | if enabled: 145 | dragging = True 146 | dx, dy = imgui.get_mouse_drag_delta() 147 | imgui.reset_mouse_drag_delta() 148 | imgui.end_drag_drop_source() 149 | return dragging, dx, dy 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def drag_button(label, width=0, enabled=True): 154 | clicked = button(label, width=width, enabled=enabled) 155 | dragging, dx, dy = drag_previous_control(enabled=enabled) 156 | return clicked, dragging, dx, dy 157 | 158 | #---------------------------------------------------------------------------- 159 | 160 | def drag_hidden_window(label, x, y, width, height, enabled=True): 161 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) 162 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) 163 | imgui.set_next_window_position(x, y) 164 | imgui.set_next_window_size(width, height) 165 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) 166 | dragging, dx, dy = drag_previous_control(enabled=enabled) 167 | imgui.end() 168 | imgui.pop_style_color(2) 169 | return dragging, dx, dy 170 | 171 | #---------------------------------------------------------------------------- 172 | -------------------------------------------------------------------------------- /eg3d/gui_utils/imgui_window.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import imgui 13 | import imgui.integrations.glfw 14 | 15 | from . import glfw_window 16 | from . import imgui_utils 17 | from . import text_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class ImguiWindow(glfw_window.GlfwWindow): 22 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): 23 | if font is None: 24 | font = text_utils.get_default_font() 25 | font_sizes = {int(size) for size in font_sizes} 26 | super().__init__(title=title, **glfw_kwargs) 27 | 28 | # Init fields. 29 | self._imgui_context = None 30 | self._imgui_renderer = None 31 | self._imgui_fonts = None 32 | self._cur_font_size = max(font_sizes) 33 | 34 | # Delete leftover imgui.ini to avoid unexpected behavior. 35 | if os.path.isfile('imgui.ini'): 36 | os.remove('imgui.ini') 37 | 38 | # Init ImGui. 39 | self._imgui_context = imgui.create_context() 40 | self._imgui_renderer = _GlfwRenderer(self._glfw_window) 41 | self._attach_glfw_callbacks() 42 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. 43 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). 44 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} 45 | self._imgui_renderer.refresh_font_texture() 46 | 47 | def close(self): 48 | self.make_context_current() 49 | self._imgui_fonts = None 50 | if self._imgui_renderer is not None: 51 | self._imgui_renderer.shutdown() 52 | self._imgui_renderer = None 53 | if self._imgui_context is not None: 54 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. 55 | self._imgui_context = None 56 | super().close() 57 | 58 | def _glfw_key_callback(self, *args): 59 | super()._glfw_key_callback(*args) 60 | self._imgui_renderer.keyboard_callback(*args) 61 | 62 | @property 63 | def font_size(self): 64 | return self._cur_font_size 65 | 66 | @property 67 | def spacing(self): 68 | return round(self._cur_font_size * 0.4) 69 | 70 | def set_font_size(self, target): # Applied on next frame. 71 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] 72 | 73 | def begin_frame(self): 74 | # Begin glfw frame. 75 | super().begin_frame() 76 | 77 | # Process imgui events. 78 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 79 | if self.content_width > 0 and self.content_height > 0: 80 | self._imgui_renderer.process_inputs() 81 | 82 | # Begin imgui frame. 83 | imgui.new_frame() 84 | imgui.push_font(self._imgui_fonts[self._cur_font_size]) 85 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) 86 | 87 | def end_frame(self): 88 | imgui.pop_font() 89 | imgui.render() 90 | imgui.end_frame() 91 | self._imgui_renderer.render(imgui.get_draw_data()) 92 | super().end_frame() 93 | 94 | #---------------------------------------------------------------------------- 95 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. 96 | 97 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): 98 | def __init__(self, *args, **kwargs): 99 | super().__init__(*args, **kwargs) 100 | self.mouse_wheel_multiplier = 1 101 | 102 | def scroll_callback(self, window, x_offset, y_offset): 103 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier 104 | 105 | #---------------------------------------------------------------------------- 106 | -------------------------------------------------------------------------------- /eg3d/gui_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import functools 12 | from typing import Optional 13 | 14 | import dnnlib 15 | import numpy as np 16 | import PIL.Image 17 | import PIL.ImageFont 18 | import scipy.ndimage 19 | 20 | from . import gl_utils 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | def get_default_font(): 25 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular 26 | return dnnlib.util.open_url(url, return_filename=True) 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | @functools.lru_cache(maxsize=None) 31 | def get_pil_font(font=None, size=32): 32 | if font is None: 33 | font = get_default_font() 34 | return PIL.ImageFont.truetype(font=font, size=size) 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | def get_array(string, *, dropshadow_radius: int=None, **kwargs): 39 | if dropshadow_radius is not None: 40 | offset_x = int(np.ceil(dropshadow_radius*2/3)) 41 | offset_y = int(np.ceil(dropshadow_radius*2/3)) 42 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 43 | else: 44 | return _get_array_priv(string, **kwargs) 45 | 46 | @functools.lru_cache(maxsize=10000) 47 | def _get_array_priv( 48 | string: str, *, 49 | size: int = 32, 50 | max_width: Optional[int]=None, 51 | max_height: Optional[int]=None, 52 | min_size=10, 53 | shrink_coef=0.8, 54 | dropshadow_radius: int=None, 55 | offset_x: int=None, 56 | offset_y: int=None, 57 | **kwargs 58 | ): 59 | cur_size = size 60 | array = None 61 | while True: 62 | if dropshadow_radius is not None: 63 | # separate implementation for dropshadow text rendering 64 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 65 | else: 66 | array = _get_array_impl(string, size=cur_size, **kwargs) 67 | height, width, _ = array.shape 68 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): 69 | break 70 | cur_size = max(int(cur_size * shrink_coef), min_size) 71 | return array 72 | 73 | #---------------------------------------------------------------------------- 74 | 75 | @functools.lru_cache(maxsize=10000) 76 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): 77 | pil_font = get_pil_font(font=font, size=size) 78 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 79 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 80 | width = max(line.shape[1] for line in lines) 81 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 82 | line_spacing = line_pad if line_pad is not None else size // 2 83 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 84 | mask = np.concatenate(lines, axis=0) 85 | alpha = mask 86 | if outline > 0: 87 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) 88 | alpha = mask.astype(np.float32) / 255 89 | alpha = scipy.ndimage.gaussian_filter(alpha, outline) 90 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp 91 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 92 | alpha = np.maximum(alpha, mask) 93 | return np.stack([mask, alpha], axis=-1) 94 | 95 | #---------------------------------------------------------------------------- 96 | 97 | @functools.lru_cache(maxsize=10000) 98 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): 99 | assert (offset_x > 0) and (offset_y > 0) 100 | pil_font = get_pil_font(font=font, size=size) 101 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 102 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 103 | width = max(line.shape[1] for line in lines) 104 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 105 | line_spacing = line_pad if line_pad is not None else size // 2 106 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 107 | mask = np.concatenate(lines, axis=0) 108 | alpha = mask 109 | 110 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) 111 | alpha = mask.astype(np.float32) / 255 112 | alpha = scipy.ndimage.gaussian_filter(alpha, radius) 113 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 114 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 115 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] 116 | alpha = np.maximum(alpha, mask) 117 | return np.stack([mask, alpha], axis=-1) 118 | 119 | #---------------------------------------------------------------------------- 120 | 121 | @functools.lru_cache(maxsize=10000) 122 | def get_texture(string, bilinear=True, mipmap=True, **kwargs): 123 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /eg3d/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Frechet Inception Distance (FID) from the paper 12 | "GANs trained by a two time-scale update rule converge to a local Nash 13 | equilibrium". Matches the original implementation by Heusel et al. at 14 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 15 | 16 | import numpy as np 17 | import scipy.linalg 18 | from . import metric_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def compute_fid(opts, max_real, num_gen): 23 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 24 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 25 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 26 | 27 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 30 | 31 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 32 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 33 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 34 | 35 | if opts.rank != 0: 36 | return float('nan') 37 | 38 | m = np.square(mu_gen - mu_real).sum() 39 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 40 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 41 | return float(fid) 42 | 43 | #---------------------------------------------------------------------------- 44 | -------------------------------------------------------------------------------- /eg3d/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Inception Score (IS) from the paper "Improved techniques for training 12 | GANs". Matches the original implementation by Salimans et al. at 13 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 14 | 15 | import numpy as np 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_is(opts, num_gen, num_splits): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 24 | 25 | gen_probs = metric_utils.compute_feature_stats_for_generator( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | capture_all=True, max_items=num_gen).get_all() 28 | 29 | if opts.rank != 0: 30 | return float('nan'), float('nan') 31 | 32 | scores = [] 33 | for i in range(num_splits): 34 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 35 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 36 | kl = np.mean(np.sum(kl, axis=1)) 37 | scores.append(np.exp(kl)) 38 | return float(np.mean(scores)), float(np.std(scores)) 39 | 40 | #---------------------------------------------------------------------------- 41 | -------------------------------------------------------------------------------- /eg3d/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 12 | GANs". Matches the original implementation by Binkowski et al. at 13 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 14 | 15 | import numpy as np 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | real_features = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 28 | 29 | gen_features = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | n = real_features.shape[1] 37 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 38 | t = 0 39 | for _subset_idx in range(num_subsets): 40 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 41 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 42 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 43 | b = (x @ y.T / n + 1) ** 3 44 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 45 | kid = t / num_subsets / m 46 | return float(kid) 47 | 48 | #---------------------------------------------------------------------------- 49 | -------------------------------------------------------------------------------- /eg3d/metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Main API for computing and reporting quality metrics.""" 12 | 13 | import os 14 | import time 15 | import json 16 | import torch 17 | import dnnlib 18 | 19 | from . import metric_utils 20 | from . import frechet_inception_distance 21 | from . import kernel_inception_distance 22 | from . import precision_recall 23 | from . import perceptual_path_length 24 | from . import inception_score 25 | from . import equivariance 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | _metric_dict = dict() # name => fn 30 | 31 | def register_metric(fn): 32 | assert callable(fn) 33 | _metric_dict[fn.__name__] = fn 34 | return fn 35 | 36 | def is_valid_metric(metric): 37 | return metric in _metric_dict 38 | 39 | def list_valid_metrics(): 40 | return list(_metric_dict.keys()) 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 45 | assert is_valid_metric(metric) 46 | opts = metric_utils.MetricOptions(**kwargs) 47 | 48 | # Calculate. 49 | start_time = time.time() 50 | results = _metric_dict[metric](opts) 51 | total_time = time.time() - start_time 52 | 53 | # Broadcast results. 54 | for key, value in list(results.items()): 55 | if opts.num_gpus > 1: 56 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 57 | torch.distributed.broadcast(tensor=value, src=0) 58 | value = float(value.cpu()) 59 | results[key] = value 60 | 61 | # Decorate with metadata. 62 | return dnnlib.EasyDict( 63 | results = dnnlib.EasyDict(results), 64 | metric = metric, 65 | total_time = total_time, 66 | total_time_str = dnnlib.util.format_time(total_time), 67 | num_gpus = opts.num_gpus, 68 | ) 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 73 | metric = result_dict['metric'] 74 | assert is_valid_metric(metric) 75 | if run_dir is not None and snapshot_pkl is not None: 76 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 77 | 78 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 79 | print(jsonl_line) 80 | if run_dir is not None and os.path.isdir(run_dir): 81 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 82 | f.write(jsonl_line + '\n') 83 | 84 | #---------------------------------------------------------------------------- 85 | # Recommended metrics. 86 | 87 | @register_metric 88 | def fid50k_full(opts): 89 | opts.dataset_kwargs.update(max_size=None, xflip=False) 90 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 91 | return dict(fid50k_full=fid) 92 | 93 | @register_metric 94 | def kid50k_full(opts): 95 | opts.dataset_kwargs.update(max_size=None, xflip=False) 96 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 97 | return dict(kid50k_full=kid) 98 | 99 | @register_metric 100 | def pr50k3_full(opts): 101 | opts.dataset_kwargs.update(max_size=None, xflip=False) 102 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 103 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 104 | 105 | @register_metric 106 | def ppl2_wend(opts): 107 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 108 | return dict(ppl2_wend=ppl) 109 | 110 | @register_metric 111 | def eqt50k_int(opts): 112 | opts.G_kwargs.update(force_fp32=True) 113 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) 114 | return dict(eqt50k_int=psnr) 115 | 116 | @register_metric 117 | def eqt50k_frac(opts): 118 | opts.G_kwargs.update(force_fp32=True) 119 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) 120 | return dict(eqt50k_frac=psnr) 121 | 122 | @register_metric 123 | def eqr50k(opts): 124 | opts.G_kwargs.update(force_fp32=True) 125 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) 126 | return dict(eqr50k=psnr) 127 | 128 | #---------------------------------------------------------------------------- 129 | # Legacy metrics. 130 | 131 | @register_metric 132 | def fid50k(opts): 133 | opts.dataset_kwargs.update(max_size=None) 134 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 135 | return dict(fid50k=fid) 136 | 137 | @register_metric 138 | def kid50k(opts): 139 | opts.dataset_kwargs.update(max_size=None) 140 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 141 | return dict(kid50k=kid) 142 | 143 | @register_metric 144 | def pr50k3(opts): 145 | opts.dataset_kwargs.update(max_size=None) 146 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 147 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 148 | 149 | @register_metric 150 | def is50k(opts): 151 | opts.dataset_kwargs.update(max_size=None, xflip=False) 152 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 153 | return dict(is50k_mean=mean, is50k_std=std) 154 | 155 | #---------------------------------------------------------------------------- 156 | -------------------------------------------------------------------------------- /eg3d/metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 12 | Architecture for Generative Adversarial Networks". Matches the original 13 | implementation by Karras et al. at 14 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 15 | 16 | import copy 17 | import numpy as np 18 | import torch 19 | from . import metric_utils 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | # Spherical interpolation of a batch of vectors. 24 | def slerp(a, b, t): 25 | a = a / a.norm(dim=-1, keepdim=True) 26 | b = b / b.norm(dim=-1, keepdim=True) 27 | d = (a * b).sum(dim=-1, keepdim=True) 28 | p = t * torch.acos(d) 29 | c = b - d * a 30 | c = c / c.norm(dim=-1, keepdim=True) 31 | d = a * torch.cos(p) + c * torch.sin(p) 32 | d = d / d.norm(dim=-1, keepdim=True) 33 | return d 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | class PPLSampler(torch.nn.Module): 38 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 39 | assert space in ['z', 'w'] 40 | assert sampling in ['full', 'end'] 41 | super().__init__() 42 | self.G = copy.deepcopy(G) 43 | self.G_kwargs = G_kwargs 44 | self.epsilon = epsilon 45 | self.space = space 46 | self.sampling = sampling 47 | self.crop = crop 48 | self.vgg16 = copy.deepcopy(vgg16) 49 | 50 | def forward(self, c): 51 | # Generate random latents and interpolation t-values. 52 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 53 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 54 | 55 | # Interpolate in W or Z. 56 | if self.space == 'w': 57 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 58 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 59 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 60 | else: # space == 'z' 61 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 62 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 63 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 64 | 65 | # Randomize noise buffers. 66 | for name, buf in self.G.named_buffers(): 67 | if name.endswith('.noise_const'): 68 | buf.copy_(torch.randn_like(buf)) 69 | 70 | # Generate images. 71 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 72 | 73 | # Center crop. 74 | if self.crop: 75 | assert img.shape[2] == img.shape[3] 76 | c = img.shape[2] // 8 77 | img = img[:, :, c*3 : c*7, c*2 : c*6] 78 | 79 | # Downsample to 256x256. 80 | factor = self.G.img_resolution // 256 81 | if factor > 1: 82 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 83 | 84 | # Scale dynamic range from [-1,1] to [0,255]. 85 | img = (img + 1) * (255 / 2) 86 | if self.G.img_channels == 1: 87 | img = img.repeat([1, 3, 1, 1]) 88 | 89 | # Evaluate differential LPIPS. 90 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 91 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 92 | return dist 93 | 94 | #---------------------------------------------------------------------------- 95 | 96 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 97 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 99 | 100 | # Setup sampler and labels. 101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 102 | sampler.eval().requires_grad_(False).to(opts.device) 103 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 104 | 105 | # Sampling loop. 106 | dist = [] 107 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 108 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 109 | progress.update(batch_start) 110 | x = sampler(next(c_iter)) 111 | for src in range(opts.num_gpus): 112 | y = x.clone() 113 | if opts.num_gpus > 1: 114 | torch.distributed.broadcast(y, src=src) 115 | dist.append(y) 116 | progress.update(num_samples) 117 | 118 | # Compute PPL. 119 | if opts.rank != 0: 120 | return float('nan') 121 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 122 | lo = np.percentile(dist, 1, interpolation='lower') 123 | hi = np.percentile(dist, 99, interpolation='higher') 124 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 125 | return float(ppl) 126 | 127 | #---------------------------------------------------------------------------- 128 | -------------------------------------------------------------------------------- /eg3d/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 12 | Metric for Assessing Generative Models". Matches the original implementation 13 | by Kynkaanniemi et al. at 14 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 15 | 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 22 | assert 0 <= rank < num_gpus 23 | num_cols = col_features.shape[0] 24 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 25 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 26 | dist_batches = [] 27 | for col_batch in col_batches[rank :: num_gpus]: 28 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 29 | for src in range(num_gpus): 30 | dist_broadcast = dist_batch.clone() 31 | if num_gpus > 1: 32 | torch.distributed.broadcast(dist_broadcast, src=src) 33 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 34 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 39 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 40 | detector_kwargs = dict(return_features=True) 41 | 42 | real_features = metric_utils.compute_feature_stats_for_dataset( 43 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 44 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 45 | 46 | gen_features = metric_utils.compute_feature_stats_for_generator( 47 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 48 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 49 | 50 | results = dict() 51 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 52 | kth = [] 53 | for manifold_batch in manifold.split(row_batch_size): 54 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 55 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 56 | kth = torch.cat(kth) if opts.rank == 0 else None 57 | pred = [] 58 | for probes_batch in probes.split(row_batch_size): 59 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 60 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 61 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 62 | return results['precision'], results['recall'] 63 | 64 | #---------------------------------------------------------------------------- 65 | -------------------------------------------------------------------------------- /eg3d/shape_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | """ 13 | Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.) 14 | 15 | Takes as input an .mrc file and extracts a mesh. 16 | 17 | Ex. 18 | python shape_utils.py my_shape.mrc 19 | Ex. 20 | python shape_utils.py myshapes_directory --level=12 21 | """ 22 | 23 | 24 | import time 25 | import plyfile 26 | import glob 27 | import logging 28 | import numpy as np 29 | import os 30 | import random 31 | import torch 32 | import torch.utils.data 33 | import trimesh 34 | import skimage.measure 35 | import argparse 36 | import mrcfile 37 | from tqdm import tqdm 38 | 39 | 40 | def convert_sdf_samples_to_ply( 41 | numpy_3d_sdf_tensor, 42 | voxel_grid_origin, 43 | voxel_size, 44 | ply_filename_out, 45 | offset=None, 46 | scale=None, 47 | level=0.0 48 | ): 49 | """ 50 | Convert sdf samples to .ply 51 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 52 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 53 | :voxel_size: float, the size of the voxels 54 | :ply_filename_out: string, path of the filename to save to 55 | This function adapted from: https://github.com/RobotLocomotion/spartan 56 | """ 57 | start_time = time.time() 58 | 59 | verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0) 60 | # try: 61 | verts, faces, normals, values = skimage.measure.marching_cubes( 62 | numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3 63 | ) 64 | # except: 65 | # pass 66 | 67 | # transform from voxel coordinates to camera coordinates 68 | # note x and y are flipped in the output of marching_cubes 69 | mesh_points = np.zeros_like(verts) 70 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0] 71 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1] 72 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2] 73 | 74 | # apply additional offset and scale 75 | if scale is not None: 76 | mesh_points = mesh_points / scale 77 | if offset is not None: 78 | mesh_points = mesh_points - offset 79 | 80 | # try writing to the ply file 81 | 82 | num_verts = verts.shape[0] 83 | num_faces = faces.shape[0] 84 | 85 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 86 | 87 | for i in range(0, num_verts): 88 | verts_tuple[i] = tuple(mesh_points[i, :]) 89 | 90 | faces_building = [] 91 | for i in range(0, num_faces): 92 | faces_building.append(((faces[i, :].tolist(),))) 93 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 94 | 95 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 96 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 97 | 98 | ply_data = plyfile.PlyData([el_verts, el_faces]) 99 | ply_data.write(ply_filename_out) 100 | print(f"wrote to {ply_filename_out}") 101 | 102 | 103 | def convert_mrc(input_filename, output_filename, isosurface_level=1): 104 | with mrcfile.open(input_filename) as mrc: 105 | convert_sdf_samples_to_ply(np.transpose(mrc.data, (2, 1, 0)), [0, 0, 0], 1, output_filename, level=isosurface_level) 106 | 107 | if __name__ == '__main__': 108 | start_time = time.time() 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('input_mrc_path') 111 | parser.add_argument('--level', type=float, default=10, help="The isosurface level for marching cubes") 112 | args = parser.parse_args() 113 | 114 | if os.path.isfile(args.input_mrc_path) and args.input_mrc_path.split('.')[-1] == 'ply': 115 | output_obj_path = args.input_mrc_path.split('.mrc')[0] + '.ply' 116 | convert_mrc(args.input_mrc_path, output_obj_path, isosurface_level=1) 117 | 118 | print(f"{time.time() - start_time:02f} s") 119 | else: 120 | assert os.path.isdir(args.input_mrc_path) 121 | 122 | for mrc_path in tqdm(glob.glob(os.path.join(args.input_mrc_path, '*.mrc'))): 123 | output_obj_path = mrc_path.split('.mrc')[0] + '.ply' 124 | convert_mrc(mrc_path, output_obj_path, isosurface_level=args.level) -------------------------------------------------------------------------------- /eg3d/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import glob 12 | import hashlib 13 | import importlib 14 | import os 15 | import re 16 | import shutil 17 | import uuid 18 | 19 | import torch 20 | import torch.utils.cpp_extension 21 | from torch.utils.file_baton import FileBaton 22 | 23 | #---------------------------------------------------------------------------- 24 | # Global options. 25 | 26 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 27 | 28 | #---------------------------------------------------------------------------- 29 | # Internal helper funcs. 30 | 31 | def _find_compiler_bindir(): 32 | patterns = [ 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 35 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 36 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 37 | ] 38 | for pattern in patterns: 39 | matches = sorted(glob.glob(pattern)) 40 | if len(matches): 41 | return matches[-1] 42 | return None 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def _get_mangled_gpu_name(): 47 | name = torch.cuda.get_device_name().lower() 48 | out = [] 49 | for c in name: 50 | if re.match('[a-z0-9_-]+', c): 51 | out.append(c) 52 | else: 53 | out.append('-') 54 | return ''.join(out) 55 | 56 | #---------------------------------------------------------------------------- 57 | # Main entry point for compiling and loading C++/CUDA plugins. 58 | 59 | _cached_plugins = dict() 60 | 61 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 62 | assert verbosity in ['none', 'brief', 'full'] 63 | if headers is None: 64 | headers = [] 65 | if source_dir is not None: 66 | sources = [os.path.join(source_dir, fname) for fname in sources] 67 | headers = [os.path.join(source_dir, fname) for fname in headers] 68 | 69 | # Already cached? 70 | if module_name in _cached_plugins: 71 | return _cached_plugins[module_name] 72 | 73 | # Print status. 74 | if verbosity == 'full': 75 | print(f'Setting up PyTorch plugin "{module_name}"...') 76 | elif verbosity == 'brief': 77 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 78 | verbose_build = (verbosity == 'full') 79 | 80 | # Compile and load. 81 | try: # pylint: disable=too-many-nested-blocks 82 | # Make sure we can find the necessary compiler binaries. 83 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 84 | compiler_bindir = _find_compiler_bindir() 85 | if compiler_bindir is None: 86 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 87 | os.environ['PATH'] += ';' + compiler_bindir 88 | 89 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 90 | # break the build or unnecessarily restrict what's available to nvcc. 91 | # Unset it to let nvcc decide based on what's available on the 92 | # machine. 93 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 94 | 95 | # Incremental build md5sum trickery. Copies all the input source files 96 | # into a cached build directory under a combined md5 digest of the input 97 | # source files. Copying is done only if the combined digest has changed. 98 | # This keeps input file timestamps and filenames the same as in previous 99 | # extension builds, allowing for fast incremental rebuilds. 100 | # 101 | # This optimization is done only in case all the source files reside in 102 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 103 | # environment variable is set (we take this as a signal that the user 104 | # actually cares about this.) 105 | # 106 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 107 | # around the *.cu dependency bug in ninja config. 108 | # 109 | all_source_files = sorted(sources + headers) 110 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 111 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 112 | 113 | # Compute combined hash digest for all source files. 114 | hash_md5 = hashlib.md5() 115 | for src in all_source_files: 116 | with open(src, 'rb') as f: 117 | hash_md5.update(f.read()) 118 | 119 | # Select cached build directory name. 120 | source_digest = hash_md5.hexdigest() 121 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 122 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 123 | 124 | if not os.path.isdir(cached_build_dir): 125 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 126 | os.makedirs(tmpdir) 127 | for src in all_source_files: 128 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 129 | try: 130 | os.replace(tmpdir, cached_build_dir) # atomic 131 | except OSError: 132 | # source directory already exists, delete tmpdir and its contents. 133 | shutil.rmtree(tmpdir) 134 | if not os.path.isdir(cached_build_dir): raise 135 | 136 | # Compile. 137 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 138 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 139 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 140 | else: 141 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 142 | 143 | # Load. 144 | module = importlib.import_module(module_name) 145 | 146 | except: 147 | if verbosity == 'brief': 148 | print('Failed!') 149 | raise 150 | 151 | # Print status and add to cache dict. 152 | if verbosity == 'full': 153 | print(f'Done setting up PyTorch plugin "{module_name}".') 154 | elif verbosity == 'brief': 155 | print('Done.') 156 | _cached_plugins[module_name] = module 157 | return module 158 | 159 | #---------------------------------------------------------------------------- 160 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include "bias_act.h" 17 | 18 | //------------------------------------------------------------------------ 19 | 20 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 21 | { 22 | if (x.dim() != y.dim()) 23 | return false; 24 | for (int64_t i = 0; i < x.dim(); i++) 25 | { 26 | if (x.size(i) != y.size(i)) 27 | return false; 28 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 29 | return false; 30 | } 31 | return true; 32 | } 33 | 34 | //------------------------------------------------------------------------ 35 | 36 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 37 | { 38 | // Validate arguments. 39 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 40 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 41 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 42 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 43 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 44 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 45 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 46 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 47 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 48 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 49 | 50 | // Validate layout. 51 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 52 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 53 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 54 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 55 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 56 | 57 | // Create output tensor. 58 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 59 | torch::Tensor y = torch::empty_like(x); 60 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 61 | 62 | // Initialize CUDA kernel parameters. 63 | bias_act_kernel_params p; 64 | p.x = x.data_ptr(); 65 | p.b = (b.numel()) ? b.data_ptr() : NULL; 66 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 67 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 68 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 69 | p.y = y.data_ptr(); 70 | p.grad = grad; 71 | p.act = act; 72 | p.alpha = alpha; 73 | p.gain = gain; 74 | p.clamp = clamp; 75 | p.sizeX = (int)x.numel(); 76 | p.sizeB = (int)b.numel(); 77 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 78 | 79 | // Choose CUDA kernel. 80 | void* kernel; 81 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 82 | { 83 | kernel = choose_bias_act_kernel(p); 84 | }); 85 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 86 | 87 | // Launch CUDA kernel. 88 | p.loopX = 4; 89 | int blockSize = 4 * 32; 90 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("bias_act", &bias_act); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | #include "bias_act.h" 15 | 16 | //------------------------------------------------------------------------ 17 | // Helpers. 18 | 19 | template struct InternalType; 20 | template <> struct InternalType { typedef double scalar_t; }; 21 | template <> struct InternalType { typedef float scalar_t; }; 22 | template <> struct InternalType { typedef float scalar_t; }; 23 | 24 | //------------------------------------------------------------------------ 25 | // CUDA kernel. 26 | 27 | template 28 | __global__ void bias_act_kernel(bias_act_kernel_params p) 29 | { 30 | typedef typename InternalType::scalar_t scalar_t; 31 | int G = p.grad; 32 | scalar_t alpha = (scalar_t)p.alpha; 33 | scalar_t gain = (scalar_t)p.gain; 34 | scalar_t clamp = (scalar_t)p.clamp; 35 | scalar_t one = (scalar_t)1; 36 | scalar_t two = (scalar_t)2; 37 | scalar_t expRange = (scalar_t)80; 38 | scalar_t halfExpRange = (scalar_t)40; 39 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 40 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 41 | 42 | // Loop over elements. 43 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 44 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 45 | { 46 | // Load. 47 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 48 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 49 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 50 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 51 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 52 | scalar_t yy = (gain != 0) ? yref / gain : 0; 53 | scalar_t y = 0; 54 | 55 | // Apply bias. 56 | ((G == 0) ? x : xref) += b; 57 | 58 | // linear 59 | if (A == 1) 60 | { 61 | if (G == 0) y = x; 62 | if (G == 1) y = x; 63 | } 64 | 65 | // relu 66 | if (A == 2) 67 | { 68 | if (G == 0) y = (x > 0) ? x : 0; 69 | if (G == 1) y = (yy > 0) ? x : 0; 70 | } 71 | 72 | // lrelu 73 | if (A == 3) 74 | { 75 | if (G == 0) y = (x > 0) ? x : x * alpha; 76 | if (G == 1) y = (yy > 0) ? x : x * alpha; 77 | } 78 | 79 | // tanh 80 | if (A == 4) 81 | { 82 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 83 | if (G == 1) y = x * (one - yy * yy); 84 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 85 | } 86 | 87 | // sigmoid 88 | if (A == 5) 89 | { 90 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 91 | if (G == 1) y = x * yy * (one - yy); 92 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 93 | } 94 | 95 | // elu 96 | if (A == 6) 97 | { 98 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 99 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 100 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 101 | } 102 | 103 | // selu 104 | if (A == 7) 105 | { 106 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 107 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 108 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 109 | } 110 | 111 | // softplus 112 | if (A == 8) 113 | { 114 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 115 | if (G == 1) y = x * (one - exp(-yy)); 116 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 117 | } 118 | 119 | // swish 120 | if (A == 9) 121 | { 122 | if (G == 0) 123 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 124 | else 125 | { 126 | scalar_t c = exp(xref); 127 | scalar_t d = c + one; 128 | if (G == 1) 129 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 130 | else 131 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 132 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 133 | } 134 | } 135 | 136 | // Apply gain. 137 | y *= gain * dy; 138 | 139 | // Clamp. 140 | if (clamp >= 0) 141 | { 142 | if (G == 0) 143 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 144 | else 145 | y = (yref > -clamp & yref < clamp) ? y : 0; 146 | } 147 | 148 | // Store. 149 | ((T*)p.y)[xi] = (T)y; 150 | } 151 | } 152 | 153 | //------------------------------------------------------------------------ 154 | // CUDA kernel selection. 155 | 156 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 157 | { 158 | if (p.act == 1) return (void*)bias_act_kernel; 159 | if (p.act == 2) return (void*)bias_act_kernel; 160 | if (p.act == 3) return (void*)bias_act_kernel; 161 | if (p.act == 4) return (void*)bias_act_kernel; 162 | if (p.act == 5) return (void*)bias_act_kernel; 163 | if (p.act == 6) return (void*)bias_act_kernel; 164 | if (p.act == 7) return (void*)bias_act_kernel; 165 | if (p.act == 8) return (void*)bias_act_kernel; 166 | if (p.act == 9) return (void*)bias_act_kernel; 167 | return NULL; 168 | } 169 | 170 | //------------------------------------------------------------------------ 171 | // Template specializations. 172 | 173 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 174 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 175 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 176 | 177 | //------------------------------------------------------------------------ 178 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | //------------------------------------------------------------------------ 14 | // CUDA kernel parameters. 15 | 16 | struct bias_act_kernel_params 17 | { 18 | const void* x; // [sizeX] 19 | const void* b; // [sizeB] or NULL 20 | const void* xref; // [sizeX] or NULL 21 | const void* yref; // [sizeX] or NULL 22 | const void* dy; // [sizeX] or NULL 23 | void* y; // [sizeX] 24 | 25 | int grad; 26 | int act; 27 | float alpha; 28 | float gain; 29 | float clamp; 30 | 31 | int sizeX; 32 | int sizeB; 33 | int stepB; 34 | int loopX; 35 | }; 36 | 37 | //------------------------------------------------------------------------ 38 | // CUDA kernel selection. 39 | 40 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 41 | 42 | //------------------------------------------------------------------------ 43 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """2D convolution with optional up/downsampling.""" 12 | 13 | import torch 14 | 15 | from .. import misc 16 | from . import conv2d_gradfix 17 | from . import upfirdn2d 18 | from .upfirdn2d import _parse_padding 19 | from .upfirdn2d import _get_filter_size 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def _get_weight_shape(w): 24 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 25 | shape = [int(sz) for sz in w.shape] 26 | misc.assert_shape(w, shape) 27 | return shape 28 | 29 | #---------------------------------------------------------------------------- 30 | 31 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 32 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 33 | """ 34 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 35 | 36 | # Flip weight if requested. 37 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 38 | if not flip_weight and (kw > 1 or kh > 1): 39 | w = w.flip([2, 3]) 40 | 41 | # Execute using conv2d_gradfix. 42 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 43 | return op(x, w, stride=stride, padding=padding, groups=groups) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | @misc.profiled_function 48 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 49 | r"""2D convolution with optional up/downsampling. 50 | 51 | Padding is performed only once at the beginning, not between the operations. 52 | 53 | Args: 54 | x: Input tensor of shape 55 | `[batch_size, in_channels, in_height, in_width]`. 56 | w: Weight tensor of shape 57 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 58 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 59 | calling upfirdn2d.setup_filter(). None = identity (default). 60 | up: Integer upsampling factor (default: 1). 61 | down: Integer downsampling factor (default: 1). 62 | padding: Padding with respect to the upsampled image. Can be a single number 63 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 64 | (default: 0). 65 | groups: Split input channels into N groups (default: 1). 66 | flip_weight: False = convolution, True = correlation (default: True). 67 | flip_filter: False = convolution, True = correlation (default: False). 68 | 69 | Returns: 70 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 71 | """ 72 | # Validate arguments. 73 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 74 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 75 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 76 | assert isinstance(up, int) and (up >= 1) 77 | assert isinstance(down, int) and (down >= 1) 78 | assert isinstance(groups, int) and (groups >= 1) 79 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 80 | fw, fh = _get_filter_size(f) 81 | px0, px1, py0, py1 = _parse_padding(padding) 82 | 83 | # Adjust padding to account for up/downsampling. 84 | if up > 1: 85 | px0 += (fw + up - 1) // 2 86 | px1 += (fw - up) // 2 87 | py0 += (fh + up - 1) // 2 88 | py1 += (fh - up) // 2 89 | if down > 1: 90 | px0 += (fw - down + 1) // 2 91 | px1 += (fw - down) // 2 92 | py0 += (fh - down + 1) // 2 93 | py1 += (fh - down) // 2 94 | 95 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 96 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 97 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 98 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 99 | return x 100 | 101 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 102 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 103 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 104 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 105 | return x 106 | 107 | # Fast path: downsampling only => use strided convolution. 108 | if down > 1 and up == 1: 109 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 110 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 111 | return x 112 | 113 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 114 | if up > 1: 115 | if groups == 1: 116 | w = w.transpose(0, 1) 117 | else: 118 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 119 | w = w.transpose(1, 2) 120 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 121 | px0 -= kw - 1 122 | px1 -= kw - up 123 | py0 -= kh - 1 124 | py1 -= kh - up 125 | pxt = max(min(-px0, -px1), 0) 126 | pyt = max(min(-py0, -py1), 0) 127 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 129 | if down > 1: 130 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 131 | return x 132 | 133 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 134 | if up == 1 and down == 1: 135 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 136 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 137 | 138 | # Fallback: Generic reference implementation. 139 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 140 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 141 | if down > 1: 142 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 143 | return x 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | 15 | //------------------------------------------------------------------------ 16 | // CUDA kernel parameters. 17 | 18 | struct filtered_lrelu_kernel_params 19 | { 20 | // These parameters decide which kernel to use. 21 | int up; // upsampling ratio (1, 2, 4) 22 | int down; // downsampling ratio (1, 2, 4) 23 | int2 fuShape; // [size, 1] | [size, size] 24 | int2 fdShape; // [size, 1] | [size, size] 25 | 26 | int _dummy; // Alignment. 27 | 28 | // Rest of the parameters. 29 | const void* x; // Input tensor. 30 | void* y; // Output tensor. 31 | const void* b; // Bias tensor. 32 | unsigned char* s; // Sign tensor in/out. NULL if unused. 33 | const float* fu; // Upsampling filter. 34 | const float* fd; // Downsampling filter. 35 | 36 | int2 pad0; // Left/top padding. 37 | float gain; // Additional gain factor. 38 | float slope; // Leaky ReLU slope on negative side. 39 | float clamp; // Clamp after nonlinearity. 40 | int flip; // Filter kernel flip for gradient computation. 41 | 42 | int tilesXdim; // Original number of horizontal output tiles. 43 | int tilesXrep; // Number of horizontal tiles per CTA. 44 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 45 | 46 | int4 xShape; // [width, height, channel, batch] 47 | int4 yShape; // [width, height, channel, batch] 48 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 49 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 50 | int swLimit; // Active width of sign tensor in bytes. 51 | 52 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 53 | longlong4 yStride; // 54 | int64_t bStride; // 55 | longlong3 fuStride; // 56 | longlong3 fdStride; // 57 | }; 58 | 59 | struct filtered_lrelu_act_kernel_params 60 | { 61 | void* x; // Input/output, modified in-place. 62 | unsigned char* s; // Sign tensor in/out. NULL if unused. 63 | 64 | float gain; // Additional gain factor. 65 | float slope; // Leaky ReLU slope on negative side. 66 | float clamp; // Clamp after nonlinearity. 67 | 68 | int4 xShape; // [width, height, channel, batch] 69 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 70 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 71 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 72 | }; 73 | 74 | //------------------------------------------------------------------------ 75 | // CUDA kernel specialization. 76 | 77 | struct filtered_lrelu_kernel_spec 78 | { 79 | void* setup; // Function for filter kernel setup. 80 | void* exec; // Function for main operation. 81 | int2 tileOut; // Width/height of launch tile. 82 | int numWarps; // Number of warps per thread block, determines launch block size. 83 | int xrep; // For processing multiple horizontal tiles per thread block. 84 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 85 | }; 86 | 87 | //------------------------------------------------------------------------ 88 | // CUDA kernel selection. 89 | 90 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 91 | template void* choose_filtered_lrelu_act_kernel(void); 92 | template cudaError_t copy_filters(cudaStream_t stream); 93 | 94 | //------------------------------------------------------------------------ 95 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for no signs mode (no gradients required). 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for sign read mode. 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include "filtered_lrelu.cu" 14 | 15 | // Template/kernel specializations for sign write mode. 16 | 17 | // Full op, 32-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Full op, 64-bit indexing. 22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 24 | 25 | // Activation/signs only for generic variant. 64-bit indexing. 26 | template void* choose_filtered_lrelu_act_kernel(void); 27 | template void* choose_filtered_lrelu_act_kernel(void); 28 | template void* choose_filtered_lrelu_act_kernel(void); 29 | 30 | // Copy filters to constant memory. 31 | template cudaError_t copy_filters(cudaStream_t stream); 32 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 12 | 13 | import torch 14 | 15 | #---------------------------------------------------------------------------- 16 | 17 | def fma(a, b, c): # => a * b + c 18 | return _FusedMultiplyAdd.apply(a, b, c) 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 23 | @staticmethod 24 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 25 | out = torch.addcmul(c, a, b) 26 | ctx.save_for_backward(a, b) 27 | ctx.c_shape = c.shape 28 | return out 29 | 30 | @staticmethod 31 | def backward(ctx, dout): # pylint: disable=arguments-differ 32 | a, b = ctx.saved_tensors 33 | c_shape = ctx.c_shape 34 | da = None 35 | db = None 36 | dc = None 37 | 38 | if ctx.needs_input_grad[0]: 39 | da = _unbroadcast(dout * b, a.shape) 40 | 41 | if ctx.needs_input_grad[1]: 42 | db = _unbroadcast(dout * a, b.shape) 43 | 44 | if ctx.needs_input_grad[2]: 45 | dc = _unbroadcast(dout, c_shape) 46 | 47 | return da, db, dc 48 | 49 | #---------------------------------------------------------------------------- 50 | 51 | def _unbroadcast(x, shape): 52 | extra_dims = x.ndim - len(shape) 53 | assert extra_dims >= 0 54 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 55 | if len(dim): 56 | x = x.sum(dim=dim, keepdim=True) 57 | if extra_dims: 58 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 59 | assert x.shape == shape 60 | return x 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """Custom replacement for `torch.nn.functional.grid_sample` that 12 | supports arbitrarily high order gradients between the input and output. 13 | Only works on 2D images and assumes 14 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 15 | 16 | import torch 17 | 18 | # pylint: disable=redefined-builtin 19 | # pylint: disable=arguments-differ 20 | # pylint: disable=protected-access 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | enabled = False # Enable the custom op by setting this to true. 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 62 | ctx.save_for_backward(grid) 63 | return grad_input, grad_grid 64 | 65 | @staticmethod 66 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 67 | _ = grad2_grad_grid # unused 68 | grid, = ctx.saved_tensors 69 | grad2_grad_output = None 70 | grad2_input = None 71 | grad2_grid = None 72 | 73 | if ctx.needs_input_grad[0]: 74 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 75 | 76 | assert not ctx.needs_input_grad[2] 77 | return grad2_grad_output, grad2_input, grad2_grid 78 | 79 | #---------------------------------------------------------------------------- 80 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include "upfirdn2d.h" 17 | 18 | //------------------------------------------------------------------------ 19 | 20 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 21 | { 22 | // Validate arguments. 23 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 24 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 25 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 26 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 27 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 28 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 29 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 30 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 31 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 32 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 33 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 34 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 35 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 36 | 37 | // Create output tensor. 38 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 39 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 40 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 41 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 42 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 43 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 44 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 45 | 46 | // Initialize CUDA kernel parameters. 47 | upfirdn2d_kernel_params p; 48 | p.x = x.data_ptr(); 49 | p.f = f.data_ptr(); 50 | p.y = y.data_ptr(); 51 | p.up = make_int2(upx, upy); 52 | p.down = make_int2(downx, downy); 53 | p.pad0 = make_int2(padx0, pady0); 54 | p.flip = (flip) ? 1 : 0; 55 | p.gain = gain; 56 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 57 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 58 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 59 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 60 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 61 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 62 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 63 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 64 | 65 | // Choose CUDA kernel. 66 | upfirdn2d_kernel_spec spec; 67 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 68 | { 69 | spec = choose_upfirdn2d_kernel(p); 70 | }); 71 | 72 | // Set looping options. 73 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 74 | p.loopMinor = spec.loopMinor; 75 | p.loopX = spec.loopX; 76 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 77 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 78 | 79 | // Compute grid size. 80 | dim3 blockSize, gridSize; 81 | if (spec.tileOutW < 0) // large 82 | { 83 | blockSize = dim3(4, 32, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | else // small 90 | { 91 | blockSize = dim3(256, 1, 1); 92 | gridSize = dim3( 93 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 94 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 95 | p.launchMajor); 96 | } 97 | 98 | // Launch CUDA kernel. 99 | void* args[] = {&p}; 100 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 101 | return y; 102 | } 103 | 104 | //------------------------------------------------------------------------ 105 | 106 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 107 | { 108 | m.def("upfirdn2d", &upfirdn2d); 109 | } 110 | 111 | //------------------------------------------------------------------------ 112 | -------------------------------------------------------------------------------- /eg3d/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | * 5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 6 | * property and proprietary rights in and to this material, related 7 | * documentation and any modifications thereto. Any use, reproduction, 8 | * disclosure or distribution of this material and related documentation 9 | * without an express license agreement from NVIDIA CORPORATION or 10 | * its affiliates is strictly prohibited. 11 | */ 12 | 13 | #include 14 | 15 | //------------------------------------------------------------------------ 16 | // CUDA kernel parameters. 17 | 18 | struct upfirdn2d_kernel_params 19 | { 20 | const void* x; 21 | const float* f; 22 | void* y; 23 | 24 | int2 up; 25 | int2 down; 26 | int2 pad0; 27 | int flip; 28 | float gain; 29 | 30 | int4 inSize; // [width, height, channel, batch] 31 | int4 inStride; 32 | int2 filterSize; // [width, height] 33 | int2 filterStride; 34 | int4 outSize; // [width, height, channel, batch] 35 | int4 outStride; 36 | int sizeMinor; 37 | int sizeMajor; 38 | 39 | int loopMinor; 40 | int loopMajor; 41 | int loopX; 42 | int launchMinor; 43 | int launchMajor; 44 | }; 45 | 46 | //------------------------------------------------------------------------ 47 | // CUDA kernel specialization. 48 | 49 | struct upfirdn2d_kernel_spec 50 | { 51 | void* kernel; 52 | int tileOutW; 53 | int tileOutH; 54 | int loopMinor; 55 | int loopX; 56 | }; 57 | 58 | //------------------------------------------------------------------------ 59 | // CUDA kernel selection. 60 | 61 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 62 | 63 | //------------------------------------------------------------------------ 64 | -------------------------------------------------------------------------------- /eg3d/training/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/training/crosssection_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import torch 12 | 13 | def sample_cross_section(G, ws, resolution=256, w=1.2): 14 | axis=0 15 | A, B = torch.meshgrid(torch.linspace(w/2, -w/2, resolution, device=ws.device), torch.linspace(-w/2, w/2, resolution, device=ws.device), indexing='ij') 16 | A, B = A.reshape(-1, 1), B.reshape(-1, 1) 17 | C = torch.zeros_like(A) 18 | coordinates = [A, B] 19 | coordinates.insert(axis, C) 20 | coordinates = torch.cat(coordinates, dim=-1).expand(ws.shape[0], -1, -1) 21 | 22 | sigma = G.sample_mixed(coordinates, torch.randn_like(coordinates), ws)['sigma'] 23 | return sigma.reshape(-1, 1, resolution, resolution) 24 | 25 | # if __name__ == '__main__': 26 | # sample_crossection(None) -------------------------------------------------------------------------------- /eg3d/training/volumetric_rendering/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty -------------------------------------------------------------------------------- /eg3d/training/volumetric_rendering/math_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Petr Kellnhofer 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | 25 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: 26 | """ 27 | Left-multiplies MxM @ NxM. Returns NxM. 28 | """ 29 | res = torch.matmul(vectors4, matrix.T) 30 | return res 31 | 32 | 33 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Normalize vector lengths. 36 | """ 37 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) 38 | 39 | def torch_dot(x: torch.Tensor, y: torch.Tensor): 40 | """ 41 | Dot product of two tensors. 42 | """ 43 | return (x * y).sum(-1) 44 | 45 | 46 | def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): 47 | """ 48 | Author: Petr Kellnhofer 49 | Intersects rays with the [-1, 1] NDC volume. 50 | Returns min and max distance of entry. 51 | Returns -1 for no intersection. 52 | https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection 53 | """ 54 | o_shape = rays_o.shape 55 | rays_o = rays_o.detach().reshape(-1, 3) 56 | rays_d = rays_d.detach().reshape(-1, 3) 57 | 58 | 59 | bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] 60 | bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] 61 | bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) 62 | is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) 63 | 64 | # Precompute inverse for stability. 65 | invdir = 1 / rays_d 66 | sign = (invdir < 0).long() 67 | 68 | # Intersect with YZ plane. 69 | tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] 70 | tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] 71 | 72 | # Intersect with XZ plane. 73 | tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] 74 | tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] 75 | 76 | # Resolve parallel rays. 77 | is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False 78 | 79 | # Use the shortest intersection. 80 | tmin = torch.max(tmin, tymin) 81 | tmax = torch.min(tmax, tymax) 82 | 83 | # Intersect with XY plane. 84 | tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] 85 | tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] 86 | 87 | # Resolve parallel rays. 88 | is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False 89 | 90 | # Use the shortest intersection. 91 | tmin = torch.max(tmin, tzmin) 92 | tmax = torch.min(tmax, tzmax) 93 | 94 | # Mark invalid. 95 | tmin[torch.logical_not(is_valid)] = -1 96 | tmax[torch.logical_not(is_valid)] = -2 97 | 98 | return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) 99 | 100 | 101 | def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): 102 | """ 103 | Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. 104 | Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. 105 | """ 106 | # create a tensor of 'num' steps from 0 to 1 107 | steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) 108 | 109 | # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings 110 | # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript 111 | # "cannot statically infer the expected size of a list in this contex", hence the code below 112 | for i in range(start.ndim): 113 | steps = steps.unsqueeze(-1) 114 | 115 | # the output starts at 'start' and increments until 'stop' in each dimension 116 | out = start[None] + steps * (stop - start)[None] 117 | 118 | return out 119 | -------------------------------------------------------------------------------- /eg3d/training/volumetric_rendering/ray_marcher.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. 13 | Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) 14 | """ 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | class MipRayMarcher2(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | 25 | def run_forward(self, colors, densities, depths, rendering_options): 26 | deltas = depths[:, :, 1:] - depths[:, :, :-1] 27 | colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 28 | densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 29 | depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 30 | 31 | 32 | if rendering_options['clamp_mode'] == 'softplus': 33 | densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better 34 | else: 35 | assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!" 36 | 37 | density_delta = densities_mid * deltas 38 | 39 | alpha = 1 - torch.exp(-density_delta) 40 | 41 | alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) 42 | weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] 43 | 44 | composite_rgb = torch.sum(weights * colors_mid, -2) 45 | weight_total = weights.sum(2) 46 | composite_depth = torch.sum(weights * depths_mid, -2) / weight_total 47 | 48 | # clip the composite to min/max range of depths 49 | composite_depth = torch.nan_to_num(composite_depth, float('inf')) 50 | composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) 51 | 52 | if rendering_options.get('white_back', False): 53 | composite_rgb = composite_rgb + 1 - weight_total 54 | 55 | composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) 56 | 57 | return composite_rgb, composite_depth, weights 58 | 59 | 60 | def forward(self, colors, densities, depths, rendering_options): 61 | composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) 62 | 63 | return composite_rgb, composite_depth, weights -------------------------------------------------------------------------------- /eg3d/training/volumetric_rendering/ray_sampler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | """ 12 | The ray sampler is a module that takes in camera matrices and resolution and batches of rays. 13 | Expects cam2world matrices that use the OpenCV camera coordinate system conventions. 14 | """ 15 | 16 | import torch 17 | 18 | class RaySampler(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None 22 | 23 | 24 | def forward(self, cam2world_matrix, intrinsics, resolution): 25 | """ 26 | Create batches of rays and return origins and directions. 27 | 28 | cam2world_matrix: (N, 4, 4) 29 | intrinsics: (N, 3, 3) 30 | resolution: int 31 | 32 | ray_origins: (N, M, 3) 33 | ray_dirs: (N, M, 2) 34 | """ 35 | N, M = cam2world_matrix.shape[0], resolution**2 36 | cam_locs_world = cam2world_matrix[:, :3, 3] 37 | fx = intrinsics[:, 0, 0] 38 | fy = intrinsics[:, 1, 1] 39 | cx = intrinsics[:, 0, 2] 40 | cy = intrinsics[:, 1, 2] 41 | sk = intrinsics[:, 0, 1] 42 | 43 | uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1./resolution) + (0.5/resolution) 44 | uv = uv.flip(0).reshape(2, -1).transpose(1, 0) 45 | uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) 46 | 47 | x_cam = uv[:, :, 0].view(N, -1) 48 | y_cam = uv[:, :, 1].view(N, -1) 49 | z_cam = torch.ones((N, M), device=cam2world_matrix.device) 50 | 51 | x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam 52 | y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam 53 | 54 | cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) 55 | 56 | world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] 57 | 58 | ray_dirs = world_rel_points - cam_locs_world[:, None, :] 59 | ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) 60 | 61 | ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) 62 | 63 | return ray_origins, ray_dirs -------------------------------------------------------------------------------- /eg3d/viz/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | # empty 12 | -------------------------------------------------------------------------------- /eg3d/viz/backbone_cache_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class BackboneCacheWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.cache_backbone = True 20 | 21 | @imgui_utils.scoped_by_object_id 22 | def __call__(self, show=True): 23 | viz = self.viz 24 | 25 | if show: 26 | imgui.text('Cache Backbone') 27 | imgui.same_line(viz.label_w + viz.spacing * 4) 28 | _clicked, self.cache_backbone = imgui.checkbox('##backbonecache', self.cache_backbone) 29 | imgui.same_line(viz.label_w + viz.spacing * 10) 30 | imgui.text('Note that when enabled, you may be unable to view intermediate backbone weights below') 31 | 32 | viz.args.do_backbone_caching = self.cache_backbone 33 | 34 | #---------------------------------------------------------------------------- 35 | -------------------------------------------------------------------------------- /eg3d/viz/capture_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import re 13 | import numpy as np 14 | import imgui 15 | import PIL.Image 16 | from gui_utils import imgui_utils 17 | from . import renderer 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class CaptureWidget: 22 | def __init__(self, viz): 23 | self.viz = viz 24 | self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) 25 | self.dump_image = False 26 | self.dump_gui = False 27 | self.defer_frames = 0 28 | self.disabled_time = 0 29 | 30 | def dump_png(self, image): 31 | viz = self.viz 32 | try: 33 | _height, _width, channels = image.shape 34 | assert channels in [1, 3] 35 | assert image.dtype == np.uint8 36 | os.makedirs(self.path, exist_ok=True) 37 | file_id = 0 38 | for entry in os.scandir(self.path): 39 | if entry.is_file(): 40 | match = re.fullmatch(r'(\d+).*', entry.name) 41 | if match: 42 | file_id = max(file_id, int(match.group(1)) + 1) 43 | if channels == 1: 44 | pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') 45 | else: 46 | pil_image = PIL.Image.fromarray(image, 'RGB') 47 | pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) 48 | except: 49 | viz.result.error = renderer.CapturedException() 50 | 51 | @imgui_utils.scoped_by_object_id 52 | def __call__(self, show=True): 53 | viz = self.viz 54 | if show: 55 | with imgui_utils.grayed_out(self.disabled_time != 0): 56 | imgui.text('Capture') 57 | imgui.same_line(viz.label_w) 58 | _changed, self.path = imgui_utils.input_text('##path', self.path, 1024, 59 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 60 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 61 | help_text='PATH') 62 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '': 63 | imgui.set_tooltip(self.path) 64 | imgui.same_line() 65 | if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): 66 | self.dump_image = True 67 | self.defer_frames = 2 68 | self.disabled_time = 0.5 69 | imgui.same_line() 70 | if imgui_utils.button('Save GUI', width=-1, enabled=(self.disabled_time == 0)): 71 | self.dump_gui = True 72 | self.defer_frames = 2 73 | self.disabled_time = 0.5 74 | 75 | self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) 76 | if self.defer_frames > 0: 77 | self.defer_frames -= 1 78 | elif self.dump_image: 79 | if 'image' in viz.result: 80 | self.dump_png(viz.result.image) 81 | self.dump_image = False 82 | elif self.dump_gui: 83 | viz.capture_next_frame() 84 | self.dump_gui = False 85 | captured_frame = viz.pop_captured_frame() 86 | if captured_frame is not None: 87 | self.dump_png(captured_frame) 88 | 89 | #---------------------------------------------------------------------------- 90 | -------------------------------------------------------------------------------- /eg3d/viz/conditioning_pose_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import numpy as np 12 | import imgui 13 | import dnnlib 14 | from gui_utils import imgui_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | class ConditioningPoseWidget: 19 | def __init__(self, viz): 20 | self.viz = viz 21 | self.pose = dnnlib.EasyDict(yaw=0, pitch=0, anim=False, speed=0.25) 22 | self.pose_def = dnnlib.EasyDict(self.pose) 23 | 24 | def drag(self, dx, dy): 25 | viz = self.viz 26 | self.pose.yaw += -dx / viz.font_size * 3e-2 27 | self.pose.pitch += -dy / viz.font_size * 3e-2 28 | 29 | @imgui_utils.scoped_by_object_id 30 | def __call__(self, show=True): 31 | viz = self.viz 32 | if show: 33 | imgui.text('Cond Pose') 34 | imgui.same_line(viz.label_w) 35 | yaw = self.pose.yaw 36 | pitch = self.pose.pitch 37 | with imgui_utils.item_width(viz.font_size * 5): 38 | changed, (new_yaw, new_pitch) = imgui.input_float2('##frac', yaw, pitch, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 39 | if changed: 40 | self.pose.yaw = new_yaw 41 | self.pose.pitch = new_pitch 42 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 43 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) 44 | if dragging: 45 | self.drag(dx, dy) 46 | imgui.same_line() 47 | snapped = dnnlib.EasyDict(self.pose, yaw=round(self.pose.yaw, 1), pitch=round(self.pose.pitch, 1)) 48 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.pose != snapped)): 49 | self.pose = snapped 50 | imgui.same_line() 51 | if imgui_utils.button('Reset', width=-1, enabled=(self.pose != self.pose_def)): 52 | self.pose = dnnlib.EasyDict(self.pose_def) 53 | 54 | viz.args.conditioning_yaw = self.pose.yaw 55 | viz.args.conditioning_pitch = self.pose.pitch 56 | 57 | #---------------------------------------------------------------------------- 58 | -------------------------------------------------------------------------------- /eg3d/viz/latent_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import numpy as np 12 | import imgui 13 | import dnnlib 14 | from gui_utils import imgui_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | class LatentWidget: 19 | def __init__(self, viz): 20 | self.viz = viz 21 | self.latent = dnnlib.EasyDict(x=1, y=0, anim=False, speed=0.25) 22 | self.latent_def = dnnlib.EasyDict(self.latent) 23 | self.step_y = 100 24 | 25 | def drag(self, dx, dy): 26 | viz = self.viz 27 | self.latent.x += dx / viz.font_size * 4e-2 28 | self.latent.y += dy / viz.font_size * 4e-2 29 | 30 | @imgui_utils.scoped_by_object_id 31 | def __call__(self, show=True): 32 | viz = self.viz 33 | if show: 34 | imgui.text('Latent') 35 | imgui.same_line(viz.label_w) 36 | seed = round(self.latent.x) + round(self.latent.y) * self.step_y 37 | with imgui_utils.item_width(viz.font_size * 8): 38 | changed, seed = imgui.input_int('##seed', seed, step=0) 39 | if changed: 40 | self.latent.x = seed 41 | self.latent.y = 0 42 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 43 | frac_x = self.latent.x - round(self.latent.x) 44 | frac_y = self.latent.y - round(self.latent.y) 45 | with imgui_utils.item_width(viz.font_size * 5): 46 | changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 47 | if changed: 48 | self.latent.x += new_frac_x - frac_x 49 | self.latent.y += new_frac_y - frac_y 50 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 51 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) 52 | if dragging: 53 | self.drag(dx, dy) 54 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) 55 | _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim) 56 | imgui.same_line(round(viz.font_size * 28.7)) 57 | with imgui_utils.item_width(-2 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim): 58 | changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3) 59 | if changed: 60 | self.latent.speed = speed 61 | imgui.same_line() 62 | snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y)) 63 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)): 64 | self.latent = snapped 65 | imgui.same_line() 66 | if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)): 67 | self.latent = dnnlib.EasyDict(self.latent_def) 68 | 69 | if self.latent.anim: 70 | self.latent.x += viz.frame_delta * self.latent.speed 71 | viz.args.w0_seeds = [] # [[seed, weight], ...] 72 | for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]: 73 | seed_x = np.floor(self.latent.x) + ofs_x 74 | seed_y = np.floor(self.latent.y) + ofs_y 75 | seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1) 76 | weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y)) 77 | if weight > 0: 78 | viz.args.w0_seeds.append([seed, weight]) 79 | 80 | #---------------------------------------------------------------------------- 81 | -------------------------------------------------------------------------------- /eg3d/viz/performance_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import array 12 | import numpy as np 13 | import imgui 14 | from gui_utils import imgui_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | class PerformanceWidget: 19 | def __init__(self, viz): 20 | self.viz = viz 21 | self.gui_times = [float('nan')] * 60 22 | self.render_times = [float('nan')] * 30 23 | self.fps_limit = 60 24 | self.use_vsync = False 25 | self.is_async = False 26 | self.force_fp32 = False 27 | 28 | @imgui_utils.scoped_by_object_id 29 | def __call__(self, show=True): 30 | viz = self.viz 31 | self.gui_times = self.gui_times[1:] + [viz.frame_delta] 32 | if 'render_time' in viz.result: 33 | self.render_times = self.render_times[1:] + [viz.result.render_time] 34 | del viz.result.render_time 35 | 36 | if show: 37 | imgui.text('GUI') 38 | imgui.same_line(viz.label_w) 39 | with imgui_utils.item_width(viz.font_size * 8): 40 | imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0) 41 | imgui.same_line(viz.label_w + viz.font_size * 9) 42 | t = [x for x in self.gui_times if x > 0] 43 | t = np.mean(t) if len(t) > 0 else 0 44 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 45 | imgui.same_line(viz.label_w + viz.font_size * 14) 46 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 47 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 48 | with imgui_utils.item_width(viz.font_size * 6): 49 | _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 50 | self.fps_limit = min(max(self.fps_limit, 5), 1000) 51 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 52 | _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync) 53 | 54 | if show: 55 | imgui.text('Render') 56 | imgui.same_line(viz.label_w) 57 | with imgui_utils.item_width(viz.font_size * 8): 58 | imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0) 59 | imgui.same_line(viz.label_w + viz.font_size * 9) 60 | t = [x for x in self.render_times if x > 0] 61 | t = np.mean(t) if len(t) > 0 else 0 62 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 63 | imgui.same_line(viz.label_w + viz.font_size * 14) 64 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 65 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 66 | _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async) 67 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 68 | _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32) 69 | 70 | viz.set_fps_limit(self.fps_limit) 71 | viz.set_vsync(self.use_vsync) 72 | viz.set_async(self.is_async) 73 | viz.args.force_fp32 = self.force_fp32 74 | 75 | #---------------------------------------------------------------------------- 76 | -------------------------------------------------------------------------------- /eg3d/viz/pickle_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import glob 12 | import os 13 | import re 14 | 15 | import dnnlib 16 | import imgui 17 | import numpy as np 18 | from gui_utils import imgui_utils 19 | 20 | from . import renderer 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | def _locate_results(pattern): 25 | return pattern 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | class PickleWidget: 30 | def __init__(self, viz): 31 | self.viz = viz 32 | self.search_dirs = [] 33 | self.cur_pkl = None 34 | self.user_pkl = '' 35 | self.recent_pkls = [] 36 | self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} 37 | self.browse_refocus = False 38 | self.load('', ignore_errors=True) 39 | 40 | def add_recent(self, pkl, ignore_errors=False): 41 | try: 42 | resolved = self.resolve_pkl(pkl) 43 | if resolved not in self.recent_pkls: 44 | self.recent_pkls.append(resolved) 45 | except: 46 | if not ignore_errors: 47 | raise 48 | 49 | def load(self, pkl, ignore_errors=False): 50 | viz = self.viz 51 | viz.clear_result() 52 | viz.skip_frame() # The input field will change on next frame. 53 | try: 54 | resolved = self.resolve_pkl(pkl) 55 | name = resolved.replace('\\', '/').split('/')[-1] 56 | self.cur_pkl = resolved 57 | self.user_pkl = resolved 58 | viz.result.message = f'Loading {name}...' 59 | viz.defer_rendering() 60 | if resolved in self.recent_pkls: 61 | self.recent_pkls.remove(resolved) 62 | self.recent_pkls.insert(0, resolved) 63 | except: 64 | self.cur_pkl = None 65 | self.user_pkl = pkl 66 | if pkl == '': 67 | viz.result = dnnlib.EasyDict(message='No network pickle loaded') 68 | else: 69 | viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) 70 | if not ignore_errors: 71 | raise 72 | 73 | @imgui_utils.scoped_by_object_id 74 | def __call__(self, show=True): 75 | viz = self.viz 76 | recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] 77 | if show: 78 | imgui.text('Pickle') 79 | imgui.same_line(viz.label_w) 80 | changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024, 81 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 82 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 83 | help_text=' | | | | /.pkl') 84 | if changed: 85 | self.load(self.user_pkl, ignore_errors=True) 86 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': 87 | imgui.set_tooltip(self.user_pkl) 88 | imgui.same_line() 89 | if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): 90 | imgui.open_popup('recent_pkls_popup') 91 | imgui.same_line() 92 | if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): 93 | imgui.open_popup('browse_pkls_popup') 94 | self.browse_cache.clear() 95 | self.browse_refocus = True 96 | 97 | if imgui.begin_popup('recent_pkls_popup'): 98 | for pkl in recent_pkls: 99 | clicked, _state = imgui.menu_item(pkl) 100 | if clicked: 101 | self.load(pkl, ignore_errors=True) 102 | imgui.end_popup() 103 | 104 | if imgui.begin_popup('browse_pkls_popup'): 105 | def recurse(parents): 106 | key = tuple(parents) 107 | items = self.browse_cache.get(key, None) 108 | if items is None: 109 | items = self.list_runs_and_pkls(parents) 110 | self.browse_cache[key] = items 111 | for item in items: 112 | if item.type == 'run' and imgui.begin_menu(item.name): 113 | recurse([item.path]) 114 | imgui.end_menu() 115 | if item.type == 'pkl': 116 | clicked, _state = imgui.menu_item(item.name) 117 | if clicked: 118 | self.load(item.path, ignore_errors=True) 119 | if len(items) == 0: 120 | with imgui_utils.grayed_out(): 121 | imgui.menu_item('No results found') 122 | recurse(self.search_dirs) 123 | if self.browse_refocus: 124 | imgui.set_scroll_here() 125 | viz.skip_frame() # Focus will change on next frame. 126 | self.browse_refocus = False 127 | imgui.end_popup() 128 | 129 | paths = viz.pop_drag_and_drop_paths() 130 | if paths is not None and len(paths) >= 1: 131 | self.load(paths[0], ignore_errors=True) 132 | 133 | viz.args.pkl = self.cur_pkl 134 | 135 | def list_runs_and_pkls(self, parents): 136 | items = [] 137 | run_regex = re.compile(r'\d+-.*') 138 | pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') 139 | for parent in set(parents): 140 | if os.path.isdir(parent): 141 | for entry in os.scandir(parent): 142 | if entry.is_dir() and run_regex.fullmatch(entry.name): 143 | items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) 144 | if entry.is_file() and pkl_regex.fullmatch(entry.name): 145 | items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) 146 | 147 | items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) 148 | return items 149 | 150 | def resolve_pkl(self, pattern): 151 | assert isinstance(pattern, str) 152 | assert pattern != '' 153 | 154 | # URL => return as is. 155 | if dnnlib.util.is_url(pattern): 156 | return pattern 157 | 158 | # Short-hand pattern => locate. 159 | path = _locate_results(pattern) 160 | 161 | # Run dir => pick the last saved snapshot. 162 | if os.path.isdir(path): 163 | pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) 164 | if len(pkl_files) == 0: 165 | raise IOError(f'No network pickle found in "{path}"') 166 | path = pkl_files[-1] 167 | 168 | # Normalize. 169 | path = os.path.abspath(path) 170 | return path 171 | 172 | #---------------------------------------------------------------------------- 173 | -------------------------------------------------------------------------------- /eg3d/viz/pose_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import numpy as np 12 | import imgui 13 | import dnnlib 14 | from gui_utils import imgui_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | class PoseWidget: 19 | def __init__(self, viz): 20 | self.viz = viz 21 | self.pose = dnnlib.EasyDict(yaw=0, pitch=0, anim=False, speed=0.25) 22 | self.pose_def = dnnlib.EasyDict(self.pose) 23 | 24 | self.lookat_point_choice = 0 25 | self.lookat_point_option = ['auto', 'ffhq', 'shapenet', 'afhq', 'manual'] 26 | self.lookat_point_labels = ['Auto Detect', 'FFHQ Default', 'Shapenet Default', 'AFHQ Default', 'Manual'] 27 | self.lookat_point = (0.0, 0.0, 0.2) 28 | 29 | def drag(self, dx, dy): 30 | viz = self.viz 31 | self.pose.yaw += -dx / viz.font_size * 3e-2 32 | self.pose.pitch += -dy / viz.font_size * 3e-2 33 | 34 | @imgui_utils.scoped_by_object_id 35 | def __call__(self, show=True): 36 | viz = self.viz 37 | if show: 38 | imgui.text('Pose') 39 | imgui.same_line(viz.label_w) 40 | yaw = self.pose.yaw 41 | pitch = self.pose.pitch 42 | with imgui_utils.item_width(viz.font_size * 5): 43 | changed, (new_yaw, new_pitch) = imgui.input_float2('##pose', yaw, pitch, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 44 | if changed: 45 | self.pose.yaw = new_yaw 46 | self.pose.pitch = new_pitch 47 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 48 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) 49 | if dragging: 50 | self.drag(dx, dy) 51 | imgui.same_line() 52 | snapped = dnnlib.EasyDict(self.pose, yaw=round(self.pose.yaw, 1), pitch=round(self.pose.pitch, 1)) 53 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.pose != snapped)): 54 | self.pose = snapped 55 | imgui.same_line() 56 | if imgui_utils.button('Reset', width=-1, enabled=(self.pose != self.pose_def)): 57 | self.pose = dnnlib.EasyDict(self.pose_def) 58 | 59 | # New line starts here 60 | imgui.text('LookAt Point') 61 | imgui.same_line(viz.label_w) 62 | with imgui_utils.item_width(viz.font_size * 8): 63 | _clicked, self.lookat_point_choice = imgui.combo('', self.lookat_point_choice, self.lookat_point_labels) 64 | lookat_point = self.lookat_point_option[self.lookat_point_choice] 65 | if lookat_point == 'auto': 66 | self.lookat_point = None 67 | if lookat_point == 'ffhq': 68 | self.lookat_point = (0.0, 0.0, 0.2) 69 | changes_enabled=False 70 | if lookat_point == 'shapenet': 71 | self.lookat_point = (0.0, 0.0, 0.0) 72 | changes_enabled=False 73 | if lookat_point == 'afhq': 74 | self.lookat_point = (0.0, 0.0, 0.0) 75 | changes_enabled=False 76 | if lookat_point == 'manual': 77 | if self.lookat_point is None: 78 | self.lookat_point = (0.0, 0.0, 0.0) 79 | changes_enabled=True 80 | if lookat_point != 'auto': 81 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 82 | with imgui_utils.item_width(viz.font_size * 16): 83 | with imgui_utils.grayed_out(not changes_enabled): 84 | _changed, self.lookat_point = imgui.input_float3('##lookat', *self.lookat_point, format='%.2f', flags=(imgui.INPUT_TEXT_READ_ONLY if not changes_enabled else 0)) 85 | 86 | 87 | viz.args.yaw = self.pose.yaw 88 | viz.args.pitch = self.pose.pitch 89 | 90 | viz.args.lookat_point = self.lookat_point 91 | 92 | #---------------------------------------------------------------------------- 93 | -------------------------------------------------------------------------------- /eg3d/viz/render_depth_sample_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class RenderDepthSampleWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.depth_mult = 2 20 | self.depth_importance_mult = 2 21 | self.render_types = [.5, 1, 2, 4] 22 | self.labels = ['0.5x', '1x', '2x', '4x'] 23 | 24 | @imgui_utils.scoped_by_object_id 25 | def __call__(self, show=True): 26 | viz = self.viz 27 | 28 | if show: 29 | imgui.text('Render Type') 30 | imgui.same_line(viz.label_w) 31 | with imgui_utils.item_width(viz.font_size * 4): 32 | _clicked, self.depth_mult = imgui.combo('Depth Sample Multiplier', self.depth_mult, self.labels) 33 | imgui.same_line(viz.label_w + viz.font_size * 16 + viz.spacing * 2) 34 | with imgui_utils.item_width(viz.font_size * 4): 35 | _clicked, self.depth_importance_mult = imgui.combo('Depth Sample Importance Multiplier', self.depth_importance_mult, self.labels) 36 | 37 | viz.args.depth_mult = self.render_types[self.depth_mult] 38 | viz.args.depth_importance_mult = self.render_types[self.depth_importance_mult] 39 | 40 | #---------------------------------------------------------------------------- 41 | -------------------------------------------------------------------------------- /eg3d/viz/render_type_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class RenderTypeWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.render_type = 0 20 | self.render_types = ['image', 'image_depth', 'image_raw'] 21 | self.labels = ['RGB Image', 'Depth Image', 'Neural Rendered Image'] 22 | 23 | @imgui_utils.scoped_by_object_id 24 | def __call__(self, show=True): 25 | viz = self.viz 26 | 27 | if show: 28 | imgui.text('Render Type') 29 | imgui.same_line(viz.label_w) 30 | with imgui_utils.item_width(viz.font_size * 10): 31 | _clicked, self.render_type = imgui.combo('', self.render_type, self.labels) 32 | 33 | viz.args.render_type = self.render_types[self.render_type] 34 | 35 | #---------------------------------------------------------------------------- 36 | -------------------------------------------------------------------------------- /eg3d/viz/stylemix_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class StyleMixingWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.seed_def = 1000 20 | self.seed = self.seed_def 21 | self.animate = False 22 | self.enables = [] 23 | 24 | @imgui_utils.scoped_by_object_id 25 | def __call__(self, show=True): 26 | viz = self.viz 27 | num_ws = viz.result.get('num_ws', 0) 28 | num_enables = viz.result.get('num_ws', 18) 29 | self.enables += [False] * max(num_enables - len(self.enables), 0) 30 | 31 | if show: 32 | imgui.text('Stylemix') 33 | imgui.same_line(viz.label_w) 34 | with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0): 35 | _changed, self.seed = imgui.input_int('##seed', self.seed) 36 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 37 | with imgui_utils.grayed_out(num_ws == 0): 38 | _clicked, self.animate = imgui.checkbox('Anim', self.animate) 39 | 40 | pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w 41 | pos1 = pos2 - imgui.get_text_line_height() - viz.spacing 42 | pos0 = viz.label_w + viz.font_size * 12 43 | imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) 44 | for idx in range(num_enables): 45 | imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1)))) 46 | if idx == 0: 47 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3) 48 | with imgui_utils.grayed_out(num_ws == 0): 49 | _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx]) 50 | if imgui.is_item_hovered(): 51 | imgui.set_tooltip(f'{idx}') 52 | imgui.pop_style_var(1) 53 | 54 | imgui.same_line(pos2) 55 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3) 56 | with imgui_utils.grayed_out(num_ws == 0): 57 | if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))): 58 | self.seed = self.seed_def 59 | self.animate = False 60 | self.enables = [False] * num_enables 61 | 62 | if any(self.enables[:num_ws]): 63 | viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable] 64 | viz.args.stylemix_seed = self.seed & ((1 << 32) - 1) 65 | if self.animate: 66 | self.seed += 1 67 | 68 | #---------------------------------------------------------------------------- 69 | -------------------------------------------------------------------------------- /eg3d/viz/trunc_noise_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class TruncationNoiseWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.prev_num_ws = 0 20 | self.trunc_psi = 0.7 21 | self.trunc_cutoff = 7 22 | self.noise_enable = True 23 | self.noise_seed = 0 24 | self.noise_anim = False 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | num_ws = viz.result.get('num_ws', 0) 30 | has_noise = viz.result.get('has_noise', False) 31 | if num_ws > 0 and num_ws != self.prev_num_ws: 32 | if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws: 33 | self.trunc_cutoff = num_ws 34 | self.prev_num_ws = num_ws 35 | 36 | if show: 37 | imgui.text('Truncate') 38 | imgui.same_line(viz.label_w) 39 | with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0): 40 | _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f') 41 | imgui.same_line() 42 | if num_ws == 0: 43 | imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False) 44 | else: 45 | with imgui_utils.item_width(viz.font_size * 8 + viz.spacing): 46 | changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d') 47 | if changed: 48 | self.trunc_cutoff = min(max(new_cutoff, 0), num_ws) 49 | 50 | with imgui_utils.grayed_out(not has_noise): 51 | imgui.same_line() 52 | _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable) 53 | imgui.same_line(viz.font_size * 28.7) 54 | with imgui_utils.grayed_out(not self.noise_enable): 55 | with imgui_utils.item_width(-3 - viz.button_w - viz.spacing - viz.font_size * 4): 56 | _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed) 57 | imgui.same_line(spacing=0) 58 | _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim) 59 | 60 | is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws) 61 | is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim) 62 | with imgui_utils.grayed_out(is_def_trunc and not has_noise): 63 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 64 | if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)): 65 | self.prev_num_ws = num_ws 66 | self.trunc_psi = 0.7 67 | self.trunc_cutoff = 7 68 | self.noise_enable = True 69 | self.noise_seed = 0 70 | self.noise_anim = False 71 | 72 | if self.noise_anim: 73 | self.noise_seed += 1 74 | viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed) 75 | viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random') 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /eg3d/viz/zoom_widget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from inspect import formatargvalues 12 | import numpy as np 13 | import imgui 14 | import dnnlib 15 | from gui_utils import imgui_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class ZoomWidget: 20 | def __init__(self, viz): 21 | self.viz = viz 22 | self.fov = 18.837 23 | self.fov_default = 18.837 24 | 25 | @imgui_utils.scoped_by_object_id 26 | def __call__(self, show=True): 27 | viz = self.viz 28 | if show: 29 | imgui.text('FOV') 30 | imgui.same_line(viz.label_w) 31 | with imgui_utils.item_width(viz.font_size * 10): 32 | _changed, self.fov = imgui.slider_float('##fov', self.fov, 12, 45, format='%.2f Degrees') 33 | 34 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) 35 | snapped = round(self.fov) 36 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.fov != snapped)): 37 | self.fov = snapped 38 | imgui.same_line() 39 | if imgui_utils.button('Reset', width=-1, enabled=(abs(self.fov - self.fov_default)) > .01): 40 | self.fov = self.fov_default 41 | 42 | viz.args.focal_length = float(1 / (np.tan(self.fov * 3.14159 / 360) * 1.414)) 43 | #---------------------------------------------------------------------------- 44 | --------------------------------------------------------------------------------