├── LICENSE
├── README.md
├── configs
├── flower.txt
├── lego.txt
├── truck.txt
├── wineholder.txt
└── your_own_data.txt
├── dataLoader
├── __init__.py
├── blender.py
├── colmap2nerf.py
├── llff.py
├── nsvf.py
├── ray_utils.py
├── tankstemple.py
└── your_own_data.py
├── extra
├── auto_run_paramsets.py
└── compute_metrics.py
├── models
├── __init__.py
├── sh.py
├── tensoRF.py
└── tensorBase.py
├── opt.py
├── renderer.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Anpei Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TensoRF
2 | ## [Project page](https://apchenstu.github.io/TensoRF/) | [Paper](https://arxiv.org/abs/2203.09517)
3 | This repository contains a pytorch implementation for the paper: [TensoRF: Tensorial Radiance Fields](https://arxiv.org/abs/2203.09517). Our work present a novel approach to model and reconstruct radiance fields, which achieves super
4 | **fast** training process, **compact** memory footprint and **state-of-the-art** rendering quality.
5 |
6 |
7 | https://user-images.githubusercontent.com/16453770/158920837-3fafaa17-6ed9-4414-a0b1-a80dc9e10301.mp4
8 | ## Installation
9 |
10 | #### Tested on Ubuntu 20.04 + Pytorch 1.10.1
11 |
12 | Install environment:
13 | ```
14 | conda create -n TensoRF python=3.8
15 | conda activate TensoRF
16 | pip install torch torchvision
17 | pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg kornia lpips tensorboard
18 | ```
19 |
20 |
21 | ## Dataset
22 | * [Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
23 | * [Synthetic-NSVF](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip)
24 | * [Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)
25 | * [Forward-facing](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
26 |
27 |
28 |
29 | ## Quick Start
30 | The training script is in `train.py`, to train a TensoRF:
31 |
32 | ```
33 | python train.py --config configs/lego.txt
34 | ```
35 |
36 |
37 | we provide a few examples in the configuration folder, please note:
38 |
39 | `dataset_name`, choices = ['blender', 'llff', 'nsvf', 'tankstemple'];
40 |
41 | `shadingMode`, choices = ['MLP_Fea', 'SH'];
42 |
43 | `model_name`, choices = ['TensorVMSplit', 'TensorCP'], corresponding to the VM and CP decomposition.
44 | You need to uncomment the last a few rows of the configuration file if you want to training with the TensorCP model;
45 |
46 | `n_lamb_sigma` and `n_lamb_sh` are string type refer to the basis number of density and appearance along XYZ
47 | dimension;
48 |
49 | `N_voxel_init` and `N_voxel_final` control the resolution of matrix and vector;
50 |
51 | `N_vis` and `vis_every` control the visualization during training;
52 |
53 | You need to set `--render_test 1`/`--render_path 1` if you want to render testing views or path after training.
54 |
55 | More options refer to the `opt.py`.
56 |
57 | ### For pretrained checkpoints and results please see:
58 | [https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm](https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm)
59 |
60 |
61 |
62 | ## Rendering
63 |
64 | ```
65 | python train.py --config configs/lego.txt --ckpt path/to/your/checkpoint --render_only 1 --render_test 1
66 | ```
67 |
68 | You can just simply pass `--render_only 1` and `--ckpt path/to/your/checkpoint` to render images from a pre-trained
69 | checkpoint. You may also need to specify what you want to render, like `--render_test 1`, `--render_train 1` or `--render_path 1`.
70 | The rendering results are located in your checkpoint folder.
71 |
72 | ## Extracting mesh
73 | You can also export the mesh by passing `--export_mesh 1`:
74 | ```
75 | python train.py --config configs/lego.txt --ckpt path/to/your/checkpoint --export_mesh 1
76 | ```
77 | Note: Please re-train the model and don't use the pretrained checkpoints provided by us for mesh extraction,
78 | because some render parameters has changed.
79 |
80 | ## Training with your own data
81 | We provide two options for training on your own image set:
82 |
83 | 1. Following the instructions in the [NSVF repo](https://github.com/facebookresearch/NSVF#prepare-your-own-dataset), then set the dataset_name to 'tankstemple'.
84 | 2. Calibrating images with the script from [NGP](https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md):
85 | `python dataLoader/colmap2nerf.py --colmap_matcher exhaustive --run_colmap`, then adjust the datadir in `configs/your_own_data.txt`. Please check the `scene_bbox` and `near_far` if you get abnormal results.
86 |
87 |
88 | ## Citation
89 | If you find our code or paper helps, please consider citing:
90 | ```
91 | @INPROCEEDINGS{Chen2022ECCV,
92 | author = {Anpei Chen and Zexiang Xu and Andreas Geiger and Jingyi Yu and Hao Su},
93 | title = {TensoRF: Tensorial Radiance Fields},
94 | booktitle = {European Conference on Computer Vision (ECCV)},
95 | year = {2022}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/configs/flower.txt:
--------------------------------------------------------------------------------
1 |
2 | dataset_name = llff
3 | datadir = ./data/nerf_llff_data/flower
4 | expname = tensorf_flower_VM
5 | basedir = ./log
6 |
7 | downsample_train = 4.0
8 | ndc_ray = 1
9 |
10 | n_iters = 25000
11 | batch_size = 4096
12 |
13 | N_voxel_init = 2097156 # 128**3
14 | N_voxel_final = 262144000 # 640**3
15 | upsamp_list = [2000,3000,4000,5500]
16 | update_AlphaMask_list = [2500]
17 |
18 | N_vis = -1 # vis all testing images
19 | vis_every = 10000
20 |
21 | render_test = 1
22 | render_path = 1
23 |
24 | n_lamb_sigma = [16,4,4]
25 | n_lamb_sh = [48,12,12]
26 |
27 | shadingMode = MLP_Fea
28 | fea2denseAct = relu
29 |
30 | view_pe = 0
31 | fea_pe = 0
32 |
33 | TV_weight_density = 1.0
34 | TV_weight_app = 1.0
35 |
36 |
--------------------------------------------------------------------------------
/configs/lego.txt:
--------------------------------------------------------------------------------
1 |
2 | dataset_name = blender
3 | datadir = ./data/nerf_synthetic/lego
4 | expname = tensorf_lego_VM
5 | basedir = ./log
6 |
7 | n_iters = 30000
8 | batch_size = 4096
9 |
10 | N_voxel_init = 2097156 # 128**3
11 | N_voxel_final = 27000000 # 300**3
12 | upsamp_list = [2000,3000,4000,5500,7000]
13 | update_AlphaMask_list = [2000,4000]
14 |
15 | N_vis = 5
16 | vis_every = 10000
17 |
18 | render_test = 1
19 |
20 | n_lamb_sigma = [16,16,16]
21 | n_lamb_sh = [48,48,48]
22 | model_name = TensorVMSplit
23 |
24 |
25 | shadingMode = MLP_Fea
26 | fea2denseAct = softplus
27 |
28 | view_pe = 2
29 | fea_pe = 2
30 |
31 | L1_weight_inital = 8e-5
32 | L1_weight_rest = 4e-5
33 | rm_weight_mask_thre = 1e-4
34 |
35 | ## please uncomment following configuration if hope to training on cp model
36 | #model_name = TensorCP
37 | #n_lamb_sigma = [96]
38 | #n_lamb_sh = [288]
39 | #N_voxel_final = 125000000 # 500**3
40 | #L1_weight_inital = 1e-5
41 | #L1_weight_rest = 1e-5
42 |
--------------------------------------------------------------------------------
/configs/truck.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | dataset_name = tankstemple
4 | datadir = ./data/TanksAndTemple/Truck
5 | expname = tensorf_truck_VM
6 | basedir = ./log
7 |
8 | n_iters = 30000
9 | batch_size = 4096
10 |
11 | N_voxel_init = 2097156 # 128**3
12 | N_voxel_final = 27000000 # 300**3
13 | upsamp_list = [2000,3000,4000,5500,7000]
14 | update_AlphaMask_list = [2000,4000]
15 |
16 | N_vis = 5
17 | vis_every = 10000
18 |
19 | render_test = 1
20 |
21 | n_lamb_sigma = [16,16,16]
22 | n_lamb_sh = [48,48,48]
23 |
24 | shadingMode = MLP_Fea
25 | fea2denseAct = softplus
26 |
27 | view_pe = 2
28 | fea_pe = 2
29 |
30 | TV_weight_density = 0.1
31 | TV_weight_app = 0.01
32 |
33 | ## please uncomment following configuration if hope to training on cp model
34 | #model_name = TensorCP
35 | #n_lamb_sigma = [96]
36 | #n_lamb_sh = [288]
37 | #N_voxel_final = 125000000 # 500**3
38 | #L1_weight_inital = 1e-5
39 | #L1_weight_rest = 1e-5
40 |
41 |
--------------------------------------------------------------------------------
/configs/wineholder.txt:
--------------------------------------------------------------------------------
1 |
2 | dataset_name = nsvf
3 | datadir = ./data/Synthetic_NSVF/Wineholder
4 | expname = tensorf_Wineholder_VM
5 | basedir = ./log
6 |
7 | n_iters = 30000
8 | batch_size = 4096
9 |
10 | N_voxel_init = 2097156 # 128**3
11 | N_voxel_final = 27000000 # 300**3
12 | upsamp_list = [2000,3000,4000,5500,7000]
13 | update_AlphaMask_list = [2000,4000]
14 |
15 | N_vis = 5
16 | vis_every = 10000
17 |
18 | render_test = 1
19 |
20 | n_lamb_sigma = [16,16,16]
21 | n_lamb_sh = [48,48,48]
22 |
23 | shadingMode = MLP_Fea
24 | fea2denseAct = softplus
25 |
26 | view_pe = 2
27 | fea_pe = 2
28 |
29 | L1_weight_inital = 8e-5
30 | L1_weight_rest = 4e-5
31 | rm_weight_mask_thre = 1e-4
32 |
33 | ## please uncomment following configuration if hope to training on cp model
34 | #model_name = TensorCP
35 | #n_lamb_sigma = [96]
36 | #n_lamb_sh = [288]
37 | #N_voxel_final = 125000000 # 500**3
38 | #L1_weight_inital = 1e-5
39 | #L1_weight_rest = 1e-5
40 |
--------------------------------------------------------------------------------
/configs/your_own_data.txt:
--------------------------------------------------------------------------------
1 |
2 | dataset_name = own_data
3 | datadir = ./data/xxx
4 | expname = tensorf_xxx_VM
5 | basedir = ./log
6 |
7 | n_iters = 30000
8 | batch_size = 4096
9 |
10 | N_voxel_init = 2097156 # 128**3
11 | N_voxel_final = 27000000 # 300**3
12 | upsamp_list = [2000,3000,4000,5500,7000]
13 | update_AlphaMask_list = [2000,4000]
14 |
15 | N_vis = 5
16 | vis_every = 10000
17 |
18 | render_test = 1
19 |
20 | n_lamb_sigma = [16,16,16]
21 | n_lamb_sh = [48,48,48]
22 | model_name = TensorVMSplit
23 |
24 |
25 | shadingMode = MLP_Fea
26 | fea2denseAct = softplus
27 |
28 | view_pe = 2
29 | fea_pe = 2
30 |
31 | view_pe = 2
32 | fea_pe = 2
33 |
34 | TV_weight_density = 0.1
35 | TV_weight_app = 0.01
36 |
37 | rm_weight_mask_thre = 1e-4
38 |
39 | ## please uncomment following configuration if hope to training on cp model
40 | #model_name = TensorCP
41 | #n_lamb_sigma = [96]
42 | #n_lamb_sh = [288]
43 | #N_voxel_final = 125000000 # 500**3
44 | #L1_weight_inital = 1e-5
45 | #L1_weight_rest = 1e-5
46 |
--------------------------------------------------------------------------------
/dataLoader/__init__.py:
--------------------------------------------------------------------------------
1 | from .llff import LLFFDataset
2 | from .blender import BlenderDataset
3 | from .nsvf import NSVF
4 | from .tankstemple import TanksTempleDataset
5 | from .your_own_data import YourOwnDataset
6 |
7 |
8 |
9 | dataset_dict = {'blender': BlenderDataset,
10 | 'llff':LLFFDataset,
11 | 'tankstemple':TanksTempleDataset,
12 | 'nsvf':NSVF,
13 | 'own_data':YourOwnDataset}
--------------------------------------------------------------------------------
/dataLoader/blender.py:
--------------------------------------------------------------------------------
1 | import torch,cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 |
10 | from .ray_utils import *
11 |
12 |
13 | class BlenderDataset(Dataset):
14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
15 |
16 | self.N_vis = N_vis
17 | self.root_dir = datadir
18 | self.split = split
19 | self.is_stack = is_stack
20 | self.img_wh = (int(800/downsample),int(800/downsample))
21 | self.define_transforms()
22 |
23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
25 | self.read_meta()
26 | self.define_proj_mat()
27 |
28 | self.white_bg = True
29 | self.near_far = [2.0,6.0]
30 |
31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
33 | self.downsample=downsample
34 |
35 | def read_depth(self, filename):
36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
37 | return depth
38 |
39 | def read_meta(self):
40 |
41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
42 | self.meta = json.load(f)
43 |
44 | w, h = self.img_wh
45 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
46 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
47 |
48 |
49 | # ray directions for all pixels, same for all images (same H, W, focal)
50 | self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3)
51 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
52 | self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float()
53 |
54 | self.image_paths = []
55 | self.poses = []
56 | self.all_rays = []
57 | self.all_rgbs = []
58 | self.all_masks = []
59 | self.all_depth = []
60 | self.downsample=1.0
61 |
62 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
63 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
64 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
65 |
66 | frame = self.meta['frames'][i]
67 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv
68 | c2w = torch.FloatTensor(pose)
69 | self.poses += [c2w]
70 |
71 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
72 | self.image_paths += [image_path]
73 | img = Image.open(image_path)
74 |
75 | if self.downsample!=1.0:
76 | img = img.resize(self.img_wh, Image.LANCZOS)
77 | img = self.transform(img) # (4, h, w)
78 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
79 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
80 | self.all_rgbs += [img]
81 |
82 |
83 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
84 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
85 |
86 |
87 | self.poses = torch.stack(self.poses)
88 | if not self.is_stack:
89 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
90 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
91 |
92 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
93 | else:
94 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
95 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
96 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
97 |
98 |
99 | def define_transforms(self):
100 | self.transform = T.ToTensor()
101 |
102 | def define_proj_mat(self):
103 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
104 |
105 | def world2ndc(self,points,lindisp=None):
106 | device = points.device
107 | return (points - self.center.to(device)) / self.radius.to(device)
108 |
109 | def __len__(self):
110 | return len(self.all_rgbs)
111 |
112 | def __getitem__(self, idx):
113 |
114 | if self.split == 'train': # use data in the buffers
115 | sample = {'rays': self.all_rays[idx],
116 | 'rgbs': self.all_rgbs[idx]}
117 |
118 | else: # create data for each image separately
119 |
120 | img = self.all_rgbs[idx]
121 | rays = self.all_rays[idx]
122 | mask = self.all_masks[idx] # for quantity evaluation
123 |
124 | sample = {'rays': rays,
125 | 'rgbs': img,
126 | 'mask': mask}
127 | return sample
128 |
--------------------------------------------------------------------------------
/dataLoader/colmap2nerf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | import argparse
12 | import os
13 | from pathlib import Path, PurePosixPath
14 |
15 | import numpy as np
16 | import json
17 | import sys
18 | import math
19 | import cv2
20 | import os
21 | import shutil
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place")
25 |
26 | parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also")
27 | parser.add_argument("--video_fps", default=2)
28 | parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video")
29 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder")
30 | parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images")
31 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename")
32 | parser.add_argument("--images", default="images", help="input path to the images")
33 | parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
34 | parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
35 | parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
36 | parser.add_argument("--out", default="transforms.json", help="output path")
37 | args = parser.parse_args()
38 | return args
39 |
40 | def do_system(arg):
41 | print(f"==== running: {arg}")
42 | err = os.system(arg)
43 | if err:
44 | print("FATAL: command failed")
45 | sys.exit(err)
46 |
47 | def run_ffmpeg(args):
48 | if not os.path.isabs(args.images):
49 | args.images = os.path.join(os.path.dirname(args.video_in), args.images)
50 | images = args.images
51 | video = args.video_in
52 | fps = float(args.video_fps) or 1.0
53 | print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.")
54 | if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
55 | sys.exit(1)
56 | try:
57 | shutil.rmtree(images)
58 | except:
59 | pass
60 | do_system(f"mkdir {images}")
61 |
62 | time_slice_value = ""
63 | time_slice = args.time_slice
64 | if time_slice:
65 | start, end = time_slice.split(",")
66 | time_slice_value = f",select='between(t\,{start}\,{end})'"
67 | do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg")
68 |
69 | def run_colmap(args):
70 | db=args.colmap_db
71 | images=args.images
72 | db_noext=str(Path(db).with_suffix(""))
73 |
74 | if args.text=="text":
75 | args.text=db_noext+"_text"
76 | text=args.text
77 | sparse=db_noext+"_sparse"
78 | print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}")
79 | if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
80 | sys.exit(1)
81 | if os.path.exists(db):
82 | os.remove(db)
83 | do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}")
84 | do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}")
85 | try:
86 | shutil.rmtree(sparse)
87 | except:
88 | pass
89 | do_system(f"mkdir {sparse}")
90 | do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}")
91 | do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1")
92 | try:
93 | shutil.rmtree(text)
94 | except:
95 | pass
96 | do_system(f"mkdir {text}")
97 | do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT")
98 |
99 | def variance_of_laplacian(image):
100 | return cv2.Laplacian(image, cv2.CV_64F).var()
101 |
102 | def sharpness(imagePath):
103 | image = cv2.imread(imagePath)
104 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
105 | fm = variance_of_laplacian(gray)
106 | return fm
107 |
108 | def qvec2rotmat(qvec):
109 | return np.array([
110 | [
111 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
112 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
113 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
114 | ], [
115 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
116 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
117 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
118 | ], [
119 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
120 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
121 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
122 | ]
123 | ])
124 |
125 | def rotmat(a, b):
126 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
127 | v = np.cross(a, b)
128 | c = np.dot(a, b)
129 | s = np.linalg.norm(v)
130 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
131 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))
132 |
133 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
134 | da = da / np.linalg.norm(da)
135 | db = db / np.linalg.norm(db)
136 | c = np.cross(da, db)
137 | denom = np.linalg.norm(c)**2
138 | t = ob - oa
139 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
140 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
141 | if ta > 0:
142 | ta = 0
143 | if tb > 0:
144 | tb = 0
145 | return (oa+ta*da+ob+tb*db) * 0.5, denom
146 |
147 | if __name__ == "__main__":
148 | args = parse_args()
149 | if args.video_in != "":
150 | run_ffmpeg(args)
151 | if args.run_colmap:
152 | run_colmap(args)
153 | AABB_SCALE = int(args.aabb_scale)
154 | SKIP_EARLY = int(args.skip_early)
155 | IMAGE_FOLDER = args.images
156 | TEXT_FOLDER = args.text
157 | OUT_PATH = args.out
158 | print(f"outputting to {OUT_PATH}...")
159 | with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f:
160 | angle_x = math.pi / 2
161 | for line in f:
162 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691
163 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224
164 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443
165 | if line[0] == "#":
166 | continue
167 | els = line.split(" ")
168 | w = float(els[2])
169 | h = float(els[3])
170 | fl_x = float(els[4])
171 | fl_y = float(els[4])
172 | k1 = 0
173 | k2 = 0
174 | p1 = 0
175 | p2 = 0
176 | cx = w / 2
177 | cy = h / 2
178 | if els[1] == "SIMPLE_PINHOLE":
179 | cx = float(els[5])
180 | cy = float(els[6])
181 | elif els[1] == "PINHOLE":
182 | fl_y = float(els[5])
183 | cx = float(els[6])
184 | cy = float(els[7])
185 | elif els[1] == "SIMPLE_RADIAL":
186 | cx = float(els[5])
187 | cy = float(els[6])
188 | k1 = float(els[7])
189 | elif els[1] == "RADIAL":
190 | cx = float(els[5])
191 | cy = float(els[6])
192 | k1 = float(els[7])
193 | k2 = float(els[8])
194 | elif els[1] == "OPENCV":
195 | fl_y = float(els[5])
196 | cx = float(els[6])
197 | cy = float(els[7])
198 | k1 = float(els[8])
199 | k2 = float(els[9])
200 | p1 = float(els[10])
201 | p2 = float(els[11])
202 | else:
203 | print("unknown camera model ", els[1])
204 | # fl = 0.5 * w / tan(0.5 * angle_x);
205 | angle_x = math.atan(w / (fl_x * 2)) * 2
206 | angle_y = math.atan(h / (fl_y * 2)) * 2
207 | fovx = angle_x * 180 / math.pi
208 | fovy = angle_y * 180 / math.pi
209 |
210 | print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ")
211 |
212 | with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f:
213 | i = 0
214 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])
215 | out = {
216 | "camera_angle_x": angle_x,
217 | "camera_angle_y": angle_y,
218 | "fl_x": fl_x,
219 | "fl_y": fl_y,
220 | "k1": k1,
221 | "k2": k2,
222 | "p1": p1,
223 | "p2": p2,
224 | "cx": cx,
225 | "cy": cy,
226 | "w": w,
227 | "h": h,
228 | "aabb_scale": AABB_SCALE,
229 | "frames": [],
230 | }
231 |
232 | up = np.zeros(3)
233 | for line in f:
234 | line = line.strip()
235 | if line[0] == "#":
236 | continue
237 | i = i + 1
238 | if i < SKIP_EARLY*2:
239 | continue
240 | if i % 2 == 1:
241 | elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)
242 | #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))
243 | # why is this requireing a relitive path while using ^
244 | image_rel = os.path.relpath(IMAGE_FOLDER)
245 | name = str(f"./{image_rel}/{'_'.join(elems[9:])}")
246 | b=sharpness(name)
247 | print(name, "sharpness=",b)
248 | image_id = int(elems[0])
249 | qvec = np.array(tuple(map(float, elems[1:5])))
250 | tvec = np.array(tuple(map(float, elems[5:8])))
251 | R = qvec2rotmat(-qvec)
252 | t = tvec.reshape([3,1])
253 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
254 | c2w = np.linalg.inv(m)
255 | c2w[0:3,2] *= -1 # flip the y and z axis
256 | c2w[0:3,1] *= -1
257 | c2w = c2w[[1,0,2,3],:] # swap y and z
258 | c2w[2,:] *= -1 # flip whole world upside down
259 |
260 | up += c2w[0:3,1]
261 |
262 | frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
263 | out["frames"].append(frame)
264 | nframes = len(out["frames"])
265 | up = up / np.linalg.norm(up)
266 | print("up vector was", up)
267 | R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
268 | R = np.pad(R,[0,1])
269 | R[-1, -1] = 1
270 |
271 |
272 | for f in out["frames"]:
273 | f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis
274 |
275 | # find a central point they are all looking at
276 | print("computing center of attention...")
277 | totw = 0.0
278 | totp = np.array([0.0, 0.0, 0.0])
279 | for f in out["frames"]:
280 | mf = f["transform_matrix"][0:3,:]
281 | for g in out["frames"]:
282 | mg = g["transform_matrix"][0:3,:]
283 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
284 | if w > 0.01:
285 | totp += p*w
286 | totw += w
287 | totp /= totw
288 | print(totp) # the cameras are looking at totp
289 | for f in out["frames"]:
290 | f["transform_matrix"][0:3,3] -= totp
291 |
292 | avglen = 0.
293 | for f in out["frames"]:
294 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
295 | avglen /= nframes
296 | print("avg camera distance from origin", avglen)
297 | for f in out["frames"]:
298 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"
299 |
300 | for f in out["frames"]:
301 | f["transform_matrix"] = f["transform_matrix"].tolist()
302 | print(nframes,"frames")
303 | print(f"writing {OUT_PATH}")
304 | with open(OUT_PATH, "w") as outfile:
305 | json.dump(out, outfile, indent=2)
--------------------------------------------------------------------------------
/dataLoader/llff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import glob
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 | from .ray_utils import *
10 |
11 |
12 | def normalize(v):
13 | """Normalize a vector."""
14 | return v / np.linalg.norm(v)
15 |
16 |
17 | def average_poses(poses):
18 | """
19 | Calculate the average pose, which is then used to center all poses
20 | using @center_poses. Its computation is as follows:
21 | 1. Compute the center: the average of pose centers.
22 | 2. Compute the z axis: the normalized average z axis.
23 | 3. Compute axis y': the average y axis.
24 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
25 | 5. Compute the y axis: z cross product x.
26 |
27 | Note that at step 3, we cannot directly use y' as y axis since it's
28 | not necessarily orthogonal to z axis. We need to pass from x to y.
29 | Inputs:
30 | poses: (N_images, 3, 4)
31 | Outputs:
32 | pose_avg: (3, 4) the average pose
33 | """
34 | # 1. Compute the center
35 | center = poses[..., 3].mean(0) # (3)
36 |
37 | # 2. Compute the z axis
38 | z = normalize(poses[..., 2].mean(0)) # (3)
39 |
40 | # 3. Compute axis y' (no need to normalize as it's not the final output)
41 | y_ = poses[..., 1].mean(0) # (3)
42 |
43 | # 4. Compute the x axis
44 | x = normalize(np.cross(z, y_)) # (3)
45 |
46 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
47 | y = np.cross(x, z) # (3)
48 |
49 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
50 |
51 | return pose_avg
52 |
53 |
54 | def center_poses(poses, blender2opencv):
55 | """
56 | Center the poses so that we can use NDC.
57 | See https://github.com/bmild/nerf/issues/34
58 | Inputs:
59 | poses: (N_images, 3, 4)
60 | Outputs:
61 | poses_centered: (N_images, 3, 4) the centered poses
62 | pose_avg: (3, 4) the average pose
63 | """
64 | poses = poses @ blender2opencv
65 | pose_avg = average_poses(poses) # (3, 4)
66 | pose_avg_homo = np.eye(4)
67 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
68 | pose_avg_homo = pose_avg_homo
69 | # by simply adding 0, 0, 0, 1 as the last row
70 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
71 | poses_homo = \
72 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
73 |
74 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
75 | # poses_centered = poses_centered @ blender2opencv
76 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
77 |
78 | return poses_centered, pose_avg_homo
79 |
80 |
81 | def viewmatrix(z, up, pos):
82 | vec2 = normalize(z)
83 | vec1_avg = up
84 | vec0 = normalize(np.cross(vec1_avg, vec2))
85 | vec1 = normalize(np.cross(vec2, vec0))
86 | m = np.eye(4)
87 | m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
88 | return m
89 |
90 |
91 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
92 | render_poses = []
93 | rads = np.array(list(rads) + [1.])
94 |
95 | for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
96 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
97 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
98 | render_poses.append(viewmatrix(z, up, c))
99 | return render_poses
100 |
101 |
102 | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
103 | # center pose
104 | c2w = average_poses(c2ws_all)
105 |
106 | # Get average pose
107 | up = normalize(c2ws_all[:, :3, 1].sum(0))
108 |
109 | # Find a reasonable "focus depth" for this dataset
110 | dt = 0.75
111 | close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
112 | focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
113 |
114 | # Get radii for spiral path
115 | zdelta = near_fars.min() * .2
116 | tt = c2ws_all[:, :3, 3]
117 | rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
118 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
119 | return np.stack(render_poses)
120 |
121 |
122 | class LLFFDataset(Dataset):
123 | def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8):
124 | """
125 | spheric_poses: whether the images are taken in a spheric inward-facing manner
126 | default: False (forward-facing)
127 | val_num: number of val images (used for multigpu training, validate same image for all gpus)
128 | """
129 |
130 | self.root_dir = datadir
131 | self.split = split
132 | self.hold_every = hold_every
133 | self.is_stack = is_stack
134 | self.downsample = downsample
135 | self.define_transforms()
136 |
137 | self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
138 | self.read_meta()
139 | self.white_bg = False
140 |
141 | # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
142 | self.near_far = [0.0, 1.0]
143 | self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])
144 | # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
145 | self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
146 | self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
147 |
148 | def read_meta(self):
149 |
150 |
151 | poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17)
152 | self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
153 | # load full resolution image then resize
154 | if self.split in ['train', 'test']:
155 | assert len(poses_bounds) == len(self.image_paths), \
156 | 'Mismatch between number of images and number of poses! Please rerun COLMAP!'
157 |
158 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
159 | self.near_fars = poses_bounds[:, -2:] # (N_images, 2)
160 | hwf = poses[:, :, -1]
161 |
162 | # Step 1: rescale focal length according to training resolution
163 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images
164 | self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
165 | self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]
166 |
167 | # Step 2: correct poses
168 | # Original poses has rotation in form "down right back", change to "right up back"
169 | # See https://github.com/bmild/nerf/issues/34
170 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
171 | # (N_images, 3, 4) exclude H, W, focal
172 | self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)
173 |
174 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0
175 | # See https://github.com/bmild/nerf/issues/34
176 | near_original = self.near_fars.min()
177 | scale_factor = near_original * 0.75 # 0.75 is the default parameter
178 | # the nearest depth is at 1/0.75=1.33
179 | self.near_fars /= scale_factor
180 | self.poses[..., 3] /= scale_factor
181 |
182 | # build rendering path
183 | N_views, N_rots = 120, 2
184 | tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
185 | up = normalize(self.poses[:, :3, 1].sum(0))
186 | rads = np.percentile(np.abs(tt), 90, 0)
187 |
188 | self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
189 |
190 | # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
191 | # val_idx = np.argmin(distances_from_center) # choose val image as the closest to
192 | # center image
193 |
194 | # ray directions for all pixels, same for all images (same H, W, focal)
195 | W, H = self.img_wh
196 | self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3)
197 |
198 | average_pose = average_poses(self.poses)
199 | dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
200 | i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)]
201 | img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))
202 |
203 | # use first N_images-1 to train, the LAST is val
204 | self.all_rays = []
205 | self.all_rgbs = []
206 | for i in img_list:
207 | image_path = self.image_paths[i]
208 | c2w = torch.FloatTensor(self.poses[i])
209 |
210 | img = Image.open(image_path).convert('RGB')
211 | if self.downsample != 1.0:
212 | img = img.resize(self.img_wh, Image.LANCZOS)
213 | img = self.transform(img) # (3, h, w)
214 |
215 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
216 | self.all_rgbs += [img]
217 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
218 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
219 | # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
220 |
221 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
222 |
223 | if not self.is_stack:
224 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
225 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
226 | else:
227 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3)
228 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
229 |
230 |
231 | def define_transforms(self):
232 | self.transform = T.ToTensor()
233 |
234 | def __len__(self):
235 | return len(self.all_rgbs)
236 |
237 | def __getitem__(self, idx):
238 |
239 | sample = {'rays': self.all_rays[idx],
240 | 'rgbs': self.all_rgbs[idx]}
241 |
242 | return sample
--------------------------------------------------------------------------------
/dataLoader/nsvf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from tqdm import tqdm
4 | import os
5 | from PIL import Image
6 | from torchvision import transforms as T
7 |
8 | from .ray_utils import *
9 |
10 | trans_t = lambda t : torch.Tensor([
11 | [1,0,0,0],
12 | [0,1,0,0],
13 | [0,0,1,t],
14 | [0,0,0,1]]).float()
15 |
16 | rot_phi = lambda phi : torch.Tensor([
17 | [1,0,0,0],
18 | [0,np.cos(phi),-np.sin(phi),0],
19 | [0,np.sin(phi), np.cos(phi),0],
20 | [0,0,0,1]]).float()
21 |
22 | rot_theta = lambda th : torch.Tensor([
23 | [np.cos(th),0,-np.sin(th),0],
24 | [0,1,0,0],
25 | [np.sin(th),0, np.cos(th),0],
26 | [0,0,0,1]]).float()
27 |
28 |
29 | def pose_spherical(theta, phi, radius):
30 | c2w = trans_t(radius)
31 | c2w = rot_phi(phi/180.*np.pi) @ c2w
32 | c2w = rot_theta(theta/180.*np.pi) @ c2w
33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34 | return c2w
35 |
36 | class NSVF(Dataset):
37 | """NSVF Generic Dataset."""
38 | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
39 | self.root_dir = datadir
40 | self.split = split
41 | self.is_stack = is_stack
42 | self.downsample = downsample
43 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
44 | self.define_transforms()
45 |
46 | self.white_bg = True
47 | self.near_far = [0.5,6.0]
48 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
49 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
50 | self.read_meta()
51 | self.define_proj_mat()
52 |
53 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
54 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
55 |
56 | def bbox2corners(self):
57 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
58 | for i in range(3):
59 | corners[i,[0,1],i] = corners[i,[1,0],i]
60 | return corners.view(-1,3)
61 |
62 |
63 | def read_meta(self):
64 | with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
65 | focal = float(f.readline().split()[0])
66 | self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
67 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)
68 |
69 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
70 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
71 |
72 | if self.split == 'train':
73 | pose_files = [x for x in pose_files if x.startswith('0_')]
74 | img_files = [x for x in img_files if x.startswith('0_')]
75 | elif self.split == 'val':
76 | pose_files = [x for x in pose_files if x.startswith('1_')]
77 | img_files = [x for x in img_files if x.startswith('1_')]
78 | elif self.split == 'test':
79 | test_pose_files = [x for x in pose_files if x.startswith('2_')]
80 | test_img_files = [x for x in img_files if x.startswith('2_')]
81 | if len(test_pose_files) == 0:
82 | test_pose_files = [x for x in pose_files if x.startswith('1_')]
83 | test_img_files = [x for x in img_files if x.startswith('1_')]
84 | pose_files = test_pose_files
85 | img_files = test_img_files
86 |
87 | # ray directions for all pixels, same for all images (same H, W, focal)
88 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
89 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
90 |
91 |
92 | self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
93 |
94 | self.poses = []
95 | self.all_rays = []
96 | self.all_rgbs = []
97 |
98 | assert len(img_files) == len(pose_files)
99 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
100 | image_path = os.path.join(self.root_dir, 'rgb', img_fname)
101 | img = Image.open(image_path)
102 | if self.downsample!=1.0:
103 | img = img.resize(self.img_wh, Image.LANCZOS)
104 | img = self.transform(img) # (4, h, w)
105 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
106 | if img.shape[-1]==4:
107 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
108 | self.all_rgbs += [img]
109 |
110 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
111 | c2w = torch.FloatTensor(c2w)
112 | self.poses.append(c2w) # C2W
113 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
114 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
115 |
116 | # w2c = torch.inverse(c2w)
117 | #
118 |
119 | self.poses = torch.stack(self.poses)
120 | if 'train' == self.split:
121 | if self.is_stack:
122 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
123 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
124 | else:
125 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
126 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
127 | else:
128 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
129 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
130 |
131 |
132 | def define_transforms(self):
133 | self.transform = T.ToTensor()
134 |
135 | def define_proj_mat(self):
136 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
137 |
138 | def world2ndc(self, points):
139 | device = points.device
140 | return (points - self.center.to(device)) / self.radius.to(device)
141 |
142 | def __len__(self):
143 | if self.split == 'train':
144 | return len(self.all_rays)
145 | return len(self.all_rgbs)
146 |
147 | def __getitem__(self, idx):
148 |
149 | if self.split == 'train': # use data in the buffers
150 | sample = {'rays': self.all_rays[idx],
151 | 'rgbs': self.all_rgbs[idx]}
152 |
153 | else: # create data for each image separately
154 |
155 | img = self.all_rgbs[idx]
156 | rays = self.all_rays[idx]
157 |
158 | sample = {'rays': rays,
159 | 'rgbs': img}
160 | return sample
--------------------------------------------------------------------------------
/dataLoader/ray_utils.py:
--------------------------------------------------------------------------------
1 | import torch, re
2 | import numpy as np
3 | from torch import searchsorted
4 | from kornia import create_meshgrid
5 |
6 |
7 | # from utils import index_point_feature
8 |
9 | def depth2dist(z_vals, cos_angle):
10 | # z_vals: [N_ray N_sample]
11 | device = z_vals.device
12 | dists = z_vals[..., 1:] - z_vals[..., :-1]
13 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples]
14 | dists = dists * cos_angle.unsqueeze(-1)
15 | return dists
16 |
17 |
18 | def ndc2dist(ndc_pts, cos_angle):
19 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)
20 | dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples]
21 | return dists
22 |
23 |
24 | def get_ray_directions(H, W, focal, center=None):
25 | """
26 | Get ray directions for all pixels in camera coordinate.
27 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
28 | ray-tracing-generating-camera-rays/standard-coordinate-systems
29 | Inputs:
30 | H, W, focal: image height, width and focal length
31 | Outputs:
32 | directions: (H, W, 3), the direction of the rays in camera coordinate
33 | """
34 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5
35 |
36 | i, j = grid.unbind(-1)
37 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
38 | # see https://github.com/bmild/nerf/issues/24
39 | cent = center if center is not None else [W / 2, H / 2]
40 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
41 |
42 | return directions
43 |
44 |
45 | def get_ray_directions_blender(H, W, focal, center=None):
46 | """
47 | Get ray directions for all pixels in camera coordinate.
48 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
49 | ray-tracing-generating-camera-rays/standard-coordinate-systems
50 | Inputs:
51 | H, W, focal: image height, width and focal length
52 | Outputs:
53 | directions: (H, W, 3), the direction of the rays in camera coordinate
54 | """
55 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5
56 | i, j = grid.unbind(-1)
57 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
58 | # see https://github.com/bmild/nerf/issues/24
59 | cent = center if center is not None else [W / 2, H / 2]
60 | directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],
61 | -1) # (H, W, 3)
62 |
63 | return directions
64 |
65 |
66 | def get_rays(directions, c2w):
67 | """
68 | Get ray origin and normalized directions in world coordinate for all pixels in one image.
69 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
70 | ray-tracing-generating-camera-rays/standard-coordinate-systems
71 | Inputs:
72 | directions: (H, W, 3) precomputed ray directions in camera coordinate
73 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
74 | Outputs:
75 | rays_o: (H*W, 3), the origin of the rays in world coordinate
76 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
77 | """
78 | # Rotate ray directions from camera coordinate to the world coordinate
79 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
80 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
81 | # The origin of all rays is the camera origin in world coordinate
82 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)
83 |
84 | rays_d = rays_d.view(-1, 3)
85 | rays_o = rays_o.view(-1, 3)
86 |
87 | return rays_o, rays_d
88 |
89 |
90 | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
91 | # Shift ray origins to near plane
92 | t = -(near + rays_o[..., 2]) / rays_d[..., 2]
93 | rays_o = rays_o + t[..., None] * rays_d
94 |
95 | # Projection
96 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
97 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
98 | o2 = 1. + 2. * near / rays_o[..., 2]
99 |
100 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
101 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
102 | d2 = -2. * near / rays_o[..., 2]
103 |
104 | rays_o = torch.stack([o0, o1, o2], -1)
105 | rays_d = torch.stack([d0, d1, d2], -1)
106 |
107 | return rays_o, rays_d
108 |
109 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
110 | # Shift ray origins to near plane
111 | t = (near - rays_o[..., 2]) / rays_d[..., 2]
112 | rays_o = rays_o + t[..., None] * rays_d
113 |
114 | # Projection
115 | o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
116 | o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
117 | o2 = 1. - 2. * near / rays_o[..., 2]
118 |
119 | d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
120 | d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
121 | d2 = 2. * near / rays_o[..., 2]
122 |
123 | rays_o = torch.stack([o0, o1, o2], -1)
124 | rays_d = torch.stack([d0, d1, d2], -1)
125 |
126 | return rays_o, rays_d
127 |
128 | # Hierarchical sampling (section 5.2)
129 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
130 | device = weights.device
131 | # Get pdf
132 | weights = weights + 1e-5 # prevent nans
133 | pdf = weights / torch.sum(weights, -1, keepdim=True)
134 | cdf = torch.cumsum(pdf, -1)
135 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
136 |
137 | # Take uniform samples
138 | if det:
139 | u = torch.linspace(0., 1., steps=N_samples, device=device)
140 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
141 | else:
142 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)
143 |
144 | # Pytest, overwrite u with numpy's fixed random numbers
145 | if pytest:
146 | np.random.seed(0)
147 | new_shape = list(cdf.shape[:-1]) + [N_samples]
148 | if det:
149 | u = np.linspace(0., 1., N_samples)
150 | u = np.broadcast_to(u, new_shape)
151 | else:
152 | u = np.random.rand(*new_shape)
153 | u = torch.Tensor(u)
154 |
155 | # Invert CDF
156 | u = u.contiguous()
157 | inds = searchsorted(cdf.detach(), u, right=True)
158 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
159 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
160 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
161 |
162 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
163 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
164 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
165 |
166 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
167 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
168 | t = (u - cdf_g[..., 0]) / denom
169 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
170 |
171 | return samples
172 |
173 |
174 | def dda(rays_o, rays_d, bbox_3D):
175 | inv_ray_d = 1.0 / (rays_d + 1e-6)
176 | t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3
177 | t_max = (bbox_3D[1:] - rays_o) * inv_ray_d
178 | t = torch.stack((t_min, t_max)) # 2 N_rays 3
179 | t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]
180 | t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]
181 | return t_min, t_max
182 |
183 |
184 | def ray_marcher(rays,
185 | N_samples=64,
186 | lindisp=False,
187 | perturb=0,
188 | bbox_3D=None):
189 | """
190 | sample points along the rays
191 | Inputs:
192 | rays: ()
193 |
194 | Returns:
195 |
196 | """
197 |
198 | # Decompose the inputs
199 | N_rays = rays.shape[0]
200 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
201 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
202 |
203 | if bbox_3D is not None:
204 | # cal aabb boundles
205 | near, far = dda(rays_o, rays_d, bbox_3D)
206 |
207 | # Sample depth points
208 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples)
209 | if not lindisp: # use linear sampling in depth space
210 | z_vals = near * (1 - z_steps) + far * z_steps
211 | else: # use linear sampling in disparity space
212 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)
213 |
214 | z_vals = z_vals.expand(N_rays, N_samples)
215 |
216 | if perturb > 0: # perturb sampling depths (z_vals)
217 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points
218 | # get intervals between samples
219 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
220 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)
221 |
222 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
223 | z_vals = lower + (upper - lower) * perturb_rand
224 |
225 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \
226 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)
227 |
228 | return xyz_coarse_sampled, rays_o, rays_d, z_vals
229 |
230 |
231 | def read_pfm(filename):
232 | file = open(filename, 'rb')
233 | color = None
234 | width = None
235 | height = None
236 | scale = None
237 | endian = None
238 |
239 | header = file.readline().decode('utf-8').rstrip()
240 | if header == 'PF':
241 | color = True
242 | elif header == 'Pf':
243 | color = False
244 | else:
245 | raise Exception('Not a PFM file.')
246 |
247 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
248 | if dim_match:
249 | width, height = map(int, dim_match.groups())
250 | else:
251 | raise Exception('Malformed PFM header.')
252 |
253 | scale = float(file.readline().rstrip())
254 | if scale < 0: # little-endian
255 | endian = '<'
256 | scale = -scale
257 | else:
258 | endian = '>' # big-endian
259 |
260 | data = np.fromfile(file, endian + 'f')
261 | shape = (height, width, 3) if color else (height, width)
262 |
263 | data = np.reshape(data, shape)
264 | data = np.flipud(data)
265 | file.close()
266 | return data, scale
267 |
268 |
269 | def ndc_bbox(all_rays):
270 | near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]
271 | near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]
272 | far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]
273 | far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]
274 | print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')
275 | return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))
--------------------------------------------------------------------------------
/dataLoader/tankstemple.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from tqdm import tqdm
4 | import os
5 | from PIL import Image
6 | from torchvision import transforms as T
7 |
8 | from .ray_utils import *
9 |
10 |
11 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
12 | if axis == 'z':
13 | return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]
14 | elif axis == 'y':
15 | return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]
16 | else:
17 | return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]
18 |
19 |
20 | def cross(x, y, axis=0):
21 | T = torch if isinstance(x, torch.Tensor) else np
22 | return T.cross(x, y, axis)
23 |
24 |
25 | def normalize(x, axis=-1, order=2):
26 | if isinstance(x, torch.Tensor):
27 | l2 = x.norm(p=order, dim=axis, keepdim=True)
28 | return x / (l2 + 1e-8), l2
29 |
30 | else:
31 | l2 = np.linalg.norm(x, order, axis)
32 | l2 = np.expand_dims(l2, axis)
33 | l2[l2 == 0] = 1
34 | return x / l2,
35 |
36 |
37 | def cat(x, axis=1):
38 | if isinstance(x[0], torch.Tensor):
39 | return torch.cat(x, dim=axis)
40 | return np.concatenate(x, axis=axis)
41 |
42 |
43 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
44 | """
45 | This function takes a vector 'camera_position' which specifies the location
46 | of the camera in world coordinates and two vectors `at` and `up` which
47 | indicate the position of the object and the up directions of the world
48 | coordinate system respectively. The object is assumed to be centered at
49 | the origin.
50 | The output is a rotation matrix representing the transformation
51 | from world coordinates -> view coordinates.
52 | Input:
53 | camera_position: 3
54 | at: 1 x 3 or N x 3 (0, 0, 0) in default
55 | up: 1 x 3 or N x 3 (0, 1, 0) in default
56 | """
57 |
58 | if at is None:
59 | at = torch.zeros_like(camera_position)
60 | else:
61 | at = torch.tensor(at).type_as(camera_position)
62 | if up is None:
63 | up = torch.zeros_like(camera_position)
64 | up[2] = -1
65 | else:
66 | up = torch.tensor(up).type_as(camera_position)
67 |
68 | z_axis = normalize(at - camera_position)[0]
69 | x_axis = normalize(cross(up, z_axis))[0]
70 | y_axis = normalize(cross(z_axis, x_axis))[0]
71 |
72 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
73 | return R
74 |
75 |
76 | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
77 | c2ws = []
78 | for t in range(frames):
79 | c2w = torch.eye(4)
80 | cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))
81 | cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)
82 | c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot
83 | c2ws.append(c2w)
84 | return torch.stack(c2ws)
85 |
86 | class TanksTempleDataset(Dataset):
87 | """NSVF Generic Dataset."""
88 | def __init__(self, datadir, split='train', downsample=1.0, wh=[1920,1080], is_stack=False):
89 | self.root_dir = datadir
90 | self.split = split
91 | self.is_stack = is_stack
92 | self.downsample = downsample
93 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
94 | self.define_transforms()
95 |
96 | self.white_bg = True
97 | self.near_far = [0.01,6.0]
98 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2
99 |
100 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
101 | self.read_meta()
102 | self.define_proj_mat()
103 |
104 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
105 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
106 |
107 | def bbox2corners(self):
108 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
109 | for i in range(3):
110 | corners[i,[0,1],i] = corners[i,[1,0],i]
111 | return corners.view(-1,3)
112 |
113 |
114 | def read_meta(self):
115 |
116 | self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt"))
117 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1)
118 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
119 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
120 |
121 | if self.split == 'train':
122 | pose_files = [x for x in pose_files if x.startswith('0_')]
123 | img_files = [x for x in img_files if x.startswith('0_')]
124 | elif self.split == 'val':
125 | pose_files = [x for x in pose_files if x.startswith('1_')]
126 | img_files = [x for x in img_files if x.startswith('1_')]
127 | elif self.split == 'test':
128 | test_pose_files = [x for x in pose_files if x.startswith('2_')]
129 | test_img_files = [x for x in img_files if x.startswith('2_')]
130 | if len(test_pose_files) == 0:
131 | test_pose_files = [x for x in pose_files if x.startswith('1_')]
132 | test_img_files = [x for x in img_files if x.startswith('1_')]
133 | pose_files = test_pose_files
134 | img_files = test_img_files
135 |
136 | # ray directions for all pixels, same for all images (same H, W, focal)
137 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
138 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
139 |
140 |
141 |
142 | self.poses = []
143 | self.all_rays = []
144 | self.all_rgbs = []
145 |
146 | assert len(img_files) == len(pose_files)
147 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
148 | image_path = os.path.join(self.root_dir, 'rgb', img_fname)
149 | img = Image.open(image_path)
150 | if self.downsample!=1.0:
151 | img = img.resize(self.img_wh, Image.LANCZOS)
152 | img = self.transform(img) # (4, h, w)
153 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
154 | if img.shape[-1]==4:
155 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
156 | self.all_rgbs.append(img)
157 |
158 |
159 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans
160 | c2w = torch.FloatTensor(c2w)
161 | self.poses.append(c2w) # C2W
162 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
163 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
164 |
165 | self.poses = torch.stack(self.poses)
166 |
167 | center = torch.mean(self.scene_bbox, dim=0)
168 | radius = torch.norm(self.scene_bbox[1]-center)*1.2
169 | up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
170 | pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
171 | self.render_path = gen_path(pos_gen, up=up,frames=200)
172 | self.render_path[:, :3, 3] += center
173 |
174 |
175 |
176 | if 'train' == self.split:
177 | if self.is_stack:
178 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
179 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
180 | else:
181 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
182 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
183 | else:
184 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
185 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
186 |
187 |
188 | def define_transforms(self):
189 | self.transform = T.ToTensor()
190 |
191 | def define_proj_mat(self):
192 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
193 |
194 | def world2ndc(self, points):
195 | device = points.device
196 | return (points - self.center.to(device)) / self.radius.to(device)
197 |
198 | def __len__(self):
199 | if self.split == 'train':
200 | return len(self.all_rays)
201 | return len(self.all_rgbs)
202 |
203 | def __getitem__(self, idx):
204 |
205 | if self.split == 'train': # use data in the buffers
206 | sample = {'rays': self.all_rays[idx],
207 | 'rgbs': self.all_rgbs[idx]}
208 |
209 | else: # create data for each image separately
210 |
211 | img = self.all_rgbs[idx]
212 | rays = self.all_rays[idx]
213 |
214 | sample = {'rays': rays,
215 | 'rgbs': img}
216 | return sample
--------------------------------------------------------------------------------
/dataLoader/your_own_data.py:
--------------------------------------------------------------------------------
1 | import torch,cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 |
10 | from .ray_utils import *
11 |
12 |
13 | class YourOwnDataset(Dataset):
14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
15 |
16 | self.N_vis = N_vis
17 | self.root_dir = datadir
18 | self.split = split
19 | self.is_stack = is_stack
20 | self.downsample = downsample
21 | self.define_transforms()
22 |
23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
25 | self.read_meta()
26 | self.define_proj_mat()
27 |
28 | self.white_bg = True
29 | self.near_far = [0.1,100.0]
30 |
31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
33 | self.downsample=downsample
34 |
35 | def read_depth(self, filename):
36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
37 | return depth
38 |
39 | def read_meta(self):
40 |
41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
42 | self.meta = json.load(f)
43 |
44 | w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
45 | self.img_wh = [w,h]
46 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
47 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
48 | self.cx, self.cy = self.meta['cx'],self.meta['cy']
49 |
50 |
51 | # ray directions for all pixels, same for all images (same H, W, focal)
52 | self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
53 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
54 | self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()
55 |
56 | self.image_paths = []
57 | self.poses = []
58 | self.all_rays = []
59 | self.all_rgbs = []
60 | self.all_masks = []
61 | self.all_depth = []
62 |
63 |
64 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
65 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
66 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
67 |
68 | frame = self.meta['frames'][i]
69 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv
70 | c2w = torch.FloatTensor(pose)
71 | self.poses += [c2w]
72 |
73 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
74 | self.image_paths += [image_path]
75 | img = Image.open(image_path)
76 |
77 | if self.downsample!=1.0:
78 | img = img.resize(self.img_wh, Image.LANCZOS)
79 | img = self.transform(img) # (4, h, w)
80 | img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA
81 | if img.shape[-1]==4:
82 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
83 | self.all_rgbs += [img]
84 |
85 |
86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
87 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
88 |
89 |
90 | self.poses = torch.stack(self.poses)
91 | if not self.is_stack:
92 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
93 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
94 |
95 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
96 | else:
97 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
98 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
99 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
100 |
101 |
102 | def define_transforms(self):
103 | self.transform = T.ToTensor()
104 |
105 | def define_proj_mat(self):
106 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
107 |
108 | def world2ndc(self,points,lindisp=None):
109 | device = points.device
110 | return (points - self.center.to(device)) / self.radius.to(device)
111 |
112 | def __len__(self):
113 | return len(self.all_rgbs)
114 |
115 | def __getitem__(self, idx):
116 |
117 | if self.split == 'train': # use data in the buffers
118 | sample = {'rays': self.all_rays[idx],
119 | 'rgbs': self.all_rgbs[idx]}
120 |
121 | else: # create data for each image separately
122 |
123 | img = self.all_rgbs[idx]
124 | rays = self.all_rays[idx]
125 | mask = self.all_masks[idx] # for quantity evaluation
126 |
127 | sample = {'rays': rays,
128 | 'rgbs': img}
129 | return sample
130 |
--------------------------------------------------------------------------------
/extra/auto_run_paramsets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import threading, queue
3 | import numpy as np
4 | import time
5 |
6 |
7 | def getFolderLocker(logFolder):
8 | while True:
9 | try:
10 | os.makedirs(logFolder+"/lockFolder")
11 | break
12 | except:
13 | time.sleep(0.01)
14 |
15 | def releaseFolderLocker(logFolder):
16 | os.removedirs(logFolder+"/lockFolder")
17 |
18 | def getStopFolder(logFolder):
19 | return os.path.isdir(logFolder+"/stopFolder")
20 |
21 |
22 | def get_param_str(key, val):
23 | if key == 'data_name':
24 | return f'--datadir {datafolder}/{val} '
25 | else:
26 | return f'--{key} {val} '
27 |
28 | def get_param_list(param_dict):
29 | param_keys = list(param_dict.keys())
30 | param_modes = len(param_keys)
31 | param_nums = [len(param_dict[key]) for key in param_keys]
32 |
33 | param_ids = np.zeros(param_nums+[param_modes], dtype=int)
34 | for i in range(param_modes):
35 | broad_tuple = np.ones(param_modes, dtype=int).tolist()
36 | broad_tuple[i] = param_nums[i]
37 | broad_tuple = tuple(broad_tuple)
38 | print(broad_tuple)
39 | param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple)
40 | param_ids = param_ids.reshape(-1, param_modes)
41 | # print(param_ids)
42 | print(len(param_ids))
43 |
44 | params = []
45 | expnames = []
46 | for i in range(param_ids.shape[0]):
47 | one = ""
48 | name = ""
49 | param_id = param_ids[i]
50 | for j in range(param_modes):
51 | key = param_keys[j]
52 | val = param_dict[key][param_id[j]]
53 | if type(key) is tuple:
54 | assert len(key) == len(val)
55 | for k in range(len(key)):
56 | one += get_param_str(key[k], val[k])
57 | name += f'{val[k]},'
58 | name=name[:-1]+'-'
59 | else:
60 | one += get_param_str(key, val)
61 | name += f'{val}-'
62 | params.append(one)
63 | name=name.replace(' ','')
64 | print(name)
65 | expnames.append(name[:-1])
66 | # print(params)
67 | return params, expnames
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 | if __name__ == '__main__':
76 |
77 |
78 |
79 | # nerf
80 | expFolder = "nerf/"
81 | # parameters to iterate, use tuple to couple multiple parameters
82 | datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/'
83 | param_dict = {
84 | 'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'],
85 | 'data_dim_color': [13, 27, 54]
86 | }
87 |
88 | # n_iters = 30000
89 | # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
90 | # cmd = f'CUDA_VISIBLE_DEVICES={cuda} python train.py ' \
91 | # f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\
92 | # f'--expname {data_name} --batch_size {batch_size} ' \
93 | # f'--n_iters {n_iters} ' \
94 | # f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\
95 | # f'--N_vis {5} ' \
96 | # f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \
97 | # f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \
98 | # f'--shadingMode MLP_Fea --fea2denseAct softplus --view_pe {2} --fea_pe {2} ' \
99 | # f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \
100 | # f'--render_test 1 '
101 | # print(cmd)
102 | # os.system(cmd)
103 |
104 | # nsvf
105 | # expFolder = "nsvf_0227/"
106 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/'
107 | # param_dict = {
108 | # 'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
109 | # 'shadingMode': ['SH'],
110 | # ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")],
111 | # ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)],
112 | # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)],
113 | # ('n_iters','N_voxel_final'): [(30000,300**3)],
114 | # ('dataset_name','N_vis','render_test') : [("nsvf",5,1)],
115 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")]
116 | #
117 | # }
118 |
119 | # tankstemple
120 | # expFolder = "tankstemple_0304/"
121 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/'
122 | # param_dict = {
123 | # 'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'],
124 | # 'shadingMode': ['MLP_Fea'],
125 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")],
126 | # ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)],
127 | # ('TV_weight_density','TV_weight_app'):[(0.1,0.01)],
128 | # # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)],
129 | # ('n_iters','N_voxel_final'): [(15000,300**3)],
130 | # ('dataset_name','N_vis') : [("tankstemple",5)],
131 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")]
132 | # }
133 |
134 | # llff
135 | # expFolder = "real_iconic/"
136 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/'
137 | # List = os.listdir(datafolder)
138 | # param_dict = {
139 | # 'data_name': List,
140 | # ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)],
141 | # ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")],
142 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
143 | # ('n_iters','N_voxel_final'): [(25000,640**3)],
144 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)],
145 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
146 | # }
147 |
148 | # expFolder = "llff/"
149 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data'
150 | # param_dict = {
151 | # 'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'
152 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")],
153 | # ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)],
154 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
155 | # ('n_iters','N_voxel_final'): [(25000,640**3)],
156 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)],
157 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
158 | # }
159 |
160 | #setting available gpus
161 | gpus_que = queue.Queue(3)
162 | for i in [1,2,3]:
163 | gpus_que.put(i)
164 |
165 | os.makedirs(f"log/{expFolder}", exist_ok=True)
166 |
167 | def run_program(gpu, expname, param):
168 | cmd = f'CUDA_VISIBLE_DEVICES={gpu} python train.py ' \
169 | f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \
170 | f'{param}' \
171 | f'> "log/{expFolder}{expname}/{expname}.txt"'
172 | print(cmd)
173 | os.system(cmd)
174 | gpus_que.put(gpu)
175 |
176 | params, expnames = get_param_list(param_dict)
177 |
178 |
179 | logFolder=f"log/{expFolder}"
180 | os.makedirs(logFolder, exist_ok=True)
181 |
182 | ths = []
183 | for i in range(len(params)):
184 |
185 | if getStopFolder(logFolder):
186 | break
187 |
188 |
189 | targetFolder = f"log/{expFolder}{expnames[i]}"
190 | gpu = gpus_que.get()
191 | getFolderLocker(logFolder)
192 | if os.path.isdir(targetFolder):
193 | releaseFolderLocker(logFolder)
194 | gpus_que.put(gpu)
195 | continue
196 | else:
197 | os.makedirs(targetFolder, exist_ok=True)
198 | print("making",targetFolder, "running",expnames[i], params[i])
199 | releaseFolderLocker(logFolder)
200 |
201 |
202 | t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True)
203 | t.start()
204 | ths.append(t)
205 |
206 | for th in ths:
207 | th.join()
--------------------------------------------------------------------------------
/extra/compute_metrics.py:
--------------------------------------------------------------------------------
1 | import os, math
2 | import numpy as np
3 | import scipy.signal
4 | from typing import List, Optional
5 | from PIL import Image
6 | import os
7 | import torch
8 | import configargparse
9 |
10 | __LPIPS__ = {}
11 | def init_lpips(net_name, device):
12 | assert net_name in ['alex', 'vgg']
13 | import lpips
14 | print(f'init_lpips: lpips_{net_name}')
15 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
16 |
17 | def rgb_lpips(np_gt, np_im, net_name, device):
18 | if net_name not in __LPIPS__:
19 | __LPIPS__[net_name] = init_lpips(net_name, device)
20 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
21 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
22 | return __LPIPS__[net_name](gt, im, normalize=True).item()
23 |
24 |
25 | def findItem(items, target):
26 | for one in items:
27 | if one[:len(target)]==target:
28 | return one
29 | return None
30 |
31 |
32 | ''' Evaluation metrics (ssim, lpips)
33 | '''
34 | def rgb_ssim(img0, img1, max_val,
35 | filter_size=11,
36 | filter_sigma=1.5,
37 | k1=0.01,
38 | k2=0.03,
39 | return_map=False):
40 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
41 | assert len(img0.shape) == 3
42 | assert img0.shape[-1] == 3
43 | assert img0.shape == img1.shape
44 |
45 | # Construct a 1D Gaussian blur filter.
46 | hw = filter_size // 2
47 | shift = (2 * hw - filter_size + 1) / 2
48 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
49 | filt = np.exp(-0.5 * f_i)
50 | filt /= np.sum(filt)
51 |
52 | # Blur in x and y (faster than the 2D convolution).
53 | def convolve2d(z, f):
54 | return scipy.signal.convolve2d(z, f, mode='valid')
55 |
56 | filt_fn = lambda z: np.stack([
57 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
58 | for i in range(z.shape[-1])], -1)
59 | mu0 = filt_fn(img0)
60 | mu1 = filt_fn(img1)
61 | mu00 = mu0 * mu0
62 | mu11 = mu1 * mu1
63 | mu01 = mu0 * mu1
64 | sigma00 = filt_fn(img0**2) - mu00
65 | sigma11 = filt_fn(img1**2) - mu11
66 | sigma01 = filt_fn(img0 * img1) - mu01
67 |
68 | # Clip the variances and covariances to valid values.
69 | # Variance must be non-negative:
70 | sigma00 = np.maximum(0., sigma00)
71 | sigma11 = np.maximum(0., sigma11)
72 | sigma01 = np.sign(sigma01) * np.minimum(
73 | np.sqrt(sigma00 * sigma11), np.abs(sigma01))
74 | c1 = (k1 * max_val)**2
75 | c2 = (k2 * max_val)**2
76 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
77 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
78 | ssim_map = numer / denom
79 | ssim = np.mean(ssim_map)
80 | return ssim_map if return_map else ssim
81 |
82 |
83 | if __name__ == '__main__':
84 |
85 | parser = configargparse.ArgumentParser()
86 | parser.add_argument("--exp", type=str, help="folder of exps")
87 | parser.add_argument("--paramStr", type=str, help="str of params")
88 | args = parser.parse_args()
89 |
90 |
91 | # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']#
92 | # gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic"
93 | # expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp
94 |
95 | # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']#
96 | # gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/"
97 | # expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp
98 | paramStr = args.paramStr
99 | fileNum = 200
100 |
101 |
102 | expitems = os.listdir(expFolder)
103 | finalFolder = f'{expFolder}/finals/{paramStr}'
104 | outFile = f'{finalFolder}/{paramStr}_metrics.txt'
105 | os.makedirs(finalFolder, exist_ok=True)
106 |
107 | expitems.sort(reverse=True)
108 |
109 |
110 | with open(outFile, 'w') as f:
111 | all_psnr = []
112 | all_ssim = []
113 | all_alex = []
114 | all_vgg = []
115 | for dataname in datanames:
116 |
117 |
118 | gtstr = gtFolder+"/"+dataname+"/test/r_%d.png"
119 | expname = findItem(expitems, f'{paramStr}-{dataname}')
120 | print("expname: ", expname)
121 | if expname is None:
122 | print("no ",dataname, "exists")
123 | continue
124 | resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png"
125 | metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt'
126 | video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4'
127 |
128 | exist_metric=False
129 | if os.path.isfile(metric_file):
130 | metrics = np.loadtxt(metric_file)
131 | print(metrics, metrics.tolist())
132 | if metrics.size == 4:
133 | psnr, ssim, l_a, l_v = metrics.tolist()
134 | exist_metric = True
135 | os.system(f"cp {video_file} {finalFolder}/")
136 |
137 | if not exist_metric:
138 | psnrs = []
139 | ssims = []
140 | l_alex = []
141 | l_vgg = []
142 | for i in range(fileNum):
143 | gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0
144 | gtmask = gt[...,[3]]
145 | gt = gt[...,:3]
146 | gt = gt*gtmask + (1-gtmask)
147 | img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3] / 255.0
148 | # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max())
149 |
150 |
151 | psnr = -10. * np.log10(np.mean(np.square(img - gt)))
152 | ssim = rgb_ssim(img, gt, 1)
153 | lpips_alex = rgb_lpips(gt, img, 'alex','cuda')
154 | lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda')
155 |
156 | print(i, psnr, ssim, lpips_alex, lpips_vgg)
157 | psnrs.append(psnr)
158 | ssims.append(ssim)
159 | l_alex.append(lpips_alex)
160 | l_vgg.append(lpips_vgg)
161 | psnr = np.mean(np.array(psnrs))
162 | ssim = np.mean(np.array(ssims))
163 | l_a = np.mean(np.array(l_alex))
164 | l_v = np.mean(np.array(l_vgg))
165 |
166 | rS=f'{dataname} : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}'
167 | print(rS)
168 | f.write(rS+"\n")
169 |
170 | all_psnr.append(psnr)
171 | all_ssim.append(ssim)
172 | all_alex.append(l_a)
173 | all_vgg.append(l_v)
174 |
175 | psnr = np.mean(np.array(all_psnr))
176 | ssim = np.mean(np.array(all_ssim))
177 | l_a = np.mean(np.array(all_alex))
178 | l_v = np.mean(np.array(all_vgg))
179 |
180 | rS=f'mean : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}'
181 | print(rS)
182 | f.write(rS+"\n")
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apchenstu/TensoRF/9370a87c88bf41b309da694833c81845cc960d50/models/__init__.py
--------------------------------------------------------------------------------
/models/sh.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | ################## sh function ##################
4 | C0 = 0.28209479177387814
5 | C1 = 0.4886025119029199
6 | C2 = [
7 | 1.0925484305920792,
8 | -1.0925484305920792,
9 | 0.31539156525252005,
10 | -1.0925484305920792,
11 | 0.5462742152960396
12 | ]
13 | C3 = [
14 | -0.5900435899266435,
15 | 2.890611442640554,
16 | -0.4570457994644658,
17 | 0.3731763325901154,
18 | -0.4570457994644658,
19 | 1.445305721320277,
20 | -0.5900435899266435
21 | ]
22 | C4 = [
23 | 2.5033429417967046,
24 | -1.7701307697799304,
25 | 0.9461746957575601,
26 | -0.6690465435572892,
27 | 0.10578554691520431,
28 | -0.6690465435572892,
29 | 0.47308734787878004,
30 | -1.7701307697799304,
31 | 0.6258357354491761,
32 | ]
33 |
34 | def eval_sh(deg, sh, dirs):
35 | """
36 | Evaluate spherical harmonics at unit directions
37 | using hardcoded SH polynomials.
38 | Works with torch/np/jnp.
39 | ... Can be 0 or more batch dimensions.
40 | :param deg: int SH max degree. Currently, 0-4 supported
41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)
42 | :param dirs: torch.Tensor unit directions (..., 3)
43 | :return: (..., C)
44 | """
45 | assert deg <= 4 and deg >= 0
46 | assert (deg + 1) ** 2 == sh.shape[-1]
47 | C = sh.shape[-2]
48 |
49 | result = C0 * sh[..., 0]
50 | if deg > 0:
51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
52 | result = (result -
53 | C1 * y * sh[..., 1] +
54 | C1 * z * sh[..., 2] -
55 | C1 * x * sh[..., 3])
56 | if deg > 1:
57 | xx, yy, zz = x * x, y * y, z * z
58 | xy, yz, xz = x * y, y * z, x * z
59 | result = (result +
60 | C2[0] * xy * sh[..., 4] +
61 | C2[1] * yz * sh[..., 5] +
62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
63 | C2[3] * xz * sh[..., 7] +
64 | C2[4] * (xx - yy) * sh[..., 8])
65 |
66 | if deg > 2:
67 | result = (result +
68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
69 | C3[1] * xy * z * sh[..., 10] +
70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
73 | C3[5] * z * (xx - yy) * sh[..., 14] +
74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
75 | if deg > 3:
76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
85 | return result
86 |
87 | def eval_sh_bases(deg, dirs):
88 | """
89 | Evaluate spherical harmonics bases at unit directions,
90 | without taking linear combination.
91 | At each point, the final result may the be
92 | obtained through simple multiplication.
93 | :param deg: int SH max degree. Currently, 0-4 supported
94 | :param dirs: torch.Tensor (..., 3) unit directions
95 | :return: torch.Tensor (..., (deg+1) ** 2)
96 | """
97 | assert deg <= 4 and deg >= 0
98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)
99 | result[..., 0] = C0
100 | if deg > 0:
101 | x, y, z = dirs.unbind(-1)
102 | result[..., 1] = -C1 * y;
103 | result[..., 2] = C1 * z;
104 | result[..., 3] = -C1 * x;
105 | if deg > 1:
106 | xx, yy, zz = x * x, y * y, z * z
107 | xy, yz, xz = x * y, y * z, x * z
108 | result[..., 4] = C2[0] * xy;
109 | result[..., 5] = C2[1] * yz;
110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy);
111 | result[..., 7] = C2[3] * xz;
112 | result[..., 8] = C2[4] * (xx - yy);
113 |
114 | if deg > 2:
115 | result[..., 9] = C3[0] * y * (3 * xx - yy);
116 | result[..., 10] = C3[1] * xy * z;
117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy);
118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);
119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy);
120 | result[..., 14] = C3[5] * z * (xx - yy);
121 | result[..., 15] = C3[6] * x * (xx - 3 * yy);
122 |
123 | if deg > 3:
124 | result[..., 16] = C4[0] * xy * (xx - yy);
125 | result[..., 17] = C4[1] * yz * (3 * xx - yy);
126 | result[..., 18] = C4[2] * xy * (7 * zz - 1);
127 | result[..., 19] = C4[3] * yz * (7 * zz - 3);
128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);
129 | result[..., 21] = C4[5] * xz * (7 * zz - 3);
130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);
131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy);
132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));
133 | return result
134 |
--------------------------------------------------------------------------------
/models/tensoRF.py:
--------------------------------------------------------------------------------
1 | from .tensorBase import *
2 |
3 |
4 | class TensorVM(TensorBase):
5 | def __init__(self, aabb, gridSize, device, **kargs):
6 | super(TensorVM, self).__init__(aabb, gridSize, device, **kargs)
7 |
8 |
9 | def init_svd_volume(self, res, device):
10 | self.plane_coef = torch.nn.Parameter(
11 | 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, res), device=device))
12 | self.line_coef = torch.nn.Parameter(
13 | 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, 1), device=device))
14 | self.basis_mat = torch.nn.Linear(self.app_n_comp * 3, self.app_dim, bias=False, device=device)
15 |
16 |
17 | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
18 | grad_vars = [{'params': self.line_coef, 'lr': lr_init_spatialxyz}, {'params': self.plane_coef, 'lr': lr_init_spatialxyz},
19 | {'params': self.basis_mat.parameters(), 'lr':lr_init_network}]
20 | if isinstance(self.renderModule, torch.nn.Module):
21 | grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
22 | return grad_vars
23 |
24 | def compute_features(self, xyz_sampled):
25 |
26 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach()
27 | coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
28 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach()
29 |
30 | plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, align_corners=True).view(
31 | -1, *xyz_sampled.shape[:1])
32 | line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view(
33 | -1, *xyz_sampled.shape[:1])
34 |
35 | sigma_feature = torch.sum(plane_feats * line_feats, dim=0)
36 |
37 |
38 | plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(3 * self.app_n_comp, -1)
39 | line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(3 * self.app_n_comp, -1)
40 |
41 |
42 | app_features = self.basis_mat((plane_feats * line_feats).T)
43 |
44 | return sigma_feature, app_features
45 |
46 | def compute_densityfeature(self, xyz_sampled):
47 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
48 | coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
49 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
50 |
51 | plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, align_corners=True).view(
52 | -1, *xyz_sampled.shape[:1])
53 | line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view(
54 | -1, *xyz_sampled.shape[:1])
55 |
56 | sigma_feature = torch.sum(plane_feats * line_feats, dim=0)
57 |
58 |
59 | return sigma_feature
60 |
61 | def compute_appfeature(self, xyz_sampled):
62 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
63 | coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
64 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
65 |
66 | plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(3 * self.app_n_comp, -1)
67 | line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(3 * self.app_n_comp, -1)
68 |
69 |
70 | app_features = self.basis_mat((plane_feats * line_feats).T)
71 |
72 |
73 | return app_features
74 |
75 |
76 | def vectorDiffs(self, vector_comps):
77 | total = 0
78 |
79 | for idx in range(len(vector_comps)):
80 | # print(self.line_coef.shape, vector_comps[idx].shape)
81 | n_comp, n_size = vector_comps[idx].shape[:-1]
82 |
83 | dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2))
84 | # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape)
85 | non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]
86 | # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape,non_diagonal.shape)
87 | total = total + torch.mean(torch.abs(non_diagonal))
88 | return total
89 |
90 | def vector_comp_diffs(self):
91 |
92 | return self.vectorDiffs(self.line_coef[:,-self.density_n_comp:]) + self.vectorDiffs(self.line_coef[:,:self.app_n_comp])
93 |
94 |
95 | @torch.no_grad()
96 | def up_sampling_VM(self, plane_coef, line_coef, res_target):
97 |
98 | for i in range(len(self.vecMode)):
99 | vec_id = self.vecMode[i]
100 | mat_id_0, mat_id_1 = self.matMode[i]
101 |
102 | plane_coef[i] = torch.nn.Parameter(
103 | F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
104 | align_corners=True))
105 | line_coef[i] = torch.nn.Parameter(
106 | F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
107 |
108 | # plane_coef[0] = torch.nn.Parameter(
109 | # F.interpolate(plane_coef[0].data, size=(res_target[1], res_target[0]), mode='bilinear',
110 | # align_corners=True))
111 | # line_coef[0] = torch.nn.Parameter(
112 | # F.interpolate(line_coef[0].data, size=(res_target[2], 1), mode='bilinear', align_corners=True))
113 | # plane_coef[1] = torch.nn.Parameter(
114 | # F.interpolate(plane_coef[1].data, size=(res_target[2], res_target[0]), mode='bilinear',
115 | # align_corners=True))
116 | # line_coef[1] = torch.nn.Parameter(
117 | # F.interpolate(line_coef[1].data, size=(res_target[1], 1), mode='bilinear', align_corners=True))
118 | # plane_coef[2] = torch.nn.Parameter(
119 | # F.interpolate(plane_coef[2].data, size=(res_target[2], res_target[1]), mode='bilinear',
120 | # align_corners=True))
121 | # line_coef[2] = torch.nn.Parameter(
122 | # F.interpolate(line_coef[2].data, size=(res_target[0], 1), mode='bilinear', align_corners=True))
123 |
124 | return plane_coef, line_coef
125 |
126 | @torch.no_grad()
127 | def upsample_volume_grid(self, res_target):
128 | # self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
129 | # self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)
130 |
131 | scale = res_target[0]/self.line_coef.shape[2] #assuming xyz have the same scale
132 | plane_coef = F.interpolate(self.plane_coef.detach().data, scale_factor=scale, mode='bilinear',align_corners=True)
133 | line_coef = F.interpolate(self.line_coef.detach().data, size=(res_target[0],1), mode='bilinear',align_corners=True)
134 | self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef)
135 | self.compute_stepSize(res_target)
136 | print(f'upsamping to {res_target}')
137 |
138 |
139 | class TensorVMSplit(TensorBase):
140 | def __init__(self, aabb, gridSize, device, **kargs):
141 | super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs)
142 |
143 |
144 | def init_svd_volume(self, res, device):
145 | self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device)
146 | self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device)
147 | self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device)
148 |
149 |
150 | def init_one_svd(self, n_component, gridSize, scale, device):
151 | plane_coef, line_coef = [], []
152 | for i in range(len(self.vecMode)):
153 | vec_id = self.vecMode[i]
154 | mat_id_0, mat_id_1 = self.matMode[i]
155 | plane_coef.append(torch.nn.Parameter(
156 | scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) #
157 | line_coef.append(
158 | torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1))))
159 |
160 | return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device)
161 |
162 |
163 |
164 | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
165 | grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, {'params': self.density_plane, 'lr': lr_init_spatialxyz},
166 | {'params': self.app_line, 'lr': lr_init_spatialxyz}, {'params': self.app_plane, 'lr': lr_init_spatialxyz},
167 | {'params': self.basis_mat.parameters(), 'lr':lr_init_network}]
168 | if isinstance(self.renderModule, torch.nn.Module):
169 | grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
170 | return grad_vars
171 |
172 |
173 | def vectorDiffs(self, vector_comps):
174 | total = 0
175 |
176 | for idx in range(len(vector_comps)):
177 | n_comp, n_size = vector_comps[idx].shape[1:-1]
178 |
179 | dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2))
180 | non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]
181 | total = total + torch.mean(torch.abs(non_diagonal))
182 | return total
183 |
184 | def vector_comp_diffs(self):
185 | return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line)
186 |
187 | def density_L1(self):
188 | total = 0
189 | for idx in range(len(self.density_plane)):
190 | total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[idx]))# + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx]))
191 | return total
192 |
193 | def TV_loss_density(self, reg):
194 | total = 0
195 | for idx in range(len(self.density_plane)):
196 | total = total + reg(self.density_plane[idx]) * 1e-2 #+ reg(self.density_line[idx]) * 1e-3
197 | return total
198 |
199 | def TV_loss_app(self, reg):
200 | total = 0
201 | for idx in range(len(self.app_plane)):
202 | total = total + reg(self.app_plane[idx]) * 1e-2 #+ reg(self.app_line[idx]) * 1e-3
203 | return total
204 |
205 | def compute_densityfeature(self, xyz_sampled):
206 |
207 | # plane + line basis
208 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
209 | coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
210 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
211 |
212 | sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device)
213 | for idx_plane in range(len(self.density_plane)):
214 | plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]],
215 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
216 | line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]],
217 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
218 | sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0)
219 |
220 | return sigma_feature
221 |
222 |
223 | def compute_appfeature(self, xyz_sampled):
224 |
225 | # plane + line basis
226 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
227 | coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
228 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
229 |
230 | plane_coef_point,line_coef_point = [],[]
231 | for idx_plane in range(len(self.app_plane)):
232 | plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]],
233 | align_corners=True).view(-1, *xyz_sampled.shape[:1]))
234 | line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]],
235 | align_corners=True).view(-1, *xyz_sampled.shape[:1]))
236 | plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)
237 |
238 |
239 | return self.basis_mat((plane_coef_point * line_coef_point).T)
240 |
241 |
242 |
243 | @torch.no_grad()
244 | def up_sampling_VM(self, plane_coef, line_coef, res_target):
245 |
246 | for i in range(len(self.vecMode)):
247 | vec_id = self.vecMode[i]
248 | mat_id_0, mat_id_1 = self.matMode[i]
249 | plane_coef[i] = torch.nn.Parameter(
250 | F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
251 | align_corners=True))
252 | line_coef[i] = torch.nn.Parameter(
253 | F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
254 |
255 |
256 | return plane_coef, line_coef
257 |
258 | @torch.no_grad()
259 | def upsample_volume_grid(self, res_target):
260 | self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
261 | self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)
262 |
263 | self.update_stepSize(res_target)
264 | print(f'upsamping to {res_target}')
265 |
266 | @torch.no_grad()
267 | def shrink(self, new_aabb):
268 | print("====> shrinking ...")
269 | xyz_min, xyz_max = new_aabb
270 | t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
271 | # print(new_aabb, self.aabb)
272 | # print(t_l, b_r,self.alphaMask.alpha_volume.shape)
273 | t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
274 | b_r = torch.stack([b_r, self.gridSize]).amin(0)
275 |
276 | for i in range(len(self.vecMode)):
277 | mode0 = self.vecMode[i]
278 | self.density_line[i] = torch.nn.Parameter(
279 | self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]
280 | )
281 | self.app_line[i] = torch.nn.Parameter(
282 | self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]
283 | )
284 | mode0, mode1 = self.matMode[i]
285 | self.density_plane[i] = torch.nn.Parameter(
286 | self.density_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
287 | )
288 | self.app_plane[i] = torch.nn.Parameter(
289 | self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
290 | )
291 |
292 |
293 | if not torch.all(self.alphaMask.gridSize == self.gridSize):
294 | t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
295 | correct_aabb = torch.zeros_like(new_aabb)
296 | correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
297 | correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
298 | print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
299 | new_aabb = correct_aabb
300 |
301 | newSize = b_r - t_l
302 | self.aabb = new_aabb
303 | self.update_stepSize((newSize[0], newSize[1], newSize[2]))
304 |
305 |
306 | class TensorCP(TensorBase):
307 | def __init__(self, aabb, gridSize, device, **kargs):
308 | super(TensorCP, self).__init__(aabb, gridSize, device, **kargs)
309 |
310 |
311 | def init_svd_volume(self, res, device):
312 | self.density_line = self.init_one_svd(self.density_n_comp[0], self.gridSize, 0.2, device)
313 | self.app_line = self.init_one_svd(self.app_n_comp[0], self.gridSize, 0.2, device)
314 | self.basis_mat = torch.nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device)
315 |
316 |
317 | def init_one_svd(self, n_component, gridSize, scale, device):
318 | line_coef = []
319 | for i in range(len(self.vecMode)):
320 | vec_id = self.vecMode[i]
321 | line_coef.append(
322 | torch.nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1))))
323 | return torch.nn.ParameterList(line_coef).to(device)
324 |
325 |
326 | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
327 | grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz},
328 | {'params': self.app_line, 'lr': lr_init_spatialxyz},
329 | {'params': self.basis_mat.parameters(), 'lr':lr_init_network}]
330 | if isinstance(self.renderModule, torch.nn.Module):
331 | grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
332 | return grad_vars
333 |
334 | def compute_densityfeature(self, xyz_sampled):
335 |
336 | coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
337 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
338 |
339 |
340 | line_coef_point = F.grid_sample(self.density_line[0], coordinate_line[[0]],
341 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
342 | line_coef_point = line_coef_point * F.grid_sample(self.density_line[1], coordinate_line[[1]],
343 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
344 | line_coef_point = line_coef_point * F.grid_sample(self.density_line[2], coordinate_line[[2]],
345 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
346 | sigma_feature = torch.sum(line_coef_point, dim=0)
347 |
348 |
349 | return sigma_feature
350 |
351 | def compute_appfeature(self, xyz_sampled):
352 |
353 | coordinate_line = torch.stack(
354 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
355 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
356 |
357 |
358 | line_coef_point = F.grid_sample(self.app_line[0], coordinate_line[[0]],
359 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
360 | line_coef_point = line_coef_point * F.grid_sample(self.app_line[1], coordinate_line[[1]],
361 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
362 | line_coef_point = line_coef_point * F.grid_sample(self.app_line[2], coordinate_line[[2]],
363 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
364 |
365 | return self.basis_mat(line_coef_point.T)
366 |
367 |
368 | @torch.no_grad()
369 | def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target):
370 |
371 | for i in range(len(self.vecMode)):
372 | vec_id = self.vecMode[i]
373 | density_line_coef[i] = torch.nn.Parameter(
374 | F.interpolate(density_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
375 | app_line_coef[i] = torch.nn.Parameter(
376 | F.interpolate(app_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
377 |
378 | return density_line_coef, app_line_coef
379 |
380 | @torch.no_grad()
381 | def upsample_volume_grid(self, res_target):
382 | self.density_line, self.app_line = self.up_sampling_Vector(self.density_line, self.app_line, res_target)
383 |
384 | self.update_stepSize(res_target)
385 | print(f'upsamping to {res_target}')
386 |
387 | @torch.no_grad()
388 | def shrink(self, new_aabb):
389 | print("====> shrinking ...")
390 | xyz_min, xyz_max = new_aabb
391 | t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
392 |
393 | t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
394 | b_r = torch.stack([b_r, self.gridSize]).amin(0)
395 |
396 |
397 | for i in range(len(self.vecMode)):
398 | mode0 = self.vecMode[i]
399 | self.density_line[i] = torch.nn.Parameter(
400 | self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]
401 | )
402 | self.app_line[i] = torch.nn.Parameter(
403 | self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]
404 | )
405 |
406 | if not torch.all(self.alphaMask.gridSize == self.gridSize):
407 | t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
408 | correct_aabb = torch.zeros_like(new_aabb)
409 | correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
410 | correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
411 | print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
412 | new_aabb = correct_aabb
413 |
414 | newSize = b_r - t_l
415 | self.aabb = new_aabb
416 | self.update_stepSize((newSize[0], newSize[1], newSize[2]))
417 |
418 | def density_L1(self):
419 | total = 0
420 | for idx in range(len(self.density_line)):
421 | total = total + torch.mean(torch.abs(self.density_line[idx]))
422 | return total
423 |
424 | def TV_loss_density(self, reg):
425 | total = 0
426 | for idx in range(len(self.density_line)):
427 | total = total + reg(self.density_line[idx]) * 1e-3
428 | return total
429 |
430 | def TV_loss_app(self, reg):
431 | total = 0
432 | for idx in range(len(self.app_line)):
433 | total = total + reg(self.app_line[idx]) * 1e-3
434 | return total
--------------------------------------------------------------------------------
/models/tensorBase.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 | import torch.nn.functional as F
4 | from .sh import eval_sh_bases
5 | import numpy as np
6 | import time
7 |
8 |
9 | def positional_encoding(positions, freqs):
10 |
11 | freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,)
12 | pts = (positions[..., None] * freq_bands).reshape(
13 | positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF)
14 | pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
15 | return pts
16 |
17 | def raw2alpha(sigma, dist):
18 | # sigma, dist [N_rays, N_samples]
19 | alpha = 1. - torch.exp(-sigma*dist)
20 |
21 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)
22 |
23 | weights = alpha * T[:, :-1] # [N_rays, N_samples]
24 | return alpha, weights, T[:,-1:]
25 |
26 |
27 | def SHRender(xyz_sampled, viewdirs, features):
28 | sh_mult = eval_sh_bases(2, viewdirs)[:, None]
29 | rgb_sh = features.view(-1, 3, sh_mult.shape[-1])
30 | rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5)
31 | return rgb
32 |
33 |
34 | def RGBRender(xyz_sampled, viewdirs, features):
35 |
36 | rgb = features
37 | return rgb
38 |
39 | class AlphaGridMask(torch.nn.Module):
40 | def __init__(self, device, aabb, alpha_volume):
41 | super(AlphaGridMask, self).__init__()
42 | self.device = device
43 |
44 | self.aabb=aabb.to(self.device)
45 | self.aabbSize = self.aabb[1] - self.aabb[0]
46 | self.invgridSize = 1.0/self.aabbSize * 2
47 | self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:])
48 | self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device)
49 |
50 | def sample_alpha(self, xyz_sampled):
51 | xyz_sampled = self.normalize_coord(xyz_sampled)
52 | alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1)
53 |
54 | return alpha_vals
55 |
56 | def normalize_coord(self, xyz_sampled):
57 | return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1
58 |
59 |
60 | class MLPRender_Fea(torch.nn.Module):
61 | def __init__(self,inChanel, viewpe=6, feape=6, featureC=128):
62 | super(MLPRender_Fea, self).__init__()
63 |
64 | self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel
65 | self.viewpe = viewpe
66 | self.feape = feape
67 | layer1 = torch.nn.Linear(self.in_mlpC, featureC)
68 | layer2 = torch.nn.Linear(featureC, featureC)
69 | layer3 = torch.nn.Linear(featureC,3)
70 |
71 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
72 | torch.nn.init.constant_(self.mlp[-1].bias, 0)
73 |
74 | def forward(self, pts, viewdirs, features):
75 | indata = [features, viewdirs]
76 | if self.feape > 0:
77 | indata += [positional_encoding(features, self.feape)]
78 | if self.viewpe > 0:
79 | indata += [positional_encoding(viewdirs, self.viewpe)]
80 | mlp_in = torch.cat(indata, dim=-1)
81 | rgb = self.mlp(mlp_in)
82 | rgb = torch.sigmoid(rgb)
83 |
84 | return rgb
85 |
86 | class MLPRender_PE(torch.nn.Module):
87 | def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128):
88 | super(MLPRender_PE, self).__init__()
89 |
90 | self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel #
91 | self.viewpe = viewpe
92 | self.pospe = pospe
93 | layer1 = torch.nn.Linear(self.in_mlpC, featureC)
94 | layer2 = torch.nn.Linear(featureC, featureC)
95 | layer3 = torch.nn.Linear(featureC,3)
96 |
97 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
98 | torch.nn.init.constant_(self.mlp[-1].bias, 0)
99 |
100 | def forward(self, pts, viewdirs, features):
101 | indata = [features, viewdirs]
102 | if self.pospe > 0:
103 | indata += [positional_encoding(pts, self.pospe)]
104 | if self.viewpe > 0:
105 | indata += [positional_encoding(viewdirs, self.viewpe)]
106 | mlp_in = torch.cat(indata, dim=-1)
107 | rgb = self.mlp(mlp_in)
108 | rgb = torch.sigmoid(rgb)
109 |
110 | return rgb
111 |
112 | class MLPRender(torch.nn.Module):
113 | def __init__(self,inChanel, viewpe=6, featureC=128):
114 | super(MLPRender, self).__init__()
115 |
116 | self.in_mlpC = (3+2*viewpe*3) + inChanel
117 | self.viewpe = viewpe
118 |
119 | layer1 = torch.nn.Linear(self.in_mlpC, featureC)
120 | layer2 = torch.nn.Linear(featureC, featureC)
121 | layer3 = torch.nn.Linear(featureC,3)
122 |
123 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
124 | torch.nn.init.constant_(self.mlp[-1].bias, 0)
125 |
126 | def forward(self, pts, viewdirs, features):
127 | indata = [features, viewdirs]
128 | if self.viewpe > 0:
129 | indata += [positional_encoding(viewdirs, self.viewpe)]
130 | mlp_in = torch.cat(indata, dim=-1)
131 | rgb = self.mlp(mlp_in)
132 | rgb = torch.sigmoid(rgb)
133 |
134 | return rgb
135 |
136 |
137 |
138 | class TensorBase(torch.nn.Module):
139 | def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27,
140 | shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0],
141 | density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001,
142 | pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0,
143 | fea2denseAct = 'softplus'):
144 | super(TensorBase, self).__init__()
145 |
146 | self.density_n_comp = density_n_comp
147 | self.app_n_comp = appearance_n_comp
148 | self.app_dim = app_dim
149 | self.aabb = aabb
150 | self.alphaMask = alphaMask
151 | self.device=device
152 |
153 | self.density_shift = density_shift
154 | self.alphaMask_thres = alphaMask_thres
155 | self.distance_scale = distance_scale
156 | self.rayMarch_weight_thres = rayMarch_weight_thres
157 | self.fea2denseAct = fea2denseAct
158 |
159 | self.near_far = near_far
160 | self.step_ratio = step_ratio
161 |
162 |
163 | self.update_stepSize(gridSize)
164 |
165 | self.matMode = [[0,1], [0,2], [1,2]]
166 | self.vecMode = [2, 1, 0]
167 | self.comp_w = [1,1,1]
168 |
169 |
170 | self.init_svd_volume(gridSize[0], device)
171 |
172 | self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC
173 | self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device)
174 |
175 | def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device):
176 | if shadingMode == 'MLP_PE':
177 | self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device)
178 | elif shadingMode == 'MLP_Fea':
179 | self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device)
180 | elif shadingMode == 'MLP':
181 | self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device)
182 | elif shadingMode == 'SH':
183 | self.renderModule = SHRender
184 | elif shadingMode == 'RGB':
185 | assert self.app_dim == 3
186 | self.renderModule = RGBRender
187 | else:
188 | print("Unrecognized shading module")
189 | exit()
190 | print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe)
191 | print(self.renderModule)
192 |
193 | def update_stepSize(self, gridSize):
194 | print("aabb", self.aabb.view(-1))
195 | print("grid size", gridSize)
196 | self.aabbSize = self.aabb[1] - self.aabb[0]
197 | self.invaabbSize = 2.0/self.aabbSize
198 | self.gridSize= torch.LongTensor(gridSize).to(self.device)
199 | self.units=self.aabbSize / (self.gridSize-1)
200 | self.stepSize=torch.mean(self.units)*self.step_ratio
201 | self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))
202 | self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1
203 | print("sampling step size: ", self.stepSize)
204 | print("sampling number: ", self.nSamples)
205 |
206 | def init_svd_volume(self, res, device):
207 | pass
208 |
209 | def compute_features(self, xyz_sampled):
210 | pass
211 |
212 | def compute_densityfeature(self, xyz_sampled):
213 | pass
214 |
215 | def compute_appfeature(self, xyz_sampled):
216 | pass
217 |
218 | def normalize_coord(self, xyz_sampled):
219 | return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1
220 |
221 | def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001):
222 | pass
223 |
224 | def get_kwargs(self):
225 | return {
226 | 'aabb': self.aabb,
227 | 'gridSize':self.gridSize.tolist(),
228 | 'density_n_comp': self.density_n_comp,
229 | 'appearance_n_comp': self.app_n_comp,
230 | 'app_dim': self.app_dim,
231 |
232 | 'density_shift': self.density_shift,
233 | 'alphaMask_thres': self.alphaMask_thres,
234 | 'distance_scale': self.distance_scale,
235 | 'rayMarch_weight_thres': self.rayMarch_weight_thres,
236 | 'fea2denseAct': self.fea2denseAct,
237 |
238 | 'near_far': self.near_far,
239 | 'step_ratio': self.step_ratio,
240 |
241 | 'shadingMode': self.shadingMode,
242 | 'pos_pe': self.pos_pe,
243 | 'view_pe': self.view_pe,
244 | 'fea_pe': self.fea_pe,
245 | 'featureC': self.featureC
246 | }
247 |
248 | def save(self, path):
249 | kwargs = self.get_kwargs()
250 | ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()}
251 | if self.alphaMask is not None:
252 | alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
253 | ckpt.update({'alphaMask.shape':alpha_volume.shape})
254 | ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))})
255 | ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})
256 | torch.save(ckpt, path)
257 |
258 | def load(self, ckpt):
259 | if 'alphaMask.aabb' in ckpt.keys():
260 | length = np.prod(ckpt['alphaMask.shape'])
261 | alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))
262 | self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device))
263 | self.load_state_dict(ckpt['state_dict'])
264 |
265 |
266 | def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
267 | N_samples = N_samples if N_samples > 0 else self.nSamples
268 | near, far = self.near_far
269 | interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
270 | if is_train:
271 | interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)
272 |
273 | rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
274 | mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1)
275 | return rays_pts, interpx, ~mask_outbbox
276 |
277 | def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
278 | N_samples = N_samples if N_samples>0 else self.nSamples
279 | stepsize = self.stepSize
280 | near, far = self.near_far
281 | vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d)
282 | rate_a = (self.aabb[1] - rays_o) / vec
283 | rate_b = (self.aabb[0] - rays_o) / vec
284 | t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)
285 |
286 | rng = torch.arange(N_samples)[None].float()
287 | if is_train:
288 | rng = rng.repeat(rays_d.shape[-2],1)
289 | rng += torch.rand_like(rng[:,[0]])
290 | step = stepsize * rng.to(rays_o.device)
291 | interpx = (t_min[...,None] + step)
292 |
293 | rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None]
294 | mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1)
295 |
296 | return rays_pts, interpx, ~mask_outbbox
297 |
298 |
299 | def shrink(self, new_aabb, voxel_size):
300 | pass
301 |
302 | @torch.no_grad()
303 | def getDenseAlpha(self,gridSize=None):
304 | gridSize = self.gridSize if gridSize is None else gridSize
305 |
306 | samples = torch.stack(torch.meshgrid(
307 | torch.linspace(0, 1, gridSize[0]),
308 | torch.linspace(0, 1, gridSize[1]),
309 | torch.linspace(0, 1, gridSize[2]),
310 | ), -1).to(self.device)
311 | dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples
312 |
313 | # dense_xyz = dense_xyz
314 | # print(self.stepSize, self.distance_scale*self.aabbDiag)
315 | alpha = torch.zeros_like(dense_xyz[...,0])
316 | for i in range(gridSize[0]):
317 | alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2]))
318 | return alpha, dense_xyz
319 |
320 | @torch.no_grad()
321 | def updateAlphaMask(self, gridSize=(200,200,200)):
322 |
323 | alpha, dense_xyz = self.getDenseAlpha(gridSize)
324 | dense_xyz = dense_xyz.transpose(0,2).contiguous()
325 | alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None]
326 | total_voxels = gridSize[0] * gridSize[1] * gridSize[2]
327 |
328 | ks = 3
329 | alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])
330 | alpha[alpha>=self.alphaMask_thres] = 1
331 | alpha[alpha0.5]
336 |
337 | xyz_min = valid_xyz.amin(0)
338 | xyz_max = valid_xyz.amax(0)
339 |
340 | new_aabb = torch.stack((xyz_min, xyz_max))
341 |
342 | total = torch.sum(alpha)
343 | print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100))
344 | return new_aabb
345 |
346 | @torch.no_grad()
347 | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False):
348 | print('========> filtering rays ...')
349 | tt = time.time()
350 |
351 | N = torch.tensor(all_rays.shape[:-1]).prod()
352 |
353 | mask_filtered = []
354 | idx_chunks = torch.split(torch.arange(N), chunk)
355 | for idx_chunk in idx_chunks:
356 | rays_chunk = all_rays[idx_chunk].to(self.device)
357 |
358 | rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
359 | if bbox_only:
360 | vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
361 | rate_a = (self.aabb[1] - rays_o) / vec
362 | rate_b = (self.aabb[0] - rays_o) / vec
363 | t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far)
364 | t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far)
365 | mask_inbbox = t_max > t_min
366 |
367 | else:
368 | xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False)
369 | mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)
370 |
371 | mask_filtered.append(mask_inbbox.cpu())
372 |
373 | mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1])
374 |
375 | print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}')
376 | return all_rays[mask_filtered], all_rgbs[mask_filtered]
377 |
378 |
379 | def feature2density(self, density_features):
380 | if self.fea2denseAct == "softplus":
381 | return F.softplus(density_features+self.density_shift)
382 | elif self.fea2denseAct == "relu":
383 | return F.relu(density_features)
384 |
385 |
386 | def compute_alpha(self, xyz_locs, length=1):
387 |
388 | if self.alphaMask is not None:
389 | alphas = self.alphaMask.sample_alpha(xyz_locs)
390 | alpha_mask = alphas > 0
391 | else:
392 | alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool)
393 |
394 |
395 | sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)
396 |
397 | if alpha_mask.any():
398 | xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask])
399 | sigma_feature = self.compute_densityfeature(xyz_sampled)
400 | validsigma = self.feature2density(sigma_feature)
401 | sigma[alpha_mask] = validsigma
402 |
403 |
404 | alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1])
405 |
406 | return alpha
407 |
408 |
409 | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):
410 |
411 | # sample points
412 | viewdirs = rays_chunk[:, 3:6]
413 | if ndc_ray:
414 | xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
415 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
416 | rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
417 | dists = dists * rays_norm
418 | viewdirs = viewdirs / rays_norm
419 | else:
420 | xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
421 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
422 | viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
423 |
424 | if self.alphaMask is not None:
425 | alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
426 | alpha_mask = alphas > 0
427 | ray_invalid = ~ray_valid
428 | ray_invalid[ray_valid] |= (~alpha_mask)
429 | ray_valid = ~ray_invalid
430 |
431 |
432 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
433 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
434 |
435 | if ray_valid.any():
436 | xyz_sampled = self.normalize_coord(xyz_sampled)
437 | sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
438 |
439 | validsigma = self.feature2density(sigma_feature)
440 | sigma[ray_valid] = validsigma
441 |
442 |
443 | alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)
444 |
445 | app_mask = weight > self.rayMarch_weight_thres
446 |
447 | if app_mask.any():
448 | app_features = self.compute_appfeature(xyz_sampled[app_mask])
449 | valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
450 | rgb[app_mask] = valid_rgbs
451 |
452 | acc_map = torch.sum(weight, -1)
453 | rgb_map = torch.sum(weight[..., None] * rgb, -2)
454 |
455 | if white_bg or (is_train and torch.rand((1,))<0.5):
456 | rgb_map = rgb_map + (1. - acc_map[..., None])
457 |
458 |
459 | rgb_map = rgb_map.clamp(0,1)
460 |
461 | with torch.no_grad():
462 | depth_map = torch.sum(weight * z_vals, -1)
463 | depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]
464 |
465 | return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight
466 |
467 |
--------------------------------------------------------------------------------
/opt.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 |
3 | def config_parser(cmd=None):
4 | parser = configargparse.ArgumentParser()
5 | parser.add_argument('--config', is_config_file=True,
6 | help='config file path')
7 | parser.add_argument("--expname", type=str,
8 | help='experiment name')
9 | parser.add_argument("--basedir", type=str, default='./log',
10 | help='where to store ckpts and logs')
11 | parser.add_argument("--add_timestamp", type=int, default=0,
12 | help='add timestamp to dir')
13 | parser.add_argument("--datadir", type=str, default='./data/llff/fern',
14 | help='input data directory')
15 | parser.add_argument("--progress_refresh_rate", type=int, default=10,
16 | help='how many iterations to show psnrs or iters')
17 |
18 | parser.add_argument('--with_depth', action='store_true')
19 | parser.add_argument('--downsample_train', type=float, default=1.0)
20 | parser.add_argument('--downsample_test', type=float, default=1.0)
21 |
22 | parser.add_argument('--model_name', type=str, default='TensorVMSplit',
23 | choices=['TensorVMSplit', 'TensorCP'])
24 |
25 | # loader options
26 | parser.add_argument("--batch_size", type=int, default=4096)
27 | parser.add_argument("--n_iters", type=int, default=30000)
28 |
29 | parser.add_argument('--dataset_name', type=str, default='blender',
30 | choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'])
31 |
32 |
33 | # training options
34 | # learning rate
35 | parser.add_argument("--lr_init", type=float, default=0.02,
36 | help='learning rate')
37 | parser.add_argument("--lr_basis", type=float, default=1e-3,
38 | help='learning rate')
39 | parser.add_argument("--lr_decay_iters", type=int, default=-1,
40 | help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters')
41 | parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1,
42 | help='the target decay ratio; after decay_iters inital lr decays to lr*ratio')
43 | parser.add_argument("--lr_upsample_reset", type=int, default=1,
44 | help='reset lr to inital after upsampling')
45 |
46 | # loss
47 | parser.add_argument("--L1_weight_inital", type=float, default=0.0,
48 | help='loss weight')
49 | parser.add_argument("--L1_weight_rest", type=float, default=0,
50 | help='loss weight')
51 | parser.add_argument("--Ortho_weight", type=float, default=0.0,
52 | help='loss weight')
53 | parser.add_argument("--TV_weight_density", type=float, default=0.0,
54 | help='loss weight')
55 | parser.add_argument("--TV_weight_app", type=float, default=0.0,
56 | help='loss weight')
57 |
58 | # model
59 | # volume options
60 | parser.add_argument("--n_lamb_sigma", type=int, action="append")
61 | parser.add_argument("--n_lamb_sh", type=int, action="append")
62 | parser.add_argument("--data_dim_color", type=int, default=27)
63 |
64 | parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001,
65 | help='mask points in ray marching')
66 | parser.add_argument("--alpha_mask_thre", type=float, default=0.0001,
67 | help='threshold for creating alpha mask volume')
68 | parser.add_argument("--distance_scale", type=float, default=25,
69 | help='scaling sampling distance for computation')
70 | parser.add_argument("--density_shift", type=float, default=-10,
71 | help='shift density in softplus; making density = 0 when feature == 0')
72 |
73 | # network decoder
74 | parser.add_argument("--shadingMode", type=str, default="MLP_PE",
75 | help='which shading mode to use')
76 | parser.add_argument("--pos_pe", type=int, default=6,
77 | help='number of pe for pos')
78 | parser.add_argument("--view_pe", type=int, default=6,
79 | help='number of pe for view')
80 | parser.add_argument("--fea_pe", type=int, default=6,
81 | help='number of pe for features')
82 | parser.add_argument("--featureC", type=int, default=128,
83 | help='hidden feature channel in MLP')
84 |
85 |
86 |
87 | parser.add_argument("--ckpt", type=str, default=None,
88 | help='specific weights npy file to reload for coarse network')
89 | parser.add_argument("--render_only", type=int, default=0)
90 | parser.add_argument("--render_test", type=int, default=0)
91 | parser.add_argument("--render_train", type=int, default=0)
92 | parser.add_argument("--render_path", type=int, default=0)
93 | parser.add_argument("--export_mesh", type=int, default=0)
94 |
95 | # rendering options
96 | parser.add_argument('--lindisp', default=False, action="store_true",
97 | help='use disparity depth sampling')
98 | parser.add_argument("--perturb", type=float, default=1.,
99 | help='set to 0. for no jitter, 1. for jitter')
100 | parser.add_argument("--accumulate_decay", type=float, default=0.998)
101 | parser.add_argument("--fea2denseAct", type=str, default='softplus')
102 | parser.add_argument('--ndc_ray', type=int, default=0)
103 | parser.add_argument('--nSamples', type=int, default=1e6,
104 | help='sample point each ray, pass 1e6 if automatic adjust')
105 | parser.add_argument('--step_ratio',type=float,default=0.5)
106 |
107 |
108 | ## blender flags
109 | parser.add_argument("--white_bkgd", action='store_true',
110 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
111 |
112 |
113 |
114 | parser.add_argument('--N_voxel_init',
115 | type=int,
116 | default=100**3)
117 | parser.add_argument('--N_voxel_final',
118 | type=int,
119 | default=300**3)
120 | parser.add_argument("--upsamp_list", type=int, action="append")
121 | parser.add_argument("--update_AlphaMask_list", type=int, action="append")
122 |
123 | parser.add_argument('--idx_view',
124 | type=int,
125 | default=0)
126 | # logging/saving options
127 | parser.add_argument("--N_vis", type=int, default=5,
128 | help='N images to vis')
129 | parser.add_argument("--vis_every", type=int, default=10000,
130 | help='frequency of visualize the image')
131 | if cmd is not None:
132 | return parser.parse_args(cmd)
133 | else:
134 | return parser.parse_args()
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
1 | import torch,os,imageio,sys
2 | from tqdm.auto import tqdm
3 | from dataLoader.ray_utils import get_rays
4 | from models.tensoRF import TensorVM, TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask
5 | from utils import *
6 | from dataLoader.ray_utils import ndc_rays_blender
7 |
8 |
9 | def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda'):
10 |
11 | rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], []
12 | N_rays_all = rays.shape[0]
13 | for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
14 | rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
15 |
16 | rgb_map, depth_map = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples)
17 |
18 | rgbs.append(rgb_map)
19 | depth_maps.append(depth_map)
20 |
21 | return torch.cat(rgbs), None, torch.cat(depth_maps), None, None
22 |
23 | @torch.no_grad()
24 | def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
25 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
26 | PSNRs, rgb_maps, depth_maps = [], [], []
27 | ssims,l_alex,l_vgg=[],[],[]
28 | os.makedirs(savePath, exist_ok=True)
29 | os.makedirs(savePath+"/rgbd", exist_ok=True)
30 |
31 | try:
32 | tqdm._instances.clear()
33 | except Exception:
34 | pass
35 |
36 | near_far = test_dataset.near_far
37 | img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays.shape[0] // N_vis,1)
38 | idxs = list(range(0, test_dataset.all_rays.shape[0], img_eval_interval))
39 | for idx, samples in tqdm(enumerate(test_dataset.all_rays[0::img_eval_interval]), file=sys.stdout):
40 |
41 | W, H = test_dataset.img_wh
42 | rays = samples.view(-1,samples.shape[-1])
43 |
44 | rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples,
45 | ndc_ray=ndc_ray, white_bg = white_bg, device=device)
46 | rgb_map = rgb_map.clamp(0.0, 1.0)
47 |
48 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
49 |
50 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
51 | if len(test_dataset.all_rgbs):
52 | gt_rgb = test_dataset.all_rgbs[idxs[idx]].view(H, W, 3)
53 | loss = torch.mean((rgb_map - gt_rgb) ** 2)
54 | PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))
55 |
56 | if compute_extra_metrics:
57 | ssim = rgb_ssim(rgb_map, gt_rgb, 1)
58 | l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device)
59 | l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device)
60 | ssims.append(ssim)
61 | l_alex.append(l_a)
62 | l_vgg.append(l_v)
63 |
64 | rgb_map = (rgb_map.numpy() * 255).astype('uint8')
65 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
66 | rgb_maps.append(rgb_map)
67 | depth_maps.append(depth_map)
68 | if savePath is not None:
69 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
70 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
71 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
72 |
73 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10)
74 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10)
75 |
76 | if PSNRs:
77 | psnr = np.mean(np.asarray(PSNRs))
78 | if compute_extra_metrics:
79 | ssim = np.mean(np.asarray(ssims))
80 | l_a = np.mean(np.asarray(l_alex))
81 | l_v = np.mean(np.asarray(l_vgg))
82 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
83 | else:
84 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
85 |
86 |
87 | return PSNRs
88 |
89 | @torch.no_grad()
90 | def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
91 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
92 | PSNRs, rgb_maps, depth_maps = [], [], []
93 | ssims,l_alex,l_vgg=[],[],[]
94 | os.makedirs(savePath, exist_ok=True)
95 | os.makedirs(savePath+"/rgbd", exist_ok=True)
96 |
97 | try:
98 | tqdm._instances.clear()
99 | except Exception:
100 | pass
101 |
102 | near_far = test_dataset.near_far
103 | for idx, c2w in tqdm(enumerate(c2ws)):
104 |
105 | W, H = test_dataset.img_wh
106 |
107 | c2w = torch.FloatTensor(c2w)
108 | rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3)
109 | if ndc_ray:
110 | rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
111 | rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
112 |
113 | rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples,
114 | ndc_ray=ndc_ray, white_bg = white_bg, device=device)
115 | rgb_map = rgb_map.clamp(0.0, 1.0)
116 |
117 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
118 |
119 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
120 |
121 | rgb_map = (rgb_map.numpy() * 255).astype('uint8')
122 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
123 | rgb_maps.append(rgb_map)
124 | depth_maps.append(depth_map)
125 | if savePath is not None:
126 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
127 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
128 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
129 |
130 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
131 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8)
132 |
133 | if PSNRs:
134 | psnr = np.mean(np.asarray(PSNRs))
135 | if compute_extra_metrics:
136 | ssim = np.mean(np.asarray(ssims))
137 | l_a = np.mean(np.asarray(l_alex))
138 | l_v = np.mean(np.asarray(l_vgg))
139 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
140 | else:
141 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
142 |
143 |
144 | return PSNRs
145 |
146 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from tqdm.auto import tqdm
4 | from opt import config_parser
5 |
6 |
7 |
8 | import json, random
9 | from renderer import *
10 | from utils import *
11 | from torch.utils.tensorboard import SummaryWriter
12 | import datetime
13 |
14 | from dataLoader import dataset_dict
15 | import sys
16 |
17 |
18 |
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 | renderer = OctreeRender_trilinear_fast
22 |
23 |
24 | class SimpleSampler:
25 | def __init__(self, total, batch):
26 | self.total = total
27 | self.batch = batch
28 | self.curr = total
29 | self.ids = None
30 |
31 | def nextids(self):
32 | self.curr+=self.batch
33 | if self.curr + self.batch > self.total:
34 | self.ids = torch.LongTensor(np.random.permutation(self.total))
35 | self.curr = 0
36 | return self.ids[self.curr:self.curr+self.batch]
37 |
38 |
39 | @torch.no_grad()
40 | def export_mesh(args):
41 |
42 | ckpt = torch.load(args.ckpt, map_location=device)
43 | kwargs = ckpt['kwargs']
44 | kwargs.update({'device': device})
45 | tensorf = eval(args.model_name)(**kwargs)
46 | tensorf.load(ckpt)
47 |
48 | alpha,_ = tensorf.getDenseAlpha()
49 | convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)
50 |
51 |
52 | @torch.no_grad()
53 | def render_test(args):
54 | # init dataset
55 | dataset = dataset_dict[args.dataset_name]
56 | test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
57 | white_bg = test_dataset.white_bg
58 | ndc_ray = args.ndc_ray
59 |
60 | if not os.path.exists(args.ckpt):
61 | print('the ckpt path does not exists!!')
62 | return
63 |
64 | ckpt = torch.load(args.ckpt, map_location=device)
65 | kwargs = ckpt['kwargs']
66 | kwargs.update({'device': device})
67 | tensorf = eval(args.model_name)(**kwargs)
68 | tensorf.load(ckpt)
69 |
70 | logfolder = os.path.dirname(args.ckpt)
71 | if args.render_train:
72 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
73 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
74 | PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
75 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
76 | print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
77 |
78 | if args.render_test:
79 | os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
80 | evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
81 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
82 |
83 | if args.render_path:
84 | c2ws = test_dataset.render_path
85 | os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
86 | evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
87 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
88 |
89 | def reconstruction(args):
90 |
91 | # init dataset
92 | dataset = dataset_dict[args.dataset_name]
93 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
94 | test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
95 | white_bg = train_dataset.white_bg
96 | near_far = train_dataset.near_far
97 | ndc_ray = args.ndc_ray
98 |
99 | # init resolution
100 | upsamp_list = args.upsamp_list
101 | update_AlphaMask_list = args.update_AlphaMask_list
102 | n_lamb_sigma = args.n_lamb_sigma
103 | n_lamb_sh = args.n_lamb_sh
104 |
105 |
106 | if args.add_timestamp:
107 | logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
108 | else:
109 | logfolder = f'{args.basedir}/{args.expname}'
110 |
111 |
112 | # init log file
113 | os.makedirs(logfolder, exist_ok=True)
114 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
115 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
116 | os.makedirs(f'{logfolder}/rgba', exist_ok=True)
117 | summary_writer = SummaryWriter(logfolder)
118 |
119 |
120 |
121 | # init parameters
122 | # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
123 | aabb = train_dataset.scene_bbox.to(device)
124 | reso_cur = N_to_reso(args.N_voxel_init, aabb)
125 | nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
126 |
127 |
128 | if args.ckpt is not None:
129 | ckpt = torch.load(args.ckpt, map_location=device)
130 | kwargs = ckpt['kwargs']
131 | kwargs.update({'device':device})
132 | tensorf = eval(args.model_name)(**kwargs)
133 | tensorf.load(ckpt)
134 | else:
135 | tensorf = eval(args.model_name)(aabb, reso_cur, device,
136 | density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,
137 | shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale,
138 | pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct)
139 |
140 |
141 | grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
142 | if args.lr_decay_iters > 0:
143 | lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
144 | else:
145 | args.lr_decay_iters = args.n_iters
146 | lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)
147 |
148 | print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
149 |
150 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))
151 |
152 |
153 | #linear in logrithmic space
154 | N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]
155 |
156 |
157 | torch.cuda.empty_cache()
158 | PSNRs,PSNRs_test = [],[0]
159 |
160 | allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
161 | if not args.ndc_ray:
162 | allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True)
163 | trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
164 |
165 | Ortho_reg_weight = args.Ortho_weight
166 | print("initial Ortho_reg_weight", Ortho_reg_weight)
167 |
168 | L1_reg_weight = args.L1_weight_inital
169 | print("initial L1_reg_weight", L1_reg_weight)
170 | TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app
171 | tvreg = TVLoss()
172 | print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}")
173 |
174 |
175 | pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
176 | for iteration in pbar:
177 |
178 |
179 | ray_idx = trainingSampler.nextids()
180 | rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device)
181 |
182 | #rgb_map, alphas_map, depth_map, weights, uncertainty
183 | rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(rays_train, tensorf, chunk=args.batch_size,
184 | N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True)
185 |
186 | loss = torch.mean((rgb_map - rgb_train) ** 2)
187 |
188 |
189 | # loss
190 | total_loss = loss
191 | if Ortho_reg_weight > 0:
192 | loss_reg = tensorf.vector_comp_diffs()
193 | total_loss += Ortho_reg_weight*loss_reg
194 | summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
195 | if L1_reg_weight > 0:
196 | loss_reg_L1 = tensorf.density_L1()
197 | total_loss += L1_reg_weight*loss_reg_L1
198 | summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
199 |
200 | if TV_weight_density>0:
201 | TV_weight_density *= lr_factor
202 | loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density
203 | total_loss = total_loss + loss_tv
204 | summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
205 | if TV_weight_app>0:
206 | TV_weight_app *= lr_factor
207 | loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app
208 | total_loss = total_loss + loss_tv
209 | summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
210 |
211 | optimizer.zero_grad()
212 | total_loss.backward()
213 | optimizer.step()
214 |
215 | loss = loss.detach().item()
216 |
217 | PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
218 | summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
219 | summary_writer.add_scalar('train/mse', loss, global_step=iteration)
220 |
221 |
222 | for param_group in optimizer.param_groups:
223 | param_group['lr'] = param_group['lr'] * lr_factor
224 |
225 | # Print the current values of the losses.
226 | if iteration % args.progress_refresh_rate == 0:
227 | pbar.set_description(
228 | f'Iteration {iteration:05d}:'
229 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
230 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
231 | + f' mse = {loss:.6f}'
232 | )
233 | PSNRs = []
234 |
235 |
236 | if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:
237 | PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
238 | prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False)
239 | summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
240 |
241 |
242 |
243 | if iteration in update_AlphaMask_list:
244 |
245 | if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution
246 | reso_mask = reso_cur
247 | new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))
248 | if iteration == update_AlphaMask_list[0]:
249 | tensorf.shrink(new_aabb)
250 | # tensorVM.alphaMask = None
251 | L1_reg_weight = args.L1_weight_rest
252 | print("continuing L1_reg_weight", L1_reg_weight)
253 |
254 |
255 | if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
256 | # filter rays outside the bbox
257 | allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs)
258 | trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
259 |
260 |
261 | if iteration in upsamp_list:
262 | n_voxels = N_voxel_list.pop(0)
263 | reso_cur = N_to_reso(n_voxels, tensorf.aabb)
264 | nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
265 | tensorf.upsample_volume_grid(reso_cur)
266 |
267 | if args.lr_upsample_reset:
268 | print("reset lr to initial")
269 | lr_scale = 1 #0.1 ** (iteration / args.n_iters)
270 | else:
271 | lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
272 | grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
273 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
274 |
275 |
276 | tensorf.save(f'{logfolder}/{args.expname}.th')
277 |
278 |
279 | if args.render_train:
280 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
281 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
282 | PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
283 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
284 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
285 |
286 | if args.render_test:
287 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
288 | PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/',
289 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
290 | summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
291 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
292 |
293 | if args.render_path:
294 | c2ws = test_dataset.render_path
295 | # c2ws = test_dataset.poses
296 | print('========>',c2ws.shape)
297 | os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
298 | evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
299 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
300 |
301 |
302 | if __name__ == '__main__':
303 |
304 | torch.set_default_dtype(torch.float32)
305 | torch.manual_seed(20211202)
306 | np.random.seed(20211202)
307 |
308 | args = config_parser()
309 | print(args)
310 |
311 | if args.export_mesh:
312 | export_mesh(args)
313 |
314 | if args.render_only and (args.render_test or args.render_path):
315 | render_test(args)
316 | else:
317 | reconstruction(args)
318 |
319 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2,torch
2 | import numpy as np
3 | from PIL import Image
4 | import torchvision.transforms as T
5 | import torch.nn.functional as F
6 | import scipy.signal
7 |
8 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
9 |
10 |
11 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
12 | """
13 | depth: (H, W)
14 | """
15 |
16 | x = np.nan_to_num(depth) # change nan to 0
17 | if minmax is None:
18 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
19 | ma = np.max(x)
20 | else:
21 | mi,ma = minmax
22 |
23 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
24 | x = (255*x).astype(np.uint8)
25 | x_ = cv2.applyColorMap(x, cmap)
26 | return x_, [mi,ma]
27 |
28 | def init_log(log, keys):
29 | for key in keys:
30 | log[key] = torch.tensor([0.0], dtype=float)
31 | return log
32 |
33 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
34 | """
35 | depth: (H, W)
36 | """
37 | if type(depth) is not np.ndarray:
38 | depth = depth.cpu().numpy()
39 |
40 | x = np.nan_to_num(depth) # change nan to 0
41 | if minmax is None:
42 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
43 | ma = np.max(x)
44 | else:
45 | mi,ma = minmax
46 |
47 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
48 | x = (255*x).astype(np.uint8)
49 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
50 | x_ = T.ToTensor()(x_) # (3, H, W)
51 | return x_, [mi,ma]
52 |
53 | def N_to_reso(n_voxels, bbox):
54 | xyz_min, xyz_max = bbox
55 | dim = len(xyz_min)
56 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
57 | return ((xyz_max - xyz_min) / voxel_size).long().tolist()
58 |
59 | def cal_n_samples(reso, step_ratio=0.5):
60 | return int(np.linalg.norm(reso)/step_ratio)
61 |
62 |
63 |
64 |
65 | __LPIPS__ = {}
66 | def init_lpips(net_name, device):
67 | assert net_name in ['alex', 'vgg']
68 | import lpips
69 | print(f'init_lpips: lpips_{net_name}')
70 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
71 |
72 | def rgb_lpips(np_gt, np_im, net_name, device):
73 | if net_name not in __LPIPS__:
74 | __LPIPS__[net_name] = init_lpips(net_name, device)
75 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
76 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
77 | return __LPIPS__[net_name](gt, im, normalize=True).item()
78 |
79 |
80 | def findItem(items, target):
81 | for one in items:
82 | if one[:len(target)]==target:
83 | return one
84 | return None
85 |
86 |
87 | ''' Evaluation metrics (ssim, lpips)
88 | '''
89 | def rgb_ssim(img0, img1, max_val,
90 | filter_size=11,
91 | filter_sigma=1.5,
92 | k1=0.01,
93 | k2=0.03,
94 | return_map=False):
95 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
96 | assert len(img0.shape) == 3
97 | assert img0.shape[-1] == 3
98 | assert img0.shape == img1.shape
99 |
100 | # Construct a 1D Gaussian blur filter.
101 | hw = filter_size // 2
102 | shift = (2 * hw - filter_size + 1) / 2
103 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
104 | filt = np.exp(-0.5 * f_i)
105 | filt /= np.sum(filt)
106 |
107 | # Blur in x and y (faster than the 2D convolution).
108 | def convolve2d(z, f):
109 | return scipy.signal.convolve2d(z, f, mode='valid')
110 |
111 | filt_fn = lambda z: np.stack([
112 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
113 | for i in range(z.shape[-1])], -1)
114 | mu0 = filt_fn(img0)
115 | mu1 = filt_fn(img1)
116 | mu00 = mu0 * mu0
117 | mu11 = mu1 * mu1
118 | mu01 = mu0 * mu1
119 | sigma00 = filt_fn(img0**2) - mu00
120 | sigma11 = filt_fn(img1**2) - mu11
121 | sigma01 = filt_fn(img0 * img1) - mu01
122 |
123 | # Clip the variances and covariances to valid values.
124 | # Variance must be non-negative:
125 | sigma00 = np.maximum(0., sigma00)
126 | sigma11 = np.maximum(0., sigma11)
127 | sigma01 = np.sign(sigma01) * np.minimum(
128 | np.sqrt(sigma00 * sigma11), np.abs(sigma01))
129 | c1 = (k1 * max_val)**2
130 | c2 = (k2 * max_val)**2
131 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
132 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
133 | ssim_map = numer / denom
134 | ssim = np.mean(ssim_map)
135 | return ssim_map if return_map else ssim
136 |
137 |
138 | import torch.nn as nn
139 | class TVLoss(nn.Module):
140 | def __init__(self,TVLoss_weight=1):
141 | super(TVLoss,self).__init__()
142 | self.TVLoss_weight = TVLoss_weight
143 |
144 | def forward(self,x):
145 | batch_size = x.size()[0]
146 | h_x = x.size()[2]
147 | w_x = x.size()[3]
148 | count_h = self._tensor_size(x[:,:,1:,:])
149 | count_w = self._tensor_size(x[:,:,:,1:])
150 | count_w = max(count_w, 1)
151 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
152 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
153 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
154 |
155 | def _tensor_size(self,t):
156 | return t.size()[1]*t.size()[2]*t.size()[3]
157 |
158 |
159 |
160 | import plyfile
161 | import skimage.measure
162 | def convert_sdf_samples_to_ply(
163 | pytorch_3d_sdf_tensor,
164 | ply_filename_out,
165 | bbox,
166 | level=0.5,
167 | offset=None,
168 | scale=None,
169 | ):
170 | """
171 | Convert sdf samples to .ply
172 |
173 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
174 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
175 | :voxel_size: float, the size of the voxels
176 | :ply_filename_out: string, path of the filename to save to
177 |
178 | This function adapted from: https://github.com/RobotLocomotion/spartan
179 | """
180 |
181 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
182 | voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape))
183 |
184 | verts, faces, normals, values = skimage.measure.marching_cubes(
185 | numpy_3d_sdf_tensor, level=level, spacing=voxel_size
186 | )
187 | faces = faces[...,::-1] # inverse face orientation
188 |
189 | # transform from voxel coordinates to camera coordinates
190 | # note x and y are flipped in the output of marching_cubes
191 | mesh_points = np.zeros_like(verts)
192 | mesh_points[:, 0] = bbox[0,0] + verts[:, 0]
193 | mesh_points[:, 1] = bbox[0,1] + verts[:, 1]
194 | mesh_points[:, 2] = bbox[0,2] + verts[:, 2]
195 |
196 | # apply additional offset and scale
197 | if scale is not None:
198 | mesh_points = mesh_points / scale
199 | if offset is not None:
200 | mesh_points = mesh_points - offset
201 |
202 | # try writing to the ply file
203 |
204 | num_verts = verts.shape[0]
205 | num_faces = faces.shape[0]
206 |
207 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
208 |
209 | for i in range(0, num_verts):
210 | verts_tuple[i] = tuple(mesh_points[i, :])
211 |
212 | faces_building = []
213 | for i in range(0, num_faces):
214 | faces_building.append(((faces[i, :].tolist(),)))
215 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
216 |
217 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
218 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
219 |
220 | ply_data = plyfile.PlyData([el_verts, el_faces])
221 | print("saving mesh to %s" % (ply_filename_out))
222 | ply_data.write(ply_filename_out)
223 |
--------------------------------------------------------------------------------