├── .gitignore ├── README.md ├── freenerf ├── freenerf.py ├── freenerf_config.py ├── freenerf_field.py ├── freenerf_pipeline.py └── util.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | *egg-info 2 | *__pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FreeNeRF implement with nerfstudio 2 | An unofficial implementation (only implement freq_mask) for [FreeNeRF: Improving Few-shot Neural Rendering with Free Frequency Regularization](https://arxiv.org/abs/2303.07418) based [nerfstudio](https://github.com/nerfstudio-project/nerfstudio) . 3 | 4 | # how to use 5 | 6 | ## install nerfstudio 7 | 8 | follow the guide in [nerfstudio](https://github.com/nerfstudio-project/nerfstudio) to install nerfstudio. 9 | 10 | 11 | ## register freenerf 12 | 13 | ```python 14 | pip install -e . 15 | ``` 16 | 17 | ## train 18 | 19 | ```python 20 | ns-train freenerfacto --data DATADIR 21 | ``` 22 | 23 | 24 | ## notice 25 | 26 | due to the nerfstudio's update (mainly due to the change of nerfacc's version),this repository may not work on the latest nerfstudio :( .I will fix that when im available. 27 | -------------------------------------------------------------------------------- /freenerf/freenerf.py: -------------------------------------------------------------------------------- 1 | from nerfstudio.models.nerfacto import NerfactoModel, NerfactoModelConfig 2 | from collections import defaultdict 3 | from dataclasses import dataclass, field 4 | from typing import Dict, List, Tuple, Type 5 | 6 | import numpy as np 7 | import torch 8 | from nerfstudio.cameras.rays import RayBundle, RaySamples 9 | from nerfstudio.data.scene_box import SceneBox 10 | from nerfstudio.field_components.field_heads import FieldHeadNames 11 | from nerfstudio.field_components.spatial_distortions import SceneContraction 12 | from nerfstudio.model_components.ray_samplers import PDFSampler 13 | from nerfstudio.model_components.renderers import DepthRenderer 14 | from torch.nn import Parameter 15 | from dataclasses import dataclass, field 16 | from typing import Dict, List, Tuple, Type 17 | 18 | import numpy as np 19 | import torch 20 | from torch.nn import Parameter 21 | from torchmetrics import PeakSignalNoiseRatio 22 | from torchmetrics.functional import structural_similarity_index_measure 23 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 24 | from typing_extensions import Literal 25 | 26 | from nerfstudio.cameras.rays import RayBundle 27 | from nerfstudio.engine.callbacks import ( 28 | TrainingCallback, 29 | TrainingCallbackAttributes, 30 | TrainingCallbackLocation, 31 | ) 32 | from nerfstudio.field_components.field_heads import FieldHeadNames 33 | from nerfstudio.field_components.spatial_distortions import SceneContraction 34 | from nerfstudio.fields.density_fields import HashMLPDensityField 35 | from nerfstudio.fields.nerfacto_field import TCNNNerfactoField, TorchNerfactoField 36 | from nerfstudio.model_components.losses import ( 37 | MSELoss, 38 | distortion_loss, 39 | interlevel_loss, 40 | orientation_loss, 41 | pred_normal_loss, 42 | ) 43 | from nerfstudio.model_components.ray_samplers import ( 44 | ProposalNetworkSampler, 45 | UniformSampler, 46 | ) 47 | from nerfstudio.model_components.renderers import ( 48 | AccumulationRenderer, 49 | DepthRenderer, 50 | NormalsRenderer, 51 | RGBRenderer, 52 | ) 53 | from nerfstudio.model_components.scene_colliders import NearFarCollider 54 | from nerfstudio.model_components.shaders import NormalsShader 55 | from nerfstudio.models.base_model import Model, ModelConfig 56 | from nerfstudio.utils import colormaps 57 | 58 | from freenerf.freenerf_field import TorchFreeNerfactoField 59 | 60 | @dataclass 61 | class FreeNeRFactoConfig(NerfactoModelConfig): 62 | _target: Type = field(default_factory=lambda: FreeNeRFactoModel) 63 | 64 | 65 | class FreeNeRFactoModel(NerfactoModel): 66 | config: FreeNeRFactoConfig 67 | def populate_modules(self): 68 | super().populate_modules() 69 | if self.config.disable_scene_contraction: 70 | scene_contraction = None 71 | else: 72 | scene_contraction = SceneContraction(order=float("inf")) 73 | self.field = TorchFreeNerfactoField( 74 | self.scene_box.aabb, spatial_distortion=scene_contraction, num_images=self.num_train_data 75 | ) 76 | pass 77 | def get_param_groups(self) -> Dict[str, List[Parameter]]: 78 | param_groups = {} 79 | param_groups["proposal_networks"] = list(self.proposal_networks.parameters()) 80 | param_groups["fields"] = list(self.field.parameters()) 81 | return param_groups 82 | 83 | def get_training_callbacks( 84 | self, training_callback_attributes: TrainingCallbackAttributes 85 | ) -> List[TrainingCallback]: 86 | callbacks = [] 87 | if self.config.use_proposal_weight_anneal: 88 | # anneal the weights of the proposal network before doing PDF sampling 89 | N = self.config.proposal_weights_anneal_max_num_iters 90 | 91 | def set_anneal(step): 92 | # https://arxiv.org/pdf/2111.12077.pdf eq. 18 93 | train_frac = np.clip(step / N, 0, 1) 94 | bias = lambda x, b: (b * x) / ((b - 1) * x + 1) 95 | anneal = bias(train_frac, self.config.proposal_weights_anneal_slope) 96 | self.proposal_sampler.set_anneal(anneal) 97 | 98 | callbacks.append( 99 | TrainingCallback( 100 | where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], 101 | update_every_num_iters=1, 102 | func=set_anneal, 103 | ) 104 | ) 105 | callbacks.append( 106 | TrainingCallback( 107 | where_to_run=[TrainingCallbackLocation.AFTER_TRAIN_ITERATION], 108 | update_every_num_iters=1, 109 | func=self.proposal_sampler.step_cb, 110 | ) 111 | ) 112 | return callbacks 113 | # TODO 114 | def forward(self, ray_bundle: RayBundle,freq_mask=None) -> Dict[str, torch.Tensor]: 115 | """Run forward starting with a ray bundle. This outputs different things depending on the configuration 116 | of the model and whether or not the batch is provided (whether or not we are training basically) 117 | 118 | Args: 119 | ray_bundle: containing all the information needed to render that ray latents included 120 | """ 121 | 122 | if self.collider is not None: 123 | # 在这里完成了远近平面的设置 124 | ray_bundle = self.collider(ray_bundle) 125 | 126 | return self.get_outputs(ray_bundle,freq_mask) 127 | def get_outputs(self, ray_bundle: RayBundle,freq_mask=None): 128 | # print(step) 129 | ray_samples, weights_list, ray_samples_list = self.proposal_sampler(ray_bundle, density_fns=self.density_fns) 130 | field_outputs = self.field(ray_samples, compute_normals=self.config.predict_normals,freq_mask=freq_mask) 131 | weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY]) 132 | weights_list.append(weights) 133 | ray_samples_list.append(ray_samples) 134 | 135 | rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights) 136 | depth = self.renderer_depth(weights=weights, ray_samples=ray_samples) 137 | accumulation = self.renderer_accumulation(weights=weights) 138 | 139 | outputs = { 140 | "rgb": rgb, 141 | "accumulation": accumulation, 142 | "depth": depth, 143 | } 144 | 145 | if self.config.predict_normals: 146 | normals = self.renderer_normals(normals=field_outputs[FieldHeadNames.NORMALS], weights=weights) 147 | pred_normals = self.renderer_normals(field_outputs[FieldHeadNames.PRED_NORMALS], weights=weights) 148 | outputs["normals"] = self.normals_shader(normals) 149 | outputs["pred_normals"] = self.normals_shader(pred_normals) 150 | # These use a lot of GPU memory, so we avoid storing them for eval. 151 | if self.training: 152 | outputs["weights_list"] = weights_list 153 | outputs["ray_samples_list"] = ray_samples_list 154 | 155 | if self.training and self.config.predict_normals: 156 | outputs["rendered_orientation_loss"] = orientation_loss( 157 | weights.detach(), field_outputs[FieldHeadNames.NORMALS], ray_bundle.directions 158 | ) 159 | 160 | outputs["rendered_pred_normal_loss"] = pred_normal_loss( 161 | weights.detach(), 162 | field_outputs[FieldHeadNames.NORMALS].detach(), 163 | field_outputs[FieldHeadNames.PRED_NORMALS], 164 | ) 165 | 166 | for i in range(self.config.num_proposal_iterations): 167 | outputs[f"prop_depth_{i}"] = self.renderer_depth(weights=weights_list[i], ray_samples=ray_samples_list[i]) 168 | 169 | return outputs 170 | 171 | def get_metrics_dict(self, outputs, batch): 172 | metrics_dict = {} 173 | image = batch["image"].to(self.device) 174 | metrics_dict["psnr"] = self.psnr(outputs["rgb"], image) 175 | if self.training: 176 | metrics_dict["distortion"] = distortion_loss(outputs["weights_list"], outputs["ray_samples_list"]) 177 | return metrics_dict 178 | 179 | def get_loss_dict(self, outputs, batch, metrics_dict=None): 180 | loss_dict = {} 181 | image = batch["image"].to(self.device) 182 | loss_dict["rgb_loss"] = self.rgb_loss(image, outputs["rgb"]) 183 | if self.training: 184 | loss_dict["interlevel_loss"] = self.config.interlevel_loss_mult * interlevel_loss( 185 | outputs["weights_list"], outputs["ray_samples_list"] 186 | ) 187 | assert metrics_dict is not None and "distortion" in metrics_dict 188 | loss_dict["distortion_loss"] = self.config.distortion_loss_mult * metrics_dict["distortion"] 189 | if self.config.predict_normals: 190 | # orientation loss for computed normals 191 | loss_dict["orientation_loss"] = self.config.orientation_loss_mult * torch.mean( 192 | outputs["rendered_orientation_loss"] 193 | ) 194 | 195 | # ground truth supervision for normals 196 | loss_dict["pred_normal_loss"] = self.config.pred_normal_loss_mult * torch.mean( 197 | outputs["rendered_pred_normal_loss"] 198 | ) 199 | return loss_dict 200 | 201 | def get_image_metrics_and_images( 202 | self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] 203 | ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: 204 | image = batch["image"].to(self.device) 205 | rgb = outputs["rgb"] 206 | acc = colormaps.apply_colormap(outputs["accumulation"]) 207 | depth = colormaps.apply_depth_colormap( 208 | outputs["depth"], 209 | accumulation=outputs["accumulation"], 210 | ) 211 | 212 | combined_rgb = torch.cat([image, rgb], dim=1) 213 | combined_acc = torch.cat([acc], dim=1) 214 | combined_depth = torch.cat([depth], dim=1) 215 | 216 | # Switch images from [H, W, C] to [1, C, H, W] for metrics computations 217 | image = torch.moveaxis(image, -1, 0)[None, ...] 218 | rgb = torch.moveaxis(rgb, -1, 0)[None, ...] 219 | 220 | psnr = self.psnr(image, rgb) 221 | ssim = self.ssim(image, rgb) 222 | lpips = self.lpips(image, rgb) 223 | 224 | # all of these metrics will be logged as scalars 225 | metrics_dict = {"psnr": float(psnr.item()), "ssim": float(ssim)} # type: ignore 226 | metrics_dict["lpips"] = float(lpips) 227 | 228 | images_dict = {"img": combined_rgb, "accumulation": combined_acc, "depth": combined_depth} 229 | 230 | for i in range(self.config.num_proposal_iterations): 231 | key = f"prop_depth_{i}" 232 | prop_depth_i = colormaps.apply_depth_colormap( 233 | outputs[key], 234 | accumulation=outputs["accumulation"], 235 | ) 236 | images_dict[key] = prop_depth_i 237 | 238 | return metrics_dict, images_dict 239 | -------------------------------------------------------------------------------- /freenerf/freenerf_config.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import tyro 4 | from nerfacc import ContractionType 5 | 6 | 7 | from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig 8 | from nerfstudio.configs.base_config import ViewerConfig 9 | from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManagerConfig 10 | from nerfstudio.data.datamanagers.depth_datamanager import DepthDataManagerConfig 11 | from nerfstudio.data.datamanagers.sdf_datamanager import SDFDataManagerConfig 12 | from nerfstudio.data.datamanagers.semantic_datamanager import SemanticDataManagerConfig 13 | from nerfstudio.data.datamanagers.variable_res_datamanager import ( 14 | VariableResDataManagerConfig, 15 | ) 16 | from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig 17 | from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig 18 | from nerfstudio.data.dataparsers.dycheck_dataparser import DycheckDataParserConfig 19 | from nerfstudio.data.dataparsers.instant_ngp_dataparser import ( 20 | InstantNGPDataParserConfig, 21 | ) 22 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig 23 | from nerfstudio.data.dataparsers.phototourism_dataparser import ( 24 | PhototourismDataParserConfig, 25 | ) 26 | from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig 27 | from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig 28 | from nerfstudio.engine.optimizers import AdamOptimizerConfig, RAdamOptimizerConfig 29 | from nerfstudio.engine.schedulers import ( 30 | CosineDecaySchedulerConfig, 31 | ExponentialDecaySchedulerConfig, 32 | ) 33 | from nerfstudio.engine.trainer import TrainerConfig 34 | from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind 35 | from nerfstudio.models.depth_nerfacto import DepthNerfactoModelConfig 36 | from nerfstudio.models.instant_ngp import InstantNGPModelConfig 37 | from nerfstudio.models.mipnerf import MipNerfModel 38 | from nerfstudio.models.nerfacto import NerfactoModelConfig 39 | from nerfstudio.models.nerfplayer_nerfacto import NerfplayerNerfactoModelConfig 40 | from nerfstudio.models.nerfplayer_ngp import NerfplayerNGPModelConfig 41 | from nerfstudio.models.neus import NeuSModelConfig 42 | from nerfstudio.models.semantic_nerfw import SemanticNerfWModelConfig 43 | from nerfstudio.models.tensorf import TensoRFModelConfig 44 | from nerfstudio.models.vanilla_nerf import NeRFModel, VanillaModelConfig 45 | from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig 46 | 47 | from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig 48 | from nerfstudio.plugins.registry import discover_methods 49 | from freenerf.util import freenerfT 50 | 51 | from nerfstudio.plugins.types import MethodSpecification 52 | from freenerf.freenerf import FreeNeRFactoModel,FreeNeRFactoConfig 53 | 54 | from freenerf.freenerf_pipeline import FreeNeRFactoPipelineConfig,FreeNeRFactoPipeline 55 | 56 | freenerfacto_method = MethodSpecification(config=TrainerConfig( 57 | method_name="freenerfacto", 58 | steps_per_eval_batch=500, 59 | steps_per_save=2000, 60 | max_num_iterations=freenerfT.max_num_iterations, 61 | mixed_precision=True, 62 | pipeline=FreeNeRFactoPipelineConfig( 63 | datamanager=VanillaDataManagerConfig( 64 | dataparser=NerfstudioDataParserConfig(), 65 | train_num_rays_per_batch=4096, 66 | eval_num_rays_per_batch=4096, 67 | camera_optimizer=CameraOptimizerConfig( 68 | mode="SO3xR3", optimizer=AdamOptimizerConfig(lr=6e-4, eps=1e-8, weight_decay=1e-2) 69 | ), 70 | ), 71 | model=FreeNeRFactoConfig(eval_num_rays_per_chunk=1 << 15), 72 | ), 73 | optimizers={ 74 | "proposal_networks": { 75 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 76 | "scheduler": None, 77 | }, 78 | "fields": { 79 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 80 | "scheduler": None, 81 | }, 82 | }, 83 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 84 | vis="viewer", 85 | ), 86 | description="Custom description" 87 | ) 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /freenerf/freenerf_field.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Field for compound nerf model, adds scene contraction and image embeddings to instant ngp 17 | """ 18 | 19 | 20 | from typing import Dict, Optional, Tuple 21 | 22 | import numpy as np 23 | import torch 24 | from torch import nn 25 | from torch.nn.parameter import Parameter 26 | from torchtyping import TensorType 27 | from nerfstudio.field_components.activations import trunc_exp 28 | from nerfstudio.cameras.rays import RaySamples 29 | from nerfstudio.data.scene_box import SceneBox 30 | from nerfstudio.field_components.activations import trunc_exp 31 | from nerfstudio.field_components.embedding import Embedding 32 | from nerfstudio.field_components.encodings import Encoding, HashEncoding, SHEncoding 33 | from nerfstudio.field_components.field_heads import ( 34 | DensityFieldHead, 35 | FieldHead, 36 | FieldHeadNames, 37 | PredNormalsFieldHead, 38 | RGBFieldHead, 39 | SemanticFieldHead, 40 | TransientDensityFieldHead, 41 | TransientRGBFieldHead, 42 | UncertaintyFieldHead, 43 | ) 44 | from nerfstudio.field_components.mlp import MLP 45 | from nerfstudio.field_components.spatial_distortions import ( 46 | SceneContraction, 47 | SpatialDistortion, 48 | ) 49 | from nerfstudio.fields.base_field import Field, shift_directions_for_tcnn 50 | 51 | try: 52 | import tinycudann as tcnn 53 | except ImportError: 54 | # tinycudann module doesn't exist 55 | pass 56 | 57 | 58 | class TCNNNerfactoField(Field): 59 | """Compound Field that uses TCNN 60 | 61 | Args: 62 | aabb: parameters of scene aabb bounds 63 | num_images: number of images in the dataset 64 | num_layers: number of hidden layers 65 | hidden_dim: dimension of hidden layers 66 | geo_feat_dim: output geo feat dimensions 67 | num_levels: number of levels of the hashmap for the base mlp 68 | max_res: maximum resolution of the hashmap for the base mlp 69 | log2_hashmap_size: size of the hashmap for the base mlp 70 | num_layers_color: number of hidden layers for color network 71 | num_layers_transient: number of hidden layers for transient network 72 | hidden_dim_color: dimension of hidden layers for color network 73 | hidden_dim_transient: dimension of hidden layers for transient network 74 | appearance_embedding_dim: dimension of appearance embedding 75 | transient_embedding_dim: dimension of transient embedding 76 | use_transient_embedding: whether to use transient embedding 77 | use_semantics: whether to use semantic segmentation 78 | num_semantic_classes: number of semantic classes 79 | use_pred_normals: whether to use predicted normals 80 | use_average_appearance_embedding: whether to use average appearance embedding or zeros for inference 81 | spatial_distortion: spatial distortion to apply to the scene 82 | """ 83 | 84 | def __init__( 85 | self, 86 | aabb: TensorType, 87 | num_images: int, 88 | num_layers: int = 2, 89 | hidden_dim: int = 64, 90 | geo_feat_dim: int = 15, 91 | num_levels: int = 16, 92 | max_res: int = 2048, 93 | log2_hashmap_size: int = 19, 94 | num_layers_color: int = 3, 95 | num_layers_transient: int = 2, 96 | hidden_dim_color: int = 64, 97 | hidden_dim_transient: int = 64, 98 | appearance_embedding_dim: int = 32, 99 | transient_embedding_dim: int = 16, 100 | use_transient_embedding: bool = False, 101 | use_semantics: bool = False, 102 | num_semantic_classes: int = 100, 103 | pass_semantic_gradients: bool = False, 104 | use_pred_normals: bool = False, 105 | use_average_appearance_embedding: bool = False, 106 | spatial_distortion: SpatialDistortion = None, 107 | ) -> None: 108 | super().__init__() 109 | 110 | self.register_buffer("aabb", aabb) 111 | self.geo_feat_dim = geo_feat_dim 112 | 113 | self.register_buffer("max_res", torch.tensor(max_res)) 114 | self.register_buffer("num_levels", torch.tensor(num_levels)) 115 | self.register_buffer("log2_hashmap_size", torch.tensor(log2_hashmap_size)) 116 | 117 | self.spatial_distortion = spatial_distortion 118 | self.num_images = num_images 119 | self.appearance_embedding_dim = appearance_embedding_dim 120 | self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim) 121 | self.use_average_appearance_embedding = use_average_appearance_embedding 122 | self.use_transient_embedding = use_transient_embedding 123 | self.use_semantics = use_semantics 124 | self.use_pred_normals = use_pred_normals 125 | self.pass_semantic_gradients = pass_semantic_gradients 126 | 127 | base_res: int = 16 128 | features_per_level: int = 2 129 | growth_factor = np.exp((np.log(max_res) - np.log(base_res)) / (num_levels - 1)) 130 | 131 | self.direction_encoding = tcnn.Encoding( 132 | n_input_dims=3, 133 | encoding_config={ 134 | "otype": "SphericalHarmonics", 135 | "degree": 4, 136 | }, 137 | ) 138 | 139 | self.position_encoding = tcnn.Encoding( 140 | n_input_dims=3, 141 | encoding_config={"otype": "Frequency", "n_frequencies": 2}, 142 | ) 143 | 144 | self.mlp_base = tcnn.NetworkWithInputEncoding( 145 | n_input_dims=3, 146 | n_output_dims=1 + self.geo_feat_dim, 147 | encoding_config={ 148 | "otype": "HashGrid", 149 | "n_levels": num_levels, 150 | "n_features_per_level": features_per_level, 151 | "log2_hashmap_size": log2_hashmap_size, 152 | "base_resolution": base_res, 153 | "per_level_scale": growth_factor, 154 | }, 155 | network_config={ 156 | "otype": "FullyFusedMLP", 157 | "activation": "ReLU", 158 | "output_activation": "None", 159 | "n_neurons": hidden_dim, 160 | "n_hidden_layers": num_layers - 1, 161 | }, 162 | ) 163 | 164 | # transients 165 | if self.use_transient_embedding: 166 | self.transient_embedding_dim = transient_embedding_dim 167 | self.embedding_transient = Embedding(self.num_images, self.transient_embedding_dim) 168 | self.mlp_transient = tcnn.Network( 169 | n_input_dims=self.geo_feat_dim + self.transient_embedding_dim, 170 | n_output_dims=hidden_dim_transient, 171 | network_config={ 172 | "otype": "FullyFusedMLP", 173 | "activation": "ReLU", 174 | "output_activation": "None", 175 | "n_neurons": hidden_dim_transient, 176 | "n_hidden_layers": num_layers_transient - 1, 177 | }, 178 | ) 179 | self.field_head_transient_uncertainty = UncertaintyFieldHead(in_dim=self.mlp_transient.n_output_dims) 180 | self.field_head_transient_rgb = TransientRGBFieldHead(in_dim=self.mlp_transient.n_output_dims) 181 | self.field_head_transient_density = TransientDensityFieldHead(in_dim=self.mlp_transient.n_output_dims) 182 | 183 | # semantics 184 | if self.use_semantics: 185 | self.mlp_semantics = tcnn.Network( 186 | n_input_dims=self.geo_feat_dim, 187 | n_output_dims=hidden_dim_transient, 188 | network_config={ 189 | "otype": "FullyFusedMLP", 190 | "activation": "ReLU", 191 | "output_activation": "None", 192 | "n_neurons": 64, 193 | "n_hidden_layers": 1, 194 | }, 195 | ) 196 | self.field_head_semantics = SemanticFieldHead( 197 | in_dim=self.mlp_semantics.n_output_dims, num_classes=num_semantic_classes 198 | ) 199 | 200 | # predicted normals 201 | if self.use_pred_normals: 202 | self.mlp_pred_normals = tcnn.Network( 203 | n_input_dims=self.geo_feat_dim + self.position_encoding.n_output_dims, 204 | n_output_dims=hidden_dim_transient, 205 | network_config={ 206 | "otype": "FullyFusedMLP", 207 | "activation": "ReLU", 208 | "output_activation": "None", 209 | "n_neurons": 64, 210 | "n_hidden_layers": 2, 211 | }, 212 | ) 213 | self.field_head_pred_normals = PredNormalsFieldHead(in_dim=self.mlp_pred_normals.n_output_dims) 214 | 215 | self.mlp_head = tcnn.Network( 216 | n_input_dims=self.direction_encoding.n_output_dims + self.geo_feat_dim + self.appearance_embedding_dim, 217 | n_output_dims=3, 218 | network_config={ 219 | "otype": "FullyFusedMLP", 220 | "activation": "ReLU", 221 | "output_activation": "Sigmoid", 222 | "n_neurons": hidden_dim_color, 223 | "n_hidden_layers": num_layers_color - 1, 224 | }, 225 | ) 226 | 227 | def get_density(self, ray_samples: RaySamples) -> Tuple[TensorType, TensorType]: 228 | """Computes and returns the densities.""" 229 | if self.spatial_distortion is not None: 230 | positions = ray_samples.frustums.get_positions() 231 | positions = self.spatial_distortion(positions) 232 | positions = (positions + 2.0) / 4.0 233 | else: 234 | positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb) 235 | # Make sure the tcnn gets inputs between 0 and 1. 236 | selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1) 237 | positions = positions * selector[..., None] 238 | self._sample_locations = positions 239 | if not self._sample_locations.requires_grad: 240 | self._sample_locations.requires_grad = True 241 | positions_flat = positions.view(-1, 3) 242 | h = self.mlp_base(positions_flat).view(*ray_samples.frustums.shape, -1) 243 | density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1) 244 | self._density_before_activation = density_before_activation 245 | 246 | # Rectifying the density with an exponential is much more stable than a ReLU or 247 | # softplus, because it enables high post-activation (float32) density outputs 248 | # from smaller internal (float16) parameters. 249 | density = trunc_exp(density_before_activation.to(positions)) 250 | density = density * selector[..., None] 251 | return density, base_mlp_out 252 | 253 | def get_outputs( 254 | self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None 255 | ) -> Dict[FieldHeadNames, TensorType]: 256 | assert density_embedding is not None 257 | outputs = {} 258 | if ray_samples.camera_indices is None: 259 | raise AttributeError("Camera indices are not provided.") 260 | camera_indices = ray_samples.camera_indices.squeeze() 261 | directions = shift_directions_for_tcnn(ray_samples.frustums.directions) 262 | directions_flat = directions.view(-1, 3) 263 | d = self.direction_encoding(directions_flat) 264 | 265 | outputs_shape = ray_samples.frustums.directions.shape[:-1] 266 | 267 | # appearance 268 | if self.training: 269 | embedded_appearance = self.embedding_appearance(camera_indices) 270 | else: 271 | if self.use_average_appearance_embedding: 272 | embedded_appearance = torch.ones( 273 | (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device 274 | ) * self.embedding_appearance.mean(dim=0) 275 | else: 276 | embedded_appearance = torch.zeros( 277 | (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device 278 | ) 279 | 280 | # transients 281 | if self.use_transient_embedding and self.training: 282 | embedded_transient = self.embedding_transient(camera_indices) 283 | transient_input = torch.cat( 284 | [ 285 | density_embedding.view(-1, self.geo_feat_dim), 286 | embedded_transient.view(-1, self.transient_embedding_dim), 287 | ], 288 | dim=-1, 289 | ) 290 | x = self.mlp_transient(transient_input).view(*outputs_shape, -1).to(directions) 291 | outputs[FieldHeadNames.UNCERTAINTY] = self.field_head_transient_uncertainty(x) 292 | outputs[FieldHeadNames.TRANSIENT_RGB] = self.field_head_transient_rgb(x) 293 | outputs[FieldHeadNames.TRANSIENT_DENSITY] = self.field_head_transient_density(x) 294 | 295 | # semantics 296 | if self.use_semantics: 297 | semantics_input = density_embedding.view(-1, self.geo_feat_dim) 298 | if not self.pass_semantic_gradients: 299 | semantics_input = semantics_input.detach() 300 | 301 | x = self.mlp_semantics(semantics_input).view(*outputs_shape, -1).to(directions) 302 | outputs[FieldHeadNames.SEMANTICS] = self.field_head_semantics(x) 303 | 304 | # predicted normals 305 | if self.use_pred_normals: 306 | positions = ray_samples.frustums.get_positions() 307 | 308 | positions_flat = self.position_encoding(positions.view(-1, 3)) 309 | pred_normals_inp = torch.cat([positions_flat, density_embedding.view(-1, self.geo_feat_dim)], dim=-1) 310 | 311 | x = self.mlp_pred_normals(pred_normals_inp).view(*outputs_shape, -1).to(directions) 312 | outputs[FieldHeadNames.PRED_NORMALS] = self.field_head_pred_normals(x) 313 | 314 | h = torch.cat( 315 | [ 316 | d, 317 | density_embedding.view(-1, self.geo_feat_dim), 318 | embedded_appearance.view(-1, self.appearance_embedding_dim), 319 | ], 320 | dim=-1, 321 | ) 322 | rgb = self.mlp_head(h).view(*outputs_shape, -1).to(directions) 323 | outputs.update({FieldHeadNames.RGB: rgb}) 324 | 325 | return outputs 326 | 327 | 328 | class TorchFreeNerfactoField(Field): 329 | """ 330 | PyTorch implementation of the compound field. 331 | """ 332 | 333 | def __init__( 334 | self, 335 | aabb: TensorType, 336 | num_images: int, 337 | geo_feat_dim: int = 15, 338 | hidden_dim: int = 64, 339 | num_layers: int = 2, 340 | num_layers_color: int = 3, 341 | hidden_dim_color: int = 64, 342 | num_levels: int = 16, 343 | max_res: int = 2048, 344 | hidden_dim_transient: int = 64, 345 | transient_embedding_dim: int = 16, 346 | num_semantic_classes: int = 100, 347 | num_layers_transient: int = 2, 348 | log2_hashmap_size: int = 19, 349 | appearance_embedding_dim: int = 40, 350 | use_average_appearance_embedding: bool = False, 351 | use_transient_embedding: bool = False, 352 | use_semantics: bool = False, 353 | use_pred_normals: bool = False, 354 | # field_heads: Tuple[FieldHead] = (RGBFieldHead(),), 355 | spatial_distortion: SpatialDistortion = SceneContraction(), 356 | ) -> None: 357 | super().__init__() 358 | self.aabb = Parameter(aabb, requires_grad=False) 359 | self.spatial_distortion = spatial_distortion 360 | self.num_images = num_images 361 | self.appearance_embedding_dim = appearance_embedding_dim 362 | self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim) 363 | self.geo_feat_dim = geo_feat_dim 364 | self.use_average_appearance_embedding=use_average_appearance_embedding 365 | self.use_transient_embedding=use_transient_embedding 366 | self.use_semantics=use_semantics 367 | self.use_pred_normals=use_pred_normals 368 | base_res: int = 16 369 | features_per_level: int = 2 370 | growth_factor = np.exp((np.log(max_res) - np.log(base_res)) / (num_levels - 1)) 371 | self.direction_encoding = tcnn.Encoding( 372 | n_input_dims=3, 373 | encoding_config={ 374 | "otype": "SphericalHarmonics", 375 | "degree": 4, 376 | }, 377 | ) 378 | self.position_encoding = tcnn.Encoding( 379 | n_input_dims=3, 380 | encoding_config={ 381 | "otype": "HashGrid", 382 | "n_levels": num_levels, 383 | "n_features_per_level": features_per_level, 384 | "log2_hashmap_size": log2_hashmap_size, 385 | "base_resolution": base_res, 386 | "per_level_scale": growth_factor,}, 387 | ) 388 | ## Fit it to FreeNeRF 389 | self.mlp_base = tcnn.Network( 390 | n_input_dims=self.position_encoding.n_output_dims, 391 | n_output_dims=1 + self.geo_feat_dim, 392 | network_config={ 393 | "otype": "FullyFusedMLP", 394 | "activation": "ReLU", 395 | "output_activation": "None", 396 | "n_neurons": hidden_dim, 397 | "n_hidden_layers": num_layers - 1, 398 | } 399 | ) 400 | if self.use_transient_embedding: 401 | self.transient_embedding_dim = transient_embedding_dim 402 | self.embedding_transient = Embedding(self.num_images, self.transient_embedding_dim) 403 | self.mlp_transient = tcnn.Network( 404 | n_input_dims=self.geo_feat_dim + self.transient_embedding_dim, 405 | n_output_dims=hidden_dim_transient, 406 | network_config={ 407 | "otype": "FullyFusedMLP", 408 | "activation": "ReLU", 409 | "output_activation": "None", 410 | "n_neurons": hidden_dim_transient, 411 | "n_hidden_layers": num_layers_transient - 1, 412 | }, 413 | ) 414 | self.field_head_transient_uncertainty = UncertaintyFieldHead(in_dim=self.mlp_transient.n_output_dims) 415 | self.field_head_transient_rgb = TransientRGBFieldHead(in_dim=self.mlp_transient.n_output_dims) 416 | self.field_head_transient_density = TransientDensityFieldHead(in_dim=self.mlp_transient.n_output_dims) 417 | 418 | # semantics 419 | if self.use_semantics: 420 | self.mlp_semantics = tcnn.Network( 421 | n_input_dims=self.geo_feat_dim, 422 | n_output_dims=hidden_dim_transient, 423 | network_config={ 424 | "otype": "FullyFusedMLP", 425 | "activation": "ReLU", 426 | "output_activation": "None", 427 | "n_neurons": 64, 428 | "n_hidden_layers": 1, 429 | }, 430 | ) 431 | self.field_head_semantics = SemanticFieldHead( 432 | in_dim=self.mlp_semantics.n_output_dims, num_classes=num_semantic_classes 433 | ) 434 | 435 | # predicted normals 436 | if self.use_pred_normals: 437 | self.mlp_pred_normals = tcnn.Network( 438 | n_input_dims=self.geo_feat_dim + self.position_encoding.n_output_dims, 439 | n_output_dims=hidden_dim_transient, 440 | network_config={ 441 | "otype": "FullyFusedMLP", 442 | "activation": "ReLU", 443 | "output_activation": "None", 444 | "n_neurons": 64, 445 | "n_hidden_layers": 2, 446 | }, 447 | ) 448 | self.field_head_pred_normals = PredNormalsFieldHead(in_dim=self.mlp_pred_normals.n_output_dims) 449 | 450 | self.mlp_head = tcnn.Network( 451 | n_input_dims=self.direction_encoding.n_output_dims + self.geo_feat_dim + self.appearance_embedding_dim, 452 | n_output_dims=3, 453 | network_config={ 454 | "otype": "FullyFusedMLP", 455 | "activation": "ReLU", 456 | "output_activation": "Sigmoid", 457 | "n_neurons": hidden_dim_color, 458 | "n_hidden_layers": num_layers_color - 1, 459 | }, 460 | ) 461 | 462 | # self.field_output_density = DensityFieldHead(in_dim=self.mlp_base.get_out_dim()) 463 | # self.field_heads = nn.ModuleList(field_heads) 464 | # for field_head in self.field_heads: 465 | # field_head.set_in_dim(self.mlp_head.get_out_dim()) # type: ignore 466 | 467 | def get_density(self, ray_samples: RaySamples,freq_mask=None) -> Tuple[TensorType, TensorType]: 468 | if self.spatial_distortion is not None: 469 | positions = ray_samples.frustums.get_positions() 470 | positions = self.spatial_distortion(positions) 471 | positions = (positions + 2.0) / 4.0 472 | else: 473 | positions = ray_samples.frustums.get_positions() 474 | selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1) 475 | positions = positions * selector[..., None] 476 | self._sample_locations = positions 477 | if freq_mask is not None : 478 | (positions_freq_mask,dir_freq_mask) = freq_mask 479 | 480 | encoded_xyz = self.position_encoding(positions.view(-1, 3))*positions_freq_mask.view(1,-1) 481 | else: 482 | encoded_xyz = self.position_encoding(positions.view(-1, 3)) 483 | if not self._sample_locations.requires_grad: 484 | self._sample_locations.requires_grad = True 485 | h = self.mlp_base(encoded_xyz).view(*ray_samples.frustums.shape, -1) 486 | density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1) 487 | self._density_before_activation = density_before_activation 488 | density = trunc_exp(density_before_activation.to(positions)) 489 | 490 | density = density * selector[..., None] 491 | # density = self.field_output_density(base_mlp_out) 492 | return density, base_mlp_out 493 | 494 | def get_outputs( 495 | self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None,freq_mask=None 496 | ) -> Dict[FieldHeadNames, TensorType]: 497 | outputs_shape = ray_samples.frustums.directions.shape[:-1] 498 | assert density_embedding is not None 499 | if ray_samples.camera_indices is None: 500 | raise AttributeError("Camera indices are not provided.") 501 | camera_indices = ray_samples.camera_indices.squeeze() 502 | directions = shift_directions_for_tcnn(ray_samples.frustums.directions) 503 | directions_flat = directions.view(-1, 3) 504 | d = self.direction_encoding(directions_flat) 505 | outputs_shape = ray_samples.frustums.directions.shape[:-1] 506 | if self.training: 507 | embedded_appearance = self.embedding_appearance(camera_indices) 508 | else: 509 | embedded_appearance = torch.zeros( 510 | (*outputs_shape, self.appearance_embedding_dim), 511 | device=ray_samples.frustums.directions.device, 512 | ) 513 | 514 | outputs = {} 515 | if freq_mask is not None : 516 | (positions_freq_mask,dir_freq_mask) = freq_mask 517 | d=d*dir_freq_mask.view(1,-1) 518 | if self.training: 519 | embedded_appearance = self.embedding_appearance(camera_indices) 520 | else: 521 | if self.use_average_appearance_embedding: 522 | embedded_appearance = torch.ones( 523 | (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device 524 | ) * self.embedding_appearance.mean(dim=0) 525 | else: 526 | embedded_appearance = torch.zeros( 527 | (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device 528 | ) 529 | if self.use_transient_embedding and self.training: 530 | embedded_transient = self.embedding_transient(camera_indices) 531 | transient_input = torch.cat( 532 | [ 533 | density_embedding.view(-1, self.geo_feat_dim), 534 | embedded_transient.view(-1, self.transient_embedding_dim), 535 | ], 536 | dim=-1, 537 | ) 538 | x = self.mlp_transient(transient_input).view(*outputs_shape, -1).to(directions) 539 | outputs[FieldHeadNames.UNCERTAINTY] = self.field_head_transient_uncertainty(x) 540 | outputs[FieldHeadNames.TRANSIENT_RGB] = self.field_head_transient_rgb(x) 541 | outputs[FieldHeadNames.TRANSIENT_DENSITY] = self.field_head_transient_density(x) 542 | 543 | # semantics 544 | if self.use_semantics: 545 | semantics_input = density_embedding.view(-1, self.geo_feat_dim) 546 | if not self.pass_semantic_gradients: 547 | semantics_input = semantics_input.detach() 548 | 549 | x = self.mlp_semantics(semantics_input).view(*outputs_shape, -1).to(directions) 550 | outputs[FieldHeadNames.SEMANTICS] = self.field_head_semantics(x) 551 | 552 | # predicted normals 553 | if self.use_pred_normals: 554 | positions = ray_samples.frustums.get_positions() 555 | 556 | positions_flat = self.position_encoding(positions.view(-1, 3)) 557 | pred_normals_inp = torch.cat([positions_flat, density_embedding.view(-1, self.geo_feat_dim)], dim=-1) 558 | 559 | x = self.mlp_pred_normals(pred_normals_inp).view(*outputs_shape, -1).to(directions) 560 | outputs[FieldHeadNames.PRED_NORMALS] = self.field_head_pred_normals(x) 561 | 562 | h = torch.cat( 563 | [ 564 | d, 565 | density_embedding.view(-1, self.geo_feat_dim), 566 | embedded_appearance.view(-1, self.appearance_embedding_dim), 567 | ], 568 | dim=-1, 569 | ) 570 | rgb = self.mlp_head(h).view(*outputs_shape, -1).to(directions) 571 | outputs.update({FieldHeadNames.RGB: rgb}) 572 | # for field_head in self.field_heads: 573 | # # encoded_dir = self.direction_encoding(ray_samples.frustums.directions.reshape(-1, 3)).view( 574 | # # *outputs_shape, -1 575 | # # ) 576 | # if freq_mask is not None : 577 | # encoded_dir = self.direction_encoding(ray_samples.frustums.directions.reshape(-1, 3))*dir_freq_mask.view(1,-1) 578 | # else: 579 | # encoded_dir = self.direction_encoding(ray_samples.frustums.directions.reshape(-1, 3)) 580 | # mlp_out = self.mlp_head( 581 | # torch.cat( 582 | # [ 583 | # encoded_dir, 584 | # density_embedding.view(-1, density_embedding.shape[-1]), # type:ignore 585 | # embedded_appearance.view(-1, self.appearance_embedding_dim), 586 | # ], 587 | # dim=-1, # type:ignore 588 | # ) 589 | # ).view(*outputs_shape, -1) 590 | # outputs[field_head.field_head_name] = field_head(mlp_out) 591 | 592 | 593 | return outputs 594 | def forward(self, ray_samples: RaySamples, compute_normals: bool = False,freq_mask=None) -> Dict[FieldHeadNames, TensorType]: 595 | """Evaluates the field at points along the ray. 596 | 597 | Args: 598 | ray_samples: Samples to evaluate field on. 599 | """ 600 | if compute_normals: 601 | with torch.enable_grad(): 602 | density, density_embedding = self.get_density(ray_samples,freq_mask=freq_mask) 603 | else: 604 | density, density_embedding = self.get_density(ray_samples,freq_mask=freq_mask) 605 | 606 | field_outputs = self.get_outputs(ray_samples, density_embedding=density_embedding,freq_mask=freq_mask) 607 | field_outputs[FieldHeadNames.DENSITY] = density # type: ignore 608 | 609 | if compute_normals: 610 | with torch.enable_grad(): 611 | normals = self.get_normals() 612 | field_outputs[FieldHeadNames.NORMALS] = normals # type: ignore 613 | return field_outputs 614 | 615 | 616 | field_implementation_to_class: Dict[str, Field] = {"torch": TorchFreeNerfactoField} 617 | -------------------------------------------------------------------------------- /freenerf/freenerf_pipeline.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import abstractmethod 3 | from dataclasses import dataclass, field 4 | from time import time 5 | from typing import Any, Dict, List, Mapping, Optional, Type, Union, cast 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from rich.progress import ( 10 | BarColumn, 11 | MofNCompleteColumn, 12 | Progress, 13 | TextColumn, 14 | TimeElapsedColumn, 15 | ) 16 | from torch import nn 17 | from torch.nn import Parameter 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | from typing_extensions import Literal 20 | # from nerfstudio.engine.trainer import TrainerConfig 21 | from freenerf.util import freenerfT 22 | from nerfstudio.pipelines.base_pipeline import ( 23 | VanillaPipeline, 24 | VanillaPipelineConfig, 25 | ) 26 | 27 | from nerfstudio.configs import base_config as cfg 28 | from nerfstudio.data.datamanagers.base_datamanager import ( 29 | DataManager, 30 | DataManagerConfig, 31 | VanillaDataManager, 32 | VanillaDataManagerConfig, 33 | ) 34 | from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes 35 | from nerfstudio.models.base_model import Model, ModelConfig 36 | from nerfstudio.utils import profiler 37 | 38 | @dataclass 39 | class FreeNeRFactoPipelineConfig(VanillaPipelineConfig): 40 | """Configuration for pipeline instantiation""" 41 | 42 | _target: Type = field(default_factory=lambda: FreeNeRFactoPipeline) 43 | """target class to instantiate""" 44 | datamanager: DataManagerConfig = VanillaDataManagerConfig() 45 | """specifies the datamanager config""" 46 | model: ModelConfig = ModelConfig() 47 | """specifies the model config""" 48 | T:int=freenerfT.max_num_iterations*0.9 49 | 50 | 51 | class FreeNeRFactoPipeline(VanillaPipeline): 52 | """The pipeline class for the vanilla nerf setup of multiple cameras for one or a few scenes. 53 | 54 | config: configuration to instantiate pipeline 55 | device: location to place model and data 56 | test_mode: 57 | 'val': loads train/val datasets into memory 58 | 'test': loads train/test dataset into memory 59 | 'inference': does not load any dataset into memory 60 | world_size: total number of machines available 61 | local_rank: rank of current machine 62 | 63 | Attributes: 64 | datamanager: The data manager that will be used 65 | model: The model that will be used 66 | """ 67 | 68 | def __init__( 69 | self, 70 | config: FreeNeRFactoPipelineConfig, 71 | device: str, 72 | test_mode: Literal["test", "val", "inference"] = "val", 73 | world_size: int = 1, 74 | local_rank: int = 0, 75 | ): 76 | super(VanillaPipeline,self).__init__() 77 | self.config = config 78 | self.test_mode = test_mode 79 | self.datamanager: VanillaDataManager = config.datamanager.setup( 80 | device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank 81 | ) 82 | self.datamanager.to(device) 83 | # TODO(ethan): get rid of scene_bounds from the model 84 | assert self.datamanager.train_dataset is not None, "Missing input dataset" 85 | 86 | self._model = config.model.setup( 87 | scene_box=self.datamanager.train_dataset.scene_box, 88 | num_train_data=len(self.datamanager.train_dataset), 89 | metadata=self.datamanager.train_dataset.metadata, 90 | ) 91 | self.model.to(device) 92 | 93 | self.world_size = world_size 94 | if world_size > 1: 95 | self._model = typing.cast(Model, DDP(self._model, device_ids=[local_rank], find_unused_parameters=True)) 96 | dist.barrier(device_ids=[local_rank]) 97 | 98 | @property 99 | def device(self): 100 | """Returns the device that the model is on.""" 101 | return self.model.device 102 | 103 | @profiler.time_function 104 | def get_train_loss_dict(self, step: int): 105 | """This function gets your training loss dict. This will be responsible for 106 | getting the next batch of data from the DataManager and interfacing with the 107 | Model class, feeding the data to the model's forward function. 108 | 109 | Args: 110 | step: current iteration step to update sampler if using DDP (distributed) 111 | """ 112 | # print(self._model.field.position_encoding.get_out_dim()) 113 | from freenerf.util import get_freq_mask 114 | pos_freq_mask=get_freq_mask(self._model.field.position_encoding.n_output_dims,step,self.config.T).to(self.device) 115 | dir_freq_mask=get_freq_mask(self._model.field.direction_encoding.n_output_dims,step,self.config.T).to(self.device) 116 | # TODO Need a sanity check 117 | 118 | ray_bundle, batch = self.datamanager.next_train(step) 119 | model_outputs = self.model(ray_bundle,(pos_freq_mask,dir_freq_mask)) 120 | metrics_dict = self.model.get_metrics_dict(model_outputs, batch) 121 | 122 | if self.config.datamanager.camera_optimizer is not None: 123 | camera_opt_param_group = self.config.datamanager.camera_optimizer.param_group 124 | if camera_opt_param_group in self.datamanager.get_param_groups(): 125 | # Report the camera optimization metrics 126 | metrics_dict["camera_opt_translation"] = ( 127 | self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, :3].norm() 128 | ) 129 | metrics_dict["camera_opt_rotation"] = ( 130 | self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, 3:].norm() 131 | ) 132 | 133 | loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict) 134 | 135 | return model_outputs, loss_dict, metrics_dict 136 | 137 | def forward(self): 138 | """Blank forward method 139 | 140 | This is an nn.Module, and so requires a forward() method normally, although in our case 141 | we do not need a forward() method""" 142 | raise NotImplementedError 143 | 144 | @profiler.time_function 145 | def get_eval_loss_dict(self, step: int): 146 | """This function gets your evaluation loss dict. It needs to get the data 147 | from the DataManager and feed it to the model's forward function 148 | 149 | Args: 150 | step: current iteration step 151 | """ 152 | self.eval() 153 | ray_bundle, batch = self.datamanager.next_eval(step) 154 | model_outputs = self.model(ray_bundle) 155 | metrics_dict = self.model.get_metrics_dict(model_outputs, batch) 156 | loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict) 157 | self.train() 158 | return model_outputs, loss_dict, metrics_dict 159 | 160 | @profiler.time_function 161 | def get_eval_image_metrics_and_images(self, step: int): 162 | """This function gets your evaluation loss dict. It needs to get the data 163 | from the DataManager and feed it to the model's forward function 164 | 165 | Args: 166 | step: current iteration step 167 | """ 168 | self.eval() 169 | image_idx, camera_ray_bundle, batch = self.datamanager.next_eval_image(step) 170 | outputs = self.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle) 171 | metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch) 172 | assert "image_idx" not in metrics_dict 173 | metrics_dict["image_idx"] = image_idx 174 | assert "num_rays" not in metrics_dict 175 | metrics_dict["num_rays"] = len(camera_ray_bundle) 176 | self.train() 177 | return metrics_dict, images_dict 178 | 179 | @profiler.time_function 180 | def get_average_eval_image_metrics(self, step: Optional[int] = None): 181 | """Iterate over all the images in the eval dataset and get the average. 182 | 183 | Returns: 184 | metrics_dict: dictionary of metrics 185 | """ 186 | self.eval() 187 | metrics_dict_list = [] 188 | num_images = len(self.datamanager.fixed_indices_eval_dataloader) 189 | with Progress( 190 | TextColumn("[progress.description]{task.description}"), 191 | BarColumn(), 192 | TimeElapsedColumn(), 193 | MofNCompleteColumn(), 194 | transient=True, 195 | ) as progress: 196 | task = progress.add_task("[green]Evaluating all eval images...", total=num_images) 197 | for camera_ray_bundle, batch in self.datamanager.fixed_indices_eval_dataloader: 198 | # time this the following line 199 | inner_start = time() 200 | height, width = camera_ray_bundle.shape 201 | num_rays = height * width 202 | outputs = self.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle) 203 | metrics_dict, _ = self.model.get_image_metrics_and_images(outputs, batch) 204 | assert "num_rays_per_sec" not in metrics_dict 205 | metrics_dict["num_rays_per_sec"] = num_rays / (time() - inner_start) 206 | fps_str = "fps" 207 | assert fps_str not in metrics_dict 208 | metrics_dict[fps_str] = metrics_dict["num_rays_per_sec"] / (height * width) 209 | metrics_dict_list.append(metrics_dict) 210 | progress.advance(task) 211 | # average the metrics list 212 | metrics_dict = {} 213 | for key in metrics_dict_list[0].keys(): 214 | metrics_dict[key] = float( 215 | torch.mean(torch.tensor([metrics_dict[key] for metrics_dict in metrics_dict_list])) 216 | ) 217 | self.train() 218 | return metrics_dict 219 | 220 | def load_pipeline(self, loaded_state: Dict[str, Any], step: int) -> None: 221 | """Load the checkpoint from the given path 222 | 223 | Args: 224 | loaded_state: pre-trained model state dict 225 | step: training step of the loaded checkpoint 226 | """ 227 | state = { 228 | (key[len("module.") :] if key.startswith("module.") else key): value for key, value in loaded_state.items() 229 | } 230 | self._model.update_to_step(step) 231 | self.load_state_dict(state, strict=True) 232 | 233 | def get_training_callbacks( 234 | self, training_callback_attributes: TrainingCallbackAttributes 235 | ) -> List[TrainingCallback]: 236 | """Returns the training callbacks from both the Dataloader and the Model.""" 237 | datamanager_callbacks = self.datamanager.get_training_callbacks(training_callback_attributes) 238 | model_callbacks = self.model.get_training_callbacks(training_callback_attributes) 239 | callbacks = datamanager_callbacks + model_callbacks 240 | return callbacks 241 | 242 | def get_param_groups(self) -> Dict[str, List[Parameter]]: 243 | """Get the param groups for the pipeline. 244 | 245 | Returns: 246 | A list of dictionaries containing the pipeline's param groups. 247 | """ 248 | datamanager_params = self.datamanager.get_param_groups() 249 | model_params = self.model.get_param_groups() 250 | # TODO(ethan): assert that key names don't overlap 251 | return {**datamanager_params, **model_params} 252 | -------------------------------------------------------------------------------- /freenerf/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass, field 3 | @dataclass 4 | class freenerfT: 5 | max_num_iterations: int =30000 6 | 7 | def get_freq_mask(pos_enc_length, current_iter, total_reg_iter): 8 | if current_iter < total_reg_iter: 9 | freq_mask=torch.zeros(pos_enc_length) 10 | ptr = pos_enc_length / 3 * current_iter / total_reg_iter + 1 11 | ptr = ptr if ptr < pos_enc_length / 3 else pos_enc_length / 3 12 | int_ptr = int(ptr) 13 | freq_mask[: int_ptr * 3] = 1.0 # assign the integer part 14 | freq_mask[int_ptr * 3 : int_ptr * 3 + 3] = (ptr - int_ptr) # assign the fractional part 15 | return torch.clip(freq_mask,1e-8, 1-1e-8) # for math stabiltiy 16 | else: 17 | return torch.ones(pos_enc_length) 18 | 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "freenerfacto" 3 | version = "0.1.0" 4 | 5 | dependencies=[ 6 | "nerfstudio" 7 | ] 8 | 9 | [project.entry-points.'nerfstudio.method_configs'] 10 | freenerfacto = 'freenerf.freenerf_config:freenerfacto_method' 11 | --------------------------------------------------------------------------------