├── .gitignore
├── ENVIRONMENT.md
├── README.md
├── demo.py
├── external
├── README.md
├── einops_exts.py
├── external_utils.py
├── gridencoder
│ ├── __init__.py
│ ├── backend.py
│ ├── grid.py
│ ├── setup.py
│ └── src
│ │ ├── bindings.cpp
│ │ ├── gridencoder.cu
│ │ └── gridencoder.h
├── imagen_pytorch.py
├── ldm
│ ├── configs
│ │ └── sd-vae.yaml
│ ├── models
│ │ └── autoencoder.py
│ ├── modules
│ │ ├── attention.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ ├── image_degradation
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ ├── utils
│ │ │ │ └── test.png
│ │ │ └── utils_image.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ └── util.py
├── nerf
│ ├── clip_utils.py
│ ├── gui.py
│ ├── network.py
│ ├── network_df.py
│ ├── network_ff.py
│ ├── network_grid.py
│ ├── network_tcnn.py
│ ├── provider.py
│ ├── renderer.py
│ ├── renderer_df.py
│ └── utils.py
├── ngp_activation.py
├── ngp_encoder.py
└── plms.py
├── media
└── teaser.jpg
├── raymarching
├── README.md
├── __init__.py
├── backend.py
├── raymarching.py
├── setup.py
└── src
│ ├── bindings.cpp
│ ├── raymarching.cu
│ └── raymarching.h
├── requirements.txt
├── sparsefusion
├── distillation.py
├── eft.py
└── vldm.py
├── train.py
└── utils
├── camera_utils.py
├── check_args.py
├── co3d_dataloader.py
├── co3d_toy_dataloader.py
├── common_utils.py
├── eft_raymarcher.py
├── eft_renderer.py
├── load_dataset.py
├── load_model.py
└── render_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info/
2 | **/__pycache__/
3 | **/build/
4 | **build**
5 | vis/
6 | output/
7 | out/
8 | logs/
9 | data/
10 | checkpoints/
11 | *.zip
12 | *.gz
13 | *.tar
14 | *.pt
15 | *.so
--------------------------------------------------------------------------------
/ENVIRONMENT.md:
--------------------------------------------------------------------------------
1 | # Environment Setup
2 | We describe steps (in Linux command line) to setup the environment for SparseFusion.
3 |
4 | ## Conda Environment
5 | We install and setup a conda environment.
6 |
7 | ### (optional) Install Conda
8 | Required if conda not installed.
9 | ```bash
10 | cd ~
11 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
12 | chmod +x Miniconda3-latest-Linux-x86_64.sh
13 | ./Miniconda3-latest-Linux-x86_64.sh
14 | export PATH="/home/username/miniconda/bin:$PATH"
15 | conda init
16 | source ~/.bashrc
17 | ```
18 |
19 | ### Create New Environment
20 | ```bash
21 | conda create -n sparsefusion python=3.8
22 | conda activate sparsefusion
23 | ```
24 |
25 | ## Install Dependencies
26 | We install the necessary dependencies.
27 |
28 | ### GCC and Cuda
29 | Make sure to do this first!
30 |
31 | We also assume that nvidia drivers and `cuda=11.3.x` is installed.
32 | ```bash
33 | conda install -c conda-forge cxx-compiler=1.3.0
34 | conda install -c conda-forge cudatoolkit-dev
35 | conda install -c conda-forge ninja
36 | ```
37 |
38 | ### Python Libraries
39 | ```bash
40 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
41 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath
42 | conda install -c pytorch3d pytorch3d
43 | ```
44 |
45 | ### Support Stable Diffusion
46 | ```bash
47 | pip install transformers==4.19.2 pytorch-lightning==1.4.2 torchmetrics==0.6.0
48 | pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
49 | ```
50 |
51 | ### (optional) Install CO3D
52 | Required if using full CO3D dataset.
53 | ```bash
54 | git clone https://github.com/facebookresearch/co3d
55 | cd co3d
56 | pip install -r requirements.txt
57 | pip install -e .
58 | ```
59 |
60 | ### Install Other SparseFusion Requirements
61 | ```bash
62 | cd sparsefusion
63 | pip install -r requirements.txt
64 | ```
65 |
66 | ## Building Extensions
67 | We require a few extensions from [torch-ngp](https://github.com/ashawkey/torch-ngp). We detail how to install them below. See more details on the [torch-ngp](https://github.com/ashawkey/torch-ngp) Github.
68 |
69 | ### Build gridencoder
70 | ```bash
71 | pip install ./external/gridencoder
72 | ```
73 |
74 | ### Build raymarcher
75 | ```bash
76 | pip install ./raymarching
77 | ```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SparseFusion
2 |
3 | [**SparseFusion: Distilling View-conditioned Diffusion for 3D Reconstruction**](https://sparsefusion.github.io/)
4 | [Zhizhuo Zhou](https://www.zhiz.dev/),
5 | [Shubham Tulsiani](https://shubhtuls.github.io/)
6 | _CVPR '23 | [GitHub](https://github.com/zhizdev/sparsefusion) | [arXiv](https://arxiv.org/abs/2212.00792) | [Project page](https://sparsefusion.github.io/)_
7 |
8 | 
9 | SparseFusion reconstructs a consistent and realistic 3D neural scene representation from as few as 2 input images with known relative pose. SparseFusion is able to generate detailed and plausible structures in uncertain or unobserved regions (such as front of the hydrant, teddybear's face, back of the laptop, or left side of the toybus).
10 |
11 | ---
12 | ## Shoutouts and Credits
13 | This project is built on top of open-source code. We thank the open-source research community and credit our use of parts of [Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Imagen Pytorch](https://github.com/lucidrains/imagen-pytorch), and [torch-ngp](https://github.com/ashawkey/torch-ngp) [below](#acknowledgements).
14 |
15 |
16 | ## Code
17 | Our code release contains:
18 |
19 | 1. Code for inference
20 | 2. Code for training
21 | 3. Pretrained weights for 10 categories
22 |
23 | For bugs and issues, please open an issue on GitHub and I will try to address it promptly.
24 |
25 | ---
26 | ## Environment Setup
27 | Please follow the environment setup guide in [ENVIRONMENT.md](ENVIRONMENT.md).
28 |
29 | ## Dataset
30 | We provide two options for datasets, the original CO3Dv2 dataset and also a heavily cutdown toy dataset for demonstration purposes only. Please download at least one dataset.
31 |
32 | 1. (optional) Download CO3Dv2 dataset (5.5TB) [here](https://github.com/facebookresearch/co3d) and follow instructions to extract them to a folder. We assume the default location to be `data/co3d/{category_name}`.
33 | 2. Download the toy evaluation only CO3Dv2 dataset (6.7GB) [here](https://drive.google.com/drive/folders/1IzgFjdgm_RjCHe2WOkIQa4BRdgKuSglL?usp=share_link) and place them in a folder. We assume the default location to be `data/co3d_toy/{category_name}`.
34 |
35 | ## Pretrained Weights
36 | SparseFusion requires both SparseFusion weights and Stable Diffusion VAE weights.
37 | 1. Find SparseFusion weights [here](https://drive.google.com/drive/folders/1Czsnf-PVjwH-HL7K5mTt_kF9u-PVWRyL?usp=share_link). Please download and put in `checkpoints/sf/{category_name}`.
38 | 2. Download Stable Diffusion v-1-3 weights [here](https://huggingface.co/CompVis/stable-diffusion-v-1-3-original) and rename `sd-v1-3.ckpt` to `sd-v1-3-vae.ckpt`. While our code is compatible with the default downloaded weight, we only use the VAE weights from Stable Diffusion. We assume the default location and filename of the vae checkpoint to be `checkpoints/sd/sd-v1-3-vae.ckpt`.
39 |
40 | ## Evaluation
41 |
42 |
43 | ### Examples
44 | To run evaluation, assuming the CO3D toy dataset and model weights are in the default paths specified above, simply pass in `-d, --dataset_name` and `-c, --category`:
45 | ```shell
46 | $ python demo.py -d co3d_toy -c hydrant
47 | ```
48 |
49 | To specify specific scenes on evaluate on, pass the desired index `0,5,7` to `-i, --idx`.
50 | ```shell
51 | $ python demo.py -d co3d_toy -c hydrant -i 0,5,7
52 | ```
53 |
54 | To specify the number of input views to use, specify `-v, --input_views`.
55 | ```shell
56 | $ python demo.py -d co3d_toy -c hydrant -i 0,5,7 -v 3
57 | ```
58 |
59 | To specify a custom dataset root location, specify `-r, --root`.
60 | ```shell
61 | $ python demo.py -d co3d_toy -r data/co3d_toy -c hydrant -i 0,5,7 -v 3
62 | ```
63 |
64 | To specify custom model checkpoints, specify `--eft`, `--vldm`, and `--vae`.
65 | ```shell
66 | $ python demo.py -d co3d_toy -r data/co3d_toy -c hydrant -i 0,5,7 -v 3 \
67 | --eft checkpoints/sf/hydrant/ckpt_latest_eft.pt \
68 | --vldm checkpoints/sf/hydrant/ckpt_latest.pt \
69 | --vae checkpoints/sd/sd-v1-3-vae.pt
70 | ```
71 |
72 | To use the original CO3Dv2 dataset, pass `co3d` for dataset_name `-d` and also the dataset root location `-r`.
73 | ```shell
74 | $ python demo.py -d co3d -r data/co3d/ -c hydrant -i 0
75 | ```
76 |
77 | ### Flags
78 | ```
79 | -g, --gpus number of gpus to use (default: 1)
80 | -p, --port last digit of DDP port (default: 1)
81 | -d, --dataset_name name of dataset (default: co3d_toy)
82 | -r, --root root directory of the dataset
83 | -c, --category CO3D category
84 | -v, --input_views number of random input views (default: 2)
85 | -i, --idx scene indices to evaluate (default: 0)
86 | -e, --eft location to EFT checkpoint
87 | -l, --vldm location to VLDM checkpoint
88 | -a, --vae location to Stable Diffusion VAE checkpoint
89 | ```
90 |
91 | ### Output
92 | Output artifacts—images, gifs, torch-ngp checkpoints—will be saved to `output/demo/` by default.
93 |
94 | ---
95 |
96 | ## Training
97 | Early access training code is provided in `train.py`. Please follow the evaluation tutorial above to setup the environment and pretrained VAE weights. It is recommended to directly modify `train.py` to specify the experiment directory and set the training hyperparameters. We show training flags below.
98 |
99 | ### Flags
100 | ```
101 | -g, --gpus number of gpus to use (default: 1)
102 | -p, --port last digit of DDP port (default: 1)
103 | -d, --dataset_name name of dataset (default: co3d_toy)
104 | -r, --root root directory of the dataset
105 | -c, --category CO3D category
106 | -a, --vae location to Stable Diffusion VAE checkpoint
107 | -b, --backend distributed data parallel backend (default: nccl)
108 | ```
109 |
110 | ### Using Custom Datasets
111 | To train on a custom dataset, one needs to write a custom dataloader. We describe the required outputs for the `__getitem__` function, which should be a dictionary containing:
112 | ```
113 | {
114 | 'images': (B, 3, H, W) image tensor,
115 | 'R': (B, 3, 3) PyTorch3D rotation,
116 | 'T': (B, 3) PyTorch3D translation,
117 | 'f': (B, 2) PyTorch3D focal_length in NDC space,
118 | 'c': (B, 2) PyTorch3D principal_point in NDC space,
119 | 'valid_region': (B, 1, H, W) binary tensor where 1 denotes valid image region,
120 | 'image_size': (B, 2) image size
121 | }
122 | ```
123 |
124 | ---
125 | ## Citation
126 | If you find this work useful, a citation will be appreciated via:
127 |
128 | ```
129 | @inproceedings{zhou2023sparsefusion,
130 | title={SparseFusion: Distilling View-conditioned Diffusion for 3D Reconstruction},
131 | author={Zhizhuo Zhou and Shubham Tulsiani},
132 | booktitle={CVPR},
133 | year={2023}
134 | }
135 | ```
136 |
137 | ## Acknowledgements
138 | We thank Naveen Venkat, Mayank Agarwal, Jeff Tan, Paritosh Mittal, Yen-Chi Cheng, and Nikolaos Gkanatsios for helpful discussions and feedback. We also thank David Novotny and Jonáš Kulhánek for sharing outputs of their work and helpful correspondence. This material is based upon work supported by the National Science Foundation Graduate Research Fellowship under Grant No. (DGE1745016, DGE2140739).
139 |
140 | We also use parts of existing projects:
141 |
142 | VAE from [Stable Diffusion](https://github.com/CompVis/stable-diffusion).
143 | ```
144 | @misc{rombach2021highresolution,
145 | title={High-Resolution Image Synthesis with Latent Diffusion Models},
146 | author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
147 | year={2021},
148 | eprint={2112.10752},
149 | archivePrefix={arXiv},
150 | primaryClass={cs.CV}
151 | }
152 | ```
153 |
154 | Diffusion model from [Imagen Pytorch](https://github.com/lucidrains/imagen-pytorch).
155 | ```
156 | @misc{imagen-pytorch,
157 | Author = {Phil Wang},
158 | Year = {2022},
159 | Note = {https://github.com/lucidrains/imagen-pytorch},
160 | Title = {Imagen - Pytorch}
161 | }
162 | ```
163 |
164 | Instant NGP implementation from [torch-ngp](https://github.com/ashawkey/torch-ngp).
165 | ```
166 | @misc{torch-ngp,
167 | Author = {Jiaxiang Tang},
168 | Year = {2022},
169 | Note = {https://github.com/ashawkey/torch-ngp},
170 | Title = {Torch-ngp: a PyTorch implementation of instant-ngp}
171 | }
172 | ```
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import torch.distributed as dist
5 | import torch.multiprocessing as mp
6 | from pytorch3d.renderer import PerspectiveCameras
7 |
8 |
9 | from sparsefusion.distillation import distillation_loop, get_default_torch_ngp_opt
10 | from utils.camera_utils import RelativeCameraLoader
11 | from utils.common_utils import get_lpips_fn, get_metrics, split_list, normalize, unnormalize
12 | from utils.co3d_dataloader import CO3Dv2Wrapper
13 | from utils.co3d_dataloader import CO3D_ALL_CATEGORIES, CO3D_ALL_TEN
14 | from utils.load_model import load_models
15 | from utils.load_dataset import load_dataset_test
16 | from utils.check_args import check_args
17 |
18 | def fit(gpu, args):
19 | #@ SPAWN DISTRIBUTED NODES
20 | rank = args.nr * args.gpus + gpu
21 | print('spawning gpu rank', rank, 'out of', args.gpus)
22 | dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
23 | torch.cuda.set_device(gpu)
24 | os.makedirs(args.exp_dir, exist_ok=True)
25 | os.makedirs(args.exp_dir + '/log/', exist_ok=True)
26 | os.makedirs(args.exp_dir + '/metrics/', exist_ok=True)
27 | os.makedirs(args.exp_dir + '/render_imgs/', exist_ok=True)
28 | os.makedirs(args.exp_dir + '/render_gifs/', exist_ok=True)
29 | render_imgs_dir = args.exp_dir + '/render_imgs/'
30 | print('evaluating', args.exp_dir)
31 |
32 | #@ INIT METRICS
33 | loss_fn_vgg = get_lpips_fn()
34 |
35 | #@ SET CATEGORIES
36 | if args.category == 'all_ten':
37 | cat_list = CO3D_ALL_TEN
38 | elif args.category == 'all':
39 | cat_list = CO3D_ALL_CATEGORIES
40 | else:
41 | cat_list = [args.category]
42 |
43 | #@ LOOP THROUGH CATEGORIES
44 | for ci, cat in enumerate(cat_list):
45 |
46 | #@ LOAD MODELS
47 | eft, vae, vldm = load_models(gpu=gpu, args=args, verbose=False)
48 | use_diffusion = True
49 |
50 | #@ LOAD DATASET
51 | print(f'gpu {gpu}: setting category to {cat} {ci}/{len(cat_list)}')
52 | args.category = cat
53 | test_dataset = load_dataset_test(args, image_size=args.image_size, masked=False)
54 |
55 | #@ SPLIT VAL LIST
56 | if args.val_list == None:
57 | args.val_list = torch.arange(len(test_dataset)).long().tolist()
58 |
59 | val_list = split_list(args.val_list, args.gpus)[gpu]
60 | print(f'gpu {gpu}: assigned idx {val_list}')
61 |
62 | #@
63 | args.val_seed = 0
64 | context_views = args.context_views
65 |
66 | #@ LOOP THROUGH VAL LIST
67 | for val_idx in val_list:
68 |
69 | #@ FETCH DATA
70 | data = test_dataset.__getitem__(val_idx)
71 | scene_idx = val_idx
72 | scene_cameras = PerspectiveCameras(R=data['R'],T=data['T'],focal_length=data['f'],principal_point=data['c'],image_size=data['image_size']).cuda(gpu)
73 | scene_rgb = data['images'].cuda(gpu)
74 | scene_mask = data['masks'].cuda(gpu)
75 | scene_valid_region = data['valid_region'].cuda(gpu)
76 |
77 | #@ SET RANDOM INPUT VIEWS
78 | g_cpu = torch.Generator()
79 | g_cpu.manual_seed(args.val_seed + val_idx)
80 | rand_perm = torch.randperm(len(data['R']), generator=g_cpu)
81 | input_idx = rand_perm[:context_views].long().tolist()
82 | output_idx = torch.arange(len(data['R'])).long().tolist()
83 | print('val_idx', val_idx, input_idx)
84 |
85 | #@ CALL DISTILLATION LOOP
86 | seq_name = f'{cat}_{val_idx:03d}_c{len(input_idx)}'
87 | opt = get_default_torch_ngp_opt()
88 | distillation_loop(
89 | gpu,
90 | args,
91 | opt,
92 | (eft, vae, vldm),
93 | args.exp_dir,
94 | seq_name,
95 | scene_cameras,
96 | scene_rgb,
97 | scene_mask,
98 | scene_valid_region,
99 | input_idx,
100 | use_diffusion=True,
101 | max_itr=3000,
102 | loss_fn_vgg=loss_fn_vgg
103 | )
104 |
105 |
106 | def main():
107 | parser = argparse.ArgumentParser()
108 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
109 | help='number of data loading workers (default: 4)')
110 | parser.add_argument('-g', '--gpus', default=1, type=int,
111 | help='number of gpus per node')
112 | parser.add_argument('-nr', '--nr', default=0, type=int,
113 | help='ranking within the nodes')
114 | parser.add_argument('-p', '--port', default=1, type=int, metavar='N',
115 | help='last digit of port (default: 1234[1])')
116 | parser.add_argument('-c', '--category', type=str, metavar='S', required=True,
117 | help='category')
118 | parser.add_argument('-r', '--root', type=str, default='data/co3d_toy', metavar='S',
119 | help='location of test features')
120 | parser.add_argument('-d', '--dataset_name', type=str, default='co3d_toy', metavar='S',
121 | help='dataset name')
122 | parser.add_argument('-e', '--eft', type=str, default='-DNE', metavar='S',
123 | help='eft ckpt')
124 | parser.add_argument('-l', '--vldm', type=str, default='-DNE', metavar='S',
125 | help='vldm ckpt')
126 | parser.add_argument('-a', '--vae', type=str, default='-DNE', metavar='S',
127 | help='vae ckpt')
128 | parser.add_argument('-i', '--idx', type=str, default='-DNE', metavar='N',
129 | help='evaluataion indicies')
130 | parser.add_argument('-v', '--input_views', type=int, default=2, metavar='N',
131 | help='input views')
132 | args = parser.parse_args()
133 |
134 | #@ SET MULTIPROCESSING PORTS
135 | args.world_size = args.gpus * args.nodes
136 | os.environ['MASTER_ADDR'] = 'localhost'
137 | os.environ['MASTER_PORT'] = f'1234{args.port}'
138 | print('using port', f'1234{args.port}')
139 |
140 | #@ SET DEFAUL PARAMETERS
141 | args.use_r = True
142 | args.encoder = 'resnet18'
143 | args.num_input = 4
144 | args.timesteps = 500
145 | args.objective = 'noise'
146 | args.scale_factor = 8
147 | args.image_size = 256
148 | args.z_scale_factor = 0.18215
149 |
150 | args.server_prefix = 'checkpoints/'
151 | args.diffusion_exp_name = 'sf'
152 | args.eft_ckpt = f'{args.server_prefix}/{args.diffusion_exp_name}/{args.category}/ckpt_latest_eft.pt'
153 | args.vae_ckpt = f'{args.server_prefix}/sd/sd-v1-3-vae.ckpt'
154 | args.vldm_ckpt = f'{args.server_prefix}/{args.diffusion_exp_name}/{args.category}/ckpt_latest.pt'
155 |
156 | args.context_views = args.input_views
157 | args.val_list = [0]
158 | args.exp_dir = 'output/demo/'
159 |
160 | #@ OVERRIDE DEFAULT ARGS WITH INPUTS
161 | if args.vae != '-DNE':
162 | args.vae_ckpt = args.vae
163 | if args.eft != '-DNE':
164 | args.eft_ckpt = args.eft
165 | if args.vldm != '-DNE':
166 | args.vldm_ckpt = args.vldm
167 | if args.idx != '-DNE':
168 | try:
169 | val_list_str = args.idx.split(',')
170 | args.val_list = []
171 | for val_str in val_list_str:
172 | args.val_list.append(int(val_str))
173 | except:
174 | print('ERROR: -i --idx arg invalid, please use form 1,2,3')
175 | print('Exiting...')
176 | exit(1)
177 |
178 | check_args(args)
179 |
180 | mp.spawn(fit, nprocs=args.gpus, args=(args,))
181 |
182 | if __name__ == '__main__':
183 | main()
--------------------------------------------------------------------------------
/external/README.md:
--------------------------------------------------------------------------------
1 | All code in external/ are from many amazing researchers and engineers who made
2 | their code public. Please see the README.md in the root directory for credits.
--------------------------------------------------------------------------------
/external/einops_exts.py:
--------------------------------------------------------------------------------
1 | '''
2 | Einops exts
3 | #@ FROM https://github.com/lucidrains/einops-exts
4 | '''
5 |
6 | import re
7 | from torch import nn
8 | from functools import wraps, partial
9 |
10 | from einops import rearrange, reduce, repeat
11 |
12 | # checking shape
13 | # @nils-werner
14 | # https://github.com/arogozhnikov/einops/issues/168#issuecomment-1042933838
15 |
16 | def check_shape(tensor, pattern, **kwargs):
17 | return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)
18 |
19 | # do same einops operations on a list of tensors
20 |
21 | def _many(fn):
22 | @wraps(fn)
23 | def inner(tensors, pattern, **kwargs):
24 | return (fn(tensor, pattern, **kwargs) for tensor in tensors)
25 | return inner
26 |
27 | # do einops with unflattening of anonymously named dimensions
28 | # (...flattened) -> ...flattened
29 |
30 | def _with_anon_dims(fn):
31 | @wraps(fn)
32 | def inner(tensor, pattern, **kwargs):
33 | regex = r'(\.\.\.[a-zA-Z]+)'
34 | matches = re.findall(regex, pattern)
35 | get_anon_dim_name = lambda t: t.lstrip('...')
36 | dim_prefixes = tuple(map(get_anon_dim_name, set(matches)))
37 |
38 | update_kwargs_dict = dict()
39 |
40 | for prefix in dim_prefixes:
41 | assert prefix in kwargs, f'dimension list "{prefix}" was not passed in'
42 | dim_list = kwargs[prefix]
43 | assert isinstance(dim_list, (list, tuple)), f'dimension list "{prefix}" needs to be a tuple of list of dimensions'
44 | dim_names = list(map(lambda ind: f'{prefix}{ind}', range(len(dim_list))))
45 | update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list))
46 |
47 | def sub_with_anonymous_dims(t):
48 | dim_name_prefix = get_anon_dim_name(t.groups()[0])
49 | return ' '.join(update_kwargs_dict[dim_name_prefix].keys())
50 |
51 | pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern)
52 |
53 | for prefix, update_dict in update_kwargs_dict.items():
54 | del kwargs[prefix]
55 | kwargs.update(update_dict)
56 |
57 | return fn(tensor, pattern_new, **kwargs)
58 | return inner
59 |
60 | # generate all helper functions
61 |
62 | rearrange_many = _many(rearrange)
63 | repeat_many = _many(repeat)
64 | reduce_many = _many(reduce)
65 |
66 | rearrange_with_anon_dims = _with_anon_dims(rearrange)
67 | repeat_with_anon_dims = _with_anon_dims(repeat)
68 | reduce_with_anon_dims = _with_anon_dims(reduce)
--------------------------------------------------------------------------------
/external/external_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | A wrapper class for the perceptual deep feature loss.
3 |
4 | Reference:
5 | Richard Zhang et al. The Unreasonable Effectiveness of Deep Features as a Perceptual Metric. (CVPR 2018).
6 | """
7 | import lpips
8 | import torch.nn as nn
9 |
10 |
11 | class PerceptualLoss(nn.Module):
12 | def __init__(self, net, device):
13 | super().__init__()
14 | self.model = lpips.LPIPS(net=net, verbose=False).to(device)
15 | self.device = device
16 |
17 | def get_device(self, default_device=None):
18 | """
19 | Returns which device module is on, assuming all parameters are on the same GPU.
20 | """
21 | try:
22 | return next(self.parameters()).device
23 | except StopIteration:
24 | return default_device
25 |
26 | def __call__(self, pred, target, normalize=True):
27 | """
28 | Pred and target are Variables.
29 | If normalize is on, scales images between [-1, 1]
30 | Assumes the inputs are in range [0, 1].
31 | B 3 H W
32 | """
33 | if pred.shape[1] != 3:
34 | pred = pred.permute(0, 3, 1, 2)
35 | target = target.permute(0, 3, 1, 2)
36 | # print(pred.shape, target.shape)
37 | if normalize:
38 | target = 2 * target - 1
39 | pred = 2 * pred - 1
40 |
41 | # temp_device = pred.device
42 | # device = self.get_device(temp_device)
43 |
44 | device = self.device
45 |
46 | pred = pred.to(device).float()
47 | target = target.to(device)
48 | dist = self.model.forward(pred, target)
49 | return dist.to(device)
50 |
--------------------------------------------------------------------------------
/external/gridencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .grid import GridEncoder
--------------------------------------------------------------------------------
/external/gridencoder/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | ]
10 |
11 | if os.name == "posix":
12 | c_flags = ['-O3', '-std=c++14']
13 | elif os.name == "nt":
14 | c_flags = ['/O2', '/std:c++17']
15 |
16 | # find cl.exe
17 | def find_cl_path():
18 | import glob
19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21 | if paths:
22 | return paths[0]
23 |
24 | # If cl.exe is not on path, try to find it.
25 | if os.system("where cl.exe >nul 2>nul") != 0:
26 | cl_path = find_cl_path()
27 | if cl_path is None:
28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29 | os.environ["PATH"] += ";" + cl_path
30 |
31 | _backend = load(name='_grid_encoder',
32 | extra_cflags=c_flags,
33 | extra_cuda_cflags=nvcc_flags,
34 | sources=[os.path.join(_src_path, 'src', f) for f in [
35 | 'gridencoder.cu',
36 | 'bindings.cpp',
37 | ]],
38 | )
39 |
40 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/external/gridencoder/grid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Function
6 | from torch.autograd.function import once_differentiable
7 | from torch.cuda.amp import custom_bwd, custom_fwd
8 |
9 | try:
10 | import _gridencoder as _backend
11 | except ImportError:
12 | from .backend import _backend
13 |
14 | _gridtype_to_id = {
15 | 'hash': 0,
16 | 'tiled': 1,
17 | }
18 |
19 | class _grid_encode(Function):
20 | @staticmethod
21 | @custom_fwd
22 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
23 | # inputs: [B, D], float in [0, 1]
24 | # embeddings: [sO, C], float
25 | # offsets: [L + 1], int
26 | # RETURN: [B, F], float
27 |
28 | inputs = inputs.contiguous()
29 |
30 | B, D = inputs.shape # batch size, coord dim
31 | L = offsets.shape[0] - 1 # level
32 | C = embeddings.shape[1] # embedding dim for each level
33 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
34 | H = base_resolution # base resolution
35 |
36 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
37 | # if C % 2 != 0, force float, since half for atomicAdd is very slow.
38 | if torch.is_autocast_enabled() and C % 2 == 0:
39 | embeddings = embeddings.to(torch.half)
40 |
41 | # L first, optimize cache for cuda kernel, but needs an extra permute later
42 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
43 |
44 | if calc_grad_inputs:
45 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
46 | else:
47 | dy_dx = None
48 |
49 | _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
50 |
51 | # permute back to [B, L * C]
52 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
53 |
54 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
55 | ctx.dims = [B, D, C, L, S, H, gridtype]
56 | ctx.align_corners = align_corners
57 |
58 | return outputs
59 |
60 | @staticmethod
61 | #@once_differentiable
62 | @custom_bwd
63 | def backward(ctx, grad):
64 |
65 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
66 | B, D, C, L, S, H, gridtype = ctx.dims
67 | align_corners = ctx.align_corners
68 |
69 | # grad: [B, L * C] --> [L, B, C]
70 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
71 |
72 | grad_embeddings = torch.zeros_like(embeddings)
73 |
74 | if dy_dx is not None:
75 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
76 | else:
77 | grad_inputs = None
78 |
79 | _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
80 |
81 | if dy_dx is not None:
82 | grad_inputs = grad_inputs.to(inputs.dtype)
83 |
84 | return grad_inputs, grad_embeddings, None, None, None, None, None, None
85 |
86 |
87 |
88 | grid_encode = _grid_encode.apply
89 |
90 |
91 | class GridEncoder(nn.Module):
92 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
93 | super().__init__()
94 |
95 | # the finest resolution desired at the last level, if provided, overridee per_level_scale
96 | if desired_resolution is not None:
97 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
98 |
99 | self.input_dim = input_dim # coord dims, 2 or 3
100 | self.num_levels = num_levels # num levels, each level multiply resolution by 2
101 | self.level_dim = level_dim # encode channels per level
102 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
103 | self.log2_hashmap_size = log2_hashmap_size
104 | self.base_resolution = base_resolution
105 | self.output_dim = num_levels * level_dim
106 | self.gridtype = gridtype
107 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
108 | self.align_corners = align_corners
109 |
110 | # allocate parameters
111 | offsets = []
112 | offset = 0
113 | self.max_params = 2 ** log2_hashmap_size
114 | for i in range(num_levels):
115 | resolution = int(np.ceil(base_resolution * per_level_scale ** i))
116 | params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
117 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
118 | offsets.append(offset)
119 | offset += params_in_level
120 | offsets.append(offset)
121 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
122 | self.register_buffer('offsets', offsets)
123 |
124 | self.n_params = offsets[-1] * level_dim
125 |
126 | # parameters
127 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
128 |
129 | self.reset_parameters()
130 |
131 | def reset_parameters(self):
132 | std = 1e-4
133 | self.embeddings.data.uniform_(-std, std)
134 |
135 | def __repr__(self):
136 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
137 |
138 | def forward(self, inputs, bound=1):
139 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
140 | # return: [..., num_levels * level_dim]
141 |
142 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
143 |
144 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
145 |
146 | prefix_shape = list(inputs.shape[:-1])
147 | inputs = inputs.view(-1, self.input_dim)
148 |
149 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
150 | outputs = outputs.view(prefix_shape + [self.output_dim])
151 |
152 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
153 |
154 | return outputs
--------------------------------------------------------------------------------
/external/gridencoder/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22 | if paths:
23 | return paths[0]
24 |
25 | # If cl.exe is not on path, try to find it.
26 | if os.system("where cl.exe >nul 2>nul") != 0:
27 | cl_path = find_cl_path()
28 | if cl_path is None:
29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30 | os.environ["PATH"] += ";" + cl_path
31 |
32 | setup(
33 | name='gridencoder', # package name, import this to use python API
34 | ext_modules=[
35 | CUDAExtension(
36 | name='_gridencoder', # extension name, import this to use CUDA API
37 | sources=[os.path.join(_src_path, 'src', f) for f in [
38 | 'gridencoder.cu',
39 | 'bindings.cpp',
40 | ]],
41 | extra_compile_args={
42 | 'cxx': c_flags,
43 | 'nvcc': nvcc_flags,
44 | }
45 | ),
46 | ],
47 | cmdclass={
48 | 'build_ext': BuildExtension,
49 | }
50 | )
--------------------------------------------------------------------------------
/external/gridencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "gridencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
8 | }
--------------------------------------------------------------------------------
/external/gridencoder/src/gridencoder.h:
--------------------------------------------------------------------------------
1 | #ifndef _HASH_ENCODE_H
2 | #define _HASH_ENCODE_H
3 |
4 | #include
5 | #include
6 |
7 | // inputs: [B, D], float, in [0, 1]
8 | // embeddings: [sO, C], float
9 | // offsets: [L + 1], uint32_t
10 | // outputs: [B, L * C], float
11 | // H: base resolution
12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners);
13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners);
14 |
15 | #endif
--------------------------------------------------------------------------------
/external/ldm/configs/sd-vae.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: external.ldm.models.autoencoder.AutoencoderKL
3 | params:
4 | embed_dim: 4
5 | monitor: val/rec_loss
6 | ddconfig:
7 | double_z: true
8 | z_channels: 4
9 | resolution: 256
10 | in_channels: 3
11 | out_ch: 3
12 | ch: 128
13 | ch_mult:
14 | - 1
15 | - 2
16 | - 4
17 | - 4
18 | num_res_blocks: 2
19 | attn_resolutions: []
20 | dropout: 0.0
21 | lossconfig:
22 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/external/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from external.ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/external/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/external/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/external/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from external.ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/external/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/external/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/external/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/external/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/external/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/external/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/external/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | from transformers import CLIPTokenizer, CLIPTextModel
7 | import kornia
8 |
9 | from external.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66 | tokens = batch_encoding["input_ids"].to(self.device)
67 | return tokens
68 |
69 | @torch.no_grad()
70 | def encode(self, text):
71 | tokens = self(text)
72 | if not self.vq_interface:
73 | return tokens
74 | return None, None, [None, None, tokens]
75 |
76 | def decode(self, text):
77 | return text
78 |
79 |
80 | class BERTEmbedder(AbstractEncoder):
81 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84 | super().__init__()
85 | self.use_tknz_fn = use_tokenizer
86 | if self.use_tknz_fn:
87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88 | self.device = device
89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
91 | emb_dropout=embedding_dropout)
92 |
93 | def forward(self, text):
94 | if self.use_tknz_fn:
95 | tokens = self.tknz_fn(text)#.to(self.device)
96 | else:
97 | tokens = text
98 | z = self.transformer(tokens, return_embeddings=True)
99 | return z
100 |
101 | def encode(self, text):
102 | # output of length 77
103 | return self(text)
104 |
105 |
106 | class SpatialRescaler(nn.Module):
107 | def __init__(self,
108 | n_stages=1,
109 | method='bilinear',
110 | multiplier=0.5,
111 | in_channels=3,
112 | out_channels=None,
113 | bias=False):
114 | super().__init__()
115 | self.n_stages = n_stages
116 | assert self.n_stages >= 0
117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118 | self.multiplier = multiplier
119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120 | self.remap_output = out_channels is not None
121 | if self.remap_output:
122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124 |
125 | def forward(self,x):
126 | for stage in range(self.n_stages):
127 | x = self.interpolator(x, scale_factor=self.multiplier)
128 |
129 |
130 | if self.remap_output:
131 | x = self.channel_mapper(x)
132 | return x
133 |
134 | def encode(self, x):
135 | return self(x)
136 |
137 | class FrozenCLIPEmbedder(AbstractEncoder):
138 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
140 | super().__init__()
141 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
142 | self.transformer = CLIPTextModel.from_pretrained(version)
143 | self.device = device
144 | self.max_length = max_length
145 | self.freeze()
146 |
147 | def freeze(self):
148 | self.transformer = self.transformer.eval()
149 | for param in self.parameters():
150 | param.requires_grad = False
151 |
152 | def forward(self, text):
153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
155 | tokens = batch_encoding["input_ids"].to(self.device)
156 | outputs = self.transformer(input_ids=tokens)
157 |
158 | z = outputs.last_hidden_state
159 | return z
160 |
161 | def encode(self, text):
162 | return self(text)
163 |
164 |
165 | class FrozenCLIPTextEmbedder(nn.Module):
166 | """
167 | Uses the CLIP transformer encoder for text.
168 | """
169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
170 | super().__init__()
171 | self.model, _ = clip.load(version, jit=False, device="cpu")
172 | self.device = device
173 | self.max_length = max_length
174 | self.n_repeat = n_repeat
175 | self.normalize = normalize
176 |
177 | def freeze(self):
178 | self.model = self.model.eval()
179 | for param in self.parameters():
180 | param.requires_grad = False
181 |
182 | def forward(self, text):
183 | tokens = clip.tokenize(text).to(self.device)
184 | z = self.model.encode_text(tokens)
185 | if self.normalize:
186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
187 | return z
188 |
189 | def encode(self, text):
190 | z = self(text)
191 | if z.ndim==2:
192 | z = z[:, None, :]
193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
194 | return z
195 |
196 |
197 | class FrozenClipImageEmbedder(nn.Module):
198 | """
199 | Uses the CLIP image encoder.
200 | """
201 | def __init__(
202 | self,
203 | model,
204 | jit=False,
205 | device='cuda' if torch.cuda.is_available() else 'cpu',
206 | antialias=False,
207 | ):
208 | super().__init__()
209 | self.model, _ = clip.load(name=model, device=device, jit=jit)
210 |
211 | self.antialias = antialias
212 |
213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
215 |
216 | def preprocess(self, x):
217 | # normalize to [0,1]
218 | x = kornia.geometry.resize(x, (224, 224),
219 | interpolation='bicubic',align_corners=True,
220 | antialias=self.antialias)
221 | x = (x + 1.) / 2.
222 | # renormalize according to clip
223 | x = kornia.enhance.normalize(x, self.mean, self.std)
224 | return x
225 |
226 | def forward(self, x):
227 | # x is assumed to be in range [-1,1]
228 | return self.model.encode_image(self.preprocess(x))
229 |
230 |
231 | if __name__ == "__main__":
232 | from external.ldm.util import count_params
233 | model = FrozenCLIPEmbedder()
234 | count_params(model, verbose=True)
--------------------------------------------------------------------------------
/external/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from external.ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from external.ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/external/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/external/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/external/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from external.ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/external/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/external/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/external/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/external/nerf/clip_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import torchvision.transforms as T
7 | import torchvision.transforms.functional as TF
8 |
9 | import clip
10 |
11 | class CLIPLoss:
12 | def __init__(self, device, name='ViT-B/16'):
13 | self.device = device
14 | self.name = name
15 | self.clip_model, self.transform_PIL = clip.load(self.name, device=self.device, jit=False)
16 |
17 | # disable training
18 | self.clip_model.eval()
19 | for p in self.clip_model.parameters():
20 | p.requires_grad = False
21 |
22 | # image augmentation
23 | self.transform = T.Compose([
24 | T.Resize((224, 224)),
25 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
26 | ])
27 |
28 | # placeholder
29 | self.text_zs = None
30 | self.image_zs = None
31 |
32 | def normalize(self, x):
33 | return x / x.norm(dim=-1, keepdim=True)
34 |
35 | # image-text (e.g., dreamfields)
36 | def prepare_text(self, texts):
37 | # texts: list of strings.
38 | texts = clip.tokenize(texts).to(self.device)
39 | self.text_zs = self.normalize(self.clip_model.encode_text(texts))
40 | print(f'[INFO] prepared CLIP text feature: {self.text_zs.shape}')
41 |
42 | def __call__(self, images, mode='text'):
43 |
44 | images = self.transform(images)
45 | image_zs = self.normalize(self.clip_model.encode_image(images))
46 |
47 | if mode == 'text':
48 | # if more than one string, randomly choose one.
49 | if self.text_zs.shape[0] > 1:
50 | idx = random.randint(0, self.text_zs.shape[0] - 1)
51 | text_zs = self.text_zs[[idx]]
52 | else:
53 | text_zs = self.text_zs
54 | # broadcast text_zs to all image_zs
55 | loss = - (image_zs * text_zs).sum(-1).mean()
56 | else:
57 | raise NotImplementedError
58 |
59 | return loss
60 |
61 | # image-image (e.g., diet-nerf)
62 | def prepare_image(self, dataset):
63 | # images: a nerf dataset (we need both poses and images!)
64 | pass
--------------------------------------------------------------------------------
/external/nerf/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from encoding import get_encoder
6 | from activation import trunc_exp
7 | from .renderer import NeRFRenderer
8 |
9 |
10 | class NeRFNetwork(NeRFRenderer):
11 | def __init__(self,
12 | encoding="hashgrid",
13 | encoding_dir="sphere_harmonics",
14 | encoding_bg="hashgrid",
15 | num_layers=5,
16 | hidden_dim=64,
17 | geo_feat_dim=15,
18 | num_layers_color=3,
19 | hidden_dim_color=64,
20 | num_layers_bg=2,
21 | hidden_dim_bg=64,
22 | bound=1,
23 | **kwargs,
24 | ):
25 | super().__init__(bound, **kwargs)
26 |
27 | # sigma network
28 | self.num_layers = num_layers
29 | self.hidden_dim = hidden_dim
30 | self.geo_feat_dim = geo_feat_dim
31 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound)
32 |
33 | sigma_net = []
34 | for l in range(num_layers):
35 | if l == 0:
36 | in_dim = self.in_dim
37 | else:
38 | in_dim = hidden_dim
39 |
40 | if l == num_layers - 1:
41 | out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color
42 | else:
43 | out_dim = hidden_dim
44 |
45 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))
46 |
47 | self.sigma_net = nn.ModuleList(sigma_net)
48 |
49 | # color network
50 | self.num_layers_color = num_layers_color
51 | self.hidden_dim_color = hidden_dim_color
52 | self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir)
53 |
54 | color_net = []
55 | for l in range(num_layers_color):
56 | if l == 0:
57 | in_dim = self.in_dim_dir + self.geo_feat_dim
58 | else:
59 | in_dim = hidden_dim_color
60 |
61 | if l == num_layers_color - 1:
62 | out_dim = 3 # 3 rgb
63 | else:
64 | out_dim = hidden_dim_color
65 |
66 | color_net.append(nn.Linear(in_dim, out_dim, bias=False))
67 |
68 | self.color_net = nn.ModuleList(color_net)
69 |
70 | # background network
71 | if self.bg_radius > 0:
72 | self.num_layers_bg = num_layers_bg
73 | self.hidden_dim_bg = hidden_dim_bg
74 | self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid
75 |
76 | bg_net = []
77 | for l in range(num_layers_bg):
78 | if l == 0:
79 | in_dim = self.in_dim_bg + self.in_dim_dir
80 | else:
81 | in_dim = hidden_dim_bg
82 |
83 | if l == num_layers_bg - 1:
84 | out_dim = 3 # 3 rgb
85 | else:
86 | out_dim = hidden_dim_bg
87 |
88 | bg_net.append(nn.Linear(in_dim, out_dim, bias=False))
89 |
90 | self.bg_net = nn.ModuleList(bg_net)
91 | else:
92 | self.bg_net = None
93 |
94 |
95 | def forward(self, x, d):
96 | # x: [N, 3], in [-bound, bound]
97 | # d: [N, 3], nomalized in [-1, 1]
98 |
99 | # sigma
100 | x = self.encoder(x, bound=self.bound)
101 |
102 | h = x
103 | for l in range(self.num_layers):
104 | h = self.sigma_net[l](h)
105 | if l != self.num_layers - 1:
106 | h = F.relu(h, inplace=True)
107 |
108 | #sigma = F.relu(h[..., 0])
109 | sigma = trunc_exp(h[..., 0])
110 | geo_feat = h[..., 1:]
111 |
112 | # color
113 |
114 | d = self.encoder_dir(d)
115 | h = torch.cat([d, geo_feat], dim=-1)
116 | for l in range(self.num_layers_color):
117 | h = self.color_net[l](h)
118 | if l != self.num_layers_color - 1:
119 | h = F.relu(h, inplace=True)
120 |
121 | # sigmoid activation for rgb
122 | color = torch.sigmoid(h)
123 |
124 | return sigma, color
125 |
126 | def density(self, x):
127 | # x: [N, 3], in [-bound, bound]
128 |
129 | x = self.encoder(x, bound=self.bound)
130 | h = x
131 | for l in range(self.num_layers):
132 | h = self.sigma_net[l](h)
133 | if l != self.num_layers - 1:
134 | h = F.relu(h, inplace=True)
135 |
136 | #sigma = F.relu(h[..., 0])
137 | sigma = trunc_exp(h[..., 0])
138 | geo_feat = h[..., 1:]
139 |
140 | return {
141 | 'sigma': sigma,
142 | 'geo_feat': geo_feat,
143 | }
144 |
145 | def background(self, x, d):
146 | # x: [N, 2], in [-1, 1]
147 |
148 | h = self.encoder_bg(x) # [N, C]
149 | d = self.encoder_dir(d)
150 |
151 | h = torch.cat([d, h], dim=-1)
152 | for l in range(self.num_layers_bg):
153 | h = self.bg_net[l](h)
154 | if l != self.num_layers_bg - 1:
155 | h = F.relu(h, inplace=True)
156 |
157 | # sigmoid activation for rgb
158 | rgbs = torch.sigmoid(h)
159 |
160 | return rgbs
161 |
162 | # allow masked inference
163 | def color(self, x, d, mask=None, geo_feat=None, **kwargs):
164 | # x: [N, 3] in [-bound, bound]
165 | # mask: [N,], bool, indicates where we actually needs to compute rgb.
166 |
167 | if mask is not None:
168 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
169 | # in case of empty mask
170 | if not mask.any():
171 | return rgbs
172 | x = x[mask]
173 | d = d[mask]
174 | geo_feat = geo_feat[mask]
175 |
176 | d = self.encoder_dir(d)
177 | h = torch.cat([d, geo_feat], dim=-1)
178 | for l in range(self.num_layers_color):
179 | h = self.color_net[l](h)
180 | if l != self.num_layers_color - 1:
181 | h = F.relu(h, inplace=True)
182 |
183 | # sigmoid activation for rgb
184 | h = torch.sigmoid(h)
185 |
186 | if mask is not None:
187 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
188 | else:
189 | rgbs = h
190 |
191 | return rgbs
192 |
193 | # optimizer utils
194 | def get_params(self, lr):
195 |
196 | params = [
197 | {'params': self.encoder.parameters(), 'lr': lr},
198 | {'params': self.sigma_net.parameters(), 'lr': lr},
199 | {'params': self.encoder_dir.parameters(), 'lr': lr},
200 | {'params': self.color_net.parameters(), 'lr': lr},
201 | ]
202 | if self.bg_radius > 0:
203 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr})
204 | params.append({'params': self.bg_net.parameters(), 'lr': lr})
205 |
206 | return params
207 |
--------------------------------------------------------------------------------
/external/nerf/network_df.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from external.ngp_activation import trunc_exp
6 | from external.nerf.renderer_df import NeRFRenderer
7 |
8 | import numpy as np
9 | from external.ngp_encoder import get_encoder
10 |
11 | from .utils import safe_normalize
12 |
13 | class MLP(nn.Module):
14 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
15 | super().__init__()
16 | self.dim_in = dim_in
17 | self.dim_out = dim_out
18 | self.dim_hidden = dim_hidden
19 | self.num_layers = num_layers
20 |
21 | net = []
22 | for l in range(num_layers):
23 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
24 |
25 | self.net = nn.ModuleList(net)
26 |
27 | def forward(self, x):
28 | for l in range(self.num_layers):
29 | x = self.net[l](x)
30 | if l != self.num_layers - 1:
31 | x = F.relu(x, inplace=True)
32 | return x
33 |
34 |
35 | class NeRFNetwork(NeRFRenderer):
36 | def __init__(self,
37 | opt,
38 | num_layers=5,
39 | hidden_dim=128,
40 | num_layers_bg=2,
41 | hidden_dim_bg=64,
42 | ):
43 |
44 | super().__init__(opt)
45 |
46 | self.num_layers = num_layers
47 | self.hidden_dim = hidden_dim
48 | self.encoder, self.in_dim = get_encoder('frequency', input_dim=3)
49 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
50 |
51 | # background network
52 | if self.bg_radius > 0:
53 | self.num_layers_bg = num_layers_bg
54 | self.hidden_dim_bg = hidden_dim_bg
55 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
56 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
57 |
58 | else:
59 | self.bg_net = None
60 |
61 | def gaussian(self, x):
62 | # x: [B, N, 3]
63 |
64 | d = (x ** 2).sum(-1)
65 | g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
66 |
67 | return g
68 |
69 | def common_forward(self, x):
70 | # x: [N, 3], in [-bound, bound]
71 |
72 | # sigma
73 | h = self.encoder(x, bound=self.bound)
74 |
75 | h = self.sigma_net(h)
76 |
77 | sigma = trunc_exp(h[..., 0] + self.gaussian(x))
78 | albedo = torch.sigmoid(h[..., 1:])
79 |
80 | return sigma, albedo
81 |
82 | # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
83 | def finite_difference_normal(self, x, epsilon=1e-2):
84 | # x: [N, 3]
85 | dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
86 | dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
87 | dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
88 | dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
89 | dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
90 | dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
91 |
92 | normal = torch.stack([
93 | 0.5 * (dx_pos - dx_neg) / epsilon,
94 | 0.5 * (dy_pos - dy_neg) / epsilon,
95 | 0.5 * (dz_pos - dz_neg) / epsilon
96 | ], dim=-1)
97 |
98 | return normal
99 |
100 | def normal(self, x):
101 |
102 | with torch.enable_grad():
103 | x.requires_grad_(True)
104 | sigma, albedo = self.common_forward(x)
105 | # query gradient
106 | normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
107 |
108 | # normalize...
109 | normal = safe_normalize(normal)
110 | normal[torch.isnan(normal)] = 0
111 | return normal
112 |
113 | def forward(self, x, d, l=None, ratio=1, shading='albedo'):
114 | # x: [N, 3], in [-bound, bound]
115 | # d: [N, 3], view direction, nomalized in [-1, 1]
116 | # l: [3], plane light direction, nomalized in [-1, 1]
117 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
118 |
119 | if shading == 'albedo':
120 | # no need to query normal
121 | sigma, color = self.common_forward(x)
122 | normal = None
123 |
124 | else:
125 | # query normal
126 |
127 | # sigma, albedo = self.common_forward(x)
128 | # normal = self.finite_difference_normal(x)
129 |
130 | with torch.enable_grad():
131 | x.requires_grad_(True)
132 | sigma, albedo = self.common_forward(x)
133 | # query gradient
134 | normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
135 |
136 | # normalize...
137 | normal = safe_normalize(normal)
138 | normal[torch.isnan(normal)] = 0
139 |
140 | # lambertian shading
141 | lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
142 |
143 | if shading == 'textureless':
144 | color = lambertian.unsqueeze(-1).repeat(1, 3)
145 | elif shading == 'normal':
146 | color = (normal + 1) / 2
147 | else: # 'lambertian'
148 | color = albedo * lambertian.unsqueeze(-1)
149 |
150 | return sigma, color, normal
151 |
152 |
153 | def density(self, x):
154 | # x: [N, 3], in [-bound, bound]
155 |
156 | sigma, albedo = self.common_forward(x)
157 |
158 | return {
159 | 'sigma': sigma,
160 | 'albedo': albedo,
161 | }
162 |
163 |
164 | def background(self, d):
165 |
166 | h = self.encoder_bg(d) # [N, C]
167 |
168 | h = self.bg_net(h)
169 |
170 | # sigmoid activation for rgb
171 | rgbs = torch.sigmoid(h)
172 |
173 | return rgbs
174 |
175 | # optimizer utils
176 | def get_params(self, lr):
177 |
178 | params = [
179 | # {'params': self.encoder.parameters(), 'lr': lr * 10},
180 | {'params': self.sigma_net.parameters(), 'lr': lr},
181 | ]
182 |
183 | if self.bg_radius > 0:
184 | # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
185 | params.append({'params': self.bg_net.parameters(), 'lr': lr})
186 |
187 | return params
--------------------------------------------------------------------------------
/external/nerf/network_ff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from encoding import get_encoder
6 | from activation import trunc_exp
7 | from ffmlp import FFMLP
8 |
9 | from .renderer import NeRFRenderer
10 |
11 | class NeRFNetwork(NeRFRenderer):
12 | def __init__(self,
13 | encoding="hashgrid",
14 | encoding_dir="sphere_harmonics",
15 | num_layers=2,
16 | hidden_dim=64,
17 | geo_feat_dim=15,
18 | num_layers_color=3,
19 | hidden_dim_color=64,
20 | bound=1,
21 | **kwargs
22 | ):
23 | super().__init__(bound, **kwargs)
24 |
25 | # sigma network
26 | self.num_layers = num_layers
27 | self.hidden_dim = hidden_dim
28 | self.geo_feat_dim = geo_feat_dim
29 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound)
30 |
31 | self.sigma_net = FFMLP(
32 | input_dim=self.in_dim,
33 | output_dim=1 + self.geo_feat_dim,
34 | hidden_dim=self.hidden_dim,
35 | num_layers=self.num_layers,
36 | )
37 |
38 | # color network
39 | self.num_layers_color = num_layers_color
40 | self.hidden_dim_color = hidden_dim_color
41 | self.encoder_dir, self.in_dim_color = get_encoder(encoding_dir)
42 | self.in_dim_color += self.geo_feat_dim + 1 # a manual fixing to make it 32, as done in nerf_network.h#178
43 |
44 | self.color_net = FFMLP(
45 | input_dim=self.in_dim_color,
46 | output_dim=3,
47 | hidden_dim=self.hidden_dim_color,
48 | num_layers=self.num_layers_color,
49 | )
50 |
51 | def forward(self, x, d):
52 | # x: [N, 3], in [-bound, bound]
53 | # d: [N, 3], nomalized in [-1, 1]
54 |
55 | # sigma
56 | x = self.encoder(x, bound=self.bound)
57 | h = self.sigma_net(x)
58 |
59 | #sigma = F.relu(h[..., 0])
60 | sigma = trunc_exp(h[..., 0])
61 | geo_feat = h[..., 1:]
62 |
63 | # color
64 | d = self.encoder_dir(d)
65 |
66 | # TODO: preallocate space and avoid this cat?
67 | p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
68 | h = torch.cat([d, geo_feat, p], dim=-1)
69 | h = self.color_net(h)
70 |
71 | # sigmoid activation for rgb
72 | rgb = torch.sigmoid(h)
73 |
74 | return sigma, rgb
75 |
76 | def density(self, x):
77 | # x: [N, 3], in [-bound, bound]
78 |
79 | x = self.encoder(x, bound=self.bound)
80 | h = self.sigma_net(x)
81 |
82 | #sigma = F.relu(h[..., 0])
83 | sigma = trunc_exp(h[..., 0])
84 | geo_feat = h[..., 1:]
85 |
86 | return {
87 | 'sigma': sigma,
88 | 'geo_feat': geo_feat,
89 | }
90 |
91 | # allow masked inference
92 | def color(self, x, d, mask=None, geo_feat=None, **kwargs):
93 | # x: [N, 3] in [-bound, bound]
94 | # mask: [N,], bool, indicates where we actually needs to compute rgb.
95 |
96 | #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
97 | #starter.record()
98 |
99 | if mask is not None:
100 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
101 | # in case of empty mask
102 | if not mask.any():
103 | return rgbs
104 | x = x[mask]
105 | d = d[mask]
106 | geo_feat = geo_feat[mask]
107 |
108 | #print(x.shape, rgbs.shape)
109 |
110 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'mask = {curr_time}')
111 | #starter.record()
112 |
113 | d = self.encoder_dir(d)
114 |
115 | p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
116 | h = torch.cat([d, geo_feat, p], dim=-1)
117 |
118 | h = self.color_net(h)
119 |
120 | # sigmoid activation for rgb
121 | h = torch.sigmoid(h)
122 |
123 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'call = {curr_time}')
124 | #starter.record()
125 |
126 | if mask is not None:
127 | rgbs[mask] = h.to(rgbs.dtype)
128 | else:
129 | rgbs = h
130 |
131 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'unmask = {curr_time}')
132 | #starter.record()
133 |
134 | return rgbs
135 |
136 | # optimizer utils
137 | def get_params(self, lr):
138 |
139 | params = [
140 | {'params': self.encoder.parameters(), 'lr': lr},
141 | {'params': self.sigma_net.parameters(), 'lr': lr},
142 | {'params': self.encoder_dir.parameters(), 'lr': lr},
143 | {'params': self.color_net.parameters(), 'lr': lr},
144 | ]
145 | if self.bg_radius > 0:
146 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr})
147 | params.append({'params': self.bg_net.parameters(), 'lr': lr})
148 |
149 | return params
--------------------------------------------------------------------------------
/external/nerf/network_grid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from external.ngp_activation import trunc_exp
6 | from external.nerf.renderer_df import NeRFRenderer
7 |
8 | import numpy as np
9 | from external.ngp_encoder import get_encoder
10 |
11 | from .utils import safe_normalize
12 |
13 |
14 | class MLP(nn.Module):
15 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
16 | super().__init__()
17 | self.dim_in = dim_in
18 | self.dim_out = dim_out
19 | self.dim_hidden = dim_hidden
20 | self.num_layers = num_layers
21 |
22 | net = []
23 | for l in range(num_layers):
24 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
25 |
26 | self.net = nn.ModuleList(net)
27 |
28 | def forward(self, x):
29 | for l in range(self.num_layers):
30 | x = self.net[l](x)
31 | if l != self.num_layers - 1:
32 | x = F.relu(x, inplace=True)
33 | return x
34 |
35 |
36 | class NeRFNetwork(NeRFRenderer):
37 | def __init__(self,
38 | opt,
39 | num_layers=3,
40 | hidden_dim=64,
41 | num_layers_bg=2,
42 | hidden_dim_bg=64,
43 | ):
44 |
45 | super().__init__(opt)
46 |
47 | self.num_layers = num_layers
48 | self.hidden_dim = hidden_dim
49 |
50 | self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, log2_hashmap_size=16, desired_resolution=2048 * self.bound)
51 |
52 | self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
53 |
54 | # background network
55 | if self.bg_radius > 0:
56 | self.num_layers_bg = num_layers_bg
57 | self.hidden_dim_bg = hidden_dim_bg
58 |
59 | # use a very simple network to avoid it learning the prompt...
60 | # self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048)
61 | self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
62 |
63 | self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
64 |
65 | else:
66 | self.bg_net = None
67 |
68 | # add a density blob to the scene center
69 | def gaussian(self, x):
70 | # x: [B, N, 3]
71 |
72 | d = (x ** 2).sum(-1)
73 | g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
74 |
75 | return g
76 |
77 | def common_forward(self, x):
78 | # x: [N, 3], in [-bound, bound]
79 |
80 | # sigma
81 | h = self.encoder(x, bound=self.bound)
82 |
83 | h = self.sigma_net(h)
84 |
85 | sigma = trunc_exp(h[..., 0] + self.gaussian(x))
86 | albedo = torch.sigmoid(h[..., 1:])
87 |
88 | return sigma, albedo
89 |
90 | # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
91 | def finite_difference_normal(self, x, epsilon=1e-2):
92 | # x: [N, 3]
93 | dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
94 | dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
95 | dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
96 | dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
97 | dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
98 | dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
99 |
100 | normal = torch.stack([
101 | 0.5 * (dx_pos - dx_neg) / epsilon,
102 | 0.5 * (dy_pos - dy_neg) / epsilon,
103 | 0.5 * (dz_pos - dz_neg) / epsilon
104 | ], dim=-1)
105 |
106 | return normal
107 |
108 | @torch.no_grad()
109 | def common_forward_smooth(self, x, radius=1e-2, k=5):
110 | # x: [N, 3], in [-bound, bound]
111 |
112 | x_repeat = x.unsqueeze(0).repeat(k, 1, 1)
113 | eps = torch.rand_like(x_repeat) * radius * 2 - radius
114 | x_sample = x_repeat + eps
115 |
116 | sigma_list, albedo_list = [], []
117 | for xi in range(len(x_sample)):
118 | # sigma
119 | h = self.encoder(x_sample[xi], bound=self.bound)
120 |
121 | h = self.sigma_net(h)
122 |
123 | sigma = trunc_exp(h[..., 0] + self.gaussian(x))
124 | albedo = torch.sigmoid(h[..., 1:])
125 |
126 | sigma_list.append(sigma)
127 | albedo_list.append(albedo)
128 | sigma = torch.stack(sigma_list, dim=0).mean(dim=0)
129 | albedo = torch.stack(albedo_list, dim=0).mean(dim=0)
130 | return sigma, albedo
131 |
132 | @torch.no_grad()
133 | def finite_difference_normal_smooth(self, x, epsilon=3e-1):
134 | # x: [N, 3]
135 | # eps 1e-2 | radius
136 | # eps 3e-1 | radius 3e-1
137 | radius = 3e-1
138 | k = 10
139 | dx_pos, _ = self.common_forward_smooth((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k)
140 | dx_neg, _ = self.common_forward_smooth((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k)
141 | dy_pos, _ = self.common_forward_smooth((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k)
142 | dy_neg, _ = self.common_forward_smooth((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k)
143 | dz_pos, _ = self.common_forward_smooth((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k)
144 | dz_neg, _ = self.common_forward_smooth((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound), radius=radius, k=k)
145 |
146 | normal = torch.stack([
147 | 0.5 * (dx_pos - dx_neg) / epsilon,
148 | 0.5 * (dy_pos - dy_neg) / epsilon,
149 | 0.5 * (dz_pos - dz_neg) / epsilon
150 | ], dim=-1)
151 |
152 | return normal
153 |
154 |
155 | def normal(self, x, smooth=False):
156 |
157 | if smooth:
158 | normal = self.finite_difference_normal_smooth(x)
159 | else:
160 | normal = self.finite_difference_normal(x)
161 | normal = safe_normalize(normal)
162 | normal[torch.isnan(normal)] = 0
163 |
164 | return normal
165 |
166 |
167 | def forward(self, x, d, l=None, ratio=1, shading='textureless'):
168 | # x: [N, 3], in [-bound, bound]
169 | # d: [N, 3], view direction, nomalized in [-1, 1]
170 | # l: [3], plane light direction, nomalized in [-1, 1]
171 | # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
172 |
173 | if shading == 'albedo':
174 | # no need to query normal
175 | sigma, color = self.common_forward(x)
176 | normal = None
177 |
178 | else:
179 | # query normal
180 |
181 | sigma, albedo = self.common_forward(x)
182 | if shading == 'textureless' or shading == 'normal':
183 | normal = self.normal(x, smooth=True)
184 | else:
185 | normal = self.normal(x)
186 |
187 | # lambertian shading
188 | lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
189 |
190 | if shading == 'textureless':
191 | color = lambertian.unsqueeze(-1).repeat(1, 3)*0.8 + .2
192 | elif shading == 'normal':
193 | color = (normal + 1) / 2
194 | else: # 'lambertian'
195 | color = albedo * lambertian.unsqueeze(-1)
196 |
197 | return sigma, color, normal
198 |
199 |
200 | def density(self, x):
201 | # x: [N, 3], in [-bound, bound]
202 |
203 | sigma, albedo = self.common_forward(x)
204 |
205 | return {
206 | 'sigma': sigma,
207 | 'albedo': albedo,
208 | }
209 |
210 |
211 | def background(self, d):
212 |
213 | h = self.encoder_bg(d) # [N, C]
214 |
215 | h = self.bg_net(h)
216 |
217 | # sigmoid activation for rgb
218 | rgbs = torch.sigmoid(h)
219 |
220 | return rgbs
221 |
222 | # optimizer utils
223 | def get_params(self, lr):
224 |
225 | params = [
226 | {'params': self.encoder.parameters(), 'lr': lr * 10},
227 | {'params': self.sigma_net.parameters(), 'lr': lr},
228 | ]
229 |
230 | if self.bg_radius > 0:
231 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
232 | params.append({'params': self.bg_net.parameters(), 'lr': lr})
233 |
234 | return params
--------------------------------------------------------------------------------
/external/nerf/network_tcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 |
7 | import tinycudann as tcnn
8 | from activation import trunc_exp
9 | from .renderer import NeRFRenderer
10 |
11 |
12 | class NeRFNetwork(NeRFRenderer):
13 | def __init__(self,
14 | encoding="HashGrid",
15 | encoding_dir="SphericalHarmonics",
16 | num_layers=2,
17 | hidden_dim=64,
18 | geo_feat_dim=15,
19 | num_layers_color=3,
20 | hidden_dim_color=64,
21 | bound=1,
22 | **kwargs
23 | ):
24 | super().__init__(bound, **kwargs)
25 |
26 | # sigma network
27 | self.num_layers = num_layers
28 | self.hidden_dim = hidden_dim
29 | self.geo_feat_dim = geo_feat_dim
30 |
31 | per_level_scale = np.exp2(np.log2(2048 * bound / 16) / (16 - 1))
32 |
33 | self.encoder = tcnn.Encoding(
34 | n_input_dims=3,
35 | encoding_config={
36 | "otype": "HashGrid",
37 | "n_levels": 16,
38 | "n_features_per_level": 2,
39 | "log2_hashmap_size": 19,
40 | "base_resolution": 16,
41 | "per_level_scale": per_level_scale,
42 | },
43 | )
44 |
45 | self.sigma_net = tcnn.Network(
46 | n_input_dims=32,
47 | n_output_dims=1 + self.geo_feat_dim,
48 | network_config={
49 | "otype": "FullyFusedMLP",
50 | "activation": "ReLU",
51 | "output_activation": "None",
52 | "n_neurons": hidden_dim,
53 | "n_hidden_layers": num_layers - 1,
54 | },
55 | )
56 |
57 | # color network
58 | self.num_layers_color = num_layers_color
59 | self.hidden_dim_color = hidden_dim_color
60 |
61 | self.encoder_dir = tcnn.Encoding(
62 | n_input_dims=3,
63 | encoding_config={
64 | "otype": "SphericalHarmonics",
65 | "degree": 4,
66 | },
67 | )
68 |
69 | self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim
70 |
71 | self.color_net = tcnn.Network(
72 | n_input_dims=self.in_dim_color,
73 | n_output_dims=3,
74 | network_config={
75 | "otype": "FullyFusedMLP",
76 | "activation": "ReLU",
77 | "output_activation": "None",
78 | "n_neurons": hidden_dim_color,
79 | "n_hidden_layers": num_layers_color - 1,
80 | },
81 | )
82 |
83 |
84 | def forward(self, x, d):
85 | # x: [N, 3], in [-bound, bound]
86 | # d: [N, 3], nomalized in [-1, 1]
87 |
88 |
89 | # sigma
90 | x = (x + self.bound) / (2 * self.bound) # to [0, 1]
91 | x = self.encoder(x)
92 | h = self.sigma_net(x)
93 |
94 | #sigma = F.relu(h[..., 0])
95 | sigma = trunc_exp(h[..., 0])
96 | geo_feat = h[..., 1:]
97 |
98 | # color
99 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
100 | d = self.encoder_dir(d)
101 |
102 | #p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
103 | h = torch.cat([d, geo_feat], dim=-1)
104 | h = self.color_net(h)
105 |
106 | # sigmoid activation for rgb
107 | color = torch.sigmoid(h)
108 |
109 | return sigma, color
110 |
111 | def density(self, x):
112 | # x: [N, 3], in [-bound, bound]
113 |
114 | x = (x + self.bound) / (2 * self.bound) # to [0, 1]
115 | x = self.encoder(x)
116 | h = self.sigma_net(x)
117 |
118 | #sigma = F.relu(h[..., 0])
119 | sigma = trunc_exp(h[..., 0])
120 | geo_feat = h[..., 1:]
121 |
122 | return {
123 | 'sigma': sigma,
124 | 'geo_feat': geo_feat,
125 | }
126 |
127 | # allow masked inference
128 | def color(self, x, d, mask=None, geo_feat=None, **kwargs):
129 | # x: [N, 3] in [-bound, bound]
130 | # mask: [N,], bool, indicates where we actually needs to compute rgb.
131 |
132 | x = (x + self.bound) / (2 * self.bound) # to [0, 1]
133 |
134 | if mask is not None:
135 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
136 | # in case of empty mask
137 | if not mask.any():
138 | return rgbs
139 | x = x[mask]
140 | d = d[mask]
141 | geo_feat = geo_feat[mask]
142 |
143 | # color
144 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
145 | d = self.encoder_dir(d)
146 |
147 | h = torch.cat([d, geo_feat], dim=-1)
148 | h = self.color_net(h)
149 |
150 | # sigmoid activation for rgb
151 | h = torch.sigmoid(h)
152 |
153 | if mask is not None:
154 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
155 | else:
156 | rgbs = h
157 |
158 | return rgbs
159 |
160 | # optimizer utils
161 | def get_params(self, lr):
162 |
163 | params = [
164 | {'params': self.encoder.parameters(), 'lr': lr},
165 | {'params': self.sigma_net.parameters(), 'lr': lr},
166 | {'params': self.encoder_dir.parameters(), 'lr': lr},
167 | {'params': self.color_net.parameters(), 'lr': lr},
168 | ]
169 | if self.bg_radius > 0:
170 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr})
171 | params.append({'params': self.bg_net.parameters(), 'lr': lr})
172 |
173 | return params
--------------------------------------------------------------------------------
/external/nerf/provider.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import json
5 | from cv2 import transform
6 | import tqdm
7 | import numpy as np
8 | from scipy.spatial.transform import Slerp, Rotation
9 |
10 | import trimesh
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 |
15 | from .utils import get_rays
16 |
17 |
18 | # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
19 | def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]):
20 | # for the fox dataset, 0.33 scales camera radius to ~ 2
21 | new_pose = np.array([
22 | [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]],
23 | [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]],
24 | [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]],
25 | [0, 0, 0, 1],
26 | ], dtype=np.float32)
27 | return new_pose
28 |
29 |
30 | def visualize_poses(poses, size=0.1):
31 | # poses: [B, 4, 4]
32 |
33 | axes = trimesh.creation.axis(axis_length=4)
34 | box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
35 | box.colors = np.array([[128, 128, 128]] * len(box.entities))
36 | objects = [axes, box]
37 |
38 | for pose in poses:
39 | # a camera is visualized with 8 line segments.
40 | pos = pose[:3, 3]
41 | a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
42 | b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
43 | c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
44 | d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
45 |
46 | dir = (a + b + c + d) / 4 - pos
47 | dir = dir / (np.linalg.norm(dir) + 1e-8)
48 | o = pos + dir * 3
49 |
50 | segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]])
51 | segs = trimesh.load_path(segs)
52 | objects.append(segs)
53 |
54 | trimesh.Scene(objects).show()
55 |
56 |
57 | def rand_poses(size, device, radius=1, theta_range=[np.pi/3, 2*np.pi/3], phi_range=[0, 2*np.pi]):
58 | ''' generate random poses from an orbit camera
59 | Args:
60 | size: batch size of generated poses.
61 | device: where to allocate the output.
62 | radius: camera radius
63 | theta_range: [min, max], should be in [0, \pi]
64 | phi_range: [min, max], should be in [0, 2\pi]
65 | Return:
66 | poses: [size, 4, 4]
67 | '''
68 |
69 | def normalize(vectors):
70 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
71 |
72 | thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
73 | phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
74 |
75 | centers = torch.stack([
76 | radius * torch.sin(thetas) * torch.sin(phis),
77 | radius * torch.cos(thetas),
78 | radius * torch.sin(thetas) * torch.cos(phis),
79 | ], dim=-1) # [B, 3]
80 |
81 | # lookat
82 | forward_vector = - normalize(centers)
83 | up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) # confused at the coordinate system...
84 | right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))
85 | up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))
86 |
87 | poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
88 | poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
89 | poses[:, :3, 3] = centers
90 |
91 | return poses
92 |
93 |
94 | class NeRFDataset:
95 | def __init__(self, opt, device, type='train', downscale=1, n_test=10):
96 | super().__init__()
97 |
98 | self.opt = opt
99 | self.device = device
100 | self.type = type # train, val, test
101 | self.downscale = downscale
102 | self.root_path = opt.path
103 | self.preload = opt.preload # preload data into GPU
104 | self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
105 | self.offset = opt.offset # camera offset
106 | self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
107 | self.fp16 = opt.fp16 # if preload, load into fp16.
108 |
109 | self.training = self.type in ['train', 'all', 'trainval']
110 | self.num_rays = self.opt.num_rays if self.training else -1
111 |
112 | self.rand_pose = opt.rand_pose
113 |
114 | # auto-detect transforms.json and split mode.
115 | if os.path.exists(os.path.join(self.root_path, 'transforms.json')):
116 | self.mode = 'colmap' # manually split, use view-interpolation for test.
117 | elif os.path.exists(os.path.join(self.root_path, 'transforms_train.json')):
118 | self.mode = 'blender' # provided split
119 | else:
120 | raise NotImplementedError(f'[NeRFDataset] Cannot find transforms*.json under {self.root_path}')
121 |
122 | # load nerf-compatible format data.
123 | if self.mode == 'colmap':
124 | with open(os.path.join(self.root_path, 'transforms.json'), 'r') as f:
125 | transform = json.load(f)
126 | elif self.mode == 'blender':
127 | # load all splits (train/valid/test), this is what instant-ngp in fact does...
128 | if type == 'all':
129 | transform_paths = glob.glob(os.path.join(self.root_path, '*.json'))
130 | transform = None
131 | for transform_path in transform_paths:
132 | with open(transform_path, 'r') as f:
133 | tmp_transform = json.load(f)
134 | if transform is None:
135 | transform = tmp_transform
136 | else:
137 | transform['frames'].extend(tmp_transform['frames'])
138 | # load train and val split
139 | elif type == 'trainval':
140 | with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f:
141 | transform = json.load(f)
142 | with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f:
143 | transform_val = json.load(f)
144 | transform['frames'].extend(transform_val['frames'])
145 | # only load one specified split
146 | else:
147 | with open(os.path.join(self.root_path, f'transforms_{type}.json'), 'r') as f:
148 | transform = json.load(f)
149 |
150 | else:
151 | raise NotImplementedError(f'unknown dataset mode: {self.mode}')
152 |
153 | # load image size
154 | if 'h' in transform and 'w' in transform:
155 | self.H = int(transform['h']) // downscale
156 | self.W = int(transform['w']) // downscale
157 | else:
158 | # we have to actually read an image to get H and W later.
159 | self.H = self.W = None
160 |
161 | # read images
162 | frames = transform["frames"]
163 | #frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort...
164 |
165 | # for colmap, manually interpolate a test set.
166 | if self.mode == 'colmap' and type == 'test':
167 |
168 | # choose two random poses, and interpolate between.
169 | f0, f1 = np.random.choice(frames, 2, replace=False)
170 | pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
171 | pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
172 | rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]]))
173 | slerp = Slerp([0, 1], rots)
174 |
175 | self.poses = []
176 | self.images = None
177 | for i in range(n_test + 1):
178 | ratio = np.sin(((i / n_test) - 0.5) * np.pi) * 0.5 + 0.5
179 | pose = np.eye(4, dtype=np.float32)
180 | pose[:3, :3] = slerp(ratio).as_matrix()
181 | pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3]
182 | self.poses.append(pose)
183 |
184 | else:
185 | # for colmap, manually split a valid set (the first frame).
186 | if self.mode == 'colmap':
187 | if type == 'train':
188 | frames = frames[1:]
189 | elif type == 'val':
190 | frames = frames[:1]
191 | # else 'all' or 'trainval' : use all frames
192 |
193 | self.poses = []
194 | self.images = []
195 | for f in tqdm.tqdm(frames, desc=f'Loading {type} data'):
196 | f_path = os.path.join(self.root_path, f['file_path'])
197 | if self.mode == 'blender' and '.' not in os.path.basename(f_path):
198 | f_path += '.png' # so silly...
199 |
200 | # there are non-exist paths in fox...
201 | if not os.path.exists(f_path):
202 | continue
203 |
204 | pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4]
205 | pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)
206 |
207 | image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4]
208 | if self.H is None or self.W is None:
209 | self.H = image.shape[0] // downscale
210 | self.W = image.shape[1] // downscale
211 |
212 | # add support for the alpha channel as a mask.
213 | if image.shape[-1] == 3:
214 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
215 | else:
216 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
217 |
218 | if image.shape[0] != self.H or image.shape[1] != self.W:
219 | image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA)
220 |
221 | image = image.astype(np.float32) / 255 # [H, W, 3/4]
222 |
223 | self.poses.append(pose)
224 | self.images.append(image)
225 |
226 | self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4]
227 | if self.images is not None:
228 | self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C]
229 |
230 | # calculate mean radius of all camera poses
231 | self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()
232 | #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}')
233 |
234 | # initialize error_map
235 | if self.training and self.opt.error_map:
236 | self.error_map = torch.ones([self.images.shape[0], 128 * 128], dtype=torch.float) # [B, 128 * 128], flattened for easy indexing, fixed resolution...
237 | else:
238 | self.error_map = None
239 |
240 | # [debug] uncomment to view all training poses.
241 | # visualize_poses(self.poses.numpy())
242 |
243 | # [debug] uncomment to view examples of randomly generated poses.
244 | # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy())
245 |
246 | if self.preload:
247 | self.poses = self.poses.to(self.device)
248 | if self.images is not None:
249 | # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ?
250 | if self.fp16 and self.opt.color_space != 'linear':
251 | dtype = torch.half
252 | else:
253 | dtype = torch.float
254 | self.images = self.images.to(dtype).to(self.device)
255 | if self.error_map is not None:
256 | self.error_map = self.error_map.to(self.device)
257 |
258 | # load intrinsics
259 | if 'fl_x' in transform or 'fl_y' in transform:
260 | fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale
261 | fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale
262 | elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
263 | # blender, assert in radians. already downscaled since we use H/W
264 | fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
265 | fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
266 | if fl_x is None: fl_x = fl_y
267 | if fl_y is None: fl_y = fl_x
268 | else:
269 | raise RuntimeError('Failed to load focal length, please check the transforms.json!')
270 |
271 | cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2)
272 | cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2)
273 |
274 | self.intrinsics = np.array([fl_x, fl_y, cx, cy])
275 |
276 |
277 | def collate(self, index):
278 |
279 | B = len(index) # a list of length 1
280 |
281 | # random pose without gt images.
282 | if self.rand_pose == 0 or index[0] >= len(self.poses):
283 |
284 | poses = rand_poses(B, self.device, radius=self.radius)
285 |
286 | # sample a low-resolution but full image for CLIP
287 | s = np.sqrt(self.H * self.W / self.num_rays) # only in training, assert num_rays > 0
288 | rH, rW = int(self.H / s), int(self.W / s)
289 | rays = get_rays(poses, self.intrinsics / s, rH, rW, -1)
290 |
291 | return {
292 | 'H': rH,
293 | 'W': rW,
294 | 'rays_o': rays['rays_o'],
295 | 'rays_d': rays['rays_d'],
296 | }
297 |
298 | poses = self.poses[index].to(self.device) # [B, 4, 4]
299 |
300 | error_map = None if self.error_map is None else self.error_map[index]
301 |
302 | rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, error_map, self.opt.patch_size)
303 |
304 | results = {
305 | 'H': self.H,
306 | 'W': self.W,
307 | 'rays_o': rays['rays_o'],
308 | 'rays_d': rays['rays_d'],
309 | }
310 |
311 | if self.images is not None:
312 | images = self.images[index].to(self.device) # [B, H, W, 3/4]
313 | if self.training:
314 | C = images.shape[-1]
315 | images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
316 | results['images'] = images
317 |
318 | # need inds to update error_map
319 | if error_map is not None:
320 | results['index'] = index
321 | results['inds_coarse'] = rays['inds_coarse']
322 |
323 | return results
324 |
325 | def dataloader(self):
326 | size = len(self.poses)
327 | if self.training and self.rand_pose > 0:
328 | size += size // self.rand_pose # index >= size means we use random pose.
329 | loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
330 | loader._data = self # an ugly fix... we need to access error_map & poses in trainer.
331 | loader.has_gt = self.images is not None
332 | return loader
--------------------------------------------------------------------------------
/external/ngp_activation.py:
--------------------------------------------------------------------------------
1 | '''
2 | Helper functions for diffusion model
3 | #@ FROM https://github.com/ashawkey/torch-ngp
4 | '''
5 |
6 | import torch
7 | from torch.autograd import Function
8 | from torch.cuda.amp import custom_bwd, custom_fwd
9 |
10 | class _trunc_exp(Function):
11 | @staticmethod
12 | @custom_fwd(cast_inputs=torch.float32) # cast to float32
13 | def forward(ctx, x):
14 | ctx.save_for_backward(x)
15 | return torch.exp(x)
16 |
17 | @staticmethod
18 | @custom_bwd
19 | def backward(ctx, g):
20 | x = ctx.saved_tensors[0]
21 | return g * torch.exp(x.clamp(-15, 15))
22 |
23 | trunc_exp = _trunc_exp.apply
--------------------------------------------------------------------------------
/external/ngp_encoder.py:
--------------------------------------------------------------------------------
1 | '''
2 | Helper functions for diffusion model
3 | #@ FROM https://github.com/ashawkey/torch-ngp
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | class FreqEncoder(nn.Module):
11 | def __init__(self, input_dim, max_freq_log2, N_freqs,
12 | log_sampling=True, include_input=True,
13 | periodic_fns=(torch.sin, torch.cos)):
14 |
15 | super().__init__()
16 |
17 | self.input_dim = input_dim
18 | self.include_input = include_input
19 | self.periodic_fns = periodic_fns
20 |
21 | self.output_dim = 0
22 | if self.include_input:
23 | self.output_dim += self.input_dim
24 |
25 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
26 |
27 | if log_sampling:
28 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
29 | else:
30 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
31 |
32 | self.freq_bands = self.freq_bands.numpy().tolist()
33 |
34 | def forward(self, input, **kwargs):
35 |
36 | out = []
37 | if self.include_input:
38 | out.append(input)
39 |
40 | for i in range(len(self.freq_bands)):
41 | freq = self.freq_bands[i]
42 | for p_fn in self.periodic_fns:
43 | out.append(p_fn(input * freq))
44 |
45 | out = torch.cat(out, dim=-1)
46 |
47 |
48 | return out
49 |
50 | def get_encoder(encoding, input_dim=3,
51 | multires=6,
52 | degree=4,
53 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
54 | **kwargs):
55 |
56 | if encoding == 'None':
57 | return lambda x, **kwargs: x, input_dim
58 |
59 | elif encoding == 'frequency':
60 | raise NotImplementedError
61 |
62 | elif encoding == 'sphere_harmonics':
63 | raise NotImplementedError
64 |
65 | elif encoding == 'hashgrid':
66 | from external.gridencoder import GridEncoder
67 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
68 |
69 | elif encoding == 'tiledgrid':
70 | from external.gridencoder import GridEncoder
71 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
72 |
73 | elif encoding == 'ash':
74 | raise NotImplementedError
75 |
76 | else:
77 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
78 |
79 | return encoder, encoder.output_dim
--------------------------------------------------------------------------------
/external/plms.py:
--------------------------------------------------------------------------------
1 | '''
2 | Fast PNDM PLMS Sampler
3 | #@ FROM https://github.com/CompVis/stable-diffusion
4 | '''
5 | import torch
6 | import numpy as np
7 | from tqdm import tqdm
8 | from external.imagen_pytorch import GaussianDiffusionContinuousTimes, right_pad_dims_to
9 | from sparsefusion.vldm import DDPM
10 | from einops import rearrange
11 |
12 |
13 | class PLMSSampler():
14 |
15 | def __init__(self, diffusion: DDPM, plms_steps=100):
16 |
17 | self.diffusion = diffusion
18 | self.plms_steps = plms_steps
19 |
20 | @torch.no_grad()
21 | def sample(self,
22 | image=None,
23 | max_thres=.999,
24 | cond_images=None,
25 | cond_scale=1.0,
26 | use_tqdm=True,
27 | return_noise=False,
28 | **kwargs
29 | ):
30 | '''
31 | Single UNet PLMS Sampler
32 | '''
33 | outputs = []
34 | batch_size = cond_images.shape[0]
35 | shape = (batch_size, self.diffusion.sample_channels[0], self.diffusion.image_sizes[0], self.diffusion.image_sizes[0])
36 | img, x_noisy, noise, alpha_cumprod = self.plms_sample_loop(
37 | self.diffusion.unets[0],
38 | image = image,
39 | shape = shape,
40 | cond_images = cond_images,
41 | cond_scale = cond_scale,
42 | noise_scheduler = self.diffusion.noise_schedulers[0],
43 | pred_objective = self.diffusion.pred_objectives[0],
44 | dynamic_threshold = self.diffusion.dynamic_thresholding[0],
45 | use_tqdm = use_tqdm,
46 | max_thres = max_thres,
47 | )
48 | outputs.append(img)
49 | if not return_noise:
50 | return outputs[-1]
51 | return outputs[-1], x_noisy, noise, alpha_cumprod
52 |
53 | @torch.no_grad()
54 | def plms_sample_loop(self,
55 | unet,
56 | image,
57 | shape,
58 | cond_images,
59 | cond_scale,
60 | noise_scheduler,
61 | pred_objective,
62 | dynamic_threshold,
63 | use_tqdm,
64 | max_thres = None,
65 | ):
66 | '''
67 | Sampling loop
68 | '''
69 | batch = shape[0]
70 | device = self.diffusion.device
71 |
72 | if image is None:
73 | image = torch.randn(shape, device = device)
74 | else:
75 | assert(max_thres is not None)
76 |
77 |
78 | old_eps = []
79 |
80 | if max_thres >= .99:
81 | noise_scheduler_short = GaussianDiffusionContinuousTimes(noise_schedule='cosine', timesteps=self.plms_steps)
82 | timesteps = noise_scheduler_short.get_sampling_timesteps(batch, device=device)
83 | noise = torch.randn_like(image)
84 | x_noisy, log_snr= noise_scheduler_short.q_sample(image, t=max_thres, noise=noise)
85 | img = image
86 | else:
87 | n_steps = min(int(max_thres * self.plms_steps * 2), self.plms_steps)
88 | # n_steps = 50
89 | noise_scheduler_short = GaussianDiffusionContinuousTimes(noise_schedule='cosine', timesteps=self.plms_steps)
90 | timesteps = noise_scheduler_short.get_sampling_timesteps_custom(batch, device=device, max_thres=max_thres, n_steps=n_steps)
91 | noise = torch.randn_like(image)
92 | img, log_snr = noise_scheduler_short.q_sample(image, t=max_thres, noise=noise)
93 | x_noisy = img
94 |
95 | for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm):
96 | is_last_timestep = times_next == 0
97 | outs = self.p_sample(
98 | unet,
99 | img,
100 | times,
101 | t_next = times_next,
102 | cond_images = cond_images,
103 | cond_scale = cond_scale,
104 | noise_scheduler = noise_scheduler,
105 | pred_objective = pred_objective,
106 | dynamic_threshold = dynamic_threshold,
107 | old_eps = old_eps
108 | )
109 | img, pred_x0, e_t = outs
110 | old_eps.append(e_t)
111 | if len(old_eps) >= 4:
112 | old_eps.pop(0)
113 |
114 | if self.diffusion.clip_output:
115 | img.clamp_(-self.diffusion.clip_value, self.diffusion.clip_value)
116 |
117 | unnormalize_img = self.diffusion.unnormalize_img(img)
118 | alpha_cumprod = torch.sigmoid(log_snr)
119 | return unnormalize_img, x_noisy, noise, alpha_cumprod
120 |
121 | @torch.no_grad()
122 | def p_sample(self,
123 | unet,
124 | x,
125 | t,
126 | t_next,
127 | cond_images,
128 | cond_scale,
129 | noise_scheduler,
130 | pred_objective,
131 | dynamic_threshold,
132 | old_eps
133 | ):
134 |
135 | b, *_, device = *x.shape, x.device
136 | _, _, e_t = self.get_model_output(unet, x, t, t_next, cond_images, cond_scale, noise_scheduler, pred_objective, dynamic_threshold)
137 | if len(old_eps) == 0:
138 | # Pseudo Improved Euler (2nd order)
139 | # x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
140 | x_prev, pred_x0, _ = self.get_model_output(unet, x, t, t_next, cond_images, cond_scale, noise_scheduler, pred_objective, dynamic_threshold, pred_e = e_t)
141 | # e_t_next = get_model_output(x_prev, t_next)
142 | _, _, e_t_next = self.get_model_output(unet, x_prev, t_next, t_next, cond_images, cond_scale, noise_scheduler, pred_objective, dynamic_threshold)
143 | e_t_prime = (e_t + e_t_next) / 2
144 | elif len(old_eps) == 1:
145 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
146 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
147 | elif len(old_eps) == 2:
148 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
149 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
150 | elif len(old_eps) >= 3:
151 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
152 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
153 |
154 | x_prev, pred_x0, _ = self.get_model_output(unet, x, t, t_next, cond_images, cond_scale, noise_scheduler, pred_objective, dynamic_threshold, pred_e = e_t_prime)
155 |
156 | return x_prev, pred_x0, e_t
157 |
158 | def get_model_output(self,
159 | unet,
160 | x,
161 | t,
162 | t_next,
163 | cond_images,
164 | cond_scale,
165 | noise_scheduler,
166 | pred_objective,
167 | dynamic_threshold,
168 | pred_e = None,
169 | ):
170 | assert(pred_objective == 'noise')
171 | b, *_, device = *x.shape, x.device
172 |
173 |
174 | #@ PRED EPS
175 | if pred_e is None:
176 | pred = unet.forward_with_cond_scale(x, noise_scheduler.get_condition(t), cond_images = cond_images, cond_scale = cond_scale)
177 | pred_e = pred
178 | else:
179 | pred = pred_e
180 |
181 |
182 | #@ PREDICT X_0
183 | if pred_objective == 'noise':
184 | x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
185 | elif pred_objective == 'x_start':
186 | x_start = pred
187 | else:
188 | raise ValueError(f'unknown objective {pred_objective}')
189 |
190 | #@ CLIP X_0
191 | if self.diffusion.clip_output:
192 | if dynamic_threshold:
193 | # following pseudocode in appendix
194 | # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
195 | s = torch.quantile(
196 | rearrange(x_start, 'b ... -> b (...)').abs(),
197 | self.dynamic_thresholding_percentile,
198 | dim = -1
199 | )
200 |
201 | s.clamp_(min = 1.)
202 | s = right_pad_dims_to(x_start, s)
203 | x_start = x_start.clamp(-s, s) / s
204 | else:
205 | x_start.clamp_(-self.diffusion.clip_value, self.diffusion.clip_value)
206 |
207 | #@ USE Q_POSTERIOR TO GET X_PREV
208 | model_mean, _, model_log_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
209 | noise = torch.randn_like(x)
210 | is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
211 | nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
212 | x_prev = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
213 |
214 | return x_prev, x_start, pred_e
215 |
--------------------------------------------------------------------------------
/media/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhizdev/sparsefusion/0324bae8bd32c854d7222122110d866ff0ceba5e/media/teaser.jpg
--------------------------------------------------------------------------------
/raymarching/README.md:
--------------------------------------------------------------------------------
1 | All code in raymarching is from https://github.com/ashawkey/torch-ngp.
--------------------------------------------------------------------------------
/raymarching/__init__.py:
--------------------------------------------------------------------------------
1 | from .raymarching import *
--------------------------------------------------------------------------------
/raymarching/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | ]
10 |
11 | if os.name == "posix":
12 | c_flags = ['-O3', '-std=c++14']
13 | elif os.name == "nt":
14 | c_flags = ['/O2', '/std:c++17']
15 |
16 | # find cl.exe
17 | def find_cl_path():
18 | import glob
19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21 | if paths:
22 | return paths[0]
23 |
24 | # If cl.exe is not on path, try to find it.
25 | if os.system("where cl.exe >nul 2>nul") != 0:
26 | cl_path = find_cl_path()
27 | if cl_path is None:
28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29 | os.environ["PATH"] += ";" + cl_path
30 |
31 | _backend = load(name='_raymarching',
32 | extra_cflags=c_flags,
33 | extra_cuda_cflags=nvcc_flags,
34 | sources=[os.path.join(_src_path, 'src', f) for f in [
35 | 'raymarching.cu',
36 | 'bindings.cpp',
37 | ]],
38 | )
39 |
40 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/raymarching/raymarching.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Function
7 | from torch.cuda.amp import custom_bwd, custom_fwd
8 |
9 | try:
10 | import _raymarching as _backend
11 | except ImportError:
12 | from .backend import _backend
13 |
14 |
15 | # ----------------------------------------
16 | # utils
17 | # ----------------------------------------
18 |
19 | class _near_far_from_aabb(Function):
20 | @staticmethod
21 | @custom_fwd(cast_inputs=torch.float32)
22 | def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
23 | ''' near_far_from_aabb, CUDA implementation
24 | Calculate rays' intersection time (near and far) with aabb
25 | Args:
26 | rays_o: float, [N, 3]
27 | rays_d: float, [N, 3]
28 | aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
29 | min_near: float, scalar
30 | Returns:
31 | nears: float, [N]
32 | fars: float, [N]
33 | '''
34 | if not rays_o.is_cuda: rays_o = rays_o.cuda()
35 | if not rays_d.is_cuda: rays_d = rays_d.cuda()
36 |
37 | rays_o = rays_o.contiguous().view(-1, 3)
38 | rays_d = rays_d.contiguous().view(-1, 3)
39 |
40 | N = rays_o.shape[0] # num rays
41 |
42 | nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
43 | fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
44 |
45 | _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
46 |
47 | return nears, fars
48 |
49 | near_far_from_aabb = _near_far_from_aabb.apply
50 |
51 |
52 | class _sph_from_ray(Function):
53 | @staticmethod
54 | @custom_fwd(cast_inputs=torch.float32)
55 | def forward(ctx, rays_o, rays_d, radius):
56 | ''' sph_from_ray, CUDA implementation
57 | get spherical coordinate on the background sphere from rays.
58 | Assume rays_o are inside the Sphere(radius).
59 | Args:
60 | rays_o: [N, 3]
61 | rays_d: [N, 3]
62 | radius: scalar, float
63 | Return:
64 | coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
65 | '''
66 | if not rays_o.is_cuda: rays_o = rays_o.cuda()
67 | if not rays_d.is_cuda: rays_d = rays_d.cuda()
68 |
69 | rays_o = rays_o.contiguous().view(-1, 3)
70 | rays_d = rays_d.contiguous().view(-1, 3)
71 |
72 | N = rays_o.shape[0] # num rays
73 |
74 | coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
75 |
76 | _backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
77 |
78 | return coords
79 |
80 | sph_from_ray = _sph_from_ray.apply
81 |
82 |
83 | class _morton3D(Function):
84 | @staticmethod
85 | def forward(ctx, coords):
86 | ''' morton3D, CUDA implementation
87 | Args:
88 | coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
89 | TODO: check if the coord range is valid! (current 128 is safe)
90 | Returns:
91 | indices: [N], int32, in [0, 128^3)
92 |
93 | '''
94 | if not coords.is_cuda: coords = coords.cuda()
95 |
96 | N = coords.shape[0]
97 |
98 | indices = torch.empty(N, dtype=torch.int32, device=coords.device)
99 |
100 | _backend.morton3D(coords.int(), N, indices)
101 |
102 | return indices
103 |
104 | morton3D = _morton3D.apply
105 |
106 | class _morton3D_invert(Function):
107 | @staticmethod
108 | def forward(ctx, indices):
109 | ''' morton3D_invert, CUDA implementation
110 | Args:
111 | indices: [N], int32, in [0, 128^3)
112 | Returns:
113 | coords: [N, 3], int32, in [0, 128)
114 |
115 | '''
116 | if not indices.is_cuda: indices = indices.cuda()
117 |
118 | N = indices.shape[0]
119 |
120 | coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
121 |
122 | _backend.morton3D_invert(indices.int(), N, coords)
123 |
124 | return coords
125 |
126 | morton3D_invert = _morton3D_invert.apply
127 |
128 |
129 | class _packbits(Function):
130 | @staticmethod
131 | @custom_fwd(cast_inputs=torch.float32)
132 | def forward(ctx, grid, thresh, bitfield=None):
133 | ''' packbits, CUDA implementation
134 | Pack up the density grid into a bit field to accelerate ray marching.
135 | Args:
136 | grid: float, [C, H * H * H], assume H % 2 == 0
137 | thresh: float, threshold
138 | Returns:
139 | bitfield: uint8, [C, H * H * H / 8]
140 | '''
141 | if not grid.is_cuda: grid = grid.cuda()
142 | grid = grid.contiguous()
143 |
144 | C = grid.shape[0]
145 | H3 = grid.shape[1]
146 | N = C * H3 // 8
147 |
148 | if bitfield is None:
149 | bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
150 |
151 | _backend.packbits(grid, N, thresh, bitfield)
152 |
153 | return bitfield
154 |
155 | packbits = _packbits.apply
156 |
157 | # ----------------------------------------
158 | # train functions
159 | # ----------------------------------------
160 |
161 | class _march_rays_train(Function):
162 | @staticmethod
163 | @custom_fwd(cast_inputs=torch.float32)
164 | def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
165 | ''' march rays to generate points (forward only)
166 | Args:
167 | rays_o/d: float, [N, 3]
168 | bound: float, scalar
169 | density_bitfield: uint8: [CHHH // 8]
170 | C: int
171 | H: int
172 | nears/fars: float, [N]
173 | step_counter: int32, (2), used to count the actual number of generated points.
174 | mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
175 | perturb: bool
176 | align: int, pad output so its size is dividable by align, set to -1 to disable.
177 | force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
178 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
179 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
180 | Returns:
181 | xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
182 | dirs: float, [M, 3], all generated points' view dirs.
183 | deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
184 | rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
185 | '''
186 |
187 | if not rays_o.is_cuda: rays_o = rays_o.cuda()
188 | if not rays_d.is_cuda: rays_d = rays_d.cuda()
189 | if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
190 |
191 | rays_o = rays_o.contiguous().view(-1, 3)
192 | rays_d = rays_d.contiguous().view(-1, 3)
193 | density_bitfield = density_bitfield.contiguous()
194 |
195 | N = rays_o.shape[0] # num rays
196 | M = N * max_steps # init max points number in total
197 |
198 | # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
199 | # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
200 | if not force_all_rays and mean_count > 0:
201 | if align > 0:
202 | mean_count += align - mean_count % align
203 | M = mean_count
204 |
205 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
206 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
207 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
208 | rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
209 |
210 | if step_counter is None:
211 | step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
212 |
213 | if perturb:
214 | noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
215 | else:
216 | noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
217 |
218 | _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
219 |
220 | #print(step_counter, M)
221 |
222 | # only used at the first (few) epochs.
223 | if force_all_rays or mean_count <= 0:
224 | m = step_counter[0].item() # D2H copy
225 | if align > 0:
226 | m += align - m % align
227 | xyzs = xyzs[:m]
228 | dirs = dirs[:m]
229 | deltas = deltas[:m]
230 |
231 | torch.cuda.empty_cache()
232 |
233 | return xyzs, dirs, deltas, rays
234 |
235 | march_rays_train = _march_rays_train.apply
236 |
237 |
238 | class _composite_rays_train(Function):
239 | @staticmethod
240 | @custom_fwd(cast_inputs=torch.float32)
241 | def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
242 | ''' composite rays' rgbs, according to the ray marching formula.
243 | Args:
244 | rgbs: float, [M, 3]
245 | sigmas: float, [M,]
246 | deltas: float, [M, 2]
247 | rays: int32, [N, 3]
248 | Returns:
249 | weights_sum: float, [N,], the alpha channel
250 | depth: float, [N, ], the Depth
251 | image: float, [N, 3], the RGB channel (after multiplying alpha!)
252 | '''
253 |
254 | sigmas = sigmas.contiguous()
255 | rgbs = rgbs.contiguous()
256 |
257 | M = sigmas.shape[0]
258 | N = rays.shape[0]
259 |
260 | weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
261 | depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
262 | image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
263 |
264 | _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
265 |
266 | ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
267 | ctx.dims = [M, N, T_thresh]
268 |
269 | return weights_sum, depth, image
270 |
271 | @staticmethod
272 | @custom_bwd
273 | def backward(ctx, grad_weights_sum, grad_depth, grad_image):
274 |
275 | # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
276 |
277 | grad_weights_sum = grad_weights_sum.contiguous()
278 | grad_image = grad_image.contiguous()
279 |
280 | sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
281 | M, N, T_thresh = ctx.dims
282 |
283 | grad_sigmas = torch.zeros_like(sigmas)
284 | grad_rgbs = torch.zeros_like(rgbs)
285 |
286 | _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
287 |
288 | return grad_sigmas, grad_rgbs, None, None, None
289 |
290 |
291 | composite_rays_train = _composite_rays_train.apply
292 |
293 | # ----------------------------------------
294 | # infer functions
295 | # ----------------------------------------
296 |
297 | class _march_rays(Function):
298 | @staticmethod
299 | @custom_fwd(cast_inputs=torch.float32)
300 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
301 | ''' march rays to generate points (forward only, for inference)
302 | Args:
303 | n_alive: int, number of alive rays
304 | n_step: int, how many steps we march
305 | rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
306 | rays_t: float, [N], the alive rays' time, we only use the first n_alive.
307 | rays_o/d: float, [N, 3]
308 | bound: float, scalar
309 | density_bitfield: uint8: [CHHH // 8]
310 | C: int
311 | H: int
312 | nears/fars: float, [N]
313 | align: int, pad output so its size is dividable by align, set to -1 to disable.
314 | perturb: bool/int, int > 0 is used as the random seed.
315 | dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
316 | max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
317 | Returns:
318 | xyzs: float, [n_alive * n_step, 3], all generated points' coords
319 | dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
320 | deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
321 | '''
322 |
323 | if not rays_o.is_cuda: rays_o = rays_o.cuda()
324 | if not rays_d.is_cuda: rays_d = rays_d.cuda()
325 |
326 | rays_o = rays_o.contiguous().view(-1, 3)
327 | rays_d = rays_d.contiguous().view(-1, 3)
328 |
329 | M = n_alive * n_step
330 |
331 | if align > 0:
332 | M += align - (M % align)
333 |
334 | xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
335 | dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
336 | deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
337 |
338 | if perturb:
339 | # torch.manual_seed(perturb) # test_gui uses spp index as seed
340 | noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
341 | else:
342 | noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
343 |
344 | _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
345 |
346 | return xyzs, dirs, deltas
347 |
348 | march_rays = _march_rays.apply
349 |
350 |
351 | class _composite_rays(Function):
352 | @staticmethod
353 | @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
354 | def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
355 | ''' composite rays' rgbs, according to the ray marching formula. (for inference)
356 | Args:
357 | n_alive: int, number of alive rays
358 | n_step: int, how many steps we march
359 | rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
360 | rays_t: float, [N], the alive rays' time
361 | sigmas: float, [n_alive * n_step,]
362 | rgbs: float, [n_alive * n_step, 3]
363 | deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
364 | In-place Outputs:
365 | weights_sum: float, [N,], the alpha channel
366 | depth: float, [N,], the depth value
367 | image: float, [N, 3], the RGB channel (after multiplying alpha!)
368 | '''
369 | _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
370 | return tuple()
371 |
372 |
373 | composite_rays = _composite_rays.apply
--------------------------------------------------------------------------------
/raymarching/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22 | if paths:
23 | return paths[0]
24 |
25 | # If cl.exe is not on path, try to find it.
26 | if os.system("where cl.exe >nul 2>nul") != 0:
27 | cl_path = find_cl_path()
28 | if cl_path is None:
29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30 | os.environ["PATH"] += ";" + cl_path
31 |
32 | '''
33 | Usage:
34 |
35 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
36 |
37 | python setup.py install # build extensions and install (copy) to PATH.
38 | pip install . # ditto but better (e.g., dependency & metadata handling)
39 |
40 | python setup.py develop # build extensions and install (symbolic) to PATH.
41 | pip install -e . # ditto but better (e.g., dependency & metadata handling)
42 |
43 | '''
44 | setup(
45 | name='raymarching', # package name, import this to use python API
46 | ext_modules=[
47 | CUDAExtension(
48 | name='_raymarching', # extension name, import this to use CUDA API
49 | sources=[os.path.join(_src_path, 'src', f) for f in [
50 | 'raymarching.cu',
51 | 'bindings.cpp',
52 | ]],
53 | extra_compile_args={
54 | 'cxx': c_flags,
55 | 'nvcc': nvcc_flags,
56 | }
57 | ),
58 | ],
59 | cmdclass={
60 | 'build_ext': BuildExtension,
61 | }
62 | )
--------------------------------------------------------------------------------
/raymarching/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "raymarching.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | // utils
7 | m.def("packbits", &packbits, "packbits (CUDA)");
8 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
9 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
10 | m.def("morton3D", &morton3D, "morton3D (CUDA)");
11 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
12 | // train
13 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
14 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
15 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
16 | // infer
17 | m.def("march_rays", &march_rays, "march rays (CUDA)");
18 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
19 | }
--------------------------------------------------------------------------------
/raymarching/src/raymarching.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 |
7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
12 |
13 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
14 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
15 | void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
16 |
17 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
18 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch-ema
2 | trimesh
3 | opencv-python
4 | tensorboardX
5 | torch
6 | numpy
7 | pandas
8 | tqdm
9 | matplotlib
10 | PyMCubes
11 | rich
12 | pysdf
13 | packaging
14 | scipy
15 | lpips
16 | imageio
17 | einops
18 | scikit-image
19 | omegaconf
20 | plotly
--------------------------------------------------------------------------------
/utils/check_args.py:
--------------------------------------------------------------------------------
1 | '''
2 | Check that the provided arguments are valid.
3 | '''
4 |
5 | import os
6 | import sys
7 | from utils.co3d_dataloader import CO3D_ALL_CATEGORIES
8 |
9 | def check_args(args):
10 |
11 | #@ CHECK ALL EXPECTED ARGS ARE PRESENT
12 | if not (args.dataset_name == 'co3d' or args.dataset_name == 'co3d_toy'):
13 | print(f'ERROR: Provided {args.dataset} as dataset, but only (co3d, co3d_toy) are supported.')
14 | print('Exiting...')
15 | exit(1)
16 |
17 | if args.category not in CO3D_ALL_CATEGORIES and args.category not in {'all_ten', 'all'}:
18 | print(f'ERROR: Provided category {args.category} is not in CO3D.')
19 | print('Exiting...')
20 | exit(1)
21 |
22 | #@ CHECK DATASET FOLDER EXISTS
23 | if not os.path.exists(args.root):
24 | print(f'ERROR: Provided dataset root {args.root} does not exist.')
25 | print('Exiting...')
26 | exit(1)
27 |
28 | #@ CHECK MODEL WEIGHTS PATHS EXIST
29 | if not os.path.exists(args.eft_ckpt):
30 | print(f'ERROR: Provided EFT weight {args.eft_ckpt} does not exist.')
31 | print('Exiting...')
32 | exit(1)
33 |
34 | if not os.path.exists(args.vldm_ckpt):
35 | print(f'ERROR: Provided VLDM weight {args.vldm_ckpt} does not exist.')
36 | print('Exiting...')
37 | exit(1)
38 |
39 | if not os.path.exists(args.vae_ckpt):
40 | print(f'ERROR: Provided VAE weight {args.vae_ckpt} does not exist.')
41 | print('Exiting...')
42 | exit(1)
43 |
44 | return
45 |
--------------------------------------------------------------------------------
/utils/co3d_toy_dataloader.py:
--------------------------------------------------------------------------------
1 | '''
2 | Wrapper for a preprocessed and simplified CO3Dv2 dataset
3 | #@ Modified from https://github.com/facebookresearch/pytorch3d
4 | '''
5 |
6 | import os
7 | import torch
8 |
9 | class CO3Dv2ToyLoader(torch.utils.data.Dataset):
10 |
11 | def __init__(self, root, category):
12 | super().__init__()
13 |
14 | self.root = root
15 | self.category = category
16 |
17 | default_path = f'{root}/{category}/{category}_toy.pt'
18 | if not os.path.exists(default_path):
19 | print(f'ERROR: toy dataset not found at {default_path}')
20 | print('Exiting...')
21 | exit(1)
22 |
23 | dataset = torch.load(default_path)
24 | self.seq_list = dataset[category]
25 |
26 | def __len__(self):
27 | return len(self.seq_list)
28 |
29 | def __getitem__(self, index):
30 | return self.seq_list[index]
--------------------------------------------------------------------------------
/utils/common_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | A collection of common utilities
3 | '''
4 | import torch
5 | import lpips
6 | import skimage.metrics
7 |
8 |
9 | def normalize(x):
10 | '''
11 | Normalize [0, 1] to [-1, 1]
12 | '''
13 | return torch.clip(x*2 - 1.0, -1.0, 1.0)
14 |
15 | def unnormalize(x):
16 | '''
17 | Unnormalize [-1, 1] to [0, 1]
18 | '''
19 | return torch.clip((x + 1.0) / 2.0, 0.0, 1.0)
20 |
21 | def split_list(a, n):
22 | '''
23 | Split list into n parts
24 |
25 | Args:
26 | a (list): list
27 | n (int): number of parts
28 |
29 | Returns:
30 | a_split (list[list]): nested list of a split into n parts
31 | '''
32 | k, m = divmod(len(a), n)
33 | return [a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)]
34 |
35 |
36 | def get_lpips_fn():
37 | '''
38 | Return LPIPS function
39 | '''
40 | loss_fn_vgg = lpips.LPIPS(net='vgg')
41 | return loss_fn_vgg
42 |
43 |
44 | def get_metrics(pred, gt, use_lpips=False, loss_fn_vgg=None, device=None):
45 | '''
46 | Compute image metrics
47 |
48 | Args:
49 | pred (np array): (H, W, 3)
50 | gt (np array): (H, W, 3)
51 | '''
52 | ssim = skimage.metrics.structural_similarity(pred, gt, channel_axis = -1, data_range=1)
53 | psnr = skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=1)
54 | if use_lpips:
55 | if loss_fn_vgg is None:
56 | loss_fn_vgg = lpips.LPIPS(net='vgg')
57 | gt_ = torch.from_numpy(gt).permute(2,0,1).unsqueeze(0)*2 - 1.0
58 | pred_ = torch.from_numpy(pred).permute(2,0,1).unsqueeze(0)*2 - 1.0
59 | if device is not None:
60 | lp = loss_fn_vgg(gt_.to(device), pred_.to(device)).detach().cpu().numpy().item()
61 | else:
62 | lp = loss_fn_vgg(gt_, pred_).detach().numpy().item()
63 | return ssim, psnr, lp
64 | return ssim, psnr
65 |
66 |
67 | #@ FROM PyTorch3D
68 | class HarmonicEmbedding(torch.nn.Module):
69 | def __init__(
70 | self,
71 | n_harmonic_functions: int = 6,
72 | omega_0: float = 1.0,
73 | logspace: bool = True,
74 | append_input: bool = True,
75 | ) -> None:
76 | """
77 | Given an input tensor `x` of shape [minibatch, ... , dim],
78 | the harmonic embedding layer converts each feature
79 | (i.e. vector along the last dimension) in `x`
80 | into a series of harmonic features `embedding`,
81 | where for each i in range(dim) the following are present
82 | in embedding[...]:
83 | ```
84 | [
85 | sin(f_1*x[..., i]),
86 | sin(f_2*x[..., i]),
87 | ...
88 | sin(f_N * x[..., i]),
89 | cos(f_1*x[..., i]),
90 | cos(f_2*x[..., i]),
91 | ...
92 | cos(f_N * x[..., i]),
93 | x[..., i], # only present if append_input is True.
94 | ]
95 | ```
96 | where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
97 | denoting the i-th frequency of the harmonic embedding.
98 | If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
99 | powers of 2:
100 | `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
101 | If `logspace==False`, frequencies are linearly spaced between
102 | `1.0` and `2**(n_harmonic_functions-1)`:
103 | `f_1, ..., f_N = torch.linspace(
104 | 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
105 | )`
106 | Note that `x` is also premultiplied by the base frequency `omega_0`
107 | before evaluating the harmonic functions.
108 | Args:
109 | n_harmonic_functions: int, number of harmonic
110 | features
111 | omega_0: float, base frequency
112 | logspace: bool, Whether to space the frequencies in
113 | logspace or linear space
114 | append_input: bool, whether to concat the original
115 | input to the harmonic embedding. If true the
116 | output is of the form (x, embed.sin(), embed.cos()
117 | """
118 | super().__init__()
119 |
120 | if logspace:
121 | frequencies = 2.0 ** torch.arange(
122 | n_harmonic_functions,
123 | dtype=torch.float32,
124 | )
125 | else:
126 | frequencies = torch.linspace(
127 | 1.0,
128 | 2.0 ** (n_harmonic_functions - 1),
129 | n_harmonic_functions,
130 | dtype=torch.float32,
131 | )
132 |
133 | self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
134 | self.append_input = append_input
135 |
136 | def forward(self, x: torch.Tensor) -> torch.Tensor:
137 | """
138 | Args:
139 | x: tensor of shape [..., dim]
140 | Returns:
141 | embedding: a harmonic embedding of `x`
142 | of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
143 | """
144 | embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
145 | embed = torch.cat(
146 | (embed.sin(), embed.cos(), x)
147 | if self.append_input
148 | else (embed.sin(), embed.cos()),
149 | dim=-1,
150 | )
151 | return embed
152 |
153 | @staticmethod
154 | def get_output_dim_static(
155 | input_dims: int,
156 | n_harmonic_functions: int,
157 | append_input: bool,
158 | ) -> int:
159 | """
160 | Utility to help predict the shape of the output of `forward`.
161 | Args:
162 | input_dims: length of the last dimension of the input tensor
163 | n_harmonic_functions: number of embedding frequencies
164 | append_input: whether or not to concat the original
165 | input to the harmonic embedding
166 | Returns:
167 | int: the length of the last dimension of the output tensor
168 | """
169 | return input_dims * (2 * n_harmonic_functions + int(append_input))
170 |
171 | def get_output_dim(self, input_dims: int = 3) -> int:
172 | """
173 | Same as above. The default for input_dims is 3 for 3D applications
174 | which use harmonic embedding for positional encoding,
175 | so the input might be xyz.
176 | """
177 | return self.get_output_dim_static(
178 | input_dims, len(self._frequencies), self.append_input
179 | )
180 |
181 |
182 | #@ From PyTorch3D
183 | def huber(x, y, scaling=0.1):
184 | """
185 | A helper function for evaluating the smooth L1 (huber) loss
186 | between the rendered silhouettes and colors.
187 | """
188 | diff_sq = (x - y) ** 2
189 | loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)
190 | return loss
191 |
192 |
193 | #@ From PyTorch3D
194 | def sample_images_at_mc_locs(target_images, sampled_rays_xy):
195 | """
196 | Given a set of Monte Carlo pixel locations `sampled_rays_xy`,
197 | this method samples the tensor `target_images` at the
198 | respective 2D locations.
199 |
200 | This function is used in order to extract the colors from
201 | ground truth images that correspond to the colors
202 | rendered using `MonteCarloRaysampler`.
203 | """
204 | ba = target_images.shape[0]
205 | dim = target_images.shape[-1]
206 | spatial_size = sampled_rays_xy.shape[1:-1]
207 | # In order to sample target_images, we utilize
208 | # the grid_sample function which implements a
209 | # bilinear image sampler.
210 | # Note that we have to invert the sign of the
211 | # sampled ray positions to convert the NDC xy locations
212 | # of the MonteCarloRaysampler to the coordinate
213 | # convention of grid_sample.
214 |
215 | if target_images.shape[2] != target_images.shape[3]:
216 | target_images = target_images.permute(0, 3, 1, 2)
217 | elif target_images.shape[2] == target_images.shape[3]:
218 | dim = target_images.shape[1]
219 |
220 | images_sampled = torch.nn.functional.grid_sample(
221 | target_images,
222 | -sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion
223 | align_corners=True
224 | )
225 | return images_sampled.permute(0, 2, 3, 1).view(
226 | ba, *spatial_size, dim
227 | )
228 |
--------------------------------------------------------------------------------
/utils/eft_raymarcher.py:
--------------------------------------------------------------------------------
1 | '''
2 | Custom PyTorch3D Raymarcher for EFT
3 | #@ Modified from https://github.com/facebookresearch/pytorch3d
4 | '''
5 |
6 | from typing import Optional, Tuple, Union
7 | import torch
8 |
9 | from pytorch3d.renderer import EmissionAbsorptionRaymarcher
10 | from pytorch3d.renderer.implicit.raymarching import (
11 | _check_density_bounds,
12 | _check_raymarcher_inputs,
13 | _shifted_cumprod,
14 | )
15 |
16 | class LightFieldRaymarcher(torch.nn.Module):
17 | """
18 | A nominal ray marcher that returns LightField features without any raymarching
19 | """
20 | def __init__(self):
21 | super().__init__()
22 |
23 | def forward(
24 | self,
25 | rays_densities: torch.Tensor,
26 | rays_features: torch.Tensor,
27 | **kwargs
28 | ) -> Union[None, torch.Tensor]:
29 | """
30 | """
31 | return torch.cat((rays_densities, rays_features), dim=-1)
--------------------------------------------------------------------------------
/utils/eft_renderer.py:
--------------------------------------------------------------------------------
1 | '''
2 | Custom PyTorch3D Implicit Renderer
3 | #@ Modified from https://github.com/facebookresearch/pytorch3d
4 | '''
5 |
6 | from typing import Callable, Tuple
7 | from einops import rearrange, repeat, reduce
8 | import torch
9 |
10 | from pytorch3d.ops.utils import eyes
11 | from pytorch3d.structures import Volumes
12 | from pytorch3d.transforms import Transform3d
13 | from pytorch3d.renderer.cameras import CamerasBase
14 | from pytorch3d.renderer.implicit.raysampling import RayBundle
15 | from pytorch3d.renderer.implicit.utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points, ray_bundle_to_ray_points
16 |
17 | #@ MODIFIED FROM PYTORCH3D
18 | class CustomImplicitRenderer(torch.nn.Module):
19 | """
20 | A class for rendering a batch of implicit surfaces. The class should
21 | be initialized with a raysampler and raymarcher class which both have
22 | to be a `Callable`.
23 | VOLUMETRIC_FUNCTION
24 | The `forward` function of the renderer accepts as input the rendering cameras
25 | as well as the `volumetric_function` `Callable`, which defines a field of opacity
26 | and feature vectors over the 3D domain of the scene.
27 | A standard `volumetric_function` has the following signature:
28 | ```
29 | def volumetric_function(
30 | ray_bundle: RayBundle,
31 | **kwargs,
32 | ) -> Tuple[torch.Tensor, torch.Tensor]
33 | ```
34 | With the following arguments:
35 | `ray_bundle`: A RayBundle object containing the following variables:
36 | `origins`: A tensor of shape `(minibatch, ..., 3)` denoting
37 | the origins of the rendering rays.
38 | `directions`: A tensor of shape `(minibatch, ..., 3)`
39 | containing the direction vectors of rendering rays.
40 | `lengths`: A tensor of shape
41 | `(minibatch, ..., num_points_per_ray)`containing the
42 | lengths at which the ray points are sampled.
43 | `xys`: A tensor of shape
44 | `(minibatch, ..., 2)` containing the
45 | xy locations of each ray's pixel in the screen space.
46 | Calling `volumetric_function` then returns the following:
47 | `rays_densities`: A tensor of shape
48 | `(minibatch, ..., num_points_per_ray, opacity_dim)` containing
49 | the an opacity vector for each ray point.
50 | `rays_features`: A tensor of shape
51 | `(minibatch, ..., num_points_per_ray, feature_dim)` containing
52 | the an feature vector for each ray point.
53 | Note that, in order to increase flexibility of the API, we allow multiple
54 | other arguments to enter the volumetric function via additional
55 | (optional) keyword arguments `**kwargs`.
56 | A typical use-case is passing a `CamerasBase` object as an additional
57 | keyword argument, which can allow the volumetric function to adjust its
58 | outputs based on the directions of the projection rays.
59 | Example:
60 | A simple volumetric function of a 0-centered
61 | RGB sphere with a unit diameter is defined as follows:
62 | ```
63 | def volumetric_function(
64 | ray_bundle: RayBundle,
65 | **kwargs,
66 | ) -> Tuple[torch.Tensor, torch.Tensor]:
67 | # first convert the ray origins, directions and lengths
68 | # to 3D ray point locations in world coords
69 | rays_points_world = ray_bundle_to_ray_points(ray_bundle)
70 | # set the densities as an inverse sigmoid of the
71 | # ray point distance from the sphere centroid
72 | rays_densities = torch.sigmoid(
73 | -100.0 * rays_points_world.norm(dim=-1, keepdim=True)
74 | )
75 | # set the ray features to RGB colors proportional
76 | # to the 3D location of the projection of ray points
77 | # on the sphere surface
78 | rays_features = torch.nn.functional.normalize(
79 | rays_points_world, dim=-1
80 | ) * 0.5 + 0.5
81 | return rays_densities, rays_features
82 | ```
83 | """
84 |
85 | def __init__(self, raysampler: Callable, raymarcher: Callable, reg=None) -> None:
86 | """
87 | Args:
88 | raysampler: A `Callable` that takes as input scene cameras
89 | (an instance of `CamerasBase`) and returns a `RayBundle` that
90 | describes the rays emitted from the cameras.
91 | raymarcher: A `Callable` that receives the response of the
92 | `volumetric_function` (an input to `self.forward`) evaluated
93 | along the sampled rays, and renders the rays with a
94 | ray-marching algorithm.
95 | """
96 | super().__init__()
97 |
98 | if not callable(raysampler):
99 | raise ValueError('"raysampler" has to be a "Callable" object.')
100 | if not callable(raymarcher):
101 | raise ValueError('"raymarcher" has to be a "Callable" object.')
102 |
103 | self.raysampler = raysampler
104 | self.raymarcher = raymarcher
105 | self.reg = reg
106 |
107 | def forward(
108 | self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
109 | ) -> Tuple[torch.Tensor, RayBundle]:
110 | """
111 | Render a batch of images using a volumetric function
112 | represented as a callable (e.g. a Pytorch module).
113 | Args:
114 | cameras: A batch of cameras that render the scene. A `self.raysampler`
115 | takes the cameras as input and samples rays that pass through the
116 | domain of the volumetric function.
117 | volumetric_function: A `Callable` that accepts the parametrizations
118 | of the rendering rays and returns the densities and features
119 | at the respective 3D of the rendering rays. Please refer to
120 | the main class documentation for details.
121 | Returns:
122 | images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
123 | containing the result of the rendering.
124 | ray_bundle: A `RayBundle` containing the parametrizations of the
125 | sampled rendering rays.
126 | """
127 |
128 | if not callable(volumetric_function):
129 | raise ValueError('"volumetric_function" has to be a "Callable" object.')
130 |
131 | # first call the ray sampler that returns the RayBundle parametrizing
132 | # the rendering rays.
133 | ray_bundle = self.raysampler(
134 | cameras=cameras, volumetric_function=volumetric_function, **kwargs
135 | )
136 | # ray_bundle.origins - minibatch x ... x 3
137 | # ray_bundle.directions - minibatch x ... x 3
138 | # ray_bundle.lengths - minibatch x ... x n_pts_per_ray
139 | # ray_bundle.xys - minibatch x ... x 2
140 |
141 | # given sampled rays, call the volumetric function that
142 | # evaluates the densities and features at the locations of the
143 | # ray points
144 | if self.reg is not None:
145 | rays_densities, rays_features, reg_term = volumetric_function(
146 | ray_bundle=ray_bundle, cameras=cameras, **kwargs
147 | )
148 | else:
149 | rays_densities, rays_features, _ = volumetric_function(
150 | ray_bundle=ray_bundle, cameras=cameras, **kwargs
151 | )
152 | # ray_densities - minibatch x ... x n_pts_per_ray x density_dim
153 | # ray_features - minibatch x ... x n_pts_per_ray x feature_dim
154 |
155 | # finally, march along the sampled rays to obtain the renders
156 | images = self.raymarcher(
157 | rays_densities=rays_densities,
158 | rays_features=rays_features,
159 | ray_bundle=ray_bundle,
160 | **kwargs
161 | )
162 | # images - minibatch x ... x (feature_dim + opacity_dim)
163 |
164 | if self.reg is not None:
165 | return images, ray_bundle, reg_term
166 | else:
167 | return images, ray_bundle, 0
--------------------------------------------------------------------------------
/utils/load_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.co3d_dataloader import CO3Dv2Wrapper
3 | from utils.co3d_toy_dataloader import CO3Dv2ToyLoader
4 |
5 | def load_dataset_test(args, image_size=None, masked=True):
6 | '''
7 | Load test dataset
8 | '''
9 |
10 | if args.dataset_name == 'co3d':
11 |
12 | if image_size is None:
13 | test_dataset = CO3Dv2Wrapper(root=args.root, category=args.category, sample_batch_size=32, subset='fewview_dev', stage='test', masked=masked)
14 | else:
15 | test_dataset = CO3Dv2Wrapper(root=args.root, category=args.category, sample_batch_size=32, subset='fewview_dev', stage='test', image_size=image_size, masked=masked)
16 |
17 | elif args.dataset_name == 'co3d_toy':
18 | test_dataset = CO3Dv2ToyLoader(root=args.root, category=args.category)
19 |
20 | else:
21 | raise NotImplementedError
22 |
23 | return test_dataset
--------------------------------------------------------------------------------
/utils/load_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | Functions to load models
3 | '''
4 | from collections import OrderedDict
5 | import importlib
6 | from omegaconf import OmegaConf
7 | import torch
8 | from external.imagen_pytorch import Unet
9 | from sparsefusion.eft import EpipolarFeatureTransformer
10 | from sparsefusion.vldm import DDPM
11 |
12 | def load_models(gpu, args, verbose=False):
13 | '''
14 | Loads models
15 |
16 | Args:
17 | gpu (int): gpu id
18 | args (Namespace): model args
19 | eft_ckpt (str): path to eft checkpoint (optional)
20 | vae_ckpt (str): path to vae checkpoint (optional)
21 | vldm_ckpt (str): path to vldm checkpoint (optional)
22 | Returns:
23 | eft (PyTorch module): eft feature extractor
24 | vae (PyTorch module): stable diffusion vae
25 | vldm (PyTorch module): vldm diffusion modle
26 | '''
27 |
28 | eft = None
29 | vae = None
30 | vldm = None
31 |
32 | #! LOAD EFT
33 | eft = EpipolarFeatureTransformer(use_r=args.use_r, encoder=args.encoder, return_features=True, remove_unused_layers=False).cuda(gpu)
34 | if args.eft_ckpt is not None:
35 | checkpoint = torch.load(args.eft_ckpt, map_location='cpu')
36 |
37 | model_dict = eft.state_dict()
38 | pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict}
39 | model_dict.update(pretrained_dict)
40 | eft.load_state_dict(model_dict)
41 |
42 | print('LOADING 1/3 loaded eft checkpoint from', args.eft_ckpt)
43 |
44 | else:
45 | print('LOADING 1/3 initialized eft from scratch')
46 |
47 |
48 | #! LOAD VAE
49 | if args.vae_ckpt is not None:
50 | vae = load_vae(args.vae_ckpt, verbose=verbose).cuda(gpu)
51 | print('LOADING 2/3 loaded sd vae from', args.vae_ckpt)
52 |
53 | else:
54 | vae = load_vae(None)
55 | print('LOADING 2/3 initialized vae from scratch')
56 |
57 | #! LOAD UNet
58 | channels = 4
59 | feature_dim = 256
60 | unet1 = Unet(
61 | channels=channels,
62 | dim = 256,
63 | dim_mults = (1, 2, 4, 4),
64 | num_resnet_blocks = (2, 2, 2, 2),
65 | layer_attns = (False, False, False, True),
66 | layer_cross_attns = (False, False, False, False),
67 | cond_images_channels = feature_dim,
68 | attn_pool_text=False
69 | )
70 |
71 | if verbose:
72 | total_params = sum(p.numel() for p in unet1.parameters())
73 | print(f"{unet1.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
74 |
75 | #! LOAD DIFFUSION
76 | vldm = DDPM(
77 | channels=channels,
78 | unets = (unet1, ),
79 | conditional_encoder = None,
80 | conditional_embed_dim = None,
81 | image_sizes = (32, ),
82 | timesteps = 500,
83 | cond_drop_prob = 0.1,
84 | pred_objectives=args.objective,
85 | conditional=False,
86 | auto_normalize_img=False,
87 | clip_output=True,
88 | dynamic_thresholding=False,
89 | dynamic_thresholding_percentile=.68,
90 | clip_value=10,
91 | ).cuda(gpu)
92 | if args.vldm_ckpt is not None:
93 | checkpoint = torch.load(args.vldm_ckpt, map_location='cpu')
94 | vldm.load_state_dict(checkpoint['model_state_dict'])
95 | print('LOADING 3/3 loaded diffusion from', args.vldm_ckpt)
96 | else:
97 | print('LOADING 3/3 loaded diffusion from', 'scratch')
98 |
99 |
100 | return eft, vae, vldm
101 |
102 |
103 | def load_vae(vae_ckpt, verbose=False):
104 | '''
105 | Load StableDiffusion VAE
106 | '''
107 | config = OmegaConf.load(f"external/ldm/configs/sd-vae.yaml")
108 | vae = load_model_from_config(config, ckpt=vae_ckpt, verbose=verbose)
109 | vae = vae
110 | return vae
111 |
112 |
113 | def get_obj_from_str(string, reload=False):
114 | module, cls = string.rsplit(".", 1)
115 | if reload:
116 | module_imp = importlib.import_module(module)
117 | importlib.reload(module_imp)
118 | return getattr(importlib.import_module(module, package=None), cls)
119 |
120 |
121 | def instantiate_from_config(config):
122 | if not "target" in config:
123 | if config == '__is_first_stage__':
124 | return None
125 | elif config == "__is_unconditional__":
126 | return None
127 | raise KeyError("Expected key `target` to instantiate.")
128 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
129 |
130 |
131 | def load_model_from_config(config, ckpt=None, verbose=True, full_model=False):
132 |
133 | model = instantiate_from_config(config.model)
134 |
135 | if ckpt is not None:
136 | if verbose:
137 | print(f"Loading model from {ckpt}")
138 | pl_sd = torch.load(ckpt, map_location="cpu")
139 | if "global_step" in pl_sd:
140 | if verbose:
141 | print(f"Global Step: {pl_sd['global_step']}")
142 | sd = pl_sd["state_dict"]
143 |
144 | #! Rename state dict params
145 | sd_ = OrderedDict()
146 | for k, v in sd.items():
147 | if not full_model:
148 | name = k.replace('first_stage_model.','').replace('model.','')
149 | else:
150 | name = k
151 | sd_[name] = v
152 |
153 | m, u = model.load_state_dict(sd_, strict=False)
154 |
155 | if len(m) > 0 and verbose:
156 | print("missing keys:")
157 | parent_set = set()
158 | for uk in m:
159 | uk_ = uk[:uk.find('.')]
160 | if uk_ not in parent_set:
161 | parent_set.add(uk_)
162 | print(parent_set)
163 | else:
164 | if verbose:
165 | print('all weights found')
166 | if len(u) > 0 and verbose:
167 | print("unexpected keys:")
168 | print(len(u))
169 | parent_set = set()
170 | for uk in u:
171 | uk_ = uk[:uk.find('.')]
172 | if uk_ not in parent_set:
173 | parent_set.add(uk_)
174 | print(parent_set)
175 | else:
176 | if verbose:
177 | print('initializing from scratch')
178 |
179 | model.eval()
180 | return model
--------------------------------------------------------------------------------
/utils/render_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Initialize Renderers
3 | '''
4 |
5 | import torch
6 | from pytorch3d.renderer import (
7 | PerspectiveCameras,
8 | MonteCarloRaysampler,
9 | EmissionAbsorptionRaymarcher,
10 | GridRaysampler,
11 | )
12 | from utils.eft_renderer import CustomImplicitRenderer
13 | from utils.eft_raymarcher import LightFieldRaymarcher
14 |
15 |
16 | def init_ray_sampler(gpu, img_h, img_w, min=0.1, max=4.0, bbox=None, n_pts_per_ray=128, n_rays=750, scale_factor=None):
17 | '''
18 | Construct ray samplers for torch-ngp
19 |
20 | Args:
21 | gpu (int): gpu id
22 | img_h (int): image height
23 | img_w (int): image width
24 | min (int): min depth for point along ray
25 | max (int): max depth for point along ray
26 | bbox (List): bounding box for monte carlo sampler
27 | n_pts_per_ray (int): number of points along a ray
28 | n_rays (int): number of rays for monte carlo sampler
29 | scale_factor (int): return a grid sampler at a scale factor
30 |
31 | Returns:
32 | sampler_grid (sampler): a grid sampler at full resolution
33 | sampler_mc (sampler): a monte carlo sampler
34 | sampler_feat (sampler): a grid sampler at scale factor resolution
35 | if scale factor is provided
36 | '''
37 |
38 | img_h, img_w = img_h, img_w
39 | volume_extent_world = max
40 | half_pix_width = 1.0 / img_w
41 | half_pix_height = 1.0 / img_h
42 |
43 | raysampler_grid = GridRaysampler(
44 | min_x=1.0 - half_pix_width,
45 | max_x=-1.0 + half_pix_width,
46 | min_y=1.0 - half_pix_height,
47 | max_y=-1.0 + half_pix_height,
48 | image_height=img_h,
49 | image_width=img_w,
50 | n_pts_per_ray=n_pts_per_ray,
51 | min_depth=min,
52 | max_depth=volume_extent_world,
53 | )
54 | if scale_factor is not None:
55 | raysampler_features = GridRaysampler(
56 | min_x=1.0 - half_pix_width,
57 | max_x=-1.0 + half_pix_width,
58 | min_y=1.0 - half_pix_height,
59 | max_y=-1.0 + half_pix_height,
60 | image_height=int(img_h//scale_factor),
61 | image_width=int(img_w//scale_factor),
62 | n_pts_per_ray=20,
63 | min_depth=min,
64 | max_depth=volume_extent_world,
65 | )
66 | if bbox is None:
67 | raysampler_mc = MonteCarloRaysampler(
68 | min_x = -1.0,
69 | max_x = 1.0,
70 | min_y = -1.0,
71 | max_y = 1.0,
72 | n_rays_per_image=n_rays,
73 | n_pts_per_ray=n_pts_per_ray,
74 | min_depth=min,
75 | max_depth=volume_extent_world,
76 | )
77 | elif bbox is not None:
78 | raysampler_mc = MonteCarloRaysampler(
79 | min_x = -bbox[0,1],
80 | max_x = -bbox[0,3],
81 | min_y = -bbox[0,0],
82 | max_y = -bbox[0,2],
83 | n_rays_per_image=n_rays,
84 | n_pts_per_ray=n_pts_per_ray,
85 | min_depth=min,
86 | max_depth=volume_extent_world,
87 | )
88 |
89 | if scale_factor is not None:
90 | return raysampler_grid, raysampler_mc, raysampler_features
91 | else:
92 | return raysampler_grid, raysampler_mc
93 |
94 |
95 | def init_light_field_renderer(gpu, img_h, img_w, min=0.1, max=4.0, bbox=None, n_pts_per_ray=128, n_rays=750, scale_factor=None):
96 | '''
97 | Construct implicit renderers for EFT
98 |
99 | Args:
100 | gpu (int): gpu id
101 | img_h (int): image height
102 | img_w (int): image width
103 | min (int): min depth for point along ray
104 | max (int): max depth for point along ray
105 | bbox (List): bounding box for monte carlo sampler
106 | n_pts_per_ray (int): number of points along a ray
107 | n_rays (int): number of rays for monte carlo sampler
108 | scale_factor (int): return a grid sampler at a scale factor
109 |
110 | Returns:
111 | renderer_grid (renderer): a grid renderer at full resolution
112 | renderer_mc (renderer): a monte carlo renderer
113 | renderer_feat (renderer): a grid renderer at scale factor resolution
114 | if scale factor is provided
115 | '''
116 |
117 | img_h, img_w = img_h, img_w
118 | volume_extent_world = max
119 | half_pix_width = 1.0 / img_w
120 | half_pix_height = 1.0 / img_h
121 |
122 | raysampler_grid = GridRaysampler(
123 | min_x=1.0 - half_pix_width,
124 | max_x=-1.0 + half_pix_width,
125 | min_y=1.0 - half_pix_height,
126 | max_y=-1.0 + half_pix_height,
127 | image_height=img_h,
128 | image_width=img_w,
129 | n_pts_per_ray=n_pts_per_ray,
130 | min_depth=min,
131 | max_depth=volume_extent_world,
132 | )
133 | if scale_factor is not None:
134 | raysampler_features = GridRaysampler(
135 | min_x=1.0 - half_pix_width,
136 | max_x=-1.0 + half_pix_width,
137 | min_y=1.0 - half_pix_height,
138 | max_y=-1.0 + half_pix_height,
139 | image_height=int(img_h//scale_factor),
140 | image_width=int(img_w//scale_factor),
141 | n_pts_per_ray=20,
142 | min_depth=min,
143 | max_depth=volume_extent_world,
144 | )
145 | if bbox is None:
146 | raysampler_mc = MonteCarloRaysampler(
147 | min_x = -1.0,
148 | max_x = 1.0,
149 | min_y = -1.0,
150 | max_y = 1.0,
151 | n_rays_per_image=n_rays,
152 | n_pts_per_ray=n_pts_per_ray,
153 | min_depth=min,
154 | max_depth=volume_extent_world,
155 | )
156 | elif bbox is not None:
157 | raysampler_mc = MonteCarloRaysampler(
158 | min_x = -bbox[0,1],
159 | max_x = -bbox[0,3],
160 | min_y = -bbox[0,0],
161 | max_y = -bbox[0,2],
162 | n_rays_per_image=n_rays,
163 | n_pts_per_ray=n_pts_per_ray,
164 | min_depth=min,
165 | max_depth=volume_extent_world,
166 | )
167 |
168 | raymarcher = LightFieldRaymarcher()
169 |
170 | renderer_grid = CustomImplicitRenderer(
171 | raysampler=raysampler_grid, raymarcher=raymarcher, reg=True
172 | )
173 | renderer_mc = CustomImplicitRenderer(
174 | raysampler=raysampler_mc, raymarcher=raymarcher, reg=True
175 | )
176 |
177 | renderer_grid = renderer_grid.cuda(gpu)
178 | renderer_mc = renderer_mc.cuda(gpu)
179 |
180 | if scale_factor is None:
181 | return renderer_grid, renderer_mc
182 | else:
183 | renderer_feat = CustomImplicitRenderer(
184 | raysampler=raysampler_features, raymarcher=raymarcher, reg=True
185 | )
186 | renderer_feat = renderer_feat.cuda(gpu)
187 | return renderer_grid, renderer_mc, renderer_feat
--------------------------------------------------------------------------------