├── .gitignore ├── ENVIRONMENT.md ├── README.md ├── demo.py ├── external ├── README.md ├── einops_exts.py ├── external_utils.py ├── gridencoder │ ├── __init__.py │ ├── backend.py │ ├── grid.py │ ├── setup.py │ └── src │ │ ├── bindings.cpp │ │ ├── gridencoder.cu │ │ └── gridencoder.h ├── imagen_pytorch.py ├── ldm │ ├── configs │ │ └── sd-vae.yaml │ ├── models │ │ └── autoencoder.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ └── util.py ├── nerf │ ├── clip_utils.py │ ├── gui.py │ ├── network.py │ ├── network_df.py │ ├── network_ff.py │ ├── network_grid.py │ ├── network_tcnn.py │ ├── provider.py │ ├── renderer.py │ ├── renderer_df.py │ └── utils.py ├── ngp_activation.py ├── ngp_encoder.py └── plms.py ├── media └── teaser.jpg ├── raymarching ├── README.md ├── __init__.py ├── backend.py ├── raymarching.py ├── setup.py └── src │ ├── bindings.cpp │ ├── raymarching.cu │ └── raymarching.h ├── requirements.txt ├── sparsefusion ├── distillation.py ├── eft.py └── vldm.py ├── train.py └── utils ├── camera_utils.py ├── check_args.py ├── co3d_dataloader.py ├── co3d_toy_dataloader.py ├── common_utils.py ├── eft_raymarcher.py ├── eft_renderer.py ├── load_dataset.py ├── load_model.py └── render_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | **/__pycache__/ 3 | **/build/ 4 | **build** 5 | vis/ 6 | output/ 7 | out/ 8 | logs/ 9 | data/ 10 | checkpoints/ 11 | *.zip 12 | *.gz 13 | *.tar 14 | *.pt 15 | *.so -------------------------------------------------------------------------------- /ENVIRONMENT.md: -------------------------------------------------------------------------------- 1 | # Environment Setup 2 | We describe steps (in Linux command line) to setup the environment for SparseFusion. 3 | 4 | ## Conda Environment 5 | We install and setup a conda environment. 6 | 7 | ### (optional) Install Conda 8 | Required if conda not installed. 9 | ```bash 10 | cd ~ 11 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 12 | chmod +x Miniconda3-latest-Linux-x86_64.sh 13 | ./Miniconda3-latest-Linux-x86_64.sh 14 | export PATH="/home/username/miniconda/bin:$PATH" 15 | conda init 16 | source ~/.bashrc 17 | ``` 18 | 19 | ### Create New Environment 20 | ```bash 21 | conda create -n sparsefusion python=3.8 22 | conda activate sparsefusion 23 | ``` 24 | 25 | ## Install Dependencies 26 | We install the necessary dependencies. 27 | 28 | ### GCC and Cuda 29 | Make sure to do this first! 30 | 31 | We also assume that nvidia drivers and `cuda=11.3.x` is installed. 32 | ```bash 33 | conda install -c conda-forge cxx-compiler=1.3.0 34 | conda install -c conda-forge cudatoolkit-dev 35 | conda install -c conda-forge ninja 36 | ``` 37 | 38 | ### Python Libraries 39 | ```bash 40 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch 41 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 42 | conda install -c pytorch3d pytorch3d 43 | ``` 44 | 45 | ### Support Stable Diffusion 46 | ```bash 47 | pip install transformers==4.19.2 pytorch-lightning==1.4.2 torchmetrics==0.6.0 48 | pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 49 | ``` 50 | 51 | ### (optional) Install CO3D 52 | Required if using full CO3D dataset. 53 | ```bash 54 | git clone https://github.com/facebookresearch/co3d 55 | cd co3d 56 | pip install -r requirements.txt 57 | pip install -e . 58 | ``` 59 | 60 | ### Install Other SparseFusion Requirements 61 | ```bash 62 | cd sparsefusion 63 | pip install -r requirements.txt 64 | ``` 65 | 66 | ## Building Extensions 67 | We require a few extensions from [torch-ngp](https://github.com/ashawkey/torch-ngp). We detail how to install them below. See more details on the [torch-ngp](https://github.com/ashawkey/torch-ngp) Github. 68 | 69 | ### Build gridencoder 70 | ```bash 71 | pip install ./external/gridencoder 72 | ``` 73 | 74 | ### Build raymarcher 75 | ```bash 76 | pip install ./raymarching 77 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparseFusion 2 | 3 | [**SparseFusion: Distilling View-conditioned Diffusion for 3D Reconstruction**](https://sparsefusion.github.io/)
4 | [Zhizhuo Zhou](https://www.zhiz.dev/), 5 | [Shubham Tulsiani](https://shubhtuls.github.io/)
6 | _CVPR '23 | [GitHub](https://github.com/zhizdev/sparsefusion) | [arXiv](https://arxiv.org/abs/2212.00792) | [Project page](https://sparsefusion.github.io/)_ 7 | 8 | ![txt2img-stable2](media/teaser.jpg) 9 | SparseFusion reconstructs a consistent and realistic 3D neural scene representation from as few as 2 input images with known relative pose. SparseFusion is able to generate detailed and plausible structures in uncertain or unobserved regions (such as front of the hydrant, teddybear's face, back of the laptop, or left side of the toybus). 10 | 11 | --- 12 | ## Shoutouts and Credits 13 | This project is built on top of open-source code. We thank the open-source research community and credit our use of parts of [Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Imagen Pytorch](https://github.com/lucidrains/imagen-pytorch), and [torch-ngp](https://github.com/ashawkey/torch-ngp) [below](#acknowledgements). 14 | 15 | 16 | ## Code 17 | Our code release contains: 18 | 19 | 1. Code for inference 20 | 2. Code for training 21 | 3. Pretrained weights for 10 categories 22 | 23 | For bugs and issues, please open an issue on GitHub and I will try to address it promptly. 24 | 25 | --- 26 | ## Environment Setup 27 | Please follow the environment setup guide in [ENVIRONMENT.md](ENVIRONMENT.md). 28 | 29 | ## Dataset 30 | We provide two options for datasets, the original CO3Dv2 dataset and also a heavily cutdown toy dataset for demonstration purposes only. Please download at least one dataset. 31 | 32 | 1. (optional) Download CO3Dv2 dataset (5.5TB) [here](https://github.com/facebookresearch/co3d) and follow instructions to extract them to a folder. We assume the default location to be `data/co3d/{category_name}`. 33 | 2. Download the toy evaluation only CO3Dv2 dataset (6.7GB) [here](https://drive.google.com/drive/folders/1IzgFjdgm_RjCHe2WOkIQa4BRdgKuSglL?usp=share_link) and place them in a folder. We assume the default location to be `data/co3d_toy/{category_name}`. 34 | 35 | ## Pretrained Weights 36 | SparseFusion requires both SparseFusion weights and Stable Diffusion VAE weights. 37 | 1. Find SparseFusion weights [here](https://drive.google.com/drive/folders/1Czsnf-PVjwH-HL7K5mTt_kF9u-PVWRyL?usp=share_link). Please download and put in `checkpoints/sf/{category_name}`. 38 | 2. Download Stable Diffusion v-1-3 weights [here](https://huggingface.co/CompVis/stable-diffusion-v-1-3-original) and rename `sd-v1-3.ckpt` to `sd-v1-3-vae.ckpt`. While our code is compatible with the default downloaded weight, we only use the VAE weights from Stable Diffusion. We assume the default location and filename of the vae checkpoint to be `checkpoints/sd/sd-v1-3-vae.ckpt`. 39 | 40 | ## Evaluation 41 | 42 | 43 | ### Examples 44 | To run evaluation, assuming the CO3D toy dataset and model weights are in the default paths specified above, simply pass in `-d, --dataset_name` and `-c, --category`: 45 | ```shell 46 | $ python demo.py -d co3d_toy -c hydrant 47 | ``` 48 | 49 | To specify specific scenes on evaluate on, pass the desired index `0,5,7` to `-i, --idx`. 50 | ```shell 51 | $ python demo.py -d co3d_toy -c hydrant -i 0,5,7 52 | ``` 53 | 54 | To specify the number of input views to use, specify `-v, --input_views`. 55 | ```shell 56 | $ python demo.py -d co3d_toy -c hydrant -i 0,5,7 -v 3 57 | ``` 58 | 59 | To specify a custom dataset root location, specify `-r, --root`. 60 | ```shell 61 | $ python demo.py -d co3d_toy -r data/co3d_toy -c hydrant -i 0,5,7 -v 3 62 | ``` 63 | 64 | To specify custom model checkpoints, specify `--eft`, `--vldm`, and `--vae`. 65 | ```shell 66 | $ python demo.py -d co3d_toy -r data/co3d_toy -c hydrant -i 0,5,7 -v 3 \ 67 | --eft checkpoints/sf/hydrant/ckpt_latest_eft.pt \ 68 | --vldm checkpoints/sf/hydrant/ckpt_latest.pt \ 69 | --vae checkpoints/sd/sd-v1-3-vae.pt 70 | ``` 71 | 72 | To use the original CO3Dv2 dataset, pass `co3d` for dataset_name `-d` and also the dataset root location `-r`. 73 | ```shell 74 | $ python demo.py -d co3d -r data/co3d/ -c hydrant -i 0 75 | ``` 76 | 77 | ### Flags 78 | ``` 79 | -g, --gpus number of gpus to use (default: 1) 80 | -p, --port last digit of DDP port (default: 1) 81 | -d, --dataset_name name of dataset (default: co3d_toy) 82 | -r, --root root directory of the dataset 83 | -c, --category CO3D category 84 | -v, --input_views number of random input views (default: 2) 85 | -i, --idx scene indices to evaluate (default: 0) 86 | -e, --eft location to EFT checkpoint 87 | -l, --vldm location to VLDM checkpoint 88 | -a, --vae location to Stable Diffusion VAE checkpoint 89 | ``` 90 | 91 | ### Output 92 | Output artifacts—images, gifs, torch-ngp checkpoints—will be saved to `output/demo/` by default. 93 | 94 | --- 95 | 96 | ## Training 97 | Early access training code is provided in `train.py`. Please follow the evaluation tutorial above to setup the environment and pretrained VAE weights. It is recommended to directly modify `train.py` to specify the experiment directory and set the training hyperparameters. We show training flags below. 98 | 99 | ### Flags 100 | ``` 101 | -g, --gpus number of gpus to use (default: 1) 102 | -p, --port last digit of DDP port (default: 1) 103 | -d, --dataset_name name of dataset (default: co3d_toy) 104 | -r, --root root directory of the dataset 105 | -c, --category CO3D category 106 | -a, --vae location to Stable Diffusion VAE checkpoint 107 | -b, --backend distributed data parallel backend (default: nccl) 108 | ``` 109 | 110 | ### Using Custom Datasets 111 | To train on a custom dataset, one needs to write a custom dataloader. We describe the required outputs for the `__getitem__` function, which should be a dictionary containing: 112 | ``` 113 | { 114 | 'images': (B, 3, H, W) image tensor, 115 | 'R': (B, 3, 3) PyTorch3D rotation, 116 | 'T': (B, 3) PyTorch3D translation, 117 | 'f': (B, 2) PyTorch3D focal_length in NDC space, 118 | 'c': (B, 2) PyTorch3D principal_point in NDC space, 119 | 'valid_region': (B, 1, H, W) binary tensor where 1 denotes valid image region, 120 | 'image_size': (B, 2) image size 121 | } 122 | ``` 123 | 124 | --- 125 | ## Citation 126 | If you find this work useful, a citation will be appreciated via: 127 | 128 | ``` 129 | @inproceedings{zhou2023sparsefusion, 130 | title={SparseFusion: Distilling View-conditioned Diffusion for 3D Reconstruction}, 131 | author={Zhizhuo Zhou and Shubham Tulsiani}, 132 | booktitle={CVPR}, 133 | year={2023} 134 | } 135 | ``` 136 | 137 | ## Acknowledgements 138 | We thank Naveen Venkat, Mayank Agarwal, Jeff Tan, Paritosh Mittal, Yen-Chi Cheng, and Nikolaos Gkanatsios for helpful discussions and feedback. We also thank David Novotny and Jonáš Kulhánek for sharing outputs of their work and helpful correspondence. This material is based upon work supported by the National Science Foundation Graduate Research Fellowship under Grant No. (DGE1745016, DGE2140739). 139 | 140 | We also use parts of existing projects: 141 | 142 | VAE from [Stable Diffusion](https://github.com/CompVis/stable-diffusion). 143 | ``` 144 | @misc{rombach2021highresolution, 145 | title={High-Resolution Image Synthesis with Latent Diffusion Models}, 146 | author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, 147 | year={2021}, 148 | eprint={2112.10752}, 149 | archivePrefix={arXiv}, 150 | primaryClass={cs.CV} 151 | } 152 | ``` 153 | 154 | Diffusion model from [Imagen Pytorch](https://github.com/lucidrains/imagen-pytorch). 155 | ``` 156 | @misc{imagen-pytorch, 157 | Author = {Phil Wang}, 158 | Year = {2022}, 159 | Note = {https://github.com/lucidrains/imagen-pytorch}, 160 | Title = {Imagen - Pytorch} 161 | } 162 | ``` 163 | 164 | Instant NGP implementation from [torch-ngp](https://github.com/ashawkey/torch-ngp). 165 | ``` 166 | @misc{torch-ngp, 167 | Author = {Jiaxiang Tang}, 168 | Year = {2022}, 169 | Note = {https://github.com/ashawkey/torch-ngp}, 170 | Title = {Torch-ngp: a PyTorch implementation of instant-ngp} 171 | } 172 | ``` -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.distributed as dist 5 | import torch.multiprocessing as mp 6 | from pytorch3d.renderer import PerspectiveCameras 7 | 8 | 9 | from sparsefusion.distillation import distillation_loop, get_default_torch_ngp_opt 10 | from utils.camera_utils import RelativeCameraLoader 11 | from utils.common_utils import get_lpips_fn, get_metrics, split_list, normalize, unnormalize 12 | from utils.co3d_dataloader import CO3Dv2Wrapper 13 | from utils.co3d_dataloader import CO3D_ALL_CATEGORIES, CO3D_ALL_TEN 14 | from utils.load_model import load_models 15 | from utils.load_dataset import load_dataset_test 16 | from utils.check_args import check_args 17 | 18 | def fit(gpu, args): 19 | #@ SPAWN DISTRIBUTED NODES 20 | rank = args.nr * args.gpus + gpu 21 | print('spawning gpu rank', rank, 'out of', args.gpus) 22 | dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank) 23 | torch.cuda.set_device(gpu) 24 | os.makedirs(args.exp_dir, exist_ok=True) 25 | os.makedirs(args.exp_dir + '/log/', exist_ok=True) 26 | os.makedirs(args.exp_dir + '/metrics/', exist_ok=True) 27 | os.makedirs(args.exp_dir + '/render_imgs/', exist_ok=True) 28 | os.makedirs(args.exp_dir + '/render_gifs/', exist_ok=True) 29 | render_imgs_dir = args.exp_dir + '/render_imgs/' 30 | print('evaluating', args.exp_dir) 31 | 32 | #@ INIT METRICS 33 | loss_fn_vgg = get_lpips_fn() 34 | 35 | #@ SET CATEGORIES 36 | if args.category == 'all_ten': 37 | cat_list = CO3D_ALL_TEN 38 | elif args.category == 'all': 39 | cat_list = CO3D_ALL_CATEGORIES 40 | else: 41 | cat_list = [args.category] 42 | 43 | #@ LOOP THROUGH CATEGORIES 44 | for ci, cat in enumerate(cat_list): 45 | 46 | #@ LOAD MODELS 47 | eft, vae, vldm = load_models(gpu=gpu, args=args, verbose=False) 48 | use_diffusion = True 49 | 50 | #@ LOAD DATASET 51 | print(f'gpu {gpu}: setting category to {cat} {ci}/{len(cat_list)}') 52 | args.category = cat 53 | test_dataset = load_dataset_test(args, image_size=args.image_size, masked=False) 54 | 55 | #@ SPLIT VAL LIST 56 | if args.val_list == None: 57 | args.val_list = torch.arange(len(test_dataset)).long().tolist() 58 | 59 | val_list = split_list(args.val_list, args.gpus)[gpu] 60 | print(f'gpu {gpu}: assigned idx {val_list}') 61 | 62 | #@ 63 | args.val_seed = 0 64 | context_views = args.context_views 65 | 66 | #@ LOOP THROUGH VAL LIST 67 | for val_idx in val_list: 68 | 69 | #@ FETCH DATA 70 | data = test_dataset.__getitem__(val_idx) 71 | scene_idx = val_idx 72 | scene_cameras = PerspectiveCameras(R=data['R'],T=data['T'],focal_length=data['f'],principal_point=data['c'],image_size=data['image_size']).cuda(gpu) 73 | scene_rgb = data['images'].cuda(gpu) 74 | scene_mask = data['masks'].cuda(gpu) 75 | scene_valid_region = data['valid_region'].cuda(gpu) 76 | 77 | #@ SET RANDOM INPUT VIEWS 78 | g_cpu = torch.Generator() 79 | g_cpu.manual_seed(args.val_seed + val_idx) 80 | rand_perm = torch.randperm(len(data['R']), generator=g_cpu) 81 | input_idx = rand_perm[:context_views].long().tolist() 82 | output_idx = torch.arange(len(data['R'])).long().tolist() 83 | print('val_idx', val_idx, input_idx) 84 | 85 | #@ CALL DISTILLATION LOOP 86 | seq_name = f'{cat}_{val_idx:03d}_c{len(input_idx)}' 87 | opt = get_default_torch_ngp_opt() 88 | distillation_loop( 89 | gpu, 90 | args, 91 | opt, 92 | (eft, vae, vldm), 93 | args.exp_dir, 94 | seq_name, 95 | scene_cameras, 96 | scene_rgb, 97 | scene_mask, 98 | scene_valid_region, 99 | input_idx, 100 | use_diffusion=True, 101 | max_itr=3000, 102 | loss_fn_vgg=loss_fn_vgg 103 | ) 104 | 105 | 106 | def main(): 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', 109 | help='number of data loading workers (default: 4)') 110 | parser.add_argument('-g', '--gpus', default=1, type=int, 111 | help='number of gpus per node') 112 | parser.add_argument('-nr', '--nr', default=0, type=int, 113 | help='ranking within the nodes') 114 | parser.add_argument('-p', '--port', default=1, type=int, metavar='N', 115 | help='last digit of port (default: 1234[1])') 116 | parser.add_argument('-c', '--category', type=str, metavar='S', required=True, 117 | help='category') 118 | parser.add_argument('-r', '--root', type=str, default='data/co3d_toy', metavar='S', 119 | help='location of test features') 120 | parser.add_argument('-d', '--dataset_name', type=str, default='co3d_toy', metavar='S', 121 | help='dataset name') 122 | parser.add_argument('-e', '--eft', type=str, default='-DNE', metavar='S', 123 | help='eft ckpt') 124 | parser.add_argument('-l', '--vldm', type=str, default='-DNE', metavar='S', 125 | help='vldm ckpt') 126 | parser.add_argument('-a', '--vae', type=str, default='-DNE', metavar='S', 127 | help='vae ckpt') 128 | parser.add_argument('-i', '--idx', type=str, default='-DNE', metavar='N', 129 | help='evaluataion indicies') 130 | parser.add_argument('-v', '--input_views', type=int, default=2, metavar='N', 131 | help='input views') 132 | args = parser.parse_args() 133 | 134 | #@ SET MULTIPROCESSING PORTS 135 | args.world_size = args.gpus * args.nodes 136 | os.environ['MASTER_ADDR'] = 'localhost' 137 | os.environ['MASTER_PORT'] = f'1234{args.port}' 138 | print('using port', f'1234{args.port}') 139 | 140 | #@ SET DEFAUL PARAMETERS 141 | args.use_r = True 142 | args.encoder = 'resnet18' 143 | args.num_input = 4 144 | args.timesteps = 500 145 | args.objective = 'noise' 146 | args.scale_factor = 8 147 | args.image_size = 256 148 | args.z_scale_factor = 0.18215 149 | 150 | args.server_prefix = 'checkpoints/' 151 | args.diffusion_exp_name = 'sf' 152 | args.eft_ckpt = f'{args.server_prefix}/{args.diffusion_exp_name}/{args.category}/ckpt_latest_eft.pt' 153 | args.vae_ckpt = f'{args.server_prefix}/sd/sd-v1-3-vae.ckpt' 154 | args.vldm_ckpt = f'{args.server_prefix}/{args.diffusion_exp_name}/{args.category}/ckpt_latest.pt' 155 | 156 | args.context_views = args.input_views 157 | args.val_list = [0] 158 | args.exp_dir = 'output/demo/' 159 | 160 | #@ OVERRIDE DEFAULT ARGS WITH INPUTS 161 | if args.vae != '-DNE': 162 | args.vae_ckpt = args.vae 163 | if args.eft != '-DNE': 164 | args.eft_ckpt = args.eft 165 | if args.vldm != '-DNE': 166 | args.vldm_ckpt = args.vldm 167 | if args.idx != '-DNE': 168 | try: 169 | val_list_str = args.idx.split(',') 170 | args.val_list = [] 171 | for val_str in val_list_str: 172 | args.val_list.append(int(val_str)) 173 | except: 174 | print('ERROR: -i --idx arg invalid, please use form 1,2,3') 175 | print('Exiting...') 176 | exit(1) 177 | 178 | check_args(args) 179 | 180 | mp.spawn(fit, nprocs=args.gpus, args=(args,)) 181 | 182 | if __name__ == '__main__': 183 | main() -------------------------------------------------------------------------------- /external/README.md: -------------------------------------------------------------------------------- 1 | All code in external/ are from many amazing researchers and engineers who made 2 | their code public. Please see the README.md in the root directory for credits. -------------------------------------------------------------------------------- /external/einops_exts.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Einops exts 3 | #@ FROM https://github.com/lucidrains/einops-exts 4 | ''' 5 | 6 | import re 7 | from torch import nn 8 | from functools import wraps, partial 9 | 10 | from einops import rearrange, reduce, repeat 11 | 12 | # checking shape 13 | # @nils-werner 14 | # https://github.com/arogozhnikov/einops/issues/168#issuecomment-1042933838 15 | 16 | def check_shape(tensor, pattern, **kwargs): 17 | return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs) 18 | 19 | # do same einops operations on a list of tensors 20 | 21 | def _many(fn): 22 | @wraps(fn) 23 | def inner(tensors, pattern, **kwargs): 24 | return (fn(tensor, pattern, **kwargs) for tensor in tensors) 25 | return inner 26 | 27 | # do einops with unflattening of anonymously named dimensions 28 | # (...flattened) -> ...flattened 29 | 30 | def _with_anon_dims(fn): 31 | @wraps(fn) 32 | def inner(tensor, pattern, **kwargs): 33 | regex = r'(\.\.\.[a-zA-Z]+)' 34 | matches = re.findall(regex, pattern) 35 | get_anon_dim_name = lambda t: t.lstrip('...') 36 | dim_prefixes = tuple(map(get_anon_dim_name, set(matches))) 37 | 38 | update_kwargs_dict = dict() 39 | 40 | for prefix in dim_prefixes: 41 | assert prefix in kwargs, f'dimension list "{prefix}" was not passed in' 42 | dim_list = kwargs[prefix] 43 | assert isinstance(dim_list, (list, tuple)), f'dimension list "{prefix}" needs to be a tuple of list of dimensions' 44 | dim_names = list(map(lambda ind: f'{prefix}{ind}', range(len(dim_list)))) 45 | update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list)) 46 | 47 | def sub_with_anonymous_dims(t): 48 | dim_name_prefix = get_anon_dim_name(t.groups()[0]) 49 | return ' '.join(update_kwargs_dict[dim_name_prefix].keys()) 50 | 51 | pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern) 52 | 53 | for prefix, update_dict in update_kwargs_dict.items(): 54 | del kwargs[prefix] 55 | kwargs.update(update_dict) 56 | 57 | return fn(tensor, pattern_new, **kwargs) 58 | return inner 59 | 60 | # generate all helper functions 61 | 62 | rearrange_many = _many(rearrange) 63 | repeat_many = _many(repeat) 64 | reduce_many = _many(reduce) 65 | 66 | rearrange_with_anon_dims = _with_anon_dims(rearrange) 67 | repeat_with_anon_dims = _with_anon_dims(repeat) 68 | reduce_with_anon_dims = _with_anon_dims(reduce) -------------------------------------------------------------------------------- /external/external_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | A wrapper class for the perceptual deep feature loss. 3 | 4 | Reference: 5 | Richard Zhang et al. The Unreasonable Effectiveness of Deep Features as a Perceptual Metric. (CVPR 2018). 6 | """ 7 | import lpips 8 | import torch.nn as nn 9 | 10 | 11 | class PerceptualLoss(nn.Module): 12 | def __init__(self, net, device): 13 | super().__init__() 14 | self.model = lpips.LPIPS(net=net, verbose=False).to(device) 15 | self.device = device 16 | 17 | def get_device(self, default_device=None): 18 | """ 19 | Returns which device module is on, assuming all parameters are on the same GPU. 20 | """ 21 | try: 22 | return next(self.parameters()).device 23 | except StopIteration: 24 | return default_device 25 | 26 | def __call__(self, pred, target, normalize=True): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is on, scales images between [-1, 1] 30 | Assumes the inputs are in range [0, 1]. 31 | B 3 H W 32 | """ 33 | if pred.shape[1] != 3: 34 | pred = pred.permute(0, 3, 1, 2) 35 | target = target.permute(0, 3, 1, 2) 36 | # print(pred.shape, target.shape) 37 | if normalize: 38 | target = 2 * target - 1 39 | pred = 2 * pred - 1 40 | 41 | # temp_device = pred.device 42 | # device = self.get_device(temp_device) 43 | 44 | device = self.device 45 | 46 | pred = pred.to(device).float() 47 | target = target.to(device) 48 | dist = self.model.forward(pred, target) 49 | return dist.to(device) 50 | -------------------------------------------------------------------------------- /external/gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /external/gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /external/gridencoder/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _gridencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | _gridtype_to_id = { 15 | 'hash': 0, 16 | 'tiled': 1, 17 | } 18 | 19 | class _grid_encode(Function): 20 | @staticmethod 21 | @custom_fwd 22 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False): 23 | # inputs: [B, D], float in [0, 1] 24 | # embeddings: [sO, C], float 25 | # offsets: [L + 1], int 26 | # RETURN: [B, F], float 27 | 28 | inputs = inputs.contiguous() 29 | 30 | B, D = inputs.shape # batch size, coord dim 31 | L = offsets.shape[0] - 1 # level 32 | C = embeddings.shape[1] # embedding dim for each level 33 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 34 | H = base_resolution # base resolution 35 | 36 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) 37 | # if C % 2 != 0, force float, since half for atomicAdd is very slow. 38 | if torch.is_autocast_enabled() and C % 2 == 0: 39 | embeddings = embeddings.to(torch.half) 40 | 41 | # L first, optimize cache for cuda kernel, but needs an extra permute later 42 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) 43 | 44 | if calc_grad_inputs: 45 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) 46 | else: 47 | dy_dx = None 48 | 49 | _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners) 50 | 51 | # permute back to [B, L * C] 52 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 53 | 54 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 55 | ctx.dims = [B, D, C, L, S, H, gridtype] 56 | ctx.align_corners = align_corners 57 | 58 | return outputs 59 | 60 | @staticmethod 61 | #@once_differentiable 62 | @custom_bwd 63 | def backward(ctx, grad): 64 | 65 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 66 | B, D, C, L, S, H, gridtype = ctx.dims 67 | align_corners = ctx.align_corners 68 | 69 | # grad: [B, L * C] --> [L, B, C] 70 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 71 | 72 | grad_embeddings = torch.zeros_like(embeddings) 73 | 74 | if dy_dx is not None: 75 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 76 | else: 77 | grad_inputs = None 78 | 79 | _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners) 80 | 81 | if dy_dx is not None: 82 | grad_inputs = grad_inputs.to(inputs.dtype) 83 | 84 | return grad_inputs, grad_embeddings, None, None, None, None, None, None 85 | 86 | 87 | 88 | grid_encode = _grid_encode.apply 89 | 90 | 91 | class GridEncoder(nn.Module): 92 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False): 93 | super().__init__() 94 | 95 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 96 | if desired_resolution is not None: 97 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 98 | 99 | self.input_dim = input_dim # coord dims, 2 or 3 100 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 101 | self.level_dim = level_dim # encode channels per level 102 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 103 | self.log2_hashmap_size = log2_hashmap_size 104 | self.base_resolution = base_resolution 105 | self.output_dim = num_levels * level_dim 106 | self.gridtype = gridtype 107 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" 108 | self.align_corners = align_corners 109 | 110 | # allocate parameters 111 | offsets = [] 112 | offset = 0 113 | self.max_params = 2 ** log2_hashmap_size 114 | for i in range(num_levels): 115 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 116 | params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number 117 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 118 | offsets.append(offset) 119 | offset += params_in_level 120 | offsets.append(offset) 121 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 122 | self.register_buffer('offsets', offsets) 123 | 124 | self.n_params = offsets[-1] * level_dim 125 | 126 | # parameters 127 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 128 | 129 | self.reset_parameters() 130 | 131 | def reset_parameters(self): 132 | std = 1e-4 133 | self.embeddings.data.uniform_(-std, std) 134 | 135 | def __repr__(self): 136 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}" 137 | 138 | def forward(self, inputs, bound=1): 139 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] 140 | # return: [..., num_levels * level_dim] 141 | 142 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 143 | 144 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 145 | 146 | prefix_shape = list(inputs.shape[:-1]) 147 | inputs = inputs.view(-1, self.input_dim) 148 | 149 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners) 150 | outputs = outputs.view(prefix_shape + [self.output_dim]) 151 | 152 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 153 | 154 | return outputs -------------------------------------------------------------------------------- /external/gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='gridencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_gridencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'gridencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /external/gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /external/gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners); 14 | 15 | #endif -------------------------------------------------------------------------------- /external/ldm/configs/sd-vae.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: external.ldm.models.autoencoder.AutoencoderKL 3 | params: 4 | embed_dim: 4 5 | monitor: val/rec_loss 6 | ddconfig: 7 | double_z: true 8 | z_channels: 4 9 | resolution: 256 10 | in_channels: 3 11 | out_ch: 3 12 | ch: 128 13 | ch_mult: 14 | - 1 15 | - 2 16 | - 4 17 | - 4 18 | num_res_blocks: 2 19 | attn_resolutions: [] 20 | dropout: 0.0 21 | lossconfig: 22 | target: torch.nn.Identity -------------------------------------------------------------------------------- /external/ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from external.ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /external/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/external/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /external/ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from external.ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /external/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/external/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /external/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /external/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /external/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/external/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /external/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | from transformers import CLIPTokenizer, CLIPTextModel 7 | import kornia 8 | 9 | from external.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | class FrozenCLIPEmbedder(AbstractEncoder): 138 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 140 | super().__init__() 141 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 142 | self.transformer = CLIPTextModel.from_pretrained(version) 143 | self.device = device 144 | self.max_length = max_length 145 | self.freeze() 146 | 147 | def freeze(self): 148 | self.transformer = self.transformer.eval() 149 | for param in self.parameters(): 150 | param.requires_grad = False 151 | 152 | def forward(self, text): 153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 155 | tokens = batch_encoding["input_ids"].to(self.device) 156 | outputs = self.transformer(input_ids=tokens) 157 | 158 | z = outputs.last_hidden_state 159 | return z 160 | 161 | def encode(self, text): 162 | return self(text) 163 | 164 | 165 | class FrozenCLIPTextEmbedder(nn.Module): 166 | """ 167 | Uses the CLIP transformer encoder for text. 168 | """ 169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 170 | super().__init__() 171 | self.model, _ = clip.load(version, jit=False, device="cpu") 172 | self.device = device 173 | self.max_length = max_length 174 | self.n_repeat = n_repeat 175 | self.normalize = normalize 176 | 177 | def freeze(self): 178 | self.model = self.model.eval() 179 | for param in self.parameters(): 180 | param.requires_grad = False 181 | 182 | def forward(self, text): 183 | tokens = clip.tokenize(text).to(self.device) 184 | z = self.model.encode_text(tokens) 185 | if self.normalize: 186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 187 | return z 188 | 189 | def encode(self, text): 190 | z = self(text) 191 | if z.ndim==2: 192 | z = z[:, None, :] 193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 194 | return z 195 | 196 | 197 | class FrozenClipImageEmbedder(nn.Module): 198 | """ 199 | Uses the CLIP image encoder. 200 | """ 201 | def __init__( 202 | self, 203 | model, 204 | jit=False, 205 | device='cuda' if torch.cuda.is_available() else 'cpu', 206 | antialias=False, 207 | ): 208 | super().__init__() 209 | self.model, _ = clip.load(name=model, device=device, jit=jit) 210 | 211 | self.antialias = antialias 212 | 213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 215 | 216 | def preprocess(self, x): 217 | # normalize to [0,1] 218 | x = kornia.geometry.resize(x, (224, 224), 219 | interpolation='bicubic',align_corners=True, 220 | antialias=self.antialias) 221 | x = (x + 1.) / 2. 222 | # renormalize according to clip 223 | x = kornia.enhance.normalize(x, self.mean, self.std) 224 | return x 225 | 226 | def forward(self, x): 227 | # x is assumed to be in range [-1,1] 228 | return self.model.encode_image(self.preprocess(x)) 229 | 230 | 231 | if __name__ == "__main__": 232 | from external.ldm.util import count_params 233 | model = FrozenCLIPEmbedder() 234 | count_params(model, verbose=True) -------------------------------------------------------------------------------- /external/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from external.ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from external.ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /external/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/external/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /external/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from external.ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /external/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /external/ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /external/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /external/nerf/clip_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import torchvision.transforms as T 7 | import torchvision.transforms.functional as TF 8 | 9 | import clip 10 | 11 | class CLIPLoss: 12 | def __init__(self, device, name='ViT-B/16'): 13 | self.device = device 14 | self.name = name 15 | self.clip_model, self.transform_PIL = clip.load(self.name, device=self.device, jit=False) 16 | 17 | # disable training 18 | self.clip_model.eval() 19 | for p in self.clip_model.parameters(): 20 | p.requires_grad = False 21 | 22 | # image augmentation 23 | self.transform = T.Compose([ 24 | T.Resize((224, 224)), 25 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 26 | ]) 27 | 28 | # placeholder 29 | self.text_zs = None 30 | self.image_zs = None 31 | 32 | def normalize(self, x): 33 | return x / x.norm(dim=-1, keepdim=True) 34 | 35 | # image-text (e.g., dreamfields) 36 | def prepare_text(self, texts): 37 | # texts: list of strings. 38 | texts = clip.tokenize(texts).to(self.device) 39 | self.text_zs = self.normalize(self.clip_model.encode_text(texts)) 40 | print(f'[INFO] prepared CLIP text feature: {self.text_zs.shape}') 41 | 42 | def __call__(self, images, mode='text'): 43 | 44 | images = self.transform(images) 45 | image_zs = self.normalize(self.clip_model.encode_image(images)) 46 | 47 | if mode == 'text': 48 | # if more than one string, randomly choose one. 49 | if self.text_zs.shape[0] > 1: 50 | idx = random.randint(0, self.text_zs.shape[0] - 1) 51 | text_zs = self.text_zs[[idx]] 52 | else: 53 | text_zs = self.text_zs 54 | # broadcast text_zs to all image_zs 55 | loss = - (image_zs * text_zs).sum(-1).mean() 56 | else: 57 | raise NotImplementedError 58 | 59 | return loss 60 | 61 | # image-image (e.g., diet-nerf) 62 | def prepare_image(self, dataset): 63 | # images: a nerf dataset (we need both poses and images!) 64 | pass -------------------------------------------------------------------------------- /external/nerf/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from .renderer import NeRFRenderer 8 | 9 | 10 | class NeRFNetwork(NeRFRenderer): 11 | def __init__(self, 12 | encoding="hashgrid", 13 | encoding_dir="sphere_harmonics", 14 | encoding_bg="hashgrid", 15 | num_layers=5, 16 | hidden_dim=64, 17 | geo_feat_dim=15, 18 | num_layers_color=3, 19 | hidden_dim_color=64, 20 | num_layers_bg=2, 21 | hidden_dim_bg=64, 22 | bound=1, 23 | **kwargs, 24 | ): 25 | super().__init__(bound, **kwargs) 26 | 27 | # sigma network 28 | self.num_layers = num_layers 29 | self.hidden_dim = hidden_dim 30 | self.geo_feat_dim = geo_feat_dim 31 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) 32 | 33 | sigma_net = [] 34 | for l in range(num_layers): 35 | if l == 0: 36 | in_dim = self.in_dim 37 | else: 38 | in_dim = hidden_dim 39 | 40 | if l == num_layers - 1: 41 | out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color 42 | else: 43 | out_dim = hidden_dim 44 | 45 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 46 | 47 | self.sigma_net = nn.ModuleList(sigma_net) 48 | 49 | # color network 50 | self.num_layers_color = num_layers_color 51 | self.hidden_dim_color = hidden_dim_color 52 | self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) 53 | 54 | color_net = [] 55 | for l in range(num_layers_color): 56 | if l == 0: 57 | in_dim = self.in_dim_dir + self.geo_feat_dim 58 | else: 59 | in_dim = hidden_dim_color 60 | 61 | if l == num_layers_color - 1: 62 | out_dim = 3 # 3 rgb 63 | else: 64 | out_dim = hidden_dim_color 65 | 66 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 67 | 68 | self.color_net = nn.ModuleList(color_net) 69 | 70 | # background network 71 | if self.bg_radius > 0: 72 | self.num_layers_bg = num_layers_bg 73 | self.hidden_dim_bg = hidden_dim_bg 74 | self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid 75 | 76 | bg_net = [] 77 | for l in range(num_layers_bg): 78 | if l == 0: 79 | in_dim = self.in_dim_bg + self.in_dim_dir 80 | else: 81 | in_dim = hidden_dim_bg 82 | 83 | if l == num_layers_bg - 1: 84 | out_dim = 3 # 3 rgb 85 | else: 86 | out_dim = hidden_dim_bg 87 | 88 | bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) 89 | 90 | self.bg_net = nn.ModuleList(bg_net) 91 | else: 92 | self.bg_net = None 93 | 94 | 95 | def forward(self, x, d): 96 | # x: [N, 3], in [-bound, bound] 97 | # d: [N, 3], nomalized in [-1, 1] 98 | 99 | # sigma 100 | x = self.encoder(x, bound=self.bound) 101 | 102 | h = x 103 | for l in range(self.num_layers): 104 | h = self.sigma_net[l](h) 105 | if l != self.num_layers - 1: 106 | h = F.relu(h, inplace=True) 107 | 108 | #sigma = F.relu(h[..., 0]) 109 | sigma = trunc_exp(h[..., 0]) 110 | geo_feat = h[..., 1:] 111 | 112 | # color 113 | 114 | d = self.encoder_dir(d) 115 | h = torch.cat([d, geo_feat], dim=-1) 116 | for l in range(self.num_layers_color): 117 | h = self.color_net[l](h) 118 | if l != self.num_layers_color - 1: 119 | h = F.relu(h, inplace=True) 120 | 121 | # sigmoid activation for rgb 122 | color = torch.sigmoid(h) 123 | 124 | return sigma, color 125 | 126 | def density(self, x): 127 | # x: [N, 3], in [-bound, bound] 128 | 129 | x = self.encoder(x, bound=self.bound) 130 | h = x 131 | for l in range(self.num_layers): 132 | h = self.sigma_net[l](h) 133 | if l != self.num_layers - 1: 134 | h = F.relu(h, inplace=True) 135 | 136 | #sigma = F.relu(h[..., 0]) 137 | sigma = trunc_exp(h[..., 0]) 138 | geo_feat = h[..., 1:] 139 | 140 | return { 141 | 'sigma': sigma, 142 | 'geo_feat': geo_feat, 143 | } 144 | 145 | def background(self, x, d): 146 | # x: [N, 2], in [-1, 1] 147 | 148 | h = self.encoder_bg(x) # [N, C] 149 | d = self.encoder_dir(d) 150 | 151 | h = torch.cat([d, h], dim=-1) 152 | for l in range(self.num_layers_bg): 153 | h = self.bg_net[l](h) 154 | if l != self.num_layers_bg - 1: 155 | h = F.relu(h, inplace=True) 156 | 157 | # sigmoid activation for rgb 158 | rgbs = torch.sigmoid(h) 159 | 160 | return rgbs 161 | 162 | # allow masked inference 163 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 164 | # x: [N, 3] in [-bound, bound] 165 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 166 | 167 | if mask is not None: 168 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 169 | # in case of empty mask 170 | if not mask.any(): 171 | return rgbs 172 | x = x[mask] 173 | d = d[mask] 174 | geo_feat = geo_feat[mask] 175 | 176 | d = self.encoder_dir(d) 177 | h = torch.cat([d, geo_feat], dim=-1) 178 | for l in range(self.num_layers_color): 179 | h = self.color_net[l](h) 180 | if l != self.num_layers_color - 1: 181 | h = F.relu(h, inplace=True) 182 | 183 | # sigmoid activation for rgb 184 | h = torch.sigmoid(h) 185 | 186 | if mask is not None: 187 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 188 | else: 189 | rgbs = h 190 | 191 | return rgbs 192 | 193 | # optimizer utils 194 | def get_params(self, lr): 195 | 196 | params = [ 197 | {'params': self.encoder.parameters(), 'lr': lr}, 198 | {'params': self.sigma_net.parameters(), 'lr': lr}, 199 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 200 | {'params': self.color_net.parameters(), 'lr': lr}, 201 | ] 202 | if self.bg_radius > 0: 203 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 204 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 205 | 206 | return params 207 | -------------------------------------------------------------------------------- /external/nerf/network_df.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from external.ngp_activation import trunc_exp 6 | from external.nerf.renderer_df import NeRFRenderer 7 | 8 | import numpy as np 9 | from external.ngp_encoder import get_encoder 10 | 11 | from .utils import safe_normalize 12 | 13 | class MLP(nn.Module): 14 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): 15 | super().__init__() 16 | self.dim_in = dim_in 17 | self.dim_out = dim_out 18 | self.dim_hidden = dim_hidden 19 | self.num_layers = num_layers 20 | 21 | net = [] 22 | for l in range(num_layers): 23 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) 24 | 25 | self.net = nn.ModuleList(net) 26 | 27 | def forward(self, x): 28 | for l in range(self.num_layers): 29 | x = self.net[l](x) 30 | if l != self.num_layers - 1: 31 | x = F.relu(x, inplace=True) 32 | return x 33 | 34 | 35 | class NeRFNetwork(NeRFRenderer): 36 | def __init__(self, 37 | opt, 38 | num_layers=5, 39 | hidden_dim=128, 40 | num_layers_bg=2, 41 | hidden_dim_bg=64, 42 | ): 43 | 44 | super().__init__(opt) 45 | 46 | self.num_layers = num_layers 47 | self.hidden_dim = hidden_dim 48 | self.encoder, self.in_dim = get_encoder('frequency', input_dim=3) 49 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) 50 | 51 | # background network 52 | if self.bg_radius > 0: 53 | self.num_layers_bg = num_layers_bg 54 | self.hidden_dim_bg = hidden_dim_bg 55 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3) 56 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) 57 | 58 | else: 59 | self.bg_net = None 60 | 61 | def gaussian(self, x): 62 | # x: [B, N, 3] 63 | 64 | d = (x ** 2).sum(-1) 65 | g = 5 * torch.exp(-d / (2 * 0.2 ** 2)) 66 | 67 | return g 68 | 69 | def common_forward(self, x): 70 | # x: [N, 3], in [-bound, bound] 71 | 72 | # sigma 73 | h = self.encoder(x, bound=self.bound) 74 | 75 | h = self.sigma_net(h) 76 | 77 | sigma = trunc_exp(h[..., 0] + self.gaussian(x)) 78 | albedo = torch.sigmoid(h[..., 1:]) 79 | 80 | return sigma, albedo 81 | 82 | # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 83 | def finite_difference_normal(self, x, epsilon=1e-2): 84 | # x: [N, 3] 85 | dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 86 | dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 87 | dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 88 | dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 89 | dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) 90 | dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) 91 | 92 | normal = torch.stack([ 93 | 0.5 * (dx_pos - dx_neg) / epsilon, 94 | 0.5 * (dy_pos - dy_neg) / epsilon, 95 | 0.5 * (dz_pos - dz_neg) / epsilon 96 | ], dim=-1) 97 | 98 | return normal 99 | 100 | def normal(self, x): 101 | 102 | with torch.enable_grad(): 103 | x.requires_grad_(True) 104 | sigma, albedo = self.common_forward(x) 105 | # query gradient 106 | normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] 107 | 108 | # normalize... 109 | normal = safe_normalize(normal) 110 | normal[torch.isnan(normal)] = 0 111 | return normal 112 | 113 | def forward(self, x, d, l=None, ratio=1, shading='albedo'): 114 | # x: [N, 3], in [-bound, bound] 115 | # d: [N, 3], view direction, nomalized in [-1, 1] 116 | # l: [3], plane light direction, nomalized in [-1, 1] 117 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) 118 | 119 | if shading == 'albedo': 120 | # no need to query normal 121 | sigma, color = self.common_forward(x) 122 | normal = None 123 | 124 | else: 125 | # query normal 126 | 127 | # sigma, albedo = self.common_forward(x) 128 | # normal = self.finite_difference_normal(x) 129 | 130 | with torch.enable_grad(): 131 | x.requires_grad_(True) 132 | sigma, albedo = self.common_forward(x) 133 | # query gradient 134 | normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] 135 | 136 | # normalize... 137 | normal = safe_normalize(normal) 138 | normal[torch.isnan(normal)] = 0 139 | 140 | # lambertian shading 141 | lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,] 142 | 143 | if shading == 'textureless': 144 | color = lambertian.unsqueeze(-1).repeat(1, 3) 145 | elif shading == 'normal': 146 | color = (normal + 1) / 2 147 | else: # 'lambertian' 148 | color = albedo * lambertian.unsqueeze(-1) 149 | 150 | return sigma, color, normal 151 | 152 | 153 | def density(self, x): 154 | # x: [N, 3], in [-bound, bound] 155 | 156 | sigma, albedo = self.common_forward(x) 157 | 158 | return { 159 | 'sigma': sigma, 160 | 'albedo': albedo, 161 | } 162 | 163 | 164 | def background(self, d): 165 | 166 | h = self.encoder_bg(d) # [N, C] 167 | 168 | h = self.bg_net(h) 169 | 170 | # sigmoid activation for rgb 171 | rgbs = torch.sigmoid(h) 172 | 173 | return rgbs 174 | 175 | # optimizer utils 176 | def get_params(self, lr): 177 | 178 | params = [ 179 | # {'params': self.encoder.parameters(), 'lr': lr * 10}, 180 | {'params': self.sigma_net.parameters(), 'lr': lr}, 181 | ] 182 | 183 | if self.bg_radius > 0: 184 | # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) 185 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 186 | 187 | return params -------------------------------------------------------------------------------- /external/nerf/network_ff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from ffmlp import FFMLP 8 | 9 | from .renderer import NeRFRenderer 10 | 11 | class NeRFNetwork(NeRFRenderer): 12 | def __init__(self, 13 | encoding="hashgrid", 14 | encoding_dir="sphere_harmonics", 15 | num_layers=2, 16 | hidden_dim=64, 17 | geo_feat_dim=15, 18 | num_layers_color=3, 19 | hidden_dim_color=64, 20 | bound=1, 21 | **kwargs 22 | ): 23 | super().__init__(bound, **kwargs) 24 | 25 | # sigma network 26 | self.num_layers = num_layers 27 | self.hidden_dim = hidden_dim 28 | self.geo_feat_dim = geo_feat_dim 29 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) 30 | 31 | self.sigma_net = FFMLP( 32 | input_dim=self.in_dim, 33 | output_dim=1 + self.geo_feat_dim, 34 | hidden_dim=self.hidden_dim, 35 | num_layers=self.num_layers, 36 | ) 37 | 38 | # color network 39 | self.num_layers_color = num_layers_color 40 | self.hidden_dim_color = hidden_dim_color 41 | self.encoder_dir, self.in_dim_color = get_encoder(encoding_dir) 42 | self.in_dim_color += self.geo_feat_dim + 1 # a manual fixing to make it 32, as done in nerf_network.h#178 43 | 44 | self.color_net = FFMLP( 45 | input_dim=self.in_dim_color, 46 | output_dim=3, 47 | hidden_dim=self.hidden_dim_color, 48 | num_layers=self.num_layers_color, 49 | ) 50 | 51 | def forward(self, x, d): 52 | # x: [N, 3], in [-bound, bound] 53 | # d: [N, 3], nomalized in [-1, 1] 54 | 55 | # sigma 56 | x = self.encoder(x, bound=self.bound) 57 | h = self.sigma_net(x) 58 | 59 | #sigma = F.relu(h[..., 0]) 60 | sigma = trunc_exp(h[..., 0]) 61 | geo_feat = h[..., 1:] 62 | 63 | # color 64 | d = self.encoder_dir(d) 65 | 66 | # TODO: preallocate space and avoid this cat? 67 | p = torch.zeros_like(geo_feat[..., :1]) # manual input padding 68 | h = torch.cat([d, geo_feat, p], dim=-1) 69 | h = self.color_net(h) 70 | 71 | # sigmoid activation for rgb 72 | rgb = torch.sigmoid(h) 73 | 74 | return sigma, rgb 75 | 76 | def density(self, x): 77 | # x: [N, 3], in [-bound, bound] 78 | 79 | x = self.encoder(x, bound=self.bound) 80 | h = self.sigma_net(x) 81 | 82 | #sigma = F.relu(h[..., 0]) 83 | sigma = trunc_exp(h[..., 0]) 84 | geo_feat = h[..., 1:] 85 | 86 | return { 87 | 'sigma': sigma, 88 | 'geo_feat': geo_feat, 89 | } 90 | 91 | # allow masked inference 92 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 93 | # x: [N, 3] in [-bound, bound] 94 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 95 | 96 | #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 97 | #starter.record() 98 | 99 | if mask is not None: 100 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 101 | # in case of empty mask 102 | if not mask.any(): 103 | return rgbs 104 | x = x[mask] 105 | d = d[mask] 106 | geo_feat = geo_feat[mask] 107 | 108 | #print(x.shape, rgbs.shape) 109 | 110 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'mask = {curr_time}') 111 | #starter.record() 112 | 113 | d = self.encoder_dir(d) 114 | 115 | p = torch.zeros_like(geo_feat[..., :1]) # manual input padding 116 | h = torch.cat([d, geo_feat, p], dim=-1) 117 | 118 | h = self.color_net(h) 119 | 120 | # sigmoid activation for rgb 121 | h = torch.sigmoid(h) 122 | 123 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'call = {curr_time}') 124 | #starter.record() 125 | 126 | if mask is not None: 127 | rgbs[mask] = h.to(rgbs.dtype) 128 | else: 129 | rgbs = h 130 | 131 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'unmask = {curr_time}') 132 | #starter.record() 133 | 134 | return rgbs 135 | 136 | # optimizer utils 137 | def get_params(self, lr): 138 | 139 | params = [ 140 | {'params': self.encoder.parameters(), 'lr': lr}, 141 | {'params': self.sigma_net.parameters(), 'lr': lr}, 142 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 143 | {'params': self.color_net.parameters(), 'lr': lr}, 144 | ] 145 | if self.bg_radius > 0: 146 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 147 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 148 | 149 | return params -------------------------------------------------------------------------------- /external/nerf/network_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from external.ngp_activation import trunc_exp 6 | from external.nerf.renderer_df import NeRFRenderer 7 | 8 | import numpy as np 9 | from external.ngp_encoder import get_encoder 10 | 11 | from .utils import safe_normalize 12 | 13 | 14 | class MLP(nn.Module): 15 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): 16 | super().__init__() 17 | self.dim_in = dim_in 18 | self.dim_out = dim_out 19 | self.dim_hidden = dim_hidden 20 | self.num_layers = num_layers 21 | 22 | net = [] 23 | for l in range(num_layers): 24 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) 25 | 26 | self.net = nn.ModuleList(net) 27 | 28 | def forward(self, x): 29 | for l in range(self.num_layers): 30 | x = self.net[l](x) 31 | if l != self.num_layers - 1: 32 | x = F.relu(x, inplace=True) 33 | return x 34 | 35 | 36 | class NeRFNetwork(NeRFRenderer): 37 | def __init__(self, 38 | opt, 39 | num_layers=3, 40 | hidden_dim=64, 41 | num_layers_bg=2, 42 | hidden_dim_bg=64, 43 | ): 44 | 45 | super().__init__(opt) 46 | 47 | self.num_layers = num_layers 48 | self.hidden_dim = hidden_dim 49 | 50 | self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, log2_hashmap_size=16, desired_resolution=2048 * self.bound) 51 | 52 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) 53 | 54 | # background network 55 | if self.bg_radius > 0: 56 | self.num_layers_bg = num_layers_bg 57 | self.hidden_dim_bg = hidden_dim_bg 58 | 59 | # use a very simple network to avoid it learning the prompt... 60 | # self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048) 61 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3) 62 | 63 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) 64 | 65 | else: 66 | self.bg_net = None 67 | 68 | # add a density blob to the scene center 69 | def gaussian(self, x): 70 | # x: [B, N, 3] 71 | 72 | d = (x ** 2).sum(-1) 73 | g = 5 * torch.exp(-d / (2 * 0.2 ** 2)) 74 | 75 | return g 76 | 77 | def common_forward(self, x): 78 | # x: [N, 3], in [-bound, bound] 79 | 80 | # sigma 81 | h = self.encoder(x, bound=self.bound) 82 | 83 | h = self.sigma_net(h) 84 | 85 | sigma = trunc_exp(h[..., 0] + self.gaussian(x)) 86 | albedo = torch.sigmoid(h[..., 1:]) 87 | 88 | return sigma, albedo 89 | 90 | # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 91 | def finite_difference_normal(self, x, epsilon=1e-2): 92 | # x: [N, 3] 93 | dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 94 | dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 95 | dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 96 | dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) 97 | dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) 98 | dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) 99 | 100 | normal = torch.stack([ 101 | 0.5 * (dx_pos - dx_neg) / epsilon, 102 | 0.5 * (dy_pos - dy_neg) / epsilon, 103 | 0.5 * (dz_pos - dz_neg) / epsilon 104 | ], dim=-1) 105 | 106 | return normal 107 | 108 | @torch.no_grad() 109 | def common_forward_smooth(self, x, radius=1e-2, k=5): 110 | # x: [N, 3], in [-bound, bound] 111 | 112 | x_repeat = x.unsqueeze(0).repeat(k, 1, 1) 113 | eps = torch.rand_like(x_repeat) * radius * 2 - radius 114 | x_sample = x_repeat + eps 115 | 116 | sigma_list, albedo_list = [], [] 117 | for xi in range(len(x_sample)): 118 | # sigma 119 | h = self.encoder(x_sample[xi], bound=self.bound) 120 | 121 | h = self.sigma_net(h) 122 | 123 | sigma = trunc_exp(h[..., 0] + self.gaussian(x)) 124 | albedo = torch.sigmoid(h[..., 1:]) 125 | 126 | sigma_list.append(sigma) 127 | albedo_list.append(albedo) 128 | sigma = torch.stack(sigma_list, dim=0).mean(dim=0) 129 | albedo = torch.stack(albedo_list, dim=0).mean(dim=0) 130 | return sigma, albedo 131 | 132 | @torch.no_grad() 133 | def finite_difference_normal_smooth(self, x, epsilon=3e-1): 134 | # x: [N, 3] 135 | # eps 1e-2 | radius 136 | # eps 3e-1 | radius 3e-1 137 | radius = 3e-1 138 | k = 10 139 | dx_pos, _ = self.common_forward_smooth((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k) 140 | dx_neg, _ = self.common_forward_smooth((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k) 141 | dy_pos, _ = self.common_forward_smooth((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k) 142 | dy_neg, _ = self.common_forward_smooth((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k) 143 | dz_pos, _ = self.common_forward_smooth((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k) 144 | dz_neg, _ = self.common_forward_smooth((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k) 145 | 146 | normal = torch.stack([ 147 | 0.5 * (dx_pos - dx_neg) / epsilon, 148 | 0.5 * (dy_pos - dy_neg) / epsilon, 149 | 0.5 * (dz_pos - dz_neg) / epsilon 150 | ], dim=-1) 151 | 152 | return normal 153 | 154 | 155 | def normal(self, x, smooth=False): 156 | 157 | if smooth: 158 | normal = self.finite_difference_normal_smooth(x) 159 | else: 160 | normal = self.finite_difference_normal(x) 161 | normal = safe_normalize(normal) 162 | normal[torch.isnan(normal)] = 0 163 | 164 | return normal 165 | 166 | 167 | def forward(self, x, d, l=None, ratio=1, shading='textureless'): 168 | # x: [N, 3], in [-bound, bound] 169 | # d: [N, 3], view direction, nomalized in [-1, 1] 170 | # l: [3], plane light direction, nomalized in [-1, 1] 171 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) 172 | 173 | if shading == 'albedo': 174 | # no need to query normal 175 | sigma, color = self.common_forward(x) 176 | normal = None 177 | 178 | else: 179 | # query normal 180 | 181 | sigma, albedo = self.common_forward(x) 182 | if shading == 'textureless' or shading == 'normal': 183 | normal = self.normal(x, smooth=True) 184 | else: 185 | normal = self.normal(x) 186 | 187 | # lambertian shading 188 | lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,] 189 | 190 | if shading == 'textureless': 191 | color = lambertian.unsqueeze(-1).repeat(1, 3)*0.8 + .2 192 | elif shading == 'normal': 193 | color = (normal + 1) / 2 194 | else: # 'lambertian' 195 | color = albedo * lambertian.unsqueeze(-1) 196 | 197 | return sigma, color, normal 198 | 199 | 200 | def density(self, x): 201 | # x: [N, 3], in [-bound, bound] 202 | 203 | sigma, albedo = self.common_forward(x) 204 | 205 | return { 206 | 'sigma': sigma, 207 | 'albedo': albedo, 208 | } 209 | 210 | 211 | def background(self, d): 212 | 213 | h = self.encoder_bg(d) # [N, C] 214 | 215 | h = self.bg_net(h) 216 | 217 | # sigmoid activation for rgb 218 | rgbs = torch.sigmoid(h) 219 | 220 | return rgbs 221 | 222 | # optimizer utils 223 | def get_params(self, lr): 224 | 225 | params = [ 226 | {'params': self.encoder.parameters(), 'lr': lr * 10}, 227 | {'params': self.sigma_net.parameters(), 'lr': lr}, 228 | ] 229 | 230 | if self.bg_radius > 0: 231 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) 232 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 233 | 234 | return params -------------------------------------------------------------------------------- /external/nerf/network_tcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | import tinycudann as tcnn 8 | from activation import trunc_exp 9 | from .renderer import NeRFRenderer 10 | 11 | 12 | class NeRFNetwork(NeRFRenderer): 13 | def __init__(self, 14 | encoding="HashGrid", 15 | encoding_dir="SphericalHarmonics", 16 | num_layers=2, 17 | hidden_dim=64, 18 | geo_feat_dim=15, 19 | num_layers_color=3, 20 | hidden_dim_color=64, 21 | bound=1, 22 | **kwargs 23 | ): 24 | super().__init__(bound, **kwargs) 25 | 26 | # sigma network 27 | self.num_layers = num_layers 28 | self.hidden_dim = hidden_dim 29 | self.geo_feat_dim = geo_feat_dim 30 | 31 | per_level_scale = np.exp2(np.log2(2048 * bound / 16) / (16 - 1)) 32 | 33 | self.encoder = tcnn.Encoding( 34 | n_input_dims=3, 35 | encoding_config={ 36 | "otype": "HashGrid", 37 | "n_levels": 16, 38 | "n_features_per_level": 2, 39 | "log2_hashmap_size": 19, 40 | "base_resolution": 16, 41 | "per_level_scale": per_level_scale, 42 | }, 43 | ) 44 | 45 | self.sigma_net = tcnn.Network( 46 | n_input_dims=32, 47 | n_output_dims=1 + self.geo_feat_dim, 48 | network_config={ 49 | "otype": "FullyFusedMLP", 50 | "activation": "ReLU", 51 | "output_activation": "None", 52 | "n_neurons": hidden_dim, 53 | "n_hidden_layers": num_layers - 1, 54 | }, 55 | ) 56 | 57 | # color network 58 | self.num_layers_color = num_layers_color 59 | self.hidden_dim_color = hidden_dim_color 60 | 61 | self.encoder_dir = tcnn.Encoding( 62 | n_input_dims=3, 63 | encoding_config={ 64 | "otype": "SphericalHarmonics", 65 | "degree": 4, 66 | }, 67 | ) 68 | 69 | self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim 70 | 71 | self.color_net = tcnn.Network( 72 | n_input_dims=self.in_dim_color, 73 | n_output_dims=3, 74 | network_config={ 75 | "otype": "FullyFusedMLP", 76 | "activation": "ReLU", 77 | "output_activation": "None", 78 | "n_neurons": hidden_dim_color, 79 | "n_hidden_layers": num_layers_color - 1, 80 | }, 81 | ) 82 | 83 | 84 | def forward(self, x, d): 85 | # x: [N, 3], in [-bound, bound] 86 | # d: [N, 3], nomalized in [-1, 1] 87 | 88 | 89 | # sigma 90 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 91 | x = self.encoder(x) 92 | h = self.sigma_net(x) 93 | 94 | #sigma = F.relu(h[..., 0]) 95 | sigma = trunc_exp(h[..., 0]) 96 | geo_feat = h[..., 1:] 97 | 98 | # color 99 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] 100 | d = self.encoder_dir(d) 101 | 102 | #p = torch.zeros_like(geo_feat[..., :1]) # manual input padding 103 | h = torch.cat([d, geo_feat], dim=-1) 104 | h = self.color_net(h) 105 | 106 | # sigmoid activation for rgb 107 | color = torch.sigmoid(h) 108 | 109 | return sigma, color 110 | 111 | def density(self, x): 112 | # x: [N, 3], in [-bound, bound] 113 | 114 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 115 | x = self.encoder(x) 116 | h = self.sigma_net(x) 117 | 118 | #sigma = F.relu(h[..., 0]) 119 | sigma = trunc_exp(h[..., 0]) 120 | geo_feat = h[..., 1:] 121 | 122 | return { 123 | 'sigma': sigma, 124 | 'geo_feat': geo_feat, 125 | } 126 | 127 | # allow masked inference 128 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 129 | # x: [N, 3] in [-bound, bound] 130 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 131 | 132 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 133 | 134 | if mask is not None: 135 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 136 | # in case of empty mask 137 | if not mask.any(): 138 | return rgbs 139 | x = x[mask] 140 | d = d[mask] 141 | geo_feat = geo_feat[mask] 142 | 143 | # color 144 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] 145 | d = self.encoder_dir(d) 146 | 147 | h = torch.cat([d, geo_feat], dim=-1) 148 | h = self.color_net(h) 149 | 150 | # sigmoid activation for rgb 151 | h = torch.sigmoid(h) 152 | 153 | if mask is not None: 154 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 155 | else: 156 | rgbs = h 157 | 158 | return rgbs 159 | 160 | # optimizer utils 161 | def get_params(self, lr): 162 | 163 | params = [ 164 | {'params': self.encoder.parameters(), 'lr': lr}, 165 | {'params': self.sigma_net.parameters(), 'lr': lr}, 166 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 167 | {'params': self.color_net.parameters(), 'lr': lr}, 168 | ] 169 | if self.bg_radius > 0: 170 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 171 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 172 | 173 | return params -------------------------------------------------------------------------------- /external/nerf/provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import json 5 | from cv2 import transform 6 | import tqdm 7 | import numpy as np 8 | from scipy.spatial.transform import Slerp, Rotation 9 | 10 | import trimesh 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from .utils import get_rays 16 | 17 | 18 | # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 19 | def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): 20 | # for the fox dataset, 0.33 scales camera radius to ~ 2 21 | new_pose = np.array([ 22 | [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]], 23 | [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]], 24 | [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]], 25 | [0, 0, 0, 1], 26 | ], dtype=np.float32) 27 | return new_pose 28 | 29 | 30 | def visualize_poses(poses, size=0.1): 31 | # poses: [B, 4, 4] 32 | 33 | axes = trimesh.creation.axis(axis_length=4) 34 | box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() 35 | box.colors = np.array([[128, 128, 128]] * len(box.entities)) 36 | objects = [axes, box] 37 | 38 | for pose in poses: 39 | # a camera is visualized with 8 line segments. 40 | pos = pose[:3, 3] 41 | a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] 42 | b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] 43 | c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] 44 | d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] 45 | 46 | dir = (a + b + c + d) / 4 - pos 47 | dir = dir / (np.linalg.norm(dir) + 1e-8) 48 | o = pos + dir * 3 49 | 50 | segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) 51 | segs = trimesh.load_path(segs) 52 | objects.append(segs) 53 | 54 | trimesh.Scene(objects).show() 55 | 56 | 57 | def rand_poses(size, device, radius=1, theta_range=[np.pi/3, 2*np.pi/3], phi_range=[0, 2*np.pi]): 58 | ''' generate random poses from an orbit camera 59 | Args: 60 | size: batch size of generated poses. 61 | device: where to allocate the output. 62 | radius: camera radius 63 | theta_range: [min, max], should be in [0, \pi] 64 | phi_range: [min, max], should be in [0, 2\pi] 65 | Return: 66 | poses: [size, 4, 4] 67 | ''' 68 | 69 | def normalize(vectors): 70 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) 71 | 72 | thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] 73 | phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] 74 | 75 | centers = torch.stack([ 76 | radius * torch.sin(thetas) * torch.sin(phis), 77 | radius * torch.cos(thetas), 78 | radius * torch.sin(thetas) * torch.cos(phis), 79 | ], dim=-1) # [B, 3] 80 | 81 | # lookat 82 | forward_vector = - normalize(centers) 83 | up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) # confused at the coordinate system... 84 | right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1)) 85 | up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) 86 | 87 | poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) 88 | poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) 89 | poses[:, :3, 3] = centers 90 | 91 | return poses 92 | 93 | 94 | class NeRFDataset: 95 | def __init__(self, opt, device, type='train', downscale=1, n_test=10): 96 | super().__init__() 97 | 98 | self.opt = opt 99 | self.device = device 100 | self.type = type # train, val, test 101 | self.downscale = downscale 102 | self.root_path = opt.path 103 | self.preload = opt.preload # preload data into GPU 104 | self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. 105 | self.offset = opt.offset # camera offset 106 | self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. 107 | self.fp16 = opt.fp16 # if preload, load into fp16. 108 | 109 | self.training = self.type in ['train', 'all', 'trainval'] 110 | self.num_rays = self.opt.num_rays if self.training else -1 111 | 112 | self.rand_pose = opt.rand_pose 113 | 114 | # auto-detect transforms.json and split mode. 115 | if os.path.exists(os.path.join(self.root_path, 'transforms.json')): 116 | self.mode = 'colmap' # manually split, use view-interpolation for test. 117 | elif os.path.exists(os.path.join(self.root_path, 'transforms_train.json')): 118 | self.mode = 'blender' # provided split 119 | else: 120 | raise NotImplementedError(f'[NeRFDataset] Cannot find transforms*.json under {self.root_path}') 121 | 122 | # load nerf-compatible format data. 123 | if self.mode == 'colmap': 124 | with open(os.path.join(self.root_path, 'transforms.json'), 'r') as f: 125 | transform = json.load(f) 126 | elif self.mode == 'blender': 127 | # load all splits (train/valid/test), this is what instant-ngp in fact does... 128 | if type == 'all': 129 | transform_paths = glob.glob(os.path.join(self.root_path, '*.json')) 130 | transform = None 131 | for transform_path in transform_paths: 132 | with open(transform_path, 'r') as f: 133 | tmp_transform = json.load(f) 134 | if transform is None: 135 | transform = tmp_transform 136 | else: 137 | transform['frames'].extend(tmp_transform['frames']) 138 | # load train and val split 139 | elif type == 'trainval': 140 | with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f: 141 | transform = json.load(f) 142 | with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f: 143 | transform_val = json.load(f) 144 | transform['frames'].extend(transform_val['frames']) 145 | # only load one specified split 146 | else: 147 | with open(os.path.join(self.root_path, f'transforms_{type}.json'), 'r') as f: 148 | transform = json.load(f) 149 | 150 | else: 151 | raise NotImplementedError(f'unknown dataset mode: {self.mode}') 152 | 153 | # load image size 154 | if 'h' in transform and 'w' in transform: 155 | self.H = int(transform['h']) // downscale 156 | self.W = int(transform['w']) // downscale 157 | else: 158 | # we have to actually read an image to get H and W later. 159 | self.H = self.W = None 160 | 161 | # read images 162 | frames = transform["frames"] 163 | #frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort... 164 | 165 | # for colmap, manually interpolate a test set. 166 | if self.mode == 'colmap' and type == 'test': 167 | 168 | # choose two random poses, and interpolate between. 169 | f0, f1 = np.random.choice(frames, 2, replace=False) 170 | pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4] 171 | pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4] 172 | rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]])) 173 | slerp = Slerp([0, 1], rots) 174 | 175 | self.poses = [] 176 | self.images = None 177 | for i in range(n_test + 1): 178 | ratio = np.sin(((i / n_test) - 0.5) * np.pi) * 0.5 + 0.5 179 | pose = np.eye(4, dtype=np.float32) 180 | pose[:3, :3] = slerp(ratio).as_matrix() 181 | pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3] 182 | self.poses.append(pose) 183 | 184 | else: 185 | # for colmap, manually split a valid set (the first frame). 186 | if self.mode == 'colmap': 187 | if type == 'train': 188 | frames = frames[1:] 189 | elif type == 'val': 190 | frames = frames[:1] 191 | # else 'all' or 'trainval' : use all frames 192 | 193 | self.poses = [] 194 | self.images = [] 195 | for f in tqdm.tqdm(frames, desc=f'Loading {type} data'): 196 | f_path = os.path.join(self.root_path, f['file_path']) 197 | if self.mode == 'blender' and '.' not in os.path.basename(f_path): 198 | f_path += '.png' # so silly... 199 | 200 | # there are non-exist paths in fox... 201 | if not os.path.exists(f_path): 202 | continue 203 | 204 | pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] 205 | pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) 206 | 207 | image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] 208 | if self.H is None or self.W is None: 209 | self.H = image.shape[0] // downscale 210 | self.W = image.shape[1] // downscale 211 | 212 | # add support for the alpha channel as a mask. 213 | if image.shape[-1] == 3: 214 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 215 | else: 216 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) 217 | 218 | if image.shape[0] != self.H or image.shape[1] != self.W: 219 | image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA) 220 | 221 | image = image.astype(np.float32) / 255 # [H, W, 3/4] 222 | 223 | self.poses.append(pose) 224 | self.images.append(image) 225 | 226 | self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4] 227 | if self.images is not None: 228 | self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C] 229 | 230 | # calculate mean radius of all camera poses 231 | self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() 232 | #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}') 233 | 234 | # initialize error_map 235 | if self.training and self.opt.error_map: 236 | self.error_map = torch.ones([self.images.shape[0], 128 * 128], dtype=torch.float) # [B, 128 * 128], flattened for easy indexing, fixed resolution... 237 | else: 238 | self.error_map = None 239 | 240 | # [debug] uncomment to view all training poses. 241 | # visualize_poses(self.poses.numpy()) 242 | 243 | # [debug] uncomment to view examples of randomly generated poses. 244 | # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy()) 245 | 246 | if self.preload: 247 | self.poses = self.poses.to(self.device) 248 | if self.images is not None: 249 | # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ? 250 | if self.fp16 and self.opt.color_space != 'linear': 251 | dtype = torch.half 252 | else: 253 | dtype = torch.float 254 | self.images = self.images.to(dtype).to(self.device) 255 | if self.error_map is not None: 256 | self.error_map = self.error_map.to(self.device) 257 | 258 | # load intrinsics 259 | if 'fl_x' in transform or 'fl_y' in transform: 260 | fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale 261 | fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale 262 | elif 'camera_angle_x' in transform or 'camera_angle_y' in transform: 263 | # blender, assert in radians. already downscaled since we use H/W 264 | fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None 265 | fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None 266 | if fl_x is None: fl_x = fl_y 267 | if fl_y is None: fl_y = fl_x 268 | else: 269 | raise RuntimeError('Failed to load focal length, please check the transforms.json!') 270 | 271 | cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2) 272 | cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2) 273 | 274 | self.intrinsics = np.array([fl_x, fl_y, cx, cy]) 275 | 276 | 277 | def collate(self, index): 278 | 279 | B = len(index) # a list of length 1 280 | 281 | # random pose without gt images. 282 | if self.rand_pose == 0 or index[0] >= len(self.poses): 283 | 284 | poses = rand_poses(B, self.device, radius=self.radius) 285 | 286 | # sample a low-resolution but full image for CLIP 287 | s = np.sqrt(self.H * self.W / self.num_rays) # only in training, assert num_rays > 0 288 | rH, rW = int(self.H / s), int(self.W / s) 289 | rays = get_rays(poses, self.intrinsics / s, rH, rW, -1) 290 | 291 | return { 292 | 'H': rH, 293 | 'W': rW, 294 | 'rays_o': rays['rays_o'], 295 | 'rays_d': rays['rays_d'], 296 | } 297 | 298 | poses = self.poses[index].to(self.device) # [B, 4, 4] 299 | 300 | error_map = None if self.error_map is None else self.error_map[index] 301 | 302 | rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, error_map, self.opt.patch_size) 303 | 304 | results = { 305 | 'H': self.H, 306 | 'W': self.W, 307 | 'rays_o': rays['rays_o'], 308 | 'rays_d': rays['rays_d'], 309 | } 310 | 311 | if self.images is not None: 312 | images = self.images[index].to(self.device) # [B, H, W, 3/4] 313 | if self.training: 314 | C = images.shape[-1] 315 | images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4] 316 | results['images'] = images 317 | 318 | # need inds to update error_map 319 | if error_map is not None: 320 | results['index'] = index 321 | results['inds_coarse'] = rays['inds_coarse'] 322 | 323 | return results 324 | 325 | def dataloader(self): 326 | size = len(self.poses) 327 | if self.training and self.rand_pose > 0: 328 | size += size // self.rand_pose # index >= size means we use random pose. 329 | loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) 330 | loader._data = self # an ugly fix... we need to access error_map & poses in trainer. 331 | loader.has_gt = self.images is not None 332 | return loader -------------------------------------------------------------------------------- /external/ngp_activation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Helper functions for diffusion model 3 | #@ FROM https://github.com/ashawkey/torch-ngp 4 | ''' 5 | 6 | import torch 7 | from torch.autograd import Function 8 | from torch.cuda.amp import custom_bwd, custom_fwd 9 | 10 | class _trunc_exp(Function): 11 | @staticmethod 12 | @custom_fwd(cast_inputs=torch.float32) # cast to float32 13 | def forward(ctx, x): 14 | ctx.save_for_backward(x) 15 | return torch.exp(x) 16 | 17 | @staticmethod 18 | @custom_bwd 19 | def backward(ctx, g): 20 | x = ctx.saved_tensors[0] 21 | return g * torch.exp(x.clamp(-15, 15)) 22 | 23 | trunc_exp = _trunc_exp.apply -------------------------------------------------------------------------------- /external/ngp_encoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Helper functions for diffusion model 3 | #@ FROM https://github.com/ashawkey/torch-ngp 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class FreqEncoder(nn.Module): 11 | def __init__(self, input_dim, max_freq_log2, N_freqs, 12 | log_sampling=True, include_input=True, 13 | periodic_fns=(torch.sin, torch.cos)): 14 | 15 | super().__init__() 16 | 17 | self.input_dim = input_dim 18 | self.include_input = include_input 19 | self.periodic_fns = periodic_fns 20 | 21 | self.output_dim = 0 22 | if self.include_input: 23 | self.output_dim += self.input_dim 24 | 25 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 26 | 27 | if log_sampling: 28 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) 29 | else: 30 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) 31 | 32 | self.freq_bands = self.freq_bands.numpy().tolist() 33 | 34 | def forward(self, input, **kwargs): 35 | 36 | out = [] 37 | if self.include_input: 38 | out.append(input) 39 | 40 | for i in range(len(self.freq_bands)): 41 | freq = self.freq_bands[i] 42 | for p_fn in self.periodic_fns: 43 | out.append(p_fn(input * freq)) 44 | 45 | out = torch.cat(out, dim=-1) 46 | 47 | 48 | return out 49 | 50 | def get_encoder(encoding, input_dim=3, 51 | multires=6, 52 | degree=4, 53 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, 54 | **kwargs): 55 | 56 | if encoding == 'None': 57 | return lambda x, **kwargs: x, input_dim 58 | 59 | elif encoding == 'frequency': 60 | raise NotImplementedError 61 | 62 | elif encoding == 'sphere_harmonics': 63 | raise NotImplementedError 64 | 65 | elif encoding == 'hashgrid': 66 | from external.gridencoder import GridEncoder 67 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) 68 | 69 | elif encoding == 'tiledgrid': 70 | from external.gridencoder import GridEncoder 71 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) 72 | 73 | elif encoding == 'ash': 74 | raise NotImplementedError 75 | 76 | else: 77 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') 78 | 79 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /external/plms.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Fast PNDM PLMS Sampler 3 | #@ FROM https://github.com/CompVis/stable-diffusion 4 | ''' 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | from external.imagen_pytorch import GaussianDiffusionContinuousTimes, right_pad_dims_to 9 | from sparsefusion.vldm import DDPM 10 | from einops import rearrange 11 | 12 | 13 | class PLMSSampler(): 14 | 15 | def __init__(self, diffusion: DDPM, plms_steps=100): 16 | 17 | self.diffusion = diffusion 18 | self.plms_steps = plms_steps 19 | 20 | @torch.no_grad() 21 | def sample(self, 22 | image=None, 23 | max_thres=.999, 24 | cond_images=None, 25 | cond_scale=1.0, 26 | use_tqdm=True, 27 | return_noise=False, 28 | **kwargs 29 | ): 30 | ''' 31 | Single UNet PLMS Sampler 32 | ''' 33 | outputs = [] 34 | batch_size = cond_images.shape[0] 35 | shape = (batch_size, self.diffusion.sample_channels[0], self.diffusion.image_sizes[0], self.diffusion.image_sizes[0]) 36 | img, x_noisy, noise, alpha_cumprod = self.plms_sample_loop( 37 | self.diffusion.unets[0], 38 | image = image, 39 | shape = shape, 40 | cond_images = cond_images, 41 | cond_scale = cond_scale, 42 | noise_scheduler = self.diffusion.noise_schedulers[0], 43 | pred_objective = self.diffusion.pred_objectives[0], 44 | dynamic_threshold = self.diffusion.dynamic_thresholding[0], 45 | use_tqdm = use_tqdm, 46 | max_thres = max_thres, 47 | ) 48 | outputs.append(img) 49 | if not return_noise: 50 | return outputs[-1] 51 | return outputs[-1], x_noisy, noise, alpha_cumprod 52 | 53 | @torch.no_grad() 54 | def plms_sample_loop(self, 55 | unet, 56 | image, 57 | shape, 58 | cond_images, 59 | cond_scale, 60 | noise_scheduler, 61 | pred_objective, 62 | dynamic_threshold, 63 | use_tqdm, 64 | max_thres = None, 65 | ): 66 | ''' 67 | Sampling loop 68 | ''' 69 | batch = shape[0] 70 | device = self.diffusion.device 71 | 72 | if image is None: 73 | image = torch.randn(shape, device = device) 74 | else: 75 | assert(max_thres is not None) 76 | 77 | 78 | old_eps = [] 79 | 80 | if max_thres >= .99: 81 | noise_scheduler_short = GaussianDiffusionContinuousTimes(noise_schedule='cosine', timesteps=self.plms_steps) 82 | timesteps = noise_scheduler_short.get_sampling_timesteps(batch, device=device) 83 | noise = torch.randn_like(image) 84 | x_noisy, log_snr= noise_scheduler_short.q_sample(image, t=max_thres, noise=noise) 85 | img = image 86 | else: 87 | n_steps = min(int(max_thres * self.plms_steps * 2), self.plms_steps) 88 | # n_steps = 50 89 | noise_scheduler_short = GaussianDiffusionContinuousTimes(noise_schedule='cosine', timesteps=self.plms_steps) 90 | timesteps = noise_scheduler_short.get_sampling_timesteps_custom(batch, device=device, max_thres=max_thres, n_steps=n_steps) 91 | noise = torch.randn_like(image) 92 | img, log_snr = noise_scheduler_short.q_sample(image, t=max_thres, noise=noise) 93 | x_noisy = img 94 | 95 | for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm): 96 | is_last_timestep = times_next == 0 97 | outs = self.p_sample( 98 | unet, 99 | img, 100 | times, 101 | t_next = times_next, 102 | cond_images = cond_images, 103 | cond_scale = cond_scale, 104 | noise_scheduler = noise_scheduler, 105 | pred_objective = pred_objective, 106 | dynamic_threshold = dynamic_threshold, 107 | old_eps = old_eps 108 | ) 109 | img, pred_x0, e_t = outs 110 | old_eps.append(e_t) 111 | if len(old_eps) >= 4: 112 | old_eps.pop(0) 113 | 114 | if self.diffusion.clip_output: 115 | img.clamp_(-self.diffusion.clip_value, self.diffusion.clip_value) 116 | 117 | unnormalize_img = self.diffusion.unnormalize_img(img) 118 | alpha_cumprod = torch.sigmoid(log_snr) 119 | return unnormalize_img, x_noisy, noise, alpha_cumprod 120 | 121 | @torch.no_grad() 122 | def p_sample(self, 123 | unet, 124 | x, 125 | t, 126 | t_next, 127 | cond_images, 128 | cond_scale, 129 | noise_scheduler, 130 | pred_objective, 131 | dynamic_threshold, 132 | old_eps 133 | ): 134 | 135 | b, *_, device = *x.shape, x.device 136 | _, _, e_t = self.get_model_output(unet, x, t, t_next, cond_images, cond_scale, noise_scheduler, pred_objective, dynamic_threshold) 137 | if len(old_eps) == 0: 138 | # Pseudo Improved Euler (2nd order) 139 | # x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 140 | x_prev, pred_x0, _ = self.get_model_output(unet, x, t, t_next, cond_images, cond_scale, noise_scheduler, pred_objective, dynamic_threshold, pred_e = e_t) 141 | # e_t_next = get_model_output(x_prev, t_next) 142 | _, _, e_t_next = self.get_model_output(unet, x_prev, t_next, t_next, cond_images, cond_scale, noise_scheduler, pred_objective, dynamic_threshold) 143 | e_t_prime = (e_t + e_t_next) / 2 144 | elif len(old_eps) == 1: 145 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 146 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 147 | elif len(old_eps) == 2: 148 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 149 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 150 | elif len(old_eps) >= 3: 151 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 152 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 153 | 154 | x_prev, pred_x0, _ = self.get_model_output(unet, x, t, t_next, cond_images, cond_scale, noise_scheduler, pred_objective, dynamic_threshold, pred_e = e_t_prime) 155 | 156 | return x_prev, pred_x0, e_t 157 | 158 | def get_model_output(self, 159 | unet, 160 | x, 161 | t, 162 | t_next, 163 | cond_images, 164 | cond_scale, 165 | noise_scheduler, 166 | pred_objective, 167 | dynamic_threshold, 168 | pred_e = None, 169 | ): 170 | assert(pred_objective == 'noise') 171 | b, *_, device = *x.shape, x.device 172 | 173 | 174 | #@ PRED EPS 175 | if pred_e is None: 176 | pred = unet.forward_with_cond_scale(x, noise_scheduler.get_condition(t), cond_images = cond_images, cond_scale = cond_scale) 177 | pred_e = pred 178 | else: 179 | pred = pred_e 180 | 181 | 182 | #@ PREDICT X_0 183 | if pred_objective == 'noise': 184 | x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) 185 | elif pred_objective == 'x_start': 186 | x_start = pred 187 | else: 188 | raise ValueError(f'unknown objective {pred_objective}') 189 | 190 | #@ CLIP X_0 191 | if self.diffusion.clip_output: 192 | if dynamic_threshold: 193 | # following pseudocode in appendix 194 | # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element 195 | s = torch.quantile( 196 | rearrange(x_start, 'b ... -> b (...)').abs(), 197 | self.dynamic_thresholding_percentile, 198 | dim = -1 199 | ) 200 | 201 | s.clamp_(min = 1.) 202 | s = right_pad_dims_to(x_start, s) 203 | x_start = x_start.clamp(-s, s) / s 204 | else: 205 | x_start.clamp_(-self.diffusion.clip_value, self.diffusion.clip_value) 206 | 207 | #@ USE Q_POSTERIOR TO GET X_PREV 208 | model_mean, _, model_log_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next) 209 | noise = torch.randn_like(x) 210 | is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0) 211 | nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1))) 212 | x_prev = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 213 | 214 | return x_prev, x_start, pred_e 215 | -------------------------------------------------------------------------------- /media/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/media/teaser.jpg -------------------------------------------------------------------------------- /raymarching/README.md: -------------------------------------------------------------------------------- 1 | All code in raymarching is from https://github.com/ashawkey/torch-ngp. -------------------------------------------------------------------------------- /raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * -------------------------------------------------------------------------------- /raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_raymarching', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'raymarching.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /raymarching/raymarching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _raymarching as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | # ---------------------------------------- 16 | # utils 17 | # ---------------------------------------- 18 | 19 | class _near_far_from_aabb(Function): 20 | @staticmethod 21 | @custom_fwd(cast_inputs=torch.float32) 22 | def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): 23 | ''' near_far_from_aabb, CUDA implementation 24 | Calculate rays' intersection time (near and far) with aabb 25 | Args: 26 | rays_o: float, [N, 3] 27 | rays_d: float, [N, 3] 28 | aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) 29 | min_near: float, scalar 30 | Returns: 31 | nears: float, [N] 32 | fars: float, [N] 33 | ''' 34 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 35 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 36 | 37 | rays_o = rays_o.contiguous().view(-1, 3) 38 | rays_d = rays_d.contiguous().view(-1, 3) 39 | 40 | N = rays_o.shape[0] # num rays 41 | 42 | nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) 43 | fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) 44 | 45 | _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) 46 | 47 | return nears, fars 48 | 49 | near_far_from_aabb = _near_far_from_aabb.apply 50 | 51 | 52 | class _sph_from_ray(Function): 53 | @staticmethod 54 | @custom_fwd(cast_inputs=torch.float32) 55 | def forward(ctx, rays_o, rays_d, radius): 56 | ''' sph_from_ray, CUDA implementation 57 | get spherical coordinate on the background sphere from rays. 58 | Assume rays_o are inside the Sphere(radius). 59 | Args: 60 | rays_o: [N, 3] 61 | rays_d: [N, 3] 62 | radius: scalar, float 63 | Return: 64 | coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) 65 | ''' 66 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 67 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 68 | 69 | rays_o = rays_o.contiguous().view(-1, 3) 70 | rays_d = rays_d.contiguous().view(-1, 3) 71 | 72 | N = rays_o.shape[0] # num rays 73 | 74 | coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) 75 | 76 | _backend.sph_from_ray(rays_o, rays_d, radius, N, coords) 77 | 78 | return coords 79 | 80 | sph_from_ray = _sph_from_ray.apply 81 | 82 | 83 | class _morton3D(Function): 84 | @staticmethod 85 | def forward(ctx, coords): 86 | ''' morton3D, CUDA implementation 87 | Args: 88 | coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) 89 | TODO: check if the coord range is valid! (current 128 is safe) 90 | Returns: 91 | indices: [N], int32, in [0, 128^3) 92 | 93 | ''' 94 | if not coords.is_cuda: coords = coords.cuda() 95 | 96 | N = coords.shape[0] 97 | 98 | indices = torch.empty(N, dtype=torch.int32, device=coords.device) 99 | 100 | _backend.morton3D(coords.int(), N, indices) 101 | 102 | return indices 103 | 104 | morton3D = _morton3D.apply 105 | 106 | class _morton3D_invert(Function): 107 | @staticmethod 108 | def forward(ctx, indices): 109 | ''' morton3D_invert, CUDA implementation 110 | Args: 111 | indices: [N], int32, in [0, 128^3) 112 | Returns: 113 | coords: [N, 3], int32, in [0, 128) 114 | 115 | ''' 116 | if not indices.is_cuda: indices = indices.cuda() 117 | 118 | N = indices.shape[0] 119 | 120 | coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) 121 | 122 | _backend.morton3D_invert(indices.int(), N, coords) 123 | 124 | return coords 125 | 126 | morton3D_invert = _morton3D_invert.apply 127 | 128 | 129 | class _packbits(Function): 130 | @staticmethod 131 | @custom_fwd(cast_inputs=torch.float32) 132 | def forward(ctx, grid, thresh, bitfield=None): 133 | ''' packbits, CUDA implementation 134 | Pack up the density grid into a bit field to accelerate ray marching. 135 | Args: 136 | grid: float, [C, H * H * H], assume H % 2 == 0 137 | thresh: float, threshold 138 | Returns: 139 | bitfield: uint8, [C, H * H * H / 8] 140 | ''' 141 | if not grid.is_cuda: grid = grid.cuda() 142 | grid = grid.contiguous() 143 | 144 | C = grid.shape[0] 145 | H3 = grid.shape[1] 146 | N = C * H3 // 8 147 | 148 | if bitfield is None: 149 | bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) 150 | 151 | _backend.packbits(grid, N, thresh, bitfield) 152 | 153 | return bitfield 154 | 155 | packbits = _packbits.apply 156 | 157 | # ---------------------------------------- 158 | # train functions 159 | # ---------------------------------------- 160 | 161 | class _march_rays_train(Function): 162 | @staticmethod 163 | @custom_fwd(cast_inputs=torch.float32) 164 | def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): 165 | ''' march rays to generate points (forward only) 166 | Args: 167 | rays_o/d: float, [N, 3] 168 | bound: float, scalar 169 | density_bitfield: uint8: [CHHH // 8] 170 | C: int 171 | H: int 172 | nears/fars: float, [N] 173 | step_counter: int32, (2), used to count the actual number of generated points. 174 | mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) 175 | perturb: bool 176 | align: int, pad output so its size is dividable by align, set to -1 to disable. 177 | force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. 178 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) 179 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. 180 | Returns: 181 | xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) 182 | dirs: float, [M, 3], all generated points' view dirs. 183 | deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) 184 | rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] 185 | ''' 186 | 187 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 188 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 189 | if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() 190 | 191 | rays_o = rays_o.contiguous().view(-1, 3) 192 | rays_d = rays_d.contiguous().view(-1, 3) 193 | density_bitfield = density_bitfield.contiguous() 194 | 195 | N = rays_o.shape[0] # num rays 196 | M = N * max_steps # init max points number in total 197 | 198 | # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) 199 | # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. 200 | if not force_all_rays and mean_count > 0: 201 | if align > 0: 202 | mean_count += align - mean_count % align 203 | M = mean_count 204 | 205 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 206 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 207 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) 208 | rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps 209 | 210 | if step_counter is None: 211 | step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter 212 | 213 | if perturb: 214 | noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) 215 | else: 216 | noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) 217 | 218 | _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number 219 | 220 | #print(step_counter, M) 221 | 222 | # only used at the first (few) epochs. 223 | if force_all_rays or mean_count <= 0: 224 | m = step_counter[0].item() # D2H copy 225 | if align > 0: 226 | m += align - m % align 227 | xyzs = xyzs[:m] 228 | dirs = dirs[:m] 229 | deltas = deltas[:m] 230 | 231 | torch.cuda.empty_cache() 232 | 233 | return xyzs, dirs, deltas, rays 234 | 235 | march_rays_train = _march_rays_train.apply 236 | 237 | 238 | class _composite_rays_train(Function): 239 | @staticmethod 240 | @custom_fwd(cast_inputs=torch.float32) 241 | def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4): 242 | ''' composite rays' rgbs, according to the ray marching formula. 243 | Args: 244 | rgbs: float, [M, 3] 245 | sigmas: float, [M,] 246 | deltas: float, [M, 2] 247 | rays: int32, [N, 3] 248 | Returns: 249 | weights_sum: float, [N,], the alpha channel 250 | depth: float, [N, ], the Depth 251 | image: float, [N, 3], the RGB channel (after multiplying alpha!) 252 | ''' 253 | 254 | sigmas = sigmas.contiguous() 255 | rgbs = rgbs.contiguous() 256 | 257 | M = sigmas.shape[0] 258 | N = rays.shape[0] 259 | 260 | weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) 261 | depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) 262 | image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) 263 | 264 | _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image) 265 | 266 | ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image) 267 | ctx.dims = [M, N, T_thresh] 268 | 269 | return weights_sum, depth, image 270 | 271 | @staticmethod 272 | @custom_bwd 273 | def backward(ctx, grad_weights_sum, grad_depth, grad_image): 274 | 275 | # NOTE: grad_depth is not used now! It won't be propagated to sigmas. 276 | 277 | grad_weights_sum = grad_weights_sum.contiguous() 278 | grad_image = grad_image.contiguous() 279 | 280 | sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors 281 | M, N, T_thresh = ctx.dims 282 | 283 | grad_sigmas = torch.zeros_like(sigmas) 284 | grad_rgbs = torch.zeros_like(rgbs) 285 | 286 | _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs) 287 | 288 | return grad_sigmas, grad_rgbs, None, None, None 289 | 290 | 291 | composite_rays_train = _composite_rays_train.apply 292 | 293 | # ---------------------------------------- 294 | # infer functions 295 | # ---------------------------------------- 296 | 297 | class _march_rays(Function): 298 | @staticmethod 299 | @custom_fwd(cast_inputs=torch.float32) 300 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): 301 | ''' march rays to generate points (forward only, for inference) 302 | Args: 303 | n_alive: int, number of alive rays 304 | n_step: int, how many steps we march 305 | rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) 306 | rays_t: float, [N], the alive rays' time, we only use the first n_alive. 307 | rays_o/d: float, [N, 3] 308 | bound: float, scalar 309 | density_bitfield: uint8: [CHHH // 8] 310 | C: int 311 | H: int 312 | nears/fars: float, [N] 313 | align: int, pad output so its size is dividable by align, set to -1 to disable. 314 | perturb: bool/int, int > 0 is used as the random seed. 315 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) 316 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize. 317 | Returns: 318 | xyzs: float, [n_alive * n_step, 3], all generated points' coords 319 | dirs: float, [n_alive * n_step, 3], all generated points' view dirs. 320 | deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). 321 | ''' 322 | 323 | if not rays_o.is_cuda: rays_o = rays_o.cuda() 324 | if not rays_d.is_cuda: rays_d = rays_d.cuda() 325 | 326 | rays_o = rays_o.contiguous().view(-1, 3) 327 | rays_d = rays_d.contiguous().view(-1, 3) 328 | 329 | M = n_alive * n_step 330 | 331 | if align > 0: 332 | M += align - (M % align) 333 | 334 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 335 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) 336 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth 337 | 338 | if perturb: 339 | # torch.manual_seed(perturb) # test_gui uses spp index as seed 340 | noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) 341 | else: 342 | noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) 343 | 344 | _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises) 345 | 346 | return xyzs, dirs, deltas 347 | 348 | march_rays = _march_rays.apply 349 | 350 | 351 | class _composite_rays(Function): 352 | @staticmethod 353 | @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float 354 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2): 355 | ''' composite rays' rgbs, according to the ray marching formula. (for inference) 356 | Args: 357 | n_alive: int, number of alive rays 358 | n_step: int, how many steps we march 359 | rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) 360 | rays_t: float, [N], the alive rays' time 361 | sigmas: float, [n_alive * n_step,] 362 | rgbs: float, [n_alive * n_step, 3] 363 | deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). 364 | In-place Outputs: 365 | weights_sum: float, [N,], the alpha channel 366 | depth: float, [N,], the depth value 367 | image: float, [N, 3], the RGB channel (after multiplying alpha!) 368 | ''' 369 | _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) 370 | return tuple() 371 | 372 | 373 | composite_rays = _composite_rays.apply -------------------------------------------------------------------------------- /raymarching/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | ''' 33 | Usage: 34 | 35 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) 36 | 37 | python setup.py install # build extensions and install (copy) to PATH. 38 | pip install . # ditto but better (e.g., dependency & metadata handling) 39 | 40 | python setup.py develop # build extensions and install (symbolic) to PATH. 41 | pip install -e . # ditto but better (e.g., dependency & metadata handling) 42 | 43 | ''' 44 | setup( 45 | name='raymarching', # package name, import this to use python API 46 | ext_modules=[ 47 | CUDAExtension( 48 | name='_raymarching', # extension name, import this to use CUDA API 49 | sources=[os.path.join(_src_path, 'src', f) for f in [ 50 | 'raymarching.cu', 51 | 'bindings.cpp', 52 | ]], 53 | extra_compile_args={ 54 | 'cxx': c_flags, 55 | 'nvcc': nvcc_flags, 56 | } 57 | ), 58 | ], 59 | cmdclass={ 60 | 'build_ext': BuildExtension, 61 | } 62 | ) -------------------------------------------------------------------------------- /raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("packbits", &packbits, "packbits (CUDA)"); 8 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 9 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); 10 | m.def("morton3D", &morton3D, "morton3D (CUDA)"); 11 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); 12 | // train 13 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 14 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 15 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 16 | // infer 17 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 18 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); 19 | } -------------------------------------------------------------------------------- /raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); 9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); 10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); 11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); 12 | 13 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); 14 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 15 | void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs); 16 | 17 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); 18 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch-ema 2 | trimesh 3 | opencv-python 4 | tensorboardX 5 | torch 6 | numpy 7 | pandas 8 | tqdm 9 | matplotlib 10 | PyMCubes 11 | rich 12 | pysdf 13 | packaging 14 | scipy 15 | lpips 16 | imageio 17 | einops 18 | scikit-image 19 | omegaconf 20 | plotly -------------------------------------------------------------------------------- /utils/check_args.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Check that the provided arguments are valid. 3 | ''' 4 | 5 | import os 6 | import sys 7 | from utils.co3d_dataloader import CO3D_ALL_CATEGORIES 8 | 9 | def check_args(args): 10 | 11 | #@ CHECK ALL EXPECTED ARGS ARE PRESENT 12 | if not (args.dataset_name == 'co3d' or args.dataset_name == 'co3d_toy'): 13 | print(f'ERROR: Provided {args.dataset} as dataset, but only (co3d, co3d_toy) are supported.') 14 | print('Exiting...') 15 | exit(1) 16 | 17 | if args.category not in CO3D_ALL_CATEGORIES and args.category not in {'all_ten', 'all'}: 18 | print(f'ERROR: Provided category {args.category} is not in CO3D.') 19 | print('Exiting...') 20 | exit(1) 21 | 22 | #@ CHECK DATASET FOLDER EXISTS 23 | if not os.path.exists(args.root): 24 | print(f'ERROR: Provided dataset root {args.root} does not exist.') 25 | print('Exiting...') 26 | exit(1) 27 | 28 | #@ CHECK MODEL WEIGHTS PATHS EXIST 29 | if not os.path.exists(args.eft_ckpt): 30 | print(f'ERROR: Provided EFT weight {args.eft_ckpt} does not exist.') 31 | print('Exiting...') 32 | exit(1) 33 | 34 | if not os.path.exists(args.vldm_ckpt): 35 | print(f'ERROR: Provided VLDM weight {args.vldm_ckpt} does not exist.') 36 | print('Exiting...') 37 | exit(1) 38 | 39 | if not os.path.exists(args.vae_ckpt): 40 | print(f'ERROR: Provided VAE weight {args.vae_ckpt} does not exist.') 41 | print('Exiting...') 42 | exit(1) 43 | 44 | return 45 | -------------------------------------------------------------------------------- /utils/co3d_toy_dataloader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Wrapper for a preprocessed and simplified CO3Dv2 dataset 3 | #@ Modified from https://github.com/facebookresearch/pytorch3d 4 | ''' 5 | 6 | import os 7 | import torch 8 | 9 | class CO3Dv2ToyLoader(torch.utils.data.Dataset): 10 | 11 | def __init__(self, root, category): 12 | super().__init__() 13 | 14 | self.root = root 15 | self.category = category 16 | 17 | default_path = f'{root}/{category}/{category}_toy.pt' 18 | if not os.path.exists(default_path): 19 | print(f'ERROR: toy dataset not found at {default_path}') 20 | print('Exiting...') 21 | exit(1) 22 | 23 | dataset = torch.load(default_path) 24 | self.seq_list = dataset[category] 25 | 26 | def __len__(self): 27 | return len(self.seq_list) 28 | 29 | def __getitem__(self, index): 30 | return self.seq_list[index] -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A collection of common utilities 3 | ''' 4 | import torch 5 | import lpips 6 | import skimage.metrics 7 | 8 | 9 | def normalize(x): 10 | ''' 11 | Normalize [0, 1] to [-1, 1] 12 | ''' 13 | return torch.clip(x*2 - 1.0, -1.0, 1.0) 14 | 15 | def unnormalize(x): 16 | ''' 17 | Unnormalize [-1, 1] to [0, 1] 18 | ''' 19 | return torch.clip((x + 1.0) / 2.0, 0.0, 1.0) 20 | 21 | def split_list(a, n): 22 | ''' 23 | Split list into n parts 24 | 25 | Args: 26 | a (list): list 27 | n (int): number of parts 28 | 29 | Returns: 30 | a_split (list[list]): nested list of a split into n parts 31 | ''' 32 | k, m = divmod(len(a), n) 33 | return [a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)] 34 | 35 | 36 | def get_lpips_fn(): 37 | ''' 38 | Return LPIPS function 39 | ''' 40 | loss_fn_vgg = lpips.LPIPS(net='vgg') 41 | return loss_fn_vgg 42 | 43 | 44 | def get_metrics(pred, gt, use_lpips=False, loss_fn_vgg=None, device=None): 45 | ''' 46 | Compute image metrics 47 | 48 | Args: 49 | pred (np array): (H, W, 3) 50 | gt (np array): (H, W, 3) 51 | ''' 52 | ssim = skimage.metrics.structural_similarity(pred, gt, channel_axis = -1, data_range=1) 53 | psnr = skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=1) 54 | if use_lpips: 55 | if loss_fn_vgg is None: 56 | loss_fn_vgg = lpips.LPIPS(net='vgg') 57 | gt_ = torch.from_numpy(gt).permute(2,0,1).unsqueeze(0)*2 - 1.0 58 | pred_ = torch.from_numpy(pred).permute(2,0,1).unsqueeze(0)*2 - 1.0 59 | if device is not None: 60 | lp = loss_fn_vgg(gt_.to(device), pred_.to(device)).detach().cpu().numpy().item() 61 | else: 62 | lp = loss_fn_vgg(gt_, pred_).detach().numpy().item() 63 | return ssim, psnr, lp 64 | return ssim, psnr 65 | 66 | 67 | #@ FROM PyTorch3D 68 | class HarmonicEmbedding(torch.nn.Module): 69 | def __init__( 70 | self, 71 | n_harmonic_functions: int = 6, 72 | omega_0: float = 1.0, 73 | logspace: bool = True, 74 | append_input: bool = True, 75 | ) -> None: 76 | """ 77 | Given an input tensor `x` of shape [minibatch, ... , dim], 78 | the harmonic embedding layer converts each feature 79 | (i.e. vector along the last dimension) in `x` 80 | into a series of harmonic features `embedding`, 81 | where for each i in range(dim) the following are present 82 | in embedding[...]: 83 | ``` 84 | [ 85 | sin(f_1*x[..., i]), 86 | sin(f_2*x[..., i]), 87 | ... 88 | sin(f_N * x[..., i]), 89 | cos(f_1*x[..., i]), 90 | cos(f_2*x[..., i]), 91 | ... 92 | cos(f_N * x[..., i]), 93 | x[..., i], # only present if append_input is True. 94 | ] 95 | ``` 96 | where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar 97 | denoting the i-th frequency of the harmonic embedding. 98 | If `logspace==True`, the frequencies `[f_1, ..., f_N]` are 99 | powers of 2: 100 | `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` 101 | If `logspace==False`, frequencies are linearly spaced between 102 | `1.0` and `2**(n_harmonic_functions-1)`: 103 | `f_1, ..., f_N = torch.linspace( 104 | 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions 105 | )` 106 | Note that `x` is also premultiplied by the base frequency `omega_0` 107 | before evaluating the harmonic functions. 108 | Args: 109 | n_harmonic_functions: int, number of harmonic 110 | features 111 | omega_0: float, base frequency 112 | logspace: bool, Whether to space the frequencies in 113 | logspace or linear space 114 | append_input: bool, whether to concat the original 115 | input to the harmonic embedding. If true the 116 | output is of the form (x, embed.sin(), embed.cos() 117 | """ 118 | super().__init__() 119 | 120 | if logspace: 121 | frequencies = 2.0 ** torch.arange( 122 | n_harmonic_functions, 123 | dtype=torch.float32, 124 | ) 125 | else: 126 | frequencies = torch.linspace( 127 | 1.0, 128 | 2.0 ** (n_harmonic_functions - 1), 129 | n_harmonic_functions, 130 | dtype=torch.float32, 131 | ) 132 | 133 | self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) 134 | self.append_input = append_input 135 | 136 | def forward(self, x: torch.Tensor) -> torch.Tensor: 137 | """ 138 | Args: 139 | x: tensor of shape [..., dim] 140 | Returns: 141 | embedding: a harmonic embedding of `x` 142 | of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim] 143 | """ 144 | embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1) 145 | embed = torch.cat( 146 | (embed.sin(), embed.cos(), x) 147 | if self.append_input 148 | else (embed.sin(), embed.cos()), 149 | dim=-1, 150 | ) 151 | return embed 152 | 153 | @staticmethod 154 | def get_output_dim_static( 155 | input_dims: int, 156 | n_harmonic_functions: int, 157 | append_input: bool, 158 | ) -> int: 159 | """ 160 | Utility to help predict the shape of the output of `forward`. 161 | Args: 162 | input_dims: length of the last dimension of the input tensor 163 | n_harmonic_functions: number of embedding frequencies 164 | append_input: whether or not to concat the original 165 | input to the harmonic embedding 166 | Returns: 167 | int: the length of the last dimension of the output tensor 168 | """ 169 | return input_dims * (2 * n_harmonic_functions + int(append_input)) 170 | 171 | def get_output_dim(self, input_dims: int = 3) -> int: 172 | """ 173 | Same as above. The default for input_dims is 3 for 3D applications 174 | which use harmonic embedding for positional encoding, 175 | so the input might be xyz. 176 | """ 177 | return self.get_output_dim_static( 178 | input_dims, len(self._frequencies), self.append_input 179 | ) 180 | 181 | 182 | #@ From PyTorch3D 183 | def huber(x, y, scaling=0.1): 184 | """ 185 | A helper function for evaluating the smooth L1 (huber) loss 186 | between the rendered silhouettes and colors. 187 | """ 188 | diff_sq = (x - y) ** 2 189 | loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling) 190 | return loss 191 | 192 | 193 | #@ From PyTorch3D 194 | def sample_images_at_mc_locs(target_images, sampled_rays_xy): 195 | """ 196 | Given a set of Monte Carlo pixel locations `sampled_rays_xy`, 197 | this method samples the tensor `target_images` at the 198 | respective 2D locations. 199 | 200 | This function is used in order to extract the colors from 201 | ground truth images that correspond to the colors 202 | rendered using `MonteCarloRaysampler`. 203 | """ 204 | ba = target_images.shape[0] 205 | dim = target_images.shape[-1] 206 | spatial_size = sampled_rays_xy.shape[1:-1] 207 | # In order to sample target_images, we utilize 208 | # the grid_sample function which implements a 209 | # bilinear image sampler. 210 | # Note that we have to invert the sign of the 211 | # sampled ray positions to convert the NDC xy locations 212 | # of the MonteCarloRaysampler to the coordinate 213 | # convention of grid_sample. 214 | 215 | if target_images.shape[2] != target_images.shape[3]: 216 | target_images = target_images.permute(0, 3, 1, 2) 217 | elif target_images.shape[2] == target_images.shape[3]: 218 | dim = target_images.shape[1] 219 | 220 | images_sampled = torch.nn.functional.grid_sample( 221 | target_images, 222 | -sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion 223 | align_corners=True 224 | ) 225 | return images_sampled.permute(0, 2, 3, 1).view( 226 | ba, *spatial_size, dim 227 | ) 228 | -------------------------------------------------------------------------------- /utils/eft_raymarcher.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Custom PyTorch3D Raymarcher for EFT 3 | #@ Modified from https://github.com/facebookresearch/pytorch3d 4 | ''' 5 | 6 | from typing import Optional, Tuple, Union 7 | import torch 8 | 9 | from pytorch3d.renderer import EmissionAbsorptionRaymarcher 10 | from pytorch3d.renderer.implicit.raymarching import ( 11 | _check_density_bounds, 12 | _check_raymarcher_inputs, 13 | _shifted_cumprod, 14 | ) 15 | 16 | class LightFieldRaymarcher(torch.nn.Module): 17 | """ 18 | A nominal ray marcher that returns LightField features without any raymarching 19 | """ 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def forward( 24 | self, 25 | rays_densities: torch.Tensor, 26 | rays_features: torch.Tensor, 27 | **kwargs 28 | ) -> Union[None, torch.Tensor]: 29 | """ 30 | """ 31 | return torch.cat((rays_densities, rays_features), dim=-1) -------------------------------------------------------------------------------- /utils/eft_renderer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Custom PyTorch3D Implicit Renderer 3 | #@ Modified from https://github.com/facebookresearch/pytorch3d 4 | ''' 5 | 6 | from typing import Callable, Tuple 7 | from einops import rearrange, repeat, reduce 8 | import torch 9 | 10 | from pytorch3d.ops.utils import eyes 11 | from pytorch3d.structures import Volumes 12 | from pytorch3d.transforms import Transform3d 13 | from pytorch3d.renderer.cameras import CamerasBase 14 | from pytorch3d.renderer.implicit.raysampling import RayBundle 15 | from pytorch3d.renderer.implicit.utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points, ray_bundle_to_ray_points 16 | 17 | #@ MODIFIED FROM PYTORCH3D 18 | class CustomImplicitRenderer(torch.nn.Module): 19 | """ 20 | A class for rendering a batch of implicit surfaces. The class should 21 | be initialized with a raysampler and raymarcher class which both have 22 | to be a `Callable`. 23 | VOLUMETRIC_FUNCTION 24 | The `forward` function of the renderer accepts as input the rendering cameras 25 | as well as the `volumetric_function` `Callable`, which defines a field of opacity 26 | and feature vectors over the 3D domain of the scene. 27 | A standard `volumetric_function` has the following signature: 28 | ``` 29 | def volumetric_function( 30 | ray_bundle: RayBundle, 31 | **kwargs, 32 | ) -> Tuple[torch.Tensor, torch.Tensor] 33 | ``` 34 | With the following arguments: 35 | `ray_bundle`: A RayBundle object containing the following variables: 36 | `origins`: A tensor of shape `(minibatch, ..., 3)` denoting 37 | the origins of the rendering rays. 38 | `directions`: A tensor of shape `(minibatch, ..., 3)` 39 | containing the direction vectors of rendering rays. 40 | `lengths`: A tensor of shape 41 | `(minibatch, ..., num_points_per_ray)`containing the 42 | lengths at which the ray points are sampled. 43 | `xys`: A tensor of shape 44 | `(minibatch, ..., 2)` containing the 45 | xy locations of each ray's pixel in the screen space. 46 | Calling `volumetric_function` then returns the following: 47 | `rays_densities`: A tensor of shape 48 | `(minibatch, ..., num_points_per_ray, opacity_dim)` containing 49 | the an opacity vector for each ray point. 50 | `rays_features`: A tensor of shape 51 | `(minibatch, ..., num_points_per_ray, feature_dim)` containing 52 | the an feature vector for each ray point. 53 | Note that, in order to increase flexibility of the API, we allow multiple 54 | other arguments to enter the volumetric function via additional 55 | (optional) keyword arguments `**kwargs`. 56 | A typical use-case is passing a `CamerasBase` object as an additional 57 | keyword argument, which can allow the volumetric function to adjust its 58 | outputs based on the directions of the projection rays. 59 | Example: 60 | A simple volumetric function of a 0-centered 61 | RGB sphere with a unit diameter is defined as follows: 62 | ``` 63 | def volumetric_function( 64 | ray_bundle: RayBundle, 65 | **kwargs, 66 | ) -> Tuple[torch.Tensor, torch.Tensor]: 67 | # first convert the ray origins, directions and lengths 68 | # to 3D ray point locations in world coords 69 | rays_points_world = ray_bundle_to_ray_points(ray_bundle) 70 | # set the densities as an inverse sigmoid of the 71 | # ray point distance from the sphere centroid 72 | rays_densities = torch.sigmoid( 73 | -100.0 * rays_points_world.norm(dim=-1, keepdim=True) 74 | ) 75 | # set the ray features to RGB colors proportional 76 | # to the 3D location of the projection of ray points 77 | # on the sphere surface 78 | rays_features = torch.nn.functional.normalize( 79 | rays_points_world, dim=-1 80 | ) * 0.5 + 0.5 81 | return rays_densities, rays_features 82 | ``` 83 | """ 84 | 85 | def __init__(self, raysampler: Callable, raymarcher: Callable, reg=None) -> None: 86 | """ 87 | Args: 88 | raysampler: A `Callable` that takes as input scene cameras 89 | (an instance of `CamerasBase`) and returns a `RayBundle` that 90 | describes the rays emitted from the cameras. 91 | raymarcher: A `Callable` that receives the response of the 92 | `volumetric_function` (an input to `self.forward`) evaluated 93 | along the sampled rays, and renders the rays with a 94 | ray-marching algorithm. 95 | """ 96 | super().__init__() 97 | 98 | if not callable(raysampler): 99 | raise ValueError('"raysampler" has to be a "Callable" object.') 100 | if not callable(raymarcher): 101 | raise ValueError('"raymarcher" has to be a "Callable" object.') 102 | 103 | self.raysampler = raysampler 104 | self.raymarcher = raymarcher 105 | self.reg = reg 106 | 107 | def forward( 108 | self, cameras: CamerasBase, volumetric_function: Callable, **kwargs 109 | ) -> Tuple[torch.Tensor, RayBundle]: 110 | """ 111 | Render a batch of images using a volumetric function 112 | represented as a callable (e.g. a Pytorch module). 113 | Args: 114 | cameras: A batch of cameras that render the scene. A `self.raysampler` 115 | takes the cameras as input and samples rays that pass through the 116 | domain of the volumetric function. 117 | volumetric_function: A `Callable` that accepts the parametrizations 118 | of the rendering rays and returns the densities and features 119 | at the respective 3D of the rendering rays. Please refer to 120 | the main class documentation for details. 121 | Returns: 122 | images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)` 123 | containing the result of the rendering. 124 | ray_bundle: A `RayBundle` containing the parametrizations of the 125 | sampled rendering rays. 126 | """ 127 | 128 | if not callable(volumetric_function): 129 | raise ValueError('"volumetric_function" has to be a "Callable" object.') 130 | 131 | # first call the ray sampler that returns the RayBundle parametrizing 132 | # the rendering rays. 133 | ray_bundle = self.raysampler( 134 | cameras=cameras, volumetric_function=volumetric_function, **kwargs 135 | ) 136 | # ray_bundle.origins - minibatch x ... x 3 137 | # ray_bundle.directions - minibatch x ... x 3 138 | # ray_bundle.lengths - minibatch x ... x n_pts_per_ray 139 | # ray_bundle.xys - minibatch x ... x 2 140 | 141 | # given sampled rays, call the volumetric function that 142 | # evaluates the densities and features at the locations of the 143 | # ray points 144 | if self.reg is not None: 145 | rays_densities, rays_features, reg_term = volumetric_function( 146 | ray_bundle=ray_bundle, cameras=cameras, **kwargs 147 | ) 148 | else: 149 | rays_densities, rays_features, _ = volumetric_function( 150 | ray_bundle=ray_bundle, cameras=cameras, **kwargs 151 | ) 152 | # ray_densities - minibatch x ... x n_pts_per_ray x density_dim 153 | # ray_features - minibatch x ... x n_pts_per_ray x feature_dim 154 | 155 | # finally, march along the sampled rays to obtain the renders 156 | images = self.raymarcher( 157 | rays_densities=rays_densities, 158 | rays_features=rays_features, 159 | ray_bundle=ray_bundle, 160 | **kwargs 161 | ) 162 | # images - minibatch x ... x (feature_dim + opacity_dim) 163 | 164 | if self.reg is not None: 165 | return images, ray_bundle, reg_term 166 | else: 167 | return images, ray_bundle, 0 -------------------------------------------------------------------------------- /utils/load_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.co3d_dataloader import CO3Dv2Wrapper 3 | from utils.co3d_toy_dataloader import CO3Dv2ToyLoader 4 | 5 | def load_dataset_test(args, image_size=None, masked=True): 6 | ''' 7 | Load test dataset 8 | ''' 9 | 10 | if args.dataset_name == 'co3d': 11 | 12 | if image_size is None: 13 | test_dataset = CO3Dv2Wrapper(root=args.root, category=args.category, sample_batch_size=32, subset='fewview_dev', stage='test', masked=masked) 14 | else: 15 | test_dataset = CO3Dv2Wrapper(root=args.root, category=args.category, sample_batch_size=32, subset='fewview_dev', stage='test', image_size=image_size, masked=masked) 16 | 17 | elif args.dataset_name == 'co3d_toy': 18 | test_dataset = CO3Dv2ToyLoader(root=args.root, category=args.category) 19 | 20 | else: 21 | raise NotImplementedError 22 | 23 | return test_dataset -------------------------------------------------------------------------------- /utils/load_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Functions to load models 3 | ''' 4 | from collections import OrderedDict 5 | import importlib 6 | from omegaconf import OmegaConf 7 | import torch 8 | from external.imagen_pytorch import Unet 9 | from sparsefusion.eft import EpipolarFeatureTransformer 10 | from sparsefusion.vldm import DDPM 11 | 12 | def load_models(gpu, args, verbose=False): 13 | ''' 14 | Loads models 15 | 16 | Args: 17 | gpu (int): gpu id 18 | args (Namespace): model args 19 | eft_ckpt (str): path to eft checkpoint (optional) 20 | vae_ckpt (str): path to vae checkpoint (optional) 21 | vldm_ckpt (str): path to vldm checkpoint (optional) 22 | Returns: 23 | eft (PyTorch module): eft feature extractor 24 | vae (PyTorch module): stable diffusion vae 25 | vldm (PyTorch module): vldm diffusion modle 26 | ''' 27 | 28 | eft = None 29 | vae = None 30 | vldm = None 31 | 32 | #! LOAD EFT 33 | eft = EpipolarFeatureTransformer(use_r=args.use_r, encoder=args.encoder, return_features=True, remove_unused_layers=False).cuda(gpu) 34 | if args.eft_ckpt is not None: 35 | checkpoint = torch.load(args.eft_ckpt, map_location='cpu') 36 | 37 | model_dict = eft.state_dict() 38 | pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict} 39 | model_dict.update(pretrained_dict) 40 | eft.load_state_dict(model_dict) 41 | 42 | print('LOADING 1/3 loaded eft checkpoint from', args.eft_ckpt) 43 | 44 | else: 45 | print('LOADING 1/3 initialized eft from scratch') 46 | 47 | 48 | #! LOAD VAE 49 | if args.vae_ckpt is not None: 50 | vae = load_vae(args.vae_ckpt, verbose=verbose).cuda(gpu) 51 | print('LOADING 2/3 loaded sd vae from', args.vae_ckpt) 52 | 53 | else: 54 | vae = load_vae(None) 55 | print('LOADING 2/3 initialized vae from scratch') 56 | 57 | #! LOAD UNet 58 | channels = 4 59 | feature_dim = 256 60 | unet1 = Unet( 61 | channels=channels, 62 | dim = 256, 63 | dim_mults = (1, 2, 4, 4), 64 | num_resnet_blocks = (2, 2, 2, 2), 65 | layer_attns = (False, False, False, True), 66 | layer_cross_attns = (False, False, False, False), 67 | cond_images_channels = feature_dim, 68 | attn_pool_text=False 69 | ) 70 | 71 | if verbose: 72 | total_params = sum(p.numel() for p in unet1.parameters()) 73 | print(f"{unet1.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 74 | 75 | #! LOAD DIFFUSION 76 | vldm = DDPM( 77 | channels=channels, 78 | unets = (unet1, ), 79 | conditional_encoder = None, 80 | conditional_embed_dim = None, 81 | image_sizes = (32, ), 82 | timesteps = 500, 83 | cond_drop_prob = 0.1, 84 | pred_objectives=args.objective, 85 | conditional=False, 86 | auto_normalize_img=False, 87 | clip_output=True, 88 | dynamic_thresholding=False, 89 | dynamic_thresholding_percentile=.68, 90 | clip_value=10, 91 | ).cuda(gpu) 92 | if args.vldm_ckpt is not None: 93 | checkpoint = torch.load(args.vldm_ckpt, map_location='cpu') 94 | vldm.load_state_dict(checkpoint['model_state_dict']) 95 | print('LOADING 3/3 loaded diffusion from', args.vldm_ckpt) 96 | else: 97 | print('LOADING 3/3 loaded diffusion from', 'scratch') 98 | 99 | 100 | return eft, vae, vldm 101 | 102 | 103 | def load_vae(vae_ckpt, verbose=False): 104 | ''' 105 | Load StableDiffusion VAE 106 | ''' 107 | config = OmegaConf.load(f"external/ldm/configs/sd-vae.yaml") 108 | vae = load_model_from_config(config, ckpt=vae_ckpt, verbose=verbose) 109 | vae = vae 110 | return vae 111 | 112 | 113 | def get_obj_from_str(string, reload=False): 114 | module, cls = string.rsplit(".", 1) 115 | if reload: 116 | module_imp = importlib.import_module(module) 117 | importlib.reload(module_imp) 118 | return getattr(importlib.import_module(module, package=None), cls) 119 | 120 | 121 | def instantiate_from_config(config): 122 | if not "target" in config: 123 | if config == '__is_first_stage__': 124 | return None 125 | elif config == "__is_unconditional__": 126 | return None 127 | raise KeyError("Expected key `target` to instantiate.") 128 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 129 | 130 | 131 | def load_model_from_config(config, ckpt=None, verbose=True, full_model=False): 132 | 133 | model = instantiate_from_config(config.model) 134 | 135 | if ckpt is not None: 136 | if verbose: 137 | print(f"Loading model from {ckpt}") 138 | pl_sd = torch.load(ckpt, map_location="cpu") 139 | if "global_step" in pl_sd: 140 | if verbose: 141 | print(f"Global Step: {pl_sd['global_step']}") 142 | sd = pl_sd["state_dict"] 143 | 144 | #! Rename state dict params 145 | sd_ = OrderedDict() 146 | for k, v in sd.items(): 147 | if not full_model: 148 | name = k.replace('first_stage_model.','').replace('model.','') 149 | else: 150 | name = k 151 | sd_[name] = v 152 | 153 | m, u = model.load_state_dict(sd_, strict=False) 154 | 155 | if len(m) > 0 and verbose: 156 | print("missing keys:") 157 | parent_set = set() 158 | for uk in m: 159 | uk_ = uk[:uk.find('.')] 160 | if uk_ not in parent_set: 161 | parent_set.add(uk_) 162 | print(parent_set) 163 | else: 164 | if verbose: 165 | print('all weights found') 166 | if len(u) > 0 and verbose: 167 | print("unexpected keys:") 168 | print(len(u)) 169 | parent_set = set() 170 | for uk in u: 171 | uk_ = uk[:uk.find('.')] 172 | if uk_ not in parent_set: 173 | parent_set.add(uk_) 174 | print(parent_set) 175 | else: 176 | if verbose: 177 | print('initializing from scratch') 178 | 179 | model.eval() 180 | return model -------------------------------------------------------------------------------- /utils/render_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Initialize Renderers 3 | ''' 4 | 5 | import torch 6 | from pytorch3d.renderer import ( 7 | PerspectiveCameras, 8 | MonteCarloRaysampler, 9 | EmissionAbsorptionRaymarcher, 10 | GridRaysampler, 11 | ) 12 | from utils.eft_renderer import CustomImplicitRenderer 13 | from utils.eft_raymarcher import LightFieldRaymarcher 14 | 15 | 16 | def init_ray_sampler(gpu, img_h, img_w, min=0.1, max=4.0, bbox=None, n_pts_per_ray=128, n_rays=750, scale_factor=None): 17 | ''' 18 | Construct ray samplers for torch-ngp 19 | 20 | Args: 21 | gpu (int): gpu id 22 | img_h (int): image height 23 | img_w (int): image width 24 | min (int): min depth for point along ray 25 | max (int): max depth for point along ray 26 | bbox (List): bounding box for monte carlo sampler 27 | n_pts_per_ray (int): number of points along a ray 28 | n_rays (int): number of rays for monte carlo sampler 29 | scale_factor (int): return a grid sampler at a scale factor 30 | 31 | Returns: 32 | sampler_grid (sampler): a grid sampler at full resolution 33 | sampler_mc (sampler): a monte carlo sampler 34 | sampler_feat (sampler): a grid sampler at scale factor resolution 35 | if scale factor is provided 36 | ''' 37 | 38 | img_h, img_w = img_h, img_w 39 | volume_extent_world = max 40 | half_pix_width = 1.0 / img_w 41 | half_pix_height = 1.0 / img_h 42 | 43 | raysampler_grid = GridRaysampler( 44 | min_x=1.0 - half_pix_width, 45 | max_x=-1.0 + half_pix_width, 46 | min_y=1.0 - half_pix_height, 47 | max_y=-1.0 + half_pix_height, 48 | image_height=img_h, 49 | image_width=img_w, 50 | n_pts_per_ray=n_pts_per_ray, 51 | min_depth=min, 52 | max_depth=volume_extent_world, 53 | ) 54 | if scale_factor is not None: 55 | raysampler_features = GridRaysampler( 56 | min_x=1.0 - half_pix_width, 57 | max_x=-1.0 + half_pix_width, 58 | min_y=1.0 - half_pix_height, 59 | max_y=-1.0 + half_pix_height, 60 | image_height=int(img_h//scale_factor), 61 | image_width=int(img_w//scale_factor), 62 | n_pts_per_ray=20, 63 | min_depth=min, 64 | max_depth=volume_extent_world, 65 | ) 66 | if bbox is None: 67 | raysampler_mc = MonteCarloRaysampler( 68 | min_x = -1.0, 69 | max_x = 1.0, 70 | min_y = -1.0, 71 | max_y = 1.0, 72 | n_rays_per_image=n_rays, 73 | n_pts_per_ray=n_pts_per_ray, 74 | min_depth=min, 75 | max_depth=volume_extent_world, 76 | ) 77 | elif bbox is not None: 78 | raysampler_mc = MonteCarloRaysampler( 79 | min_x = -bbox[0,1], 80 | max_x = -bbox[0,3], 81 | min_y = -bbox[0,0], 82 | max_y = -bbox[0,2], 83 | n_rays_per_image=n_rays, 84 | n_pts_per_ray=n_pts_per_ray, 85 | min_depth=min, 86 | max_depth=volume_extent_world, 87 | ) 88 | 89 | if scale_factor is not None: 90 | return raysampler_grid, raysampler_mc, raysampler_features 91 | else: 92 | return raysampler_grid, raysampler_mc 93 | 94 | 95 | def init_light_field_renderer(gpu, img_h, img_w, min=0.1, max=4.0, bbox=None, n_pts_per_ray=128, n_rays=750, scale_factor=None): 96 | ''' 97 | Construct implicit renderers for EFT 98 | 99 | Args: 100 | gpu (int): gpu id 101 | img_h (int): image height 102 | img_w (int): image width 103 | min (int): min depth for point along ray 104 | max (int): max depth for point along ray 105 | bbox (List): bounding box for monte carlo sampler 106 | n_pts_per_ray (int): number of points along a ray 107 | n_rays (int): number of rays for monte carlo sampler 108 | scale_factor (int): return a grid sampler at a scale factor 109 | 110 | Returns: 111 | renderer_grid (renderer): a grid renderer at full resolution 112 | renderer_mc (renderer): a monte carlo renderer 113 | renderer_feat (renderer): a grid renderer at scale factor resolution 114 | if scale factor is provided 115 | ''' 116 | 117 | img_h, img_w = img_h, img_w 118 | volume_extent_world = max 119 | half_pix_width = 1.0 / img_w 120 | half_pix_height = 1.0 / img_h 121 | 122 | raysampler_grid = GridRaysampler( 123 | min_x=1.0 - half_pix_width, 124 | max_x=-1.0 + half_pix_width, 125 | min_y=1.0 - half_pix_height, 126 | max_y=-1.0 + half_pix_height, 127 | image_height=img_h, 128 | image_width=img_w, 129 | n_pts_per_ray=n_pts_per_ray, 130 | min_depth=min, 131 | max_depth=volume_extent_world, 132 | ) 133 | if scale_factor is not None: 134 | raysampler_features = GridRaysampler( 135 | min_x=1.0 - half_pix_width, 136 | max_x=-1.0 + half_pix_width, 137 | min_y=1.0 - half_pix_height, 138 | max_y=-1.0 + half_pix_height, 139 | image_height=int(img_h//scale_factor), 140 | image_width=int(img_w//scale_factor), 141 | n_pts_per_ray=20, 142 | min_depth=min, 143 | max_depth=volume_extent_world, 144 | ) 145 | if bbox is None: 146 | raysampler_mc = MonteCarloRaysampler( 147 | min_x = -1.0, 148 | max_x = 1.0, 149 | min_y = -1.0, 150 | max_y = 1.0, 151 | n_rays_per_image=n_rays, 152 | n_pts_per_ray=n_pts_per_ray, 153 | min_depth=min, 154 | max_depth=volume_extent_world, 155 | ) 156 | elif bbox is not None: 157 | raysampler_mc = MonteCarloRaysampler( 158 | min_x = -bbox[0,1], 159 | max_x = -bbox[0,3], 160 | min_y = -bbox[0,0], 161 | max_y = -bbox[0,2], 162 | n_rays_per_image=n_rays, 163 | n_pts_per_ray=n_pts_per_ray, 164 | min_depth=min, 165 | max_depth=volume_extent_world, 166 | ) 167 | 168 | raymarcher = LightFieldRaymarcher() 169 | 170 | renderer_grid = CustomImplicitRenderer( 171 | raysampler=raysampler_grid, raymarcher=raymarcher, reg=True 172 | ) 173 | renderer_mc = CustomImplicitRenderer( 174 | raysampler=raysampler_mc, raymarcher=raymarcher, reg=True 175 | ) 176 | 177 | renderer_grid = renderer_grid.cuda(gpu) 178 | renderer_mc = renderer_mc.cuda(gpu) 179 | 180 | if scale_factor is None: 181 | return renderer_grid, renderer_mc 182 | else: 183 | renderer_feat = CustomImplicitRenderer( 184 | raysampler=raysampler_features, raymarcher=raymarcher, reg=True 185 | ) 186 | renderer_feat = renderer_feat.cuda(gpu) 187 | return renderer_grid, renderer_mc, renderer_feat --------------------------------------------------------------------------------