├── switch_nerf ├── __init__.py ├── datasets │ ├── __init__.py │ ├── nerf_data │ │ ├── __init__.py │ │ ├── load_gigapixel.py │ │ ├── ray_utils.py │ │ ├── load_blender.py │ │ ├── load_LINEMOD.py │ │ ├── load_deepvoxels.py │ │ ├── load_bungee.py │ │ ├── nerf_loader.py │ │ └── load_llff.py │ ├── lists │ │ └── block_nerf_train_val_dummy.txt │ ├── dataset_utils.py │ └── memory_dataset.py ├── models │ ├── __init__.py │ ├── cascade.py │ ├── mega_nerf_container.py │ ├── mega_nerf.py │ ├── nerf.py │ └── model_utils.py ├── modules │ ├── __init__.py │ └── tutel_moe_ext │ │ ├── tutel_moe_nobatch.py │ │ ├── tutel_system.py │ │ ├── tutel_sparse_nobatch.py │ │ ├── torch_moe_layer_nobatch.py │ │ ├── tutel_communicate_nobatch.py │ │ └── tutel_fast_dispatch.py ├── utils │ ├── __init__.py │ ├── logger.py │ └── functions.py ├── train.py ├── misc_utils.py ├── train_nerf_moe.py ├── eval.py ├── eval_ckpt.py ├── eval_image.py ├── eval_points.py ├── eval_nerf_moe.py ├── eval_image_blocknerf.py ├── configs │ └── switch_nerf │ │ ├── rubble.yaml │ │ ├── building.yaml │ │ ├── campus.yaml │ │ ├── residence.yaml │ │ ├── sci-art.yaml │ │ ├── mission_bay.yaml │ │ └── bungee.yaml ├── image_metadata.py ├── scripts │ ├── copy_images.py │ ├── convert_to_container_moe.py │ ├── merge_points.py │ └── create_octree_moe.py ├── ray_utils.py ├── spherical_harmonics.py └── metrics.py ├── .vscode ├── settings.json └── launch.json ├── requirements.txt ├── install_tutel.md ├── LICENSE ├── .gitignore └── README.md /switch_nerf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | } -------------------------------------------------------------------------------- /switch_nerf/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /switch_nerf/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /switch_nerf/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /switch_nerf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /switch_nerf/datasets/lists/block_nerf_train_val_dummy.txt: -------------------------------------------------------------------------------- 1 | waymo_block_nerf_mission_bay_train.tfrecord-00000-of-01063 2 | waymo_block_nerf_mission_bay_train.tfrecord-00001-of-01063 3 | waymo_block_nerf_mission_bay_validation.tfrecord-00000-of-00373 4 | waymo_block_nerf_mission_bay_validation.tfrecord-00001-of-00373 -------------------------------------------------------------------------------- /switch_nerf/modules/tutel_moe_ext/tutel_moe_nobatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | # Level-level Ops 6 | from tutel.jit_kernels.gating import fast_cumsum_sub_one 7 | from .tutel_fast_dispatch_nobatch import fast_dispatcher, extract_critical, fast_encode, fast_decode 8 | 9 | # High-level Ops 10 | from .tutel_moe_layer_nobatch import moe_layer, SingleExpert 11 | -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/load_gigapixel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import imageio 3 | import torchvision 4 | import cv2 5 | import math 6 | 7 | def load_gigapixel_data(img_path, scale=1.0): 8 | img = imageio.imread(img_path) / 255.0 9 | H, W = img.shape[0:2] 10 | 11 | if scale < 1.0: 12 | H = math.floor(scale * H) 13 | W = math.floor(scale * W) 14 | 15 | img = cv2.resize(img, (W, H), interpolation=cv2.INTER_LINEAR) 16 | 17 | return img -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # nerf_3_8 python 3.8 2 | --find-links https://download.pytorch.org/whl/torch_stable.html 3 | torch==1.10.0+cu111 4 | torchvision==0.11.0+cu111 5 | torchaudio==0.10.0 6 | tqdm 7 | ConfigArgParse 8 | # deepspeed 9 | einops 10 | imageio 11 | imageio-ffmpeg 12 | matplotlib 13 | tensorboard 14 | # tensorboard==2.7 15 | PyYAML 16 | npy_append_array 17 | parscript 18 | opencv-python 19 | setuptools==58.0.4 20 | lpips==0.1.4 21 | # tutel 22 | fairscale 23 | protobuf==3.20.1 24 | torch_tb_profiler 25 | timm 26 | plyfile 27 | tensorflow==2.10.0 -------------------------------------------------------------------------------- /switch_nerf/models/cascade.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Cascade(nn.Module): 8 | def __init__(self, coarse: nn.Module, fine: nn.Module): 9 | super(Cascade, self).__init__() 10 | self.coarse = coarse 11 | self.fine = fine 12 | 13 | def forward(self, use_coarse: bool, x: torch.Tensor, sigma_only: bool = False, 14 | sigma_noise: Optional[torch.Tensor] = None) -> torch.Tensor: 15 | if use_coarse: 16 | return self.coarse(x, sigma_only, sigma_noise) 17 | else: 18 | return self.fine(x, sigma_only, sigma_noise) 19 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true, 14 | "env": { 15 | // Enable this to turn on redux logging during debugging 16 | // "CUDA_VISIBLE_DEVICES": "4,5,6" 17 | "PYTHONPATH": "${workspaceRoot}" 18 | } 19 | } 20 | ] 21 | } -------------------------------------------------------------------------------- /switch_nerf/train.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from switch_nerf.opts import get_opts_base 7 | from switch_nerf.runner import Runner 8 | 9 | 10 | def _get_train_opts() -> Namespace: 11 | parser = get_opts_base() 12 | 13 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 14 | parser.add_argument('--dataset_path', type=str, required=True) 15 | 16 | return parser.parse_args() 17 | 18 | @record 19 | def main(hparams: Namespace) -> None: 20 | if hparams.detect_anomalies: 21 | with torch.autograd.detect_anomaly(): 22 | Runner(hparams).train() 23 | else: 24 | Runner(hparams).train() 25 | 26 | 27 | if __name__ == '__main__': 28 | main(_get_train_opts()) 29 | -------------------------------------------------------------------------------- /switch_nerf/misc_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | import logging 5 | 6 | def main_print(log) -> None: 7 | if ('LOCAL_RANK' not in os.environ) or int(os.environ['LOCAL_RANK']) == 0: 8 | print(log) 9 | 10 | def main_log(log) -> None: 11 | if ('LOCAL_RANK' not in os.environ) or int(os.environ['LOCAL_RANK']) == 0: 12 | logger = logging.getLogger() 13 | logger.info(log) 14 | 15 | def process_log(log) -> None: 16 | logger = logging.getLogger() 17 | logger.info(log) 18 | 19 | def main_tqdm(inner): 20 | if ('LOCAL_RANK' not in os.environ) or int(os.environ['LOCAL_RANK']) == 0: 21 | return tqdm(inner) 22 | else: 23 | return inner 24 | 25 | # https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8 26 | def count_parameters(model): 27 | return sum(p.numel() for p in model.parameters() if p.requires_grad) -------------------------------------------------------------------------------- /switch_nerf/train_nerf_moe.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from switch_nerf.opts_nerf import get_opts_base 7 | from switch_nerf.runner import Runner 8 | 9 | 10 | def _get_train_opts() -> Namespace: 11 | parser = get_opts_base() 12 | 13 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 14 | parser.add_argument('--dataset_path', type=str, required=True) 15 | 16 | return parser.parse_args() 17 | 18 | @record 19 | def main(hparams: Namespace) -> None: 20 | assert hparams.data_type == "nerf" 21 | if hparams.detect_anomalies: 22 | with torch.autograd.detect_anomaly(): 23 | Runner(hparams).train_nerf() 24 | else: 25 | Runner(hparams).train_nerf() 26 | 27 | 28 | if __name__ == '__main__': 29 | main(_get_train_opts()) 30 | -------------------------------------------------------------------------------- /switch_nerf/eval.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from switch_nerf.opts import get_opts_base 7 | from switch_nerf.runner import Runner 8 | 9 | 10 | def _get_eval_opts() -> Namespace: 11 | parser = get_opts_base() 12 | 13 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 14 | parser.add_argument('--dataset_path', type=str, required=True) 15 | 16 | return parser.parse_args() 17 | 18 | @record 19 | def main(hparams: Namespace) -> None: 20 | assert hparams.ckpt_path is not None or hparams.container_path is not None 21 | 22 | if hparams.detect_anomalies: 23 | with torch.autograd.detect_anomaly(): 24 | Runner(hparams).eval() 25 | else: 26 | Runner(hparams).eval() 27 | 28 | 29 | if __name__ == '__main__': 30 | main(_get_eval_opts()) 31 | -------------------------------------------------------------------------------- /switch_nerf/eval_ckpt.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from switch_nerf.opts import get_opts_base 7 | from switch_nerf.runner import Runner 8 | 9 | 10 | def _get_eval_opts() -> Namespace: 11 | parser = get_opts_base() 12 | 13 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 14 | parser.add_argument('--dataset_path', type=str, required=True) 15 | 16 | return parser.parse_args() 17 | 18 | @record 19 | def main(hparams: Namespace) -> None: 20 | assert hparams.ckpt_path is not None or hparams.container_path is not None 21 | 22 | if hparams.detect_anomalies: 23 | with torch.autograd.detect_anomaly(): 24 | Runner(hparams).eval_ckpt() 25 | else: 26 | Runner(hparams).eval_ckpt() 27 | 28 | 29 | if __name__ == '__main__': 30 | main(_get_eval_opts()) 31 | -------------------------------------------------------------------------------- /switch_nerf/eval_image.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from switch_nerf.opts import get_opts_base 7 | from switch_nerf.runner import Runner 8 | 9 | 10 | def _get_eval_opts() -> Namespace: 11 | parser = get_opts_base() 12 | 13 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 14 | parser.add_argument('--dataset_path', type=str, required=True) 15 | 16 | return parser.parse_args() 17 | 18 | @record 19 | def main(hparams: Namespace) -> None: 20 | assert hparams.ckpt_path is not None or hparams.container_path is not None 21 | 22 | if hparams.detect_anomalies: 23 | with torch.autograd.detect_anomaly(): 24 | Runner(hparams).eval_image() 25 | else: 26 | Runner(hparams).eval_image() 27 | 28 | 29 | if __name__ == '__main__': 30 | main(_get_eval_opts()) 31 | -------------------------------------------------------------------------------- /switch_nerf/eval_points.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from switch_nerf.opts import get_opts_base 7 | from switch_nerf.runner import Runner 8 | 9 | 10 | def _get_eval_opts() -> Namespace: 11 | parser = get_opts_base() 12 | 13 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 14 | parser.add_argument('--dataset_path', type=str, required=True) 15 | 16 | return parser.parse_args() 17 | 18 | @record 19 | def main(hparams: Namespace) -> None: 20 | assert hparams.ckpt_path is not None or hparams.container_path is not None 21 | 22 | if hparams.detect_anomalies: 23 | with torch.autograd.detect_anomaly(): 24 | Runner(hparams).eval_points() 25 | else: 26 | Runner(hparams).eval_points() 27 | 28 | 29 | if __name__ == '__main__': 30 | main(_get_eval_opts()) 31 | -------------------------------------------------------------------------------- /switch_nerf/eval_nerf_moe.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from switch_nerf.opts_nerf import get_opts_base 7 | from switch_nerf.runner import Runner 8 | 9 | 10 | def _get_eval_opts() -> Namespace: 11 | parser = get_opts_base() 12 | 13 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 14 | parser.add_argument('--dataset_path', type=str, required=True) 15 | 16 | return parser.parse_args() 17 | 18 | @record 19 | def main(hparams: Namespace) -> None: 20 | assert hparams.ckpt_path is not None or hparams.container_path is not None 21 | 22 | if hparams.detect_anomalies: 23 | with torch.autograd.detect_anomaly(): 24 | Runner(hparams).eval_nerf() 25 | else: 26 | Runner(hparams).eval_nerf() 27 | 28 | 29 | if __name__ == '__main__': 30 | main(_get_eval_opts()) 31 | -------------------------------------------------------------------------------- /switch_nerf/eval_image_blocknerf.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from switch_nerf.opts import get_opts_base 7 | from switch_nerf.runner import Runner 8 | 9 | 10 | def _get_eval_opts() -> Namespace: 11 | parser = get_opts_base() 12 | 13 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 14 | parser.add_argument('--dataset_path', type=str, required=True) 15 | 16 | return parser.parse_args() 17 | 18 | @record 19 | def main(hparams: Namespace) -> None: 20 | assert hparams.ckpt_path is not None or hparams.container_path is not None 21 | 22 | if hparams.detect_anomalies: 23 | with torch.autograd.detect_anomaly(): 24 | Runner(hparams).eval_image_blocknerf() 25 | else: 26 | Runner(hparams).eval_image_blocknerf() 27 | 28 | 29 | if __name__ == '__main__': 30 | main(_get_eval_opts()) 31 | -------------------------------------------------------------------------------- /install_tutel.md: -------------------------------------------------------------------------------- 1 | # Install Tutel 2 | 3 | Our method depends on a very early [vesion](https://github.com/microsoft/tutel/tree/56dbd664341cf6485c9fa292955f77d3ac918a65) of Tutel. The commit id is 56dbd664341cf6485c9fa292955f77d3ac918a65. The Tutel has changed a lot since then, so please make sure you download the correct version. 4 | 5 | After dwonload the code, run: 6 | 7 | ```sh 8 | cd tutel 9 | python3 ./setup.py install 10 | ``` 11 | 12 | You may need to change the cuda version. Just search `/usr/local/cuda` in an editor and change all of them to your cuda location such as `/usr/local/cuda-11.1`. 13 | 14 | You may need to install NCCL library. Please follow its website. 15 | 16 | If you intsall NCCL manually, you should add the include path to the `setup.py` of tutel. You can add `ext_args['cxx'] += ['-I/you/path/to/NCCL/include']` after Line 68 in the `setup.py`. 17 | If you have several local libraries installed on your own directory and you need them for the compilation, you should add `library_dirs += ['/you/local/lib']` after Line 115 in the `setup.py`. -------------------------------------------------------------------------------- /switch_nerf/models/mega_nerf_container.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MegaNeRFContainer(nn.Module): 8 | def __init__(self, sub_modules: List[nn.Module], bg_sub_modules: List[nn.Module], centroids: torch.Tensor, 9 | grid_dim: torch.Tensor, min_position: torch.Tensor, max_position: torch.Tensor, need_viewdir: bool, 10 | need_appearance_embedding: bool, cluster_2d: bool): 11 | super(MegaNeRFContainer, self).__init__() 12 | 13 | for i, sub_module in enumerate(sub_modules): 14 | setattr(self, 'sub_module_{}'.format(i), sub_module) 15 | 16 | for i, bg_sub_module in enumerate(bg_sub_modules): 17 | setattr(self, 'bg_sub_module_{}'.format(i), bg_sub_module) 18 | 19 | self.centroids = centroids 20 | self.grid_dim = grid_dim 21 | self.min_position = min_position 22 | self.max_position = max_position 23 | self.need_viewdir = need_viewdir 24 | self.need_appearance_embedding = need_appearance_embedding 25 | self.cluster_2d = cluster_2d 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MiZhenxing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /switch_nerf/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | import time 6 | import numpy as np 7 | import torch 8 | from os.path import join 9 | import cv2 10 | 11 | 12 | def setup_logger(name, save_dir, prefix="", timestamp=True): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.INFO) 15 | ch = logging.StreamHandler(stream=sys.stdout) 16 | ch.setLevel(logging.INFO) 17 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 18 | ch.setFormatter(formatter) 19 | logger.addHandler(ch) 20 | 21 | if save_dir: 22 | timestamp = time.strftime(".%m_%d_%H_%M_%S") if timestamp else "" 23 | prefix = "." + prefix if prefix else "" 24 | log_file = os.path.join(save_dir, "log{}.txt".format(prefix + timestamp)) 25 | fh = logging.FileHandler(log_file) 26 | fh.setLevel(logging.INFO) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | logger.propagate = False 31 | return logger 32 | 33 | 34 | def shutdown_logger(logger): 35 | logger.handlers = [] 36 | 37 | def setup_logger_file(logger, save_dir, prefix="", timestamp=True): 38 | timestamp = time.strftime(".%m_%d_%H_%M_%S") if timestamp else "" 39 | prefix = "." + prefix if prefix else "" 40 | log_file = os.path.join(save_dir, "log{}.txt".format(prefix + timestamp)) 41 | fh = logging.FileHandler(log_file) 42 | fh.setLevel(logging.INFO) 43 | fh.setFormatter(logger.handlers[0].formatter) 44 | logger.addHandler(fh) 45 | return logger -------------------------------------------------------------------------------- /switch_nerf/configs/switch_nerf/rubble.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [11, 38] 2 | appearance_dim: 48 3 | 4 | data_type: "mega_nerf" 5 | nerfmoe_class_name: "NeRFMoE" 6 | model: 7 | layer_num_main: 3 8 | sigma_tag: 0 9 | dir_tag: 1 10 | color_tag: 2 11 | 12 | layers: 13 | "xyz": 14 | in_ch: 75 # 3 + 12 * 3 * 2 15 | h_ch: 0 16 | out_ch: 256 17 | num: 1 18 | type: "mlp" 19 | act: "none" 20 | 21 | "0": 22 | in_ch: 256 23 | h_ch: 256 24 | out_ch: 256 25 | num: 7 26 | skips: [3] 27 | init_factor: 1.0 28 | type: "moe" 29 | act: "relu" 30 | 31 | # gate 32 | gate_type: "top" 33 | k: 1 34 | fp32_gate: True 35 | gate_dim: 256 36 | 37 | "1": # xyz_encoding_final 38 | in_ch: 256 39 | h_ch: 0 40 | out_ch: 256 41 | num: 1 42 | type: "mlp" 43 | act: "none" 44 | 45 | "2": # dir_a_encoding 46 | in_ch: 331 # 256 + 27 + 48 47 | h_ch: 0 48 | out_ch: 128 49 | num: 1 50 | type: "mlp" 51 | act: "relu" 52 | 53 | sigma: # sigma 54 | in_ch: 256 55 | h_ch: 0 56 | out_ch: 1 57 | num: 1 58 | type: "mlp" 59 | act: "none" 60 | 61 | color: # rgb 62 | in_ch: 128 63 | h_ch: 0 64 | out_ch: 3 65 | num: 1 66 | type: "mlp" 67 | act: "none" 68 | 69 | moe_external_gate: 70 | in_ch: 256 71 | h_ch: 256 72 | out_ch: 256 73 | num: 2 74 | type: "mlp" 75 | act: "none" 76 | out_skip: False 77 | 78 | gate_input_norm: 79 | in_ch: 256 80 | h_ch: 0 81 | out_ch: 0 82 | num: 1 83 | type: "layernorm" -------------------------------------------------------------------------------- /switch_nerf/configs/switch_nerf/building.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [8, 50] 2 | appearance_dim: 48 3 | 4 | data_type: "mega_nerf" 5 | nerfmoe_class_name: "NeRFMoE" 6 | model: 7 | layer_num_main: 3 8 | sigma_tag: 0 9 | dir_tag: 1 10 | color_tag: 2 11 | 12 | layers: 13 | "xyz": 14 | in_ch: 75 # 3 + 12 * 3 * 2 15 | h_ch: 0 16 | out_ch: 256 17 | num: 1 18 | type: "mlp" 19 | act: "none" 20 | 21 | "0": 22 | in_ch: 256 23 | h_ch: 256 24 | out_ch: 256 25 | num: 7 26 | skips: [3] 27 | init_factor: 1.0 28 | type: "moe" 29 | act: "relu" 30 | 31 | # gate 32 | gate_type: "top" 33 | k: 1 34 | fp32_gate: True 35 | gate_dim: 256 36 | 37 | "1": # xyz_encoding_final 38 | in_ch: 256 39 | h_ch: 0 40 | out_ch: 256 41 | num: 1 42 | type: "mlp" 43 | act: "none" 44 | 45 | "2": # dir_a_encoding 46 | in_ch: 331 # 256 + 27 + 48 47 | h_ch: 0 48 | out_ch: 128 49 | num: 1 50 | type: "mlp" 51 | act: "relu" 52 | 53 | sigma: # sigma 54 | in_ch: 256 55 | h_ch: 0 56 | out_ch: 1 57 | num: 1 58 | type: "mlp" 59 | act: "none" 60 | 61 | color: # rgb 62 | in_ch: 128 63 | h_ch: 0 64 | out_ch: 3 65 | num: 1 66 | type: "mlp" 67 | act: "none" 68 | 69 | moe_external_gate: 70 | in_ch: 256 71 | h_ch: 256 72 | out_ch: 256 73 | num: 2 74 | type: "mlp" 75 | act: "none" 76 | out_skip: False 77 | 78 | gate_input_norm: 79 | in_ch: 256 80 | h_ch: 0 81 | out_ch: 0 82 | num: 1 83 | type: "layernorm" -------------------------------------------------------------------------------- /switch_nerf/configs/switch_nerf/campus.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [3, 132] 2 | appearance_dim: 48 3 | 4 | data_type: "mega_nerf" 5 | nerfmoe_class_name: "NeRFMoE" 6 | model: 7 | layer_num_main: 3 8 | sigma_tag: 0 9 | dir_tag: 1 10 | color_tag: 2 11 | 12 | layers: 13 | "xyz": 14 | in_ch: 75 # 3 + 12 * 3 * 2 15 | h_ch: 0 16 | out_ch: 256 17 | num: 1 18 | type: "mlp" 19 | act: "none" 20 | 21 | "0": 22 | in_ch: 256 23 | h_ch: 256 24 | out_ch: 256 25 | num: 7 26 | skips: [3] 27 | init_factor: 1.0 28 | type: "moe" 29 | act: "relu" 30 | 31 | # gate 32 | gate_type: "top" 33 | k: 1 34 | fp32_gate: True 35 | gate_dim: 256 36 | 37 | "1": # xyz_encoding_final 38 | in_ch: 256 39 | h_ch: 0 40 | out_ch: 256 41 | num: 1 42 | type: "mlp" 43 | act: "none" 44 | 45 | "2": # dir_a_encoding 46 | in_ch: 331 # 256 + 27 + 48 47 | h_ch: 0 48 | out_ch: 128 49 | num: 1 50 | type: "mlp" 51 | act: "relu" 52 | 53 | sigma: # sigma 54 | in_ch: 256 55 | h_ch: 0 56 | out_ch: 1 57 | num: 1 58 | type: "mlp" 59 | act: "none" 60 | 61 | color: # rgb 62 | in_ch: 128 63 | h_ch: 0 64 | out_ch: 3 65 | num: 1 66 | type: "mlp" 67 | act: "none" 68 | 69 | moe_external_gate: 70 | in_ch: 256 71 | h_ch: 256 72 | out_ch: 256 73 | num: 2 74 | type: "mlp" 75 | act: "none" 76 | out_skip: False 77 | 78 | gate_input_norm: 79 | in_ch: 256 80 | h_ch: 0 81 | out_ch: 0 82 | num: 1 83 | type: "layernorm" -------------------------------------------------------------------------------- /switch_nerf/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | 3 | import torch 4 | 5 | from switch_nerf.image_metadata import ImageMetadata 6 | 7 | 8 | def get_rgb_index_mask(metadata: ImageMetadata) -> Optional[ 9 | Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: 10 | rgbs = metadata.load_image().view(-1, 3) 11 | 12 | keep_mask = metadata.load_mask() 13 | 14 | if metadata.is_val: 15 | if keep_mask is None: 16 | keep_mask = torch.ones(metadata.H, metadata.W, dtype=torch.bool) 17 | else: 18 | # Get how many pixels we're discarding that would otherwise be added 19 | discard_half = keep_mask[:, metadata.W // 2:] 20 | discard_pos_count = discard_half[discard_half == True].shape[0] 21 | 22 | candidates_to_add = torch.arange(metadata.H * metadata.W).view(metadata.H, metadata.W)[:, :metadata.W // 2] 23 | keep_half = keep_mask[:, :metadata.W // 2] 24 | candidates_to_add = candidates_to_add[keep_half == False].reshape(-1) 25 | to_add = candidates_to_add[torch.randperm(candidates_to_add.shape[0])[:discard_pos_count]] 26 | 27 | keep_mask.view(-1).scatter_(0, to_add, torch.ones_like(to_add, dtype=torch.bool)) 28 | 29 | keep_mask[:, metadata.W // 2:] = False 30 | 31 | if keep_mask is not None: 32 | if keep_mask[keep_mask == True].shape[0] == 0: 33 | return None 34 | 35 | keep_mask = keep_mask.view(-1) 36 | rgbs = rgbs[keep_mask == True] 37 | 38 | assert metadata.image_index <= torch.iinfo(torch.short).max 39 | return rgbs, metadata.image_index * torch.ones(rgbs.shape[0], dtype=torch.short), keep_mask 40 | -------------------------------------------------------------------------------- /switch_nerf/configs/switch_nerf/residence.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [30, 118] 2 | appearance_dim: 48 3 | 4 | data_type: "mega_nerf" 5 | nerfmoe_class_name: "NeRFMoE" 6 | model: 7 | layer_num_main: 3 8 | sigma_tag: 0 9 | dir_tag: 1 10 | color_tag: 2 11 | 12 | layers: 13 | "xyz": 14 | in_ch: 75 # 3 + 12 * 3 * 2 15 | h_ch: 0 16 | out_ch: 256 17 | num: 1 18 | type: "mlp" 19 | act: "none" 20 | 21 | "0": 22 | in_ch: 256 23 | h_ch: 256 24 | out_ch: 256 25 | num: 7 26 | skips: [3] 27 | init_factor: 1.0 28 | type: "moe" 29 | act: "relu" 30 | 31 | # gate 32 | gate_type: "top" 33 | k: 1 34 | fp32_gate: True 35 | gate_dim: 256 36 | 37 | "1": # xyz_encoding_final 38 | in_ch: 256 39 | h_ch: 0 40 | out_ch: 256 41 | num: 1 42 | type: "mlp" 43 | act: "none" 44 | 45 | "2": # dir_a_encoding 46 | in_ch: 331 # 256 + 27 + 48 47 | h_ch: 0 48 | out_ch: 128 49 | num: 1 50 | type: "mlp" 51 | act: "relu" 52 | 53 | sigma: # sigma 54 | in_ch: 256 55 | h_ch: 0 56 | out_ch: 1 57 | num: 1 58 | type: "mlp" 59 | act: "none" 60 | 61 | color: # rgb 62 | in_ch: 128 63 | h_ch: 0 64 | out_ch: 3 65 | num: 1 66 | type: "mlp" 67 | act: "none" 68 | 69 | moe_external_gate: 70 | in_ch: 256 71 | h_ch: 256 72 | out_ch: 256 73 | num: 2 74 | type: "mlp" 75 | act: "none" 76 | out_skip: False 77 | 78 | gate_input_norm: 79 | in_ch: 256 80 | h_ch: 0 81 | out_ch: 0 82 | num: 1 83 | type: "layernorm" -------------------------------------------------------------------------------- /switch_nerf/configs/switch_nerf/sci-art.yaml: -------------------------------------------------------------------------------- 1 | ray_altitude_range: [-14, 70] 2 | appearance_dim: 48 3 | 4 | data_type: "mega_nerf" 5 | nerfmoe_class_name: "NeRFMoE" 6 | model: 7 | layer_num_main: 3 8 | sigma_tag: 0 9 | dir_tag: 1 10 | color_tag: 2 11 | 12 | layers: 13 | "xyz": 14 | in_ch: 75 # 3 + 12 * 3 * 2 15 | h_ch: 0 16 | out_ch: 256 17 | num: 1 18 | type: "mlp" 19 | act: "none" 20 | 21 | "0": 22 | in_ch: 256 23 | h_ch: 256 24 | out_ch: 256 25 | num: 7 26 | skips: [3] 27 | init_factor: 1.0 28 | type: "moe" 29 | act: "relu" 30 | 31 | # gate 32 | gate_type: "top" 33 | k: 1 34 | fp32_gate: True 35 | gate_dim: 256 36 | 37 | "1": # xyz_encoding_final 38 | in_ch: 256 39 | h_ch: 0 40 | out_ch: 256 41 | num: 1 42 | type: "mlp" 43 | act: "none" 44 | 45 | "2": # dir_a_encoding 46 | in_ch: 331 # 256 + 27 + 48 47 | h_ch: 0 48 | out_ch: 128 49 | num: 1 50 | type: "mlp" 51 | act: "relu" 52 | 53 | sigma: # sigma 54 | in_ch: 256 55 | h_ch: 0 56 | out_ch: 1 57 | num: 1 58 | type: "mlp" 59 | act: "none" 60 | 61 | color: # rgb 62 | in_ch: 128 63 | h_ch: 0 64 | out_ch: 3 65 | num: 1 66 | type: "mlp" 67 | act: "none" 68 | 69 | moe_external_gate: 70 | in_ch: 256 71 | h_ch: 256 72 | out_ch: 256 73 | num: 2 74 | type: "mlp" 75 | act: "none" 76 | out_skip: False 77 | 78 | gate_input_norm: 79 | in_ch: 256 80 | h_ch: 0 81 | out_ch: 0 82 | num: 1 83 | type: "layernorm" -------------------------------------------------------------------------------- /switch_nerf/configs/switch_nerf/mission_bay.yaml: -------------------------------------------------------------------------------- 1 | data_type: "block_nerf" 2 | appearance_dim: 48 3 | val_scale_factor: 1 4 | use_mip: True 5 | use_moe: True 6 | no_bg_nerf: True 7 | pos_xyz_dim: 12 8 | pos_dir_dim: 4 9 | fine_samples: 513 10 | coarse_samples: 513 11 | training_step_fn: "_training_step_mip" 12 | nerfmoe_class_name: "MipNeRFMoE" 13 | 14 | model: 15 | layer_num_main: 3 16 | sigma_tag: 0 17 | dir_tag: 1 18 | color_tag: 2 19 | 20 | layers: 21 | "xyz": 22 | in_ch: 75 # 3 + 12 * 3 * 2 23 | h_ch: 0 24 | out_ch: 512 25 | num: 1 26 | type: "mlp" 27 | act: "none" 28 | 29 | "0": 30 | in_ch: 512 31 | h_ch: 512 32 | out_ch: 512 33 | num: 7 34 | skips: [3] 35 | init_factor: 1.0 36 | type: "moe" 37 | act: "relu" 38 | 39 | # gate 40 | gate_type: "top" 41 | k: 1 42 | fp32_gate: True 43 | gate_dim: 512 44 | 45 | "1": # xyz_encoding_final 46 | in_ch: 512 47 | h_ch: 0 48 | out_ch: 512 49 | num: 1 50 | type: "mlp" 51 | act: "none" 52 | 53 | "2": # dir_a_encoding 54 | in_ch: 587 # 512 + 27 + 48 55 | h_ch: 0 56 | out_ch: 128 57 | num: 1 58 | type: "mlp" 59 | act: "relu" 60 | 61 | sigma: # sigma 62 | in_ch: 512 63 | h_ch: 0 64 | out_ch: 1 65 | num: 1 66 | type: "mlp" 67 | act: "none" 68 | 69 | color: # rgb 70 | in_ch: 128 71 | h_ch: 0 72 | out_ch: 3 73 | num: 1 74 | type: "mlp" 75 | act: "none" 76 | 77 | moe_external_gate: 78 | in_ch: 512 79 | h_ch: 512 80 | out_ch: 512 81 | num: 2 82 | type: "mlp" 83 | act: "none" 84 | out_skip: False 85 | 86 | gate_input_norm: 87 | in_ch: 512 88 | h_ch: 0 89 | out_ch: 0 90 | num: 1 91 | type: "layernorm" -------------------------------------------------------------------------------- /switch_nerf/configs/switch_nerf/bungee.yaml: -------------------------------------------------------------------------------- 1 | data_type: "nerf" 2 | dataset_type: "bungee" 3 | # white_bkgd: True 4 | appearance_dim: 0 5 | # use_viewdirs: True 6 | use_moe: True 7 | use_mip: True 8 | no_bg_nerf: True 9 | pos_xyz_dim: 10 10 | pos_dir_dim: 4 11 | fine_samples: 65 12 | coarse_samples: 65 13 | training_step_fn: "_training_step_nerf_mip" 14 | nerfmoe_class_name: "MipNeRFMoE" 15 | # moe_expert_num: 4 16 | llffhold: 16 17 | scale_factor: 3 18 | bungee_ray_nearfar: "sphere" 19 | 20 | model: 21 | layer_num_main: 3 22 | sigma_tag: 0 23 | dir_tag: 1 24 | color_tag: 2 25 | 26 | layers: 27 | "xyz": 28 | in_ch: 63 # 3 + 10 * 3 * 2 29 | h_ch: 0 30 | out_ch: 256 31 | num: 1 32 | type: "mlp" 33 | act: "none" 34 | 35 | "0": 36 | in_ch: 256 37 | h_ch: 256 38 | out_ch: 256 39 | num: 7 40 | skips: [3] 41 | init_factor: 1.0 42 | type: "moe" 43 | act: "relu" 44 | 45 | # gate 46 | gate_type: "top" 47 | k: 1 48 | fp32_gate: True 49 | gate_dim: 256 50 | 51 | "1": # xyz_encoding_final 52 | in_ch: 256 53 | h_ch: 0 54 | out_ch: 256 55 | num: 1 56 | type: "mlp" 57 | act: "none" 58 | 59 | "2": # dir_a_encoding 60 | in_ch: 283 # 256 + 27 61 | h_ch: 0 62 | out_ch: 128 63 | num: 1 64 | type: "mlp" 65 | act: "relu" 66 | 67 | sigma: # sigma 68 | in_ch: 256 69 | h_ch: 0 70 | out_ch: 1 71 | num: 1 72 | type: "mlp" 73 | act: "none" 74 | 75 | color: # rgb 76 | in_ch: 128 77 | h_ch: 0 78 | out_ch: 3 79 | num: 1 80 | type: "mlp" 81 | act: "none" 82 | 83 | moe_external_gate: 84 | in_ch: 256 85 | h_ch: 256 86 | out_ch: 256 87 | num: 2 88 | type: "mlp" 89 | act: "none" 90 | out_skip: False 91 | 92 | gate_input_norm: 93 | in_ch: 256 94 | h_ch: 0 95 | out_ch: 0 96 | num: 1 97 | type: "layernorm" -------------------------------------------------------------------------------- /switch_nerf/image_metadata.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | from zipfile import ZipFile 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | 10 | 11 | class ImageMetadata: 12 | def __init__(self, image_path: Path, c2w: torch.Tensor, W: int, H: int, intrinsics: torch.Tensor, image_index: int, 13 | mask_path: Optional[Path], is_val: bool): 14 | self.image_path = image_path 15 | self.c2w = c2w.float() 16 | self.W = W 17 | self.H = H 18 | self.intrinsics = intrinsics 19 | self.image_index = image_index 20 | self._mask_path = mask_path 21 | self.is_val = is_val 22 | 23 | if self.intrinsics.numel() == 2: 24 | # for dataset of waymo processed by LargeScaleNeRFPytorch 25 | intrinsics = torch.zeros([4]) 26 | intrinsics[0] = self.intrinsics[0] 27 | intrinsics[1] = self.intrinsics[1] 28 | intrinsics[2] = self.W / 2.0 29 | intrinsics[3] = self.H / 2.0 30 | self.intrinsics = intrinsics 31 | 32 | def load_image(self) -> torch.Tensor: 33 | rgbs = Image.open(self.image_path).convert('RGB') 34 | size = rgbs.size 35 | 36 | if size[0] != self.W or size[1] != self.H: 37 | rgbs = rgbs.resize((self.W, self.H), Image.LANCZOS) 38 | 39 | return torch.ByteTensor(np.asarray(rgbs)) 40 | 41 | def load_mask(self) -> Optional[torch.Tensor]: 42 | if self._mask_path is None: 43 | return None 44 | 45 | with ZipFile(self._mask_path) as zf: 46 | with zf.open(self._mask_path.name) as f: 47 | keep_mask = torch.load(f, map_location='cpu') 48 | 49 | if keep_mask.shape[0] != self.H or keep_mask.shape[1] != self.W: 50 | keep_mask = F.interpolate(keep_mask.unsqueeze(0).unsqueeze(0).float(), 51 | size=(self.H, self.W)).bool().squeeze() 52 | 53 | return keep_mask 54 | -------------------------------------------------------------------------------- /switch_nerf/scripts/copy_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import Namespace 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | 11 | def _get_images_opts(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--image_path', type=str, required=True) 15 | parser.add_argument('--dataset_path', type=str, required=True) 16 | 17 | return parser.parse_args() 18 | 19 | 20 | def main(hparams: Namespace) -> None: 21 | image_path = Path(hparams.image_path) 22 | dataset_path = Path(hparams.dataset_path) 23 | (dataset_path / 'train' / 'rgbs').mkdir() 24 | (dataset_path / 'val' / 'rgbs').mkdir() 25 | 26 | with (Path(hparams.dataset_path) / 'mappings.txt').open() as f: 27 | for line in tqdm(f): 28 | image_name, metadata_name = line.strip().split(',') 29 | metadata_path = dataset_path / 'train' / 'metadata' / metadata_name 30 | if not metadata_path.exists(): 31 | metadata_path = dataset_path / 'val' / 'metadata' / metadata_name 32 | assert metadata_path.exists() 33 | 34 | distorted = cv2.imread(str(image_path / image_name)) 35 | metadata = torch.load(metadata_path, map_location='cpu') 36 | intrinsics = metadata['intrinsics'] 37 | camera_matrix = np.array([[intrinsics[0], 0, intrinsics[2]], 38 | [0, intrinsics[1], intrinsics[3]], 39 | [0, 0, 1]]) 40 | 41 | undistorted = cv2.undistort(distorted, camera_matrix, metadata['distortion'].numpy()) 42 | assert undistorted.shape[0] == metadata['H'] 43 | assert undistorted.shape[1] == metadata['W'] 44 | 45 | cv2.imwrite(str(dataset_path / metadata_path.parent.parent / 'rgbs' / '{}.{}'.format(metadata_path.stem, 46 | image_name.split('.')[ 47 | -1])), 48 | undistorted) 49 | 50 | 51 | if __name__ == '__main__': 52 | main(_get_images_opts()) 53 | -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/ray_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # from nerf_pytorch 5 | # def get_rays_np(H, W, K, c2w): 6 | # i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 7 | # dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1) 8 | # # Rotate ray directions from camera frame to the world frame 9 | # rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 10 | # # Translate camera frame's origin to the world frame. It is the origin of all rays. 11 | # rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 12 | # return rays_o, rays_d 13 | 14 | def get_rays(H, W, K, c2w): 15 | i, j = torch.meshgrid(torch.linspace(0, W-1, W, device=c2w.device), torch.linspace(0, H-1, H, device=c2w.device)) # pytorch's meshgrid has indexing='ij' 16 | i = i.t() 17 | j = j.t() 18 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) 19 | # Rotate ray directions from camera frame to the world frame 20 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 21 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 22 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 23 | rays_o = c2w[:3,-1].expand(rays_d.shape) 24 | return rays_o, rays_d 25 | 26 | 27 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 28 | # Shift ray origins to near plane 29 | t = -(near + rays_o[...,2]) / rays_d[...,2] 30 | rays_o = rays_o + t[...,None] * rays_d 31 | 32 | # Projection 33 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 34 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 35 | o2 = 1. + 2. * near / rays_o[...,2] 36 | 37 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 38 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 39 | d2 = -2. * near / rays_o[...,2] 40 | 41 | rays_o = torch.stack([o0,o1,o2], -1) 42 | rays_d = torch.stack([d0,d1,d2], -1) 43 | 44 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 45 | 46 | return rays_o, rays_d -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /switch_nerf/datasets/memory_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from switch_nerf.datasets.dataset_utils import get_rgb_index_mask 7 | from switch_nerf.image_metadata import ImageMetadata 8 | from switch_nerf.misc_utils import main_tqdm, main_print 9 | from switch_nerf.ray_utils import get_rays, get_ray_directions 10 | 11 | 12 | class MemoryDataset(Dataset): 13 | 14 | def __init__(self, metadata_items: List[ImageMetadata], near: float, far: float, ray_altitude_range: List[float], 15 | center_pixels: bool, device: torch.device): 16 | super(MemoryDataset, self).__init__() 17 | 18 | rgbs = [] 19 | rays = [] 20 | indices = [] 21 | 22 | main_print('Loading data') 23 | 24 | for metadata_item in main_tqdm(metadata_items): 25 | image_data = get_rgb_index_mask(metadata_item) 26 | 27 | if image_data is None: 28 | continue 29 | 30 | image_rgbs, image_indices, image_keep_mask = image_data 31 | 32 | directions = get_ray_directions(metadata_item.W, 33 | metadata_item.H, 34 | metadata_item.intrinsics[0], 35 | metadata_item.intrinsics[1], 36 | metadata_item.intrinsics[2], 37 | metadata_item.intrinsics[3], 38 | center_pixels, 39 | device) 40 | image_rays = get_rays(directions, metadata_item.c2w.to(device), near, far, ray_altitude_range).view(-1, 41 | 8).cpu() 42 | if image_keep_mask is not None: 43 | image_rays = image_rays[image_keep_mask == True] 44 | 45 | rgbs.append(image_rgbs.float() / 255.) 46 | rays.append(image_rays) 47 | indices.append(image_indices) 48 | 49 | main_print('Finished loading data') 50 | 51 | self._rgbs = torch.cat(rgbs) 52 | self._rays = torch.cat(rays) 53 | self._image_indices = torch.cat(indices) 54 | 55 | def __len__(self) -> int: 56 | return self._rgbs.shape[0] 57 | 58 | def __getitem__(self, idx) -> Dict[str, torch.Tensor]: 59 | return { 60 | 'rgbs': self._rgbs[idx], 61 | 'rays': self._rays[idx], 62 | 'image_indices': self._image_indices[idx] 63 | } 64 | -------------------------------------------------------------------------------- /switch_nerf/modules/tutel_moe_ext/tutel_system.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os, sys 5 | import re 6 | import logging 7 | 8 | TUTEL_CUDA_SANDBOX = int(os.environ.get('TUTEL_CUDA_SANDBOX', 0)) 9 | 10 | def init_affinity_at_program_beginning(): 11 | if TUTEL_CUDA_SANDBOX: 12 | return 13 | try: 14 | numa_type = int(os.environ.get('NUMA_TYPE', '1')) 15 | if numa_type <= 0: 16 | return 17 | group_rank = int(os.environ.get('LOCAL_RANK', '0')) 18 | nodes = sorted([int(x[4:]) for x in os.listdir('/sys/devices/system/node') if re.match('node[0-9]+', x)]) 19 | cpus = [sorted([int(x[3:]) for x in os.listdir('/sys/devices/system/node/node%d' % node_id) if re.match('cpu[0-9]+', x)]) for node_id in nodes] 20 | sel_node = (group_rank // numa_type) % len(nodes) 21 | os.sched_setaffinity(0, cpus[sel_node]) 22 | logging.info('LOCAL_RANK %d is to set NUMA node: %d (total NUMA nodes = %d)' % (group_rank, sel_node, len(nodes))) 23 | except Exception as ex: 24 | if group_rank == 0: 25 | logging.warning('Failed to set NUMA status: %s' % ex) 26 | 27 | def init_data_model_parallel(group_count=1, backend='nccl', use_slurm=False, timeout=None): 28 | from tutel import net as C 29 | from .tutel_communicate_nobatch import create_groups_from_world_slurm, create_groups_from_world 30 | if use_slurm: 31 | result = create_groups_from_world_slurm(group_count=group_count, include_init=backend) 32 | else: 33 | result = create_groups_from_world(group_count=group_count, include_init=backend, timeout=timeout) 34 | result.is_cuda = (result.local_device.type == 'cuda') 35 | 36 | logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}') 37 | init_data_model_parallel.default_env = result 38 | 39 | def on_quit(): 40 | sys.stdout.flush() 41 | sys.stderr.flush() 42 | # Builtin dist.all_to_all_single in torch is unstable in some versions. 43 | # Temp work around: https://github.com/pytorch/pytorch/issues/56390 44 | if getattr(C.simple_all_to_all, '_use_builtins', False): 45 | os._exit(0) 46 | 47 | import atexit 48 | atexit.register(lambda *args: on_quit()) 49 | return result 50 | 51 | def get_local_session(): 52 | if not hasattr(init_data_model_parallel, 'default_env'): 53 | raise Exception("Current session is not initialized with: system.init_data_model_parallel() from tutel") 54 | return init_data_model_parallel.default_env 55 | 56 | def record_time(): 57 | import time 58 | if get_local_session().is_cuda: 59 | import torch 60 | torch.cuda.synchronize() 61 | return time.time() 62 | -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_blender_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for frame in meta['frames'][::skip]: 57 | fname = os.path.join(basedir, frame['file_path'] + '.png') 58 | imgs.append(imageio.imread(fname)) 59 | poses.append(np.array(frame['transform_matrix'])) 60 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 61 | poses = np.array(poses).astype(np.float32) 62 | counts.append(counts[-1] + imgs.shape[0]) 63 | all_imgs.append(imgs) 64 | all_poses.append(poses) 65 | 66 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 67 | 68 | imgs = np.concatenate(all_imgs, 0) 69 | poses = np.concatenate(all_poses, 0) 70 | 71 | H, W = imgs[0].shape[:2] 72 | camera_angle_x = float(meta['camera_angle_x']) 73 | focal = .5 * W / np.tan(.5 * camera_angle_x) 74 | 75 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 76 | 77 | if half_res: 78 | H = H//2 79 | W = W//2 80 | focal = focal/2. 81 | 82 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 83 | for i, img in enumerate(imgs): 84 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 85 | imgs = imgs_half_res 86 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 87 | 88 | 89 | return imgs, poses, render_poses, [H, W, focal], i_split 90 | 91 | 92 | -------------------------------------------------------------------------------- /switch_nerf/models/mega_nerf.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MegaNeRF(nn.Module): 8 | def __init__(self, sub_modules: List[nn.Module], centroids: torch.Tensor, boundary_margin: float, xyz_real: bool, 9 | cluster_2d: bool, joint_training: bool = False): 10 | super(MegaNeRF, self).__init__() 11 | assert boundary_margin >= 1 12 | self.sub_modules = nn.ModuleList(sub_modules) 13 | self.register_buffer('centroids', centroids) 14 | self.boundary_margin = boundary_margin 15 | self.xyz_real = xyz_real 16 | self.cluster_dim_start = 1 if cluster_2d else 0 17 | self.joint_training = joint_training 18 | 19 | def forward(self, x: torch.Tensor, sigma_only: bool = False, 20 | sigma_noise: Optional[torch.Tensor] = None) -> torch.Tensor: 21 | if self.boundary_margin > 1: 22 | cluster_distances = torch.cdist(x[:, self.cluster_dim_start:3], self.centroids[:, self.cluster_dim_start:]) 23 | inverse_cluster_distances = 1 / (cluster_distances + 1e-8) 24 | 25 | min_cluster_distances = cluster_distances.min(dim=1)[0].unsqueeze(-1).repeat(1, cluster_distances.shape[1]) 26 | inverse_cluster_distances[cluster_distances > self.boundary_margin * min_cluster_distances] = 0 27 | weights = inverse_cluster_distances / inverse_cluster_distances.sum(dim=-1).unsqueeze(-1) 28 | else: 29 | cluster_assignments = torch.cdist(x[:, self.cluster_dim_start:3], 30 | self.centroids[:, self.cluster_dim_start:]).argmin(dim=1) 31 | 32 | results = torch.empty(0) 33 | 34 | for i, child in enumerate(self.sub_modules): 35 | cluster_mask = cluster_assignments == i if self.boundary_margin == 1 else weights[:, i] > 0 36 | sub_input = x[cluster_mask, 3:] if self.xyz_real else x[cluster_mask] 37 | 38 | if sub_input.shape[0] > 0: 39 | sub_result = child(sub_input, sigma_only, 40 | sigma_noise[cluster_mask] if sigma_noise is not None else None) 41 | 42 | if results.shape[0] == 0: 43 | results = torch.zeros(x.shape[0], sub_result.shape[1], device=sub_result.device, 44 | dtype=sub_result.dtype) 45 | 46 | if self.boundary_margin == 1: 47 | results[cluster_mask] = sub_result 48 | else: 49 | results[cluster_mask] += sub_result * weights[cluster_mask, i].unsqueeze(-1) 50 | 51 | elif self.joint_training: # Hack to make distributed training happy 52 | sub_result = child(x[:0, 3:] if self.xyz_real else x[:0], sigma_only, 53 | sigma_noise[:0] if sigma_noise is not None else None) 54 | 55 | if results.shape[0] == 0: 56 | results = torch.zeros(x.shape[0], sub_result.shape[1], device=sub_result.device, 57 | dtype=sub_result.dtype) 58 | 59 | results[:0] += 0 * sub_result 60 | 61 | return results 62 | -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/load_LINEMOD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for idx_test, frame in enumerate(meta['frames'][::skip]): 57 | fname = frame['file_path'] 58 | if s == 'test': 59 | print(f"{idx_test}th test frame: {fname}") 60 | imgs.append(imageio.imread(fname)) 61 | poses.append(np.array(frame['transform_matrix'])) 62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 63 | poses = np.array(poses).astype(np.float32) 64 | counts.append(counts[-1] + imgs.shape[0]) 65 | all_imgs.append(imgs) 66 | all_poses.append(poses) 67 | 68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 69 | 70 | imgs = np.concatenate(all_imgs, 0) 71 | poses = np.concatenate(all_poses, 0) 72 | 73 | H, W = imgs[0].shape[:2] 74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) 75 | K = meta['frames'][0]['intrinsic_matrix'] 76 | print(f"Focal: {focal}") 77 | 78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 79 | 80 | if half_res: 81 | H = H//2 82 | W = W//2 83 | focal = focal/2. 84 | 85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 86 | for i, img in enumerate(imgs): 87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 88 | imgs = imgs_half_res 89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 90 | 91 | near = np.floor(min(metas['train']['near'], metas['test']['near'])) 92 | far = np.ceil(max(metas['train']['far'], metas['test']['far'])) 93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far 94 | 95 | 96 | -------------------------------------------------------------------------------- /switch_nerf/scripts/convert_to_container_moe.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 6 | 7 | from switch_nerf.models.mega_nerf import MegaNeRF 8 | from switch_nerf.models.mega_nerf_container import MegaNeRFContainer 9 | from switch_nerf.models.model_utils import get_nerf, get_bg_nerf 10 | from switch_nerf.opts import get_opts_base 11 | 12 | 13 | def _get_merge_opts() -> Namespace: 14 | parser = get_opts_base() 15 | parser.add_argument('--exp_name', type=str, required=True, help='experiment name') 16 | parser.add_argument('--centroid_path', type=str, required=True) 17 | parser.add_argument('--output', type=str, required=True) 18 | 19 | return parser.parse_known_args()[0] 20 | 21 | 22 | @torch.inference_mode() 23 | def main(hparams: Namespace) -> None: 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | exp_name = Path(hparams.exp_name) 26 | exp_name.mkdir(parents=True, exist_ok=True) 27 | output = exp_name / hparams.output 28 | hparams.moe_local_expert_num = hparams.moe_expert_num 29 | hparams.single_data_group = None 30 | centroid_metadata = torch.load(hparams.centroid_path, map_location='cpu') 31 | centroids = centroid_metadata['centroids'] 32 | 33 | loaded = torch.load(hparams.ckpt_path, map_location='cpu') 34 | consume_prefix_in_state_dict_if_present(loaded['model_state_dict'], prefix='module.') 35 | 36 | if hparams.appearance_dim > 0: 37 | appearance_count = len(loaded['model_state_dict']['embedding_a.weight']) 38 | else: 39 | appearance_count = 0 40 | 41 | sub_module = get_nerf(hparams, appearance_count) 42 | 43 | if 'bg_model_state_dict' in loaded: 44 | bg_sub_module = get_bg_nerf(hparams, appearance_count) 45 | 46 | container = MegaNeRFContainer([sub_module], [bg_sub_module] if 'bg_model_state_dict' in loaded else [], centroids, 47 | torch.IntTensor(centroid_metadata['grid_dim']), 48 | centroid_metadata['min_position'], 49 | centroid_metadata['max_position'], 50 | hparams.pos_dir_dim > 0, 51 | hparams.appearance_dim > 0, 52 | centroid_metadata['cluster_2d']) 53 | torch.jit.save(torch.jit.script(container.eval()), output) 54 | container = torch.jit.load(output, map_location='cpu') 55 | 56 | # Test container 57 | nerf = getattr(container, 'sub_module_{}'.format(0)).to(device) 58 | 59 | width = 3 60 | if hparams.pos_dir_dim > 0: 61 | width += 3 62 | if hparams.appearance_dim > 0: 63 | width += 1 64 | 65 | print('fg test eval: {}'.format(nerf(torch.ones(1, width, device=device)))) 66 | sub_module = sub_module.to(device) 67 | print('fg sub_module test eval: {}'.format(sub_module(torch.ones(1, width, device=device)))) 68 | 69 | if 'bg_model_state_dict' in loaded: 70 | bg_nerf = getattr(container, 'bg_sub_module_{}'.format(0)).to(device) 71 | 72 | width = 8 73 | print('bg test eval: {}'.format(bg_nerf(torch.ones(1, width, device=device)))) 74 | bg_sub_module = bg_sub_module.to(device) 75 | print('bg bg_sub_module test eval: {}'.format(bg_nerf(torch.ones(1, width, device=device)))) 76 | 77 | 78 | if __name__ == '__main__': 79 | main(_get_merge_opts()) 80 | -------------------------------------------------------------------------------- /switch_nerf/ray_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def get_ray_directions(W: int, H: int, fx: float, fy: float, cx: float, cy: float, center_pixels: bool, 7 | device: torch.device) -> torch.Tensor: 8 | i, j = torch.meshgrid(torch.arange(W, dtype=torch.float32, device=device), 9 | torch.arange(H, dtype=torch.float32, device=device), indexing='xy') 10 | if center_pixels: 11 | i = i.clone() + 0.5 12 | j = j.clone() + 0.5 13 | 14 | directions = \ 15 | torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) # (H, W, 3) 16 | directions /= torch.linalg.norm(directions, dim=-1, keepdim=True) 17 | 18 | return directions 19 | 20 | 21 | def get_rays(directions: torch.Tensor, c2w: torch.Tensor, near: float, far: float, 22 | ray_altitude_range: List[float]) -> torch.Tensor: 23 | # Rotate ray directions from camera coordinate to the world coordinate 24 | rays_d = directions @ c2w[:, :3].T # (H, W, 3) 25 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 26 | 27 | # The origin of all rays is the camera origin in world coordinate 28 | rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3) 29 | 30 | return _get_rays_inner(rays_o, rays_d, near, far, ray_altitude_range) 31 | 32 | 33 | def get_rays_batch(directions: torch.Tensor, c2w: torch.Tensor, near: float, far: float, 34 | ray_altitude_range: List[float]) -> torch.Tensor: 35 | # Rotate ray directions from camera coordinate to the world coordinate 36 | rays_d = directions @ c2w[:, :, :3].transpose(1, 2) # (n, H*W, 3) 37 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 38 | # The origin of all rays is the camera origin in world coordinate 39 | rays_o = c2w[:, :, 3].unsqueeze(1).expand(rays_d.shape) # (n, H*W, 3) 40 | 41 | return _get_rays_inner(rays_o, rays_d, near, far, ray_altitude_range) 42 | 43 | 44 | def _get_rays_inner(rays_o: torch.Tensor, rays_d: torch.Tensor, near: float, far: float, 45 | ray_altitude_range: List[float]) -> torch.Tensor: 46 | # c2w is drb, ray_altitude_range is max_altitude (neg), min_altitude (neg) 47 | near_bounds = near * torch.ones_like(rays_o[..., :1]) 48 | far_bounds = far * torch.ones_like(rays_o[..., :1]) 49 | 50 | if ray_altitude_range is not None: 51 | _truncate_with_plane_intersection(rays_o, rays_d, ray_altitude_range[0], near_bounds) 52 | near_bounds = torch.clamp(near_bounds, min=near) 53 | _truncate_with_plane_intersection(rays_o, rays_d, ray_altitude_range[1], far_bounds) 54 | 55 | far_bounds = torch.clamp(far_bounds, max=far) 56 | far_bounds = torch.maximum(near_bounds, far_bounds) 57 | 58 | return torch.cat([rays_o, 59 | rays_d, 60 | near_bounds, 61 | far_bounds], 62 | -1) # (h, w, 8) 63 | 64 | 65 | def _truncate_with_plane_intersection(rays_o: torch.Tensor, rays_d: torch.Tensor, altitude: float, 66 | default_bounds: torch.Tensor) -> None: 67 | starts_before = rays_o[:, :, 0] < altitude 68 | goes_down = rays_d[:, :, 0] > 0 69 | boundable_rays = torch.minimum(starts_before, goes_down) 70 | 71 | ray_points = rays_o[boundable_rays] 72 | if ray_points.shape[0] == 0: 73 | return 74 | 75 | ray_directions = rays_d[boundable_rays] 76 | 77 | plane_normal = torch.FloatTensor([-1, 0, 0]).to(rays_o.device).unsqueeze(1) 78 | ndotu = ray_directions.mm(plane_normal) 79 | 80 | plane_point = torch.FloatTensor([altitude, 0, 0]).to(rays_o.device) 81 | w = ray_points - plane_point 82 | si = -w.mm(plane_normal) / ndotu 83 | plane_intersection = w + si * ray_directions + plane_point 84 | default_bounds[boundable_rays] = (ray_points - plane_intersection).norm(dim=-1).unsqueeze(1) 85 | -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /switch_nerf/scripts/merge_points.py: -------------------------------------------------------------------------------- 1 | from plyfile import PlyData, PlyElement 2 | import os 3 | import argparse 4 | import random 5 | import numpy as np 6 | from pathlib import Path 7 | 8 | parser = argparse.ArgumentParser(description='Merge points.') 9 | parser.add_argument('--data_path', type=str, 10 | help='root data path') 11 | parser.add_argument('--image_ids', type=str, nargs='+', default=None, 12 | help='image ids for process') 13 | parser.add_argument("--merge_all", action='store_true', default=False, 14 | help='''merge all the data in data_path''') 15 | parser.add_argument("--image_num", type=int, default=0, 16 | help='''image number used for merge all''') 17 | parser.add_argument("--expert_num", type=int, default=8, 18 | help='''expert or submodel number''') 19 | parser.add_argument("--model_type", type=str, 20 | help='''mega or switch or nerf''') 21 | parser.add_argument("--data_type", type=str, default="coarse", 22 | help='''coarse or fine, only support coarse''') 23 | parser.add_argument("--topk", type=int, default=0, 24 | help='''topk for expert''') 25 | parser.add_argument("-r", "--sample_ratio", type=float, default=1.0, 26 | help='''topk for expert''') 27 | 28 | args = parser.parse_args() 29 | 30 | data_path = args.data_path 31 | merge_all = args.merge_all 32 | expert_num = args.expert_num 33 | model_type = args.model_type 34 | data_type = args.data_type 35 | topk = args.topk 36 | sample_ratio = args.sample_ratio 37 | 38 | if model_type == "nerf": 39 | if merge_all: 40 | data_path_1 = Path(data_path) 41 | plys = [i.name for i in data_path_1.glob('**/*') if i.suffix == ".ply"] 42 | image_ids = [i.split("_")[0] for i in plys if i.split("_")[0].isdigit()] 43 | image_ids = list(set(image_ids)) 44 | else: 45 | image_ids = args.image_ids 46 | else: 47 | if merge_all: 48 | image_ids = [str(i) for i in range(args.image_num)] 49 | else: 50 | image_ids = args.image_ids 51 | 52 | print("image_ids", image_ids) 53 | if expert_num > 0: 54 | for expert_id in range(expert_num): 55 | out_ply_name = '{}_pts_rgba_exp_{}.ply'.format(data_type, expert_id) 56 | out_ply_path = os.path.join(data_path, out_ply_name) 57 | sample_datas = [] 58 | for image_id in image_ids: 59 | if model_type == "mega": 60 | ply_name = '{:03d}_{}_pts_rgba_exp_{}.ply'.format(int(image_id), data_type, expert_id) 61 | elif model_type == "switch" or model_type == "nerf": 62 | ply_name = '{:03d}_{}_pts_rgba_top_{:01d}_exp_{}.ply'.format(int(image_id), data_type, topk, expert_id) 63 | 64 | ply_path = os.path.join(data_path, image_id, ply_name) 65 | ply_data = PlyData.read(ply_path) 66 | pts_data = ply_data.elements[0].data 67 | 68 | pts_num = ply_data.elements[0].count 69 | sample_num = int(pts_num * sample_ratio) 70 | if sample_num == 0: 71 | continue 72 | else: 73 | sample_ids = random.sample(range(pts_num), sample_num) 74 | sample_data = pts_data[sample_ids] 75 | sample_datas.append(sample_data) 76 | 77 | sample_data = np.concatenate(sample_datas) 78 | el = PlyElement.describe(sample_data, 'vertex') 79 | PlyData([el]).write(out_ply_path) 80 | pass 81 | else: 82 | # no moe or clusters 83 | out_ply_name = '{}_pts_rgba.ply'.format(data_type) 84 | out_ply_path = os.path.join(data_path, out_ply_name) 85 | sample_datas = [] 86 | for image_id in image_ids: 87 | ply_name = '{:03d}_{}_pts_rgba.ply'.format(int(image_id), data_type) 88 | 89 | ply_path = os.path.join(data_path, image_id, ply_name) 90 | ply_data = PlyData.read(ply_path) 91 | pts_data = ply_data.elements[0].data 92 | 93 | pts_num = ply_data.elements[0].count 94 | sample_num = int(pts_num * sample_ratio) 95 | if sample_num == 0: 96 | continue 97 | else: 98 | sample_ids = random.sample(range(pts_num), sample_num) 99 | sample_data = pts_data[sample_ids] 100 | sample_datas.append(sample_data) 101 | 102 | sample_data = np.concatenate(sample_datas) 103 | el = PlyElement.describe(sample_data, 'vertex') 104 | PlyData([el]).write(out_ply_path) 105 | pass 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/load_bungee.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | import cv2 5 | import imageio 6 | import torch 7 | 8 | def _load_google_data(basedir, factor=None): 9 | img_basedir = basedir 10 | img_folder = 'images' 11 | imgdir = os.path.join(img_basedir, img_folder) 12 | 13 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png') or f.endswith('jpeg')] 14 | sh = np.array(cv2.imread(imgfiles[0]).shape) 15 | imgs = [] 16 | for f in imgfiles: 17 | im = cv2.imread(f, cv2.IMREAD_UNCHANGED) 18 | if im.shape[-1] == 3: 19 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 20 | else: 21 | im = cv2.cvtColor(im, cv2.COLOR_BGRA2RGBA) 22 | im = cv2.resize(im, (sh[1]//factor, sh[0]//factor), interpolation=cv2.INTER_AREA) 23 | im = im.astype(np.float32) / 255 24 | imgs.append(im) 25 | imgs = np.stack(imgs, -1) 26 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 27 | 28 | data = json.load(open(os.path.join(basedir, 'poses_enu.json'))) 29 | poses = np.array(data['poses'])[:, :-2].reshape([-1, 3, 5]) 30 | poses[:, :2, 4] = np.array(sh[:2]//factor).reshape([1, 2]) 31 | poses[:, 2, 4] = poses[:,2, 4] * 1./factor 32 | 33 | scene_scaling_factor = data['scene_scale'] 34 | scene_origin = np.array(data['scene_origin']) 35 | scale_split = data['scale_split'] 36 | 37 | return imgs, poses, scene_scaling_factor, scene_origin, scale_split 38 | 39 | def load_bungee_multiscale_data(basedir, factor=3): 40 | imgs, poses, scene_scaling_factor, scene_origin, scale_split = _load_google_data(basedir, factor=factor) 41 | print('Loaded image data shape:', imgs.shape, ' hwf:', poses[0,:,-1]) 42 | return imgs, poses, scene_scaling_factor, scene_origin, scale_split 43 | 44 | def get_bungee_nearfar_radii(rays, scene_scaling_factor, scene_origin, ray_nearfar): 45 | rays_o = rays[..., 0:3] 46 | rays_d = rays[..., 3:6] 47 | # rays_shape = rays.shape[0:-1] 48 | 49 | if ray_nearfar == 'sphere': ## treats earth as a sphere and computes the intersection of a ray and a sphere 50 | globe_center = torch.tensor(np.array(scene_origin) * scene_scaling_factor).float() 51 | 52 | # 6371011 is earth radius, 250 is the assumed height limitation of buildings in the scene 53 | earth_radius = 6371011 * scene_scaling_factor 54 | earth_radius_plus_bldg = (6371011+250) * scene_scaling_factor 55 | 56 | ## intersect with building upper limit sphere 57 | delta = (2*torch.sum((rays_o-globe_center) * rays_d, dim=-1))**2 - 4*torch.norm(rays_d, dim=-1)**2 * (torch.norm((rays_o-globe_center), dim=-1)**2 - (earth_radius_plus_bldg)**2) 58 | d_near = (-2*torch.sum((rays_o-globe_center) * rays_d, dim=-1) - delta**0.5) / (2*torch.norm(rays_d, dim=-1)**2) 59 | rays_start = rays_o + (d_near[...,None]*rays_d) 60 | 61 | ## intersect with earth 62 | delta = (2*torch.sum((rays_o-globe_center) * rays_d, dim=-1))**2 - 4*torch.norm(rays_d, dim=-1)**2 * (torch.norm((rays_o-globe_center), dim=-1)**2 - (earth_radius)**2) 63 | d_far = (-2*torch.sum((rays_o-globe_center) * rays_d, dim=-1) - delta**0.5) / (2*torch.norm(rays_d, dim=-1)**2) 64 | rays_end = rays_o + (d_far[...,None]*rays_d) 65 | 66 | ## compute near and far for each ray 67 | new_near = torch.norm(rays_o - rays_start, dim=-1, keepdim=True) 68 | near = new_near * 0.9 69 | 70 | new_far = torch.norm(rays_o - rays_end, dim=-1, keepdim=True) 71 | far = new_far * 1.1 72 | 73 | elif ray_nearfar == 'flat': ## treats earth as a flat surface and computes the intersection of a ray and a plane 74 | normal = torch.tensor([0, 0, 1]).to(rays_o) * scene_scaling_factor 75 | p0_far = torch.tensor([0, 0, 0]).to(rays_o) * scene_scaling_factor 76 | p0_near = torch.tensor([0, 0, 250]).to(rays_o) * scene_scaling_factor 77 | 78 | near = (p0_near - rays_o * normal).sum(-1) / (rays_d * normal).sum(-1) 79 | far = (p0_far - rays_o * normal).sum(-1) / (rays_d * normal).sum(-1) 80 | near = near.clamp(min=1e-6) 81 | near, far = near.unsqueeze(-1), far.unsqueeze(-1) 82 | 83 | new_rays = torch.cat([rays, near, far], dim=-1) 84 | # new_rays = new_rays.reshape(list(rays_shape) + [8]) 85 | # rays_d: N, H, W, 3 86 | dx = torch.sqrt( 87 | torch.sum((rays_d[:, :-1, :, :] - rays_d[:, 1:, :, :])**2, -1)) 88 | dx = torch.cat([dx, dx[:, -2:-1, :]], 1) 89 | radii = dx[..., None] * 2 / np.sqrt(12) 90 | return new_rays, radii -------------------------------------------------------------------------------- /switch_nerf/spherical_harmonics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | import torch 24 | 25 | C0 = 0.28209479177387814 26 | C1 = 0.4886025119029199 27 | C2 = [ 28 | 1.0925484305920792, 29 | -1.0925484305920792, 30 | 0.31539156525252005, 31 | -1.0925484305920792, 32 | 0.5462742152960396 33 | ] 34 | C3 = [ 35 | -0.5900435899266435, 36 | 2.890611442640554, 37 | -0.4570457994644658, 38 | 0.3731763325901154, 39 | -0.4570457994644658, 40 | 1.445305721320277, 41 | -0.5900435899266435 42 | ] 43 | C4 = [ 44 | 2.5033429417967046, 45 | -1.7701307697799304, 46 | 0.9461746957575601, 47 | -0.6690465435572892, 48 | 0.10578554691520431, 49 | -0.6690465435572892, 50 | 0.47308734787878004, 51 | -1.7701307697799304, 52 | 0.6258357354491761, 53 | ] 54 | 55 | def eval_sh(deg: int, sh: torch.Tensor, dirs: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Evaluate spherical harmonics at unit directions 58 | using hardcoded SH polynomials. 59 | Works with torch/np/jnp. 60 | ... Can be 0 or more batch dimensions. 61 | Args: 62 | deg: int SH deg. Currently, 0-3 supported 63 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 64 | dirs: jnp.ndarray unit directions [..., 3] 65 | Returns: 66 | [..., C] 67 | """ 68 | assert deg <= 4 and deg >= 0 69 | assert (deg + 1) ** 2 == sh.shape[-1] 70 | 71 | result = C0 * sh[..., 0] 72 | if deg > 0: 73 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 74 | result = (result - 75 | C1 * y * sh[..., 1] + 76 | C1 * z * sh[..., 2] - 77 | C1 * x * sh[..., 3]) 78 | if deg > 1: 79 | xx, yy, zz = x * x, y * y, z * z 80 | xy, yz, xz = x * y, y * z, x * z 81 | result = (result + 82 | C2[0] * xy * sh[..., 4] + 83 | C2[1] * yz * sh[..., 5] + 84 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 85 | C2[3] * xz * sh[..., 7] + 86 | C2[4] * (xx - yy) * sh[..., 8]) 87 | 88 | if deg > 2: 89 | result = (result + 90 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 91 | C3[1] * xy * z * sh[..., 10] + 92 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 93 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 94 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 95 | C3[5] * z * (xx - yy) * sh[..., 14] + 96 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 97 | if deg > 3: 98 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 99 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 100 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 101 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 102 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 103 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 104 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 105 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 106 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 107 | return result -------------------------------------------------------------------------------- /switch_nerf/modules/tutel_moe_ext/tutel_sparse_nobatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | from tutel.impls.jit_compiler import JitCompiler 6 | 7 | 8 | def get_kernel_dtype(param_dtype): 9 | if param_dtype == torch.float16: 10 | return '__half2' 11 | elif param_dtype == torch.float32: 12 | return 'float' 13 | else: 14 | raise Exception("Unrecognized data type: %s" % param_dtype) 15 | 16 | 17 | def create_forward(param_dtype, is_cuda=True): 18 | if not is_cuda: 19 | return JitCompiler.generate_cpu_kernel(kernel_type=0) 20 | 21 | return JitCompiler.generate_kernel({'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' 22 | #define __dtype @dtype@ 23 | 24 | extern "C" __global__ __launch_bounds__(1024) void execute(__dtype* __restrict__ gates1_s, int* __restrict__ indices1_s, int* __restrict__ locations1_s, int* __restrict__ expert_locations_begin1_s, __dtype* __restrict__ reshaped_input, __dtype* __restrict__ dispatched_input, int samples, int hidden, int capacity) { 25 | // [thread_extent] blockIdx.x = 512 26 | // [thread_extent] threadIdx.x = 1024 27 | 28 | for (int i = blockIdx.x; i < samples; i += gridDim.x) 29 | if (indices1_s[i] >= 0) { 30 | #pragma unroll 31 | for (int j = threadIdx.x; j < hidden; j += 1024) 32 | atomicAdd(&dispatched_input[(expert_locations_begin1_s[indices1_s[i]] + locations1_s[i]) * (hidden) + j], gates1_s[i] * reshaped_input[i * (hidden) + j]); 33 | } 34 | } 35 | ''') 36 | 37 | 38 | def create_backward_data(param_dtype, is_cuda=True): 39 | if not is_cuda: 40 | return JitCompiler.generate_cpu_kernel(kernel_type=1) 41 | 42 | return JitCompiler.generate_kernel({'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' 43 | #define __dtype @dtype@ 44 | 45 | extern "C" __global__ __launch_bounds__(1024) void execute(__dtype* __restrict__ gates1_s, int* __restrict__ indices1_s, int* __restrict__ locations1_s, int* __restrict__ expert_locations_begin1_s, __dtype* __restrict__ grad_reshaped_input, __dtype* __restrict__ dispatched_input, int samples, int hidden, int capacity) { 46 | // [thread_extent] blockIdx.x = 512 47 | // [thread_extent] threadIdx.x = 1024 48 | 49 | for (int i = blockIdx.x; i < samples; i += gridDim.x) 50 | if (indices1_s[i] >= 0) { 51 | #pragma unroll 52 | for (int j = threadIdx.x; j < hidden; j += 1024) 53 | grad_reshaped_input[i * hidden + j] = gates1_s[i] * dispatched_input[(expert_locations_begin1_s[indices1_s[i]] + locations1_s[i]) * (hidden) + j]; 54 | } else { 55 | #pragma unroll 56 | for (int j = threadIdx.x; j < hidden; j += 1024) 57 | #if @IS_FLOAT@ 58 | grad_reshaped_input[i * hidden + j] = __dtype(0); 59 | #else 60 | grad_reshaped_input[i * hidden + j] = __dtype(0, 0); 61 | #endif 62 | } 63 | } 64 | ''') 65 | 66 | 67 | def create_backward_gate(param_dtype, is_cuda=True): 68 | if not is_cuda: 69 | return JitCompiler.generate_cpu_kernel(kernel_type=2) 70 | 71 | return JitCompiler.generate_kernel({'dtype': get_kernel_dtype(param_dtype), 'IS_FLOAT': 1 if param_dtype == torch.float32 else 0}, ''' 72 | #define __dtype @dtype@ 73 | 74 | extern "C" __global__ __launch_bounds__(32) void execute(void* __restrict__ grad_gates1_s, int* __restrict__ indices1_s, int* __restrict__ locations1_s, int* __restrict__ expert_locations_begin1_s, __dtype* __restrict__ reshaped_input, __dtype* __restrict__ dispatched_input, int samples, int hidden, int capacity) { 75 | // [thread_extent] blockIdx.x = 512 76 | // [thread_extent] threadIdx.x = 32 77 | for (int index = blockIdx.x; index < samples; index += gridDim.x) { 78 | if (indices1_s[index] < 0) { 79 | if (((int)threadIdx.x) == 0) 80 | #if @IS_FLOAT@ 81 | ((float*)grad_gates1_s)[index] = 0; 82 | #else 83 | ((half*)grad_gates1_s)[index] = __float2half_rn(0.000000e+00f); 84 | #endif 85 | continue; 86 | } 87 | int indice = expert_locations_begin1_s[indices1_s[index]] + locations1_s[index]; 88 | #if @IS_FLOAT@ 89 | __dtype grad_gates1_s_rf = 0.000000e+00f; 90 | #else 91 | __dtype grad_gates1_s_rf = __dtype(0, 0); 92 | #endif 93 | for (int i = threadIdx.x; i < hidden; i += 32) 94 | grad_gates1_s_rf += dispatched_input[indice * (hidden) + i] * reshaped_input[index * (hidden) + i]; 95 | 96 | #if !defined(__HIPCC__) 97 | __dtype red_buf0[1]; 98 | unsigned int mask[1]; 99 | __dtype t0[1]; 100 | red_buf0[(0)] = grad_gates1_s_rf; 101 | mask[(0)] = __activemask(); 102 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 16, 32); 103 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 104 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 8, 32); 105 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 106 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32); 107 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 108 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32); 109 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 110 | t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32); 111 | red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]); 112 | red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], 0, 32); 113 | #else 114 | __shared__ __dtype red_buf0[32]; 115 | __syncthreads(); 116 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = grad_gates1_s_rf; 117 | if (((int)threadIdx.x) < 16) { 118 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 16))])); 119 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 8))])); 120 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 4))])); 121 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 2))])); 122 | ((volatile __dtype*)red_buf0)[(((int)threadIdx.x))] = ((__dtype)(((volatile __dtype*)red_buf0)[(((int)threadIdx.x))]) + (__dtype)(((volatile __dtype*)red_buf0)[((((int)threadIdx.x) + 1))])); 123 | } 124 | __syncthreads(); 125 | #endif 126 | if (((int)threadIdx.x) == 0) 127 | #if @IS_FLOAT@ 128 | ((float*)grad_gates1_s)[index] = red_buf0[(0)]; 129 | #else 130 | ((half*)grad_gates1_s)[index] = red_buf0[(0)].x + red_buf0[(0)].y; 131 | #endif 132 | } 133 | } 134 | ''') 135 | -------------------------------------------------------------------------------- /switch_nerf/models/nerf.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch.cuda.amp import custom_bwd, custom_fwd 7 | from torch.autograd import Function 8 | 9 | class Embedding(nn.Module): 10 | def __init__(self, num_freqs: int, logscale=True): 11 | """ 12 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 13 | """ 14 | super(Embedding, self).__init__() 15 | 16 | if logscale: 17 | self.freq_bands = 2 ** torch.linspace(0, num_freqs - 1, num_freqs) 18 | else: 19 | self.freq_bands = torch.linspace(1, 2 ** (num_freqs - 1), num_freqs) 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | out = [x] 23 | for freq in self.freq_bands: 24 | out += [torch.sin(freq * x), torch.cos(freq * x)] 25 | 26 | return torch.cat(out, -1) 27 | 28 | class MipEmbedder(nn.Module): 29 | def __init__(self, num_freqs: int, logscale=True, input_dims=3): 30 | super(MipEmbedder, self).__init__() 31 | embed_fns = [] 32 | d = input_dims 33 | out_dim = 0 34 | embed_fns.append(lambda x : x[:,:d]) 35 | out_dim += d 36 | 37 | max_freq = num_freqs - 1 38 | min_freq = 0 39 | 40 | if logscale: 41 | freq_bands_y = 2.**torch.linspace(min_freq, max_freq, steps=num_freqs) 42 | freq_bands_w = 4.**torch.linspace(min_freq, max_freq, steps=num_freqs) 43 | else: 44 | freq_bands_y = torch.linspace(2.**min_freq, 2.**max_freq, steps=num_freqs) 45 | freq_bands_w = torch.linspace(4.**min_freq, 4.**max_freq, steps=num_freqs) 46 | 47 | for ctr in range(len(freq_bands_y)): 48 | for p_fn in [torch.sin, torch.cos]: 49 | embed_fns.append(lambda inputs, p_fn=p_fn, freq_y=freq_bands_y[ctr], freq_w=freq_bands_w[ctr] : p_fn(inputs[:,:d] * freq_y) * torch.exp((-0.5) * freq_w * inputs[:,d:])) 50 | out_dim += d 51 | 52 | self.embed_fns = embed_fns 53 | self.out_dim = out_dim 54 | 55 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 56 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 57 | 58 | class ShiftedSoftplus(nn.Module): 59 | __constants__ = ['beta', 'threshold'] 60 | beta: int 61 | threshold: int 62 | 63 | def __init__(self, beta: int = 1, threshold: int = 20) -> None: 64 | super(ShiftedSoftplus, self).__init__() 65 | self.beta = beta 66 | self.threshold = threshold 67 | 68 | def forward(self, x: torch.Tensor) -> torch.Tensor: 69 | return F.softplus(x - 1, self.beta, self.threshold) 70 | 71 | def extra_repr(self) -> str: 72 | return 'beta={}, threshold={}'.format(self.beta, self.threshold) 73 | 74 | 75 | class NeRF(nn.Module): 76 | def __init__(self, pos_xyz_dim: int, pos_dir_dim: int, layers: int, skip_layers: List[int], layer_dim: int, 77 | appearance_dim: int, affine_appearance: bool, appearance_count: int, rgb_dim: int, xyz_dim: int, 78 | sigma_activation: nn.Module): 79 | super(NeRF, self).__init__() 80 | self.xyz_dim = xyz_dim 81 | 82 | if rgb_dim > 3: 83 | assert pos_dir_dim == 0 84 | 85 | self.embedding_xyz = Embedding(pos_xyz_dim) 86 | in_channels_xyz = xyz_dim + xyz_dim * pos_xyz_dim * 2 87 | 88 | self.skip_layers = skip_layers 89 | 90 | xyz_encodings = [] 91 | 92 | # xyz encoding layers 93 | for i in range(layers): 94 | if i == 0: 95 | layer = nn.Linear(in_channels_xyz, layer_dim) 96 | elif i in skip_layers: 97 | layer = nn.Linear(layer_dim + in_channels_xyz, layer_dim) 98 | else: 99 | layer = nn.Linear(layer_dim, layer_dim) 100 | layer = nn.Sequential(layer, nn.ReLU(True)) 101 | xyz_encodings.append(layer) 102 | 103 | self.xyz_encodings = nn.ModuleList(xyz_encodings) 104 | 105 | if pos_dir_dim > 0: 106 | self.embedding_dir = Embedding(pos_dir_dim) 107 | in_channels_dir = 3 + 3 * pos_dir_dim * 2 108 | else: 109 | self.embedding_dir = None 110 | in_channels_dir = 0 111 | 112 | if appearance_dim > 0: 113 | self.embedding_a = nn.Embedding(appearance_count, appearance_dim) 114 | else: 115 | self.embedding_a = None 116 | 117 | if affine_appearance: 118 | assert appearance_dim > 0 119 | self.affine = nn.Linear(appearance_dim, 12) 120 | else: 121 | self.affine = None 122 | 123 | if pos_dir_dim > 0 or (appearance_dim > 0 and not affine_appearance): 124 | self.xyz_encoding_final = nn.Linear(layer_dim, layer_dim) 125 | # direction and appearance encoding layers 126 | self.dir_a_encoding = nn.Sequential( 127 | nn.Linear(layer_dim + in_channels_dir + (appearance_dim if not affine_appearance else 0), 128 | layer_dim // 2), 129 | nn.ReLU(True)) 130 | else: 131 | self.xyz_encoding_final = None 132 | 133 | # output layers 134 | self.sigma = nn.Linear(layer_dim, 1) 135 | self.sigma_activation = sigma_activation 136 | 137 | self.rgb = nn.Linear( 138 | layer_dim // 2 if (pos_dir_dim > 0 or (appearance_dim > 0 and not affine_appearance)) else layer_dim, 139 | rgb_dim) 140 | if rgb_dim == 3: 141 | self.rgb_activation = nn.Sigmoid() # = nn.Sequential(rgb, nn.Sigmoid()) 142 | else: 143 | self.rgb_activation = None # We're using spherical harmonics and will convert to sigmoid in rendering.py 144 | 145 | def forward(self, x: torch.Tensor, sigma_only: bool = False, 146 | sigma_noise: Optional[torch.Tensor] = None) -> torch.Tensor: 147 | expected = self.xyz_dim \ 148 | + (0 if (sigma_only or self.embedding_dir is None) else 3) \ 149 | + (0 if (sigma_only or self.embedding_a is None) else 1) 150 | 151 | if x.shape[1] != expected: 152 | raise Exception( 153 | 'Unexpected input shape: {} (expected: {}, xyz_dim: {})'.format(x.shape, expected, self.xyz_dim)) 154 | 155 | input_xyz = self.embedding_xyz(x[:, :self.xyz_dim]) 156 | xyz_ = input_xyz 157 | for i, xyz_encoding in enumerate(self.xyz_encodings): 158 | if i in self.skip_layers: 159 | xyz_ = torch.cat([input_xyz, xyz_], -1) 160 | xyz_ = xyz_encoding(xyz_) 161 | 162 | sigma = self.sigma(xyz_) 163 | if sigma_noise is not None: 164 | sigma += sigma_noise 165 | 166 | sigma = self.sigma_activation(sigma) 167 | 168 | if sigma_only: 169 | return sigma 170 | 171 | if self.xyz_encoding_final is not None: 172 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 173 | dir_a_encoding_input = [xyz_encoding_final] 174 | 175 | if self.embedding_dir is not None: 176 | # dir_a_encoding_input.append(self.embedding_dir(x[:, -4:-1])) 177 | dir_a_encoding_input.append(self.embedding_dir(x[:, self.xyz_dim:self.xyz_dim + 3])) 178 | 179 | if self.embedding_a is not None and self.affine is None: 180 | dir_a_encoding_input.append(self.embedding_a(x[:, -1].long())) 181 | 182 | dir_a_encoding = self.dir_a_encoding(torch.cat(dir_a_encoding_input, -1)) 183 | rgb = self.rgb(dir_a_encoding) 184 | else: 185 | rgb = self.rgb(xyz_) 186 | 187 | if self.affine is not None and self.embedding_a is not None: 188 | affine_transform = self.affine(self.embedding_a(x[:, -1].long())).view(-1, 3, 4) 189 | rgb = (affine_transform[:, :, :3] @ rgb.unsqueeze(-1) + affine_transform[:, :, 3:]).squeeze(-1) 190 | 191 | return torch.cat([self.rgb_activation(rgb) if self.rgb_activation is not None else rgb, sigma], -1) 192 | -------------------------------------------------------------------------------- /switch_nerf/models/model_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 6 | 7 | from switch_nerf.models.cascade import Cascade 8 | from switch_nerf.models.mega_nerf import MegaNeRF 9 | from switch_nerf.models.nerf import NeRF, ShiftedSoftplus 10 | from switch_nerf.models.nerf_moe import get_nerf_moe_inner 11 | 12 | def convert_to_seqexperts(state_dict): 13 | keys = list(state_dict.keys()) 14 | for key in keys: 15 | if "layers.0.experts.0." in key: 16 | if "weight" in key or "bias" in key: 17 | para_type = "weight" if "weight" in key else "bias" 18 | layer_id = int(key[-1]) 19 | v = state_dict.pop(key) 20 | v = torch.unbind(v, dim=0) 21 | for expert_id, expert_v in enumerate(v): 22 | new_key = f'module.layers.0.experts.0.experts.{expert_id}.layers.{layer_id}.{para_type}' 23 | if para_type == "weight": 24 | new_v = expert_v.t().contiguous() 25 | if para_type == "bias": 26 | new_v = expert_v.squeeze(0) 27 | state_dict[new_key] = new_v 28 | return state_dict 29 | 30 | def convert_to_seqexperts1(state_dict, moe_layer_num): 31 | keys = list(state_dict.keys()) 32 | for key in keys: 33 | for moe_layer_id in range(moe_layer_num): 34 | if f"layers.{moe_layer_id}.experts.0." in key: 35 | if "weight" in key or "bias" in key: 36 | para_type = "weight" if "weight" in key else "bias" 37 | layer_id = int(key[-1]) 38 | v = state_dict.pop(key) 39 | v = torch.unbind(v, dim=0) 40 | for expert_id, expert_v in enumerate(v): 41 | new_key = f'module.layers.{moe_layer_id}.experts.0.experts.{expert_id}.layers.{layer_id}.{para_type}' 42 | if para_type == "weight": 43 | new_v = expert_v.t().contiguous() 44 | if para_type == "bias": 45 | new_v = expert_v.squeeze(0) 46 | state_dict[new_key] = new_v 47 | return state_dict 48 | 49 | 50 | def convert_to_seqexperts2(state_dict, moe_layer_ids): 51 | keys = list(state_dict.keys()) 52 | for key in keys: 53 | for moe_layer_id in moe_layer_ids: 54 | if f"layers.{moe_layer_id}.experts.0." in key: 55 | if "weight" in key or "bias" in key: 56 | para_type = "weight" if "weight" in key else "bias" 57 | layer_id = int(key[-1]) 58 | v = state_dict.pop(key) 59 | v = torch.unbind(v, dim=0) 60 | for expert_id, expert_v in enumerate(v): 61 | new_key = f'module.layers.{moe_layer_id}.experts.0.experts.{expert_id}.layers.{layer_id}.{para_type}' 62 | if para_type == "weight": 63 | new_v = expert_v.t().contiguous() 64 | if para_type == "bias": 65 | new_v = expert_v.squeeze(0) 66 | state_dict[new_key] = new_v 67 | return state_dict 68 | 69 | def get_nerf(hparams: Namespace, appearance_count: int) -> nn.Module: 70 | return _get_nerf_inner(hparams, appearance_count, hparams.layer_dim, 3, 'model_state_dict') 71 | 72 | 73 | def get_bg_nerf(hparams: Namespace, appearance_count: int) -> nn.Module: 74 | if hparams.bg_use_cfg: 75 | tmp_use_moe = hparams.use_moe 76 | hparams.use_moe = hparams.bg_use_moe 77 | bg_nerf = _get_nerf_inner(hparams, appearance_count, hparams.bg_layer_dim, 4, 'bg_model_state_dict') 78 | hparams.use_moe = tmp_use_moe 79 | else: 80 | tmp_use_moe = hparams.use_moe 81 | hparams.use_moe = False 82 | bg_nerf = _get_nerf_inner(hparams, appearance_count, hparams.bg_layer_dim, 4, 'bg_model_state_dict') 83 | hparams.use_moe = tmp_use_moe 84 | return bg_nerf 85 | 86 | 87 | def _get_nerf_inner(hparams: Namespace, appearance_count: int, layer_dim: int, xyz_dim: int, 88 | weight_key: str) -> nn.Module: 89 | if hparams.container_path is not None: 90 | container = torch.jit.load(hparams.container_path, map_location='cpu') 91 | if xyz_dim == 3: 92 | return MegaNeRF([getattr(container, 'sub_module_{}'.format(i)) for i in range(len(container.centroids))], 93 | container.centroids, hparams.boundary_margin, False, container.cluster_2d) 94 | else: 95 | return MegaNeRF([getattr(container, 'bg_sub_module_{}'.format(i)) for i in range(len(container.centroids))], 96 | container.centroids, hparams.boundary_margin, True, container.cluster_2d) 97 | elif hparams.use_cascade: 98 | if hparams.use_moe: 99 | if weight_key == "model_state_dict": 100 | model_cfg_name = "model" 101 | elif weight_key == "bg_model_state_dict": 102 | model_cfg_name = "model_bg" 103 | else: 104 | model_cfg_name = None 105 | raise NotImplementedError 106 | nerf = Cascade( 107 | get_nerf_moe_inner(hparams, appearance_count, 108 | xyz_dim, model_cfg_name=model_cfg_name), 109 | get_nerf_moe_inner(hparams, appearance_count, 110 | xyz_dim, model_cfg_name=model_cfg_name) 111 | if hparams.fine_samples > 0 else None) 112 | else: 113 | nerf = Cascade( 114 | _get_single_nerf_inner(hparams, appearance_count, 115 | layer_dim if xyz_dim == 4 else layer_dim, 116 | xyz_dim), 117 | _get_single_nerf_inner(hparams, appearance_count, layer_dim, xyz_dim) if hparams.fine_samples > 0 else None) 118 | elif hparams.train_mega_nerf is not None: 119 | centroid_metadata = torch.load(hparams.train_mega_nerf, map_location='cpu') 120 | centroids = centroid_metadata['centroids'] 121 | nerf = MegaNeRF( 122 | [_get_single_nerf_inner(hparams, appearance_count, layer_dim, xyz_dim) for _ in 123 | range(len(centroids))], centroids, 1, xyz_dim == 4, centroid_metadata['cluster_2d'], True) 124 | elif hparams.use_moe: 125 | if weight_key == "model_state_dict": 126 | model_cfg_name = "model" 127 | elif weight_key == "bg_model_state_dict": 128 | model_cfg_name = "model_bg" 129 | else: 130 | model_cfg_name = None 131 | raise NotImplementedError 132 | nerf = get_nerf_moe_inner(hparams, appearance_count, xyz_dim, model_cfg_name=model_cfg_name) 133 | else: 134 | nerf = _get_single_nerf_inner(hparams, appearance_count, layer_dim, xyz_dim) 135 | 136 | if hparams.ckpt_path is not None: 137 | state_dict = torch.load(hparams.ckpt_path, map_location='cpu')[weight_key] 138 | 139 | if hparams.expertmlp2seqexperts and hparams.use_moe: 140 | if getattr(hparams, "moe_layer_num", 1) > 1: 141 | state_dict = convert_to_seqexperts1(state_dict, hparams.moe_layer_num) 142 | elif getattr(hparams, "moe_layer_ids", None) is not None: 143 | state_dict = convert_to_seqexperts2(state_dict, hparams.moe_layer_ids) 144 | else: 145 | state_dict = convert_to_seqexperts(state_dict) 146 | 147 | consume_prefix_in_state_dict_if_present(state_dict, prefix='module.') 148 | 149 | model_dict = nerf.state_dict() 150 | model_dict.update(state_dict) 151 | nerf.load_state_dict(model_dict) 152 | 153 | return nerf 154 | 155 | 156 | def _get_single_nerf_inner(hparams: Namespace, appearance_count: int, layer_dim: int, xyz_dim: int) -> nn.Module: 157 | rgb_dim = 3 * ((hparams.sh_deg + 1) ** 2) if hparams.sh_deg is not None else 3 158 | 159 | return NeRF(hparams.pos_xyz_dim, 160 | hparams.pos_dir_dim, 161 | hparams.layers, 162 | hparams.skip_layers, 163 | layer_dim, 164 | hparams.appearance_dim, 165 | hparams.affine_appearance, 166 | appearance_count, 167 | rgb_dim, 168 | xyz_dim, 169 | ShiftedSoftplus() if hparams.shifted_softplus else nn.ReLU()) 170 | -------------------------------------------------------------------------------- /switch_nerf/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import lpips as plips 6 | 7 | 8 | def psnr(rgbs: torch.Tensor, target_rgbs: torch.Tensor) -> float: 9 | mse = torch.mean((rgbs - target_rgbs) ** 2) 10 | return -10 * torch.log10(mse).item() 11 | 12 | 13 | def lpips(rgbs: torch.Tensor, target_rgbs: torch.Tensor) -> Dict[str, float]: 14 | gt = target_rgbs.permute([2, 0, 1]).contiguous() 15 | pred = rgbs.permute([2, 0, 1]).contiguous() 16 | 17 | lpips_vgg = plips.LPIPS(net='vgg').eval().to(rgbs.device) 18 | lpips_vgg_i = lpips_vgg(gt, pred, normalize=True) 19 | 20 | lpips_alex = plips.LPIPS(net='alex').eval().to(rgbs.device) 21 | lpips_alex_i = lpips_alex(gt, pred, normalize=True) 22 | 23 | lpips_squeeze = plips.LPIPS(net='squeeze').eval().to(rgbs.device) 24 | lpips_squeeze_i = lpips_squeeze(gt, pred, normalize=True) 25 | 26 | return {'vgg': lpips_vgg_i.item(), 'alex': lpips_alex_i.item(), 'squeeze': lpips_squeeze_i.item()} 27 | 28 | 29 | # Copyright 2021 The PlenOctree Authors. 30 | # Redistribution and use in source and binary forms, with or without 31 | # modification, are permitted provided that the following conditions are met: 32 | # 33 | # 1. Redistributions of source code must retain the above copyright notice, 34 | # this list of conditions and the following disclaimer. 35 | # 36 | # 2. Redistributions in binary form must reproduce the above copyright notice, 37 | # this list of conditions and the following disclaimer in the documentation 38 | # and/or other materials provided with the distribution. 39 | # 40 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 41 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 42 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 43 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 44 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 45 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 46 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 47 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 48 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 49 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 50 | # POSSIBILITY OF SUCH DAMAGE. 51 | def ssim( 52 | rgbs: torch.Tensor, 53 | target_rgbs: torch.Tensor, 54 | max_val: float, 55 | filter_size: int = 11, 56 | filter_sigma: float = 1.5, 57 | k1: float = 0.01, 58 | k2: float = 0.03, 59 | ) -> float: 60 | """Computes SSIM from two images. 61 | This function was modeled after tf.image.ssim, and should produce comparable 62 | output. 63 | Args: 64 | rgbs: torch.tensor. An image of size [..., width, height, num_channels]. 65 | target_rgbs: torch.tensor. An image of size [..., width, height, num_channels]. 66 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. 67 | filter_size: int >= 1. Window size. 68 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. 69 | k1: float > 0. One of the SSIM dampening parameters. 70 | k2: float > 0. One of the SSIM dampening parameters. 71 | Returns: 72 | Each image's mean SSIM. 73 | """ 74 | device = rgbs.device 75 | ori_shape = rgbs.size() 76 | width, height, num_channels = ori_shape[-3:] 77 | rgbs = rgbs.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 78 | target_rgbs = target_rgbs.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 79 | 80 | # Construct a 1D Gaussian blur filter. 81 | hw = filter_size // 2 82 | shift = (2 * hw - filter_size + 1) / 2 83 | f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2 84 | filt = torch.exp(-0.5 * f_i) 85 | filt /= torch.sum(filt) 86 | 87 | # Blur in x and y (faster than the 2D convolution). 88 | # z is a tensor of size [B, H, W, C] 89 | filt_fn1 = lambda z: F.conv2d( 90 | z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), 91 | padding=[hw, 0], groups=num_channels) 92 | filt_fn2 = lambda z: F.conv2d( 93 | z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), 94 | padding=[0, hw], groups=num_channels) 95 | 96 | # Vmap the blurs to the tensor size, and then compose them. 97 | filt_fn = lambda z: filt_fn1(filt_fn2(z)) 98 | mu0 = filt_fn(rgbs) 99 | mu1 = filt_fn(target_rgbs) 100 | mu00 = mu0 * mu0 101 | mu11 = mu1 * mu1 102 | mu01 = mu0 * mu1 103 | sigma00 = filt_fn(rgbs ** 2) - mu00 104 | sigma11 = filt_fn(target_rgbs ** 2) - mu11 105 | sigma01 = filt_fn(rgbs * target_rgbs) - mu01 106 | 107 | # Clip the variances and covariances to valid values. 108 | # Variance must be non-negative: 109 | sigma00 = torch.clamp(sigma00, min=0.0) 110 | sigma11 = torch.clamp(sigma11, min=0.0) 111 | sigma01 = torch.sign(sigma01) * torch.min( 112 | torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) 113 | ) 114 | 115 | c1 = (k1 * max_val) ** 2 116 | c2 = (k2 * max_val) ** 2 117 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 118 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 119 | ssim_map = numer / denom 120 | 121 | return torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1).item() 122 | 123 | 124 | def psnr_mask(rgbs: torch.Tensor, target_rgbs: torch.Tensor, valid_mask: torch.Tensor) -> float: 125 | # valid_mask: bool tensor 126 | rgbs = rgbs[valid_mask] 127 | target_rgbs = target_rgbs[valid_mask] 128 | mse = torch.mean((rgbs - target_rgbs) ** 2) 129 | return -10 * torch.log10(mse).item() 130 | 131 | def ssim_mask( 132 | rgbs: torch.Tensor, 133 | target_rgbs: torch.Tensor, 134 | max_val: float, 135 | valid_mask: torch.Tensor, 136 | filter_size: int = 11, 137 | filter_sigma: float = 1.5, 138 | k1: float = 0.01, 139 | k2: float = 0.03, 140 | ) -> float: 141 | """Computes SSIM from two images. 142 | This function was modeled after tf.image.ssim, and should produce comparable 143 | output. 144 | Args: 145 | rgbs: torch.tensor. An image of size [..., width, height, num_channels]. 146 | target_rgbs: torch.tensor. An image of size [..., width, height, num_channels]. 147 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. 148 | filter_size: int >= 1. Window size. 149 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. 150 | k1: float > 0. One of the SSIM dampening parameters. 151 | k2: float > 0. One of the SSIM dampening parameters. 152 | Returns: 153 | Each image's mean SSIM. 154 | """ 155 | device = rgbs.device 156 | ori_shape = rgbs.size() 157 | width, height, num_channels = ori_shape[-3:] 158 | rgbs = rgbs.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 159 | target_rgbs = target_rgbs.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 160 | 161 | # Construct a 1D Gaussian blur filter. 162 | hw = filter_size // 2 163 | shift = (2 * hw - filter_size + 1) / 2 164 | f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2 165 | filt = torch.exp(-0.5 * f_i) 166 | filt /= torch.sum(filt) 167 | 168 | # Blur in x and y (faster than the 2D convolution). 169 | # z is a tensor of size [B, H, W, C] 170 | filt_fn1 = lambda z: F.conv2d( 171 | z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), 172 | padding=[hw, 0], groups=num_channels) 173 | filt_fn2 = lambda z: F.conv2d( 174 | z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), 175 | padding=[0, hw], groups=num_channels) 176 | 177 | # Vmap the blurs to the tensor size, and then compose them. 178 | filt_fn = lambda z: filt_fn1(filt_fn2(z)) 179 | mu0 = filt_fn(rgbs) 180 | mu1 = filt_fn(target_rgbs) 181 | mu00 = mu0 * mu0 182 | mu11 = mu1 * mu1 183 | mu01 = mu0 * mu1 184 | sigma00 = filt_fn(rgbs ** 2) - mu00 185 | sigma11 = filt_fn(target_rgbs ** 2) - mu11 186 | sigma01 = filt_fn(rgbs * target_rgbs) - mu01 187 | 188 | # Clip the variances and covariances to valid values. 189 | # Variance must be non-negative: 190 | sigma00 = torch.clamp(sigma00, min=0.0) 191 | sigma11 = torch.clamp(sigma11, min=0.0) 192 | sigma01 = torch.sign(sigma01) * torch.min( 193 | torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) 194 | ) 195 | 196 | c1 = (k1 * max_val) ** 2 197 | c2 = (k2 * max_val) ** 2 198 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 199 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 200 | ssim_map = numer / denom 201 | 202 | # BCHW to BHWC 203 | ssim_map = ssim_map.permute(0, 2, 3, 1) 204 | ssim_map = ssim_map.squeeze(0) 205 | ssim_map = ssim_map[valid_mask] 206 | ssim_val = torch.mean(ssim_map).item() 207 | 208 | return ssim_val 209 | -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/nerf_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from .load_llff import load_llff_data 4 | from .load_deepvoxels import load_dv_data 5 | from .load_blender import load_blender_data 6 | from .load_LINEMOD import load_LINEMOD_data 7 | from .load_bungee import load_bungee_multiscale_data, get_bungee_nearfar_radii 8 | from .ray_utils import get_rays, ndc_rays 9 | 10 | import numpy as np 11 | import cv2 12 | from switch_nerf.misc_utils import main_log 13 | 14 | class NeRFDataset(Dataset): 15 | def __init__(self, args) -> None: 16 | super().__init__() 17 | # self.split = split 18 | self.K = None 19 | if args.dataset_type == 'llff': 20 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 21 | recenter=True, bd_factor=.75, 22 | spherify=args.spherify) 23 | hwf = poses[0,:3,-1] 24 | poses = poses[:,:3,:4] 25 | main_log(f'Loaded llff {images.shape}, {render_poses.shape}, {hwf}, {args.datadir}') 26 | if not isinstance(i_test, list): 27 | i_test = [i_test] 28 | 29 | if args.llffhold > 0: 30 | main_log(f'Auto LLFF holdout, {args.llffhold}') 31 | i_test = np.arange(images.shape[0])[::args.llffhold] 32 | 33 | i_val = i_test 34 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 35 | (i not in i_test and i not in i_val)]) 36 | 37 | main_log('DEFINING BOUNDS') 38 | if args.no_ndc: 39 | near = np.ndarray.min(bds) * .9 40 | far = np.ndarray.max(bds) * 1. 41 | 42 | else: 43 | near = 0. 44 | far = 1. 45 | main_log(f'NEAR {near} FAR {far}') 46 | 47 | elif args.dataset_type == 'blender': 48 | images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip) 49 | main_log(f'Loaded blender {images.shape}, {render_poses.shape}, {hwf}, {args.datadir}') 50 | i_train, i_val, i_test = i_split 51 | 52 | near = 2. 53 | far = 6. 54 | 55 | if args.white_bkgd: 56 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 57 | else: 58 | images = images[...,:3] 59 | 60 | elif args.dataset_type == 'LINEMOD': 61 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip) 62 | main_log(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}') 63 | main_log(f'[CHECK HERE] near: {near}, far: {far}.') 64 | i_train, i_val, i_test = i_split 65 | 66 | if args.white_bkgd: 67 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 68 | else: 69 | images = images[...,:3] 70 | 71 | elif args.dataset_type == 'deepvoxels': 72 | 73 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, 74 | basedir=args.datadir, 75 | testskip=args.testskip) 76 | 77 | main_log(f'Loaded deepvoxels {images.shape}, {render_poses.shape}, {hwf}, {args.datadir}') 78 | i_train, i_val, i_test = i_split 79 | 80 | hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1)) 81 | near = hemi_R-1. 82 | far = hemi_R+1. 83 | 84 | elif args.dataset_type == 'bungee': 85 | images, poses, scene_scaling_factor, scene_origin, scale_split = load_bungee_multiscale_data(args.datadir, args.factor) 86 | self.scene_origin = scene_origin 87 | self.scale_split = scale_split 88 | self.scene_scaling_factor = scene_scaling_factor 89 | # if args.llffhold > 0: 90 | print('Auto holdout,', args.llffhold) 91 | i_test = np.arange(images.shape[0])[::args.llffhold] 92 | i_val = i_test 93 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 94 | (i not in i_test)]) 95 | hwf = poses[0,:3,-1] 96 | poses = poses[:,:3,:4] 97 | render_poses = poses # no use 98 | near = 0. # no use 99 | far = 1. # no use 100 | else: 101 | main_log(f'Unknown dataset type {args.dataset_type}, exiting') 102 | raise NotImplementedError 103 | 104 | self.poses = torch.tensor(poses) 105 | self.images = images 106 | # self.bds = bds 107 | self.render_poses = torch.tensor(render_poses) 108 | self.i_test = i_test 109 | self.i_train = i_train 110 | self.i_val = i_val 111 | 112 | self.near = near 113 | self.far = far 114 | 115 | H, W, focal = hwf 116 | H, W = int(H), int(W) 117 | hwf = [H, W, focal] 118 | 119 | if self.K is None: 120 | self.K = torch.tensor([ 121 | [focal, 0, 0.5*W], 122 | [0, focal, 0.5*H], 123 | [0, 0, 1] 124 | ]) 125 | 126 | self.hwf = hwf 127 | self.H = H 128 | self.W = W 129 | 130 | if hasattr(args, "scale_factor") and args.scale_factor > 1.0: 131 | assert (self.H % args.scale_factor == 0.0 and self.W % args.scale_factor == 0.0) 132 | self.H = int(self.H // args.scale_factor) 133 | self.W = int(self.W // args.scale_factor) 134 | H = int(H // args.scale_factor) 135 | W = int(W // args.scale_factor) 136 | 137 | self.hwf = [self.H, self.W, focal / args.scale_factor] 138 | 139 | # self.K[:2, 2] = self.K[:2, 2] / args.scale_factor 140 | # self.K[0, 0] = self.K[0, 0] / args.scale_factor 141 | # self.K[1, 1] = self.K[1, 1] / args.scale_factor 142 | self.K[:2, :] = self.K[:2, :] / args.scale_factor 143 | 144 | tmp_images = [] 145 | for i, img in enumerate(self.images): 146 | tmp_images.append(cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)) 147 | # tmp_images.append(torch.nn.functional.interpolate(img, [self.H, self.W], mode='area')) 148 | self.images = np.array(tmp_images) # actually stack 149 | 150 | self.images = torch.tensor(self.images) 151 | 152 | main_log('get rays') 153 | # rays = torch.stack([torch.cat(get_rays(self.H, self.W, self.K, p), -1) for p in self.poses[:,:3,:4]], 0) # [N, H, W, 6] 154 | rays = [] 155 | for p in self.poses[:,:3,:4]: 156 | rays_o, rays_d = get_rays(self.H, self.W, self.K, p) 157 | if not args.no_ndc: 158 | rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) 159 | else: 160 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 161 | rays.append(torch.cat([rays_o, rays_d], -1)) 162 | rays = torch.stack(rays, 0) # [N, H, W, 6] 163 | 164 | main_log('done, concats') 165 | 166 | # rays = torch.permute(rays, [0,2,3,1,4]) # [N, H, W, 2, 3] 167 | # rays = torch.reshape(rays, list(rays.shape[0:3]) + [6]) # [N, H, W, 6] 168 | if args.dataset_type == 'bungee': 169 | rays, radii = get_bungee_nearfar_radii(rays=rays, scene_scaling_factor=scene_scaling_factor, 170 | scene_origin=scene_origin, ray_nearfar=args.bungee_ray_nearfar) 171 | self.radii = radii # N, H, W, 1 172 | else: 173 | rays = torch.cat([rays, self.near*torch.ones_like(rays[..., :1]), self.far*torch.ones_like(rays[..., :1])], -1) # [N, H, W, 8] 174 | self.radii = None 175 | self.rays = rays.to(torch.float32) 176 | if self.radii is not None: 177 | self.radii = self.radii.to(torch.float32) 178 | 179 | self.rgbs = self.images # N, H, W, 3 180 | 181 | self.rays_train = self.rays[i_train] 182 | self.rays_train = torch.reshape(self.rays_train, [-1,8]) 183 | self.rgbs_train = self.rgbs[i_train] 184 | self.rgbs_train = torch.reshape(self.rgbs_train, [-1,3]) 185 | if self.radii is not None: 186 | self.radii_train = self.radii[i_train] 187 | self.radii_train = torch.reshape(self.radii_train, [-1,1]) 188 | 189 | self.rays_val = self.rays[i_val] # N, H, W, 8 190 | self.rgbs_val = self.rgbs[i_val] # N, H, W, 3 191 | if self.radii is not None: 192 | self.radii_val = self.radii[i_val] # N, H, W, 1 193 | 194 | self.rays_test = self.rays[i_test] # N, H, W, 8 195 | self.rgbs_test = self.rgbs[i_test] # N, H, W, 3 196 | if self.radii is not None: 197 | self.radii_test = self.radii[i_test] # N, H, W, 1 198 | 199 | self.args = args 200 | 201 | class NeRFDatasetTrain(Dataset): 202 | def __init__(self, dataset) -> None: 203 | super().__init__() 204 | self.dataset = dataset 205 | 206 | def __len__(self): 207 | return self.dataset.rays_train.shape[0] 208 | 209 | def __getitem__(self, idx): 210 | sample = { 211 | 'rays': self.dataset.rays_train[idx], 212 | 'rgbs': self.dataset.rgbs_train[idx]} 213 | if self.dataset.args.dataset_type == 'bungee': 214 | sample["radii"] = self.dataset.radii_train[idx] 215 | return sample 216 | 217 | class NeRFDatasetVal(Dataset): 218 | def __init__(self, dataset) -> None: 219 | super().__init__() 220 | self.dataset = dataset 221 | 222 | def __len__(self): 223 | return len(self.dataset.i_val) 224 | 225 | def __getitem__(self, idx): 226 | img_i = self.dataset.i_val[idx] 227 | sample = { 228 | 'rays': self.dataset.rays_val[idx], 229 | 'rgbs': self.dataset.rgbs_val[idx], 230 | "img_i": img_i} 231 | if self.dataset.args.dataset_type == 'bungee': 232 | sample["radii"] = self.dataset.radii_val[idx] 233 | return sample 234 | 235 | class NeRFDatasetTest(Dataset): 236 | def __init__(self, dataset) -> None: 237 | super().__init__() 238 | self.dataset = dataset 239 | 240 | def __len__(self): 241 | return len(self.dataset.i_test) 242 | 243 | def __getitem__(self, idx): 244 | img_i = self.dataset.i_test[idx] 245 | sample = { 246 | 'rays': self.dataset.rays_test[idx], 247 | 'rgbs': self.dataset.rgbs_test[idx], 248 | "img_i": img_i} 249 | if self.dataset.args.dataset_type == 'bungee': 250 | sample["radii"] = self.dataset.radii_test[idx] 251 | return sample -------------------------------------------------------------------------------- /switch_nerf/modules/tutel_moe_ext/torch_moe_layer_nobatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import re 5 | import logging 6 | 7 | import torch 8 | from torch import Tensor 9 | from torch.nn import ModuleList 10 | import torch.nn.functional as F 11 | import torch.nn as nn 12 | import copy 13 | from timm.models.layers import trunc_normal_ 14 | from typing import List, Optional 15 | 16 | class TopKGate(torch.nn.Module): 17 | """General-purpose Top-K Gate for MoE 18 | """ 19 | 20 | def __init__( 21 | self, 22 | model_dim, 23 | num_global_experts, 24 | a2a_ffn_overlap_degree=1, 25 | capacity_factor=1.0, 26 | k=2, 27 | batch_prioritized_routing=False, 28 | fp32_gate=False, 29 | is_postscore=True, 30 | input_dropout_p=0, 31 | use_normal_noise=False, 32 | ray_prioritized_droping=False, 33 | ray_prioritized_droping_mode="max", 34 | ray_prioritized_droping_factor=1.0, 35 | ray_random_droping=False, 36 | ray_random_droping_factor=1.0, 37 | gate_dim=None, 38 | gate_noise=-1.0, 39 | use_load_importance_loss=False, 40 | compute_balance_loss=False, 41 | dispatcher_no_score=False 42 | ): 43 | super().__init__() 44 | k = min(k, num_global_experts) 45 | self.top_k = k 46 | assert self.top_k > 0, "Top-k value %d is not valid." % self.top_k 47 | 48 | self.gate_dim = gate_dim 49 | if self.gate_dim is None: 50 | self.gate_dim = model_dim 51 | 52 | self.wg = torch.nn.Linear(self.gate_dim, num_global_experts, bias=False) 53 | self.fp32_gate = fp32_gate 54 | 55 | self.num_global_experts = num_global_experts 56 | 57 | def forward(self, gate_input: torch.Tensor): 58 | 59 | # if self.fp32_gate: 60 | # wg = self.wg.to(torch.float32) 61 | # else: 62 | # wg = self.wg 63 | 64 | # wg = self.wg 65 | # with torch.cuda.amp.autocast(enabled=(not self.fp32_gate)): 66 | # logits = wg(gate_input.to((wg.weight).dtype)) 67 | logits = self.wg(gate_input) 68 | 69 | gates = F.softmax(logits, dim=1) 70 | 71 | return gates 72 | 73 | class MOELayer(torch.nn.Module): 74 | """Tutel optimized MOELayer 75 | """ 76 | 77 | def __init__( 78 | self, 79 | gate_type, 80 | model_dim: int, 81 | experts=None, 82 | scan_expert_func=None, 83 | result_func=None, 84 | group=None, 85 | seeds=None, 86 | a2a_ffn_overlap_degree=1, 87 | parallel_type='auto', 88 | pad_samples=False, 89 | moe_no_batch=False, 90 | use_residual=False, 91 | return_gates=False, 92 | return_gate_logits=False, 93 | return_expert_mean_feature=False, 94 | use_random_balance_expert=False, 95 | use_scaled_dot=False, 96 | ): 97 | super().__init__() 98 | assert model_dim % 2 == 0, "Model_dim (%s) must be even value, while this Model_dim mod 2 > 0." % model_dim 99 | 100 | self.moe_no_batch = moe_no_batch 101 | self.num_global_experts = self.num_local_experts = experts.get('count_per_node', 1) 102 | self.hidden_size = experts.get('hidden_size_per_expert', 'None') 103 | self.model_dim = model_dim 104 | self.is_builtin_experts = True 105 | 106 | self.expert_type = experts['type'] 107 | if experts['type'] == 'seqexperts' or experts['type'] == 'multiseqexperts': 108 | if seeds is not None and seeds[1] is not None: 109 | torch.manual_seed(seeds[1]) 110 | net = experts['net'] 111 | self.experts = ModuleList([SeqExperts(net, local_experts=self.num_local_experts)]) 112 | else: 113 | raise Exception('Builtin expert type is not recognized: %s' % experts['type']) 114 | 115 | if isinstance(gate_type, str): 116 | assert re.match(r'^Top[0-9]+Gate$', gate_type), "Unrecognized gate_type: %s" % gate_type 117 | top_k = int(gate_type[3:-4]) 118 | logging.warning(f"gate_type value `{gate_type}` in tutel.moe_layer has been deprecated, please use gate_type = {{'type': 'top', 'k': {top_k}}} instead.") 119 | gate_type = {'type': 'top', 'k': top_k} 120 | 121 | if not isinstance(gate_type, list): 122 | gate_type = [gate_type] 123 | 124 | self.gates = [] 125 | for gi, single_gate_type in enumerate(gate_type): 126 | if single_gate_type['type'] == 'top': 127 | if seeds is not None and seeds[0] is not None: 128 | torch.manual_seed(seeds[0] + gi) 129 | 130 | single_gate_type.pop('type') 131 | self.gates += [TopKGate(model_dim=model_dim, num_global_experts=self.num_global_experts, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, **single_gate_type)] 132 | else: 133 | raise Exception("Unrecognized gate_type: %s" % single_gate_type) 134 | 135 | self.gates = ModuleList(self.gates) 136 | 137 | if seeds is not None and len(seeds) > 2 and seeds[2] is not None: 138 | torch.manual_seed(seeds[2]) 139 | 140 | def forward(self, input: Tensor, gate_input: Optional[torch.Tensor] = None): 141 | 142 | if gate_input is None: 143 | gate_input = input 144 | 145 | original_shape, original_dtype = input.shape, input.dtype 146 | # self.original_shape = original_shape 147 | assert len(input.shape) >= 2, "Input data must be at least 2D tensor: (s)amples, .., (m)odel_dim" 148 | reshaped_input = input.reshape(-1, input.shape[-1]) 149 | reshaped_input_samples = reshaped_input.shape[0] 150 | reshaped_gate_input = gate_input.reshape(-1, gate_input.shape[-1]) 151 | 152 | results = torch.empty(0) 153 | cluster_gates = self.gates[0](reshaped_gate_input) 154 | cluster_gates_mul, cluster_assignments = torch.topk(cluster_gates, k=1, dim=1) 155 | cluster_assignments = cluster_assignments.squeeze(1) 156 | 157 | for i, child in enumerate(self.experts[0].experts): 158 | cluster_mask = cluster_assignments == i 159 | sub_input = input[cluster_mask] 160 | 161 | if sub_input.shape[0] > 0: 162 | sub_result = child(sub_input) 163 | sub_result = sub_result * cluster_gates_mul[cluster_mask] 164 | 165 | if results.shape[0] == 0: 166 | results = torch.zeros(input.shape[0], sub_result.shape[1], device=sub_result.device, 167 | dtype=sub_result.dtype) 168 | results[cluster_mask] = sub_result 169 | result_output = results[:reshaped_input_samples, :] 170 | # result_output = result_output.view(original_shape).to(original_dtype) 171 | result_output = result_output.view(original_shape) 172 | return result_output 173 | 174 | moe_layer = MOELayer 175 | 176 | class SeqExperts(torch.nn.Module): 177 | def __init__(self, expert, local_experts=1): 178 | super(SeqExperts, self).__init__() 179 | 180 | if isinstance(expert, torch.nn.ModuleList): 181 | self.experts = expert 182 | else: 183 | self.experts = torch.nn.ModuleList( 184 | [copy.deepcopy(expert) for i in range(local_experts)]) 185 | self.local_experts = local_experts 186 | 187 | 188 | 189 | class SingleExpert(torch.nn.Module): 190 | # one layer MLP with experts 191 | def __init__(self, model_dim, layer_num, skips=None, activation=F.relu, init_factor=1.0, norm_layer=nn.LayerNorm, use_norm=False, init_trunc_normal=False): 192 | super().__init__() 193 | self.model_dim = model_dim 194 | self.layer_num = layer_num 195 | self.hidden_dim = model_dim 196 | self.activation = activation 197 | self.skips = skips 198 | # self.norm_layer = norm_layer 199 | self.use_norm = use_norm 200 | self.norms = None 201 | self.init_trunc_normal = init_trunc_normal 202 | if self.use_norm: 203 | self.norms = nn.ModuleDict() 204 | for skip in self.skips: 205 | self.norms[str(skip)] = norm_layer(model_dim) 206 | 207 | self.layers = nn.ModuleList() 208 | 209 | for i in range(self.layer_num): 210 | layer = nn.Linear(self.model_dim, self.model_dim) 211 | if self.init_trunc_normal: 212 | trunc_normal_linear(layer, std=init_factor) 213 | else: 214 | if init_factor != 1.0: 215 | with torch.no_grad(): 216 | layer.weight.multiply_(init_factor) 217 | if layer.bias is not None: 218 | layer.bias.multiply_(init_factor) 219 | self.layers.append(layer) 220 | 221 | def extra_repr(self): 222 | return 'model_dim=%d, layer_num=%d' % (self.model_dim, self.layer_num) 223 | 224 | def forward(self, x: torch.Tensor): 225 | # x: 1, E, N, C 226 | h = x 227 | for layer_id, fc in enumerate(self.layers): 228 | # fc = self.layers[layer_id] 229 | h = fc(h) 230 | 231 | # skip connections 232 | if self.skips is not None: 233 | if layer_id in self.skips: 234 | h = h + x 235 | if layer_id < (self.layer_num - 1): 236 | h = self.activation(h) 237 | x = h 238 | else: 239 | if layer_id < (self.layer_num - 1): 240 | h = self.activation(h) 241 | else: 242 | if layer_id < (self.layer_num - 1): 243 | h = self.activation(h) 244 | return h 245 | 246 | def trunc_normal_linear(module, std=1.0): 247 | trunc_normal_(module.weight, std=std) 248 | if module.bias is not None: 249 | nn.init.zeros_(module.bias) 250 | 251 | class Mlp(nn.Module): 252 | def __init__(self, in_features, hidden_features, out_features, layer_num, skips=None, act_fn=F.relu): 253 | super().__init__() 254 | 255 | self.act_fn = act_fn 256 | self.layer_num = layer_num 257 | self.fcs = nn.ModuleList() 258 | self.skips = skips 259 | 260 | for i in range(layer_num): 261 | in_ch = in_features if i == 0 else hidden_features 262 | out_ch = out_features if i == layer_num - 1 else hidden_features 263 | self.fcs.append(nn.Linear(in_ch, out_ch)) 264 | 265 | def forward(self, x): 266 | h = x 267 | for i, fc in enumerate(self.fcs): 268 | # fc = self.fcs[i] 269 | h = fc(h) 270 | 271 | # skip connections 272 | if self.skips is not None: 273 | if i in self.skips: 274 | h = h + x 275 | if i < self.layer_num - 1: 276 | h = self.act_fn(h) 277 | x = h 278 | else: 279 | if i < self.layer_num - 1: 280 | h = self.act_fn(h) 281 | else: 282 | if i < self.layer_num - 1: 283 | h = self.act_fn(h) 284 | return h -------------------------------------------------------------------------------- /switch_nerf/datasets/nerf_data/load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, imageio 3 | 4 | 5 | ########## Slightly modified version of LLFF data loading code 6 | ########## see https://github.com/Fyusion/LLFF for original 7 | 8 | def _minify(basedir, factors=[], resolutions=[]): 9 | needtoload = False 10 | for r in factors: 11 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 12 | if not os.path.exists(imgdir): 13 | needtoload = True 14 | for r in resolutions: 15 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 16 | if not os.path.exists(imgdir): 17 | needtoload = True 18 | if not needtoload: 19 | return 20 | 21 | from shutil import copy 22 | from subprocess import check_output 23 | 24 | imgdir = os.path.join(basedir, 'images') 25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 27 | imgdir_orig = imgdir 28 | 29 | wd = os.getcwd() 30 | 31 | for r in factors + resolutions: 32 | if isinstance(r, int): 33 | name = 'images_{}'.format(r) 34 | resizearg = '{}%'.format(100./r) 35 | else: 36 | name = 'images_{}x{}'.format(r[1], r[0]) 37 | resizearg = '{}x{}'.format(r[1], r[0]) 38 | imgdir = os.path.join(basedir, name) 39 | if os.path.exists(imgdir): 40 | continue 41 | 42 | print('Minifying', r, basedir) 43 | 44 | os.makedirs(imgdir) 45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 46 | 47 | ext = imgs[0].split('.')[-1] 48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 49 | print(args) 50 | os.chdir(imgdir) 51 | check_output(args, shell=True) 52 | os.chdir(wd) 53 | 54 | if ext != 'png': 55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 56 | print('Removed duplicates') 57 | print('Done') 58 | 59 | 60 | 61 | 62 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 63 | 64 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 65 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) 66 | bds = poses_arr[:, -2:].transpose([1,0]) 67 | 68 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 69 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 70 | sh = imageio.imread(img0).shape 71 | 72 | sfx = '' 73 | 74 | if factor is not None: 75 | sfx = '_{}'.format(factor) 76 | _minify(basedir, factors=[factor]) 77 | factor = factor 78 | elif height is not None: 79 | factor = sh[0] / float(height) 80 | width = int(sh[1] / factor) 81 | _minify(basedir, resolutions=[[height, width]]) 82 | sfx = '_{}x{}'.format(width, height) 83 | elif width is not None: 84 | factor = sh[1] / float(width) 85 | height = int(sh[0] / factor) 86 | _minify(basedir, resolutions=[[height, width]]) 87 | sfx = '_{}x{}'.format(width, height) 88 | else: 89 | factor = 1 90 | 91 | imgdir = os.path.join(basedir, 'images' + sfx) 92 | if not os.path.exists(imgdir): 93 | print( imgdir, 'does not exist, returning' ) 94 | return 95 | 96 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 97 | if poses.shape[-1] != len(imgfiles): 98 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) 99 | return 100 | 101 | sh = imageio.imread(imgfiles[0]).shape 102 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 103 | poses[2, 4, :] = poses[2, 4, :] * 1./factor 104 | 105 | if not load_imgs: 106 | return poses, bds 107 | 108 | def imread(f): 109 | if f.endswith('png'): 110 | return imageio.imread(f, ignoregamma=True) 111 | else: 112 | return imageio.imread(f) 113 | 114 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] 115 | imgs = np.stack(imgs, -1) 116 | 117 | print('Loaded image data', imgs.shape, poses[:,-1,0]) 118 | return poses, bds, imgs 119 | 120 | 121 | 122 | 123 | 124 | 125 | def normalize(x): 126 | return x / np.linalg.norm(x) 127 | 128 | def viewmatrix(z, up, pos): 129 | vec2 = normalize(z) 130 | vec1_avg = up 131 | vec0 = normalize(np.cross(vec1_avg, vec2)) 132 | vec1 = normalize(np.cross(vec2, vec0)) 133 | m = np.stack([vec0, vec1, vec2, pos], 1) 134 | return m 135 | 136 | def ptstocam(pts, c2w): 137 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] 138 | return tt 139 | 140 | def poses_avg(poses): 141 | 142 | hwf = poses[0, :3, -1:] 143 | 144 | center = poses[:, :3, 3].mean(0) 145 | vec2 = normalize(poses[:, :3, 2].sum(0)) 146 | up = poses[:, :3, 1].sum(0) 147 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 148 | 149 | return c2w 150 | 151 | 152 | 153 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 154 | render_poses = [] 155 | rads = np.array(list(rads) + [1.]) 156 | hwf = c2w[:,4:5] 157 | 158 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 159 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 160 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 161 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 162 | return render_poses 163 | 164 | 165 | 166 | def recenter_poses(poses): 167 | 168 | poses_ = poses+0 169 | bottom = np.reshape([0,0,0,1.], [1,4]) 170 | c2w = poses_avg(poses) 171 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 172 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) 173 | poses = np.concatenate([poses[:,:3,:4], bottom], -2) 174 | 175 | poses = np.linalg.inv(c2w) @ poses 176 | poses_[:,:3,:4] = poses[:,:3,:4] 177 | poses = poses_ 178 | return poses 179 | 180 | 181 | ##################### 182 | 183 | 184 | def spherify_poses(poses, bds): 185 | 186 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 187 | 188 | rays_d = poses[:,:3,2:3] 189 | rays_o = poses[:,:3,3:4] 190 | 191 | def min_line_dist(rays_o, rays_d): 192 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 193 | b_i = -A_i @ rays_o 194 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 195 | return pt_mindist 196 | 197 | pt_mindist = min_line_dist(rays_o, rays_d) 198 | 199 | center = pt_mindist 200 | up = (poses[:,:3,3] - center).mean(0) 201 | 202 | vec0 = normalize(up) 203 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 204 | vec2 = normalize(np.cross(vec0, vec1)) 205 | pos = center 206 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 207 | 208 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 209 | 210 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) 211 | 212 | sc = 1./rad 213 | poses_reset[:,:3,3] *= sc 214 | bds *= sc 215 | rad *= sc 216 | 217 | centroid = np.mean(poses_reset[:,:3,3], 0) 218 | zh = centroid[2] 219 | radcircle = np.sqrt(rad**2-zh**2) 220 | new_poses = [] 221 | 222 | for th in np.linspace(0.,2.*np.pi, 120): 223 | 224 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 225 | up = np.array([0,0,-1.]) 226 | 227 | vec2 = normalize(camorigin) 228 | vec0 = normalize(np.cross(vec2, up)) 229 | vec1 = normalize(np.cross(vec2, vec0)) 230 | pos = camorigin 231 | p = np.stack([vec0, vec1, vec2, pos], 1) 232 | 233 | new_poses.append(p) 234 | 235 | new_poses = np.stack(new_poses, 0) 236 | 237 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) 238 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 239 | 240 | return poses_reset, new_poses, bds 241 | 242 | 243 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): 244 | 245 | 246 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 247 | print('Loaded', basedir, bds.min(), bds.max()) 248 | 249 | # Correct rotation matrix ordering and move variable dim to axis 0 250 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 251 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 252 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 253 | images = imgs 254 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 255 | 256 | # Rescale if bd_factor is provided 257 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) 258 | poses[:,:3,3] *= sc 259 | bds *= sc 260 | 261 | if recenter: 262 | poses = recenter_poses(poses) 263 | 264 | if spherify: 265 | poses, render_poses, bds = spherify_poses(poses, bds) 266 | 267 | else: 268 | 269 | c2w = poses_avg(poses) 270 | print('recentered', c2w.shape) 271 | print(c2w[:3,:4]) 272 | 273 | ## Get spiral 274 | # Get average pose 275 | up = normalize(poses[:, :3, 1].sum(0)) 276 | 277 | # Find a reasonable "focus depth" for this dataset 278 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 279 | dt = .75 280 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 281 | focal = mean_dz 282 | 283 | # Get radii for spiral path 284 | shrink_factor = .8 285 | zdelta = close_depth * .2 286 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T 287 | rads = np.percentile(np.abs(tt), 90, 0) 288 | c2w_path = c2w 289 | N_views = 120 290 | N_rots = 2 291 | if path_zflat: 292 | # zloc = np.percentile(tt, 10, 0)[2] 293 | zloc = -close_depth * .1 294 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] 295 | rads[2] = 0. 296 | N_rots = 1 297 | N_views/=2 298 | 299 | # Generate poses for spiral path 300 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 301 | 302 | 303 | render_poses = np.array(render_poses).astype(np.float32) 304 | 305 | c2w = poses_avg(poses) 306 | print('Data:') 307 | print(poses.shape, images.shape, bds.shape) 308 | 309 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) 310 | i_test = np.argmin(dists) 311 | print('HOLDOUT view is', i_test) 312 | 313 | images = images.astype(np.float32) 314 | poses = poses.astype(np.float32) 315 | 316 | return images, poses, bds, render_poses, i_test 317 | 318 | 319 | 320 | -------------------------------------------------------------------------------- /switch_nerf/modules/tutel_moe_ext/tutel_communicate_nobatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast 5 | 6 | import os 7 | import re 8 | import time 9 | import torch 10 | import logging 11 | 12 | from torch import Tensor 13 | import torch.distributed as dist 14 | 15 | from tutel.impls.communicate import get_world_size 16 | import subprocess 17 | 18 | def list_all_to_all(input, input_splits, output_splits, group=None, background=False): 19 | world_size = get_world_size(group) 20 | if world_size == 1: 21 | return input if not background else (input, lambda *args: None) 22 | list_all_to_all._use_builtins = True 23 | input = input.contiguous() 24 | input = list(torch.split(input, input_splits, dim=0)) 25 | # input = [i.contiguous() for i in input] 26 | output = [torch.empty([i] + list(input[0].shape[1:]), dtype=input[0].dtype, device=input[0].device, requires_grad=input[0].requires_grad) for i in output_splits] 27 | if background: 28 | future_op = dist.all_to_all(output, input, group=group, async_op=True) 29 | return output, future_op.wait 30 | dist.all_to_all(output, input, group=group) 31 | output = torch.cat(output, dim=0) 32 | return output 33 | 34 | class ListAllToAll(torch.autograd.Function): 35 | @staticmethod 36 | def forward(ctx, input, input_splits, output_splits, group=None): 37 | ctx.group = group 38 | ctx.input_splits = input_splits 39 | ctx.output_splits = output_splits 40 | return list_all_to_all(input, input_splits, output_splits, group) 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | return (list_all_to_all(grad_output, ctx.output_splits, ctx.input_splits, ctx.group), None, None, None) 45 | # return (ListAllToAll.apply(grad_output, ctx.output_splits, ctx.input_splits, ctx.group), None, None, None) 46 | 47 | @staticmethod 48 | def single(input, input_splits, output_splits, group=None): 49 | return ListAllToAll.apply(input, input_splits, output_splits, group) 50 | 51 | list_all_to_all_single = ListAllToAll.single 52 | 53 | 54 | from tutel.impls.communicate import TUTEL_GROUPING_CACHE 55 | 56 | def create_groups_from_world_slurm(group_count, include_init=None): 57 | backend = TUTEL_GROUPING_CACHE.get('', include_init) 58 | if include_init: 59 | assert backend == include_init, "Only 1 backend type is allowed, get: %s v.s. %s" % (backend, include_init) 60 | TUTEL_GROUPING_CACHE[''] = backend 61 | 62 | if group_count in TUTEL_GROUPING_CACHE: 63 | return TUTEL_GROUPING_CACHE[group_count] 64 | 65 | def dist_init(host_addr, rank, local_rank, world_size, port=23456): 66 | host_addr_full = 'tcp://' + host_addr + ':' + str(port) 67 | torch.distributed.init_process_group(backend, init_method=host_addr_full, 68 | rank=rank, world_size=world_size) 69 | assert torch.distributed.is_initialized() 70 | 71 | try: 72 | rank = int(os.environ['SLURM_PROCID']) 73 | local_rank = int(os.environ['SLURM_LOCALID']) 74 | world_size = int(os.environ['SLURM_NTASKS']) 75 | iplist = os.environ['SLURM_JOB_NODELIST'] 76 | ip = subprocess.getoutput(f"scontrol show hostname {iplist} | head -n1") 77 | 78 | dist_init(ip, rank, local_rank, world_size, port=os.environ.get('MASTER_PORT', '23456')) 79 | dist_local_rank = local_rank 80 | 81 | glob_world_size, glob_world_rank = dist.get_world_size(), dist.get_rank() 82 | is_distributed = True 83 | 84 | def dist_print(*args): 85 | if glob_world_rank == 0: 86 | print(*args) 87 | 88 | # debug 89 | logging.info('successfully inin dist') 90 | 91 | except ValueError: 92 | glob_world_size, glob_world_rank, dist_local_rank = 1, 0, 0 93 | is_distributed = False 94 | dist_print = print 95 | 96 | assert glob_world_size % group_count == 0, f"Expected to evenly divide devices into {group_count} groups, while the world size of current sesion is {glob_world_size}." 97 | 98 | dist_group_size = group_count 99 | dist_world_size = glob_world_size // dist_group_size 100 | dist_world_rank = glob_world_rank % dist_world_size 101 | dist_group_rank = glob_world_rank // dist_world_size 102 | 103 | if is_distributed: 104 | global_group = model_group = data_group = dist.group.WORLD 105 | 106 | if dist_world_size != glob_world_size: 107 | groups, inner_ranks = [], [] 108 | for gr in range(dist_group_size): 109 | group_ranks = [x for x in range(gr * dist_world_size, (gr + 1) * dist_world_size)] 110 | groups += [dist.new_group(ranks=group_ranks)] 111 | inner_ranks += [group_ranks] 112 | model_group = groups[dist_group_rank] 113 | 114 | if dist_group_size != glob_world_size: 115 | groups, outer_ranks = [], [] 116 | for gr in range(dist_world_size): 117 | group_ranks = [x for x in range(gr, dist_world_size * dist_group_size, dist_world_size)] 118 | groups += [dist.new_group(ranks=group_ranks)] 119 | outer_ranks += [group_ranks] 120 | data_group = groups[dist_world_rank] 121 | else: 122 | model_group, data_group, global_group = None, None, None 123 | 124 | class ParallelPropStorage: 125 | pass 126 | 127 | result = ParallelPropStorage() 128 | 129 | result.global_size = glob_world_size 130 | result.global_rank = glob_world_rank 131 | 132 | result.group_count = dist_group_size 133 | result.data_rank = dist_group_rank 134 | 135 | result.model_size = dist_world_size 136 | result.model_rank = dist_world_rank 137 | 138 | if backend == 'nccl': 139 | result.local_device = torch.device('cuda', dist_local_rank) 140 | torch.cuda.set_device(result.local_device) 141 | elif backend == 'gloo': 142 | result.local_device = torch.device('cpu') 143 | elif backend is None: 144 | result.local_device = None 145 | else: 146 | raise Exception('Unsupported backend type: %s' % backend) 147 | 148 | result.data_group = data_group 149 | result.model_group = model_group 150 | result.global_group = global_group 151 | 152 | result.is_distributed = is_distributed 153 | result.dist_print = dist_print 154 | 155 | TUTEL_GROUPING_CACHE[group_count] = result 156 | return result 157 | 158 | 159 | 160 | def create_groups_from_world(group_count, include_init=None, timeout=None): 161 | backend = TUTEL_GROUPING_CACHE.get('', include_init) 162 | if include_init: 163 | assert backend == include_init, "Only 1 backend type is allowed, get: %s v.s. %s" % (backend, include_init) 164 | TUTEL_GROUPING_CACHE[''] = backend 165 | 166 | if group_count in TUTEL_GROUPING_CACHE: 167 | return TUTEL_GROUPING_CACHE[group_count] 168 | 169 | try: 170 | if timeout is not None: 171 | if ('LOCAL_RANK' not in os.environ) and ('OMPI_COMM_WORLD_SIZE' in os.environ): 172 | if include_init: 173 | dist.init_process_group(backend=backend, 174 | init_method='tcp://%s:%s' % (os.environ['MASTER_ADDR'], os.environ.get('MASTER_PORT', '23456')), 175 | rank=int(os.environ['OMPI_COMM_WORLD_RANK']), world_size=int(os.environ['OMPI_COMM_WORLD_SIZE']), timeout=timeout) 176 | dist_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 177 | else: 178 | if include_init: 179 | dist.init_process_group(backend=backend, timeout=timeout) 180 | dist_local_rank = min(int(os.environ.get('LOCAL_RANK', 0)), torch.cuda.device_count() - 1) 181 | else: 182 | if ('LOCAL_RANK' not in os.environ) and ('OMPI_COMM_WORLD_SIZE' in os.environ): 183 | if include_init: 184 | dist.init_process_group(backend=backend, 185 | init_method='tcp://%s:%s' % (os.environ['MASTER_ADDR'], os.environ.get('MASTER_PORT', '23456')), 186 | rank=int(os.environ['OMPI_COMM_WORLD_RANK']), world_size=int(os.environ['OMPI_COMM_WORLD_SIZE'])) 187 | dist_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 188 | else: 189 | if include_init: 190 | dist.init_process_group(backend=backend) 191 | dist_local_rank = min(int(os.environ.get('LOCAL_RANK', 0)), torch.cuda.device_count() - 1) 192 | glob_world_size, glob_world_rank = dist.get_world_size(), dist.get_rank() 193 | is_distributed = True 194 | 195 | def dist_print(*args): 196 | if glob_world_rank == 0: 197 | print(*args) 198 | except ValueError: 199 | glob_world_size, glob_world_rank, dist_local_rank = 1, 0, 0 200 | is_distributed = False 201 | dist_print = print 202 | 203 | assert glob_world_size % group_count == 0, f"Expected to evenly divide devices into {group_count} groups, while the world size of current sesion is {glob_world_size}." 204 | 205 | dist_group_size = group_count 206 | dist_world_size = glob_world_size // dist_group_size 207 | dist_world_rank = glob_world_rank % dist_world_size 208 | dist_group_rank = glob_world_rank // dist_world_size 209 | 210 | if is_distributed: 211 | global_group = model_group = data_group = dist.group.WORLD 212 | 213 | if dist_world_size != glob_world_size: 214 | groups, inner_ranks = [], [] 215 | for gr in range(dist_group_size): 216 | group_ranks = [x for x in range(gr * dist_world_size, (gr + 1) * dist_world_size)] 217 | groups += [dist.new_group(ranks=group_ranks)] 218 | inner_ranks += [group_ranks] 219 | model_group = groups[dist_group_rank] 220 | 221 | if dist_group_size != glob_world_size: 222 | groups, outer_ranks = [], [] 223 | for gr in range(dist_world_size): 224 | group_ranks = [x for x in range(gr, dist_world_size * dist_group_size, dist_world_size)] 225 | groups += [dist.new_group(ranks=group_ranks)] 226 | outer_ranks += [group_ranks] 227 | data_group = groups[dist_world_rank] 228 | else: 229 | model_group, data_group, global_group = None, None, None 230 | 231 | class ParallelPropStorage: 232 | pass 233 | 234 | result = ParallelPropStorage() 235 | 236 | result.global_size = glob_world_size 237 | result.global_rank = glob_world_rank 238 | 239 | result.group_count = dist_group_size 240 | result.data_rank = dist_group_rank 241 | 242 | result.model_size = dist_world_size 243 | result.model_rank = dist_world_rank 244 | 245 | if backend == 'nccl': 246 | result.local_device = torch.device('cuda', dist_local_rank) 247 | torch.cuda.set_device(result.local_device) 248 | elif backend == 'gloo': 249 | result.local_device = torch.device('cpu') 250 | elif backend is None: 251 | result.local_device = None 252 | else: 253 | raise Exception('Unsupported backend type: %s' % backend) 254 | 255 | result.data_group = data_group 256 | result.model_group = model_group 257 | result.global_group = global_group 258 | 259 | result.is_distributed = is_distributed 260 | result.dist_print = dist_print 261 | 262 | TUTEL_GROUPING_CACHE[group_count] = result 263 | return result -------------------------------------------------------------------------------- /switch_nerf/scripts/create_octree_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS' 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | from argparse import Namespace 25 | from pathlib import Path 26 | from typing import List, Tuple 27 | 28 | import numpy as np 29 | import torch 30 | from svox import N3Tree 31 | from svox.helpers import _get_c_extension 32 | from torch import nn 33 | from tqdm import tqdm 34 | 35 | from switch_nerf.models.model_utils import get_nerf 36 | from switch_nerf.opts import get_opts_base 37 | 38 | _C = _get_c_extension() 39 | 40 | 41 | def _get_extraction_opts() -> Namespace: 42 | parser = get_opts_base() 43 | 44 | parser.add_argument('--dataset_path', type=str, required=True) 45 | parser.add_argument('--output', type=str, required=True) 46 | parser.add_argument('--alpha_thresh', type=float, default=0.01) 47 | parser.add_argument('--scale_alpha_thresh', type=float, default=0.01) 48 | parser.add_argument('--max_refine_prop', type=float, default=0.5) 49 | parser.add_argument('--tree_branch_n', type=int, default=2) 50 | parser.add_argument('--init_grid_depth', type=int, default=8) 51 | parser.add_argument('--samples_per_cell', type=int, default=256) 52 | parser.add_argument('--masking_mode', type=str, default='weight', choices=['sigma', 'weight']) 53 | parser.add_argument('--weight_thresh', type=float, default=0.001) 54 | parser.add_argument('--embedding_index', type=int, default=0) 55 | parser.add_argument('--camera_params', type=int, nargs='+', default=[800, 800, 400, 400, 400, 400]) 56 | parser.add_argument('--renderer_step_size', type=float, default=1e-6) 57 | 58 | return parser.parse_known_args()[0] 59 | 60 | 61 | def _auto_scale(hparams: Namespace, nerf: nn.Module, center: List[float], radius: List[float], 62 | device: torch.device) -> Tuple[List[float], List[float]]: 63 | print('Step 0: Auto scale') 64 | reso = 2 ** hparams.init_grid_depth 65 | 66 | radius = torch.tensor(radius, dtype=torch.float32) 67 | center = torch.tensor(center, dtype=torch.float32) 68 | scale = 0.5 / radius 69 | offset = 0.5 * (1.0 - center / radius) 70 | 71 | arr = (torch.arange(0, reso, dtype=torch.float32) + 0.5) / reso 72 | xx = (arr - offset[0]) / scale[0] 73 | yy = (arr - offset[1]) / scale[1] 74 | zz = (arr - offset[2]) / scale[2] 75 | 76 | grid = torch.stack(torch.meshgrid(xx, yy, zz)).reshape(3, -1).T 77 | 78 | approx_delta = 2.0 / reso 79 | sigma_thresh = -np.log(1.0 - hparams.scale_alpha_thresh) / approx_delta 80 | 81 | lc = None 82 | uc = None 83 | 84 | for i in tqdm(range(0, grid.shape[0], hparams.model_chunk_size)): 85 | grid_chunk = grid[i:i + hparams.model_chunk_size].to(device) 86 | 87 | output = nerf(False, grid_chunk, sigma_only=True) if hparams.use_cascade else nerf(grid_chunk, sigma_only=True) 88 | sigmas = output[:, 0] 89 | mask = sigmas >= sigma_thresh 90 | grid_chunk = grid_chunk[mask] 91 | del mask 92 | 93 | if grid_chunk.shape[0] > 0: 94 | if lc is None: 95 | lc = grid_chunk.min(dim=0)[0] 96 | uc = grid_chunk.max(dim=0)[0] 97 | else: 98 | lc = torch.minimum(lc, grid_chunk.min(dim=0)[0]) 99 | uc = torch.maximum(uc, grid_chunk.max(dim=0)[0]) 100 | 101 | del grid_chunk 102 | 103 | lc = lc - 0.5 / reso 104 | uc = uc + 0.5 / reso 105 | return ((lc + uc) * 0.5).tolist(), ((uc - lc) * 0.5).tolist() 106 | 107 | 108 | def _calculate_grid_weights(hparams: Namespace, tree: N3Tree, poses: torch.Tensor, sigmas: torch.Tensor, 109 | reso: int) -> torch.Tensor: 110 | opts = _C.RenderOptions() 111 | opts.step_size = hparams.renderer_step_size 112 | opts.sigma_thresh = 0.0 113 | opts.ndc_width = -1 114 | 115 | cam = _C.CameraSpec() 116 | cam.fx = hparams.camera_params[2] 117 | cam.fy = hparams.camera_params[3] 118 | cam.width = hparams.camera_params[0] 119 | cam.height = hparams.camera_params[1] 120 | 121 | grid_data = sigmas.reshape((reso, reso, reso)) 122 | maximum_weight = torch.zeros_like(grid_data) 123 | 124 | for idx in tqdm(range(poses.shape[0])): 125 | cam.c2w = poses[idx].to(sigmas.device) 126 | grid_weight, _ = _C.grid_weight_render( 127 | grid_data, 128 | cam, 129 | opts, 130 | tree.offset, 131 | tree.invradius, 132 | ) 133 | 134 | maximum_weight = torch.max(maximum_weight, grid_weight) 135 | 136 | return maximum_weight 137 | 138 | 139 | def _step1(hparams: Namespace, nerf: nn.Module, tree: N3Tree, poses: torch.Tensor, device: torch.device): 140 | print('Step 1: Grid eval') 141 | reso = 2 ** (hparams.init_grid_depth + 1) 142 | offset = tree.offset.cpu() 143 | scale = tree.invradius.cpu() 144 | 145 | arr = (torch.arange(0, reso, dtype=torch.float32) + 0.5) / reso 146 | xx = (arr - offset[0]) / scale[0] 147 | yy = (arr - offset[1]) / scale[1] 148 | zz = (arr - offset[2]) / scale[2] 149 | 150 | grid = torch.stack(torch.meshgrid(xx, yy, zz)).reshape(3, -1).T 151 | 152 | approx_delta = 2.0 / reso 153 | sigma_thresh = -np.log(1.0 - hparams.alpha_thresh) / approx_delta 154 | 155 | out_chunks = [] 156 | for i in tqdm(range(0, grid.shape[0], hparams.model_chunk_size)): 157 | grid_chunk = grid[i:i + hparams.model_chunk_size].to(device) 158 | result = nerf(False, grid_chunk, sigma_only=True) if hparams.use_cascade else nerf(grid_chunk, sigma_only=True) 159 | del grid_chunk 160 | out_chunks.append(result[:, 0]) 161 | 162 | sigmas = torch.cat(out_chunks, 0) 163 | del out_chunks 164 | 165 | if hparams.masking_mode == 'sigma': 166 | mask = sigmas >= sigma_thresh 167 | elif hparams.masking_mode == 'weight': 168 | print('Calculating grid weights') 169 | grid_weights = _calculate_grid_weights(hparams, tree, poses, sigmas, reso) 170 | mask = grid_weights.reshape(-1) >= hparams.weight_thresh 171 | del grid_weights 172 | else: 173 | raise Exception('Unsupported masking mode: {}'.format(hparams.masking_mode)) 174 | del sigmas 175 | 176 | grid = grid[mask] 177 | del mask 178 | 179 | print('Building octree') 180 | 181 | tree = tree.cpu() 182 | 183 | for i in range(hparams.init_grid_depth): 184 | tree[grid].refine() 185 | 186 | print(tree) 187 | 188 | 189 | def _step2(hparams: Namespace, nerf: nn.Module, tree: N3Tree, device: torch.device): 190 | print('Step 2: AA with {} samples per cell'.format(hparams.samples_per_cell)) 191 | 192 | chunk_size = hparams.model_chunk_size // hparams.samples_per_cell 193 | for i in tqdm(range(0, tree.n_leaves, chunk_size)): 194 | points = tree[i:i + chunk_size].sample(hparams.samples_per_cell) # (n_cells, n_samples, 3) 195 | points = points.view(-1, 3).to(device) 196 | 197 | if hparams.pos_dir_dim > 0: 198 | dirs = torch.zeros_like(points) 199 | dirs[:, 0] = 1 200 | points = torch.cat([points, dirs], -1) 201 | 202 | if hparams.appearance_dim > 0: 203 | points = torch.cat([points, hparams.embedding_index * torch.ones(points.shape[0], 1, device=points.device)], 204 | -1) 205 | 206 | rgba = nerf(False, points) if hparams.use_cascade else nerf(points) 207 | rgba = rgba.reshape(-1, hparams.samples_per_cell, tree.data_dim).mean(dim=1) 208 | 209 | tree[i:i + chunk_size] = rgba.cpu() 210 | 211 | 212 | @torch.inference_mode() 213 | def main(hparams: Namespace) -> None: 214 | assert hparams.ckpt_path is not None or hparams.container_path is not None 215 | assert hparams.ray_altitude_range is not None 216 | hparams.moe_local_expert_num = hparams.moe_expert_num 217 | hparams.single_data_group = None 218 | 219 | dataset_path = Path(hparams.dataset_path) 220 | train_path_candidates = sorted(list((dataset_path / 'train' / 'metadata').iterdir())) 221 | train_paths = [train_path_candidates[i] for i in 222 | range(0, len(train_path_candidates), hparams.train_every)] 223 | 224 | metadata_paths = train_paths + list((dataset_path / 'val' / 'metadata').iterdir()) 225 | 226 | poses = torch.cat([torch.load(x, map_location='cpu')['c2w'].unsqueeze(0) for x in metadata_paths]) 227 | 228 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 229 | nerf = get_nerf(hparams, poses.shape[0]).to(device).eval() 230 | 231 | coordinate_info = torch.load(dataset_path / 'coordinates.pt', map_location='cpu') 232 | origin_drb = coordinate_info['origin_drb'] 233 | pose_scale_factor = coordinate_info['pose_scale_factor'] 234 | 235 | max_values = poses[:, :3, 3].max(0)[0] 236 | min_values = poses[:, :3, 3].min(0)[0] 237 | 238 | ray_altitude_range = [(x - origin_drb[0]) / pose_scale_factor for x in hparams.ray_altitude_range] 239 | 240 | min_values[0] = ray_altitude_range[0] 241 | max_values[0] = ray_altitude_range[1] 242 | 243 | print('Min and Max values: {} {}'.format(min_values, max_values)) 244 | 245 | center = ((max_values + min_values) * 0.5).tolist() 246 | radius = ((max_values - min_values) * 0.5).tolist() 247 | print('Center and radius before autoscale: {}, {}'.format(center, radius)) 248 | 249 | center, radius = _auto_scale(hparams, nerf, center, radius, device) 250 | print('Center and radius after autoscale: {}, {}'.format(center, radius)) 251 | 252 | sh_deg = hparams.sh_deg if hparams.sh_deg is not None else 0 253 | num_rgb_channels = 3 * (sh_deg + 1) ** 2 254 | data_dim = 1 + num_rgb_channels # alpha + rgb 255 | 256 | print('Data dim is', data_dim) 257 | 258 | print('Creating tree') 259 | data_format = f'SH{(sh_deg + 1) ** 2}' if sh_deg > 0 else 'RGBA' 260 | tree = N3Tree(N=hparams.tree_branch_n, 261 | data_dim=data_dim, 262 | init_refine=0, 263 | init_reserve=500000, 264 | geom_resize_fact=1.0, 265 | depth_limit=hparams.init_grid_depth, 266 | radius=radius, 267 | center=center, 268 | data_format=data_format, 269 | device=device) 270 | 271 | _step1(hparams, nerf, tree, poses, device) 272 | _step2(hparams, nerf, tree, device) 273 | 274 | tree.shrink_to_fit() 275 | 276 | print('Filling in internal nodes') 277 | child = tree.child.clone() 278 | parent_depth = tree.parent_depth.clone() 279 | n_free = tree._n_free.item() 280 | while tree.n_frontier > 1: 281 | print('Internal {} leaves {} frontier {} free {}'.format(tree.n_internal, tree.n_leaves, tree.n_frontier, 282 | tree._n_free)) 283 | tree.merge() 284 | 285 | tree.child.set_(child) 286 | tree.parent_depth.set_(parent_depth) 287 | tree._n_free.fill_(n_free) 288 | 289 | print(tree) 290 | 291 | print('Saving tree to: {}'.format(hparams.output)) 292 | Path(hparams.output).parent.mkdir(parents=True, exist_ok=True) 293 | tree.save(hparams.output, compress=False) 294 | 295 | 296 | if __name__ == '__main__': 297 | main(_get_extraction_opts()) 298 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Switch-NeRF: Learning Scene Decomposition with Mixture of Experts for Large-scale Neural Radiance Fields (ICLR 2023) 2 | 3 | ### [Openreview](https://openreview.net/forum?id=PQ2zoIZqvm) | [Project Page](https://mizhenxing.github.io/switchnerf) | [Checkpoints](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zmiaa_connect_ust_hk/ErqiFEjmMVBCrA8-Y8Q2yTMBTPdhMFhPCnLeUYh_oSLBVQ?e=1H0H2u) | [Visualizer](https://github.com/MiZhenxing/alpha_visualizer) 4 | 5 | 6 | ## Demo 7 | 8 | ![](https://raw.githubusercontent.com/MiZhenxing/Switch-NeRF-demo/master/sci-art_image_depth_video_fps_24.gif)![](https://raw.githubusercontent.com/MiZhenxing/Switch-NeRF-demo/master/building_image_depth_video_fps_24.gif) 9 | 10 | ![](https://raw.githubusercontent.com/MiZhenxing/Switch-NeRF-demo/master/residence_image_depth_video_fps_24.gif)![](https://raw.githubusercontent.com/MiZhenxing/Switch-NeRF-demo/master/rubble_image_depth_video_fps_24.gif) 11 | 12 | ## Updation 13 | 14 | - 2023-04-13, move ckpts to onedrive 15 | - 2023-03-30, stable release. 16 | - 2023-03-28, release the checkpoints and codes for three datasets. 17 | 18 | ## Installation 19 | 20 | The main dependencies are in the `requirements.txt`. We use [this version](https://github.com/microsoft/tutel/tree/56dbd664341cf6485c9fa292955f77d3ac918a65) of Tutel in for MoE layers. The Tutel has changed a lot so make sure to install the version of the correct commit. Please follow the instructions in Tutel to install it. We give an [instruction](install_tutel.md) on the Tutel installation. 21 | 22 | ## Dataset 23 | We have performed experiments on the datasets from the Mega-NeRF, Block-NeRF and Bungee-NeRF. 24 | 25 | ### Mega-NeRF 26 | 27 | Please follow the instructions in the code of [Mega-NeRF](https://github.com/cmusatyalab/mega-nerf) to download and process the Mill 19 and UrbanScene 3D datasets. 28 | 29 | ### Block-NeRF 30 | 31 | Please follow the website of [Block-NeRF](https://waymo.com/intl/zh-cn/research/block-nerf) to download the raw Mission Bay dataset. 32 | 33 | ### Bungee-NeRF 34 | 35 | Please follow the [BungeeNeRF](https://github.com/city-super/BungeeNeRF) to download its two scenes. 36 | 37 | ## Training 38 | 39 | ### Mega-NeRF scenes 40 | We provide the example commands to train the model on Building scene. 41 | 42 | We should first generate data chunks. The `dataset_path` should be set to the scene folder processed above. The `exp_name` is used for logging results. If it does not exit, the program will make a new one. The `chunk_paths` is used to store the generate the data chunks. The chunks will be reused in later experiments. 43 | 44 | Generate chunks. Please edit the `exp_name`, `dataset_path` and `chunk_paths`. 45 | ```sh 46 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch \ 47 | --use_env --master_port=12345 --nproc_per_node=1 -m \ 48 | switch_nerf.train \ 49 | --config=switch_nerf/configs/switch_nerf/building.yaml \ 50 | --use_moe \ 51 | --exp_name=/your/absolute/experiment/path \ 52 | --dataset_path=/your/absolute/scene/path/building-pixsfm \ 53 | --chunk_paths=/your/absolute/chunk/path/building_chunk_factor_1_bg \ 54 | --generate_chunk 55 | ``` 56 | 57 | Train the model on the Building scene and the generated chunks. The `chunk_paths` is reused after generating chunks. 58 | ```sh 59 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 60 | --use_env --master_port=12345 --nproc_per_node=8 -m \ 61 | switch_nerf.train \ 62 | --config=switch_nerf/configs/switch_nerf/building.yaml \ 63 | --use_moe \ 64 | --exp_name=/your/absolute/experiment/path \ 65 | --dataset_path=/your/absolute/scene/path/building-pixsfm \ 66 | --chunk_paths=/your/absolute/chunk/path/building_chunk_factor_1_bg \ 67 | --use_balance_loss \ 68 | --i_print=1000 \ 69 | --batch_size=8192 \ 70 | --moe_expert_type=expertmlp \ 71 | --moe_train_batch \ 72 | --moe_test_batch \ 73 | --model_chunk_size=131072 \ 74 | --moe_capacity_factor=1.0 \ 75 | --batch_prioritized_routing \ 76 | --moe_l_aux_wt=0.0005 \ 77 | --amp_use_bfloat16 \ 78 | --use_moe_external_gate \ 79 | --use_gate_input_norm \ 80 | --use_sigma_noise \ 81 | --sigma_noise_std=1.0 82 | ``` 83 | 84 | ### Block-NeRF scenes 85 | 86 | We adapt a data interface mainly based on the [UnboundedNeRFPytorch](https://github.com/sjtuytc/UnboundedNeRFPytorch). We first generate data chunks from the raw `tf_records` in Block-NeRF dataset. 87 | 88 | Please edit the `exp_name`, `dataset_path` and `chunk_paths`. 89 | ```sh 90 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch \ 91 | --use_env --master_port=12345 --nproc_per_node=1 -m \ 92 | switch_nerf.train \ 93 | --config=switch_nerf/configs/switch_nerf/mission_bay.yaml \ 94 | --use_moe \ 95 | --exp_name=/your/absolute/experiment/path \ 96 | --dataset_path=/your/absolute/scene/path/Mission_Bay/v1.0 \ 97 | --block_train_list_path=switch_nerf/datasets/lists/block_nerf_train_val.txt \ 98 | --block_image_hash_id_map_path=switch_nerf/datasets/lists/block_nerf_id_map.json \ 99 | --chunk_paths=/your/absolute/chunk/path/mission_bay_chunk_radii_1 \ 100 | --no_bg_nerf --near=0.01 --far=10.0 --generate_chunk 101 | ``` 102 | 103 | Then we train the model on the Mission Bay scene and the generated chunks. The `batch_size` is set according to the memory of RTX 3090. 104 | 105 | ```sh 106 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 107 | --use_env --master_port=12345 --nproc_per_node=8 -m \ 108 | switch_nerf.train \ 109 | --config=switch_nerf/configs/switch_nerf/mission_bay.yaml \ 110 | --use_moe --exp_name=/your/absolute/experiment/path \ 111 | --dataset_path=/your/absolute/scene/path/Mission_Bay/v1.0 \ 112 | --block_train_list_path=switch_nerf/datasets/lists/block_nerf_train_val.txt \ 113 | --block_image_hash_id_map_path=switch_nerf/datasets/lists/block_nerf_id_map.json \ 114 | --chunk_paths=/your/absolute/chunk/path/mission_bay_chunk_radii_1 \ 115 | --no_bg_nerf --near=0.01 --far=10.0 \ 116 | --use_balance_loss \ 117 | --i_print=1000 \ 118 | --batch_size=13312 \ 119 | --moe_expert_type=expertmlp \ 120 | --moe_train_batch \ 121 | --moe_test_batch \ 122 | --model_chunk_size=212992 \ 123 | --coarse_samples=257 \ 124 | --fine_samples=257 \ 125 | --moe_capacity_factor=1.0 \ 126 | --batch_prioritized_routing \ 127 | --moe_l_aux_wt=0.0005 \ 128 | --amp_use_bfloat16 \ 129 | --use_moe_external_gate \ 130 | --use_gate_input_norm \ 131 | --use_sigma_noise \ 132 | --sigma_noise_std=1.0 133 | ``` 134 | 135 | ### Bungee-NeRF scenes 136 | 137 | We need not to generate chunks for Bungee-NeRF scenes. We provide the example commands to train the model on Transamerica scene. 138 | 139 | ```sh 140 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch \ 141 | --use_env --master_port=12345 --nproc_per_node=4 -m \ 142 | switch_nerf.train_nerf_moe \ 143 | --config=switch_nerf/configs/switch_nerf/bungee.yaml \ 144 | --use_moe --exp_name=/your/absolute/experiment/path \ 145 | --dataset_path=/your/absolute/scene/path/multiscale_google_Transamerica \ 146 | --use_balance_loss \ 147 | --i_print=1000 \ 148 | --batch_size=4096 \ 149 | --moe_expert_type=expertmlp \ 150 | --moe_train_batch \ 151 | --moe_test_batch \ 152 | --model_chunk_size=65536 \ 153 | --moe_capacity_factor=1.0 \ 154 | --batch_prioritized_routing \ 155 | --moe_l_aux_wt=0.0005 \ 156 | --no_amp \ 157 | --use_moe_external_gate \ 158 | --use_gate_input_norm \ 159 | --use_sigma_noise \ 160 | --sigma_noise_std=1.0 \ 161 | --moe_expert_num=4 162 | ``` 163 | The two scenes in Bungee-NeRF use the same configure file. 164 | 165 | ## Testing 166 | 167 | We provide checkpoints in [onedrive](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zmiaa_connect_ust_hk/ErqiFEjmMVBCrA8-Y8Q2yTMBTPdhMFhPCnLeUYh_oSLBVQ?e=1H0H2u). 168 | 169 | Test on the Building scene in Mega-NeRF dataset. 170 | 171 | ```sh 172 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 173 | --use_env --master_port=12345 --nproc_per_node=8 -m \ 174 | switch_nerf.eval_image \ 175 | --config=switch_nerf/configs/switch_nerf/building.yaml \ 176 | --use_moe --exp_name=/your/absolute/experiment/path \ 177 | --dataset_path=/your/absolute/scene/path/building-pixsfm \ 178 | --i_print=1000 \ 179 | --moe_expert_type=seqexperts \ 180 | --model_chunk_size=131072 \ 181 | --ckpt_path=/your/absolute/ckpt/path/building.pt \ 182 | --expertmlp2seqexperts \ 183 | --use_moe_external_gate \ 184 | --use_gate_input_norm 185 | ``` 186 | 187 | Test on the the Mission Bay scene in Block-NeRF dataset. 188 | 189 | ```sh 190 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 191 | --use_env --master_port=12345 --nproc_per_node=8 -m \ 192 | switch_nerf.eval_image_blocknerf \ 193 | --config=switch_nerf/configs/switch_nerf/mission_bay.yaml \ 194 | --use_moe \ 195 | --exp_name=/your/absolute/experiment/path \ 196 | --dataset_path=/your/absolute/scene/path/Mission_Bay/v1.0 \ 197 | --block_val_list_path=switch_nerf/datasets/lists/block_nerf_val.txt \ 198 | --block_train_list_path=switch_nerf/datasets/lists/block_nerf_train_val.txt \ 199 | --block_image_hash_id_map_path=switch_nerf/datasets/lists/block_nerf_id_map.json \ 200 | --i_print=1000 \ 201 | --near=0.01 --far=10.0 \ 202 | --moe_expert_type=seqexperts \ 203 | --model_chunk_size=212992 \ 204 | --coarse_samples=513 \ 205 | --fine_samples=513 \ 206 | --ckpt_path=/your/absolute/ckpt/path/mission_bay.pt \ 207 | --expertmlp2seqexperts \ 208 | --use_moe_external_gate \ 209 | --use_gate_input_norm \ 210 | --set_timeout \ 211 | --image_pixel_batch_size=8192 212 | ``` 213 | You can also use less GPUs. 214 | 215 | Test on the Transamerica scene in Bungee-NeRF dataset. 216 | 217 | ```sh 218 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch \ 219 | --use_env --master_port=12345 --nproc_per_node=4 -m \ 220 | switch_nerf.eval_nerf_moe \ 221 | --config=switch_nerf/configs/switch_nerf/bungee.yaml \ 222 | --use_moe \ 223 | --exp_name=/your/absolute/experiment/path \ 224 | --dataset_path=/your/absolute/scene/path/multiscale_google_Transamerica \ 225 | --i_print=1000 \ 226 | --batch_size=4096 \ 227 | --moe_expert_type=seqexperts \ 228 | --model_chunk_size=65536 \ 229 | --ckpt_path=/your/absolute/ckpt/path/transamerica.pt \ 230 | --expertmlp2seqexperts \ 231 | --no_amp \ 232 | --use_moe_external_gate \ 233 | --use_gate_input_norm \ 234 | --moe_expert_num=4 235 | ``` 236 | 237 | ## Visualization 238 | 239 | We provide a simple point cloud visualizer in this [repository](https://github.com/MiZhenxing/alpha_visualizer). You can use the commands below to create point clouds and visualize them with transparency. You can use [Meshlab](https://www.meshlab.net) to visualize the point clouds without transparency. Meshlab can also visualize the transparency with "Shading: Dot Decorator" selected but the visualization is not clear enough. 240 | 241 | Generate point clouds: 242 | ```sh 243 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 244 | --use_env --master_port=12345 --nproc_per_node=8 -m \ 245 | switch_nerf.eval_points \ 246 | --config=switch_nerf/configs/switch_nerf/building.yaml \ 247 | --use_moe --exp_name=/your/absolute/experiment/path \ 248 | --dataset_path=/your/absolute/scene/path/building-pixsfm \ 249 | --i_print=1000 \ 250 | --moe_expert_type=seqexperts \ 251 | --model_chunk_size=131072 \ 252 | --ckpt_path=/your/absolute/ckpt/path/500000.pt \ 253 | --expertmlp2seqexperts \ 254 | --use_moe_external_gate \ 255 | --use_gate_input_norm \ 256 | --moe_return_gates \ 257 | --return_pts \ 258 | --return_pts_rgb \ 259 | --return_pts_alpha \ 260 | --render_test_points_sample_skip=4 \ 261 | --val_scale_factor=8 \ 262 | --render_test_points_image_num=20 263 | ``` 264 | 265 | Other scenes in Mega-NeRF use `--render_test_points_image_num=21`. 266 | 267 | Merge point clouds from different validation images. 268 | 269 | ```sh 270 | python -m switch_nerf.scripts.merge_points \ 271 | --data_path=/your/absolute/experiment/path/0/eval_points \ 272 | --merge_all \ 273 | --image_num=20 \ 274 | --model_type=switch \ 275 | -r=0.2 276 | ``` 277 | 278 | Other scenes in Mega-NeRF use `--image_num=21`. `-r` is used to randomly downsample point clouds by a ratio so that it can be visualized on our desktop. 279 | 280 | 281 | ## License 282 | 283 | Our code is distributed under the MIT License. See `LICENSE` file for more information. 284 | 285 | ## Citation 286 | 287 | ```bibtex 288 | @inproceedings{mi2025switchnerfplus, 289 | title={Learning Heterogeneous Mixture of Scene Experts for Large-scale Neural Radiance Fields}, 290 | author={Zhenxing Mi, Ping Yin, Xue Xiao and Dan Xu}, 291 | booktitle={Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 292 | year={2025}, 293 | url={https://arxiv.org/abs/2505.02005} 294 | } 295 | @inproceedings{mi2023switchnerf, 296 | title={Switch-NeRF: Learning Scene Decomposition with Mixture of Experts for Large-scale Neural Radiance Fields}, 297 | author={Zhenxing Mi and Dan Xu}, 298 | booktitle={International Conference on Learning Representations (ICLR)}, 299 | year={2023}, 300 | url={https://openreview.net/forum?id=PQ2zoIZqvm} 301 | } 302 | ``` 303 | 304 | ## Contact 305 | 306 | If you have any questions, please raise an issue or email to Zhenxing Mi (`zmiaa@connect.ust.hk`). 307 | 308 | ## Acknowledgments 309 | 310 | Our code follows several awesome repositories. We appreciate them for making their codes available to public. 311 | 312 | * [Mega-NeRF](https://github.com/cmusatyalab/mega-nerf) 313 | * [Tutel](https://github.com/microsoft/tutel/tree/56dbd664341cf6485c9fa292955f77d3ac918a65) 314 | * [UnboundedNeRFPytorch](https://github.com/sjtuytc/UnboundedNeRFPytorch) 315 | * [xrnerf](https://github.com/openxrlab/xrnerf) 316 | -------------------------------------------------------------------------------- /switch_nerf/utils/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | def set_random_seed(seed): 7 | random.seed(seed) 8 | np.random.seed(seed) 9 | torch.manual_seed(seed) 10 | torch.cuda.manual_seed_all(seed) 11 | 12 | def batch_img2mse_np(x, y): 13 | img2mse_np = lambda x, y : np.mean((x - y) ** 2) 14 | mse2psnr_np = lambda x : -10. * np.log(x) / np.log(10.) 15 | 16 | B = x.shape[0] 17 | mses = [] 18 | psnrs = [] 19 | for i in range(B): 20 | xi = x[i] 21 | yi = y[i] 22 | if isinstance(xi, torch.Tensor): 23 | xi = xi.cpu().numpy() 24 | 25 | if isinstance(yi, torch.Tensor): 26 | yi = yi.cpu().numpy() 27 | 28 | mses.append(img2mse_np(xi, yi)) 29 | psnrs.append(mse2psnr_np(mses[-1])) 30 | mse = sum(mses) / B 31 | psnr = sum(psnrs) / B 32 | 33 | return mse, psnr 34 | 35 | def batch_img2mse_torch(x, y): 36 | img2mse = lambda x, y : torch.mean((x - y) ** 2) 37 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 38 | 39 | B = x.shape[0] 40 | mses = [] 41 | psnrs = [] 42 | for i in range(B): 43 | xi = x[i] 44 | yi = y[i] 45 | mses.append(img2mse(xi, yi)) 46 | psnrs.append(mse2psnr(mses[-1])) 47 | mse = sum(mses) / B 48 | psnr = sum(psnrs) / B 49 | 50 | return mse.item(), psnr.item() 51 | 52 | class DictAverageMeter(object): 53 | def __init__(self): 54 | self.data = {} 55 | self.count = 0 56 | 57 | def update(self, new_input): 58 | self.count += 1 59 | if len(self.data) == 0: 60 | for k, v in new_input.items(): 61 | if (not isinstance(v, float)) and (not isinstance(v, torch.Tensor)): 62 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 63 | self.data[k] = v 64 | else: 65 | for k, v in new_input.items(): 66 | if (not isinstance(v, float)) and (not isinstance(v, torch.Tensor)): 67 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 68 | self.data[k] += v 69 | 70 | def mean(self): 71 | return {k: v / self.count for k, v in self.data.items()} 72 | 73 | class DictAverageMeter1(object): 74 | def __init__(self): 75 | self.data = {} 76 | self.count = {} 77 | 78 | def update(self, new_input): 79 | for k, v in new_input.items(): 80 | if (not isinstance(v, float)) and (not isinstance(v, torch.Tensor)): 81 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 82 | if k not in self.data: 83 | self.data[k] = v 84 | self.count[k] = 1 85 | else: 86 | self.data[k] += v 87 | self.count[k] += 1 88 | 89 | def mean(self): 90 | return {k: v / self.count[k] for k, v in self.data.items()} 91 | 92 | r"""" 93 | from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py 94 | """ 95 | 96 | # import torch 97 | import re 98 | import collections 99 | from torch._six import string_classes 100 | 101 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 102 | 103 | default_collate_cat_err_msg_format = ( 104 | "default_collate_cat: batch must contain tensors, numpy arrays, numbers, " 105 | "dicts or lists; found {}") 106 | 107 | def default_collate_cat(batch): 108 | r""" 109 | Function that takes in a batch of data and puts the elements within the batch 110 | into a tensor with an additional outer dimension - batch size. The exact output type can be 111 | a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a 112 | Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type. 113 | This is used as the default function for collation when 114 | `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`. 115 | Here is the general input type (based on the type of the element within the batch) to output type mapping: 116 | * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size) 117 | * NumPy Arrays -> :class:`torch.Tensor` 118 | * `float` -> :class:`torch.Tensor` 119 | * `int` -> :class:`torch.Tensor` 120 | * `str` -> `str` (unchanged) 121 | * `bytes` -> `bytes` (unchanged) 122 | * `Mapping[K, V_i]` -> `Mapping[K, default_collate_cat([V_1, V_2, ...])]` 123 | * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate_cat([V1_1, V1_2, ...]), default_collate_cat([V2_1, V2_2, ...]), ...]` 124 | * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate_cat([V1_1, V1_2, ...]), default_collate_cat([V2_1, V2_2, ...]), ...]` 125 | Args: 126 | batch: a single batch to be collated 127 | Examples: 128 | >>> # Example with a batch of `int`s: 129 | >>> default_collate_cat([0, 1, 2, 3]) 130 | tensor([0, 1, 2, 3]) 131 | >>> # Example with a batch of `str`s: 132 | >>> default_collate_cat(['a', 'b', 'c']) 133 | ['a', 'b', 'c'] 134 | >>> # Example with `Map` inside the batch: 135 | >>> default_collate_cat([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) 136 | {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} 137 | >>> # Example with `NamedTuple` inside the batch: 138 | >>> Point = namedtuple('Point', ['x', 'y']) 139 | >>> default_collate_cat([Point(0, 0), Point(1, 1)]) 140 | Point(x=tensor([0, 1]), y=tensor([0, 1])) 141 | >>> # Example with `Tuple` inside the batch: 142 | >>> default_collate_cat([(0, 1), (2, 3)]) 143 | [tensor([0, 2]), tensor([1, 3])] 144 | >>> # Example with `List` inside the batch: 145 | >>> default_collate_cat([[0, 1], [2, 3]]) 146 | [tensor([0, 2]), tensor([1, 3])] 147 | """ 148 | elem = batch[0] 149 | elem_type = type(elem) 150 | if isinstance(elem, torch.Tensor): 151 | out = None 152 | if torch.utils.data.get_worker_info() is not None: 153 | # If we're in a background process, concatenate directly into a 154 | # shared memory tensor to avoid an extra copy 155 | numel = sum(x.numel() for x in batch) 156 | storage = elem.storage()._new_shared(numel) 157 | out = elem.new(storage).resize_(len(batch), *list(elem.size())) 158 | return torch.cat(batch, 0, out=out) 159 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 160 | and elem_type.__name__ != 'string_': 161 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 162 | # array of string classes and object 163 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 164 | raise TypeError(default_collate_cat_err_msg_format.format(elem.dtype)) 165 | 166 | return default_collate_cat([torch.as_tensor(b) for b in batch]) 167 | elif elem.shape == (): # scalars 168 | return torch.as_tensor(batch) 169 | elif isinstance(elem, float): 170 | return torch.tensor(batch, dtype=torch.float64) 171 | elif isinstance(elem, int): 172 | return torch.tensor(batch) 173 | elif isinstance(elem, string_classes): 174 | return batch 175 | elif isinstance(elem, collections.abc.Mapping): 176 | try: 177 | return elem_type({key: default_collate_cat([d[key] for d in batch]) for key in elem}) 178 | except TypeError: 179 | # The mapping type may not support `__init__(iterable)`. 180 | return {key: default_collate_cat([d[key] for d in batch]) for key in elem} 181 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 182 | return elem_type(*(default_collate_cat(samples) for samples in zip(*batch))) 183 | elif isinstance(elem, collections.abc.Sequence): 184 | # check to make sure that the elements in batch have consistent size 185 | it = iter(batch) 186 | elem_size = len(next(it)) 187 | if not all(len(elem) == elem_size for elem in it): 188 | raise RuntimeError('each element in list of batch should be of equal size') 189 | transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. 190 | 191 | if isinstance(elem, tuple): 192 | return [default_collate_cat(samples) for samples in transposed] # Backwards compatibility. 193 | else: 194 | try: 195 | return elem_type([default_collate_cat(samples) for samples in transposed]) 196 | except TypeError: 197 | # The sequence type may not support `__init__(iterable)` (e.g., `range`). 198 | return [default_collate_cat(samples) for samples in transposed] 199 | 200 | raise TypeError(default_collate_cat_err_msg_format.format(elem_type)) 201 | 202 | 203 | def make_recursive_func(func): 204 | def wrapper(vars): 205 | if isinstance(vars, list): 206 | return [wrapper(x) for x in vars] 207 | elif isinstance(vars, tuple): 208 | return tuple([wrapper(x) for x in vars]) 209 | elif isinstance(vars, dict): 210 | return {k: wrapper(v) for k, v in vars.items()} 211 | else: 212 | return func(vars) 213 | 214 | return wrapper 215 | 216 | def to_device(vars, device): 217 | if isinstance(vars, list): 218 | return [to_device(x, device) for x in vars] 219 | elif isinstance(vars, tuple): 220 | return tuple([to_device(x, device) for x in vars]) 221 | elif isinstance(vars, dict): 222 | return {k: to_device(v, device) for k, v in vars.items()} 223 | elif isinstance(vars, str): 224 | return vars 225 | elif isinstance(vars, torch.Tensor): 226 | return vars.to(device) 227 | else: 228 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 229 | 230 | import torchvision.transforms as T 231 | import cv2 232 | from PIL import Image 233 | 234 | def visualize_depth(depth, cmap=cv2.COLORMAP_JET): 235 | """ 236 | depth: (H, W) 237 | """ 238 | x = depth.cpu().numpy() 239 | x = np.nan_to_num(x) # change nan to 0 240 | mi = np.min(x) # get minimum depth 241 | ma = np.max(x) 242 | x = (x-mi)/max(ma-mi, 1e-8) # normalize to 0~1 243 | x = (255*x).astype(np.uint8) 244 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 245 | x_ = T.ToTensor()(x_) # (3, H, W) 246 | return x_ 247 | 248 | 249 | # from https://github.com/open-mmlab/mmsegmentation/blob/441be4e435127868a0c72a4e0e6b87662a4c415b/mmseg/core/evaluation/class_names.py#L64 250 | def ade_palette(): 251 | """ADE20K palette for external use.""" 252 | return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 253 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 254 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 255 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 256 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 257 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 258 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 259 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 260 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 261 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 262 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 263 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 264 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 265 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 266 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], 267 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 268 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], 269 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 270 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], 271 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 272 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], 273 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 274 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], 275 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 276 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], 277 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 278 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], 279 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 280 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], 281 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 282 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], 283 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 284 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], 285 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 286 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], 287 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 288 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], 289 | [102, 255, 0], [92, 0, 255]] 290 | 291 | def cityscapes_palette(): 292 | """Cityscapes palette for external use.""" 293 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], 294 | [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], 295 | [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], 296 | [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], 297 | [0, 0, 230], [119, 11, 32]] 298 | 299 | def voc_palette(): 300 | """Pascal VOC palette for external use.""" 301 | return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], 302 | [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], 303 | [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], 304 | [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], 305 | [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] -------------------------------------------------------------------------------- /switch_nerf/modules/tutel_moe_ext/tutel_fast_dispatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from tutel.impls.jit_compiler import IS_HIP_EXTENSION 10 | from tutel.jit_kernels import sparse as jit_kernel 11 | from tutel.jit_kernels.gating import fast_cumsum_sub_one, torch_cumsum_sub_one 12 | from tutel.impls.communicate import simple_all_reduce 13 | from torch.distributions.normal import Normal 14 | 15 | class GatingEncoder(torch.autograd.Function): 16 | @staticmethod 17 | def forward(ctx: Any, config: Any, reshaped_input: Tensor, *gates_): 18 | ctx.reshaped_input = reshaped_input 19 | ctx.config = config 20 | if gates_: 21 | ctx.gates_h2 = [x.view(-1, 1).repeat(1, 2) if x.dtype == torch.float16 else x for x in gates_] 22 | else: 23 | ctx.gates_h2 = [ctx.config.ones_helper] * len(ctx.config.indices_) 24 | 25 | dispatched_input = torch.zeros([ctx.config.num_global_experts * ctx.config.capacity, ctx.config.model_dim], dtype=reshaped_input.dtype, device=reshaped_input.device) 26 | for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_): 27 | ctx.config.func_fwd(g, i, l, reshaped_input, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 28 | return dispatched_input 29 | 30 | @staticmethod 31 | def backward(ctx: Any, dispatched_input: Tensor): 32 | dispatched_input = dispatched_input.contiguous() 33 | last_result = None 34 | for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_): 35 | grad_data = torch.empty(ctx.reshaped_input.shape, dtype=dispatched_input.dtype, device=dispatched_input.device) 36 | ctx.config.func_bwd_data(g, i, l, grad_data, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 37 | last_result = grad_data if last_result is None else last_result + grad_data 38 | 39 | grad_gates = [] 40 | if id(ctx.gates_h2[0]) != id(ctx.config.ones_helper): 41 | for i, l in zip(ctx.config.indices_, ctx.config.locations_): 42 | grad_gates1_s = torch.empty([ctx.config.sample_size,], dtype=dispatched_input.dtype, device=dispatched_input.device) 43 | ctx.config.func_bwd_gate(grad_gates1_s, i, l, ctx.reshaped_input, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 44 | grad_gates.append(grad_gates1_s) 45 | return (None, last_result, *grad_gates) 46 | 47 | 48 | class GatingDecoder(torch.autograd.Function): 49 | @staticmethod 50 | def forward(ctx: Any, config: Any, expert_output: Tensor, *gates_): 51 | ctx.expert_output = expert_output 52 | ctx.config = config 53 | if gates_: 54 | ctx.gates_h2 = [x.view(-1, 1).repeat(1, 2) if x.dtype == torch.float16 else x for x in gates_] 55 | else: 56 | ctx.gates_h2 = [ctx.config.ones_helper] * len(ctx.config.indices_) 57 | 58 | last_result = None 59 | for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_): 60 | single_output = torch.empty([config.sample_size, config.model_dim], dtype=expert_output.dtype, device=expert_output.device) 61 | config.func_bwd_data(g, i, l, single_output, expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 62 | last_result = single_output if last_result is None else last_result + single_output 63 | return last_result 64 | 65 | @staticmethod 66 | def backward(ctx: Any, combined_output: Tensor): 67 | combined_output = combined_output.contiguous() 68 | grad_expert_output = torch.zeros(ctx.expert_output.shape, dtype=combined_output.dtype, device=combined_output.device) 69 | for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_): 70 | ctx.config.func_fwd(g, i, l, combined_output, grad_expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 71 | 72 | grad_gates = [] 73 | if id(ctx.gates_h2[0]) != id(ctx.config.ones_helper): 74 | for i, l in zip(ctx.config.indices_, ctx.config.locations_): 75 | grad_gates1_s = torch.empty([ctx.config.sample_size,], dtype=combined_output.dtype, device=combined_output.device) 76 | ctx.config.func_bwd_gate(grad_gates1_s, i, l, combined_output, ctx.expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity]) 77 | grad_gates.append(grad_gates1_s) 78 | return (None, grad_expert_output, *grad_gates) 79 | 80 | 81 | class TutelMoeFastDispatcher: 82 | 83 | kernel_pool = dict() 84 | 85 | def __init__(self, num_global_experts, capacity, model_dim, dispatch_dtype): 86 | self.num_global_experts = int(num_global_experts) 87 | self.capacity = int(capacity) 88 | self.model_dim = int(model_dim) 89 | self.dtype = dispatch_dtype 90 | if IS_HIP_EXTENSION or dispatch_dtype != torch.float16: 91 | self.dtype = torch.float32 92 | self.original_dtype = dispatch_dtype 93 | self.aligned_dim = model_dim // (2 if self.dtype == torch.float16 else 1) 94 | self.is_cuda = None 95 | 96 | def update(self, indices_, locations_, gates_, capacity=None, is_postscore=True): 97 | self.indices_ = [x.to(torch.int32).view(-1) for x in indices_] 98 | self.locations_ = [x.to(torch.int32) for x in locations_] 99 | self.gates_ = [x.to(self.dtype) for x in gates_] 100 | self.is_postscore = is_postscore 101 | self.sample_size, self.capacity = int(self.indices_[0].size(0)), int(capacity) or self.capacity 102 | 103 | if self.is_cuda != indices_[0].is_cuda: 104 | self.is_cuda = indices_[0].is_cuda 105 | if self.is_cuda not in TutelMoeFastDispatcher.kernel_pool: 106 | self.func_fwd = jit_kernel.create_forward(self.dtype, indices_[0].is_cuda) 107 | self.func_bwd_data = jit_kernel.create_backward_data(self.dtype, indices_[0].is_cuda) 108 | self.func_bwd_gate = jit_kernel.create_backward_gate(self.dtype, indices_[0].is_cuda) 109 | self.ones_helper = torch.ones([self.sample_size, 2], dtype=self.dtype, device=self.indices_[0].device) 110 | TutelMoeFastDispatcher.kernel_pool[self.is_cuda] = self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper 111 | else: 112 | self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper = TutelMoeFastDispatcher.kernel_pool[self.is_cuda] 113 | if self.ones_helper.shape[0] < self.sample_size: 114 | self.ones_helper = torch.ones([self.sample_size, 2], dtype=self.dtype, device=self.indices_[0].device) 115 | TutelMoeFastDispatcher.kernel_pool[self.is_cuda] = self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper 116 | 117 | def encode(self, data): 118 | if self.is_postscore: 119 | return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype) 120 | else: 121 | return GatingEncoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype) 122 | 123 | def decode(self, data): 124 | if self.is_postscore: 125 | return GatingDecoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype) 126 | else: 127 | return GatingDecoder.apply(self, data.to(self.dtype)).to(self.original_dtype) 128 | 129 | fast_dispatcher = TutelMoeFastDispatcher 130 | 131 | def one_hot_with_dtype(data, num_classes, dtype): 132 | result = torch.zeros([data.size(0), num_classes], device=data.device, dtype=dtype) 133 | result.scatter_(1, data.unsqueeze(-1), 1) 134 | return result 135 | 136 | def compute_sorted_location(x, importance_scores): 137 | sorted_x = x[importance_scores.argsort(dim=0)] 138 | sorted_cumsum = fast_cumsum_sub_one(sorted_x) * sorted_x 139 | return sorted_cumsum[importance_scores.argsort(dim=0).argsort(dim=0)] 140 | 141 | def load_balance(gates, mask1, num_global_experts, fp32_gate): 142 | if gates.dtype == torch.float32 or fp32_gate: 143 | me = torch.sum(gates.float(), dim=0) 144 | ce = torch.sum(mask1.to(me.dtype), dim=0) 145 | l_loss = torch.sum(me * ce) * (num_global_experts / (gates.size(0) * gates.size(0))) 146 | else: 147 | me = torch.mean(gates, dim=0) 148 | ce = torch.mean(mask1.to(gates.dtype), dim=0) 149 | l_loss = torch.sum(me * ce) * num_global_experts 150 | return l_loss 151 | 152 | def load_importance_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise): 153 | def load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise): 154 | assert gate_noise > 0, "`gate_noise` must be > 0 for normalization in load_importance_loss()." 155 | normal = Normal( 156 | torch.tensor([0.0], device=scores_wo_noise.device), 157 | torch.tensor([gate_noise / num_global_experts], device=scores_wo_noise.device), 158 | ) 159 | threshold = topk_logits[:, -1].view(-1, 1).float() 160 | diff = scores_wo_noise.float() - threshold.float() 161 | prob = normal.cdf(diff) 162 | Load = prob.sum(0) 163 | l_load = Load.float().var() / (Load.float().mean() ** 2 + 1e-10) 164 | return l_load 165 | 166 | def importance_loss(scores_wo_noise): 167 | Impi = scores_wo_noise.float().sum(0) 168 | l_imp = Impi.float().var() / (Impi.float().mean() ** 2 + 1e-10) 169 | 170 | return l_imp 171 | 172 | l_imp = importance_loss(scores_wo_noise) 173 | l_load = load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise) 174 | return (l_imp + l_load) / 2.0 175 | 176 | def extract_critical(gates, top_k, capacity_factor=1.0, fp32_gate=False, batch_prioritized_routing=False): 177 | topk_indices = torch.topk(gates, top_k, dim=1).indices 178 | num_global_experts = gates.size(1) 179 | 180 | indices_s = [x.view(-1) for x in topk_indices.chunk(top_k, dim=1)] 181 | masks_se = [one_hot_with_dtype(x, num_classes=num_global_experts, dtype=x.dtype) for x in indices_s] 182 | gates_s = [(gates * x).sum(dim=1) for x in masks_se] 183 | 184 | l_loss = load_balance(gates, masks_se[0], num_global_experts, fp32_gate) 185 | 186 | if batch_prioritized_routing: 187 | importance_scores = -1 * gates.max(dim=1)[0] 188 | compute_location = lambda x: compute_sorted_location(x, importance_scores) 189 | else: 190 | compute_location = fast_cumsum_sub_one 191 | 192 | locations1 = compute_location(masks_se[0]) 193 | 194 | locations_s = [torch.sum(locations1 * masks_se[0], dim=1).to(torch.int32)] 195 | 196 | if top_k > 1: 197 | acc_base = None 198 | for k in range(1, top_k): 199 | acc_base = torch.sum(masks_se[k - 1], dim=0, keepdim=True) if acc_base is None else acc_base + torch.sum(masks_se[k - 1], dim=0, keepdim=True) 200 | locations2 = compute_location(masks_se[k]) 201 | locations2 += acc_base 202 | locations_s.append(torch.sum(locations2 * masks_se[k], dim=1).to(torch.int32)) 203 | 204 | # Normalize Gate 205 | denom_s = torch.clamp(sum(gates_s), min=torch.finfo(gates_s[0].dtype).eps) 206 | gates_s = [x / denom_s for x in gates_s] 207 | 208 | indices_s = [x.to(torch.int32) for x in indices_s] 209 | 210 | if capacity_factor > 0: 211 | capacity = top_k * int(capacity_factor * ((int(gates.size(0)) + num_global_experts - 1) // num_global_experts)) 212 | else: 213 | capacity = torch.max(torch.concat(locations_s, dim=0)) 214 | capacity = int(simple_all_reduce(capacity, op=torch.distributed.ReduceOp.MAX)) + 1 215 | if capacity_factor < 0: 216 | capacity = min(capacity, top_k * int(-capacity_factor * ((int(gates.size(0)) + num_global_experts - 1) // num_global_experts))) 217 | return (num_global_experts, indices_s, locations_s, gates_s, capacity), l_loss 218 | 219 | def extract_critical_load_importance(gates, gates_wo_noise, logits_w_noise, top_k, gate_noise, capacity_factor=1.0, fp32_gate=False, batch_prioritized_routing=False, compute_balance_loss=False): 220 | # gates is from logits_w_noise 221 | topk_indices = torch.topk(gates, top_k, dim=1).indices 222 | num_global_experts = gates.size(1) 223 | 224 | indices_s = [x.view(-1) for x in topk_indices.chunk(top_k, dim=1)] 225 | masks_se = [one_hot_with_dtype(x, num_classes=num_global_experts, dtype=x.dtype) for x in indices_s] 226 | gates_s = [(gates * x).sum(dim=1) for x in masks_se] 227 | 228 | if compute_balance_loss: 229 | l_balance_loss = load_balance(gates, masks_se[0], num_global_experts, fp32_gate) 230 | else: 231 | l_balance_loss = torch.tensor(0.0, gates.device) 232 | l_loss = load_importance_loss(gates_wo_noise, logits_w_noise.gather(index=topk_indices, dim=1), num_global_experts, gate_noise) 233 | 234 | if batch_prioritized_routing: 235 | importance_scores = -1 * gates.max(dim=1)[0] 236 | compute_location = lambda x: compute_sorted_location(x, importance_scores) 237 | else: 238 | compute_location = fast_cumsum_sub_one 239 | 240 | locations1 = compute_location(masks_se[0]) 241 | 242 | locations_s = [torch.sum(locations1 * masks_se[0], dim=1).to(torch.int32)] 243 | 244 | if top_k > 1: 245 | acc_base = None 246 | for k in range(1, top_k): 247 | acc_base = torch.sum(masks_se[k - 1], dim=0, keepdim=True) if acc_base is None else acc_base + torch.sum(masks_se[k - 1], dim=0, keepdim=True) 248 | locations2 = compute_location(masks_se[k]) 249 | locations2 += acc_base 250 | locations_s.append(torch.sum(locations2 * masks_se[k], dim=1).to(torch.int32)) 251 | 252 | # Normalize Gate 253 | denom_s = torch.clamp(sum(gates_s), min=torch.finfo(gates_s[0].dtype).eps) 254 | gates_s = [x / denom_s for x in gates_s] 255 | 256 | indices_s = [x.to(torch.int32) for x in indices_s] 257 | 258 | if capacity_factor > 0: 259 | capacity = top_k * int(capacity_factor * ((int(gates.size(0)) + num_global_experts - 1) // num_global_experts)) 260 | else: 261 | capacity = torch.max(torch.concat(locations_s, dim=0)) 262 | capacity = int(simple_all_reduce(capacity, op=torch.distributed.ReduceOp.MAX)) + 1 263 | if capacity_factor < 0: 264 | capacity = min(capacity, top_k * int(-capacity_factor * ((int(gates.size(0)) + num_global_experts - 1) // num_global_experts))) 265 | return (num_global_experts, indices_s, locations_s, gates_s, capacity), l_loss, l_balance_loss 266 | 267 | 268 | def fast_encode(data, critial_data, is_postscore=True): 269 | num_global_experts = critial_data[0] 270 | dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype) 271 | dispatcher.update(*critial_data[1:], is_postscore=is_postscore) 272 | return dispatcher.encode(data).view(num_global_experts, -1, data.size(-1)) 273 | 274 | def fast_decode(data, critial_data, is_postscore=True): 275 | num_global_experts = critial_data[0] 276 | dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype) 277 | dispatcher.update(*critial_data[1:], is_postscore=is_postscore) 278 | return dispatcher.decode(data).view(-1, data.size(-1)) --------------------------------------------------------------------------------