├── .dockerignore ├── .gitattributes ├── .gitignore ├── .prettierrc.js ├── .readthedocs.yaml ├── .vscode ├── c_cpp_properties.json ├── launch.json └── settings.json ├── Dockerfile ├── LICENSE ├── MSTH ├── SpaceTimeHashing │ ├── field.py │ ├── mle.py │ ├── model.py │ ├── permute_field.py │ ├── ray_samplers.py │ ├── rect_model.py │ ├── render.py │ ├── stmodel.py │ ├── stmodel_components.py │ ├── stmodel_freeze.py │ ├── test_pose.py │ ├── trainer.py │ └── viz.py ├── __init__.py ├── configs │ ├── __init__.py │ ├── freeze.py │ ├── method_config_coffee.py │ ├── method_config_cook_spinach.py │ ├── method_config_parameter_search.py │ ├── method_config_ps2.py │ ├── method_config_rubik.py │ ├── method_config_salmon.py │ ├── method_config_success.py │ ├── method_configs.py │ ├── method_configs_czl.py │ ├── method_configs_imm.py │ ├── method_configs_imm_final.py │ ├── method_configs_tp.py │ ├── method_configs_tpa.py │ └── method_import.py ├── cos_scheduler.py ├── datamanager.py ├── dataparser.py ├── dataset.py ├── directModel │ ├── deferred_nerfacto_field.py │ ├── deferred_nerfacto_model.py │ ├── density_field_nomlp.py │ └── trainer.py ├── docs │ └── immersive.md ├── exploration │ ├── psnr.ipynb │ └── spare_loss.ipynb ├── grid_field.py ├── gridencoder │ ├── __init__.py │ ├── backend.py │ ├── grid.py │ ├── grid_backup.py │ ├── grid_czl.py │ ├── setup.py │ ├── src │ │ ├── bindings.cpp │ │ ├── gridencoder.cu │ │ ├── gridencoder.h │ │ ├── stencoder.cu │ │ ├── test_1.cu │ │ └── test_2.cu │ ├── stgrid.py │ └── test_stgrid.py ├── ibrnet │ ├── colorizer.py │ ├── feature_extractor.py │ ├── field.py │ └── test.py ├── projector.py ├── sampler.py ├── scripts │ ├── __init__.py │ ├── eval.py │ ├── eval_spiral.py │ ├── frames2video.py │ ├── imm2nerfstudio.py │ ├── nerfacto.sh │ ├── nerfacto_profiling.sh │ ├── prepdata │ │ ├── downsample.py │ │ ├── downsample_fps.py │ │ ├── llff2nerf.py │ │ └── neural3d.sh │ ├── profiling │ │ ├── cProfile.sh │ │ ├── fast_rendering.sh │ │ ├── instant_ngp.sh │ │ └── line_prof.sh │ ├── run.sh │ ├── run_imm.py │ ├── run_seq.sh │ ├── run_taks.py │ ├── sth.sh │ ├── sth_profiling.sh │ ├── sth_rect.sh │ ├── stream_nerfacto_baseline.sh │ ├── test_storage.py │ ├── train.py │ ├── tunning_machine.py │ ├── video_frames.sh │ ├── video_nerfacto_baseline.sh │ ├── video_tensorf_baseline.sh │ └── video_train.py ├── streamable_density_fields.py ├── streamable_model.py ├── streamable_nerfacto_field.py ├── streamable_pipeline.py ├── test.py ├── test_grid.py ├── utils.py ├── video_pipeline.py └── wandb │ ├── latest-run │ ├── run-20230413_195110-2c1ru4ae │ ├── files │ │ └── config.yaml │ └── run-2c1ru4ae.wandb │ ├── run-20230413_200145-7prtgoar │ ├── files │ │ └── config.yaml │ └── run-7prtgoar.wandb │ ├── run-20230413_200502-qmbtu7bi │ ├── files │ │ └── config.yaml │ └── run-qmbtu7bi.wandb │ └── run-20230413_200948-gorthja9 │ ├── files │ ├── code │ │ └── TRTT │ │ │ └── test.py │ ├── conda-environment.yaml │ ├── config.yaml │ ├── diff.patch │ ├── requirements.txt │ ├── wandb-metadata.json │ └── wandb-summary.json │ └── run-gorthja9.wandb ├── README.md ├── dataparser.py ├── download.py ├── nerfstudio ├── __init__.py ├── cameras │ ├── __init__.py │ ├── camera_optimizers.py │ ├── camera_paths.py │ ├── camera_utils.py │ ├── cameras.py │ ├── lie_groups.py │ └── rays.py ├── configs │ ├── __init__.py │ ├── base_config.py │ ├── config_utils.py │ ├── experiment_config.py │ └── method_configs.py ├── data │ ├── __init__.py │ ├── datamanagers │ │ ├── __init__.py │ │ ├── base_datamanager.py │ │ ├── depth_datamanager.py │ │ ├── semantic_datamanager.py │ │ └── variable_res_datamanager.py │ ├── dataparsers │ │ ├── __init__.py │ │ ├── arkitscenes_dataparser.py │ │ ├── base_dataparser.py │ │ ├── blender_dataparser.py │ │ ├── dnerf_dataparser.py │ │ ├── dycheck_dataparser.py │ │ ├── instant_ngp_dataparser.py │ │ ├── minimal_dataparser.py │ │ ├── nerfstudio_dataparser.py │ │ ├── nuscenes_dataparser.py │ │ ├── phototourism_dataparser.py │ │ ├── scannet_dataparser.py │ │ ├── sdfstudio_dataparser.py │ │ └── sitcoms3d_dataparser.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── depth_dataset.py │ │ ├── sdf_dataset.py │ │ └── semantic_dataset.py │ ├── pixel_samplers.py │ ├── scene_box.py │ └── utils │ │ ├── __init__.py │ │ ├── colmap_parsing_utils.py │ │ ├── data_utils.py │ │ ├── dataloaders.py │ │ └── nerfstudio_collate.py ├── engine │ ├── __init__.py │ ├── callbacks.py │ ├── optimizers.py │ ├── schedulers.py │ └── trainer.py ├── exporter │ ├── __init__.py │ ├── exporter_utils.py │ ├── texture_utils.py │ └── tsdf_utils.py ├── field_components │ ├── __init__.py │ ├── activations.py │ ├── base_field_component.py │ ├── cuda │ │ ├── __init__.py │ │ ├── _backend.py │ │ └── csrc │ │ │ ├── include │ │ │ └── temporal_gridencoder.h │ │ │ ├── pybind.cu │ │ │ └── temporal_gridencoder.cu │ ├── embedding.py │ ├── encodings.py │ ├── field_heads.py │ ├── mlp.py │ ├── spatial_distortions.py │ ├── temporal_distortions.py │ └── temporal_grid.py ├── fields │ ├── __init__.py │ ├── base_field.py │ ├── density_fields.py │ ├── instant_ngp_field.py │ ├── nerfacto_field.py │ ├── nerfplayer_nerfacto_field.py │ ├── nerfplayer_ngp_field.py │ ├── nerfw_field.py │ ├── semantic_nerf_field.py │ ├── tensorf_field.py │ └── vanilla_nerf_field.py ├── generative │ ├── __init__.py │ └── stable_diffusion.py ├── model_components │ ├── __init__.py │ ├── losses.py │ ├── ray_generators.py │ ├── ray_samplers.py │ ├── renderers.py │ ├── scene_colliders.py │ └── shaders.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── depth_nerfacto.py │ ├── instant_ngp.py │ ├── mipnerf.py │ ├── nerfacto.py │ ├── nerfplayer_nerfacto.py │ ├── nerfplayer_ngp.py │ ├── semantic_nerfw.py │ ├── tensorf.py │ └── vanilla_nerf.py ├── pipelines │ ├── __init__.py │ ├── base_pipeline.py │ └── dynamic_batch.py ├── plugins │ ├── __init__.py │ ├── registry.py │ └── types.py ├── process_data │ ├── __init__.py │ ├── colmap_utils.py │ ├── equirect_utils.py │ ├── hloc_utils.py │ ├── insta360_utils.py │ ├── metashape_utils.py │ ├── polycam_utils.py │ ├── process_data_utils.py │ ├── realitycapture_utils.py │ └── record3d_utils.py ├── py.typed ├── utils │ ├── __init__.py │ ├── colormaps.py │ ├── colors.py │ ├── comms.py │ ├── decorators.py │ ├── eval_utils.py │ ├── install_checks.py │ ├── io.py │ ├── math.py │ ├── misc.py │ ├── plotly_utils.py │ ├── poses.py │ ├── printing.py │ ├── profiler.py │ ├── rich_utils.py │ ├── scripts.py │ ├── tensor_dataclass.py │ └── writer.py └── viewer │ ├── __init__.py │ ├── app │ ├── .env.development │ ├── .eslintrc.json │ ├── .gitignore │ ├── package.json │ ├── public │ │ ├── electron.js │ │ ├── index.html │ │ ├── manifest.json │ │ └── robots.txt │ ├── requirements.txt │ ├── run_deploy.py │ ├── src │ │ ├── App.jsx │ │ ├── SceneNode.js │ │ ├── index.jsx │ │ ├── index.scss │ │ ├── modules │ │ │ ├── Banner │ │ │ │ ├── Banner.jsx │ │ │ │ └── index.jsx │ │ │ ├── ConfigPanel │ │ │ │ ├── ConfigPanel.jsx │ │ │ │ └── ConfigPanelSlice.js │ │ │ ├── LandingModal │ │ │ │ ├── LandingModal.jsx │ │ │ │ └── index.jsx │ │ │ ├── LoadPathModal │ │ │ │ ├── LoadPathModal.jsx │ │ │ │ └── index.jsx │ │ │ ├── LogPanel │ │ │ │ └── LogPanel.jsx │ │ │ ├── RenderModal │ │ │ │ ├── RenderModal.jsx │ │ │ │ └── index.jsx │ │ │ ├── Scene │ │ │ │ ├── Scene.jsx │ │ │ │ └── drawing.js │ │ │ ├── SidePanel │ │ │ │ ├── CameraPanel │ │ │ │ │ ├── CameraHelper.js │ │ │ │ │ ├── CameraPanel.jsx │ │ │ │ │ ├── CameraPropPanel.jsx │ │ │ │ │ ├── curve.js │ │ │ │ │ └── index.jsx │ │ │ │ ├── ExportPanel │ │ │ │ │ ├── ExportPanel.jsx │ │ │ │ │ ├── MeshSubPanel.jsx │ │ │ │ │ ├── PointcloudSubPanel.jsx │ │ │ │ │ └── index.jsx │ │ │ │ ├── ScenePanel │ │ │ │ │ ├── ScenePanel.jsx │ │ │ │ │ └── index.jsx │ │ │ │ ├── SidePanel.jsx │ │ │ │ └── StatusPanel │ │ │ │ │ ├── StatusPanel.jsx │ │ │ │ │ └── index.jsx │ │ │ ├── ViewerWindow │ │ │ │ ├── ViewerWindow.jsx │ │ │ │ └── ViewerWindowSlice.js │ │ │ ├── ViewportControlsModal │ │ │ │ ├── ViewportControlsModal.jsx │ │ │ │ └── index.jsx │ │ │ ├── WebRtcWindow │ │ │ │ └── WebRtcWindow.jsx │ │ │ ├── WebSocket │ │ │ │ └── WebSocket.jsx │ │ │ └── WebSocketUrlField.jsx │ │ ├── reducer.js │ │ ├── setupTests.js │ │ ├── store.js │ │ ├── subscriber.js │ │ ├── themes │ │ │ ├── leva_theme.json │ │ │ └── theme.ts │ │ └── utils.js │ └── yarn.lock │ └── server │ ├── README.md │ ├── __init__.py │ ├── path.py │ ├── server.py │ ├── state │ ├── node.py │ └── state_node.py │ ├── subprocess.py │ ├── utils.py │ ├── video_stream.py │ ├── viewer_utils.py │ └── visualizer.py ├── pyproject.toml ├── scripts ├── __init__.py ├── benchmarking │ ├── launch_eval_blender.sh │ └── launch_train_blender.sh ├── blender │ ├── __init__.py │ └── nerfstudio_blender.py ├── completions │ ├── .gitignore │ ├── __init__.py │ ├── install.py │ ├── setup.bash │ └── setup.zsh ├── datasets │ └── process_nuscenes_masks.py ├── docs │ ├── __init__.py │ ├── add_nb_tags.py │ └── build_docs.py ├── downloads │ ├── __init__.py │ └── download_data.py ├── eval.py ├── exporter.py ├── generative │ ├── __init__.py │ └── trace_stable_diffusion.py ├── github │ ├── __init__.py │ └── run_actions.py ├── licensing │ ├── copyright.txt │ └── license_headers.sh ├── process_data.py ├── render.py ├── texture.py ├── train.py └── viewer │ ├── __init__.py │ ├── run_viewer.py │ └── view_dataset.py ├── test.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /.prettierrc.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | trailingComma: 'all', 3 | arrowParens: 'always', 4 | singleQuote: true, 5 | jsxSingleQuote: false, 6 | }; -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: '3.9' 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | fail_on_warning: true 21 | configuration: docs/conf.py 22 | 23 | # If using Sphinx, optionally build your docs in additional formats such as PDF 24 | # formats: 25 | # - pdf 26 | 27 | # Optionally declare the Python requirements required to build your docs 28 | python: 29 | install: 30 | # Equivalent to 'pip install .' 31 | - method: pip 32 | path: . 33 | # Equivalent to 'pip install .[docs]' 34 | - method: pip 35 | path: . 36 | extra_requirements: 37 | - docs 38 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/**" 7 | ], 8 | "defines": [], 9 | "compilerPath": "/usr/bin/gcc", 10 | "cStandard": "c11", 11 | "intelliSenseMode": "linux-gcc-x64", 12 | "configurationProvider": "ms-vscode.makefile-tools" 13 | } 14 | ], 15 | "version": 4 16 | } -------------------------------------------------------------------------------- /MSTH/SpaceTimeHashing/field.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import * 4 | from nerfacc import ContractionType, contract 5 | from torch.nn.parameter import Parameter 6 | from torchtyping import TensorType 7 | 8 | from nerfstudio.cameras.rays import RaySamples 9 | from nerfstudio.data.scene_box import SceneBox 10 | from nerfstudio.field_components.activations import trunc_exp 11 | from nerfstudio.field_components.embedding import Embedding 12 | from nerfstudio.field_components.field_heads import FieldHeadNames 13 | from nerfstudio.fields.base_field import Field 14 | 15 | 16 | class SpaceTimeHashingField(Field): 17 | def __init__( 18 | self, 19 | aabb: TensorType, 20 | num_layers: int = 2, 21 | hidden_dim: int = 64, 22 | geo_feat_dim: int = 15, 23 | num_layers_color: int = 3, 24 | hidden_dim_color: int = 64, 25 | use_appearance_embedding: Optional[bool] = False, 26 | num_images: Optional[int] = None, 27 | appearance_embedding_dim: int = 32, 28 | contraction_type: ContractionType = ContractionType.UN_BOUNDED_SPHERE, 29 | num_levels: int = 16, 30 | log2_hashmap_size: int = 19, 31 | max_res: int = 2048, 32 | ) -> None: 33 | super().__init__() 34 | self.aabb = Parameter(aabb, requires_grad=False) 35 | self.geo_feat_dim = geo_feat_dim 36 | self.contraction_type = contraction_type 37 | 38 | self.direction_encoding = tcnn.Encoding( 39 | n_input_dims=3, 40 | encoding_config={ 41 | "otype": "SphericalHarmonics", 42 | "degree": 4, 43 | }, 44 | ) 45 | 46 | self.mlp_base = tcnn.NetworkWithInputEncoding( 47 | n_input_dims=4, 48 | n_output_dims=1 + self.geo_feat_dim, 49 | encoding_config={ 50 | "otype": "HashGrid", 51 | "n_levels": num_levels, 52 | "n_features_per_level": 2, 53 | "log2_hashmap_size": log2_hashmap_size, 54 | "base_resolution": base_res, 55 | "per_level_scale": per_level_scale, 56 | }, 57 | network_config={ 58 | "otype": "FullyFusedMLP", 59 | "activation": "ReLU", 60 | "output_activation": "None", 61 | "n_neurons": hidden_dim, 62 | "n_hidden_layers": num_layers - 1, 63 | }, 64 | ) 65 | 66 | in_dim = self.direction_encoding.n_output_dims + self.geo_feat_dim 67 | if self.use_appearance_embedding: 68 | in_dim += self.appearance_embedding_dim 69 | self.mlp_head = tcnn.Network( 70 | n_input_dims=in_dim, 71 | n_output_dims=3, 72 | network_config={ 73 | "otype": "FullyFusedMLP", 74 | "activation": "ReLU", 75 | "output_activation": "Sigmoid", 76 | "n_neurons": hidden_dim_color, 77 | "n_hidden_layers": num_layers_color - 1, 78 | }, 79 | ) 80 | 81 | def get_density(self, ray_samples: RaySamples) -> Tuple[TensorType, TensorType]: 82 | pass 83 | -------------------------------------------------------------------------------- /MSTH/SpaceTimeHashing/mle.py: -------------------------------------------------------------------------------- 1 | """ 2 | codes for maximum likelihood estimation of the distributions in ray sampling. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class MLELoss(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, ts, est_mean, est_std): 13 | # ts B x Ns 14 | mle_mean = ts.mean(dim=-1).unsqueeze(dim=-1) 15 | mle_val = (ts - mle_mean.unsqueeze(-1) ** 2).mean(dim=1) 16 | loss_mean = ((mle_mean - est_mean) ** 2).mean() 17 | loss_std = ((mle_val - est_std**2) ** 2).mean() 18 | return loss_mean + loss_std 19 | -------------------------------------------------------------------------------- /MSTH/SpaceTimeHashing/test_pose.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def normalize(v): 5 | """Normalize a vector.""" 6 | return v / np.linalg.norm(v) 7 | 8 | def average_poses(poses): 9 | """ 10 | Calculate the average pose, which is then used to center all poses 11 | using @center_poses. Its computation is as follows: 12 | 1. Compute the center: the average of pose centers. 13 | 2. Compute the z axis: the normalized average z axis. 14 | 3. Compute axis y': the average y axis. 15 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 16 | 5. Compute the y axis: z cross product x. 17 | 18 | Note that at step 3, we cannot directly use y' as y axis since it's 19 | not necessarily orthogonal to z axis. We need to pass from x to y. 20 | Inputs: 21 | poses: (N_images, 3, 4) 22 | Outputs: 23 | pose_avg: (3, 4) the average pose 24 | """ 25 | # 1. Compute the center 26 | center = poses[..., 3].mean(0) # (3) 27 | 28 | # 2. Compute the z axis 29 | z = normalize(poses[..., 2].mean(0)) # (3) 30 | 31 | # 3. Compute axis y' (no need to normalize as it's not the final output) 32 | y_ = poses[..., 1].mean(0) # (3) 33 | 34 | # 4. Compute the x axis 35 | x = normalize(np.cross(z, y_)) # (3) 36 | 37 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 38 | y = np.cross(x, z) # (3) 39 | 40 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 41 | 42 | return pose_avg 43 | 44 | def viewmatrix(z, up, pos): 45 | vec2 = normalize(z) 46 | vec1_avg = up 47 | vec0 = normalize(np.cross(vec1_avg, vec2)) # camera x axis in world coord 48 | vec1 = normalize(np.cross(vec2, vec0)) # camera y axis in world coord 49 | m = np.eye(4) 50 | m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) 51 | return m 52 | 53 | up = np.array([0., 0., 1.]) 54 | 55 | z = np.array([-3.0, 0., 0.]) 56 | print(viewmatrix(z, up, np.array([0, 0, 0]))) -------------------------------------------------------------------------------- /MSTH/SpaceTimeHashing/viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from plotly.subplots import make_subplots 4 | import plotly.graph_objects as go 5 | import plotly.express as px 6 | 7 | df = px.data.tips() 8 | fig = px.histogram(df, x="total_bill") 9 | 10 | 11 | def gen_spatial_grid(reso=256): 12 | """3d uniform spatial grid in [0,1]^3""" 13 | x = torch.linspace(0, 1, reso) 14 | y = torch.linspace(0, 1, reso) 15 | z = torch.linspace(0, 1, reso) 16 | x, y, z = torch.meshgrid(x, y, z) 17 | return torch.stack([x, y, z], dim=-1).reshape(-1, 3) 18 | 19 | 20 | def viz_histograms(xs, names=None, show=True): 21 | _num = len(xs) 22 | xs = [x.flatten().cpu().numpy() for x in xs] 23 | fig = make_subplots(rows=_num, cols=1, shared_yaxes=True, subplot_titles=names) 24 | hists = [] 25 | for x in xs: 26 | counts, bins = np.histogram(x, bins=100) 27 | bins = 0.5 * (bins[:-1] + bins[1:]) 28 | hists.append(go.Bar(x=bins, y=counts)) 29 | fig.add_traces(hists, rows=list(range(1, _num + 1)), cols=[1] * _num) 30 | 31 | if show: 32 | fig.show() 33 | 34 | 35 | @torch.no_grad() 36 | def hist_from_mask(mask, reso=256, chunk_size=1 << 17): 37 | grid = gen_spatial_grid(reso) 38 | tot_size = grid.size(0) 39 | vals = torch.zeros([tot_size, 1]).to() 40 | 41 | for start in range(0, tot_size, chunk_size): 42 | end = min(start + chunk_size, tot_size) 43 | vals[start:end] = mask(grid[start:end])[..., 0:1].to() 44 | 45 | 46 | def viz_distribution(dists, show=True): 47 | dists = dists.cpu().numpy() 48 | _num = dists.shape[0] 49 | dists = [dist.flatten() for dist in dists] 50 | fig = make_subplots(rows=_num, cols=1, shared_yaxes=True) 51 | 52 | hists = [] 53 | for x in dists: 54 | counts, bins = np.histogram(x, bins=100) 55 | bins = 0.5 * (bins[:-1] + bins[1:]) 56 | hists.append(go.Bar(x=bins, y=counts)) 57 | fig.add_traces(hists, rows=list(range(1, _num + 1)), cols=[1] * _num) 58 | 59 | if show: 60 | fig.show() 61 | -------------------------------------------------------------------------------- /MSTH/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /MSTH/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /MSTH/configs/freeze.py: -------------------------------------------------------------------------------- 1 | from MSTH.configs.method_import import * 2 | 3 | method_configs: Dict[str, Union[TrainerConfig, VideoTrainerConfig]] = {} 4 | 5 | method_configs["freeze_mask"] = SpaceTimeHashingTrainerConfig( 6 | method_name="stmodel with mask freeze", 7 | steps_per_eval_batch=1000, 8 | steps_per_save=20000, 9 | max_num_iterations=30000, 10 | mixed_precision=True, 11 | log_gradients=True, 12 | pipeline=SpaceTimePipelineConfig( 13 | datamanager=SpaceTimeDataManagerConfig( 14 | dataparser=VideoDataParserConfig( 15 | # data=Path("/data/machine/data/flame_salmon_videos_2"), 16 | data=Path("/data/machine/data/flame_salmon_videos_2"), 17 | # data=Path("/data/machine/data/flame_salmon_videos_test"), 18 | downscale_factor=2, 19 | scale_factor=1 / 2.0, 20 | # scene_scale=8, 21 | ), 22 | train_num_rays_per_batch=4096, 23 | eval_num_rays_per_batch=4096, 24 | camera_optimizer=CameraOptimizerConfig(mode="off"), 25 | use_uint8=True, 26 | use_stratified_pixel_sampler=True, 27 | static_dynamic_sampling_ratio=50.0, 28 | static_dynamic_sampling_ratio_end=10.0, 29 | static_ratio_decay_total_steps=20000, 30 | ), 31 | model=DSpaceTimeHashingModelConfig( 32 | freeze_mask=True, 33 | freeze_mask_step=7000, 34 | max_res=(2048, 2048, 2048, 300), 35 | base_res=(16, 16, 16, 30), 36 | proposal_weights_anneal_max_num_iters=5000, 37 | # proposal_weights_anneal_slope = 10.0, 38 | log2_hashmap_size_spatial=19, 39 | log2_hashmap_size_temporal=21, 40 | proposal_net_args_list=[ 41 | { 42 | "hidden_dim": 16, 43 | "log2_hashmap_size_spatial": 17, 44 | "log2_hashmap_size_temporal": 17, 45 | "num_levels": 5, 46 | "max_res": (128, 128, 128, 150), 47 | "base_res": (16, 16, 16, 30), 48 | "use_linear": False, 49 | }, 50 | { 51 | "hidden_dim": 16, 52 | "log2_hashmap_size_spatial": 17, 53 | "log2_hashmap_size_temporal": 17, 54 | "num_levels": 5, 55 | "max_res": (256, 256, 256, 300), 56 | "base_res": (16, 16, 16, 30), 57 | "use_linear": False, 58 | }, 59 | ], 60 | # use_field_with_base=True, 61 | # use_sampler_with_base=True, 62 | ), 63 | ), 64 | optimizers={ 65 | "proposal_networks": { 66 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 67 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-3, max_steps=15000), 68 | }, 69 | "fields": { 70 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 71 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-3, max_steps=15000), 72 | }, 73 | }, 74 | ) -------------------------------------------------------------------------------- /MSTH/configs/method_config_salmon.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/MSTH/configs/method_config_salmon.py -------------------------------------------------------------------------------- /MSTH/configs/method_config_success.py: -------------------------------------------------------------------------------- 1 | # some success configs 2 | 3 | cook_spinach = "base_it40000_uniform3_cook_spinach_high_dynamic_base_new_50_30_mst" 4 | -------------------------------------------------------------------------------- /MSTH/configs/method_import.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from copy import deepcopy 3 | from typing import Dict, Optional 4 | from typing import * 5 | import copy 6 | import tyro 7 | from MSTH.directModel.deferred_nerfacto_model import DeferredNerfactoModel, DeferredNerfactoModelConfig 8 | from nerfacc import ContractionType 9 | 10 | from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig 11 | from nerfstudio.configs.base_config import ViewerConfig 12 | from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig 13 | from nerfstudio.data.datamanagers.depth_datamanager import DepthDataManagerConfig 14 | from nerfstudio.data.datamanagers.semantic_datamanager import SemanticDataManagerConfig 15 | from nerfstudio.data.datamanagers.variable_res_datamanager import ( 16 | VariableResDataManagerConfig, 17 | ) 18 | from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig 19 | from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig 20 | from nerfstudio.data.dataparsers.dycheck_dataparser import DycheckDataParserConfig 21 | from nerfstudio.data.dataparsers.instant_ngp_dataparser import ( 22 | InstantNGPDataParserConfig, 23 | ) 24 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig 25 | from nerfstudio.data.dataparsers.phototourism_dataparser import ( 26 | PhototourismDataParserConfig, 27 | ) 28 | 29 | from dataclasses import dataclass, field 30 | 31 | # from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig 32 | from nerfstudio.engine.optimizers import AdamOptimizerConfig, RAdamOptimizerConfig 33 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig 34 | from MSTH.cos_scheduler import CosineDecaySchedulerConfig 35 | from nerfstudio.engine.trainer import TrainerConfig 36 | from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind 37 | from nerfstudio.models.depth_nerfacto import DepthNerfactoModelConfig 38 | from nerfstudio.models.instant_ngp import InstantNGPModelConfig 39 | from nerfstudio.models.mipnerf import MipNerfModel 40 | from nerfstudio.models.nerfacto import NerfactoModelConfig 41 | from nerfstudio.models.nerfplayer_nerfacto import NerfplayerNerfactoModelConfig 42 | from nerfstudio.models.nerfplayer_ngp import NerfplayerNGPModelConfig 43 | from nerfstudio.models.semantic_nerfw import SemanticNerfWModelConfig 44 | from nerfstudio.models.tensorf import TensoRFModelConfig 45 | from nerfstudio.models.vanilla_nerf import NeRFModel, VanillaModelConfig 46 | from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig 47 | from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig 48 | from nerfstudio.plugins.registry import discover_methods 49 | from MSTH.streamable_model import StreamableNerfactoModelConfig 50 | from MSTH.directModel.trainer import VideoTrainerConfig 51 | from MSTH.datamanager import VideoDataManagerConfig 52 | from MSTH.dataparser import VideoDataParserConfig 53 | from MSTH.video_pipeline import VideoPipelineConfig 54 | 55 | from MSTH.datamanager import VideoFeatureDataManagerConfig 56 | from MSTH.SpaceTimeHashing.trainer import SpaceTimeHashingTrainerConfig 57 | from MSTH.datamanager import SpaceTimeDataManagerConfig 58 | from MSTH.video_pipeline import SpaceTimePipelineConfig 59 | from MSTH.SpaceTimeHashing.model import SpaceTimeHashingModelConfig 60 | from MSTH.SpaceTimeHashing.stmodel import DSpaceTimeHashingModelConfig 61 | from pathlib import Path 62 | import numpy as np 63 | -------------------------------------------------------------------------------- /MSTH/cos_scheduler.py: -------------------------------------------------------------------------------- 1 | from nerfstudio.engine.schedulers import Scheduler, SchedulerConfig 2 | from dataclasses import dataclass, field 3 | from typing import Optional, Type 4 | 5 | import numpy as np 6 | from torch.optim import Optimizer, lr_scheduler 7 | from typing_extensions import Literal 8 | 9 | from nerfstudio.configs.base_config import InstantiateConfig 10 | 11 | 12 | @dataclass 13 | class CosineDecaySchedulerConfig(SchedulerConfig): 14 | """Config for exponential decay scheduler with warmup""" 15 | 16 | _target: Type = field(default_factory=lambda: CosineDecayScheduler) 17 | """target class to instantiate""" 18 | lr_pre_warmup: float = 1e-8 19 | """Learning rate before warmup.""" 20 | lr_final: Optional[float] = None 21 | """Final learning rate. If not provided, it will be set to the optimizers learning rate.""" 22 | warmup_steps: int = 0 23 | """Number of warmup steps.""" 24 | max_steps: int = 100000 25 | """The maximum number of steps.""" 26 | ramp: Literal["linear", "cosine"] = "cosine" 27 | """The ramp function to use during the warmup.""" 28 | 29 | 30 | class CosineDecayScheduler(Scheduler): 31 | """Exponential decay scheduler with linear warmup. Scheduler first ramps up to `lr_init` in `warmup_steps` 32 | steps, then exponentially decays to `lr_final` in `max_steps` steps. 33 | """ 34 | 35 | config: CosineDecaySchedulerConfig 36 | 37 | def get_scheduler(self, optimizer: Optimizer, lr_init: float) -> lr_scheduler._LRScheduler: 38 | if self.config.lr_final is None: 39 | lr_final = lr_init 40 | else: 41 | lr_final = self.config.lr_final 42 | 43 | def func(step): 44 | if step < self.config.warmup_steps: 45 | if self.config.ramp == "cosine": 46 | lr = self.config.lr_pre_warmup + (1 - self.config.lr_pre_warmup) * np.sin( 47 | 0.5 * np.pi * np.clip(step / self.config.warmup_steps, 0, 1) 48 | ) 49 | else: 50 | lr = ( 51 | self.config.lr_pre_warmup 52 | + (lr_init - self.config.lr_pre_warmup) * step / self.config.warmup_steps 53 | ) 54 | else: 55 | t = np.clip( 56 | (step - self.config.warmup_steps) / (self.config.max_steps - self.config.warmup_steps), 0, 1 57 | ) 58 | # lr = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 59 | lr = lr_final + 0.5 * (lr_init - lr_final) * (1 + np.cos(np.pi * t)) 60 | return lr / lr_init # divided by lr_init because the multiplier is with the initial learning rate 61 | 62 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=func) 63 | return scheduler 64 | -------------------------------------------------------------------------------- /MSTH/docs/immersive.md: -------------------------------------------------------------------------------- 1 | ## Immersive Dataset 2 | 3 | [Paper](https://storage.googleapis.com/immersive-lf-video-siggraph2020/ImmersiveLightFieldVideoWithALayeredMeshRepresentation.pdf) [Dataset](https://github.com/augmentedperception/deepview_video_dataset) -------------------------------------------------------------------------------- /MSTH/gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder 2 | from .stgrid import SpatialTemporalGridEncoder 3 | -------------------------------------------------------------------------------- /MSTH/gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /MSTH/gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-w', 9 | '-O3', '-std=c++14', 10 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 11 | ] 12 | 13 | if os.name == "posix": 14 | c_flags = ['-O3', '-std=c++14'] 15 | elif os.name == "nt": 16 | c_flags = ['/O2', '/std:c++17'] 17 | 18 | # find cl.exe 19 | def find_cl_path(): 20 | import glob 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | setup( 34 | name='gridencoder', # package name, import this to use python API 35 | ext_modules=[ 36 | CUDAExtension( 37 | name='_gridencoder', # extension name, import this to use CUDA API 38 | sources=[os.path.join(_src_path, 'src', f) for f in [ 39 | 'gridencoder.cu', 40 | 'bindings.cpp', 41 | ]], 42 | extra_compile_args={ 43 | 'cxx': c_flags, 44 | 'nvcc': nvcc_flags, 45 | } 46 | ), 47 | ], 48 | cmdclass={ 49 | 'build_ext': BuildExtension, 50 | } 51 | ) 52 | -------------------------------------------------------------------------------- /MSTH/gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); 9 | m.def("grid_encode_hash_reinitialize", &grid_encode_hash_reinitialize, "grid_encode_hash_reinitialize (CUDA)"); 10 | m.def("grid_encode_set_static", &grid_encode_set_static, "grid_encode_set_static (CUDA)"); 11 | m.def("rect_grid_encode_forward", &rect_grid_encode_forward, "rect_grid_encode_forward (CUDA)"); 12 | m.def("rect_grid_encode_backward", &rect_grid_encode_backward, "rect_grid_encode_backward (CUDA)"); 13 | m.def("rect_grad_total_variation", &rect_grad_total_variation, "rect_grad_total_variation (CUDA)"); 14 | m.def("stgrid_encode_forward", &stgrid_encode_forward, "stgrid_encode_forward (CUDA)"); 15 | m.def("stgrid_encode_backward", &stgrid_encode_backward, "stgrid_encode_backward (CUDA)"); 16 | } -------------------------------------------------------------------------------- /MSTH/gridencoder/src/test_1.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/MSTH/gridencoder/src/test_1.cu -------------------------------------------------------------------------------- /MSTH/gridencoder/src/test_2.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/MSTH/gridencoder/src/test_2.cu -------------------------------------------------------------------------------- /MSTH/gridencoder/test_stgrid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from stgrid import SpatialTemporalGridEncoder as stencoder 4 | from grid import GridEncoder as encoder 5 | 6 | pls = (1.2, 1.2, 1.2, 1.2) 7 | bs = (16, 16, 16, 16) 8 | stm = stencoder(input_dim=4, per_level_scale=pls, base_resolution=bs, std=1).cuda() 9 | s = encoder(per_level_scale=pls[:3], base_resolution=bs[:3]).cuda() 10 | t = encoder( 11 | input_dim=4, 12 | per_level_scale=pls, 13 | base_resolution=bs, 14 | ).cuda() 15 | m = encoder( 16 | input_dim=3, 17 | num_levels=1, 18 | level_dim=1, 19 | per_level_scale=(1, 1, 1), 20 | base_resolution=(128, 128, 128), 21 | gridtype="tiled", 22 | log2_hashmap_size=21, 23 | interpolation="all_nearest", 24 | ).cuda() 25 | 26 | print(m.embeddings.shape) 27 | print(stm.membeddings.shape) 28 | 29 | s.embeddings.data.copy_(stm.sembeddings.data) 30 | t.embeddings.data.copy_(stm.tembeddings.data) 31 | m.embeddings.data[..., 0].copy_(stm.membeddings.data) 32 | 33 | 34 | class STMTorch(nn.Module): 35 | def __init__(self, sencoder, tencoder, mencoder): 36 | super().__init__() 37 | self.s = sencoder 38 | self.t = tencoder 39 | self.m = mencoder 40 | 41 | def forward(self, x): 42 | s = self.s(x[:, :3]) 43 | t = self.t(x) 44 | m = self.m(x[:, :3]).sigmoid() 45 | print(t) 46 | print(m) 47 | return s + t * (1 - m) 48 | 49 | 50 | stmtorch = STMTorch(s, t, m).cuda() 51 | 52 | x = torch.rand(5, 4).cuda() 53 | print(x) 54 | # x = x.abs() 55 | # x = x / x.max()[0] 56 | 57 | y1 = stmtorch(x) 58 | y2 = stm(x) 59 | print(stm.toffsets) 60 | print(stmtorch.t.offsets) 61 | print(stm.soffsets) 62 | print(stmtorch.s.offsets) 63 | # print(y1) 64 | # print(y2) 65 | # print((y1).abs().mean()) 66 | # print((y1 - y2).abs().mean()) 67 | -------------------------------------------------------------------------------- /MSTH/ibrnet/colorizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from MSTH.projector import Projector 5 | from MSTH.dataset import VideoDatasetWithFeature 6 | from nerfstudio.cameras.cameras import Cameras 7 | from nerfstudio.cameras.rays import RaySamples 8 | 9 | class Colorizer(nn.Module): 10 | def __init__(self, cameras: Cameras, dataset: VideoDatasetWithFeature, device) -> None: 11 | super().__init__() 12 | self.cameras = cameras 13 | self.dataset = dataset 14 | self.projector = Projector(cameras) 15 | 16 | self.n_feats = dataset.cur_frame_feats_buffer.shape[-1] 17 | 18 | self.mlp_base = nn.Sequential(nn.Linear(3 + self.n_feats, 32)) 19 | 20 | self.mlp_head = nn.Linear(36, 1) 21 | 22 | self.device = device 23 | 24 | def get_images(self): 25 | return torch.from_numpy(self.dataset.cur_frame_buffer) 26 | 27 | def get_features(self): 28 | # return torch.from_numpy(self.dataset.cur_frame_feats_buffer).to(self.device) 29 | return self.dataset.cur_frame_feats_buffer 30 | 31 | def colorize(self, ray_samples: RaySamples, c2ws): 32 | positions = ray_samples.frustums.get_positions() 33 | ray_shape = positions.size() 34 | positions = positions.reshape(-1, 3) 35 | pixels, feats, masks = self.projector(positions, self.get_images(), self.get_features()) 36 | 37 | # pixels: [b, ncams, 3] 38 | 39 | # print(feats.shape) 40 | # feats_mean = torch.mean(feats, dim=1, keepdim=True).repeat(1, feats.size(1), 1) 41 | # feats_var = torch.var(feats, dim=1, keepdim=True).repeat(1, feats.size(1), 1) 42 | # feats = torch.cat([pixels, feats, feats_mean, feats_var], dim=-1) 43 | pixels = pixels.to(self.device) 44 | feats = feats.to(self.device) 45 | masks = masks.to(self.device) 46 | feats = torch.cat([pixels, feats], dim=-1) 47 | 48 | cam_shap = feats.size() 49 | # feats = feats.reshape(-1, feats.size(-1)) 50 | 51 | camera_indices = ray_samples.camera_indices.squeeze() 52 | 53 | dir_diff = self.projector.get_relative_directions_with_positions(positions, camera_indices, c2ws) 54 | 55 | latent = self.mlp_base(feats) 56 | 57 | latent = torch.cat([latent, dir_diff], dim=-1) 58 | 59 | weights = self.mlp_head(latent).squeeze() 60 | # [b, ncams] 61 | weights[~masks] = -100. 62 | 63 | # weights = weights.resize(*cam_shap[:-1], -1) 64 | # [b, 18] 65 | 66 | weights = F.softmax(weights, dim=-1).unsqueeze(-1) 67 | 68 | colors = torch.sum(weights * pixels, dim=1) 69 | 70 | return colors.reshape(ray_shape) 71 | 72 | def forward(self, ray_samples, c2ws): 73 | return self.colorize(ray_samples, c2ws) -------------------------------------------------------------------------------- /MSTH/ibrnet/field.py: -------------------------------------------------------------------------------- 1 | from .colorizer import Colorizer 2 | -------------------------------------------------------------------------------- /MSTH/ibrnet/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from feature_extractor import ResUNet 4 | 5 | def test_load_from_ckpt(): 6 | u = ResUNet.load_from_pretrained("model_255000.pth") 7 | print(u) 8 | 9 | def test_shape(): 10 | # torch.set_default_tensor_type(torch.cuda.FloatTensor) 11 | u = ResUNet.load_from_pretrained("model_255000.pth") 12 | inputs = torch.randn(1, 2704, 2028, 3) 13 | outputs = u(inputs) 14 | print(outputs[0].shape) 15 | print(outputs[1].shape) 16 | 17 | if __name__ == "__main__": 18 | test_shape() -------------------------------------------------------------------------------- /MSTH/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/MSTH/scripts/__init__.py -------------------------------------------------------------------------------- /MSTH/scripts/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | eval.py 4 | """ 5 | from __future__ import annotations 6 | 7 | import json 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | 11 | import tyro 12 | from rich.console import Console 13 | 14 | from nerfstudio.utils.eval_utils import eval_setup 15 | 16 | CONSOLE = Console(width=120) 17 | 18 | 19 | @dataclass 20 | class ComputePSNR: 21 | """Load a checkpoint, compute some PSNR metrics, and save it to a JSON file.""" 22 | 23 | # Path to config YAML file. 24 | load_config: Path 25 | # Name of the output file. 26 | output_path: Path = Path("output.json") 27 | 28 | def main(self) -> None: 29 | """Main function.""" 30 | config, pipeline, checkpoint_path = eval_setup(self.load_config) 31 | assert self.output_path.suffix == ".json" 32 | # metrics_dict = pipeline.get_average_eval_image_metrics() 33 | metrics_dict, images_dict = pipeline.get_eval_image_metrics_and_images(step=None, interval=10, thresh=0.99) 34 | self.output_path.parent.mkdir(parents=True, exist_ok=True) 35 | # Get the output and define the names to save to 36 | benchmark_info = { 37 | "experiment_name": config.experiment_name, 38 | "method_name": config.method_name, 39 | "checkpoint": str(checkpoint_path), 40 | "results": metrics_dict, 41 | } 42 | # Save output to output file 43 | self.output_path.write_text(json.dumps(benchmark_info, indent=2), "utf8") 44 | CONSOLE.print(f"Saved results to: {self.output_path}") 45 | 46 | 47 | def entrypoint(): 48 | """Entrypoint for use with pyproject scripts.""" 49 | tyro.extras.set_accent_color("bright_yellow") 50 | tyro.cli(ComputePSNR).main() 51 | 52 | 53 | if __name__ == "__main__": 54 | entrypoint() 55 | 56 | # For sphinx docs 57 | get_parser_fn = lambda: tyro.extras.get_parser(ComputePSNR) # noqa 58 | -------------------------------------------------------------------------------- /MSTH/scripts/frames2video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import sys 3 | import numpy as np 4 | import torch 5 | 6 | if len(sys.argv) < 3: 7 | print("[Usage]: python frame2video.py [input .pt filename] [output filename] ([fps])") 8 | 9 | video_data = torch.load(f"{sys.argv[1]}", map_location="cpu") 10 | 11 | assert video_data.dtype is torch.uint8 12 | 13 | video_data = video_data.numpy().astype(np.uint8)[..., [2, 1, 0]] 14 | 15 | print(video_data.shape) 16 | 17 | assert video_data.dtype == np.uint8 18 | 19 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 20 | fps = 30.0 if len(sys.argv) < 4 else float(sys.argv[3]) 21 | 22 | out = cv2.VideoWriter(f"{sys.argv[2]}.mp4", fourcc, fps, (video_data.shape[2], video_data.shape[1])) 23 | for frame in video_data: 24 | out.write(frame) 25 | 26 | out.release() -------------------------------------------------------------------------------- /MSTH/scripts/imm2nerfstudio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from copy import deepcopy 4 | import cv2 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation 7 | 8 | 9 | def imm2nerfstudio(filename="models.json", video_path="", video_suffix=".mp4", num_frames=-1): 10 | imm = json.load(open(filename)) 11 | nfs = {} 12 | nfs_val = {} 13 | initial_cam = imm[0] 14 | nfs["w"] = initial_cam["width"] 15 | nfs["h"] = initial_cam["height"] 16 | if initial_cam["projection_type"] == "fisheye": 17 | nfs["camera_model"] = "OPENCV_FISHEYE" 18 | nfs_val["camera_model"] = "OPENCV_FISHEYE" 19 | else: 20 | raise NotImplementedError 21 | nfs["frames"] = [] 22 | nfs_val["w"] = nfs["w"] 23 | nfs_val["h"] = nfs["h"] 24 | nfs_val["frames"] = [] 25 | for camera_setting in imm: 26 | new_cam = {} 27 | new_cam["file_path"] = camera_setting["name"] + video_suffix 28 | if len(video_path) > 1: 29 | new_cam["file_path"] = video_path + "/" + new_cam["file_path"] 30 | R = Rotation.from_rotvec(camera_setting["orientation"]).as_matrix() 31 | T = np.array(camera_setting["position"]) 32 | pose = np.eye(4) 33 | pose[:3, :3] = R.T 34 | pose[:3, -1] = T 35 | pose_pre = np.eye(4) 36 | pose_pre[1, 1] *= -1 37 | pose_pre[2, 2] *= -1 38 | pose = pose_pre @ pose @ pose_pre 39 | k1 = camera_setting["radial_distortion"][0] 40 | k2 = camera_setting["radial_distortion"][1] 41 | k3 = 0 42 | k4 = 0 43 | fl_x = camera_setting["focal_length"] 44 | fl_y = fl_x 45 | cx = camera_setting["principal_point"][0] 46 | cy = camera_setting["principal_point"][1] 47 | new_cam["transform_matrix"] = pose.tolist() 48 | new_cam["k1"] = k1 49 | new_cam["k2"] = k2 50 | new_cam["k3"] = k3 51 | new_cam["k4"] = k4 52 | new_cam["fl_x"] = fl_x 53 | new_cam["fl_y"] = fl_y 54 | new_cam["cx"] = cx 55 | new_cam["cy"] = cy 56 | if camera_setting["name"] != "camera_0001": 57 | nfs["frames"].append(new_cam) 58 | else: 59 | nfs_val["frames"].append(new_cam) 60 | 61 | example_video = nfs["frames"][0]["file_path"] 62 | vc = cv2.VideoCapture(example_video) 63 | num_frames = int(vc.get(cv2.CAP_PROP_FRAME_COUNT)) if num_frames < 0 else num_frames 64 | print(f"num frames: {num_frames}") 65 | nfs["num_frames"] = num_frames 66 | assert len(nfs_val["frames"]) == 1 67 | 68 | with open("transforms_train.json", "w") as f: 69 | json.dump(nfs, f, indent=4) 70 | with open("transforms_val.json", "w") as f: 71 | json.dump(nfs_val, f, indent=4) 72 | with open("transforms_test.json", "w") as f: 73 | json.dump(nfs_val, f, indent=4) 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument("filename", type=str, default=".") 79 | parser.add_argument("--suffix", type=str, default=".mp4") 80 | parser.add_argument("--video_path", type=str, default="") 81 | parser.add_argument("--num_frames", type=int, default=-1) 82 | 83 | opt = parser.parse_args() 84 | 85 | imm2nerfstudio(opt.filename, opt.video_path, opt.suffix) 86 | -------------------------------------------------------------------------------- /MSTH/scripts/nerfacto.sh: -------------------------------------------------------------------------------- 1 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 2 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH" 3 | CUDA_VISIBLE_DEVICES=$1 \ 4 | python -m scripts.train nerfacto_profiling \ 5 | --experiment-name nerfacto \ 6 | --vis wandb \ 7 | --output-dir tmp \ 8 | --pipeline.model.predict_normals False 9 | -------------------------------------------------------------------------------- /MSTH/scripts/nerfacto_profiling.sh: -------------------------------------------------------------------------------- 1 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 2 | export PROFILING=1 3 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH" 4 | CUDA_VISIBLE_DEVICES=$1 \ 5 | python -m scripts.train nerfacto_profiling \ 6 | --experiment-name nerfacto_profiling \ 7 | --vis tensorboard \ 8 | --output-dir tmp 9 | -------------------------------------------------------------------------------- /MSTH/scripts/prepdata/downsample.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | import os 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | 7 | input_dir = "flame_salmon_videos" 8 | output_dir = "flame_salmon_videos_2" 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--height", type=int, default=720) 13 | parser.add_argument("-w", type=int, default=960) 14 | parser.add_argument("-i", type=str, default="flame_salmon_videos") 15 | parser.add_argument("-f", type=int, default=30) 16 | parser.add_argument("-o", type=str, default="flame_salmon_videos_2") 17 | parser.add_argument("-t", type=str, default="area") 18 | 19 | args = parser.parse_args() 20 | 21 | interps = { 22 | "area": cv2.INTER_AREA, 23 | "cubic": cv2.INTER_CUBIC, 24 | "linear": cv2.INTER_LINEAR, 25 | } 26 | 27 | opth = Path(args.o) 28 | if not opth.exists(): 29 | opth.mkdir(parents=True) 30 | input_dir = args.i 31 | output_dir = args.o 32 | 33 | # 遍历输入文件夹中的所有视频文件 34 | for filename in tqdm(os.listdir(input_dir)): 35 | if filename.endswith(".mp4") or filename.endswith(".avi") or filename.endswith(".MP4"): 36 | # 构造输入输出文件路径 37 | input_path = os.path.join(input_dir, filename) 38 | output_path = os.path.join(output_dir, filename) 39 | # 打开输入视频文件 40 | cap = cv2.VideoCapture(input_path) 41 | # 获取视频帧率、宽度和高度 42 | fps = cap.get(cv2.CAP_PROP_FPS) 43 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 44 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 45 | # 构造输出视频写入器 46 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") # 可以根据需要修改编码器 47 | out = cv2.VideoWriter( 48 | output_path, 49 | fourcc, 50 | fps, 51 | (int(args.w), int(args.height)), 52 | ) 53 | # 逐帧读取、下采样并写入输出视频 54 | while cap.isOpened(): 55 | ret, frame = cap.read() 56 | if ret: 57 | # 下采样 58 | frame = cv2.resize( 59 | frame, 60 | dsize=(args.w, args.height), 61 | interpolation=interps[args.t], 62 | ) 63 | # 写入输出视频 64 | out.write(frame) 65 | else: 66 | break 67 | # 释放资源 68 | cap.release() 69 | out.release() 70 | -------------------------------------------------------------------------------- /MSTH/scripts/prepdata/downsample_fps.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--target_fps", type=int, default=30) 9 | parser.add_argument("--input_path", type=str, default="") 10 | parser.add_argument("--output_path", type=str, default="") 11 | args = parser.parse_args() 12 | 13 | if not os.path.exists(args.output_path): 14 | os.makedirs(args.output_path) 15 | 16 | for filename in tqdm(os.listdir(args.input_path)): 17 | if filename.endswith("mp4") or filename.endswith("MP4"): 18 | fp = os.path.join(args.input_path, filename) 19 | fpo = os.path.join(args.output_path, filename) 20 | cmd = f"ffmpeg -i {fp} -r {args.target_fps} {fpo}" 21 | os.system(cmd) 22 | -------------------------------------------------------------------------------- /MSTH/scripts/prepdata/neural3d.sh: -------------------------------------------------------------------------------- 1 | SCENE=$1 2 | echo "downsample" 3 | python MSTH/scripts/prepdata/downsample.py --height 1014 -w 1352 -i /data/machine/data/$SCENE/videos -o /data/machine/data/$SCENE/videos_2/ -t area 4 | 5 | echo "running json from pose_bound.npy" 6 | python MSTH/scripts/prepdata/llff2nerf.py /data/machine/data/$SCENE/ --videos videos --downscale 1 --hold_list 0 --num_frames 300 7 | python MSTH/scripts/prepdata/llff2nerf.py /data/machine/data/$SCENE/ --videos videos_2 --downscale 1 --hold_list 0 --num_frames 300 8 | -------------------------------------------------------------------------------- /MSTH/scripts/profiling/cProfile.sh: -------------------------------------------------------------------------------- 1 | cd /data/czl/nerf/MSTH_new/MSTH 2 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 3 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH_new" 4 | CUDA_VISIBLE_DEVICES=$1 \ 5 | python -m cProfile -o cprofile_ret.prof scripts/train.py dsth_with_base \ 6 | --experiment-name dsth_with_base \ 7 | --vis tensorboard \ 8 | --max_num_iterations 500 \ 9 | --output-dir tmp \ 10 | --pipeline.datamanager.dataparser.scale_factor 0.5 \ 11 | --pipeline.datamanager.use_stratified_pixel_sampler True \ 12 | --pipeline.datamanager.static_dynamic_sampling_ratio 50.0 \ 13 | --pipeline.datamanager.static_dynamic_sampling_ratio_end 10.0 \ 14 | --pipeline.datamanager.static_ratio_decay_total_steps 20000 \ 15 | --save_eval_video False \ 16 | --steps_full_video 1000000000000000 17 | -------------------------------------------------------------------------------- /MSTH/scripts/profiling/fast_rendering.sh: -------------------------------------------------------------------------------- 1 | cd /data/czl/nerf/MSTH_new/MSTH 2 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 3 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH_new" 4 | CUDA_VISIBLE_DEVICES=$1 \ 5 | kernprof -lv -o line_profiler.prof scripts/train.py tp29_hierarchicay_16384_nostatic \ 6 | --experiment-name profiling \ 7 | --vis tensorboard \ 8 | --max_num_iterations 500 \ 9 | --steps_per_eval_batch 100 \ 10 | --steps_per_eval_image 100 \ 11 | --output-dir tmp 12 | -------------------------------------------------------------------------------- /MSTH/scripts/profiling/instant_ngp.sh: -------------------------------------------------------------------------------- 1 | cd /data/czl/nerf/MSTH_new 2 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 3 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH_new" 4 | export PROFILING=1 5 | export CSV_PATH=instant-ngp.csv 6 | CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=$1 \ 7 | python -m MSTH.scripts.train instant-ngp \ 8 | --experiment-name instant-ngp \ 9 | --vis tensorboard \ 10 | --output-dir tmp \ 11 | --max_num_iterations 1000 \ 12 | --steps_per_eval_image 100 \ 13 | --data /data/machine/data/flame_salmon_image 14 | -------------------------------------------------------------------------------- /MSTH/scripts/profiling/line_prof.sh: -------------------------------------------------------------------------------- 1 | cd /data/czl/nerf/MSTH_new/MSTH 2 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 3 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH_new" 4 | CUDA_VISIBLE_DEVICES=$1 \ 5 | kernprof -lv -o line_profiler.prof scripts/train.py dsth_with_base \ 6 | --experiment-name dsth_with_base \ 7 | --vis tensorboard \ 8 | --max_num_iterations 1000 \ 9 | --output-dir tmp \ 10 | --pipeline.datamanager.dataparser.scale_factor 0.5 \ 11 | --pipeline.datamanager.use_stratified_pixel_sampler True \ 12 | --pipeline.datamanager.static_dynamic_sampling_ratio 50.0 \ 13 | --pipeline.datamanager.static_dynamic_sampling_ratio_end 10.0 \ 14 | --pipeline.datamanager.static_ratio_decay_total_steps 20000 \ 15 | --save_eval_video False \ 16 | --steps_full_video 1000000000000000 17 | -------------------------------------------------------------------------------- /MSTH/scripts/run.sh: -------------------------------------------------------------------------------- 1 | viewer=${3:-'wandb'} 2 | port=${4:-'7007'} 3 | cd /opt/czl/nerf/exp/MSTH 4 | # export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 5 | # export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH_new" 6 | CUDA_VISIBLE_DEVICES=$1 \ 7 | python -m MSTH.scripts.train ${2} \ 8 | --experiment-name ${2} \ 9 | --vis $viewer \ 10 | --output-dir tmp \ 11 | --viewer.websocket-port $port 12 | -------------------------------------------------------------------------------- /MSTH/scripts/run_imm.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import time 4 | 5 | 6 | def get_gpu_memory_usage(): 7 | result = subprocess.run( 8 | ["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv,nounits,noheader"], stdout=subprocess.PIPE 9 | ) 10 | output = result.stdout.decode("utf-8").strip() 11 | output = output.split("\n") 12 | infos = [] 13 | for gpu_id, content in enumerate(output): 14 | gpu_memory = [float(x.strip()) for x in content.split(",")] 15 | infos.append({"used": gpu_memory[0], "total": gpu_memory[1]}) 16 | return infos 17 | 18 | 19 | print(get_gpu_memory_usage()) 20 | with open("/data/czl/nerf/MSTH_new/MSTH/scripts/task_4_19.txt", "r") as f: 21 | tasks = f.readlines() 22 | 23 | using_gpu_ids = [2, 3] 24 | 25 | tasks = [task.strip() for task in tasks] 26 | while len(tasks) > 0: 27 | infos = get_gpu_memory_usage() 28 | for gpu_id in using_gpu_ids: 29 | if infos[gpu_id]["used"] < 100: 30 | cmd = f"bash /data/czl/nerf/MSTH_new/MSTH/scripts/run.sh {gpu_id} {tasks.pop(0)} wandb &" 31 | print(f"GPU {gpu_id} is available, Running: {cmd}") 32 | os.system(cmd) 33 | time.sleep(50) 34 | break 35 | 36 | # os.system(cmd) -------------------------------------------------------------------------------- /MSTH/scripts/run_seq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for i in {1..20}; do 3 | bash MSTH/scripts/run.sh 3 anoynmous_method_${i} wandb 4 | rm -r tmp 5 | done 6 | -------------------------------------------------------------------------------- /MSTH/scripts/run_taks.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import time 4 | 5 | 6 | def get_gpu_memory_usage(): 7 | result = subprocess.run( 8 | ["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv,nounits,noheader"], stdout=subprocess.PIPE 9 | ) 10 | output = result.stdout.decode("utf-8").strip() 11 | output = output.split("\n") 12 | infos = [] 13 | for gpu_id, content in enumerate(output): 14 | gpu_memory = [float(x.strip()) for x in content.split(",")] 15 | infos.append({"used": gpu_memory[0], "total": gpu_memory[1]}) 16 | return infos 17 | 18 | 19 | print(get_gpu_memory_usage()) 20 | with open("/data/czl/nerf/MSTH_new/MSTH/scripts/imm_task.txt", "r") as f: 21 | tasks = f.readlines() 22 | 23 | using_gpu_ids = [3] 24 | 25 | tasks = [task.strip() for task in tasks] 26 | while len(tasks) > 0: 27 | infos = get_gpu_memory_usage() 28 | for gpu_id in using_gpu_ids: 29 | if infos[gpu_id]["used"] < 100: 30 | cmd = f"bash /data/czl/nerf/MSTH_new/MSTH/scripts/run.sh {gpu_id} {tasks.pop(0)} wandb &" 31 | print(f"GPU {gpu_id} is available, Running: {cmd}") 32 | os.system(cmd) 33 | time.sleep(50) 34 | break 35 | 36 | # os.system(cmd) 37 | -------------------------------------------------------------------------------- /MSTH/scripts/sth.sh: -------------------------------------------------------------------------------- 1 | viewer=${2:-'tensorboard'} 2 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 3 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH" 4 | CUDA_VISIBLE_DEVICES=$1 \ 5 | python -m scripts.train sth \ 6 | --experiment-name sth \ 7 | --vis $viewer \ 8 | --output-dir tmp 9 | -------------------------------------------------------------------------------- /MSTH/scripts/sth_profiling.sh: -------------------------------------------------------------------------------- 1 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 2 | export PROFILING=1 3 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH" 4 | CUDA_VISIBLE_DEVICES=$1 \ 5 | python -m scripts.train sth_profiling \ 6 | --experiment-name sth_profiling \ 7 | --vis tensorboard \ 8 | --output-dir tmp 9 | -------------------------------------------------------------------------------- /MSTH/scripts/sth_rect.sh: -------------------------------------------------------------------------------- 1 | viewer=${2:-'tensorboard'} 2 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/data/czl/anaconda3/envs/MSTH/lib/" 3 | export PYTHONPATH="$PYTHONPATH:/data/czl/nerf/MSTH" 4 | CUDA_VISIBLE_DEVICES=$1 \ 5 | python -m scripts.train sth_rect \ 6 | --experiment-name sth_rect \ 7 | --vis $viewer \ 8 | --output-dir tmp \ 9 | --pipeline.datamanager.dataparser.scene_scale 4 \ 10 | --pipeline.model.use_proposal_weight_anneal True 11 | -------------------------------------------------------------------------------- /MSTH/scripts/stream_nerfacto_baseline.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="$PYTHONPATH:/opt/czl/nerf/MSTH" 2 | CUDA_VISIBLE_DEVICES=$1 \ 3 | python -m scripts.video_train stream-nerfacto-baseline \ 4 | --experiment-name streamable-nerfacto \ 5 | --vis wandb \ 6 | --output-dir tmp 7 | -------------------------------------------------------------------------------- /MSTH/scripts/test_storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | pth = "/data/machine/nerfstudio/tmp/base_it40000_base_lrlonger/Spatial_Time_Hashing_With_Base/2023-04-16_202626/nerfstudio_models/step-000039999.ckpt" 4 | a = torch.load(pth) 5 | pipeline = a["pipeline"] 6 | b = {} 7 | for k in pipeline.keys(): 8 | if k.startswith("_model.field") or k.startswith("_model.proposal"): 9 | b[k] = pipeline[k] 10 | 11 | torch.save(b, "test.ckpt") 12 | -------------------------------------------------------------------------------- /MSTH/scripts/tunning_machine.py: -------------------------------------------------------------------------------- 1 | from MSTH.configs.method_configs import * 2 | import numpy as np 3 | import itertools 4 | import random 5 | import os 6 | from pathlib import Path 7 | 8 | base_method = method_configs["05-horse-10-2-hidden-dim-128"] 9 | 10 | 11 | def setp(exps): 12 | def setfunc(x, v): 13 | if isinstance(exps, (tuple, list)): 14 | for exp in exps: 15 | command = "x." + exp + "=v" 16 | exec(command) 17 | else: 18 | command = "x." + exps + "=v" 19 | exec(command) 20 | 21 | return setfunc 22 | 23 | 24 | set_functions = { 25 | "dataset": setp("pipeline.datamanager.dataparser.data"), 26 | } 27 | 28 | potential_values = { 29 | "dataset": [ 30 | Path("/data/machine/data/immersive/05_Horse_2"), 31 | Path("/data/machine/data/immersive/01_Welder_2"), 32 | Path("/data/machine/data/immersive/09_Alexa_Meade_Exhibit_2"), 33 | Path("/data/machine/data/immersive/10_Face_2"), 34 | Path("/data/machine/data/immersive/02_Flames_2"), 35 | Path("/data/machine/data/immersive/11_Alexa_2"), 36 | Path("/data/machine/data/immersive/04_Truck_2"), 37 | ] 38 | } 39 | 40 | all_hyper_parameter_key = potential_values.keys() 41 | all_hyper_parameter_value = [potential_values[k] for k in all_hyper_parameter_key] 42 | all_specs = list(itertools.product(*all_hyper_parameter_value)) 43 | all_specs = [{k: v for k, v in zip(all_hyper_parameter_key, spec)} for spec in all_specs] 44 | random.shuffle(all_specs) 45 | print("==== ALL SPECS ====") 46 | print(all_specs) 47 | 48 | for i, spec in enumerate(all_specs): 49 | method_configs[f"imm_{i}_4_23"] = copy.deepcopy(base_method) 50 | for k, v in spec.items(): 51 | set_functions[k](method_configs[f"imm_{i}_4_23"], v) 52 | print(method_configs[f"imm_{i}_4_23"]) 53 | -------------------------------------------------------------------------------- /MSTH/scripts/video_frames.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | input_dir=$1 4 | output_dir=$2 5 | frame_num=$3 6 | 7 | # output direction 8 | mkdir -p $output_dir 9 | 10 | # tranverse through all video files 11 | for video_file in $input_dir/*.{mp4,avi,mkv,flv,wmv} 12 | do 13 | # 从视频文件名中提取文件名和扩展名 14 | filename=$(basename -- "$video_file") 15 | extension="${filename##*.}" 16 | filename="${filename%.*}" 17 | 18 | # 指定输出截图的文件名和路径 19 | output_file="$output_dir/$filename-$frame_num.jpg" 20 | 21 | # 使用FFmpeg截取视频的第K帧,并将其保存为JPEG文件 22 | ffmpeg -i "$video_file" -vf "select=eq(n\,$frame_num)" -q:v 1 "$output_file" 23 | done 24 | -------------------------------------------------------------------------------- /MSTH/scripts/video_nerfacto_baseline.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="$PYTHONPATH:/opt/czl/nerf/MSTH" 2 | CUDA_VISIBLE_DEVICES=$1 \ 3 | python -m scripts.video_train video-nerfacto-baseline \ 4 | --vis tensorboard \ 5 | --output-dir tmp 6 | -------------------------------------------------------------------------------- /MSTH/scripts/video_tensorf_baseline.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="$PYTHONPATH:/opt/czl/nerf/MSTH" 2 | CUDA_VISIBLE_DEVICES=$1 \ 3 | python scripts/video_train.py video-tensorf-baseline \ 4 | --vis wandb \ 5 | --output-dir tmp 6 | -------------------------------------------------------------------------------- /MSTH/test_grid.py: -------------------------------------------------------------------------------- 1 | from rich.console import Console 2 | from gridencoder.grid import _backend 3 | import numpy as np 4 | import torch 5 | import copy 6 | import torch.nn as nn 7 | import time 8 | from utils import Timer 9 | CONSOLE = Console() 10 | 11 | grid_encode_forward = _backend.grid_encode_forward 12 | grid_encode_hash_reinitialize = _backend.grid_encode_hash_reinitialize 13 | grid_encode_set_static = _backend.grid_encode_set_static 14 | 15 | log2_hashmap_size = 19 16 | max_params = 2 ** log2_hashmap_size 17 | num_levels = 16 18 | base_resolution = 16 19 | per_level_scale = 2 20 | level_dim = 2 21 | align_corners = False 22 | input_dim = 3 23 | 24 | offsets = [] 25 | offset = 0 26 | for i in range(num_levels): 27 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 28 | params_in_level = min(max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number 29 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 30 | offsets.append(offset) 31 | offset += params_in_level 32 | offsets.append(offset) 33 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)).cuda() 34 | 35 | inputs = torch.rand(4096*48, 3).cuda() 36 | embeddings = nn.Parameter(torch.empty(offset, level_dim)).cuda() 37 | std = 1e-4 38 | embeddings.data.uniform_(-std, std) 39 | 40 | B, D = inputs.shape # batch size, coord dim 41 | L = offsets.shape[0] - 1 # level 42 | C = embeddings.shape[1] # embedding dim for each level 43 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 44 | H = base_resolution # base resolution 45 | outputs = torch.zeros(L, B, C, device=inputs.device, dtype=embeddings.dtype).cuda() 46 | 47 | CONSOLE.log("start forwarding") 48 | print(outputs.sum()) 49 | grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, None, 0, align_corners, 0) 50 | CONSOLE.log("end forwarding") 51 | print(outputs.sum()) 52 | 53 | CONSOLE.log("start forwarding") 54 | inputs = torch.rand(1, 3).cuda() 55 | old_embeddings = torch.clone(embeddings) 56 | s = time.time() 57 | inputs[0][0] = 0.7321 58 | inputs[0][1] = 0.8612 59 | inputs[0][2] = 0.7708 60 | B = 1 61 | D = 3 62 | C = 2 63 | L = 1 64 | S = 0 65 | H = 512 66 | align_corners = False 67 | offsets = [0, 513**3+1] 68 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)).cuda() 69 | # 0.7321, 0.8612, 0.7708 70 | # grid_encode_hash_reinitialize(inputs, embeddings.to(torch.float16), offsets, B, D, C, L, S, H, 0, align_corners, 0, std) 71 | grid_mask = torch.ones(offsets[-1]).float().cuda() 72 | for i in range(100): 73 | with Timer(des="reinit"): 74 | grid_encode_hash_reinitialize(inputs, embeddings.to(torch.float16), offsets, B, D, C, L, S, H, 0, align_corners, 0, std, grid_mask.bool()) 75 | #grid_encode_set_static(inputs, grid_mask.to(torch.float16), offsets, B, D, C, L, S, H, 0, align_corners) 76 | print((grid_mask==0).sum()) 77 | grid_encode_set_static(inputs, grid_mask, offsets, B, D, C, L, S, H, 0, align_corners) 78 | print((grid_mask==0).sum()) 79 | print(time.time()-s) 80 | 81 | -------------------------------------------------------------------------------- /MSTH/wandb/latest-run: -------------------------------------------------------------------------------- 1 | run-20230413_200948-gorthja9 -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_195110-2c1ru4ae/files/config.yaml: -------------------------------------------------------------------------------- 1 | wandb_version: 1 2 | 3 | _wandb: 4 | desc: null 5 | value: 6 | cli_version: 0.13.5 7 | is_jupyter_run: false 8 | is_kaggle_kernel: false 9 | python_version: 3.8.16 10 | start_time: 1681386670.441585 11 | t: 12 | 1: 13 | - 55 14 | 3: 15 | - 16 16 | - 23 17 | 4: 3.8.16 18 | 5: 0.13.5 19 | 8: 20 | - 5 21 | architecture: 22 | desc: null 23 | value: CNN 24 | dataset: 25 | desc: null 26 | value: CIFAR-100 27 | epochs: 28 | desc: null 29 | value: 10 30 | learning_rate: 31 | desc: null 32 | value: 0.02 33 | -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_195110-2c1ru4ae/run-2c1ru4ae.wandb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/MSTH/wandb/run-20230413_195110-2c1ru4ae/run-2c1ru4ae.wandb -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200145-7prtgoar/files/config.yaml: -------------------------------------------------------------------------------- 1 | wandb_version: 1 2 | 3 | learning_rate: 4 | desc: null 5 | value: 0.02 6 | architecture: 7 | desc: null 8 | value: CNN 9 | dataset: 10 | desc: null 11 | value: CIFAR-100 12 | epochs: 13 | desc: null 14 | value: 10 15 | _wandb: 16 | desc: null 17 | value: 18 | python_version: 3.8.16 19 | cli_version: 0.14.2 20 | is_jupyter_run: false 21 | is_kaggle_kernel: false 22 | start_time: 1681387305.81065 23 | t: 24 | 1: 25 | - 55 26 | 3: 27 | - 16 28 | - 23 29 | 4: 3.8.16 30 | 5: 0.14.2 31 | 8: 32 | - 5 33 | -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200145-7prtgoar/run-7prtgoar.wandb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/MSTH/wandb/run-20230413_200145-7prtgoar/run-7prtgoar.wandb -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200502-qmbtu7bi/files/config.yaml: -------------------------------------------------------------------------------- 1 | wandb_version: 1 2 | 3 | learning_rate: 4 | desc: null 5 | value: 0.02 6 | architecture: 7 | desc: null 8 | value: CNN 9 | dataset: 10 | desc: null 11 | value: CIFAR-100 12 | epochs: 13 | desc: null 14 | value: 10 15 | _wandb: 16 | desc: null 17 | value: 18 | python_version: 3.8.16 19 | cli_version: 0.14.2 20 | is_jupyter_run: false 21 | is_kaggle_kernel: false 22 | start_time: 1681387502.352086 23 | t: 24 | 1: 25 | - 55 26 | 3: 27 | - 16 28 | - 23 29 | 4: 3.8.16 30 | 5: 0.14.2 31 | 8: 32 | - 5 33 | -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200502-qmbtu7bi/run-qmbtu7bi.wandb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/MSTH/wandb/run-20230413_200502-qmbtu7bi/run-qmbtu7bi.wandb -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200948-gorthja9/files/code/TRTT/test.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import random 3 | 4 | # start a new wandb run to track this script 5 | wandb.init( 6 | # set the wandb project where this run will be logged 7 | project="my-awesome-project", 8 | # track hyperparameters and run metadata 9 | config={ 10 | "learning_rate": 0.02, 11 | "architecture": "CNN", 12 | "dataset": "CIFAR-100", 13 | "epochs": 10, 14 | }, 15 | ) 16 | 17 | # simulate training 18 | epochs = 10 19 | offset = random.random() / 5 20 | for epoch in range(2, epochs): 21 | acc = 1 - 2**-epoch - random.random() / epoch - offset 22 | loss = 2**-epoch + random.random() / epoch + offset 23 | 24 | # log metrics to wandb 25 | wandb.log({"acc": acc, "loss": loss}) 26 | 27 | # [optional] finish the wandb run, necessary in notebooks 28 | wandb.finish() 29 | -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200948-gorthja9/files/conda-environment.yaml: -------------------------------------------------------------------------------- 1 | name: nerfstudio 2 | channels: 3 | - defaults 4 | prefix: /home/feng/.conda/envs/nerfstudio 5 | -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200948-gorthja9/files/config.yaml: -------------------------------------------------------------------------------- 1 | wandb_version: 1 2 | 3 | learning_rate: 4 | desc: null 5 | value: 0.02 6 | architecture: 7 | desc: null 8 | value: CNN 9 | dataset: 10 | desc: null 11 | value: CIFAR-100 12 | epochs: 13 | desc: null 14 | value: 10 15 | _wandb: 16 | desc: null 17 | value: 18 | code_path: code/MSTH/test.py 19 | python_version: 3.8.16 20 | cli_version: 0.14.2 21 | is_jupyter_run: false 22 | is_kaggle_kernel: false 23 | start_time: 1681387788.135257 24 | t: 25 | 1: 26 | - 55 27 | 2: 28 | - 55 29 | 3: 30 | - 2 31 | - 16 32 | - 23 33 | 4: 3.8.16 34 | 5: 0.14.2 35 | 8: 36 | - 5 37 | -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200948-gorthja9/files/wandb-summary.json: -------------------------------------------------------------------------------- 1 | {"acc": 0.8575738787050718, "loss": 0.13056499067897515, "_timestamp": 1681387790.878599, "_runtime": 2.7433419227600098, "_step": 7, "_wandb": {"runtime": 2}} -------------------------------------------------------------------------------- /MSTH/wandb/run-20230413_200948-gorthja9/run-gorthja9.wandb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/MSTH/wandb/run-20230413_200948-gorthja9/run-gorthja9.wandb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Masked Space-Time Hash Encoding for Efficient Dynamic Scene Reconstruction 2 | [NeurIPS 2023 Spotlight] 3 | 4 | [Project Page](https://masked-spacetime-hashing.github.io/) | [Paper](https://arxiv.org/pdf/2310.17527.pdf) | [Data](https://huggingface.co/datasets/masked-spacetime-hashing/Campus) 5 | 6 | [Feng Wang]()1*, [Zilong Chen]()1*, Guokang Wang1, Yafei Song2, [Huaping Liu]()1 7 | 8 | 1Department of Computer Science and Technology, Tsinghua University 2Alibaba Group 9 | 10 | 11 | ### Introduction 12 | We propose the Masked Space-Time Hash encoding (MSTH), a novel method for efficiently reconstructing dynamic 3D scenes from multi-view or monocular videos. Based on the observation that dynamic scenes often contain substantial static areas that result in redundancy in storage and computations, MSTH represents a dynamic scene as a weighted combination of a 3D hash encoding and a 4D hash encoding. The weights for the two components are represented by a learnable mask which is guided by an uncertainty-based objective to reflect the spatial and temporal importance of each 3D position. With this design, our method can reduce the hash collision rate by avoiding redundant queries and modifications on static areas, making it feasible to represent a large number of space-time voxels by hash tables with small size. Besides, without the requirements to fit the large numbers of temporally redundant features independently, our method is easier to optimize and converge rapidly with only twenty minutes of training for a 300-frame dynamic scene. We evaluate our method on extensive dynamic scenes. As a result, MSTH obtains consistently better results than previous state-of-the-art methods with only 20 minutes of training time and 130 MB of memory storage. 13 | 14 | ### Demos 15 | We recommend to visit our [project page](https://masked-spacetime-hashing.github.io/) for watching clear videos. 16 | #### [Immersive Dataset](https://augmentedperception.github.io/deepviewvideo/) 17 | 18 | https://github.com/masked-spacetime-hashing/msth/assets/43294876/c14dcb57-c600-43b9-adf1-f8a532785d8f 19 | 20 | 21 | #### [Plenoptic Dataset](https://neural-3d-video.github.io/) 22 | 23 | https://github.com/masked-spacetime-hashing/msth/assets/43294876/7094fee1-3cfb-49f4-abed-dc5f61a7fb72 24 | 25 | #### [Campus Dataset](https://huggingface.co/datasets/masked-spacetime-hashing/Campus) 26 | 27 | https://github.com/masked-spacetime-hashing/msth/assets/43294876/1fbc7417-e66b-4cdd-8e8c-1863f031fa30 28 | 29 | 30 | ### Instructions 31 | #### Create env 32 | ```bash 33 | conda create -n MSTH python=3.8 34 | ``` 35 | ### Install dependencies 36 | ```bash 37 | pip install -e . 38 | ``` 39 | and install tiny-cuda-nn for fast feed forward NNs: 40 | ```bash 41 | pip install 42 | ``` 43 | #### Download data 44 | ```bash 45 | python download.py --scene 46 | ``` 47 | #### Run MSTH 48 | ```bash 49 | python -m MSTH.script.train --experiment-name --vis --output-dir 50 | ``` 51 | #### Viewer 52 | Our code provides a viewer based on the [NeRFStudio web viewer](). 53 | 54 | ### Campus Dataset 55 | 56 | ### Acknowledgements 57 | Our code is based on [NeRFStudio](https://github.com/nerfstudio-project/nerfstudio) 58 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import zipfile 3 | from pathlib import Path 4 | 5 | import wget 6 | 7 | n3dv_scenes = ["sear_steak", "cut_roasted_beef", "cook_spinach", "coffee_martini", "flame_salmon"] 8 | immersive_scenes = [] 9 | dnerf_scenes = [] 10 | campus_scenes = [] 11 | 12 | tmp_dir = Path("./.msth_tmp") 13 | tmp_dir.mkdir(exist_ok=True) 14 | 15 | 16 | def download_n3dv(scene: str, download_dir: Path): 17 | assert scene in n3dv_scenes, f"Scene {scene} not found in N3DV dataset" 18 | data_url = f"https://github.com/facebookresearch/Neural_3D_Video/releases/download/v1.0/{scene}.zip" 19 | wget.download(data_url, out=str(download_dir)) 20 | zipfile.ZipFile(download_dir / f"{scene}.zip").extractall(download_dir) 21 | 22 | 23 | def download_immersive(): 24 | pass 25 | 26 | 27 | def download_dnerf(): 28 | pass 29 | 30 | 31 | def download_campus(): 32 | pass 33 | 34 | 35 | download_fns = { 36 | "n3dv": download_n3dv, 37 | "immersive": download_immersive, 38 | "dnerf": download_dnerf, 39 | "campus": download_campus, 40 | } 41 | 42 | 43 | def main(): 44 | parser = argparse.ArgumentParser("Download datasets") 45 | parser.add_argument("dataset", type=str, default="n3dv", choices=["n3dv", "immersive", "dnerf", "campus"]) 46 | parser.add_argument("--scene", type=str, required=True) 47 | parser.add_argument("--download-dir", type=Path, default=Path("data")) 48 | 49 | opt = parser.parse_args() 50 | download_fns[opt.dataset](opt.scene) 51 | -------------------------------------------------------------------------------- /nerfstudio/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/cameras/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/configs/config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Some utility code for configs. 17 | """ 18 | 19 | from __future__ import annotations 20 | 21 | from dataclasses import field 22 | from typing import Any, Dict 23 | 24 | from rich.console import Console 25 | 26 | CONSOLE = Console() 27 | # pylint: disable=import-outside-toplevel 28 | 29 | # cannot use mutable types directly within dataclass; abstracting default factory calls 30 | def to_immutable_dict(d: Dict[str, Any]): 31 | """Method to convert mutable dict to default factory dict 32 | 33 | Args: 34 | d: dictionary to convert into default factory dict for dataclass 35 | """ 36 | return field(default_factory=lambda: dict(d)) 37 | 38 | 39 | def convert_markup_to_ansi(markup_string: str) -> str: 40 | """Convert rich-style markup to ANSI sequences for command-line formatting. 41 | 42 | Args: 43 | markup_string: Text with rich-style markup. 44 | 45 | Returns: 46 | Text formatted via ANSI sequences. 47 | """ 48 | with CONSOLE.capture() as out: 49 | CONSOLE.print(markup_string, soft_wrap=True) 50 | return out.get() 51 | -------------------------------------------------------------------------------- /nerfstudio/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/data/datamanagers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/data/datamanagers/depth_datamanager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Depth datamanager. 17 | """ 18 | 19 | from dataclasses import dataclass, field 20 | from typing import Type 21 | 22 | from nerfstudio.data.datamanagers import base_datamanager 23 | from nerfstudio.data.datasets.depth_dataset import DepthDataset 24 | 25 | 26 | @dataclass 27 | class DepthDataManagerConfig(base_datamanager.VanillaDataManagerConfig): 28 | """A depth datamanager - required to use with .setup()""" 29 | 30 | _target: Type = field(default_factory=lambda: DepthDataManager) 31 | 32 | 33 | class DepthDataManager(base_datamanager.VanillaDataManager): # pylint: disable=abstract-method 34 | """Data manager implementation for data that also requires processing depth data. 35 | Args: 36 | config: the DataManagerConfig used to instantiate class 37 | """ 38 | 39 | def create_train_dataset(self) -> DepthDataset: 40 | self.train_dataparser_outputs = self.dataparser.get_dataparser_outputs(split="train") 41 | return DepthDataset( 42 | dataparser_outputs=self.train_dataparser_outputs, 43 | ) 44 | 45 | def create_eval_dataset(self) -> DepthDataset: 46 | return DepthDataset( 47 | dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), 48 | ) 49 | -------------------------------------------------------------------------------- /nerfstudio/data/datamanagers/semantic_datamanager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Semantic datamanager. 17 | """ 18 | 19 | from dataclasses import dataclass, field 20 | from typing import Type 21 | 22 | from nerfstudio.data.datamanagers.base_datamanager import ( 23 | VanillaDataManager, 24 | VanillaDataManagerConfig, 25 | ) 26 | from nerfstudio.data.datasets.semantic_dataset import SemanticDataset 27 | 28 | 29 | @dataclass 30 | class SemanticDataManagerConfig(VanillaDataManagerConfig): 31 | """A semantic datamanager - required to use with .setup()""" 32 | 33 | _target: Type = field(default_factory=lambda: SemanticDataManager) 34 | 35 | 36 | class SemanticDataManager(VanillaDataManager): # pylint: disable=abstract-method 37 | """Data manager implementation for data that also requires processing semantic data. 38 | 39 | Args: 40 | config: the DataManagerConfig used to instantiate class 41 | """ 42 | 43 | def create_train_dataset(self) -> SemanticDataset: 44 | self.train_dataparser_outputs = self.dataparser.get_dataparser_outputs(split="train") 45 | return SemanticDataset( 46 | dataparser_outputs=self.train_dataparser_outputs, 47 | scale_factor=self.config.camera_res_scale_factor, 48 | ) 49 | 50 | def create_eval_dataset(self) -> SemanticDataset: 51 | return SemanticDataset( 52 | dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), 53 | scale_factor=self.config.camera_res_scale_factor, 54 | ) 55 | -------------------------------------------------------------------------------- /nerfstudio/data/datamanagers/variable_res_datamanager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Data loader for variable resolution datasets, where batching raw image tensors isn't possible. 17 | """ 18 | 19 | from __future__ import annotations 20 | 21 | from dataclasses import dataclass 22 | from typing import Dict, List 23 | 24 | from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig 25 | from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate 26 | 27 | 28 | def variable_res_collate(batch: List[Dict]) -> Dict: 29 | """Default collate function for the cached dataloader. 30 | Args: 31 | batch: Batch of samples from the dataset. 32 | Returns: 33 | Collated batch. 34 | """ 35 | images = [] 36 | masks = [] 37 | for data in batch: 38 | image = data.pop("image") 39 | mask = data.pop("mask", None) 40 | images.append(image) 41 | if mask: 42 | masks.append(mask) 43 | 44 | new_batch: dict = nerfstudio_collate(batch) 45 | new_batch["image"] = images 46 | if masks: 47 | new_batch["mask"] = masks 48 | 49 | return new_batch 50 | 51 | 52 | @dataclass 53 | class VariableResDataManagerConfig(VanillaDataManagerConfig): 54 | """A datamanager for variable resolution datasets, with presets to optimize 55 | for the fact that we are now dealing with lists of images and masks. 56 | """ 57 | 58 | train_num_images_to_sample_from: int = 40 59 | """Number of images to sample during training iteration.""" 60 | train_num_times_to_repeat_images: int = 100 61 | """When not training on all images, number of iterations before picking new 62 | images. If -1, never pick new images.""" 63 | eval_num_images_to_sample_from: int = 40 64 | """Number of images to sample during eval iteration.""" 65 | eval_num_times_to_repeat_images: int = 100 66 | collate_fn = staticmethod(variable_res_collate) 67 | -------------------------------------------------------------------------------- /nerfstudio/data/dataparsers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/data/datasets/depth_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Depth dataset. 17 | """ 18 | 19 | from typing import Dict 20 | 21 | from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs 22 | from nerfstudio.data.datasets.base_dataset import InputDataset 23 | from nerfstudio.data.utils.data_utils import get_depth_image_from_path 24 | 25 | 26 | class DepthDataset(InputDataset): 27 | """Dataset that returns images and depths. 28 | 29 | Args: 30 | dataparser_outputs: description of where and how to read input images. 31 | scale_factor: The scaling factor for the dataparser outputs. 32 | """ 33 | 34 | def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0): 35 | super().__init__(dataparser_outputs, scale_factor) 36 | assert ( 37 | "depth_filenames" in dataparser_outputs.metadata.keys() 38 | and dataparser_outputs.metadata["depth_filenames"] is not None 39 | ) 40 | self.depth_filenames = self.metadata["depth_filenames"] 41 | self.depth_unit_scale_factor = self.metadata["depth_unit_scale_factor"] 42 | 43 | def get_metadata(self, data: Dict) -> Dict: 44 | filepath = self.depth_filenames[data["image_idx"]] 45 | height = int(self._dataparser_outputs.cameras.height[data["image_idx"]]) 46 | width = int(self._dataparser_outputs.cameras.width[data["image_idx"]]) 47 | 48 | # Scale depth images to meter units and also by scaling applied to cameras 49 | scale_factor = self.depth_unit_scale_factor * self._dataparser_outputs.dataparser_scale 50 | depth_image = get_depth_image_from_path( 51 | filepath=filepath, height=height, width=width, scale_factor=scale_factor 52 | ) 53 | 54 | return {"depth_image": depth_image} 55 | -------------------------------------------------------------------------------- /nerfstudio/data/datasets/semantic_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Semantic dataset. 17 | """ 18 | 19 | from typing import Dict 20 | 21 | import torch 22 | 23 | from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs, Semantics 24 | from nerfstudio.data.datasets.base_dataset import InputDataset 25 | from nerfstudio.data.utils.data_utils import get_semantics_and_mask_tensors_from_path 26 | 27 | 28 | class SemanticDataset(InputDataset): 29 | """Dataset that returns images and semantics and masks. 30 | 31 | Args: 32 | dataparser_outputs: description of where and how to read input images. 33 | """ 34 | 35 | def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0): 36 | super().__init__(dataparser_outputs, scale_factor) 37 | assert "semantics" in dataparser_outputs.metadata.keys() and isinstance(self.metadata["semantics"], Semantics) 38 | self.semantics = self.metadata["semantics"] 39 | self.mask_indices = torch.tensor( 40 | [self.semantics.classes.index(mask_class) for mask_class in self.semantics.mask_classes] 41 | ).view(1, 1, -1) 42 | 43 | def get_metadata(self, data: Dict) -> Dict: 44 | # handle mask 45 | filepath = self.semantics.filenames[data["image_idx"]] 46 | semantic_label, mask = get_semantics_and_mask_tensors_from_path( 47 | filepath=filepath, mask_indices=self.mask_indices, scale_factor=self.scale_factor 48 | ) 49 | if "mask" in data.keys(): 50 | mask = mask & data["mask"] 51 | return {"mask": mask, "semantics": semantic_label} 52 | -------------------------------------------------------------------------------- /nerfstudio/data/scene_box.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Dataset input structures. 17 | """ 18 | 19 | from dataclasses import dataclass 20 | from typing import Dict, Union 21 | 22 | import torch 23 | from torchtyping import TensorType 24 | 25 | 26 | @dataclass 27 | class SceneBox: 28 | """Data to represent the scene box.""" 29 | 30 | aabb: TensorType[2, 3] = None 31 | """aabb: axis-aligned bounding box. 32 | aabb[0] is the minimum (x,y,z) point. 33 | aabb[1] is the maximum (x,y,z) point.""" 34 | 35 | def get_diagonal_length(self): 36 | """Returns the longest diagonal length.""" 37 | diff = self.aabb[1] - self.aabb[0] 38 | length = torch.sqrt((diff**2).sum() + 1e-20) 39 | return length 40 | 41 | def get_center(self): 42 | """Returns the center of the box.""" 43 | diff = self.aabb[1] - self.aabb[0] 44 | return self.aabb[0] + diff / 2.0 45 | 46 | def get_centered_and_scaled_scene_box(self, scale_factor: Union[float, torch.Tensor] = 1.0): 47 | """Returns a new box that has been shifted and rescaled to be centered 48 | about the origin. 49 | 50 | Args: 51 | scale_factor: How much to scale the camera origins by. 52 | """ 53 | return SceneBox(aabb=(self.aabb - self.get_center()) * scale_factor) 54 | 55 | @staticmethod 56 | def get_normalized_positions(positions: TensorType[..., 3], aabb: TensorType[2, 3]): 57 | """Return normalized positions in range [0, 1] based on the aabb axis-aligned bounding box. 58 | 59 | Args: 60 | positions: the xyz positions 61 | aabb: the axis-aligned bounding box 62 | """ 63 | aabb_lengths = aabb[1] - aabb[0] 64 | normalized_positions = (positions - aabb[0]) / aabb_lengths 65 | return normalized_positions 66 | 67 | def to_json(self) -> Dict: 68 | """Returns a json object from the Python object.""" 69 | return {"type": "aabb", "min_point": self.aabb[0].tolist(), "max_point": self.aabb[1].tolist()} 70 | 71 | @staticmethod 72 | def from_json(json_: Dict) -> "SceneBox": 73 | """Returns the an instance of SceneBox from a json dictionary. 74 | 75 | Args: 76 | json_: the json dictionary containing scene box information 77 | """ 78 | assert json_["type"] == "aabb" 79 | aabb = torch.tensor([json_[0], json_[1]]) 80 | return SceneBox(aabb=aabb) 81 | 82 | @staticmethod 83 | def from_camera_poses(poses: TensorType[..., 3, 4], scale_factor: float) -> "SceneBox": 84 | """Returns the instance of SceneBox that fully envelopes a set of poses 85 | 86 | Args: 87 | poses: tensor of camera pose matrices 88 | scale_factor: How much to scale the camera origins by. 89 | """ 90 | xyzs = poses[..., :3, -1] 91 | aabb = torch.stack([torch.min(xyzs, dim=0)[0], torch.max(xyzs, dim=0)[0]]) 92 | return SceneBox(aabb=aabb * scale_factor) 93 | -------------------------------------------------------------------------------- /nerfstudio/data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/exporter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/field_components/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """init field modules""" 16 | from .base_field_component import FieldComponent 17 | from .encodings import Encoding, ScalingAndOffset 18 | from .mlp import MLP 19 | -------------------------------------------------------------------------------- /nerfstudio/field_components/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Special activation functions. 17 | """ 18 | 19 | import torch 20 | from torch.autograd import Function 21 | from torch.cuda.amp import custom_bwd, custom_fwd 22 | 23 | 24 | class _TruncExp(Function): # pylint: disable=abstract-method 25 | # Implementation from torch-ngp: 26 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py 27 | @staticmethod 28 | @custom_fwd(cast_inputs=torch.float32) 29 | def forward(ctx, x): # pylint: disable=arguments-differ 30 | ctx.save_for_backward(x) 31 | return torch.exp(x) 32 | 33 | @staticmethod 34 | @custom_bwd 35 | def backward(ctx, g): # pylint: disable=arguments-differ 36 | x = ctx.saved_tensors[0] 37 | return g * torch.exp(x.clamp(-15, 15)) 38 | 39 | 40 | trunc_exp = _TruncExp.apply 41 | """Same as torch.exp, but with the backward pass clipped to prevent vanishing/exploding 42 | gradients.""" 43 | -------------------------------------------------------------------------------- /nerfstudio/field_components/base_field_component.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | The field module baseclass. 17 | """ 18 | from abc import abstractmethod 19 | from typing import Optional 20 | 21 | from torch import nn 22 | from torchtyping import TensorType 23 | 24 | 25 | class FieldComponent(nn.Module): 26 | """Field modules that can be combined to store and compute the fields. 27 | 28 | Args: 29 | in_dim: Input dimension to module. 30 | out_dim: Output dimension to module. 31 | """ 32 | 33 | def __init__(self, in_dim: Optional[int] = None, out_dim: Optional[int] = None) -> None: 34 | super().__init__() 35 | self.in_dim = in_dim 36 | self.out_dim = out_dim 37 | 38 | def build_nn_modules(self) -> None: 39 | """Function instantiates any torch.nn members within the module. 40 | If none exist, do nothing.""" 41 | 42 | def set_in_dim(self, in_dim: int) -> None: 43 | """Sets input dimension of encoding 44 | 45 | Args: 46 | in_dim: input dimension 47 | """ 48 | if in_dim <= 0: 49 | raise ValueError("Input dimension should be greater than zero") 50 | self.in_dim = in_dim 51 | 52 | def get_out_dim(self) -> int: 53 | """Calculates output dimension of encoding.""" 54 | if self.out_dim is None: 55 | raise ValueError("Output dimension has not been set") 56 | return self.out_dim 57 | 58 | @abstractmethod 59 | def forward(self, in_tensor: TensorType["bs":..., "input_dim"]) -> TensorType["bs":..., "output_dim"]: 60 | """ 61 | Returns processed tensor 62 | 63 | Args: 64 | in_tensor: Input tensor to process 65 | """ 66 | raise NotImplementedError 67 | -------------------------------------------------------------------------------- /nerfstudio/field_components/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """init cuda functions""" 16 | from typing import Callable 17 | 18 | 19 | def _make_lazy_cuda_func(name: str) -> Callable: 20 | """_make_lazy_cuda_func from nerfacc.cuda""" 21 | 22 | def call_cuda(*args, **kwargs): 23 | # pylint: disable=import-outside-toplevel 24 | from ._backend import _C 25 | 26 | return getattr(_C, name)(*args, **kwargs) 27 | 28 | return call_cuda 29 | 30 | 31 | temporal_grid_encode_forward = _make_lazy_cuda_func("temporal_grid_encode_forward") 32 | temporal_grid_encode_backward = _make_lazy_cuda_func("temporal_grid_encode_backward") 33 | -------------------------------------------------------------------------------- /nerfstudio/field_components/cuda/_backend.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """compiling the cuda kernels""" 16 | import glob 17 | import os 18 | import shutil 19 | 20 | from nerfacc.cuda._backend import cuda_toolkit_available 21 | from torch.utils.cpp_extension import _get_build_directory, load 22 | 23 | PATH = os.path.dirname(os.path.abspath(__file__)) 24 | NAME = "nerfstudio_field_components_cuda" 25 | BUILD_DIR = _get_build_directory(NAME, verbose=False) 26 | 27 | 28 | _C = None 29 | if cuda_toolkit_available(): 30 | if os.listdir(BUILD_DIR) != []: 31 | # If the build exists, we assume the extension has been built 32 | # and we can load it. 33 | print("nerfstudio field components: CUDA set up, loading (should be quick)") 34 | _C = load( 35 | name=NAME, 36 | sources=glob.glob(os.path.join(PATH, "csrc/*.cu")), 37 | extra_cflags=["-O3", "-std=c++14"], 38 | extra_cuda_cflags=["-O3", "-std=c++14"], 39 | extra_include_paths=[], 40 | ) 41 | else: 42 | # Build from scratch. Remove the build directory just to be safe: pytorch jit might stuck 43 | # if the build directory exists. 44 | shutil.rmtree(BUILD_DIR) 45 | print("nerfstudio field components: Setting up CUDA (This may take a few minutes the first time)") 46 | _C = load( 47 | name=NAME, 48 | sources=glob.glob(os.path.join(PATH, "csrc/*.cu")), 49 | extra_cflags=["-O3", "-std=c++14"], 50 | extra_cuda_cflags=["-O3", "-std=c++14"], 51 | extra_include_paths=[], 52 | ) 53 | print("nerfstudio field components: Setting up CUDA finished") 54 | else: 55 | print("nerfstudio field components: No CUDA toolkit found. Some models may fail.") 56 | 57 | 58 | __all__ = ["_C"] 59 | -------------------------------------------------------------------------------- /nerfstudio/field_components/cuda/csrc/include/temporal_gridencoder.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Code adapted from (@ashawkey) https://github.com/ashawkey/torch-ngp/ 3 | * Author: @lsongx 4 | */ 5 | 6 | #ifndef _TEMPORAL_HASH_ENCODE_H 7 | #define _TEMPORAL_HASH_ENCODE_H 8 | 9 | #include 10 | #include 11 | 12 | // inputs: input coordinates, [B, D], float, in [0, 1] 13 | // temporal_row_index: row index for sampling from channels, [B, 4*num_of_channels], uint32_t 14 | // embeddings: the grid embedding, [sO, grid_C], float 15 | // offsets: offsets for different levels used in NGP, [L + 1], uint32_t 16 | // outputs: interpolated outputs, [B, L * C], float 17 | // grid_C: number of channels for the grid embedding 18 | // B: batch size 19 | // D: coord dim 20 | // L: number of levels 21 | // S: resolution multiplier at each level 22 | // H: base resolution 23 | 24 | void temporal_grid_encode_forward( 25 | const at::Tensor inputs, 26 | const at::Tensor temporal_row_index, 27 | const at::Tensor embeddings, 28 | const at::Tensor offsets, 29 | at::Tensor outputs, 30 | const uint32_t B, 31 | const uint32_t D, 32 | const uint32_t grid_C, 33 | const uint32_t C, 34 | const uint32_t L, 35 | const float S, 36 | const uint32_t H, 37 | at::optional dy_dx, 38 | const uint32_t gridtype, 39 | const bool align_corners 40 | ); 41 | 42 | void temporal_grid_encode_backward( 43 | const at::Tensor grad, 44 | const at::Tensor inputs, 45 | const at::Tensor temporal_row_index, 46 | const at::Tensor embeddings, 47 | const at::Tensor offsets, 48 | at::Tensor grad_embeddings, 49 | const uint32_t B, 50 | const uint32_t D, 51 | const uint32_t grid_C, 52 | const uint32_t C, 53 | const uint32_t L, 54 | const float S, 55 | const uint32_t H, 56 | const at::optional dy_dx, 57 | at::optional grad_inputs, 58 | const uint32_t gridtype, 59 | const bool align_corners 60 | ); 61 | 62 | #endif -------------------------------------------------------------------------------- /nerfstudio/field_components/cuda/csrc/pybind.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Code adapted from (@ashawkey) https://github.com/ashawkey/torch-ngp/ 3 | * Author: @lsongx 4 | */ 5 | 6 | 7 | #include 8 | 9 | #include "include/temporal_gridencoder.h" 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("temporal_grid_encode_forward", &temporal_grid_encode_forward, "temporal_grid_encode_forward (CUDA)"); 13 | m.def("temporal_grid_encode_backward", &temporal_grid_encode_backward, "temporal_grid_encode_backward (CUDA)"); 14 | } -------------------------------------------------------------------------------- /nerfstudio/field_components/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Code for embeddings. 17 | """ 18 | 19 | 20 | import torch 21 | from torchtyping import TensorType 22 | 23 | from nerfstudio.field_components.base_field_component import FieldComponent 24 | 25 | 26 | class Embedding(FieldComponent): 27 | """Index into embeddings. 28 | # TODO: add different types of initializations 29 | 30 | Args: 31 | in_dim: Number of embeddings 32 | out_dim: Dimension of the embedding vectors 33 | """ 34 | 35 | def __init__(self, in_dim: int, out_dim: int) -> None: 36 | super().__init__() 37 | self.in_dim = in_dim 38 | self.out_dim = out_dim 39 | self.build_nn_modules() 40 | 41 | def build_nn_modules(self) -> None: 42 | self.embedding = torch.nn.Embedding(self.in_dim, self.out_dim) 43 | 44 | def mean(self, dim=0): 45 | """Return the mean of the embedding weights along a dim.""" 46 | return self.embedding.weight.mean(dim) 47 | 48 | def forward(self, in_tensor: TensorType[..., "input_dim"]) -> TensorType[..., "output_dim"]: 49 | """Call forward 50 | 51 | Args: 52 | in_tensor: input tensor to process 53 | """ 54 | return self.embedding(in_tensor) 55 | -------------------------------------------------------------------------------- /nerfstudio/field_components/spatial_distortions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Space distortions.""" 16 | 17 | from typing import Optional, Union 18 | 19 | import torch 20 | from functorch import jacrev, vmap 21 | from torch import nn 22 | from torchtyping import TensorType 23 | 24 | from nerfstudio.utils.math import Gaussians 25 | 26 | 27 | class SpatialDistortion(nn.Module): 28 | """Apply spatial distortions""" 29 | 30 | def forward( 31 | self, positions: Union[TensorType["bs":..., 3], Gaussians] 32 | ) -> Union[TensorType["bs":..., 3], Gaussians]: 33 | """ 34 | Args: 35 | positions: Sample to distort 36 | 37 | Returns: 38 | Union: distorted sample 39 | """ 40 | 41 | 42 | class SceneContraction(SpatialDistortion): 43 | """Contract unbounded space using the contraction was proposed in MipNeRF-360. 44 | We use the following contraction equation: 45 | 46 | .. math:: 47 | 48 | f(x) = \\begin{cases} 49 | x & ||x|| \\leq 1 \\\\ 50 | (2 - \\frac{1}{||x||})(\\frac{x}{||x||}) & ||x|| > 1 51 | \\end{cases} 52 | 53 | If the order is not specified, we use the Frobenius norm, this will contract the space to a sphere of 54 | radius 1. If the order is L_inf (order=float("inf")), we will contract the space to a cube of side length 2. 55 | If using voxel based encodings such as the Hash encoder, we recommend using the L_inf norm. 56 | 57 | Args: 58 | order: Order of the norm. Default to the Frobenius norm. Must be set to None for Gaussians. 59 | 60 | """ 61 | 62 | def __init__(self, order: Optional[Union[float, int]] = None) -> None: 63 | super().__init__() 64 | self.order = order 65 | 66 | def forward(self, positions): 67 | def contract(x): 68 | mag = torch.linalg.norm(x, ord=self.order, dim=-1)[..., None] 69 | return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag)) 70 | 71 | if isinstance(positions, Gaussians): 72 | means = contract(positions.mean.clone()) 73 | 74 | contract = lambda x: (2 - (1 / torch.linalg.norm(x, ord=self.order, dim=-1, keepdim=True))) * ( 75 | x / torch.linalg.norm(x, ord=self.order, dim=-1, keepdim=True) 76 | ) 77 | jc_means = vmap(jacrev(contract))(positions.mean.view(-1, positions.mean.shape[-1])) 78 | jc_means = jc_means.view(list(positions.mean.shape) + [positions.mean.shape[-1]]) 79 | 80 | # Only update covariances on positions outside the unit sphere 81 | mag = positions.mean.norm(dim=-1) 82 | mask = mag >= 1 83 | cov = positions.cov.clone() 84 | cov[mask] = jc_means[mask] @ positions.cov[mask] @ torch.transpose(jc_means[mask], -2, -1) 85 | 86 | return Gaussians(mean=means, cov=cov) 87 | 88 | return contract(positions) 89 | -------------------------------------------------------------------------------- /nerfstudio/fields/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/generative/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/model_components/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/model_components/ray_generators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Ray generator. 17 | """ 18 | from torch import nn 19 | from torchtyping import TensorType 20 | 21 | from nerfstudio.cameras.camera_optimizers import CameraOptimizer 22 | from nerfstudio.cameras.cameras import Cameras 23 | from nerfstudio.cameras.rays import RayBundle 24 | 25 | 26 | class RayGenerator(nn.Module): 27 | """torch.nn Module for generating rays. 28 | This class is the interface between the scene's cameras/camera optimizer and the ray sampler. 29 | 30 | Args: 31 | cameras: Camera objects containing camera info. 32 | pose_optimizer: pose optimization module, for optimizing noisy camera intrinsics/extrinsics. 33 | """ 34 | 35 | def __init__(self, cameras: Cameras, pose_optimizer: CameraOptimizer) -> None: 36 | super().__init__() 37 | self.cameras = cameras 38 | self.pose_optimizer = pose_optimizer 39 | self.register_buffer("image_coords", cameras.get_image_coords(), persistent=False) 40 | 41 | def forward(self, ray_indices: TensorType["num_rays", 3]) -> RayBundle: 42 | """Index into the cameras to generate the rays. 43 | 44 | Args: 45 | ray_indices: Contains camera, row, and col indices for target rays. 46 | """ 47 | c = ray_indices[:, 0] # camera indices 48 | y = ray_indices[:, 1] # row indices 49 | x = ray_indices[:, 2] # col indices 50 | coords = self.image_coords[y, x] 51 | 52 | camera_opt_to_camera = self.pose_optimizer(c) 53 | 54 | ray_bundle = self.cameras.generate_rays( 55 | camera_indices=c.unsqueeze(-1), 56 | coords=coords, 57 | camera_opt_to_camera=camera_opt_to_camera, 58 | ) 59 | return ray_bundle 60 | -------------------------------------------------------------------------------- /nerfstudio/model_components/shaders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shaders for rendering.""" 16 | from typing import Optional 17 | 18 | from torch import nn 19 | from torchtyping import TensorType 20 | 21 | 22 | class LambertianShader(nn.Module): 23 | """Calculate Lambertian shading.""" 24 | 25 | @classmethod 26 | def forward( 27 | cls, 28 | rgb: TensorType["bs":..., 3], 29 | normals: TensorType["bs":..., 3], 30 | light_direction: TensorType["bs":..., 3], 31 | shading_weight: float = 1.0, 32 | detach_normals=True, 33 | ): 34 | """Calculate Lambertian shading. 35 | 36 | Args: 37 | rgb: Accumulated rgb along a ray. 38 | normals: Accumulated normals along a ray. 39 | light_direction: Direction of light source. 40 | shading_weight: Lambertian shading (1.0) vs. ambient lighting (0.0) ratio 41 | detach_normals: Detach normals from the computation graph when computing shading. 42 | 43 | Returns: 44 | Textureless Lambertian shading, Lambertian shading 45 | """ 46 | if detach_normals: 47 | normals = normals.detach() 48 | 49 | lambertian = (1 - shading_weight) + shading_weight * (normals @ light_direction).clamp(min=0) 50 | shaded = lambertian.unsqueeze(-1).repeat(1, 3) 51 | shaded_albedo = rgb * lambertian.unsqueeze(-1) 52 | 53 | return shaded, shaded_albedo 54 | 55 | 56 | class NormalsShader(nn.Module): 57 | """Calculate shading for normals.""" 58 | 59 | @classmethod 60 | def forward( 61 | cls, 62 | normals: TensorType["bs":..., 3], 63 | weights: Optional[TensorType["bs":..., 1]] = None, 64 | ): 65 | """Applies a rainbow colormap to the normals. 66 | 67 | Args: 68 | normals: Normalized 3D vectors. 69 | weights: Optional weights to scale to the normal colors. (Can be used for masking) 70 | 71 | Returns: 72 | Colored normals 73 | """ 74 | normals = (normals + 1) / 2 75 | if weights is not None: 76 | normals = normals * weights 77 | return normals 78 | -------------------------------------------------------------------------------- /nerfstudio/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/plugins/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Module that keeps all registered plugins and allows for plugin discovery. 17 | """ 18 | 19 | import sys 20 | import typing as t 21 | 22 | from rich.progress import Console 23 | 24 | from nerfstudio.engine.trainer import TrainerConfig 25 | from nerfstudio.plugins.types import MethodSpecification 26 | 27 | if sys.version_info < (3, 10): 28 | from importlib_metadata import entry_points 29 | else: 30 | from importlib.metadata import entry_points 31 | CONSOLE = Console(width=120) 32 | 33 | 34 | def discover_methods() -> t.Tuple[t.Dict[str, TrainerConfig], t.Dict[str, str]]: 35 | """ 36 | Discovers all methods registered using the `nerfstudio.method_configs` entrypoint. 37 | """ 38 | methods = {} 39 | descriptions = {} 40 | discovered_entry_points = entry_points(group="nerfstudio.method_configs") 41 | for name in discovered_entry_points.names: 42 | specification = discovered_entry_points[name].load() 43 | if not isinstance(specification, MethodSpecification): 44 | CONSOLE.print( 45 | "[bold yellow]Warning: Could not entry point {n} as it is not an instance of MethodSpecification" 46 | ) 47 | continue 48 | specification = t.cast(MethodSpecification, specification) 49 | methods[specification.config.method_name] = specification.config 50 | descriptions[specification.config.method_name] = specification.description 51 | return methods, descriptions 52 | -------------------------------------------------------------------------------- /nerfstudio/plugins/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This package contains specifications used to register plugins. 17 | """ 18 | from dataclasses import dataclass 19 | 20 | from nerfstudio.engine.trainer import TrainerConfig 21 | 22 | 23 | @dataclass 24 | class MethodSpecification: 25 | """ 26 | Method specification class used to register custom methods with Nerfstudio. 27 | The registered methods will be available in commands such as `ns-train` 28 | """ 29 | 30 | config: TrainerConfig 31 | """Trainer configuration""" 32 | description: str 33 | """Method description shown in `ns-train` help""" 34 | -------------------------------------------------------------------------------- /nerfstudio/process_data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/process_data/record3d_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for processing record3d data.""" 16 | 17 | import json 18 | from pathlib import Path 19 | from typing import List 20 | 21 | import numpy as np 22 | from rich.console import Console 23 | from scipy.spatial.transform import Rotation 24 | 25 | from nerfstudio.process_data.process_data_utils import CAMERA_MODELS 26 | from nerfstudio.utils import io 27 | 28 | CONSOLE = Console(width=120) 29 | 30 | 31 | def record3d_to_json(images_paths: List[Path], metadata_path: Path, output_dir: Path, indices: np.ndarray) -> int: 32 | """Converts Record3D's metadata and image paths to a JSON file. 33 | 34 | Args: 35 | images_paths: list if image paths. 36 | metadata_path: Path to the Record3D metadata JSON file. 37 | output_dir: Path to the output directory. 38 | indices: Indices to sample the metadata_path. Should be the same length as images_paths. 39 | 40 | Returns: 41 | The number of registered images. 42 | """ 43 | 44 | assert len(images_paths) == len(indices) 45 | 46 | metadata_dict = io.load_from_json(metadata_path) 47 | 48 | poses_data = np.array(metadata_dict["poses"]) # (N, 3, 4) 49 | # NB: Record3D / scipy use "scalar-last" format quaternions (x y z w) 50 | # https://fzheng.me/2017/11/12/quaternion_conventions_en/ 51 | camera_to_worlds = np.concatenate( 52 | [Rotation.from_quat(poses_data[:, :4]).as_matrix(), poses_data[:, 4:, None]], 53 | axis=-1, 54 | ).astype(np.float32) 55 | camera_to_worlds = camera_to_worlds[indices] 56 | 57 | homogeneous_coord = np.zeros_like(camera_to_worlds[..., :1, :]) 58 | homogeneous_coord[..., :, 3] = 1 59 | camera_to_worlds = np.concatenate([camera_to_worlds, homogeneous_coord], -2) 60 | 61 | frames = [] 62 | for i, im_path in enumerate(images_paths): 63 | c2w = camera_to_worlds[i] 64 | frame = { 65 | "file_path": im_path.as_posix(), 66 | "transform_matrix": c2w.tolist(), 67 | } 68 | frames.append(frame) 69 | 70 | # Camera intrinsics 71 | K = np.array(metadata_dict["K"]).reshape((3, 3)).T 72 | focal_length = K[0, 0] 73 | 74 | H = metadata_dict["h"] 75 | W = metadata_dict["w"] 76 | 77 | # TODO(akristoffersen): The metadata dict comes with principle points, 78 | # but caused errors in image coord indexing. Should update once that is fixed. 79 | cx, cy = W / 2, H / 2 80 | 81 | out = { 82 | "fl_x": focal_length, 83 | "fl_y": focal_length, 84 | "cx": cx, 85 | "cy": cy, 86 | "w": W, 87 | "h": H, 88 | "camera_model": CAMERA_MODELS["perspective"].name, 89 | } 90 | 91 | out["frames"] = frames 92 | 93 | with open(output_dir / "transforms.json", "w", encoding="utf-8") as f: 94 | json.dump(out, f, indent=4) 95 | 96 | return len(frames) 97 | -------------------------------------------------------------------------------- /nerfstudio/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/nerfstudio/py.typed -------------------------------------------------------------------------------- /nerfstudio/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/utils/colors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common Colors""" 16 | from typing import Union 17 | 18 | import torch 19 | from torchtyping import TensorType 20 | 21 | WHITE = torch.tensor([1.0, 1.0, 1.0]) 22 | BLACK = torch.tensor([0.0, 0.0, 0.0]) 23 | RED = torch.tensor([1.0, 0.0, 0.0]) 24 | GREEN = torch.tensor([0.0, 1.0, 0.0]) 25 | BLUE = torch.tensor([0.0, 0.0, 1.0]) 26 | 27 | COLORS_DICT = { 28 | "white": WHITE, 29 | "black": BLACK, 30 | "red": RED, 31 | "green": GREEN, 32 | "blue": BLUE, 33 | } 34 | 35 | 36 | def get_color(color: Union[str, list]) -> TensorType[3]: 37 | """ 38 | Args: 39 | color (Union[str, list]): Color as a string or a rgb list 40 | 41 | Returns: 42 | TensorType[3]: Parsed color 43 | """ 44 | if isinstance(color, str): 45 | color = color.lower() 46 | if color not in COLORS_DICT: 47 | raise ValueError(f"{color} is not a valid preset color") 48 | return COLORS_DICT[color] 49 | if isinstance(color, list): 50 | if len(color) != 3: 51 | raise ValueError(f"Color should be 3 values (RGB) instead got {color}") 52 | return torch.tensor(color) 53 | 54 | raise ValueError(f"Color should be an RGB list or string, instead got {type(color)}") 55 | -------------------------------------------------------------------------------- /nerfstudio/utils/comms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """functionality to handle multiprocessing syncing and communicating""" 16 | import torch.distributed as dist 17 | 18 | LOCAL_PROCESS_GROUP = None 19 | 20 | 21 | def is_dist_avail_and_initialized() -> bool: 22 | """Returns True if distributed is available and initialized.""" 23 | return dist.is_available() and dist.is_initialized() 24 | 25 | 26 | def get_world_size() -> int: 27 | """Get total number of available gpus""" 28 | if not is_dist_avail_and_initialized(): 29 | return 1 30 | return dist.get_world_size() 31 | 32 | 33 | def get_rank() -> int: 34 | """Get global rank of current thread""" 35 | if not is_dist_avail_and_initialized(): 36 | return 0 37 | return dist.get_rank() 38 | 39 | 40 | def get_local_rank() -> int: 41 | """The rank of the current process within the local (per-machine) process group.""" 42 | if not is_dist_avail_and_initialized(): 43 | return 0 44 | assert ( 45 | LOCAL_PROCESS_GROUP is not None 46 | ), "Local process group is not created! Please use launch() to spawn processes!" 47 | return dist.get_rank(group=LOCAL_PROCESS_GROUP) 48 | 49 | 50 | def get_local_size() -> int: 51 | """ 52 | The size of the per-machine process group, 53 | i.e. the number of processes per machine. 54 | """ 55 | if not is_dist_avail_and_initialized(): 56 | return 1 57 | return dist.get_world_size(group=LOCAL_PROCESS_GROUP) 58 | 59 | 60 | def is_main_process() -> bool: 61 | """check to see if you are currently on the main process""" 62 | return get_rank() == 0 63 | 64 | 65 | def synchronize(): 66 | """ 67 | Helper function to synchronize (barrier) among all processes when 68 | using distributed training 69 | """ 70 | if dist.get_world_size() == 1: 71 | return 72 | if dist.get_backend() == dist.Backend.NCCL: 73 | # This argument is needed to avoid warnings. 74 | # It's valid only for NCCL backend. 75 | dist.barrier(device_ids=[get_local_rank()]) 76 | else: 77 | dist.barrier() 78 | -------------------------------------------------------------------------------- /nerfstudio/utils/decorators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Decorator definitions 17 | """ 18 | from typing import Callable, List 19 | 20 | from nerfstudio.utils import comms 21 | 22 | 23 | def decorate_all(decorators: List[Callable]) -> Callable: 24 | """A decorator to decorate all member functions of a class 25 | 26 | Args: 27 | decorators: list of decorators to add to all functions in the class 28 | """ 29 | 30 | def decorate(cls): 31 | for attr in cls.__dict__: 32 | if callable(getattr(cls, attr)) and attr != "__init__": 33 | for decorator in decorators: 34 | setattr(cls, attr, decorator(getattr(cls, attr))) 35 | return cls 36 | 37 | return decorate 38 | 39 | 40 | def check_profiler_enabled(func: Callable) -> Callable: 41 | """Decorator: check if profiler is enabled""" 42 | 43 | def wrapper(self, *args, **kwargs): 44 | ret = None 45 | if self.config.enable_profiler: 46 | ret = func(self, *args, **kwargs) 47 | return ret 48 | 49 | return wrapper 50 | 51 | 52 | def check_viewer_enabled(func: Callable) -> Callable: 53 | """Decorator: check if viewer is enabled and only run on main process""" 54 | 55 | def wrapper(self, *args, **kwargs): 56 | ret = None 57 | if self.config.is_viewer_enabled() and comms.is_main_process(): 58 | ret = func(self, *args, **kwargs) 59 | return ret 60 | 61 | return wrapper 62 | 63 | 64 | def check_eval_enabled(func: Callable) -> Callable: 65 | """Decorator: check if evaluation step is enabled""" 66 | 67 | def wrapper(self, *args, **kwargs): 68 | ret = None 69 | if self.config.is_wandb_enabled() or self.config.is_tensorboard_enabled(): 70 | ret = func(self, *args, **kwargs) 71 | return ret 72 | 73 | return wrapper 74 | 75 | 76 | def check_main_thread(func: Callable) -> Callable: 77 | """Decorator: check if you are on main thread""" 78 | 79 | def wrapper(*args, **kwargs): 80 | ret = None 81 | if comms.is_main_process(): 82 | ret = func(*args, **kwargs) 83 | return ret 84 | 85 | return wrapper 86 | -------------------------------------------------------------------------------- /nerfstudio/utils/install_checks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helpers for checking if programs are installed""" 16 | 17 | import shutil 18 | import sys 19 | 20 | from rich.console import Console 21 | 22 | CONSOLE = Console(width=120) 23 | 24 | 25 | def check_ffmpeg_installed(): 26 | """Checks if ffmpeg is installed.""" 27 | ffmpeg_path = shutil.which("ffmpeg") 28 | if ffmpeg_path is None: 29 | CONSOLE.print("[bold red]Could not find ffmpeg. Please install ffmpeg.") 30 | print("See https://ffmpeg.org/download.html for installation instructions.") 31 | sys.exit(1) 32 | 33 | 34 | def check_colmap_installed(): 35 | """Checks if colmap is installed.""" 36 | colmap_path = shutil.which("colmap") 37 | if colmap_path is None: 38 | CONSOLE.print("[bold red]Could not find COLMAP. Please install COLMAP.") 39 | print("See https://colmap.github.io/install.html for installation instructions.") 40 | sys.exit(1) 41 | 42 | 43 | def check_curl_installed(): 44 | """Checks if curl is installed.""" 45 | curl_path = shutil.which("curl") 46 | if curl_path is None: 47 | CONSOLE.print("[bold red]Could not find [yellow]curl[red], Please install [yellow]curl") 48 | sys.exit(1) 49 | -------------------------------------------------------------------------------- /nerfstudio/utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Input/output utils. 17 | """ 18 | 19 | import json 20 | from pathlib import Path 21 | 22 | 23 | def load_from_json(filename: Path): 24 | """Load a dictionary from a JSON filename. 25 | 26 | Args: 27 | filename: The filename to load from. 28 | """ 29 | assert filename.suffix == ".json" 30 | with open(filename, encoding="UTF-8") as file: 31 | return json.load(file) 32 | 33 | 34 | def write_to_json(filename: Path, content: dict): 35 | """Write data to a JSON file. 36 | 37 | Args: 38 | filename: The filename to write to. 39 | content: The dictionary data to write. 40 | """ 41 | assert filename.suffix == ".json" 42 | with open(filename, "w", encoding="UTF-8") as file: 43 | json.dump(content, file) 44 | -------------------------------------------------------------------------------- /nerfstudio/utils/poses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Common 3D pose methods 17 | """ 18 | 19 | import torch 20 | from torchtyping import TensorType 21 | 22 | 23 | def to4x4(pose: TensorType[..., 3, 4]) -> TensorType[..., 4, 4]: 24 | """Convert 3x4 pose matrices to a 4x4 with the addition of a homogeneous coordinate. 25 | 26 | Args: 27 | pose: Camera pose without homogenous coordinate. 28 | 29 | Returns: 30 | Camera poses with additional homogenous coordinate added. 31 | """ 32 | constants = torch.zeros_like(pose[..., :1, :], device=pose.device) 33 | constants[..., :, 3] = 1 34 | return torch.cat([pose, constants], dim=-2) 35 | 36 | 37 | def inverse(pose: TensorType[..., 3, 4]) -> TensorType[..., 3, 4]: 38 | """Invert provided pose matrix. 39 | 40 | Args: 41 | pose: Camera pose without homogenous coordinate. 42 | 43 | Returns: 44 | Inverse of pose. 45 | """ 46 | R = pose[..., :3, :3] 47 | t = pose[..., :3, 3:] 48 | R_inverse = R.transpose(-2, -1) # pylint: disable=invalid-name 49 | t_inverse = -R_inverse.matmul(t) 50 | return torch.cat([R_inverse, t_inverse], dim=-1) 51 | 52 | 53 | def multiply(pose_a: TensorType[..., 3, 4], pose_b: TensorType[..., 3, 4]) -> TensorType[..., 3, 4]: 54 | """Multiply two pose matrices, A @ B. 55 | 56 | Args: 57 | pose_a: Left pose matrix, usually a transformation applied to the right. 58 | pose_b: Right pose matrix, usually a camera pose that will be transformed by pose_a. 59 | 60 | Returns: 61 | Camera pose matrix where pose_a was applied to pose_b. 62 | """ 63 | R1, t1 = pose_a[..., :3, :3], pose_a[..., :3, 3:] 64 | R2, t2 = pose_b[..., :3, :3], pose_b[..., :3, 3:] 65 | R = R1.matmul(R2) 66 | t = t1 + R1.matmul(t2) 67 | return torch.cat([R, t], dim=-1) 68 | 69 | 70 | def normalize(poses: TensorType[..., 3, 4]) -> TensorType[..., 3, 4]: 71 | """Normalize the XYZs of poses to fit within a unit cube ([-1, 1]). Note: This operation is not in-place. 72 | 73 | Args: 74 | poses: A collection of poses to be normalized. 75 | 76 | Returns; 77 | Normalized collection of poses. 78 | """ 79 | pose_copy = torch.clone(poses) 80 | pose_copy[..., :3, 3] /= torch.max(torch.abs(poses[..., :3, 3])) 81 | 82 | return pose_copy 83 | -------------------------------------------------------------------------------- /nerfstudio/utils/printing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of common strings and print statements used throughout the codebase.""" 16 | 17 | from math import floor, log 18 | 19 | from rich.console import Console 20 | 21 | CONSOLE = Console(width=120) 22 | 23 | 24 | def print_tcnn_speed_warning(method_name: str): 25 | """Prints a warning about the speed of the TCNN.""" 26 | CONSOLE.line() 27 | CONSOLE.print(f"[bold yellow]WARNING: Using a slow implementation of {method_name}. ") 28 | CONSOLE.print( 29 | "[bold yellow]:person_running: :person_running: " 30 | + "Install tcnn for speedups :person_running: :person_running:" 31 | ) 32 | CONSOLE.print("[yellow]pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch") 33 | CONSOLE.line() 34 | 35 | 36 | def human_format(num): 37 | """Format a number in a more human readable way 38 | 39 | Args: 40 | num: number to format 41 | """ 42 | units = ["", "K", "M", "B", "T", "P"] 43 | k = 1000.0 44 | magnitude = int(floor(log(num, k))) 45 | return f"{(num / k**magnitude):.2f} {units[magnitude]}" 46 | -------------------------------------------------------------------------------- /nerfstudio/utils/profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Profiler base class and functionality 17 | """ 18 | from __future__ import annotations 19 | 20 | import time 21 | from typing import Callable 22 | 23 | from rich.console import Console 24 | 25 | from nerfstudio.configs import base_config as cfg 26 | from nerfstudio.utils import comms 27 | from nerfstudio.utils.decorators import ( 28 | check_main_thread, 29 | check_profiler_enabled, 30 | decorate_all, 31 | ) 32 | 33 | CONSOLE = Console(width=120) 34 | 35 | PROFILER = [] 36 | 37 | 38 | def time_function(func: Callable) -> Callable: 39 | """Decorator: time a function call""" 40 | 41 | def wrapper(*args, **kwargs): 42 | start = time.time() 43 | ret = func(*args, **kwargs) 44 | if PROFILER: 45 | class_str = func.__qualname__ 46 | PROFILER[0].update_time(class_str, start, time.time()) 47 | return ret 48 | 49 | return wrapper 50 | 51 | 52 | def flush_profiler(config: cfg.LoggingConfig): 53 | """Method that checks if profiler is enabled before flushing""" 54 | if config.enable_profiler and PROFILER: 55 | PROFILER[0].print_profile() 56 | 57 | 58 | def setup_profiler(config: cfg.LoggingConfig): 59 | """Initialization of profilers""" 60 | if comms.is_main_process(): 61 | PROFILER.append(Profiler(config)) 62 | 63 | 64 | @decorate_all([check_profiler_enabled, check_main_thread]) 65 | class Profiler: 66 | """Profiler class""" 67 | 68 | def __init__(self, config: cfg.LoggingConfig): 69 | self.config = config 70 | self.profiler_dict = {} 71 | 72 | def update_time(self, func_name: str, start_time: float, end_time: float): 73 | """update the profiler dictionary with running averages of durations 74 | 75 | Args: 76 | func_name: the function name that is being profiled 77 | start_time: the start time when function is called 78 | end_time: the end time when function terminated 79 | """ 80 | val = end_time - start_time 81 | func_dict = self.profiler_dict.get(func_name, {"val": 0, "step": 0}) 82 | prev_val = func_dict["val"] 83 | prev_step = func_dict["step"] 84 | self.profiler_dict[func_name] = {"val": (prev_val * prev_step + val) / (prev_step + 1), "step": prev_step + 1} 85 | 86 | def print_profile(self): 87 | """helper to print out the profiler stats""" 88 | CONSOLE.print("Printing profiling stats, from longest to shortest duration in seconds") 89 | sorted_keys = sorted( 90 | self.profiler_dict.keys(), 91 | key=lambda k: self.profiler_dict[k]["val"], 92 | reverse=True, 93 | ) 94 | for k in sorted_keys: 95 | val = f"{self.profiler_dict[k]['val']:0.4f}" 96 | CONSOLE.print(f"{k:<20}: {val:<20}") 97 | -------------------------------------------------------------------------------- /nerfstudio/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Additional rich ui components""" 16 | 17 | from contextlib import nullcontext 18 | from typing import Optional 19 | 20 | from rich.console import Console 21 | from rich.progress import ( 22 | BarColumn, 23 | Progress, 24 | ProgressColumn, 25 | TaskProgressColumn, 26 | TextColumn, 27 | TimeRemainingColumn, 28 | ) 29 | from rich.text import Text 30 | 31 | CONSOLE = Console(width=120) 32 | 33 | 34 | class ItersPerSecColumn(ProgressColumn): 35 | """Renders the iterations per second for a progress bar.""" 36 | 37 | def __init__(self, suffix="it/s") -> None: 38 | super().__init__() 39 | self.suffix = suffix 40 | 41 | def render(self, task: "Task") -> Text: 42 | """Show data transfer speed.""" 43 | speed = task.finished_speed or task.speed 44 | if speed is None: 45 | return Text("?", style="progress.data.speed") 46 | return Text(f"{speed:.2f} {self.suffix}", style="progress.data.speed") 47 | 48 | 49 | def status(msg: str, spinner: str = "bouncingBall", verbose: bool = False): 50 | """A context manager that does nothing is verbose is True. Otherwise it hides logs under a message. 51 | 52 | Args: 53 | msg: The message to log. 54 | spinner: The spinner to use. 55 | verbose: If True, print all logs, else hide them. 56 | """ 57 | if verbose: 58 | return nullcontext() 59 | return CONSOLE.status(msg, spinner=spinner) 60 | 61 | 62 | def get_progress(description: str, suffix: Optional[str] = None): 63 | """Helper function to return a rich Progress object.""" 64 | progress_list = [TextColumn(description), BarColumn(), TaskProgressColumn(show_speed=True)] 65 | progress_list += [ItersPerSecColumn(suffix=suffix)] if suffix else [] 66 | progress_list += [TimeRemainingColumn(elapsed_when_finished=True, compact=True)] 67 | progress = Progress(*progress_list) 68 | return progress 69 | -------------------------------------------------------------------------------- /nerfstudio/utils/scripts.py: -------------------------------------------------------------------------------- 1 | """Helpers for running script commands.""" 2 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import subprocess 17 | import sys 18 | from typing import Optional 19 | 20 | from rich.console import Console 21 | 22 | CONSOLE = Console(width=120) 23 | 24 | 25 | def run_command(cmd: str, verbose=False) -> Optional[str]: 26 | """Runs a command and returns the output. 27 | 28 | Args: 29 | cmd: Command to run. 30 | verbose: If True, logs the output of the command. 31 | Returns: 32 | The output of the command if return_output is True, otherwise None. 33 | """ 34 | out = subprocess.run(cmd, capture_output=not verbose, shell=True, check=False) 35 | if out.returncode != 0: 36 | CONSOLE.rule("[bold red] :skull: :skull: :skull: ERROR :skull: :skull: :skull: ", style="red") 37 | CONSOLE.print(f"[bold red]Error running command: {cmd}") 38 | CONSOLE.rule(style="red") 39 | CONSOLE.print(out.stderr.decode("utf-8")) 40 | sys.exit(1) 41 | if out.stdout is not None: 42 | return out.stdout.decode("utf-8") 43 | return out 44 | -------------------------------------------------------------------------------- /nerfstudio/viewer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/.env.development: -------------------------------------------------------------------------------- 1 | BROWSER=none 2 | FAST_REFRESH=false 3 | HOST=0.0.0.0 4 | PORT=4000 5 | ESLINT_NO_DEV_ERRORS=true -------------------------------------------------------------------------------- /nerfstudio/viewer/app/.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "es2021": true 5 | }, 6 | "extends": [ 7 | "eslint:recommended", 8 | "plugin:react/recommended", 9 | "airbnb", 10 | "prettier" 11 | ], 12 | "parser": "@typescript-eslint/parser", 13 | "parserOptions": { 14 | "ecmaFeatures": { 15 | "jsx": true 16 | }, 17 | "ecmaVersion": "latest", 18 | "sourceType": "module" 19 | }, 20 | "plugins": ["react", "@typescript-eslint", "unused-imports"], 21 | "rules": { 22 | "arrow-body-style": "off", 23 | "camelcase": "off", 24 | "import/prefer-default-export": "off", 25 | "no-alert": "off", 26 | "no-console": "off", 27 | "prefer-destructuring": "off", 28 | "react/destructuring-assignment": "off", 29 | "react/prop-types": 0, 30 | "unused-imports/no-unused-imports-ts": 2 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vscode/ 3 | node_modules/ 4 | build/ 5 | .DS_Store 6 | *.tgz 7 | my-app* 8 | template/src/__tests__/__snapshots__/ 9 | lerna-debug.log 10 | npm-debug.log* 11 | yarn-debug.log* 12 | yarn-error.log* 13 | /.changelog 14 | .npm/ -------------------------------------------------------------------------------- /nerfstudio/viewer/app/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "viewer", 3 | "homepage": ".", 4 | "version": "23-02-3-0", 5 | "private": true, 6 | "dependencies": { 7 | "@emotion/react": "^11.10.4", 8 | "@emotion/styled": "^11.10.4", 9 | "@mui/icons-material": "^5.10.3", 10 | "@mui/lab": "^5.0.0-alpha.98", 11 | "@mui/material": "^5.10.3", 12 | "@mui/system": "^5.10.6", 13 | "@mui/x-date-pickers": "^5.0.0", 14 | "@reduxjs/toolkit": "^1.8.3", 15 | "@testing-library/jest-dom": "^5.16.4", 16 | "@testing-library/react": "^13.3.0", 17 | "@testing-library/user-event": "^14.2.0", 18 | "camera-controls": "^1.37.2", 19 | "classnames": "^2.3.1", 20 | "dat.gui": "^0.7.9", 21 | "dayjs": "^1.11.5", 22 | "eslint-config-prettier": "^8.5.0", 23 | "eslint-plugin-unused-imports": "^2.0.0", 24 | "leva": "^0.9.29", 25 | "meshline": "^2.0.4", 26 | "msgpack-lite": "^0.1.26", 27 | "prop-types": "^15.8.1", 28 | "re-resizable": "^6.9.9", 29 | "react": "^18.1.0", 30 | "react-dom": "^18.1.0", 31 | "react-icons": "^4.4.0", 32 | "react-pro-sidebar": "^0.7.1", 33 | "react-redux": "^8.0.2", 34 | "redux": "^4.2.0", 35 | "sass": "^1.54.8", 36 | "socket.io-client": "^4.5.1", 37 | "three": "^0.142.0", 38 | "three-wtm": "^1.0", 39 | "websocket": "^1.0.34", 40 | "wwobjloader2": "^4.0" 41 | }, 42 | "scripts": { 43 | "start": "react-scripts start", 44 | "build": "react-scripts build", 45 | "test": "react-scripts test", 46 | "eject": "react-scripts eject", 47 | "electron": "electron .", 48 | "lint": "eslint --ext .js,.jsx .", 49 | "lint:fix": "eslint --fix --ext .js,.jsx ." 50 | }, 51 | "eslintConfig": { 52 | "extends": "react-app" 53 | }, 54 | "browserslist": { 55 | "production": [ 56 | ">0.2%", 57 | "not dead", 58 | "not op_mini all" 59 | ], 60 | "development": [ 61 | "last 1 chrome version", 62 | "last 1 firefox version", 63 | "last 1 safari version" 64 | ] 65 | }, 66 | "main": "public/electron.js", 67 | "author": "", 68 | "license": "ISC", 69 | "description": "", 70 | "devDependencies": { 71 | "concurrently": "^7.2.1", 72 | "eslint": "^8.2.0", 73 | "eslint-config-airbnb": "19.0.4", 74 | "eslint-plugin-import": "^2.25.3", 75 | "eslint-plugin-jsx-a11y": "^6.5.1", 76 | "eslint-plugin-react": "^7.28.0", 77 | "eslint-plugin-react-hooks": "^4.3.0", 78 | "prettier": "2.7.1", 79 | "react-scripts": "^5.0.1", 80 | "typescript": "^4.7.3", 81 | "wait-on": "^6.0.1" 82 | }, 83 | "resolutions": { 84 | "nth-check": "^2.0.1" 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/public/electron.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | 3 | const { app, BrowserWindow } = require('electron'); 4 | const isDev = require('electron-is-dev'); 5 | 6 | function createWindow() { 7 | // Create the browser window. 8 | const win = new BrowserWindow({ 9 | width: 800, 10 | height: 600, 11 | webPreferences: { 12 | nodeIntegration: true, 13 | }, 14 | }); 15 | 16 | // and load the index.html of the app. 17 | // win.loadFile("index.html"); 18 | win.loadURL( 19 | isDev 20 | ? 'http://localhost:3000' 21 | : `file://${path.join(__dirname, '../build/index.html')}`, 22 | ); 23 | // Open the DevTools. 24 | if (isDev) { 25 | win.webContents.openDevTools({ mode: 'detach' }); 26 | } 27 | } 28 | 29 | // This method will be called when Electron has finished 30 | // initialization and is ready to create browser windows. 31 | // Some APIs can only be used after this event occurs. 32 | app.whenReady().then(createWindow); 33 | 34 | // Quit when all windows are closed, except on macOS. There, it's common 35 | // for applications and their menu bar to stay active until the user quits 36 | // explicitly with Cmd + Q. 37 | app.on('window-all-closed', () => { 38 | if (process.platform !== 'darwin') { 39 | app.quit(); 40 | } 41 | }); 42 | 43 | app.on('activate', () => { 44 | if (BrowserWindow.getAllWindows().length === 0) { 45 | createWindow(); 46 | } 47 | }); 48 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 9 | 18 | 19 | 20 | 21 | 22 | 26 | 27 | 31 | 32 | 41 | nerfstudio 42 | 47 | 48 | 53 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 |
67 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "nerfstudio viewer", 3 | "name": "Interactive NeRF viewer", 4 | "start_url": ".", 5 | "display": "standalone", 6 | "theme_color": "#000000", 7 | "background_color": "#ffffff" 8 | } 9 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/requirements.txt: -------------------------------------------------------------------------------- 1 | tyro>=0.3.22 2 | sshconf==0.2.5 3 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/App.jsx: -------------------------------------------------------------------------------- 1 | import { CssBaseline, ThemeProvider } from '@mui/material'; 2 | import React from 'react'; 3 | import { 4 | SceneTreeWebSocketListener, 5 | get_scene_tree, 6 | } from './modules/Scene/Scene'; 7 | 8 | import Banner from './modules/Banner'; 9 | import { BasicTabs } from './modules/SidePanel/SidePanel'; 10 | import ViewerWindow from './modules/ViewerWindow/ViewerWindow'; 11 | import { appTheme } from './themes/theme.ts'; 12 | 13 | export default function App() { 14 | // The scene tree won't rerender but it will listen to changes 15 | // from the redux store and draw three.js objects. 16 | // In particular, it listens to changes to 'sceneState' coming over the websocket. 17 | const sceneTree = get_scene_tree(); 18 | 19 | return ( 20 | 21 | 22 |
23 | {/* Listens for websocket 'write' messages and updates the redux store. */} 24 | 25 | {/* The banner at the top of the page. */} 26 | 27 |
28 | {/* Order matters here. The viewer window must be rendered first. */} 29 | 30 |
31 | 32 |
33 |
34 |
35 |
36 | ); 37 | } 38 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/index.jsx: -------------------------------------------------------------------------------- 1 | import './index.scss'; 2 | import React from 'react'; 3 | import ReactDOM from 'react-dom'; 4 | import { Provider } from 'react-redux'; 5 | import App from './App'; 6 | import WebSocketProvider from './modules/WebSocket/WebSocket'; 7 | import store from './store'; 8 | 9 | const root = ReactDOM.createRoot(document.getElementById('root')); 10 | root.render( 11 | 12 | 13 | 14 | 15 | , 16 | ); 17 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/Banner/Banner.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { useDispatch } from 'react-redux'; 3 | 4 | import Button from '@mui/material/Button'; 5 | import GitHubIcon from '@mui/icons-material/GitHub'; 6 | import DescriptionRoundedIcon from '@mui/icons-material/DescriptionRounded'; 7 | import LandingModal from '../LandingModal'; 8 | import ViewportControlsModal from '../ViewportControlsModal'; 9 | 10 | function getParam(param_name) { 11 | // https://stackoverflow.com/questions/831030/how-to-get-get-request-parameters-in-javascript 12 | const params = new RegExp( 13 | `[?&]${encodeURIComponent(param_name)}=([^&]*)`, 14 | ).exec(window.location.href); 15 | if (params === null) { 16 | return undefined; 17 | } 18 | return decodeURIComponent(params[1]); 19 | } 20 | 21 | export default function Banner() { 22 | const dispatch = useDispatch(); 23 | 24 | let open_modal = true; 25 | 26 | // possibly set the websocket url 27 | const websocket_url_from_argument = getParam('websocket_url'); 28 | if (websocket_url_from_argument !== undefined) { 29 | open_modal = false; 30 | dispatch({ 31 | type: 'write', 32 | path: 'websocketState/websocket_url', 33 | data: websocket_url_from_argument, 34 | }); 35 | } 36 | 37 | return ( 38 |
39 | 40 | 50 | 60 | 61 | 62 |
63 | The favicon. 68 |
69 |
70 | ); 71 | } 72 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/Banner/index.jsx: -------------------------------------------------------------------------------- 1 | import Banner from './Banner'; 2 | 3 | export default Banner; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/ConfigPanel/ConfigPanelSlice.js: -------------------------------------------------------------------------------- 1 | // The function below is called a selector and allows us to select a value from 2 | // the state. Selectors can also be defined inline where they're used instead of 3 | // in the slice file. For example: `useSelector((state) => state.counter.value)` 4 | export const selectTrainingState = (state) => 5 | state.shared.rendering.training_state; 6 | export const selectOutputOptions = (state) => state.shared.output_options; 7 | export const selectColormapOptions = (state) => state.shared.colormap_options; 8 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/LandingModal/index.jsx: -------------------------------------------------------------------------------- 1 | import LandingModel from './LandingModal'; 2 | 3 | export default LandingModel; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/LoadPathModal/index.jsx: -------------------------------------------------------------------------------- 1 | import LoadPathModal from './LoadPathModal'; 2 | 3 | export default LoadPathModal; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/LogPanel/LogPanel.jsx: -------------------------------------------------------------------------------- 1 | import { useContext, useEffect } from 'react'; 2 | import { useDispatch, useSelector } from 'react-redux'; 3 | 4 | import { WebSocketContext } from '../WebSocket/WebSocket'; 5 | 6 | const msgpack = require('msgpack-lite'); 7 | 8 | export function LogPanel() { 9 | const websocket = useContext(WebSocketContext).socket; 10 | const dispatch = useDispatch(); 11 | const gpu_oom_error_msg = 'GPU out of memory'; 12 | const resolved_msg = 'resolved'; 13 | let local_error = resolved_msg; 14 | // connection status indicators 15 | 16 | const set_max_train_util = () => { 17 | if (websocket.readyState === WebSocket.OPEN) { 18 | dispatch({ 19 | type: 'write', 20 | path: 'renderingState/targetTrainUtil', 21 | data: 0.9, 22 | }); 23 | const cmd = 'write'; 24 | const path = 'renderingState/targetTrainUtil'; 25 | const data = { 26 | type: cmd, 27 | path, 28 | data: 0.9, 29 | }; 30 | const message = msgpack.encode(data); 31 | websocket.send(message); 32 | } 33 | }; 34 | 35 | const set_small_resolution = () => { 36 | if (websocket.readyState === WebSocket.OPEN) { 37 | dispatch({ 38 | type: 'write', 39 | path: 'renderingState/maxResolution', 40 | data: 512, 41 | }); 42 | const cmd = 'write'; 43 | const path = 'renderingState/maxResolution'; 44 | const data = { 45 | type: cmd, 46 | path, 47 | data: 512, 48 | }; 49 | const message = msgpack.encode(data); 50 | websocket.send(message); 51 | } 52 | }; 53 | 54 | const set_log_message = () => { 55 | if (websocket.readyState === WebSocket.OPEN) { 56 | dispatch({ 57 | type: 'write', 58 | path: 'renderingState/log_errors', 59 | data: resolved_msg, 60 | }); 61 | const cmd = 'write'; 62 | const path = 'renderingState/log_errors'; 63 | const data = { 64 | type: cmd, 65 | path, 66 | data: resolved_msg, 67 | }; 68 | const message = msgpack.encode(data); 69 | websocket.send(message); 70 | } 71 | }; 72 | 73 | const check_error = useSelector((state) => { 74 | local_error = state.renderingState.log_errors; 75 | if (local_error.includes(gpu_oom_error_msg)) { 76 | console.log(local_error); 77 | set_log_message(); 78 | set_small_resolution(); 79 | set_max_train_util(); 80 | } 81 | }); 82 | 83 | useEffect(() => {}, [check_error, local_error]); 84 | 85 | return null; 86 | } 87 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/RenderModal/RenderModal.jsx: -------------------------------------------------------------------------------- 1 | /* eslint-disable react/jsx-props-no-spreading */ 2 | import * as React from 'react'; 3 | 4 | import { Box, Button, Modal, Typography } from '@mui/material'; 5 | import { useSelector } from 'react-redux'; 6 | import ContentCopyRoundedIcon from '@mui/icons-material/ContentCopyRounded'; 7 | 8 | interface RenderModalProps { 9 | open: object; 10 | setOpen: object; 11 | } 12 | 13 | export default function RenderModal(props: RenderModalProps) { 14 | const open = props.open; 15 | const setOpen = props.setOpen; 16 | 17 | // redux store state 18 | const config_base_dir = useSelector( 19 | (state) => state.renderingState.config_base_dir, 20 | ); 21 | 22 | const export_path = useSelector((state) => state.renderingState.export_path); 23 | 24 | const data_base_dir = useSelector( 25 | (state) => state.renderingState.data_base_dir, 26 | ); 27 | 28 | // react state 29 | 30 | const handleClose = () => setOpen(false); 31 | 32 | // Copy the text inside the text field 33 | const config_filename = `${config_base_dir}/config.yml`; 34 | const camera_path_filename = `${export_path}.json`; 35 | const data_base_dir_leaf = data_base_dir.split('/').pop(); 36 | const cmd = `ns-render --load-config ${config_filename} --traj filename --camera-path-filename ${data_base_dir}/camera_paths/${camera_path_filename} --output-path renders/${data_base_dir_leaf}/${export_path}.mp4`; 37 | 38 | const text_intro = `To render a full resolution video, run the following command in a terminal.`; 39 | 40 | const handleCopy = () => { 41 | navigator.clipboard.writeText(cmd); 42 | handleClose(); 43 | }; 44 | 45 | return ( 46 |
47 | 53 | 54 | 59 |
60 |

Rendering

61 |

62 | {text_intro} 63 |
64 | The video will be saved to{' '} 65 | 66 | ./renders/{data_base_dir_leaf}/{export_path}.mp4 67 | 68 | . 69 |

70 | 71 |
{cmd}
72 |
73 | 82 |
83 |
84 |
85 |
86 |
87 |
88 | ); 89 | } 90 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/RenderModal/index.jsx: -------------------------------------------------------------------------------- 1 | import RenderModal from './RenderModal'; 2 | 3 | export default RenderModal; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/SidePanel/CameraPanel/CameraPropPanel.jsx: -------------------------------------------------------------------------------- 1 | import { useControls, useStoreContext } from 'leva'; 2 | import { useDispatch, useSelector } from 'react-redux'; 3 | 4 | export default function CameraPropPanel(props) { 5 | const seconds = props.seconds; 6 | const set_seconds = props.set_seconds; 7 | const fps = props.fps; 8 | const set_fps = props.set_fps; 9 | 10 | // redux store state 11 | const store = useStoreContext(); 12 | 13 | const dispatch = useDispatch(); 14 | 15 | // redux store state 16 | const render_height = useSelector( 17 | (state) => state.renderingState.render_height, 18 | ); 19 | const render_width = useSelector( 20 | (state) => state.renderingState.render_width, 21 | ); 22 | const camera_type = useSelector((state) => state.renderingState.camera_type); 23 | 24 | const export_path = useSelector((state) => state.renderingState.export_path); 25 | 26 | const setExportPath = (value) => { 27 | dispatch({ 28 | type: 'write', 29 | path: 'renderingState/export_path', 30 | data: value, 31 | }); 32 | }; 33 | 34 | const setResolution = (value) => { 35 | dispatch({ 36 | type: 'write', 37 | path: 'renderingState/render_width', 38 | data: value.width, 39 | }); 40 | dispatch({ 41 | type: 'write', 42 | path: 'renderingState/render_height', 43 | data: value.height, 44 | }); 45 | }; 46 | 47 | const setCameraType = (value) => { 48 | dispatch({ 49 | type: 'write', 50 | path: 'renderingState/camera_type', 51 | data: value, 52 | }); 53 | }; 54 | 55 | const [, setControls] = useControls( 56 | () => ({ 57 | path: { 58 | label: 'Export Name', 59 | value: export_path, 60 | onChange: (v) => { 61 | const valid_filename_reg = /^([a-z]|[A-Z]|[0-9]|-|_)+$/g; 62 | if(!valid_filename_reg.test(v)){ 63 | alert("Please only use letters, numbers, and hyphens"); 64 | } 65 | else { 66 | setExportPath(v); 67 | } 68 | }, 69 | 70 | }, 71 | camera_resolution: { 72 | label: 'Resolution', 73 | value: { width: render_width, height: render_height }, 74 | joystick: false, 75 | onChange: (v) => { 76 | setResolution(v); 77 | }, 78 | }, 79 | video_duration: { 80 | label: 'Duration (Sec)', 81 | value: seconds, 82 | min: 0.1, 83 | step: 0.1, 84 | onChange: (v) => { 85 | set_seconds(v); 86 | }, 87 | }, 88 | video_fps: { 89 | label: 'Framerate (FPS)', 90 | value: fps, 91 | min: 0.1, 92 | onChange: (v) => { 93 | set_fps(v); 94 | }, 95 | }, 96 | camera_type_selector: { 97 | label: 'Camera Type', 98 | value: camera_type, 99 | options: { 100 | Perspective: 'perspective', 101 | Fisheye: 'fisheye', 102 | Equirectangular: 'equirectangular', 103 | }, 104 | onChange: (v) => { 105 | setCameraType(v); 106 | }, 107 | }, 108 | }), 109 | { store }, 110 | ); 111 | 112 | setControls({path: export_path}); 113 | setControls({ video_fps: fps }); 114 | setControls({ video_duration: seconds }); 115 | setControls({ 116 | camera_resolution: { width: render_width, height: render_height }, 117 | }); 118 | setControls({ camera_type_selector: camera_type }); 119 | 120 | return null; 121 | } 122 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/SidePanel/CameraPanel/curve.js: -------------------------------------------------------------------------------- 1 | // Code for creating a curve from a set of points 2 | 3 | import * as THREE from 'three'; 4 | 5 | function get_catmull_rom_curve(list_of_3d_vectors, is_cycle, smoothness_value) { 6 | // TODO: add some hyperparameters to this function 7 | const curve = new THREE.CatmullRomCurve3( 8 | list_of_3d_vectors, 9 | is_cycle, 10 | // 'centripetal' 11 | 'catmullrom', 12 | smoothness_value, 13 | ); 14 | return curve; 15 | } 16 | 17 | export function get_curve_object_from_cameras( 18 | cameras, 19 | is_cycle, 20 | smoothness_value, 21 | ) { 22 | if (cameras.length === 0) { 23 | return null; 24 | } 25 | // interpolate positions, lookat directions, and ups 26 | // similar to 27 | // https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L281 28 | 29 | const positions = []; 30 | const lookats = []; 31 | const ups = []; 32 | const fovs = []; 33 | const render_times = []; 34 | 35 | for (let i = 0; i < cameras.length; i += 1) { 36 | const camera = cameras[i]; 37 | 38 | const up = new THREE.Vector3(0, 1, 0); // y is up in local space 39 | const lookat = new THREE.Vector3(0, 0, 1); // z is forward in local space 40 | 41 | up.applyQuaternion(camera.quaternion); 42 | lookat.applyQuaternion(camera.quaternion); 43 | 44 | positions.push(camera.position); 45 | ups.push(up); 46 | lookats.push(lookat); 47 | // Reuse catmullromcurve3 for 1d values. TODO fix this 48 | fovs.push(new THREE.Vector3(0, 0, camera.fov)); 49 | render_times.push(new THREE.Vector3(0, 0, camera.renderTime)); 50 | } 51 | 52 | let curve_positions = null; 53 | let curve_lookats = null; 54 | let curve_ups = null; 55 | let curve_fovs = null; 56 | let curve_render_times = null; 57 | 58 | curve_positions = get_catmull_rom_curve(positions, is_cycle, smoothness_value); 59 | curve_lookats = get_catmull_rom_curve(lookats, is_cycle, smoothness_value); 60 | curve_ups = get_catmull_rom_curve(ups, is_cycle, smoothness_value); 61 | curve_fovs = get_catmull_rom_curve(fovs, is_cycle, smoothness_value / 10); 62 | curve_render_times = get_catmull_rom_curve(render_times, is_cycle, smoothness_value); 63 | 64 | const curve_object = { 65 | curve_positions, 66 | curve_lookats, 67 | curve_ups, 68 | curve_fovs, 69 | curve_render_times, 70 | }; 71 | return curve_object; 72 | } 73 | 74 | export function get_transform_matrix(position, lookat, up) { 75 | // normalize the vectors 76 | lookat.normalize(); 77 | // make up orthogonal to lookat 78 | const up_proj = lookat.clone().multiplyScalar(up.dot(lookat)); 79 | up.sub(up_proj); 80 | up.normalize(); 81 | 82 | // create a copy of the vector up 83 | const up_copy = up.clone(); 84 | const cross = up_copy.cross(lookat); 85 | cross.normalize(); 86 | 87 | // create the camera transform matrix 88 | const mat = new THREE.Matrix4(); 89 | mat.set( 90 | cross.x, 91 | up.x, 92 | lookat.x, 93 | position.x, 94 | cross.y, 95 | up.y, 96 | lookat.y, 97 | position.y, 98 | cross.z, 99 | up.z, 100 | lookat.z, 101 | position.z, 102 | 0, 103 | 0, 104 | 0, 105 | 1, 106 | ); 107 | return mat; 108 | } 109 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/SidePanel/CameraPanel/index.jsx: -------------------------------------------------------------------------------- 1 | import CameraPanel from './CameraPanel'; 2 | 3 | export default CameraPanel; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/SidePanel/ExportPanel/index.jsx: -------------------------------------------------------------------------------- 1 | import ExportPanel from './ExportPanel'; 2 | 3 | export default ExportPanel; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/SidePanel/ScenePanel/index.jsx: -------------------------------------------------------------------------------- 1 | import ScenePanel from './ScenePanel'; 2 | 3 | export default ScenePanel; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/SidePanel/StatusPanel/index.jsx: -------------------------------------------------------------------------------- 1 | import StatusPanel from './StatusPanel'; 2 | 3 | export default StatusPanel; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/ViewerWindow/ViewerWindowSlice.js: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/nerfstudio/viewer/app/src/modules/ViewerWindow/ViewerWindowSlice.js -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/ViewportControlsModal/ViewportControlsModal.jsx: -------------------------------------------------------------------------------- 1 | /* eslint-disable react/jsx-props-no-spreading */ 2 | import * as React from 'react'; 3 | 4 | import { Box, Button, Modal } from '@mui/material'; 5 | import KeyboardIcon from '@mui/icons-material/Keyboard'; 6 | 7 | export default function ControlsModal() { 8 | const [open, setOpen] = React.useState(false); 9 | const handleOpen = () => setOpen(true); 10 | const handleClose = () => setOpen(false); 11 | 12 | return ( 13 |
14 | 23 | 29 | 30 |
31 | The favicon. 36 | Controls diagram 41 |
42 |
43 |
44 |
45 | ); 46 | } 47 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/ViewportControlsModal/index.jsx: -------------------------------------------------------------------------------- 1 | import ViewportControlsModal from './ViewportControlsModal'; 2 | 3 | export default ViewportControlsModal; 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/WebSocket/WebSocket.jsx: -------------------------------------------------------------------------------- 1 | // Much of this code comes from or is inspired by: 2 | // https://www.pluralsight.com/guides/using-web-sockets-in-your-reactredux-app 3 | 4 | import React, { createContext, useEffect } from 'react'; 5 | import { useDispatch, useSelector } from 'react-redux'; 6 | 7 | import PropTypes from 'prop-types'; 8 | 9 | const WebSocketContext = createContext(null); 10 | 11 | export { WebSocketContext }; 12 | 13 | export default function WebSocketContextFunction({ children }) { 14 | const dispatch = useDispatch(); 15 | let ws = null; 16 | let socket = null; 17 | 18 | // this code will rerender anytime the websocket url changes 19 | const websocket_url = useSelector( 20 | (state) => state.websocketState.websocket_url, 21 | ); 22 | 23 | const connect = () => { 24 | // of the form wss://ip_address:port 25 | console.log(websocket_url); 26 | try { 27 | socket = new WebSocket(websocket_url); 28 | } catch (error) { 29 | socket = new WebSocket('ws://localhost:7007'); 30 | } 31 | socket.binaryType = 'arraybuffer'; 32 | socket.onopen = () => { 33 | dispatch({ 34 | type: 'write', 35 | path: 'websocketState/isConnected', 36 | data: true, 37 | }); 38 | }; 39 | 40 | socket.onclose = () => { 41 | // when closed, the websocket will try to reconnect every second 42 | dispatch({ 43 | type: 'write', 44 | path: 'websocketState/isConnected', 45 | data: false, 46 | }); 47 | }; 48 | 49 | socket.onerror = (err) => { 50 | console.error( 51 | 'Socket encountered error: ', 52 | err.message, 53 | 'Closing socket', 54 | ); 55 | socket.close(); 56 | }; 57 | return socket; 58 | }; 59 | 60 | useEffect(() => { 61 | // cleanup function to close the websocket on rerender 62 | return () => { 63 | if (socket !== null) { 64 | socket.close(); 65 | } 66 | }; 67 | }, [websocket_url]); 68 | 69 | connect(); 70 | ws = { 71 | socket, 72 | }; 73 | 74 | return ( 75 | {children} 76 | ); 77 | } 78 | 79 | WebSocketContextFunction.propTypes = { 80 | children: PropTypes.node.isRequired, 81 | }; 82 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/modules/WebSocketUrlField.jsx: -------------------------------------------------------------------------------- 1 | import * as React from 'react'; 2 | 3 | import { TextField, Link } from '@mui/material'; 4 | import { useDispatch, useSelector } from 'react-redux'; 5 | 6 | export default function WebSocketUrlField() { 7 | const dispatch = useDispatch(); 8 | 9 | // websocket url 10 | const websocket_url = useSelector( 11 | (state) => state.websocketState.websocket_url, 12 | ); 13 | const websocket_url_onchange = (event) => { 14 | const value = event.target.value; 15 | dispatch({ 16 | type: 'write', 17 | path: 'websocketState/websocket_url', 18 | data: value, 19 | }); 20 | }; 21 | 22 | const testWebSocket = (url) => { 23 | try { 24 | // eslint-disable-next-line no-new 25 | new WebSocket(url); 26 | return false; 27 | } catch (error) { 28 | return true; 29 | } 30 | }; 31 | 32 | return ( 33 |
34 | 44 | 45 | viewer.nerf.studio?websocket_url={websocket_url} 46 | 47 |
48 | ); 49 | } 50 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/setupTests.js: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import '@testing-library/jest-dom/extend-expect'; 6 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/store.js: -------------------------------------------------------------------------------- 1 | import { configureStore } from '@reduxjs/toolkit'; 2 | import rootReducer from './reducer'; 3 | 4 | export default configureStore({ reducer: rootReducer }); 5 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/subscriber.js: -------------------------------------------------------------------------------- 1 | import { useContext } from 'react'; 2 | import { ReactReduxContext } from 'react-redux'; 3 | 4 | export function subscribe_to_changes(selector_fn, fn) { 5 | // selector_fn: returns a value from the redux state 6 | // fn_valid: function to run on a valid input 7 | // fn_null: function to run on a null input 8 | const { store } = useContext(ReactReduxContext); 9 | 10 | let current; 11 | const handleChange = () => { 12 | const previous = current; 13 | current = selector_fn(store.getState()); 14 | if (previous !== current) { 15 | fn(previous, current); 16 | } 17 | }; 18 | store.subscribe(handleChange); 19 | } 20 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/themes/leva_theme.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors": { 3 | "elevation1": "#292d39", 4 | "elevation2": "#222831", 5 | "elevation3": "#393E46", 6 | "accent1": "#ffc640", 7 | "accent2": "#FFD369", 8 | "accent3": "#ffd65c", 9 | "highlight1": "#d1d4db", 10 | "highlight2": "#EEEEEE", 11 | "highlight3": "#222831", 12 | "disabled": "#595959", 13 | "vivid1": "#ffcc00" 14 | }, 15 | "radii": { 16 | "xs": "2px", 17 | "sm": "4px", 18 | "lg": "10px" 19 | }, 20 | "space": { 21 | "sm": "6px", 22 | "md": "5px", 23 | "rowGap": "6px", 24 | "colGap": "7px" 25 | }, 26 | "fontSizes": { 27 | "root": "11px" 28 | }, 29 | "sizes": { 30 | "rootWidth": "310px", 31 | "controlWidth": "170px", 32 | "scrubberWidth": "8px", 33 | "scrubberHeight": "16px", 34 | "rowHeight": "24px", 35 | "folderHeight": "20px", 36 | "checkboxSize": "16px", 37 | "joystickWidth": "100px", 38 | "joystickHeight": "100px", 39 | "colorPickerWidth": "160px", 40 | "colorPickerHeight": "100px", 41 | "monitorHeight": "60px", 42 | "titleBarHeight": "39px" 43 | }, 44 | "borderWidths": { 45 | "root": "0px", 46 | "input": "1px", 47 | "focus": "1px", 48 | "hover": "1px", 49 | "active": "1px", 50 | "folder": "1px" 51 | }, 52 | "fontWeights": { 53 | "label": "normal", 54 | "folder": "normal", 55 | "button": "normal" 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /nerfstudio/viewer/app/src/utils.js: -------------------------------------------------------------------------------- 1 | export function split_path(path_str) { 2 | return path_str.split('/').filter((x) => x.length > 0); 3 | } 4 | -------------------------------------------------------------------------------- /nerfstudio/viewer/server/README.md: -------------------------------------------------------------------------------- 1 | # Python Kernel and Client Viewer App communication 2 | 3 | > The purpose of this document is to explain how to communicate from Python with the Client Viewer app. We will eventually move this into the read the docs. 4 | 5 | - Python Kernel (nerfstudio code) 6 | - Bridge Server 7 | - Client Viewer App 8 | 9 | We have two types of components that we want to keep state updated in both locations. 10 | 11 | - Widgets 12 | - The widgets are used to keep track of the 13 | - SceneNode 14 | - The scene nodes are used to represent the three.js objects. The properties relevant to these objects are the following: `"object", "transform", "properties"`. 15 | 16 | # Checklist 17 | 18 | - [ ] Currently using request-reply (REQ, REP with zmq). I.e., Python Kernel -> Bridge Server <-> Client Viewer App. We want a way to update the Python Kernel when the Bridge Server is updated. This requries some form of binding with callbacks. When the Bridge Server is updated, we want to update the binded Python variable. We can take inspiration from [ipywidgets](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20Basics.html). 19 | 20 | - [ ] 21 | -------------------------------------------------------------------------------- /nerfstudio/viewer/server/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /nerfstudio/viewer/server/path.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Path class 16 | """ 17 | 18 | 19 | from typing import Tuple 20 | 21 | UNICODE = str 22 | 23 | 24 | class Path: 25 | """Path class 26 | 27 | Args: 28 | entries: component parts of the path 29 | """ 30 | 31 | __slots__ = ["entries"] 32 | 33 | def __init__(self, entries: Tuple = tuple()): 34 | self.entries = entries 35 | 36 | def append(self, other: str) -> "Path": 37 | """Method that appends a new component and returns new Path 38 | 39 | Args: 40 | other: _description_ 41 | """ 42 | new_path = self.entries 43 | for element in other.split("/"): 44 | if len(element) == 0: 45 | new_path = tuple() 46 | else: 47 | new_path = new_path + (element,) 48 | return Path(new_path) 49 | 50 | def lower(self): 51 | """Convert path object to serializable format""" 52 | return UNICODE("/" + "/".join(self.entries)) 53 | 54 | def __hash__(self): 55 | return hash(self.entries) 56 | 57 | def __eq__(self, other): 58 | return self.entries == other.entries 59 | -------------------------------------------------------------------------------- /nerfstudio/viewer/server/state/node.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | For tree logic code. 17 | """ 18 | 19 | from collections import defaultdict 20 | from typing import Callable 21 | 22 | 23 | class Node(defaultdict): 24 | """ 25 | The base class Node. 26 | """ 27 | 28 | def __init__(self, *args, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | 31 | 32 | def get_tree(node_class: Callable) -> Callable: 33 | """ 34 | Get a tree from a node class. 35 | This allows one to do tree["path"]["to"]["node"] 36 | and it will return a new node if it doesn't exist 37 | or the current node if it does. 38 | """ 39 | assert isinstance(node_class(), Node) 40 | tree = lambda: node_class(tree) 41 | return tree() 42 | 43 | 44 | def find_node(tree, path): 45 | if len(path) == 0: 46 | return tree 47 | else: 48 | return find_node(tree[path[0]], path[1:]) 49 | 50 | 51 | def set_node_value(tree, path, value): 52 | if len(path) == 0: 53 | tree.data = value 54 | else: 55 | set_node_value(tree[path[0]], path[1:], value) 56 | 57 | 58 | def walk(path, tree): 59 | """Walk the entire tree and return the values 60 | Args: 61 | tree: the root of the tree to start search 62 | """ 63 | yield path, tree 64 | for k, v in tree.items(): 65 | yield from walk(path + "/" + k, v) 66 | -------------------------------------------------------------------------------- /nerfstudio/viewer/server/state/state_node.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | 17 | from nerfstudio.viewer.server.state.node import Node 18 | 19 | 20 | class StateNode(Node): 21 | """Node that holds a hierarchy of state nodes""" 22 | 23 | __slots__ = ["data"] 24 | 25 | def __init__(self, *args, **kwargs): 26 | super().__init__(*args, **kwargs) 27 | self.path = None 28 | self.data = None 29 | -------------------------------------------------------------------------------- /nerfstudio/viewer/server/video_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Video Stream objects for WebRTC""" 16 | 17 | import numpy as np 18 | from aiortc import VideoStreamTrack 19 | from av import VideoFrame 20 | 21 | 22 | class SingleFrameStreamTrack(VideoStreamTrack): 23 | """Single Frame stream class: pushes single frames to a stream""" 24 | 25 | def __init__(self): 26 | super().__init__() 27 | self.background_frame = np.ones((480, 640, 3), dtype="uint8") * 100 # gray background 28 | self.frame = None 29 | self.put_frame(self.background_frame) 30 | 31 | def put_frame(self, frame: np.ndarray) -> None: 32 | """Sets the current viewing frame 33 | 34 | Args: 35 | frame: image to be viewed 36 | """ 37 | self.frame = VideoFrame.from_ndarray(frame) 38 | 39 | async def recv(self): 40 | """Async method to grab and wait on frame""" 41 | pts, time_base = await self.next_timestamp() 42 | 43 | frame = self.frame 44 | frame.pts = pts 45 | frame.time_base = time_base 46 | return frame 47 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/benchmarking/launch_eval_blender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | helpFunction_launch_eval() 4 | { 5 | echo "Usage: $0 -m -o -t [-s] []" 6 | echo -e "\t-m name of method to benchmark (e.g. nerfacto, instant-ngp)" 7 | echo -e "\t-o base directory for where all the benchmarks are stored (e.g. outputs/)" 8 | echo -e "\t-t : if using launch_train_blender.sh will be of format %Y-%m-%d_%H%M%S" 9 | echo -e "\t-s: Launch a single evaluation job per gpu." 10 | echo -e "\t [OPTIONAL] list of space-separated gpu numbers to launch train on (e.g. 0 2 4 5)" 11 | exit 1 # Exit program after printing help 12 | } 13 | 14 | single=false 15 | while getopts "m:o:t:s" opt; do 16 | case "$opt" in 17 | m ) method_name="$OPTARG" ;; 18 | o ) output_dir="$OPTARG" ;; 19 | t ) timestamp="$OPTARG" ;; 20 | s ) single=true ;; 21 | ? ) helpFunction_launch_eval ;; 22 | esac 23 | done 24 | 25 | if [ -z "$method_name" ]; then 26 | echo "Missing method name" 27 | helpFunction_launch_eval 28 | fi 29 | 30 | if [ -z "$output_dir" ]; then 31 | echo "Missing output directory location" 32 | helpFunction_launch_eval 33 | fi 34 | 35 | if [ -z "$timestamp" ]; then 36 | echo "Missing timestamp specification" 37 | helpFunction_launch_eval 38 | fi 39 | 40 | shift $((OPTIND-1)) 41 | 42 | # Deal with gpu's. If passed in, use those. 43 | GPU_IDX=("$@") 44 | if [ -z "${GPU_IDX[0]+x}" ]; then 45 | echo "no gpus set... finding available gpus" 46 | # Find available devices 47 | num_device=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 48 | START=0 49 | END=${num_device}-1 50 | GPU_IDX=() 51 | 52 | for (( id=START; id<=END; id++ )); do 53 | free_mem=$(nvidia-smi --query-gpu=memory.free --format=csv -i $id | grep -Eo '[0-9]+') 54 | if [[ $free_mem -gt 10000 ]]; then 55 | GPU_IDX+=( "$id" ) 56 | fi 57 | done 58 | fi 59 | echo "available gpus... ${GPU_IDX[*]}" 60 | 61 | DATASETS=("mic" "ficus" "chair" "hotdog" "materials" "drums" "ship" "lego") 62 | idx=0 63 | len=${#GPU_IDX[@]} 64 | GPU_PID=() 65 | # kill all the background jobs if terminated: 66 | trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT 67 | 68 | for dataset in "${DATASETS[@]}"; do 69 | if "$single" && [ -n "${GPU_PID[$idx]+x}" ]; then 70 | wait "${GPU_PID[$idx]}" 71 | fi 72 | export CUDA_VISIBLE_DEVICES=${GPU_IDX[$idx]} 73 | config_path="${output_dir}/blender_${dataset}_${timestamp::-7}/${method_name}/${timestamp}/config.yml" 74 | ns-eval --load-config="${config_path}" \ 75 | --output-path="${output_dir}/${method_name}/blender_${dataset}_${timestamp}.json" & GPU_PID[$idx]=$! 76 | echo "Launched ${config_path} on gpu ${GPU_IDX[$idx]}" 77 | 78 | # update gpu 79 | ((idx=(idx+1)%len)) 80 | done 81 | wait 82 | echo "Done." 83 | -------------------------------------------------------------------------------- /scripts/blender/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/scripts/blender/__init__.py -------------------------------------------------------------------------------- /scripts/completions/.gitignore: -------------------------------------------------------------------------------- 1 | bash/ 2 | zsh/ 3 | -------------------------------------------------------------------------------- /scripts/completions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/scripts/completions/__init__.py -------------------------------------------------------------------------------- /scripts/completions/setup.bash: -------------------------------------------------------------------------------- 1 | # nerfstudio completions for bash. 2 | # 3 | # This should generally be installed automatically by `configure.py`. 4 | 5 | completions_dir="$(dirname "$BASH_SOURCE")"/bash 6 | 7 | if [ ! -d "${completions_dir}" ]; then 8 | echo "$0: Completions are missing!" 9 | echo "Please generate them with nerfstudio/scripts/completions/generate.py!" 10 | return 1 11 | fi 12 | 13 | # Source each completion script. 14 | for completion_path in ${completions_dir}/* 15 | do 16 | source $completion_path 17 | done 18 | -------------------------------------------------------------------------------- /scripts/completions/setup.zsh: -------------------------------------------------------------------------------- 1 | # nerfstudio completions for zsh. 2 | # 3 | # This should generally be installed automatically by `configure.py`. 4 | 5 | completions_dir="${0:a:h}"/zsh 6 | 7 | if [ ! -d "${completions_dir}" ]; then 8 | echo "$0: Completions are missing!" 9 | echo "Please generate them with nerfstudio/scripts/completions/generate.py!" 10 | return 1 11 | fi 12 | 13 | # Manually load and define each completion. 14 | # 15 | # Adding the completions directory to our fpath and re-initializing would work 16 | # as well: 17 | # fpath+=${completions_dir} 18 | # autoload -Uz compinit; compinit 19 | # But would be several orders of magnitude slower. 20 | for completion_path in ${completions_dir}/* 21 | do 22 | # /some/path/to/_our_completion_py => _our_completion_py 23 | completion_name=${completion_path##*/} 24 | if [[ $name == *_py ]]; then 25 | # _our_completion_py => our_completion.py 26 | script_name="${completion_name:1:-3}.py" 27 | else 28 | # _entry-point => entry-point 29 | script_name="${completion_name:1}" 30 | fi 31 | 32 | autoload -Uz $completion_path 33 | compdef $completion_name $script_name 34 | done 35 | -------------------------------------------------------------------------------- /scripts/docs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/scripts/docs/__init__.py -------------------------------------------------------------------------------- /scripts/docs/add_nb_tags.py: -------------------------------------------------------------------------------- 1 | """Helper that add tags to notebooks based on cell comments.""" 2 | 3 | import sys 4 | from glob import glob 5 | 6 | import nbformat as nbf 7 | import tyro 8 | from rich.console import Console 9 | 10 | CONSOLE = Console(width=120) 11 | 12 | 13 | def main(check: bool = False): 14 | """Add tags to notebooks based on cell comments. 15 | 16 | In notebook cells, you can add the following tags to the notebook by adding a comment: 17 | "# HIDDEN" - This cell will be hidden from the notebook. 18 | "# OUTPUT_ONLY" - This cell will only show the output. 19 | "# COLLAPSED" - Hide the code and include a button to show the code. 20 | 21 | Args: 22 | check: check will not modify the notebooks. 23 | """ 24 | # Collect a list of all notebooks in the content folder 25 | notebooks = glob("./docs/**/*.ipynb", recursive=True) 26 | 27 | # Text to look for in adding tags 28 | text_search_dict = { 29 | "# HIDDEN": "remove-cell", # Remove the whole cell 30 | "# OUTPUT_ONLY": "remove-input", # Remove only the input 31 | "# COLLAPSED": "hide-input", # Hide the input w/ a button to show 32 | } 33 | 34 | # Search through each notebook and look for the text, add a tag if necessary 35 | any_missing = False 36 | for ipath in notebooks: 37 | ntbk = nbf.read(ipath, nbf.NO_CONVERT) 38 | 39 | incorrect_metadata = False 40 | for cell in ntbk.cells: 41 | cell_tags = cell.get("metadata", {}).get("tags", []) 42 | found_keys = [] 43 | found_tags = [] 44 | for key, val in text_search_dict.items(): 45 | if key in cell.source: 46 | found_keys.append(key) 47 | found_tags.append(val) 48 | 49 | if len(found_keys) > 1: 50 | CONSOLE.print(f"[bold yellow]Found multiple tags {found_keys} for {ipath}") 51 | sys.exit(1) 52 | 53 | if len(cell_tags) != len(found_tags): 54 | incorrect_metadata = True 55 | elif len(cell_tags) == 1 and len(found_keys) == 1: 56 | if found_tags[0] != cell_tags[0]: 57 | incorrect_metadata = True 58 | 59 | cell["metadata"]["tags"] = found_tags 60 | if incorrect_metadata: 61 | if check: 62 | CONSOLE.print( 63 | f"[bold yellow]{ipath} has incorrect metadata. Call `python scripts.docs.add_nb_tags.py` to add it." 64 | ) 65 | any_missing = True 66 | else: 67 | print(f"Adding metadata to {ipath}") 68 | nbf.write(ntbk, ipath) 69 | 70 | if not any_missing: 71 | CONSOLE.print("[green]All notebooks have correct metadata.") 72 | 73 | if check and any_missing: 74 | sys.exit(1) 75 | 76 | 77 | if __name__ == "__main__": 78 | tyro.extras.set_accent_color("bright_yellow") 79 | tyro.cli(main) 80 | -------------------------------------------------------------------------------- /scripts/docs/build_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Simple yaml debugger""" 3 | import subprocess 4 | import sys 5 | 6 | import tyro 7 | from rich.console import Console 8 | from rich.style import Style 9 | 10 | CONSOLE = Console(width=120) 11 | 12 | LOCAL_TESTS = ["Run license checks", "Run Black", "Python Pylint", "Test with pytest"] 13 | 14 | 15 | def run_command(command: str) -> None: 16 | """Run a command kill actions if it fails 17 | 18 | Args: 19 | command: command to run 20 | """ 21 | ret_code = subprocess.call(command, shell=True) 22 | if ret_code != 0: 23 | CONSOLE.print(f"[bold red]Error: `{command}` failed. Exiting...") 24 | sys.exit(1) 25 | 26 | 27 | def main(clean_cache: bool = False): 28 | """Run the github actions locally. 29 | 30 | Args: 31 | clean_cache: whether to clean the cache before building docs. 32 | """ 33 | 34 | CONSOLE.print("[green]Adding notebook documentation metadata") 35 | run_command("python scripts/docs/add_nb_tags.py") 36 | 37 | # Add checks for building documentation 38 | CONSOLE.print("[green]Building Documentation") 39 | if clean_cache: 40 | run_command("cd docs/; make clean; make html SPHINXOPTS='-W;'") 41 | else: 42 | run_command("cd docs/; make html SPHINXOPTS='-W;'") 43 | 44 | CONSOLE.line() 45 | CONSOLE.rule(characters="=", style=Style(color="green")) 46 | CONSOLE.print("[bold green]Done") 47 | CONSOLE.rule(characters="=", style=Style(color="green")) 48 | 49 | 50 | if __name__ == "__main__": 51 | tyro.extras.set_accent_color("bright_yellow") 52 | tyro.cli(main) 53 | -------------------------------------------------------------------------------- /scripts/downloads/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/scripts/downloads/__init__.py -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | eval.py 4 | """ 5 | from __future__ import annotations 6 | 7 | import json 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | 11 | import tyro 12 | from rich.console import Console 13 | 14 | from nerfstudio.utils.eval_utils import eval_setup 15 | 16 | CONSOLE = Console(width=120) 17 | 18 | 19 | @dataclass 20 | class ComputePSNR: 21 | """Load a checkpoint, compute some PSNR metrics, and save it to a JSON file.""" 22 | 23 | # Path to config YAML file. 24 | load_config: Path 25 | # Name of the output file. 26 | output_path: Path = Path("output.json") 27 | 28 | def main(self) -> None: 29 | """Main function.""" 30 | config, pipeline, checkpoint_path = eval_setup(self.load_config) 31 | assert self.output_path.suffix == ".json" 32 | metrics_dict = pipeline.get_average_eval_image_metrics() 33 | self.output_path.parent.mkdir(parents=True, exist_ok=True) 34 | # Get the output and define the names to save to 35 | benchmark_info = { 36 | "experiment_name": config.experiment_name, 37 | "method_name": config.method_name, 38 | "checkpoint": str(checkpoint_path), 39 | "results": metrics_dict, 40 | } 41 | # Save output to output file 42 | self.output_path.write_text(json.dumps(benchmark_info, indent=2), "utf8") 43 | CONSOLE.print(f"Saved results to: {self.output_path}") 44 | 45 | 46 | def entrypoint(): 47 | """Entrypoint for use with pyproject scripts.""" 48 | tyro.extras.set_accent_color("bright_yellow") 49 | tyro.cli(ComputePSNR).main() 50 | 51 | 52 | if __name__ == "__main__": 53 | entrypoint() 54 | 55 | # For sphinx docs 56 | get_parser_fn = lambda: tyro.extras.get_parser(ComputePSNR) # noqa 57 | -------------------------------------------------------------------------------- /scripts/generative/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/scripts/generative/__init__.py -------------------------------------------------------------------------------- /scripts/github/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/scripts/github/__init__.py -------------------------------------------------------------------------------- /scripts/licensing/copyright.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /scripts/licensing/license_headers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | VALID_ARGS=$(getopt -o c --long check -- "$@") 4 | 5 | eval set -- "$VALID_ARGS" 6 | check=false 7 | while [ : ]; do 8 | case "$1" in 9 | -c | --check) 10 | check=true 11 | shift 12 | ;; 13 | --) shift; 14 | break 15 | ;; 16 | esac 17 | done 18 | 19 | check_failed=false 20 | added_headers=false 21 | for i in $(find nerfstudio/ -name '*.py'); 22 | do 23 | if ! grep -q Copyright $i 24 | then 25 | if [ "$check" = true ]; 26 | then 27 | echo "$i missing copyright header" 28 | check_failed=true 29 | else 30 | cat scripts/licensing/copyright.txt $i >$i.new && mv $i.new $i 31 | echo "Adding license header to $i." 32 | fi 33 | added_headers=true 34 | fi 35 | done 36 | 37 | if [ "$check_failed" = true ]; 38 | then 39 | echo "Run '.scripts/licensing/license_headers.sh to add missing headers.'" 40 | exit 1 41 | fi 42 | 43 | if [ "$added_headers" = false ]; 44 | then 45 | echo "No missing license headers found." 46 | fi 47 | 48 | exit 0 -------------------------------------------------------------------------------- /scripts/texture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to texture an existing mesh file. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from dataclasses import dataclass 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import tyro 12 | from rich.console import Console 13 | from typing_extensions import Literal 14 | 15 | from nerfstudio.exporter import texture_utils 16 | from nerfstudio.exporter.exporter_utils import get_mesh_from_filename 17 | from nerfstudio.utils.eval_utils import eval_setup 18 | 19 | CONSOLE = Console(width=120) 20 | 21 | 22 | @dataclass 23 | class TextureMesh: 24 | """ 25 | Export a textured mesh with color computed from the NeRF. 26 | """ 27 | 28 | load_config: Path 29 | """Path to the config YAML file.""" 30 | output_dir: Path 31 | """Path to the output directory.""" 32 | input_mesh_filename: Path 33 | """Mesh filename to texture.""" 34 | px_per_uv_triangle: int = 4 35 | """Number of pixels per UV square.""" 36 | unwrap_method: Literal["xatlas", "custom"] = "xatlas" 37 | """The method to use for unwrapping the mesh.""" 38 | num_pixels_per_side: int = 2048 39 | """If using xatlas for unwrapping, the pixels per side of the texture image.""" 40 | target_num_faces: Optional[int] = 50000 41 | """Target number of faces for the mesh to texture.""" 42 | 43 | def main(self) -> None: 44 | """Export textured mesh""" 45 | # pylint: disable=too-many-statements 46 | 47 | if not self.output_dir.exists(): 48 | self.output_dir.mkdir(parents=True) 49 | 50 | # load the Mesh 51 | mesh = get_mesh_from_filename(str(self.input_mesh_filename), target_num_faces=self.target_num_faces) 52 | 53 | # load the Pipeline 54 | _, pipeline, _ = eval_setup(self.load_config, test_mode="inference") 55 | 56 | # texture the mesh with NeRF and export to a mesh.obj file 57 | # and a material and texture file 58 | texture_utils.export_textured_mesh( 59 | mesh=mesh, 60 | pipeline=pipeline, 61 | output_dir=self.output_dir, 62 | px_per_uv_triangle=self.px_per_uv_triangle, 63 | unwrap_method=self.unwrap_method, 64 | num_pixels_per_side=self.num_pixels_per_side, 65 | ) 66 | 67 | 68 | def entrypoint(): 69 | """Entrypoint for use with pyproject scripts.""" 70 | tyro.extras.set_accent_color("bright_yellow") 71 | tyro.cli(tyro.conf.FlagConversionOff[TextureMesh]).main() 72 | 73 | 74 | if __name__ == "__main__": 75 | entrypoint() 76 | 77 | # For sphinx docs 78 | get_parser_fn = lambda: tyro.extras.get_parser(TextureMesh) # noqa 79 | -------------------------------------------------------------------------------- /scripts/viewer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masked-spacetime-hashing/msth/6f085f3643b589c2c548f65b3d74350611e2bff6/scripts/viewer/__init__.py -------------------------------------------------------------------------------- /scripts/viewer/view_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | view_dataset.py 4 | """ 5 | 6 | import time 7 | from datetime import timedelta 8 | from pathlib import Path 9 | 10 | import torch 11 | import tyro 12 | from rich.console import Console 13 | 14 | from nerfstudio.configs.base_config import ViewerConfig 15 | from nerfstudio.data.datamanagers.base_datamanager import AnnotatedDataParserUnion 16 | from nerfstudio.data.datasets.base_dataset import InputDataset 17 | from nerfstudio.viewer.server import viewer_utils 18 | 19 | DEFAULT_TIMEOUT = timedelta(minutes=30) 20 | CONSOLE = Console(width=120) 21 | 22 | # speedup for when input size to model doesn't change (much) 23 | torch.backends.cudnn.benchmark = True # type: ignore 24 | 25 | 26 | def main( 27 | dataparser: AnnotatedDataParserUnion, 28 | viewer: ViewerConfig, 29 | log_base_dir: Path = Path("/tmp/nerfstudio_viewer_logs"), 30 | ) -> None: 31 | """Main function.""" 32 | viewer_state, _ = viewer_utils.setup_viewer( 33 | viewer, 34 | log_filename=log_base_dir / viewer.relative_log_filename, 35 | datapath=dataparser.data, 36 | ) 37 | dataset = InputDataset(dataparser.setup().get_dataparser_outputs(split="train")) 38 | viewer_state.init_scene(dataset=dataset, start_train=False) 39 | CONSOLE.log("Please refresh and load page at: %s", viewer_state.viewer_url) 40 | time.sleep(30) # allowing time to refresh page 41 | 42 | 43 | if __name__ == "__main__": 44 | tyro.extras.set_accent_color("bright_yellow") 45 | tyro.cli(main) 46 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | emport math 2 | import numpy as np 3 | 4 | T = 2 ** 19 5 | base_res = 16 6 | max_res = 2048 7 | num_levels = 16 8 | growth_factor = np.exp((np.log(max_res) - np.log(base_res)) / (num_levels - 1)) 9 | 10 | R = base_res 11 | for level in range(num_levels): 12 | a = [0 for i in range(T)] 13 | for i in range(R): 14 | for j in range(R): 15 | for k in range(R): 16 | hash_value = (i * 1) ^ (j * 2654435761) ^ (k * 805459861) 17 | hash_value = hash_value % T 18 | a[hash_value] += 1 19 | 20 | print(T) 21 | print(R) 22 | a = np.array(a) 23 | print(f"collision rate: {(a>1).sum()/T}") 24 | R = int(R * growth_factor) 25 | --------------------------------------------------------------------------------