├── .gitignore ├── LICENSE ├── README.md ├── feature_splatting ├── __init__.py ├── feature_extractor.py ├── feature_extractor_cfg.py ├── feature_splatting_config.py ├── feature_splatting_datamgr.py ├── model.py └── utils │ ├── __init__.py │ ├── clip_text_encoder.py │ ├── decoder_utils.py │ ├── gaussian_editor.py │ ├── math_utils.py │ ├── mpm_engine │ ├── __init__.py │ ├── mesh_io.py │ ├── mpm_solver.py │ ├── particle_io.py │ ├── renderer.py │ ├── renderer_utils.py │ └── voxelizer.py │ ├── segment_utils.py │ └── viewer_utils.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | # User-defined 2 | *.bak 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # feature-splatting-ns 2 | 3 | Official Nerfstudio implementation of Feature Splatting. 4 | 5 | **Note:** Nerfstudio version is designed to be easy-to-use and efficient, which is done via several 6 | tradeoffs than the original feature splatting paper, such as replacing SAM with MobileSAMV2 and 7 | using simple bbox to select Gaussians for editing. We recommend using this repo to check the quality 8 | of features. To reproduce full physics effects and examples on the website, please check out our 9 | [original codebase based on INRIA 3DGS](https://github.com/vuer-ai/feature-splatting-inria). 10 | 11 | ## Instructions 12 | 13 | Follow the [NerfStudio instllation instructions](https://docs.nerf.studio/quickstart/installation.html) to install a conda environment. For convenience, 14 | here are the commands I run to install nerfstudio on two machines. 15 | 16 | ```bash 17 | # Create an isolated conda environment 18 | conda create --name feature_splatting_ns -y python=3.8 19 | conda activate feature_splatting_ns 20 | 21 | # Install necessary NS dependencies 22 | pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 23 | conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit 24 | pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 25 | 26 | # Insatll nerfstudio 27 | pip install nerfstudio 28 | ``` 29 | 30 | As of this version (0.0.3), we use the gsplat kernel, which comes with NS installation. If you just want to try out feature splatting, 31 | you can run, 32 | 33 | ```bash 34 | pip install git+https://github.com/vuer-ai/feature-splatting 35 | ``` 36 | 37 | or, for dev purpose, run, 38 | 39 | ```bash 40 | # Clone and cd to this repository 41 | pip install -e . 42 | ``` 43 | 44 | After this is done, you can train feature-splatting on any nerfstudio-format datasets. An example command is given here, 45 | 46 | ```bash 47 | ns-download-data nerfstudio --capture-name=poster 48 | ns-train feature-splatting --data data/nerfstudio/poster 49 | ``` 50 | 51 | Specifically, check out various custom outputs defined by nerfstudio under `Output Type`. The `consistent_latent_pca` is used to 52 | project high-dimensional features to low dimensions without flickering effects. After any text is supplied to `Positive Text Queries`, 53 | a new output, `similarity`, will show up in the `Output Type` dropdown menu, which visualizes heatmap response to the language queries. 54 | 55 | **NOTE:** Please **PAUSE TRAINING** before using any editing utility. Otherwise it seems to lead to race conditions. Unfortunately fix of this 56 | issue seems to require modifying the core component of NerfStudio, which can not be gracefully implemented as a part of the extension plugin. 57 | 58 | ### TODOs 59 | 60 | - Remove the simple bbox selection and implement better segmentation 61 | - Support more feature extractors 62 | - Improve the object segmentation workflow. Sometimes it causes unexpected error that only prints out in the terminal 63 | - Add estimated gravity / ground plane estimation 64 | - Improve thread safety that seems to lead to race condition when editing / training happen together. 65 | -------------------------------------------------------------------------------- /feature_splatting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vuer-ai/feature-splatting/e24870d26c62dfb70d9afbbd8361c86b5754b8d9/feature_splatting/__init__.py -------------------------------------------------------------------------------- /feature_splatting/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import numpy as np 4 | import torch 5 | from argparse import ArgumentParser 6 | from PIL import Image 7 | from tqdm import tqdm, trange 8 | import cv2 9 | from typing import Any, Dict, Generator,List 10 | 11 | import torch.nn as nn 12 | import torchvision.transforms as T 13 | import torch.nn.functional as F 14 | from PIL import Image 15 | import matplotlib.pyplot as plt 16 | import maskclip_onnx 17 | 18 | def pytorch_gc(): 19 | torch.cuda.empty_cache() 20 | torch.cuda.synchronize() 21 | gc.collect() 22 | 23 | def resize_image(img, longest_edge): 24 | # resize to have the longest edge equal to longest_edge 25 | width, height = img.size 26 | if width > height: 27 | ratio = longest_edge / width 28 | else: 29 | ratio = longest_edge / height 30 | new_width = int(width * ratio) 31 | new_height = int(height * ratio) 32 | return img.resize((new_width, new_height), Image.BILINEAR) 33 | 34 | def interpolate_to_patch_size(img_bchw, patch_size): 35 | # Interpolate the image so that H and W are multiples of the patch size 36 | _, _, H, W = img_bchw.shape 37 | target_H = H // patch_size * patch_size 38 | target_W = W // patch_size * patch_size 39 | img_bchw = torch.nn.functional.interpolate(img_bchw, size=(target_H, target_W)) 40 | return img_bchw, target_H, target_W 41 | 42 | def is_valid_image(filename): 43 | ext_test_flag = any(filename.lower().endswith(extension) for extension in ['.png', '.jpg', '.jpeg']) 44 | is_file_flag = os.path.isfile(filename) 45 | return ext_test_flag and is_file_flag 46 | 47 | def show_anns(anns): 48 | if len(anns) == 0: 49 | return 50 | img = np.ones((anns.shape[1], anns.shape[2], 4)) 51 | img[:,:,3] = 0 52 | for ann in range(anns.shape[0]): 53 | m = anns[ann].bool() 54 | m=m.cpu().numpy() 55 | color_mask = np.concatenate([np.random.random(3), [1]]) 56 | img[m] = color_mask 57 | return img 58 | 59 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 60 | assert len(args) > 0 and all( 61 | len(a) == len(args[0]) for a in args 62 | ), "Batched iteration must have inputs of all the same size." 63 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 64 | for b in range(n_batches): 65 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 66 | 67 | 68 | class MaskCLIPFeaturizer(nn.Module): 69 | def __init__(self, clip_model_name): 70 | super().__init__() 71 | self.model, self.preprocess = maskclip_onnx.clip.load(clip_model_name) 72 | self.model.eval() 73 | self.patch_size = self.model.visual.patch_size 74 | 75 | def forward(self, img): 76 | b, _, input_size_h, input_size_w = img.shape 77 | patch_h = input_size_h // self.patch_size 78 | patch_w = input_size_w // self.patch_size 79 | features = self.model.get_patch_encodings(img).to(torch.float32) 80 | return features.reshape(b, patch_h, patch_w, -1).permute(0, 3, 1, 2) 81 | 82 | @torch.no_grad() 83 | def batch_extract_feature(image_paths: List[str], args): 84 | norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 85 | yolo_iou = 0.9 86 | yolo_conf = 0.4 87 | 88 | # For part-level CLIP 89 | transform = T.Compose([ 90 | T.Resize((args.part_resolution, args.part_resolution)), 91 | T.ToTensor(), 92 | norm 93 | ]) 94 | 95 | # For object-level CLIP 96 | raw_transform = T.Compose([ 97 | T.ToTensor(), 98 | norm 99 | ]) 100 | 101 | dino_transform = T.Compose([ 102 | T.ToTensor(), 103 | T.Normalize(mean=[0.5], std=[0.5]), 104 | ]) 105 | 106 | mobilesamv2, ObjAwareModel, predictor = torch.hub.load("RogerQi/MobileSAMV2", args.mobilesamv2_encoder_name) 107 | device = "cuda" if torch.cuda.is_available() else "cpu" 108 | mobilesamv2.to(device=device) 109 | mobilesamv2.eval() 110 | 111 | ret_dict = {'samclip': [], 'dinov2': []} 112 | print(f"Computing features for {len(image_paths)} images.") 113 | 114 | print("Loading DINOv2 model...") 115 | dinov2 = torch.hub.load('facebookresearch/dinov2', args.dinov2_model_name) 116 | dinov2 = dinov2.to(device) 117 | 118 | for i in trange(len(image_paths)): 119 | image = Image.open(image_paths[i]) 120 | image = resize_image(image, args.dino_resolution) 121 | image = dino_transform(image)[:3].unsqueeze(0) 122 | image, target_H, target_W = interpolate_to_patch_size(image, dinov2.patch_size) 123 | image = image.cuda() 124 | with torch.no_grad(): 125 | features = dinov2.forward_features(image)["x_norm_patchtokens"][0] 126 | 127 | features = features.cpu() 128 | 129 | features_hwc = features.reshape((target_H // dinov2.patch_size, target_W // dinov2.patch_size, -1)) 130 | features_chw = features_hwc.permute((2, 0, 1)) 131 | 132 | ret_dict['dinov2'].append(features_chw) 133 | 134 | del dinov2 135 | pytorch_gc() 136 | 137 | clip_model = MaskCLIPFeaturizer(args.clip_model_name).cuda().eval() 138 | 139 | # ====================== 140 | for i in trange(len(image_paths)): 141 | image_file_path = str(image_paths[i]) 142 | 143 | image = cv2.imread(image_file_path) 144 | # resize to longest edge 145 | if max(image.shape[:2]) > args.sam_size: 146 | if image.shape[0] > image.shape[1]: 147 | image = cv2.resize(image, (int(args.sam_size * image.shape[1] / image.shape[0]), args.sam_size)) 148 | else: 149 | image = cv2.resize(image, (args.sam_size, int(args.sam_size * image.shape[0] / image.shape[1]))) 150 | image = image[:, :, ::-1] # BGR to RGB 151 | 152 | raw_input_image = raw_transform(Image.fromarray(image)) 153 | whole_image_feature = clip_model(raw_input_image[None].cuda())[0] 154 | clip_feat_dim = whole_image_feature.shape[0] 155 | 156 | raw_img_H, raw_img_W = image.shape[:2] 157 | 158 | # part level 159 | small_W = args.part_feat_res 160 | small_H = raw_img_H * small_W // raw_img_W 161 | 162 | # obj level 163 | object_W = args.obj_feat_res 164 | object_H = raw_img_H * object_W // raw_img_W 165 | 166 | final_W = args.final_feat_res 167 | final_H = raw_img_H * final_W // raw_img_W 168 | 169 | # ===== Object-aware Model ===== 170 | obj_results = ObjAwareModel(image, device=device, imgsz=args.sam_size, conf=yolo_conf, iou=yolo_iou, verbose=False) 171 | if not obj_results: 172 | # Add an all-zero tensor if no object is detected 173 | ret_dict['samclip'].append(torch.zeros((clip_feat_dim, final_H, final_W))) 174 | continue 175 | 176 | predictor.set_image(image) 177 | input_boxes1 = obj_results[0].boxes.xyxy 178 | input_boxes = input_boxes1.cpu().numpy() 179 | input_boxes = predictor.transform.apply_boxes(input_boxes, predictor.original_size) 180 | input_boxes = torch.from_numpy(input_boxes).cuda() 181 | sam_mask = [] 182 | image_embedding = predictor.features 183 | image_embedding = torch.repeat_interleave(image_embedding, 320, dim=0) 184 | prompt_embedding = mobilesamv2.prompt_encoder.get_dense_pe() 185 | prompt_embedding = torch.repeat_interleave(prompt_embedding, 320, dim=0) 186 | for (boxes,) in batch_iterator(320, input_boxes): 187 | with torch.no_grad(): 188 | image_embedding = image_embedding[0:boxes.shape[0],:,:,:] 189 | prompt_embedding = prompt_embedding[0:boxes.shape[0],:,:,:] 190 | sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder( 191 | points=None, 192 | boxes=boxes, 193 | masks=None,) 194 | low_res_masks, _ = mobilesamv2.mask_decoder( 195 | image_embeddings=image_embedding, 196 | image_pe=prompt_embedding, 197 | sparse_prompt_embeddings=sparse_embeddings, 198 | dense_prompt_embeddings=dense_embeddings, 199 | multimask_output=False, 200 | simple_type=True, 201 | ) 202 | low_res_masks = predictor.model.postprocess_masks(low_res_masks, predictor.input_size, predictor.original_size) 203 | sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold) * 1.0 204 | sam_mask.append(sam_mask_pre.squeeze(1)) 205 | 206 | sam_mask = torch.cat(sam_mask) 207 | # Visualize SAM mask 208 | # annotation = sam_mask 209 | # areas = torch.sum(annotation, dim=(1, 2)) 210 | # sorted_indices = torch.argsort(areas, descending=True) 211 | # show_img = annotation[sorted_indices] 212 | # ann_img = show_anns(show_img) 213 | # save_img_path = obj_feat_path_list[i].replace('.npy', '_mask.png') 214 | # Image.fromarray((ann_img * 255).astype(np.uint8)).save(save_img_path) 215 | 216 | # ===== Object-level CLIP feature ===== 217 | # Interpolate CLIP features to image size 218 | resized_clip_feat_map_bchw = torch.nn.functional.interpolate(whole_image_feature.unsqueeze(0).float(), 219 | size=(object_H, object_W), 220 | mode='bilinear', 221 | align_corners=False) 222 | 223 | mask_tensor_bchw = sam_mask.unsqueeze(1) 224 | 225 | resized_mask_tensor_bchw = torch.nn.functional.interpolate(mask_tensor_bchw.float(), 226 | size=(object_H, object_W), 227 | mode='nearest').bool() 228 | 229 | aggregated_feat_map = torch.zeros((clip_feat_dim, object_H, object_W), dtype=torch.float32, device=device) 230 | aggregated_feat_cnt_map = torch.zeros((object_H, object_W), dtype=int, device=device) 231 | 232 | for mask_idx in range(resized_mask_tensor_bchw.shape[0]): 233 | aggregared_clip_feat = resized_clip_feat_map_bchw[0, :, resized_mask_tensor_bchw[mask_idx, 0]] 234 | aggregared_clip_feat = aggregared_clip_feat.mean(dim=1) 235 | 236 | aggregated_feat_map[:, resized_mask_tensor_bchw[mask_idx, 0]] += aggregared_clip_feat[:, None] 237 | aggregated_feat_cnt_map[resized_mask_tensor_bchw[mask_idx, 0]] += 1 238 | 239 | aggregated_feat_map = aggregated_feat_map / (aggregated_feat_cnt_map[None, :, :] + 1e-6) 240 | aggregated_feat_map = F.interpolate(aggregated_feat_map[None], (final_H, final_W), mode='bilinear', align_corners=False)[0] 241 | 242 | ret_dict['samclip'].append(aggregated_feat_map.detach().cpu()) 243 | 244 | gc.collect() 245 | 246 | del clip_model 247 | del mobilesamv2 248 | del ObjAwareModel 249 | pytorch_gc() 250 | 251 | for k in ret_dict.keys(): 252 | ret_dict[k] = torch.stack(ret_dict[k], dim=0) # BCHW 253 | 254 | return ret_dict 255 | 256 | if __name__ == "__main__": 257 | parser = ArgumentParser("Compute reference features for feature splatting") 258 | parser.add_argument("--source_path", "-s", required=True, type=str) 259 | parser.add_argument("--part_batch_size", type=int, default=32, help="Part-level CLIP inference batch size") 260 | parser.add_argument("--part_resolution", type=int, default=224, help="Part-level CLIP input image resolution") 261 | parser.add_argument("--sam_size", type=int, default=1024, help="Longest edge for MobileSAMV2 segmentation") 262 | parser.add_argument("--obj_feat_res", type=int, default=100, help="Intermediate (for MAP) SAM-enhanced Object-level feature resolution") 263 | parser.add_argument("--part_feat_res", type=int, default=300, help="Intermediate (for MAP) SAM-enhanced Part-level feature resolution") 264 | parser.add_argument("--final_feat_res", type=int, default=64, help="Final hierarchical CLIP feature resolution") 265 | parser.add_argument("--dino_resolution", type=int, default=800, help="Longest edge for DINOv2 feature generation") 266 | parser.add_argument("--dinov2_model_name", type=str, default='dinov2_vits14') 267 | parser.add_argument("--mobilesamv2_encoder_name", type=str, default='mobilesamv2_efficientvit_l2') 268 | parser.add_argument("--clip_model_name", type=str, default='ViT-L/14@336px') 269 | args = parser.parse_args() 270 | 271 | image_paths = [os.path.join(args.source_path, fn) for fn in os.listdir(args.source_path)] 272 | 273 | ret_dict = batch_extract_feature(image_paths, args) 274 | 275 | for k, v in ret_dict.items(): 276 | print(f"{k}: {v.shape}") 277 | -------------------------------------------------------------------------------- /feature_splatting/feature_extractor_cfg.py: -------------------------------------------------------------------------------- 1 | class SAMCLIPArgs: 2 | part_batch_size: int = 32 3 | part_resolution: int = 224 4 | sam_size: int = 1024 5 | obj_feat_res: int = 100 6 | part_feat_res: int = 300 7 | final_feat_res: int = 64 8 | dino_resolution: int = 800 9 | dinov2_model_name: str = 'dinov2_vits14' 10 | mobilesamv2_encoder_name: str = 'mobilesamv2_efficientvit_l2' 11 | clip_model_name: str = 'ViT-L/14@336px' 12 | 13 | @classmethod 14 | def id_dict(cls): 15 | """Return dict that identifies the CLIP model parameters.""" 16 | return { 17 | "part_resolution": cls.part_resolution, 18 | "sam_size": cls.sam_size, 19 | "obj_feat_res": cls.obj_feat_res, 20 | "part_feat_res": cls.part_feat_res, 21 | "final_feat_res": cls.final_feat_res, 22 | "dino_resolution": cls.dino_resolution, 23 | "dinov2_model_name": cls.dinov2_model_name, 24 | "mobilesamv2_encoder_name": cls.mobilesamv2_encoder_name, 25 | "clip_model_name": cls.clip_model_name, 26 | } 27 | -------------------------------------------------------------------------------- /feature_splatting/feature_splatting_config.py: -------------------------------------------------------------------------------- 1 | from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig 2 | from nerfstudio.configs.base_config import ViewerConfig 3 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig 4 | from nerfstudio.engine.optimizers import AdamOptimizerConfig 5 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig 6 | from nerfstudio.engine.trainer import TrainerConfig 7 | from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig 8 | from nerfstudio.plugins.types import MethodSpecification 9 | 10 | # Gsplat config templates 11 | from feature_splatting.feature_splatting_datamgr import FeatureSplattingDataManagerConfig 12 | from feature_splatting.model import FeatureSplattingModelConfig 13 | 14 | # Trainer config is modified from the template at 15 | # https://github.com/nerfstudio-project/nerfstudio/blob/bf3664a19a89a61bcac83a9f69cbe2d6dc7c444d/nerfstudio/configs/method_configs.py#L594 16 | feature_splatting_method = MethodSpecification( 17 | config=TrainerConfig( 18 | method_name="feature-splatting", 19 | steps_per_eval_image=100, 20 | steps_per_eval_batch=0, 21 | steps_per_save=2000, 22 | steps_per_eval_all_images=1000, 23 | max_num_iterations=30000, 24 | mixed_precision=False, 25 | pipeline=VanillaPipelineConfig( 26 | datamanager=FeatureSplattingDataManagerConfig( 27 | dataparser=NerfstudioDataParserConfig(load_3D_points=True), 28 | cache_images_type="uint8", 29 | ), 30 | model=FeatureSplattingModelConfig(sh_degree=0), 31 | ), 32 | optimizers={ 33 | "means": { 34 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), 35 | "scheduler": ExponentialDecaySchedulerConfig( 36 | lr_final=1.6e-6, 37 | max_steps=30000, 38 | ), 39 | }, 40 | "features_dc": { 41 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), 42 | "scheduler": None, 43 | }, 44 | "features_rest": { 45 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), 46 | "scheduler": None, 47 | }, 48 | "opacities": { 49 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 50 | "scheduler": None, 51 | }, 52 | "scales": { 53 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), 54 | "scheduler": None, 55 | }, 56 | "distill_features": { 57 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), 58 | "scheduler": ExponentialDecaySchedulerConfig( 59 | lr_final=5e-4, 60 | max_steps=10000, 61 | ), 62 | }, 63 | "feature_mlp": { 64 | "optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), 65 | "scheduler": None, 66 | }, 67 | "quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None}, 68 | "camera_opt": { 69 | "optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15), 70 | "scheduler": ExponentialDecaySchedulerConfig( 71 | lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 72 | ), 73 | }, 74 | }, 75 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 76 | vis="viewer", 77 | ), 78 | description="Feature Splatting distills language-aligned features into 3D Gaussians.", 79 | ) 80 | -------------------------------------------------------------------------------- /feature_splatting/feature_splatting_datamgr.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from dataclasses import dataclass, field 3 | from typing import Dict, Literal, Tuple, Type 4 | from nerfstudio.cameras.cameras import Cameras, CameraType 5 | 6 | import numpy as np 7 | import torch 8 | from jaxtyping import Float 9 | from nerfstudio.data.datamanagers.full_images_datamanager import ( 10 | FullImageDatamanager, 11 | FullImageDatamanagerConfig, 12 | ) 13 | from nerfstudio.utils.rich_utils import CONSOLE 14 | 15 | from feature_splatting.feature_extractor_cfg import SAMCLIPArgs 16 | 17 | # SAMCLIP 18 | from feature_splatting.feature_extractor import batch_extract_feature 19 | 20 | feat_type_to_extract_fn = { 21 | "CLIP": None, 22 | "DINO": None, 23 | "SAMCLIP": batch_extract_feature, 24 | } 25 | 26 | feat_type_to_args = { 27 | "CLIP": None, 28 | "DINO": None, 29 | "SAMCLIP": SAMCLIPArgs, 30 | } 31 | 32 | feat_type_to_main_feature_name = { 33 | "CLIP": "clip", 34 | "DINO": "dino", 35 | "SAMCLIP": "samclip", 36 | } 37 | 38 | @dataclass 39 | class FeatureSplattingDataManagerConfig(FullImageDatamanagerConfig): 40 | _target: Type = field(default_factory=lambda: FeatureSplattingDataManager) 41 | feature_type: Literal["CLIP", "DINO", "SAMCLIP"] = "SAMCLIP" 42 | """Feature type to extract.""" 43 | enable_cache: bool = True 44 | """Whether to cache extracted features.""" 45 | 46 | class FeatureSplattingDataManager(FullImageDatamanager): 47 | config: FeatureSplattingDataManagerConfig 48 | 49 | def __init__(self, *args, **kwargs): 50 | super().__init__(*args, **kwargs) 51 | # Extract features 52 | self.feature_dict = self.extract_features() 53 | 54 | # Split into train and eval features 55 | self.train_feature_dict = {} 56 | self.eval_feature_dict = {} 57 | feature_dim_dict = {} 58 | for feature_name in self.feature_dict: 59 | assert len(self.feature_dict[feature_name]) == len(self.train_dataset) + len(self.eval_dataset) 60 | self.train_feature_dict[feature_name] = self.feature_dict[feature_name][: len(self.train_dataset)] 61 | self.eval_feature_dict[feature_name] = self.feature_dict[feature_name][len(self.train_dataset) :] 62 | feature_dim_dict[feature_name] = self.feature_dict[feature_name].shape[1:] # c, h, w 63 | assert len(self.eval_feature_dict[feature_name]) == len(self.eval_dataset) 64 | 65 | del self.feature_dict 66 | 67 | # Set metadata, so we can initialize model with feature dimensionality 68 | self.train_dataset.metadata["feature_type"] = self.config.feature_type 69 | self.train_dataset.metadata["feature_dim_dict"] = feature_dim_dict 70 | self.train_dataset.metadata["main_feature_name"] = feat_type_to_main_feature_name[self.config.feature_type] 71 | self.train_dataset.metadata["clip_model_name"] = feat_type_to_args[self.config.feature_type].clip_model_name 72 | 73 | # Garbage collect 74 | torch.cuda.empty_cache() 75 | gc.collect() 76 | 77 | def extract_features(self) -> Dict[str, Float[torch.Tensor, "n h w c"]]: 78 | # Extract features 79 | if self.config.feature_type not in feat_type_to_extract_fn: 80 | raise ValueError(f"Unknown feature type {self.config.feature_type}") 81 | extract_fn = feat_type_to_extract_fn[self.config.feature_type] 82 | extract_args = feat_type_to_args[self.config.feature_type] 83 | image_fnames = self.train_dataset.image_filenames + self.eval_dataset.image_filenames 84 | # For dev purpose, visually tested image_fnames order matches camera_idx. NS seems to internally sort valid image_fnames. 85 | # self.feature_image_fnames = image_fnames 86 | 87 | # If cache exists, load it and validate it. We save it to the dataset directory. 88 | cache_dir = self.config.dataparser.data 89 | cache_path = cache_dir / f"feature_splatting_{self.config.feature_type.lower()}_features.pt" 90 | if self.config.enable_cache and cache_path.exists(): 91 | cache_dict = torch.load(cache_path) 92 | if cache_dict.get("image_fnames") != image_fnames: 93 | CONSOLE.print("Image filenames have changed, cache invalidated...") 94 | elif cache_dict.get("args") != extract_args.id_dict(): 95 | CONSOLE.print("Feature extraction args have changed, cache invalidated...") 96 | else: 97 | return cache_dict["feature_dict"] 98 | 99 | # Cache is invalid or doesn't exist, so extract features 100 | CONSOLE.print(f"Extracting {self.config.feature_type} features for {len(image_fnames)} images...") 101 | feature_dict = extract_fn(image_fnames, extract_args) 102 | if self.config.enable_cache: 103 | cache_dict = {"args": extract_args.id_dict(), "image_fnames": image_fnames, "feature_dict": feature_dict} 104 | cache_dir.mkdir(exist_ok=True) 105 | torch.save(cache_dict, cache_path) 106 | CONSOLE.print(f"Saved {self.config.feature_type} features to cache at {cache_path}") 107 | return feature_dict 108 | 109 | def next_train(self, step: int) -> Tuple[Cameras, Dict]: 110 | camera, data = super().next_train(step) 111 | camera_idx = camera.metadata['cam_idx'] 112 | feature_dict = {} 113 | for feature_name in self.train_feature_dict: 114 | feature_dict[feature_name] = self.train_feature_dict[feature_name][camera_idx] 115 | data["feature_dict"] = feature_dict 116 | return camera, data 117 | 118 | def next_eval(self, step: int) -> Tuple[Cameras, Dict]: 119 | camera, data = super().next_eval(step) 120 | camera_idx = camera.metadata['cam_idx'] 121 | feature_dict = {} 122 | for feature_name in self.eval_feature_dict: 123 | feature_dict[feature_name] = self.eval_feature_dict[feature_name][camera_idx] 124 | data["feature_dict"] = feature_dict 125 | return camera, data 126 | -------------------------------------------------------------------------------- /feature_splatting/model.py: -------------------------------------------------------------------------------- 1 | import time 2 | from dataclasses import dataclass, field 3 | from typing import Dict, List, Literal, Optional, Tuple, Type, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from nerfstudio.models.splatfacto import SplatfactoModel, SplatfactoModelConfig, get_viewmat 10 | from nerfstudio.cameras.cameras import Cameras 11 | from nerfstudio.utils.rich_utils import CONSOLE 12 | from nerfstudio.viewer.server.viewer_elements import ( 13 | ViewerButton, 14 | ViewerNumber, 15 | ViewerText, 16 | ViewerCheckbox, 17 | ViewerSlider, 18 | ViewerVec3, 19 | ) 20 | from nerfstudio.data.scene_box import OrientedBox 21 | 22 | # Feature splatting functions 23 | from torch.nn import Parameter 24 | from feature_splatting.utils import ( 25 | ViewerUtils, 26 | apply_pca_colormap_return_proj, 27 | two_layer_mlp, 28 | clip_text_encoder, 29 | compute_similarity, 30 | cluster_instance, 31 | estimate_ground, 32 | get_ground_bbox_min_max, 33 | gaussian_editor 34 | ) 35 | try: 36 | from gsplat.cuda._torch_impl import _quat_to_rotmat 37 | from gsplat.rendering import rasterization 38 | except ImportError: 39 | print("Please install gsplat>=1.0.0") 40 | 41 | @dataclass 42 | class FeatureSplattingModelConfig(SplatfactoModelConfig): 43 | """Note: make sure to use naming that doesn't conflict with NerfactoModelConfig""" 44 | 45 | _target: Type = field(default_factory=lambda: FeatureSplattingModel) 46 | # Compute SHs in python 47 | python_compute_sh: bool = False 48 | # Weighing for the overall feature loss 49 | feat_loss_weight: float = 1e-3 50 | feat_aux_loss_weight: float = 0.1 51 | # Latent dimension for the feature field 52 | # TODO(roger): this feat_dim has to add up depth/color to a number that can be rasterized without padding 53 | # https://github.com/nerfstudio-project/gsplat/blob/main/gsplat/cuda/_wrapper.py#L431 54 | # gsplat's N-D implementation seems to have some bugs that cause padded tensors to have memory issues 55 | # we can create a PR to fix this. 56 | feat_latent_dim: int = 13 57 | # Feature Field MLP Head 58 | mlp_hidden_dim: int = 64 59 | 60 | def cosine_loss(network_output, gt): 61 | assert network_output.shape == gt.shape 62 | return (1 - F.cosine_similarity(network_output, gt, dim=0)).mean() 63 | 64 | class FeatureSplattingModel(SplatfactoModel): 65 | config: FeatureSplattingModelConfig 66 | 67 | def populate_modules(self): 68 | super().populate_modules() 69 | # Sanity check 70 | if self.config.python_compute_sh: 71 | raise NotImplementedError("Not implemented yet") 72 | if self.config.sh_degree > 0: 73 | assert self.config.python_compute_sh, "SH computation is only supported in python" 74 | else: 75 | assert not self.config.python_compute_sh, "SHs python compute flag should not be used with 0 SH degree" 76 | 77 | # Initialize per-Gaussian features 78 | distill_features = torch.nn.Parameter(torch.zeros((self.means.shape[0], self.config.feat_latent_dim))) 79 | self.gauss_params["distill_features"] = distill_features 80 | self.main_feature_name = self.kwargs["metadata"]["main_feature_name"] 81 | self.main_feature_shape_chw = self.kwargs["metadata"]["feature_dim_dict"][self.main_feature_name] 82 | 83 | # Initialize the multi-head feature MLP 84 | self.feature_mlp = two_layer_mlp(self.config.feat_latent_dim, 85 | self.config.mlp_hidden_dim, 86 | self.kwargs["metadata"]["feature_dim_dict"]) 87 | 88 | # Visualization utils 89 | self.maybe_populate_text_encoder() 90 | self.setup_gui() 91 | 92 | self.gaussian_editor = gaussian_editor() 93 | 94 | def maybe_populate_text_encoder(self): 95 | if "clip_model_name" in self.kwargs["metadata"]: 96 | assert "clip" in self.main_feature_name.lower(), "CLIP model name should only be used with CLIP features" 97 | self.clip_text_encoder = clip_text_encoder(self.kwargs["metadata"]["clip_model_name"], self.kwargs["device"]) 98 | self.text_encoding_func = self.clip_text_encoder.get_text_token 99 | else: 100 | self.text_encoding_func = None 101 | 102 | def setup_gui(self): 103 | self.viewer_utils = ViewerUtils(self.text_encoding_func) 104 | # Note: the GUI elements are shown based on alphabetical variable names 105 | self.btn_refresh_pca = ViewerButton("Refresh PCA Projection", cb_hook=lambda _: self.viewer_utils.reset_pca_proj()) 106 | if "clip" in self.main_feature_name.lower(): 107 | self.hint_text = ViewerText(name="Note:", disabled=True, default_value="Use , to separate labels") 108 | self.lang_1_pos_text = ViewerText( 109 | name="Positive Text Queries", 110 | default_value="", 111 | cb_hook=lambda elem: self.viewer_utils.update_text_embedding('positive', elem.value), 112 | ) 113 | self.lang_2_neg_text = ViewerText( 114 | name="Negative Text Queries", 115 | default_value="object", 116 | cb_hook=lambda elem: self.viewer_utils.update_text_embedding('negative', elem.value), 117 | ) 118 | # call the callback function with the default value 119 | self.viewer_utils.update_text_embedding('negative', self.lang_2_neg_text.default_value) 120 | self.lang_ground_text = ViewerText( 121 | name="Ground Text Queries", 122 | default_value="floor", 123 | cb_hook=lambda elem: self.viewer_utils.update_text_embedding('ground', elem.value), 124 | ) 125 | self.viewer_utils.update_text_embedding('ground', self.lang_ground_text.default_value) 126 | self.softmax_temp = ViewerNumber( 127 | name="Softmax temperature", 128 | default_value=self.viewer_utils.softmax_temp, 129 | cb_hook=lambda elem: self.viewer_utils.update_softmax_temp(elem.value), 130 | ) 131 | # ===== Start Editing utility ===== 132 | self.edit_checkbox = ViewerCheckbox("Enter Editing Mode", default_value=False, cb_hook=lambda _: self.start_editing()) 133 | # Ground estimation 134 | self.estimate_ground_btn = ViewerButton("Estimate Ground", cb_hook=lambda _: self.estimate_ground(), disabled=True, visible=False) 135 | # Main object segmentation 136 | self.segment_main_obj_btn = ViewerButton("Segment main obj", cb_hook=lambda _: self.segment_positive_obj(), disabled=True, visible=False) 137 | self.bbox_min_offset_vec = ViewerVec3("BBox Min", default_value=(0, 0, 0), disabled=True, visible=False) 138 | self.bbox_max_offset_vec = ViewerVec3("BBox Max", default_value=(0, 0, 0), disabled=True, visible=False) 139 | self.main_obj_only_checkbox = ViewerCheckbox("View main object only", default_value=True, disabled=True, visible=False) 140 | # Basic editing 141 | self.translation_vec = ViewerVec3("Translation", default_value=(0, 0, 0), disabled=True, visible=False) 142 | self.yaw_rotation = ViewerNumber("Yaw-only Rotation (deg)", default_value=0., disabled=True, visible=False) 143 | # Physics simulation 144 | self.physics_sim_checkbox = ViewerCheckbox("Physics Simulation", default_value=False, disabled=True, visible=False) 145 | self.physics_sim_step_btn = ViewerButton("Physics Simulation Step", disabled=True, visible=False, cb_hook=lambda _: self.physics_sim_step()) 146 | 147 | def physics_sim_step(self): 148 | # It's just a placeholder now. NS needs some user interaction to send rendering requests. 149 | # So I make a button that does nothing but to trigger rendering. 150 | pass 151 | 152 | def estimate_ground(self): 153 | selected_obj_idx, sample_idx = self.segment_gaussian('ground', use_canonical=True, threshold=0.5) 154 | ground_means_xyz = self.means[sample_idx].detach().cpu().numpy()[selected_obj_idx] 155 | self.ground_R, self.ground_T, ground_inliers = estimate_ground(ground_means_xyz) 156 | self.gaussian_editor.register_ground_transform(self.ground_R, self.ground_T) 157 | 158 | # Enable next step 159 | self.segment_main_obj_btn.set_disabled(False) 160 | self.segment_main_obj_btn.set_visible(True) 161 | 162 | def start_editing(self): 163 | self.estimate_ground_btn.set_disabled(False) 164 | self.estimate_ground_btn.set_visible(True) 165 | 166 | def segment_positive_obj(self): 167 | selected_obj_idx, sample_idx = self.segment_gaussian('positive', use_canonical=False) 168 | 169 | all_xyz = self.means.detach().cpu().numpy() 170 | selected_xyz = all_xyz[sample_idx] 171 | 172 | selected_obj_idx = cluster_instance(selected_xyz, selected_obj_idx) 173 | 174 | # Get the boolean flag of selected particles (of all particles) 175 | subset_idx = np.zeros(self.means.shape[0], dtype=bool) 176 | subset_idx[sample_idx[selected_obj_idx]] = True 177 | 178 | ground_min, ground_max = get_ground_bbox_min_max(all_xyz, subset_idx, self.ground_R, self.ground_T) 179 | 180 | self.gaussian_editor.register_object_minimax(ground_min, ground_max) 181 | 182 | # Enable bbox editing 183 | self.bbox_min_offset_vec.set_disabled(False) 184 | self.bbox_min_offset_vec.set_visible(True) 185 | self.bbox_max_offset_vec.set_disabled(False) 186 | self.bbox_max_offset_vec.set_visible(True) 187 | self.main_obj_only_checkbox.set_disabled(False) 188 | self.main_obj_only_checkbox.set_visible(True) 189 | 190 | # Enable basic editing utilities 191 | self.translation_vec.set_disabled(False) 192 | self.translation_vec.set_visible(True) 193 | self.yaw_rotation.set_disabled(False) 194 | self.yaw_rotation.set_visible(True) 195 | 196 | # Enable physics simulation 197 | self.physics_sim_checkbox.set_disabled(False) 198 | self.physics_sim_checkbox.set_visible(True) 199 | self.physics_sim_step_btn.set_disabled(False) 200 | self.physics_sim_step_btn.set_visible(True) 201 | 202 | def segment_gaussian(self, field_name : str, use_canonical : bool, sample_size : Optional[int] = 2**15, threshold : Optional[float] = 0.5): 203 | if "clip" not in self.main_feature_name.lower(): 204 | return 205 | if sample_size is not None: 206 | sample_size = min(2**15, self.means.shape[0]) 207 | sample_idx = torch.randperm(self.means.shape[0])[:sample_size] 208 | sampled_features = self.distill_features[sample_idx] 209 | else: 210 | sample_idx = torch.arange(self.means.shape[0]) 211 | sampled_features = self.distill_features 212 | clip_feature_nc = self.feature_mlp.per_gaussian_forward(sampled_features)[self.main_feature_name] 213 | clip_feature_nc /= clip_feature_nc.norm(dim=1, keepdim=True) 214 | clip_feature_cn = clip_feature_nc.permute(1, 0) 215 | 216 | # Use paired softmax method as described in the paper with positive and negative texts 217 | if not use_canonical and self.viewer_utils.is_embed_valid('negative'): 218 | neg_embedding = self.viewer_utils.get_text_embed('negative') 219 | else: 220 | neg_embedding = self.viewer_utils.get_text_embed('canonical') 221 | text_embs = torch.cat([self.viewer_utils.get_text_embed(field_name), neg_embedding], dim=0) 222 | raw_sims = torch.einsum("cm,nc->nm", clip_feature_cn, text_embs) 223 | pos_sim = compute_similarity(raw_sims, self.viewer_utils.softmax_temp, self.viewer_utils.get_embed_shape(field_name)[0]) 224 | 225 | # pos_sim -= pos_sim.min() 226 | # pos_sim /= pos_sim.max() 227 | 228 | selected_obj_idx = (pos_sim > threshold).cpu().numpy() 229 | 230 | return selected_obj_idx, sample_idx 231 | 232 | def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: 233 | """Takes in a camera and returns a dictionary of outputs. 234 | 235 | Args: 236 | camera: The camera(s) for which output images are rendered. It should have 237 | all the needed information to compute the outputs. 238 | 239 | Returns: 240 | Outputs of model. (ie. rendered colors) 241 | """ 242 | if not isinstance(camera, Cameras): 243 | print("Called get_outputs with not a camera") 244 | return {} 245 | 246 | if self.training: 247 | assert camera.shape[0] == 1, "Only one camera at a time" 248 | optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera) 249 | else: 250 | optimized_camera_to_world = camera.camera_to_worlds 251 | 252 | # cropping 253 | if self.crop_box is not None and not self.training: 254 | crop_ids = self.crop_box.within(self.means).squeeze() 255 | if crop_ids.sum() == 0: 256 | return self.get_empty_outputs( 257 | int(camera.width.item()), int(camera.height.item()), self.background_color 258 | ) 259 | else: 260 | crop_ids = None 261 | 262 | if crop_ids is not None: 263 | opacities_crop = self.opacities[crop_ids] 264 | means_crop = self.means[crop_ids] 265 | features_dc_crop = self.features_dc[crop_ids] 266 | features_rest_crop = self.features_rest[crop_ids] 267 | scales_crop = self.scales[crop_ids] 268 | quats_crop = self.quats[crop_ids] 269 | distill_features_crop = self.distill_features[crop_ids] 270 | else: 271 | opacities_crop = self.opacities 272 | means_crop = self.means 273 | features_dc_crop = self.features_dc 274 | features_rest_crop = self.features_rest 275 | scales_crop = self.scales 276 | quats_crop = self.quats 277 | distill_features_crop = self.distill_features 278 | 279 | # features_dc_crop.shape: [N, 3] 280 | # features_rest_crop.shape: [N, 15, 3] 281 | colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) 282 | # colors_crop.shape: [N, 16, 3] 283 | 284 | BLOCK_WIDTH = 16 # this controls the tile size of rasterization, 16 is a good default 285 | camera_scale_fac = self._get_downscale_factor() 286 | camera.rescale_output_resolution(1 / camera_scale_fac) 287 | viewmat = get_viewmat(optimized_camera_to_world) 288 | K = camera.get_intrinsics_matrices().cuda() 289 | W, H = int(camera.width.item()), int(camera.height.item()) 290 | self.last_size = (H, W) 291 | camera.rescale_output_resolution(camera_scale_fac) # type: ignore 292 | 293 | # apply the compensation of screen space blurring to gaussians 294 | if self.config.rasterize_mode not in ["antialiased", "classic"]: 295 | raise ValueError("Unknown rasterize_mode: %s", self.config.rasterize_mode) 296 | 297 | if self.config.output_depth_during_training or not self.training: 298 | # Actually render RGB, features, and depth, but can't use RGB+FEAT+ED because we hack gsplat 299 | render_mode = "RGB+ED" 300 | else: 301 | render_mode = "RGB" 302 | 303 | if self.config.sh_degree > 0: 304 | sh_degree_to_use = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) 305 | assert self.config.python_compute_sh, "SH computation is only supported in python" 306 | raise NotImplementedError("Python SHs computation not implemented yet") 307 | sh_degree_to_use = None 308 | else: 309 | colors_crop = torch.sigmoid(colors_crop).squeeze(1) # [N, 1, 3] -> [N, 3] 310 | fused_render_properties = torch.cat((colors_crop, distill_features_crop), dim=1) 311 | sh_degree_to_use = None 312 | 313 | render, alpha, self.info = rasterization( 314 | means=means_crop, 315 | quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True), 316 | scales=torch.exp(scales_crop), 317 | opacities=torch.sigmoid(opacities_crop).squeeze(-1), 318 | colors=fused_render_properties, 319 | viewmats=viewmat, # [1, 4, 4] 320 | Ks=K, # [1, 3, 3] 321 | width=W, 322 | height=H, 323 | tile_size=BLOCK_WIDTH, 324 | packed=False, 325 | near_plane=0.01, 326 | far_plane=1e10, 327 | render_mode=render_mode, 328 | sh_degree=sh_degree_to_use, 329 | sparse_grad=False, 330 | absgrad=True, 331 | rasterize_mode=self.config.rasterize_mode, 332 | # set some threshold to disregrad small gaussians for faster rendering. 333 | # radius_clip=3.0, 334 | ) 335 | 336 | if self.training and self.info["means2d"].requires_grad: 337 | self.info["means2d"].retain_grad() 338 | self.xys = self.info["means2d"] # [1, N, 2] 339 | self.radii = self.info["radii"][0] # [N] 340 | alpha = alpha[:, ...] 341 | 342 | background = self._get_background_color() 343 | rgb = render[:, ..., :3] + (1 - alpha) * background 344 | rgb = torch.clamp(rgb, 0.0, 1.0) 345 | 346 | if render_mode == "RGB+ED": 347 | assert render.shape[3] == 3 + self.config.feat_latent_dim + 1 348 | depth_im = render[:, ..., -1:] 349 | depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0) 350 | else: 351 | assert render.shape[3] == 3 + self.config.feat_latent_dim 352 | depth_im = None 353 | 354 | if background.shape[0] == 3 and not self.training: 355 | background = background.expand(H, W, 3) 356 | 357 | feature = render[:, ..., 3:3 + self.config.feat_latent_dim] 358 | 359 | return { 360 | "rgb": rgb.squeeze(0), # type: ignore 361 | "depth": depth_im, # type: ignore 362 | "accumulation": alpha.squeeze(0), # type: ignore 363 | "background": background, # type: ignore, 364 | "feature": feature.squeeze(0), # type: ignore 365 | } # type: ignore 366 | 367 | def decode_features(self, features_hwc: torch.Tensor, resize_factor: float = 1.) -> Dict[str, torch.Tensor]: 368 | # Decode features 369 | feature_chw = features_hwc.permute(2, 0, 1) 370 | feature_shape_hw = (int(self.main_feature_shape_chw[1] * resize_factor), int(self.main_feature_shape_chw[2] * resize_factor)) 371 | rendered_feat = F.interpolate(feature_chw.unsqueeze(0), size=feature_shape_hw, mode="bilinear", align_corners=False) 372 | rendered_feat_dict = self.feature_mlp(rendered_feat) 373 | # Rest of the features 374 | for key, feat_shape_chw in self.kwargs["metadata"]["feature_dim_dict"].items(): 375 | if key != self.main_feature_name: 376 | rendered_feat_dict[key] = F.interpolate(rendered_feat_dict[key], size=feat_shape_chw[1:], mode="bilinear", align_corners=False) 377 | rendered_feat_dict[key] = rendered_feat_dict[key].squeeze(0) 378 | return rendered_feat_dict 379 | 380 | def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]: 381 | # Splatfacto computes the loss for the rgb image 382 | loss_dict = super().get_loss_dict(outputs, batch, metrics_dict) 383 | for k in batch['feature_dict']: 384 | batch['feature_dict'][k] = batch['feature_dict'][k].to(self.device) 385 | decoded_feature_dict = self.decode_features(outputs["feature"]) 386 | feature_loss = torch.tensor(0.0, device=self.device) 387 | for key, target_feat in batch['feature_dict'].items(): 388 | cur_loss_weight = 1.0 if key == self.main_feature_name else self.config.feat_aux_loss_weight 389 | ignore_feat_mask = (torch.sum(target_feat == 0, dim=0) == target_feat.shape[0]) 390 | target_feat[:, ignore_feat_mask] = decoded_feature_dict[key][:, ignore_feat_mask] 391 | feature_loss += cosine_loss(decoded_feature_dict[key], target_feat) * cur_loss_weight 392 | loss_dict["feature_loss"] = self.config.feat_loss_weight * feature_loss 393 | return loss_dict 394 | 395 | @torch.no_grad() 396 | def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]: 397 | """This function is not called during training, but used for visualization in browser. So we can use it to 398 | add visualization not needed during training. 399 | """ 400 | editing_dict = self.gaussian_editor.prepare_editing_dict(self.translation_vec.value, self.yaw_rotation.value, self.physics_sim_checkbox.value) 401 | if self.edit_checkbox.value: 402 | # Editing mode 403 | self.gaussian_editor.pre_rendering_process(self.means, self.opacities, self.scales, self.quats, 404 | editing_dict=editing_dict, 405 | min_offset=torch.tensor(self.bbox_min_offset_vec.value).float().cuda() / 10.0, 406 | max_offset=torch.tensor(self.bbox_max_offset_vec.value).float().cuda() / 10.0, 407 | view_main_obj_only=self.main_obj_only_checkbox.value) 408 | outs = super().get_outputs_for_camera(camera, obb_box) 409 | if self.edit_checkbox.value: 410 | self.gaussian_editor.post_rendering_process(self.means, self.opacities, self.quats, self.scales) 411 | if self.physics_sim_checkbox.value: 412 | # turn off feature rendering during physics sim for speed 413 | return outs 414 | # Consistent pca that does not flicker 415 | outs["consistent_latent_pca"], self.viewer_utils.pca_proj, *_ = apply_pca_colormap_return_proj( 416 | outs["feature"], self.viewer_utils.pca_proj 417 | ) 418 | # TODO(roger): this resize factor affects the resolution of similarity map. Maybe we should use a fixed size? 419 | decoded_feature_dict = self.decode_features(outs["feature"], resize_factor=8) 420 | if "clip" in self.main_feature_name.lower() and self.viewer_utils.is_embed_valid('positive'): 421 | clip_features = decoded_feature_dict[self.main_feature_name] 422 | clip_features /= clip_features.norm(dim=0, keepdim=True) 423 | 424 | # Use paired softmax method as described in the paper with positive and negative texts 425 | if self.viewer_utils.is_embed_valid('negative'): 426 | neg_embedding = self.viewer_utils.get_text_embed('negative') 427 | else: 428 | neg_embedding = self.viewer_utils.get_text_embed('canonical') 429 | text_embs = torch.cat([self.viewer_utils.get_text_embed('positive'), neg_embedding], dim=0) 430 | raw_sims = torch.einsum("chw,nc->nhw", clip_features, text_embs) 431 | sim_shape_hw = raw_sims.shape[1:] 432 | 433 | raw_sims = raw_sims.reshape(raw_sims.shape[0], -1) 434 | pos_sim = compute_similarity(raw_sims, self.viewer_utils.softmax_temp, self.viewer_utils.get_embed_shape('positive')[0]) 435 | outs["similarity"] = pos_sim.reshape(sim_shape_hw + (1,)) # H, W, 1 436 | 437 | # Upsample heatmap to match size of RGB image 438 | # It's a bit slow since we do it on full resolution; but interpolation seems to have aliasing issues 439 | assert outs["similarity"].shape[2] == 1 440 | if outs["similarity"].shape[:2] != outs["rgb"].shape[:2]: 441 | out_sim = outs["similarity"][:, :, 0] # H, W 442 | out_sim = out_sim[None, None, ...] # 1, 1, H, W 443 | outs["similarity"] = F.interpolate(out_sim, size=outs["rgb"].shape[:2], mode="bilinear", align_corners=False).squeeze() 444 | outs["similarity"] = outs["similarity"][:, :, None] 445 | return outs 446 | 447 | # ===== Utils functions for managing the gaussians ===== 448 | 449 | @property 450 | def distill_features(self): 451 | return self.gauss_params["distill_features"] 452 | 453 | def load_state_dict(self, dict, **kwargs): # type: ignore 454 | # resize the parameters to match the new number of points 455 | self.step = 30000 456 | if "means" in dict: 457 | # For backwards compatibility, we remap the names of parameters from 458 | # means->gauss_params.means since old checkpoints have that format 459 | for p in self.gauss_params.keys(): 460 | dict[f"gauss_params.{p}"] = dict[p] 461 | newp = dict["gauss_params.means"].shape[0] 462 | for name, param in self.gauss_params.items(): 463 | old_shape = param.shape 464 | new_shape = (newp,) + old_shape[1:] 465 | self.gauss_params[name] = torch.nn.Parameter(torch.zeros(new_shape, device=self.device)) 466 | super().load_state_dict(dict, **kwargs) 467 | 468 | def split_gaussians(self, split_mask, samps): 469 | """ 470 | This function splits gaussians that are too large 471 | """ 472 | n_splits = split_mask.sum().item() 473 | CONSOLE.log(f"Splitting {split_mask.sum().item()/self.num_points} gaussians: {n_splits}/{self.num_points}") 474 | centered_samples = torch.randn((samps * n_splits, 3), device=self.device) # Nx3 of axis-aligned scales 475 | scaled_samples = ( 476 | torch.exp(self.scales[split_mask].repeat(samps, 1)) * centered_samples 477 | ) # how these scales are rotated 478 | quats = self.quats[split_mask] / self.quats[split_mask].norm(dim=-1, keepdim=True) # normalize them first 479 | rots = quat_to_rotmat(quats.repeat(samps, 1)) # how these scales are rotated 480 | rotated_samples = torch.bmm(rots, scaled_samples[..., None]).squeeze() 481 | new_means = rotated_samples + self.means[split_mask].repeat(samps, 1) 482 | # step 2, sample new colors 483 | new_features_dc = self.features_dc[split_mask].repeat(samps, 1) 484 | new_features_rest = self.features_rest[split_mask].repeat(samps, 1, 1) 485 | # step 3, sample new opacities 486 | new_opacities = self.opacities[split_mask].repeat(samps, 1) 487 | # step 4, sample new scales 488 | size_fac = 1.6 489 | new_scales = torch.log(torch.exp(self.scales[split_mask]) / size_fac).repeat(samps, 1) 490 | self.scales[split_mask] = torch.log(torch.exp(self.scales[split_mask]) / size_fac) 491 | # step 5, sample new quats 492 | new_quats = self.quats[split_mask].repeat(samps, 1) 493 | # step 6 (RQ, July 2024), sample new distill_features 494 | new_distill_features = self.distill_features[split_mask].repeat(samps, 1) 495 | out = { 496 | "means": new_means, 497 | "features_dc": new_features_dc, 498 | "features_rest": new_features_rest, 499 | "opacities": new_opacities, 500 | "scales": new_scales, 501 | "quats": new_quats, 502 | "distill_features": new_distill_features, 503 | } 504 | for name, param in self.gauss_params.items(): 505 | if name not in out: 506 | out[name] = param[split_mask].repeat(samps, 1) 507 | return out 508 | 509 | def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]: 510 | # Here we explicitly use the means, scales as parameters so that the user can override this function and 511 | # specify more if they want to add more optimizable params to gaussians. 512 | return { 513 | name: [self.gauss_params[name]] 514 | for name in self.gauss_params.keys() 515 | } 516 | 517 | def get_param_groups(self) -> Dict[str, List[Parameter]]: 518 | # Gather Gaussian-related parameters 519 | # The distill_features parameter is added via the get_gaussian_param_groups method 520 | param_groups = super().get_param_groups() 521 | param_groups["feature_mlp"] = list(self.feature_mlp.parameters()) 522 | return param_groups 523 | 524 | def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: 525 | """Compute and returns metrics. 526 | 527 | Args: 528 | outputs: the output to compute loss dict to 529 | batch: ground truth batch corresponding to outputs 530 | """ 531 | gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"]) 532 | metrics_dict = {} 533 | predicted_rgb = outputs["rgb"] 534 | metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb) 535 | 536 | metrics_dict["gaussian_count"] = self.num_points 537 | 538 | self.camera_optimizer.get_metrics_dict(metrics_dict) 539 | return metrics_dict 540 | -------------------------------------------------------------------------------- /feature_splatting/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_text_encoder import clip_text_encoder 2 | from .decoder_utils import two_layer_mlp, compute_similarity 3 | from .viewer_utils import ViewerUtils, apply_pca_colormap_return_proj 4 | from .segment_utils import cluster_instance, estimate_ground, get_ground_bbox_min_max 5 | from .gaussian_editor import gaussian_editor -------------------------------------------------------------------------------- /feature_splatting/utils/clip_text_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import torch 3 | import maskclip_onnx 4 | 5 | class clip_text_encoder: 6 | def __init__(self, clip_model_name: str, device: Union[str, torch.device]): 7 | self.clip_model_name = clip_model_name 8 | self.device = device 9 | self.clip, _ = maskclip_onnx.clip.load(self.clip_model_name, device=self.device) 10 | self.clip.eval() 11 | 12 | @torch.no_grad() 13 | def get_text_token(self, text_list: List[str]): 14 | """Compute CLIP embeddings based on queries and update state""" 15 | tokens = maskclip_onnx.clip.tokenize(text_list).to(self.device) 16 | embed = self.clip.encode_text(tokens).float() 17 | embed /= embed.norm(dim=-1, keepdim=True) 18 | return embed 19 | -------------------------------------------------------------------------------- /feature_splatting/utils/decoder_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class two_layer_mlp(nn.Module): 6 | def __init__(self, input_dim, hidden_dim, feature_dim_dict): 7 | super(two_layer_mlp, self).__init__() 8 | self.hidden_conv = nn.Conv2d(input_dim, hidden_dim, kernel_size=1, stride=1, padding=0) 9 | feature_branch_dict = {} 10 | for key, feat_dim_chw in feature_dim_dict.items(): 11 | feature_branch_dict[key] = nn.Conv2d(hidden_dim, feat_dim_chw[0], kernel_size=1, stride=1, padding=0) 12 | self.feature_branch_dict = nn.ModuleDict(feature_branch_dict) 13 | 14 | def forward(self, x): 15 | intermediate_feature = self.hidden_conv(x) 16 | intermediate_feature = F.relu(intermediate_feature) 17 | ret_dict = {} 18 | for key, nn_mod in self.feature_branch_dict.items(): 19 | ret_dict[key] = nn_mod(intermediate_feature) 20 | return ret_dict 21 | 22 | @torch.no_grad() 23 | def per_gaussian_forward(self, x): 24 | intermediate_feature = F.linear(x, self.hidden_conv.weight.view(self.hidden_conv.weight.size(0), -1), self.hidden_conv.bias) 25 | intermediate_feature = F.relu(intermediate_feature) 26 | ret_dict = {} 27 | for key, nn_mod in self.feature_branch_dict.items(): 28 | ret_dict[key] = F.linear(intermediate_feature, nn_mod.weight.view(nn_mod.weight.size(0), -1), nn_mod.bias) 29 | return ret_dict 30 | 31 | def compute_similarity(prob_mn, softmax_temp, num_pos, heatmap_method="standard_softmax"): 32 | """ 33 | Compute probability of an element being positive 34 | 35 | Args: 36 | prob_mn: Tensor of shape (m, n); where m is the number of total classes; n is the number of elements 37 | softmax_temp: float 38 | num_pos: int 39 | """ 40 | assert num_pos <= prob_mn.shape[0] 41 | if heatmap_method == "standard_softmax": # Feature splatting uses this 42 | prob_mn = prob_mn / softmax_temp 43 | probs = prob_mn.softmax(dim=0) 44 | pos_sim = probs[:num_pos].sum(dim=0) # H, W 45 | return pos_sim 46 | elif heatmap_method == "pairwise_softmax": # F3RM uses this 47 | # Broadcast positive label similarities to all negative labels 48 | pos_sims = prob_mn[:num_pos] 49 | neg_sims = prob_mn[num_pos:] 50 | pos_sims = pos_sims.mean(dim=0, keepdim=True) 51 | pos_sims = pos_sims.broadcast_to(neg_sims.shape) 52 | paired_sims = torch.cat([pos_sims, neg_sims], dim=0) 53 | 54 | # Compute paired softmax 55 | probs = (paired_sims / softmax_temp).softmax(dim=0)[:1, ...] 56 | torch.nan_to_num_(probs, nan=0.0) 57 | sims, _ = probs.min(dim=0) 58 | return sims 59 | else: 60 | raise ValueError(f"Unknown heatmap method: {heatmap_method}") 61 | -------------------------------------------------------------------------------- /feature_splatting/utils/gaussian_editor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import threading 3 | import numpy as np 4 | import torch 5 | from scipy.spatial.transform import Rotation 6 | from feature_splatting.utils.mpm_engine.mpm_solver import MPMSolver 7 | 8 | class gaussian_editor: 9 | # TODO(roger): this should be integrated with viewer_utils 10 | def __init__(self): 11 | self.meta_editing_dict = {} 12 | self.particle_modification_buffer = {} 13 | # NS is asynchonous and the pre-process/post-process functions may be called 14 | # in different threads 15 | 16 | # There may be a better way to do this, but for now we use a lock 17 | self.editors_lock = threading.Lock() 18 | 19 | def register_object_minimax(self, xyz_min, xyz_max): 20 | # Object bounding box minimax 21 | self.meta_editing_dict['xyz_min'] = torch.tensor(xyz_min).cuda().float() 22 | self.meta_editing_dict['xyz_max'] = torch.tensor(xyz_max).cuda().float() 23 | 24 | def register_ground_transform(self, ground_R, ground_T): 25 | self.meta_editing_dict['ground_R_np'] = ground_R.copy() 26 | self.meta_editing_dict['ground_T_np'] = ground_T.copy() 27 | self.meta_editing_dict['ground_R'] = torch.tensor(ground_R).cuda().float() 28 | self.meta_editing_dict['ground_T'] = torch.tensor(ground_T).cuda().float() 29 | up_gravity_vec = np.array((0, 1, 0)) 30 | up_gravity_vec = ground_R.T @ up_gravity_vec 31 | self.meta_editing_dict['up_gravity_vec_np'] = up_gravity_vec 32 | 33 | def prepare_editing_dict(self, translation_vec, yaw_deg, physics_sim_flag): 34 | ret_editing_dict = {} 35 | if 'ground_R' in self.meta_editing_dict and 'ground_T' in self.meta_editing_dict: 36 | # Ground-aligned translation 37 | trans_vec_np = np.array(translation_vec) 38 | if trans_vec_np.any(): 39 | trans_vec_gpu = torch.tensor(trans_vec_np).float().cuda() 40 | ret_editing_dict["translation"] = self.meta_editing_dict['ground_R'].T @ trans_vec_gpu 41 | # Ground-aligned rotation 42 | if yaw_deg: 43 | # TODO(roger): currently only support yaw rotation around gravity axis 44 | rot_axis = self.meta_editing_dict['up_gravity_vec_np'] / np.linalg.norm(self.meta_editing_dict['up_gravity_vec_np']) 45 | r = Rotation.from_rotvec(yaw_deg * rot_axis, degrees=True) 46 | rot_mat = r.as_matrix() 47 | ret_editing_dict["rotation"] = torch.tensor(rot_mat).float().cuda() 48 | if physics_sim_flag: 49 | if "physics_sim" not in self.meta_editing_dict: 50 | print("Initializing physics simulation engine...") 51 | self.initialize_mpm_engine() 52 | ret_editing_dict["physics_sim"] = True 53 | if not physics_sim_flag: 54 | if "physics_sim" in self.meta_editing_dict: 55 | del self.meta_editing_dict["physics_sim"] 56 | gc.collect() 57 | 58 | return ret_editing_dict 59 | 60 | @torch.no_grad() 61 | def initialize_mpm_engine(self, youngs_modulus_scale=1, poisson_ratio=0.2): 62 | import taichi as ti 63 | ti.init(arch=ti.cuda, device_memory_GB=4.0) 64 | 65 | gui = ti.GUI("Taichi Elements", res=512, background_color=0x112F41) 66 | 67 | mpm = MPMSolver(res=(32, 32, 32), size=1, max_num_particles=2 ** 21, 68 | E_scale=youngs_modulus_scale, poisson_ratio=poisson_ratio) 69 | 70 | self.meta_editing_dict["physics_sim"] = { 71 | "mpm": mpm, 72 | "gui": gui, 73 | "initialized": False 74 | } 75 | 76 | @torch.no_grad() 77 | def initialize_mpm_w_particles(self, init_particles_positions, infilling_downsample_ratio=0.2, ground_level=0.05, gravity=4): 78 | assert not self.meta_editing_dict["physics_sim"]["initialized"] 79 | real_gaussian_particle = init_particles_positions 80 | real_obj_center = real_gaussian_particle.mean(axis=0) 81 | 82 | # Add pseudo points from center of the object to all particles 83 | support_per_particles = 10 84 | 85 | support_particles_list = [] 86 | 87 | for particles_idx in range(real_gaussian_particle.shape[0]): 88 | start_pos = real_obj_center 89 | end_pos = real_gaussian_particle[particles_idx] 90 | for support_idx in range(support_per_particles): 91 | # interpolate 92 | pos = (start_pos * (support_per_particles - support_idx) + end_pos * support_idx) / support_per_particles 93 | support_particles_list.append(pos) 94 | 95 | support_particles = np.array(support_particles_list) 96 | support_particles = np.random.permutation(support_particles)[:int(len(support_particles) * infilling_downsample_ratio)] 97 | 98 | all_particles = np.concatenate([real_gaussian_particle, support_particles], axis=0) 99 | 100 | # Align to ground 101 | particles = all_particles @ self.meta_editing_dict['ground_R_np'].T 102 | particles += self.meta_editing_dict['ground_T_np'] 103 | 104 | # Normalize everything to a unit world box; x-z coordinates are centered at 0.5 105 | particle_max = particles.max(axis=0) 106 | particle_min = particles.min(axis=0) 107 | particle_min[1] = min(particle_min[1], ground_level) 108 | 109 | longest_side = max(particle_max - particle_min) 110 | 111 | particles[:, 0] /= longest_side 112 | particles[:, 1] /= longest_side 113 | particles[:, 2] /= longest_side 114 | 115 | # Align centers of x and z to 0.5 and set the bottom of the object to 0 116 | shift_constant = np.array([ 117 | -particles[:,0].mean() + 0.5, 118 | -particles[:,1].min(), 119 | -particles[:,2].mean() + 0.5 120 | ]) 121 | 122 | particles += shift_constant 123 | 124 | self.meta_editing_dict["physics_sim"]["mpm"].add_particles(particles=particles, 125 | material=MPMSolver.material_sand, 126 | color=0xFFFF00) 127 | 128 | self.meta_editing_dict["physics_sim"]["mpm"].add_surface_collider(point=(0.0, ground_level, 0.0), 129 | normal=(0, 1, 0), 130 | surface=self.meta_editing_dict["physics_sim"]["mpm"].surface_sticky) 131 | 132 | self.meta_editing_dict["physics_sim"]["mpm"].set_gravity((0, -gravity, 0)) 133 | 134 | # Memorize constants 135 | self.meta_editing_dict["physics_sim"]["longest_side"] = longest_side 136 | self.meta_editing_dict["physics_sim"]["shift_constant"] = shift_constant 137 | self.meta_editing_dict["physics_sim"]["real_gaussian_particle_size"] = real_gaussian_particle.shape[0] 138 | 139 | @torch.no_grad() 140 | def physics_sim_step(self, timestep=4e-3): 141 | real_gaussian_particle_size = self.meta_editing_dict["physics_sim"]["real_gaussian_particle_size"] 142 | longest_side = self.meta_editing_dict["physics_sim"]["longest_side"] 143 | shift_constant = self.meta_editing_dict["physics_sim"]["shift_constant"] 144 | 145 | particles_info = self.meta_editing_dict["physics_sim"]["mpm"].particle_info() 146 | 147 | real_gaussian_pos = particles_info['position'][:real_gaussian_particle_size] 148 | 149 | ret_trajectory = real_gaussian_pos.copy() 150 | 151 | self.meta_editing_dict["physics_sim"]["mpm"].step(timestep) 152 | 153 | ret_trajectory -= shift_constant 154 | ret_trajectory *= longest_side 155 | 156 | # Reverse rigid transformation 157 | ret_trajectory = ret_trajectory - self.meta_editing_dict['ground_T_np'] 158 | ret_trajectory = ret_trajectory @ self.meta_editing_dict['ground_R_np'] 159 | 160 | return ret_trajectory 161 | 162 | @torch.no_grad() 163 | def pre_rendering_process(self, means, opacities, scales, quats, editing_dict, view_main_obj_only=False, **kwargs): 164 | self.editors_lock.acquire() 165 | assert not self.particle_modification_buffer, "Particle modification buffer is not empty" 166 | # If object bbox is set and view_main_obj_only is set, hide particles outside the bbox 167 | if 'xyz_min' in self.meta_editing_dict: 168 | assert 'min_offset' in kwargs and 'max_offset' in kwargs 169 | # Get object bounding box 170 | bbox_particle_idx = self.filter_particles_ground_bbox(means, 171 | kwargs['min_offset'], 172 | kwargs['max_offset']) 173 | 174 | # Hide particles outside the bounding box? 175 | if view_main_obj_only: 176 | bg_idx = ~bbox_particle_idx 177 | if 'original_opacities' not in self.particle_modification_buffer: 178 | self.particle_modification_buffer['original_opacities'] = opacities.clone() 179 | opacities[bg_idx] = -5 # in-place modification 180 | 181 | if "translation" in editing_dict: 182 | if 'original_means' not in self.particle_modification_buffer: 183 | self.particle_modification_buffer["original_means"] = means.clone() 184 | means[bbox_particle_idx] = means[bbox_particle_idx] + editing_dict["translation"] 185 | 186 | if "rotation" in editing_dict: 187 | if 'original_means' not in self.particle_modification_buffer: 188 | self.particle_modification_buffer["original_means"] = means.clone() 189 | if 'original_quats' not in self.particle_modification_buffer: 190 | self.particle_modification_buffer["original_quats"] = quats.clone() 191 | 192 | rot_mat = editing_dict["rotation"] 193 | 194 | # Rotate x/y/z 195 | selected_pts = means[bbox_particle_idx] 196 | object_center = selected_pts.mean(dim=0) 197 | 198 | selected_pts = selected_pts - object_center 199 | selected_pts = rot_mat @ selected_pts.T 200 | selected_pts = selected_pts.T 201 | selected_pts = selected_pts + object_center 202 | 203 | means[bbox_particle_idx] = selected_pts 204 | 205 | # Rotate covariance 206 | r = quats[bbox_particle_idx] 207 | rot_mat = rot_mat.reshape((1, 3, 3)) # (N, 3, 3) 208 | 209 | r = get_gaussian_rotation(rot_mat.cpu().numpy(), r) 210 | 211 | quats[bbox_particle_idx] = r 212 | 213 | if "physics_sim" in editing_dict and editing_dict["physics_sim"]: 214 | if 'original_means' not in self.particle_modification_buffer: 215 | self.particle_modification_buffer["original_means"] = means.clone() 216 | if not self.meta_editing_dict["physics_sim"]["initialized"]: 217 | selected_particles = means[bbox_particle_idx] 218 | self.initialize_mpm_w_particles(selected_particles.cpu().numpy()) 219 | self.meta_editing_dict["physics_sim"]["initialized"] = True 220 | import time 221 | start_cp = time.time() 222 | particle_positions_np = self.physics_sim_step() 223 | print("Physics sim step time: ", time.time() - start_cp) 224 | assert particle_positions_np.shape[0] == bbox_particle_idx.sum() 225 | means[bbox_particle_idx] = torch.tensor(particle_positions_np).cuda().float() 226 | 227 | # TODO(roger): implement scaling? Below are scaling editing code for INRIA impl 228 | # gaussians._xyz[selected_obj_idx] = gaussians._xyz[selected_obj_idx] / scale 229 | # gaussians._scaling = gaussians.inverse_opacity_activation( 230 | # gaussians.scaling_activation(gaussians._scaling[selected_obj_idx]) / scale 231 | # ) 232 | 233 | @torch.no_grad() 234 | def post_rendering_process(self, means, opacities, quats, scales): 235 | """Inverse function of pre_rendering_process, which reverses 236 | the transformation applied to the particles. 237 | """ 238 | if 'original_opacities' in self.particle_modification_buffer: 239 | opacities.copy_(self.particle_modification_buffer['original_opacities']) 240 | del self.particle_modification_buffer['original_opacities'] 241 | if 'original_means' in self.particle_modification_buffer: 242 | means.copy_(self.particle_modification_buffer['original_means']) 243 | del self.particle_modification_buffer['original_means'] 244 | if 'original_quats' in self.particle_modification_buffer: 245 | quats.copy_(self.particle_modification_buffer['original_quats']) 246 | del self.particle_modification_buffer['original_quats'] 247 | if 'original_scales' in self.particle_modification_buffer: 248 | scales.copy_(self.particle_modification_buffer['original_scales']) 249 | del self.particle_modification_buffer['original_scales'] 250 | self.editors_lock.release() 251 | 252 | def filter_particles_ground_bbox(self, means, min_offset, max_offset): 253 | ground_R = self.meta_editing_dict['ground_R'] 254 | ground_T = self.meta_editing_dict['ground_T'] 255 | particles = means @ ground_R.T 256 | particles += ground_T 257 | xyz_min = self.meta_editing_dict['xyz_min'] - min_offset 258 | xyz_max = self.meta_editing_dict['xyz_max'] + max_offset 259 | bbox_particles_idx = ((particles > xyz_min) & (particles < xyz_max)).all(dim=1) 260 | return bbox_particles_idx 261 | 262 | def get_gaussian_rotation(rot_mat, r): 263 | # Rotate unnormalized quaternion by rotation matrix, and gives back unnormalized quats 264 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 265 | q = r / norm[:, None] 266 | 267 | R = torch.zeros((q.size(0), 3, 3), device=r.device) 268 | 269 | r = q[:, 0] 270 | x = q[:, 1] 271 | y = q[:, 2] 272 | z = q[:, 3] 273 | 274 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 275 | R[:, 0, 1] = 2 * (x*y - r*z) 276 | R[:, 0, 2] = 2 * (x*z + r*y) 277 | R[:, 1, 0] = 2 * (x*y + r*z) 278 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 279 | R[:, 1, 2] = 2 * (y*z - r*x) 280 | R[:, 2, 0] = 2 * (x*z - r*y) 281 | R[:, 2, 1] = 2 * (y*z + r*x) 282 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 283 | 284 | R = rot_mat @ R.detach().cpu().numpy() 285 | 286 | # Convert back to quaternion 287 | r = Rotation.from_matrix(R).as_quat() 288 | r[:, [0, 1, 2, 3]] = r[:, [3, 0, 1, 2]] # x,y,z,w -> r,x,y,z 289 | r = torch.from_numpy(r).cuda().float() 290 | 291 | r = r * norm[:, None] 292 | return r 293 | 294 | -------------------------------------------------------------------------------- /feature_splatting/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def point_to_plane_distance(point, plane): 4 | """ 5 | Compute distance from point to plane 6 | :param point: point (x, y, z) 7 | :param plane: plane (A, B, C, D) 8 | :return: distance from point to plane 9 | """ 10 | x, y, z = point 11 | A, B, C, D = plane 12 | numerator = np.abs(A*x + B*y + C*z + D) 13 | denominator = np.sqrt(A**2 + B**2 + C**2) 14 | distance = numerator / denominator 15 | return distance 16 | 17 | def vector_angle(vec_a, vec_b): 18 | """ 19 | Calculate angle between two vectors 20 | :param vec_a: vector a 21 | :param vec_b: vector b 22 | :return: angle between two vectors 23 | """ 24 | dot = np.dot(vec_a, vec_b) 25 | norm_a = np.linalg.norm(vec_a) 26 | norm_b = np.linalg.norm(vec_b) 27 | cos_theta = dot / (norm_a * norm_b) 28 | theta = np.arccos(cos_theta) 29 | return theta 30 | -------------------------------------------------------------------------------- /feature_splatting/utils/mpm_engine/__init__.py: -------------------------------------------------------------------------------- 1 | from . import mpm_solver 2 | -------------------------------------------------------------------------------- /feature_splatting/utils/mpm_engine/mesh_io.py: -------------------------------------------------------------------------------- 1 | from plyfile import PlyData 2 | import numpy as np 3 | 4 | 5 | def load_mesh(fn, scale=1, offset=(0, 0, 0)): 6 | if isinstance(scale, (int, float)): 7 | scale = (scale, scale, scale) 8 | print(f'loading {fn}') 9 | plydata = PlyData.read(fn) 10 | x = plydata['vertex']['x'] 11 | y = plydata['vertex']['y'] 12 | z = plydata['vertex']['z'] 13 | elements = plydata['face'] 14 | num_tris = len(elements['vertex_indices']) 15 | triangles = np.zeros((num_tris, 9), dtype=np.float32) 16 | 17 | print(f'num vertices: {len(x)}') 18 | 19 | for i, face in enumerate(elements['vertex_indices']): 20 | assert len(face) == 3 21 | for d in range(3): 22 | triangles[i, d * 3 + 0] = x[face[d]] * scale[0] + offset[0] 23 | triangles[i, d * 3 + 1] = y[face[d]] * scale[1] + offset[1] 24 | triangles[i, d * 3 + 2] = z[face[d]] * scale[2] + offset[2] 25 | 26 | return triangles 27 | 28 | 29 | def write_point_cloud(fn, pos_and_color): 30 | num_particles = len(pos_and_color) 31 | with open(fn, 'wb') as f: 32 | header = f"""ply 33 | format binary_little_endian 1.0 34 | comment Created by taichi 35 | element vertex {num_particles} 36 | property float x 37 | property float y 38 | property float z 39 | property uchar red 40 | property uchar green 41 | property uchar blue 42 | property uchar placeholder 43 | end_header 44 | """ 45 | f.write(str.encode(header)) 46 | f.write(pos_and_color.tobytes()) 47 | -------------------------------------------------------------------------------- /feature_splatting/utils/mpm_engine/mpm_solver.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import numpy as np 3 | import time 4 | import numbers 5 | import math 6 | import multiprocessing as mp 7 | 8 | USE_IN_BLENDER = False 9 | 10 | 11 | # TODO: water needs Jp - fix this. 12 | # TODO(roger): remove routines (e.g., add_mesh; flags such as quant) 13 | # where particle rotation and override are not supported. 14 | 15 | @ti.data_oriented 16 | class MPMSolver: 17 | material_water = 0 18 | material_elastic = 1 19 | material_snow = 2 20 | material_sand = 3 21 | material_stationary = 4 22 | materials = { 23 | 'WATER': material_water, 24 | 'ELASTIC': material_elastic, 25 | 'SNOW': material_snow, 26 | 'SAND': material_sand, 27 | 'STATIONARY': material_stationary, 28 | } 29 | 30 | # Surface boundary conditions 31 | 32 | # Stick to the boundary 33 | surface_sticky = 0 34 | # Slippy boundary 35 | surface_slip = 1 36 | # Slippy and free to separate 37 | surface_separate = 2 38 | 39 | surfaces = { 40 | 'STICKY': surface_sticky, 41 | 'SLIP': surface_slip, 42 | 'SEPARATE': surface_separate 43 | } 44 | 45 | def __init__( 46 | self, 47 | res, 48 | quant=False, 49 | use_voxelizer=True, 50 | size=1, 51 | max_num_particles=2**30, 52 | # Max 1 G particles 53 | padding=3, 54 | unbounded=False, 55 | dt_scale=1, 56 | E_scale=1, 57 | voxelizer_super_sample=2, 58 | use_g2p2g=False, # Ref: A massively parallel and scalable multi-GPU material point method 59 | v_clamp_g2p2g=True, 60 | use_bls=True, 61 | g2p2g_allowed_cfl=0.9, # 0.0 for no CFL limit 62 | water_density=1.0, 63 | support_plasticity=True, # Support snow and sand materials 64 | use_adaptive_dt=False, 65 | use_ggui=False, 66 | use_emitter_id=False, 67 | poisson_ratio=0.2 68 | ): 69 | self.dim = len(res) 70 | self.quant = quant 71 | self.use_g2p2g = use_g2p2g 72 | self.v_clamp_g2p2g = v_clamp_g2p2g 73 | self.use_bls = use_bls 74 | self.g2p2g_allowed_cfl = g2p2g_allowed_cfl 75 | self.water_density = water_density 76 | self.grid_size = 4096 77 | 78 | assert not self.quant, "Particle rotation is not supported in quant mode." 79 | assert self.dim == 3, "Rotation is only supported in 3D." 80 | assert not use_g2p2g, "Particle rotation is not supported in g2p2g mode." 81 | 82 | assert self.dim in ( 83 | 2, 3), "MPM solver supports only 2D and 3D simulations." 84 | 85 | self.t = 0.0 86 | self.res = res 87 | self.n_particles = ti.field(ti.i32, shape=()) 88 | self.dx = size / res[0] 89 | self.inv_dx = 1.0 / self.dx 90 | self.default_dt = 2e-2 * self.dx / size * dt_scale 91 | self.p_vol = self.dx**self.dim 92 | self.p_rho = 1000 93 | self.p_mass = self.p_vol * self.p_rho 94 | self.max_num_particles = max_num_particles 95 | self.gravity = ti.Vector.field(self.dim, dtype=ti.f32, shape=()) 96 | self.source_bound = ti.Vector.field(self.dim, dtype=ti.f32, shape=2) 97 | self.source_velocity = ti.Vector.field(self.dim, 98 | dtype=ti.f32, 99 | shape=()) 100 | self.input_grid = 0 101 | self.all_time_max_velocity = 0 102 | self.support_plasticity = support_plasticity 103 | self.use_adaptive_dt = use_adaptive_dt 104 | self.use_ggui = use_ggui 105 | self.F_bound = 4.0 106 | 107 | # Affine velocity field 108 | if not self.use_g2p2g: 109 | self.C = ti.Matrix.field(self.dim, self.dim, dtype=ti.f32) 110 | # Deformation gradient 111 | 112 | if quant: 113 | qft = ti.types.quant.fixed(21, max_value=2.0) 114 | self.x = ti.Vector.field(self.dim, dtype=qft) 115 | 116 | qft = ti.types.quant.float(exp=7, frac=19) 117 | self.v = ti.Vector.field(self.dim, dtype=qft) 118 | 119 | qft = ti.types.quant.fixed(16, max_value=(self.F_bound + 0.1)) 120 | self.F = ti.Matrix.field(self.dim, self.dim, dtype=qft) 121 | else: 122 | self.v = ti.Vector.field(self.dim, dtype=ti.f32) 123 | self.x = ti.Vector.field(self.dim, dtype=ti.f32) 124 | self.F = ti.Matrix.field(self.dim, self.dim, dtype=ti.f32) 125 | self.particle_R = ti.Matrix.field(self.dim, self.dim, dtype=ti.f32) 126 | self.particle_motion_override_flag = ti.field(dtype=ti.i32) 127 | 128 | self.use_emitter_id = use_emitter_id 129 | if self.use_emitter_id: 130 | self.emitter_ids = ti.field(dtype=ti.i32) 131 | 132 | self.last_time_final_particles = ti.field(dtype=ti.i32, shape=()) 133 | # Material id 134 | if quant and self.dim == 3: 135 | self.material = ti.field(dtype=ti.types.quant.int(16, False)) 136 | else: 137 | self.material = ti.field(dtype=ti.i32) 138 | # Particle color 139 | self.color = ti.field(dtype=ti.i32) 140 | if self.use_ggui: 141 | self.color_with_alpha = ti.Vector.field(4, dtype=ti.f32) 142 | # Plastic deformation volume ratio 143 | if self.support_plasticity: 144 | self.Jp = ti.field(dtype=ti.f32) 145 | 146 | if self.dim == 2: 147 | indices = ti.ij 148 | else: 149 | indices = ti.ijk 150 | 151 | if unbounded: 152 | # The maximum grid size must be larger than twice of 153 | # simulation resolution in an unbounded simulation, 154 | # Otherwise the top and right sides will be bounded by grid size 155 | while self.grid_size <= 2 * max(self.res): 156 | self.grid_size *= 2 # keep it power of two 157 | offset = tuple(-self.grid_size // 2 for _ in range(self.dim)) 158 | self.offset = offset 159 | 160 | self.num_grids = 2 if self.use_g2p2g else 1 161 | 162 | grid_block_size = 128 163 | if self.dim == 2: 164 | self.leaf_block_size = 16 165 | else: 166 | # TODO: use 8? 167 | self.leaf_block_size = 4 168 | 169 | self.grid = [] 170 | self.grid_v = [] 171 | self.grid_m = [] 172 | self.pid = [] 173 | 174 | for g in range(self.num_grids): 175 | # Grid node momentum/velocity 176 | grid_v = ti.Vector.field(self.dim, dtype=ti.f32) 177 | grid_m = ti.field(dtype=ti.f32) 178 | pid = ti.field(ti.i32) 179 | self.grid_v.append(grid_v) 180 | # Grid node mass 181 | self.grid_m.append(grid_m) 182 | grid = ti.root.pointer(indices, self.grid_size // grid_block_size) 183 | block = grid.pointer(indices, 184 | grid_block_size // self.leaf_block_size) 185 | self.block = block 186 | self.grid.append(grid) 187 | 188 | def block_component(c): 189 | block.dense(indices, self.leaf_block_size).place(c, 190 | offset=offset) 191 | 192 | block_component(grid_m) 193 | for d in range(self.dim): 194 | block_component(grid_v.get_scalar_field(d)) 195 | 196 | self.pid.append(pid) 197 | 198 | block_offset = tuple(o // self.leaf_block_size 199 | for o in self.offset) 200 | self.block_offset = block_offset 201 | block.dynamic(ti.axes(self.dim), 202 | 1024 * 1024, 203 | chunk_size=self.leaf_block_size**self.dim * 8).place( 204 | pid, offset=block_offset + (0, )) 205 | 206 | self.padding = padding 207 | 208 | # Young's modulus and Poisson's ratio 209 | self.E, self.nu = 1e6 * size * E_scale, poisson_ratio 210 | # Lame parameters 211 | self.mu_0, self.lambda_0 = self.E / ( 212 | 2 * (1 + self.nu)), self.E * self.nu / ((1 + self.nu) * 213 | (1 - 2 * self.nu)) 214 | 215 | # Sand parameters 216 | friction_angle = math.radians(45) 217 | sin_phi = math.sin(friction_angle) 218 | self.alpha = math.sqrt(2 / 3) * 2 * sin_phi / (3 - sin_phi) 219 | 220 | # An empirically optimal chunk size is 1/10 of the expected particle number 221 | chunk_size = 2**20 if self.dim == 2 else 2**23 222 | 223 | # https://docs.taichi-lang.org/docs/sparse#dynamic-snode 224 | self.particle = ti.root.dynamic(ti.i, max_num_particles, chunk_size) 225 | 226 | if self.quant: 227 | if not self.use_g2p2g: 228 | self.particle.place(self.C) 229 | if self.support_plasticity: 230 | self.particle.place(self.Jp) 231 | bitpack = ti.BitpackedFields(max_num_bits=64) 232 | bitpack.place(self.x) 233 | self.particle.place(bitpack) 234 | bitpack = ti.BitpackedFields(max_num_bits=64) 235 | bitpack.place(self.v, shared_exponent=True) 236 | self.particle.place(bitpack) 237 | 238 | if self.dim == 3: 239 | bitpack = ti.BitpackedFields(max_num_bits=32) 240 | bitpack.place(self.F.get_scalar_field(0, 0), 241 | self.F.get_scalar_field(0, 1)) 242 | self.particle.place(bitpack) 243 | bitpack = ti.BitpackedFields(max_num_bits=32) 244 | bitpack.place(self.F.get_scalar_field(0, 2), 245 | self.F.get_scalar_field(1, 0)) 246 | self.particle.place(bitpack) 247 | bitpack = ti.BitpackedFields(max_num_bits=32) 248 | bitpack.place(self.F.get_scalar_field(1, 1), 249 | self.F.get_scalar_field(1, 2)) 250 | self.particle.place(bitpack) 251 | bitpack = ti.BitpackedFields(max_num_bits=32) 252 | bitpack.place(self.F.get_scalar_field(2, 0), 253 | self.F.get_scalar_field(2, 1)) 254 | self.particle.place(bitpack) 255 | bitpack = ti.BitpackedFields(max_num_bits=32) 256 | bitpack.place(self.F.get_scalar_field(2, 2), self.material) 257 | self.particle.place(bitpack) 258 | else: 259 | assert self.dim == 2 260 | bitpack = ti.BitpackedFields(max_num_bits=32) 261 | bitpack.place(self.F.get_scalar_field(0, 0), 262 | self.F.get_scalar_field(0, 1)) 263 | self.particle.place(bitpack) 264 | bitpack = ti.BitpackedFields(max_num_bits=32) 265 | bitpack.place(self.F.get_scalar_field(1, 0), 266 | self.F.get_scalar_field(1, 1)) 267 | self.particle.place(bitpack) 268 | # No quantization on particle material in 2D 269 | self.particle.place(self.material) 270 | self.particle.place(self.color) 271 | if self.use_emitter_id: 272 | self.particle.place(self.emitter_ids) 273 | else: 274 | if self.use_emitter_id: 275 | self.particle.place(self.x, self.v, self.F, self.material, 276 | self.color, self.emitter_ids, self.particle_R, 277 | self.particle_motion_override_flag) 278 | else: 279 | self.particle.place(self.x, self.v, self.F, self.material, 280 | self.color, self.particle_R, self.particle_motion_override_flag) 281 | if self.support_plasticity: 282 | self.particle.place(self.Jp) 283 | if not self.use_g2p2g: 284 | self.particle.place(self.C) 285 | 286 | if self.use_ggui: 287 | self.particle.place(self.color_with_alpha) 288 | 289 | self.total_substeps = 0 290 | self.unbounded = unbounded 291 | 292 | if self.dim == 2: 293 | self.voxelizer = None 294 | self.set_gravity((0, -9.8)) 295 | else: 296 | if use_voxelizer: 297 | from .voxelizer import Voxelizer 298 | self.voxelizer = Voxelizer(res=self.res, 299 | dx=self.dx, 300 | padding=self.padding, 301 | super_sample=voxelizer_super_sample) 302 | else: 303 | self.voxelizer = None 304 | self.set_gravity((0, -9.8, 0)) 305 | 306 | self.voxelizer_super_sample = voxelizer_super_sample 307 | 308 | self.grid_postprocess = [] 309 | 310 | self.add_bounding_box(self.unbounded) 311 | 312 | self.writers = [] 313 | 314 | if not self.use_g2p2g: 315 | self.grid = self.grid[0] 316 | self.grid_v = self.grid_v[0] 317 | self.grid_m = self.grid_m[0] 318 | self.pid = self.pid[0] 319 | 320 | @ti.func 321 | def stencil_range(self): 322 | return ti.ndrange(*((3, ) * self.dim)) 323 | 324 | def set_gravity(self, g): 325 | assert isinstance(g, (tuple, list)) 326 | assert len(g) == self.dim 327 | self.gravity[None] = g 328 | 329 | @ti.func 330 | def sand_projection(self, sigma, p): 331 | sigma_out = ti.Matrix.zero(ti.f32, self.dim, self.dim) 332 | epsilon = ti.Vector.zero(ti.f32, self.dim) 333 | for i in ti.static(range(self.dim)): 334 | epsilon[i] = ti.log(max(abs(sigma[i, i]), 1e-4)) 335 | sigma_out[i, i] = 1 336 | tr = epsilon.sum() + self.Jp[p] 337 | epsilon_hat = epsilon - tr / self.dim 338 | epsilon_hat_norm = epsilon_hat.norm() + 1e-20 339 | if tr >= 0.0: 340 | self.Jp[p] = tr 341 | else: 342 | self.Jp[p] = 0.0 343 | delta_gamma = epsilon_hat_norm + ( 344 | self.dim * self.lambda_0 + 345 | 2 * self.mu_0) / (2 * self.mu_0) * tr * self.alpha 346 | for i in ti.static(range(self.dim)): 347 | sigma_out[i, i] = ti.exp(epsilon[i] - max(0, delta_gamma) / 348 | epsilon_hat_norm * epsilon_hat[i]) 349 | 350 | return sigma_out 351 | 352 | @ti.kernel 353 | def build_pid(self, pid: ti.template(), grid_m: ti.template(), 354 | offset: ti.template()): 355 | """ 356 | grid has blocking (e.g. 4x4x4), we wish to put the particles from each block into a GPU block, 357 | then used shared memory (ti.block_local) to accelerate 358 | :param pid: 359 | :param grid_m: 360 | :param offset: 361 | :return: 362 | """ 363 | ti.loop_config(block_dim=64) 364 | for p in self.x: 365 | base = int(ti.floor(self.x[p] * self.inv_dx - 0.5)) \ 366 | - ti.Vector(self.offset) 367 | # Pid grandparent is `block` 368 | base_pid = ti.rescale_index(grid_m, pid.parent(2), base) 369 | ti.append(pid.parent(), base_pid, p) 370 | 371 | @ti.kernel 372 | def g2p2g(self, dt: ti.f32, pid: ti.template(), grid_v_in: ti.template(), 373 | grid_v_out: ti.template(), grid_m_out: ti.template()): 374 | ti.loop_config(block_dim=256) 375 | ti.no_activate(self.particle) 376 | if ti.static(self.use_bls): 377 | ti.block_local(grid_m_out) 378 | for d in ti.static(range(self.dim)): 379 | ti.block_local(grid_v_in.get_scalar_field(d)) 380 | ti.block_local(grid_v_out.get_scalar_field(d)) 381 | for I in ti.grouped(pid): 382 | p = pid[I] 383 | # G2P 384 | base = ti.floor(self.x[p] * self.inv_dx - 0.5).cast(int) 385 | Im = ti.rescale_index(pid, grid_m_out, I) 386 | for D in ti.static(range(self.dim)): 387 | base[D] = ti.assume_in_range(base[D], Im[D], 0, 1) 388 | fx = self.x[p] * self.inv_dx - base.cast(float) 389 | w = [ 390 | 0.5 * (1.5 - fx)**2, 0.75 - (fx - 1.0)**2, 0.5 * (fx - 0.5)**2 391 | ] 392 | new_v = ti.Vector.zero(ti.f32, self.dim) 393 | C = ti.Matrix.zero(ti.f32, self.dim, self.dim) 394 | # Loop over 3x3 grid node neighborhood 395 | for offset in ti.static(ti.grouped(self.stencil_range())): 396 | dpos = offset.cast(float) - fx 397 | g_v = grid_v_in[base + offset] 398 | weight = 1.0 399 | for d in ti.static(range(self.dim)): 400 | weight *= w[offset[d]][d] 401 | new_v += weight * g_v 402 | C += 4 * self.inv_dx * weight * g_v.outer_product(dpos) 403 | 404 | if p >= self.last_time_final_particles[None]: 405 | # New particles. No G2P. 406 | new_v = self.v[p] 407 | C = ti.Matrix.zero(ti.f32, self.dim, self.dim) 408 | 409 | if self.material[p] != self.material_stationary: 410 | self.v[p] = new_v 411 | self.x[p] += dt * self.v[p] # advection 412 | 413 | # P2G 414 | base = ti.floor(self.x[p] * self.inv_dx - 0.5).cast(int) 415 | for D in ti.static(range(self.dim)): 416 | base[D] = ti.assume_in_range(base[D], Im[D], -1, 2) 417 | 418 | fx = self.x[p] * self.inv_dx - float(base) 419 | # Quadratic kernels [http://mpm.graphics Eqn. 123, with x=fx, fx-1,fx-2] 420 | w2 = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] 421 | # Deformation gradient update 422 | new_F = (ti.Matrix.identity(ti.f32, self.dim) + dt * C) @ self.F[p] 423 | if ti.static(self.quant): 424 | new_F = max(-self.F_bound, min(self.F_bound, new_F)) 425 | self.F[p] = new_F 426 | # Hardening coefficient: snow gets harder when compressed 427 | h = 1.0 428 | if ti.static(self.support_plasticity): 429 | h = ti.exp(10 * (1.0 - self.Jp[p])) 430 | if self.material[ 431 | p] == self.material_elastic: # Jelly, make it softer 432 | h = 0.3 433 | mu, la = self.mu_0 * h, self.lambda_0 * h 434 | if self.material[p] == self.material_water: # Liquid 435 | mu = 0.0 436 | U, sig, V = ti.svd(self.F[p]) 437 | J = 1.0 438 | if self.material[p] != self.material_sand: 439 | for d in ti.static(range(self.dim)): 440 | new_sig = sig[d, d] 441 | if self.material[p] == self.material_snow: # Snow 442 | new_sig = min(max(sig[d, d], 1 - 2.5e-2), 443 | 1 + 4.5e-3) # Plasticity 444 | if ti.static(self.support_plasticity): 445 | self.Jp[p] *= sig[d, d] / new_sig 446 | sig[d, d] = new_sig 447 | J *= new_sig 448 | if self.material[p] == self.material_water: 449 | # Reset deformation gradient to avoid numerical instability 450 | new_F = ti.Matrix.identity(ti.f32, self.dim) 451 | new_F[0, 0] = J 452 | self.F[p] = new_F 453 | elif self.material[p] == self.material_snow: 454 | # Reconstruct elastic deformation gradient after plasticity 455 | self.F[p] = U @ sig @ V.transpose() 456 | 457 | stress = ti.Matrix.zero(ti.f32, self.dim, self.dim) 458 | 459 | if self.material[p] != self.material_sand: 460 | stress = 2 * mu * ( 461 | self.F[p] - U @ V.transpose()) @ self.F[p].transpose( 462 | ) + ti.Matrix.identity(ti.f32, self.dim) * la * J * (J - 1) 463 | else: 464 | if ti.static(self.support_plasticity): 465 | sig = self.sand_projection(sig, p) 466 | self.F[p] = U @ sig @ V.transpose() 467 | log_sig_sum = 0.0 468 | center = ti.Matrix.zero(ti.f32, self.dim, self.dim) 469 | for i in ti.static(range(self.dim)): 470 | log_sig_sum += ti.log(sig[i, i]) 471 | center[i, i] = 2.0 * self.mu_0 * ti.log( 472 | sig[i, i]) * (1 / sig[i, i]) 473 | for i in ti.static(range(self.dim)): 474 | center[i, 475 | i] += self.lambda_0 * log_sig_sum * (1 / 476 | sig[i, i]) 477 | stress = U @ center @ V.transpose() @ self.F[p].transpose() 478 | 479 | stress = (-dt * self.p_vol * 4 * self.inv_dx**2) * stress 480 | affine = stress + self.p_mass * C 481 | 482 | # Loop over 3x3 grid node neighborhood 483 | for offset in ti.static(ti.grouped(self.stencil_range())): 484 | dpos = (offset.cast(float) - fx) * self.dx 485 | weight = 1.0 486 | for d in ti.static(range(self.dim)): 487 | weight *= w2[offset[d]][d] 488 | grid_v_out[base + 489 | offset] += weight * (self.p_mass * self.v[p] + 490 | affine @ dpos) 491 | grid_m_out[base + offset] += weight * self.p_mass 492 | 493 | self.last_time_final_particles[None] = self.n_particles[None] 494 | 495 | @ti.kernel 496 | def p2g(self, dt: ti.f32): 497 | ti.no_activate(self.particle) 498 | ti.loop_config(block_dim=256) 499 | if ti.static(self.use_bls): 500 | for d in ti.static(range(self.dim)): 501 | ti.block_local(self.grid_v.get_scalar_field(d)) 502 | ti.block_local(self.grid_m) 503 | for I in ti.grouped(self.pid): 504 | p = self.pid[I] 505 | base = ti.floor(self.x[p] * self.inv_dx - 0.5).cast(int) 506 | Im = ti.rescale_index(self.pid, self.grid_m, I) 507 | for D in ti.static(range(self.dim)): 508 | # For block shared memory: hint compiler that there is a connection between `base` and loop index `I` 509 | base[D] = ti.assume_in_range(base[D], Im[D], 0, 1) 510 | 511 | fx = self.x[p] * self.inv_dx - base.cast(float) 512 | # Quadratic kernels [http://mpm.graphics Eqn. 123, with x=fx, fx-1,fx-2] 513 | w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] 514 | # Deformation gradient update 515 | F = self.F[p] 516 | if self.material[p] == self.material_water: # liquid 517 | F = ti.Matrix.identity(ti.f32, self.dim) 518 | if ti.static(self.support_plasticity): 519 | F[0, 0] = self.Jp[p] 520 | 521 | F = (ti.Matrix.identity(ti.f32, self.dim) + dt * self.C[p]) @ F 522 | # Hardening coefficient: snow gets harder when compressed 523 | h = 1.0 524 | if ti.static(self.support_plasticity): 525 | if self.material[p] != self.material_water: 526 | h = ti.exp(10 * (1.0 - self.Jp[p])) 527 | if self.material[ 528 | p] == self.material_elastic: # jelly, make it softer 529 | h = 0.3 530 | mu, la = self.mu_0 * h, self.lambda_0 * h 531 | if self.material[p] == self.material_water: # liquid 532 | mu = 0.0 533 | U, sig, V = ti.svd(F) 534 | J = 1.0 535 | if self.material[p] != self.material_sand: 536 | for d in ti.static(range(self.dim)): 537 | new_sig = sig[d, d] 538 | if self.material[p] == self.material_snow: # Snow 539 | new_sig = min(max(sig[d, d], 1 - 2.5e-2), 540 | 1 + 4.5e-3) # Plasticity 541 | if ti.static(self.support_plasticity): 542 | self.Jp[p] *= sig[d, d] / new_sig 543 | sig[d, d] = new_sig 544 | J *= new_sig 545 | if self.material[p] == self.material_water: 546 | # Reset deformation gradient to avoid numerical instability 547 | F = ti.Matrix.identity(ti.f32, self.dim) 548 | F[0, 0] = J 549 | if ti.static(self.support_plasticity): 550 | self.Jp[p] = J 551 | elif self.material[p] == self.material_snow: 552 | # Reconstruct elastic deformation gradient after plasticity 553 | F = U @ sig @ V.transpose() 554 | 555 | stress = ti.Matrix.zero(ti.f32, self.dim, self.dim) 556 | 557 | if self.material[p] != self.material_sand: 558 | stress = 2 * mu * (F - U @ V.transpose()) @ F.transpose( 559 | ) + ti.Matrix.identity(ti.f32, self.dim) * la * J * (J - 1) 560 | else: 561 | if ti.static(self.support_plasticity): 562 | sig = self.sand_projection(sig, p) 563 | F = U @ sig @ V.transpose() 564 | log_sig_sum = 0.0 565 | center = ti.Matrix.zero(ti.f32, self.dim, self.dim) 566 | for i in ti.static(range(self.dim)): 567 | log_sig_sum += ti.log(sig[i, i]) 568 | center[i, i] = 2.0 * self.mu_0 * ti.log( 569 | sig[i, i]) * (1 / sig[i, i]) 570 | for i in ti.static(range(self.dim)): 571 | center[i, 572 | i] += self.lambda_0 * log_sig_sum * (1 / 573 | sig[i, i]) 574 | stress = U @ center @ V.transpose() @ F.transpose() 575 | self.F[p] = F 576 | 577 | # Compute particle rotation 578 | if ti.math.determinant(U) < 0.0: 579 | U[0, 2] = -U[0, 2] 580 | U[1, 2] = -U[1, 2] 581 | U[2, 2] = -U[2, 2] 582 | 583 | if ti.math.determinant(V) < 0.0: 584 | V[0, 2] = -V[0, 2] 585 | V[1, 2] = -V[1, 2] 586 | V[2, 2] = -V[2, 2] 587 | 588 | R = U @ V.transpose() 589 | self.particle_R[p] = R.transpose() 590 | 591 | stress = (-dt * self.p_vol * 4 * self.inv_dx**2) * stress 592 | # TODO: implement g2p2g pmass 593 | mass = self.p_mass 594 | if self.material[p] == self.material_water: 595 | mass *= self.water_density 596 | affine = stress + mass * self.C[p] 597 | 598 | # Loop over 3x3 grid node neighborhood 599 | for offset in ti.static(ti.grouped(self.stencil_range())): 600 | dpos = (offset.cast(float) - fx) * self.dx 601 | weight = 1.0 602 | for d in ti.static(range(self.dim)): 603 | weight *= w[offset[d]][d] 604 | self.grid_v[base + offset] += weight * (mass * self.v[p] + 605 | affine @ dpos) 606 | self.grid_m[base + offset] += weight * mass 607 | 608 | @ti.kernel 609 | def grid_normalization_and_gravity(self, dt: ti.f32, grid_v: ti.template(), 610 | grid_m: ti.template()): 611 | v_allowed = self.dx * self.g2p2g_allowed_cfl / dt 612 | for I in ti.grouped(grid_m): 613 | if grid_m[I] > 0: # No need for epsilon here 614 | grid_v[I] = (1 / grid_m[I]) * grid_v[I] # Momentum to velocity 615 | grid_v[I] += dt * self.gravity[None] 616 | 617 | # Grid velocity clamping 618 | if ti.static(self.g2p2g_allowed_cfl > 0 and self.use_g2p2g 619 | and self.v_clamp_g2p2g): 620 | grid_v[I] = min(max(grid_v[I], -v_allowed), v_allowed) 621 | 622 | @ti.kernel 623 | def grid_bounding_box(self, t: ti.f32, dt: ti.f32, 624 | unbounded: ti.template(), grid_v: ti.template()): 625 | for I in ti.grouped(grid_v): 626 | for d in ti.static(range(self.dim)): 627 | if ti.static(unbounded): 628 | if I[d] < -self.grid_size // 2 + self.padding and grid_v[ 629 | I][d] < 0: 630 | grid_v[I][d] = 0 # Boundary conditions 631 | if I[d] >= self.grid_size // 2 - self.padding and grid_v[ 632 | I][d] > 0: 633 | grid_v[I][d] = 0 634 | else: 635 | if I[d] < self.padding and grid_v[I][d] < 0: 636 | grid_v[I][d] = 0 # Boundary conditions 637 | if I[d] >= self.res[d] - self.padding and grid_v[I][d] > 0: 638 | grid_v[I][d] = 0 639 | 640 | def add_sphere_collider(self, center, radius, surface=surface_sticky): 641 | center = list(center) 642 | 643 | @ti.kernel 644 | def collide(t: ti.f32, dt: ti.f32, grid_v: ti.template()): 645 | for I in ti.grouped(grid_v): 646 | offset = I * self.dx - ti.Vector(center) 647 | if offset.norm_sqr() < radius * radius: 648 | if ti.static(surface == self.surface_sticky): 649 | grid_v[I] = ti.Vector.zero(ti.f32, self.dim) 650 | else: 651 | v = grid_v[I] 652 | normal = offset.normalized(1e-5) 653 | normal_component = normal.dot(v) 654 | 655 | if ti.static(surface == self.surface_slip): 656 | # Project out all normal component 657 | v = v - normal * normal_component 658 | else: 659 | # Project out only inward normal component 660 | v = v - normal * min(normal_component, 0) 661 | 662 | grid_v[I] = v 663 | 664 | self.grid_postprocess.append(collide) 665 | 666 | def clear_grid_postprocess(self): 667 | self.grid_postprocess.clear() 668 | 669 | def add_surface_collider(self, 670 | point, 671 | normal, 672 | surface=surface_sticky, 673 | friction=0.0): 674 | point = list(point) 675 | # Normalize normal 676 | normal_scale = 1.0 / math.sqrt(sum(x**2 for x in normal)) 677 | normal = list(normal_scale * x for x in normal) 678 | 679 | if surface == self.surface_sticky and friction != 0: 680 | raise ValueError('friction must be 0 on sticky surfaces.') 681 | 682 | @ti.kernel 683 | def collide(t: ti.f32, dt: ti.f32, grid_v: ti.template()): 684 | for I in ti.grouped(grid_v): 685 | offset = I * self.dx - ti.Vector(point) 686 | n = ti.Vector(normal) 687 | if offset.dot(n) < 0: 688 | if ti.static(surface == self.surface_sticky): 689 | grid_v[I] = ti.Vector.zero(ti.f32, self.dim) 690 | else: 691 | v = grid_v[I] 692 | normal_component = n.dot(v) 693 | 694 | if ti.static(surface == self.surface_slip): 695 | # Project out all normal component 696 | v = v - n * normal_component 697 | else: 698 | # Project out only inward normal component 699 | v = v - n * min(normal_component, 0) 700 | 701 | if normal_component < 0 and v.norm() > 1e-30: 702 | # Apply friction here 703 | v = v.normalized() * max( 704 | 0, 705 | v.norm() + normal_component * friction) 706 | 707 | grid_v[I] = v 708 | 709 | self.grid_postprocess.append(collide) 710 | 711 | def add_bounding_box(self, unbounded): 712 | self.grid_postprocess.append( 713 | lambda t, dt, grid_v: self.grid_bounding_box( 714 | t, dt, unbounded, grid_v)) 715 | 716 | @ti.kernel 717 | def g2p(self, dt: ti.f32): 718 | ti.loop_config(block_dim=256) 719 | if ti.static(self.use_bls): 720 | for d in ti.static(range(self.dim)): 721 | ti.block_local(self.grid_v.get_scalar_field(d)) 722 | ti.no_activate(self.particle) 723 | for I in ti.grouped(self.pid): 724 | p = self.pid[I] 725 | base = ti.floor(self.x[p] * self.inv_dx - 0.5).cast(int) 726 | Im = ti.rescale_index(self.pid, self.grid_m, I) 727 | for D in ti.static(range(self.dim)): 728 | base[D] = ti.assume_in_range(base[D], Im[D], 0, 1) 729 | fx = self.x[p] * self.inv_dx - base.cast(float) 730 | w = [ 731 | 0.5 * (1.5 - fx)**2, 0.75 - (fx - 1.0)**2, 0.5 * (fx - 0.5)**2 732 | ] 733 | new_v = ti.Vector.zero(ti.f32, self.dim) 734 | new_C = ti.Matrix.zero(ti.f32, self.dim, self.dim) 735 | # Loop over 3x3 grid node neighborhood 736 | for offset in ti.static(ti.grouped(self.stencil_range())): 737 | dpos = offset.cast(float) - fx 738 | g_v = self.grid_v[base + offset] 739 | weight = 1.0 740 | for d in ti.static(range(self.dim)): 741 | weight *= w[offset[d]][d] 742 | new_v += weight * g_v 743 | new_C += 4 * self.inv_dx * weight * g_v.outer_product(dpos) 744 | if self.material[p] != self.material_stationary: 745 | self.v[p], self.C[p] = new_v, new_C 746 | 747 | @ti.kernel 748 | def apply_v_to_pos(self, dt: ti.f32): 749 | for I in ti.grouped(self.pid): 750 | p = self.pid[I] 751 | if self.material[p] != self.material_stationary: 752 | self.x[p] += dt * self.v[p] # advection 753 | 754 | @ti.kernel 755 | def particle_motion_override(self, dt: ti.f32, v_x: ti.f32, v_y: ti.f32, v_z: ti.f32): 756 | for I in ti.grouped(self.pid): 757 | p = self.pid[I] 758 | if self.particle_motion_override_flag[p] == 1: 759 | self.v[p][0] = v_x 760 | self.v[p][1] = v_y 761 | self.v[p][2] = v_z 762 | 763 | @ti.kernel 764 | def compute_max_velocity(self) -> ti.f32: 765 | max_velocity = 0.0 766 | for p in self.v: 767 | v = self.v[p] 768 | v_max = 0.0 769 | for i in ti.static(range(self.dim)): 770 | v_max = max(v_max, abs(v[i])) 771 | ti.atomic_max(max_velocity, v_max) 772 | return max_velocity 773 | 774 | @ti.kernel 775 | def compute_max_grid_velocity(self, grid_v: ti.template()) -> ti.f32: 776 | max_velocity = 0.0 777 | for I in ti.grouped(grid_v): 778 | v = grid_v[I] 779 | v_max = 0.0 780 | for i in ti.static(range(self.dim)): 781 | v_max = max(v_max, abs(v[i])) 782 | ti.atomic_max(max_velocity, v_max) 783 | return max_velocity 784 | 785 | def step(self, frame_dt, print_stat=False, smry_writer=None, override_velocity=None): 786 | begin_t = time.time() 787 | begin_substep = self.total_substeps 788 | 789 | substeps = int(frame_dt / self.default_dt) + 1 790 | 791 | dt = frame_dt / substeps 792 | frame_time_left = frame_dt 793 | if print_stat: 794 | print(f'needed substeps: {substeps}') 795 | 796 | if override_velocity is None: 797 | override_velocity = [0, 0, 0] 798 | 799 | while frame_time_left > 0: 800 | if print_stat: 801 | print('.', end='', flush=True) 802 | self.total_substeps += 1 803 | if self.use_adaptive_dt: 804 | if self.use_g2p2g: 805 | max_grid_v = self.compute_max_grid_velocity( 806 | self.grid_v[self.input_grid]) 807 | else: 808 | max_grid_v = self.compute_max_grid_velocity( 809 | self.grid_v) 810 | cfl_dt = self.g2p2g_allowed_cfl * self.dx / (max_grid_v + 1e-6) 811 | dt = min(dt, cfl_dt, frame_time_left) 812 | frame_time_left -= dt 813 | 814 | if self.use_g2p2g: 815 | output_grid = 1 - self.input_grid 816 | self.grid[output_grid].deactivate_all() 817 | self.build_pid(self.pid[self.input_grid], 818 | self.grid_m[self.input_grid], 0.5) 819 | self.g2p2g(dt, self.pid[self.input_grid], 820 | self.grid_v[self.input_grid], 821 | self.grid_v[output_grid], self.grid_m[output_grid]) 822 | self.grid_normalization_and_gravity(dt, 823 | self.grid_v[output_grid], 824 | self.grid_m[output_grid]) 825 | for p in self.grid_postprocess: 826 | p(self.t, dt, self.grid_v[output_grid]) 827 | self.input_grid = output_grid 828 | self.t += dt 829 | else: 830 | self.grid.deactivate_all() 831 | self.build_pid(self.pid, self.grid_m, 0.5) 832 | self.p2g(dt) 833 | self.grid_normalization_and_gravity(dt, self.grid_v, 834 | self.grid_m) 835 | for p in self.grid_postprocess: 836 | p(self.t, dt, self.grid_v) 837 | self.t += dt 838 | self.g2p(dt) 839 | self.particle_motion_override(dt, override_velocity[0], override_velocity[1], override_velocity[2]) 840 | self.apply_v_to_pos(dt) 841 | 842 | cur_frame_velocity = self.compute_max_velocity() 843 | if smry_writer is not None: 844 | smry_writer.add_scalar("substep_max_CFL", 845 | cur_frame_velocity * dt / self.dx, 846 | self.total_substeps) 847 | self.all_time_max_velocity = max(self.all_time_max_velocity, 848 | cur_frame_velocity) 849 | 850 | if print_stat: 851 | print() 852 | ti.profiler.print_kernel_profiler_info() 853 | try: 854 | ti.profiler.print_memory_profiler_info() 855 | except: 856 | pass 857 | cur_frame_velocity = self.compute_max_velocity() 858 | print(f'CFL: {cur_frame_velocity * dt / self.dx}') 859 | print(f'num particles={self.n_particles[None]}') 860 | print(f' frame time {time.time() - begin_t:.3f} s') 861 | print( 862 | f' substep time {1000 * (time.time() - begin_t) / (self.total_substeps - begin_substep):.3f} ms' 863 | ) 864 | 865 | @ti.func 866 | def seed_particle(self, i, x, material, color, velocity, emmiter_id, motion_override_flag): 867 | self.x[i] = x 868 | self.v[i] = velocity 869 | self.F[i] = ti.Matrix.identity(ti.f32, self.dim) 870 | self.particle_R[i] = ti.Matrix.identity(ti.f32, self.dim) 871 | self.particle_motion_override_flag[i] = motion_override_flag 872 | self.color[i] = color 873 | self.material[i] = material 874 | 875 | if ti.static(self.support_plasticity): 876 | if material == self.material_sand: 877 | self.Jp[i] = 0 878 | else: 879 | self.Jp[i] = 1 880 | 881 | if ti.static(self.use_emitter_id): 882 | self.emitter_ids[i] = emmiter_id 883 | 884 | @ti.kernel 885 | def seed(self, new_particles: ti.i32, new_material: ti.i32, color: ti.i32): 886 | for i in range(self.n_particles[None], 887 | self.n_particles[None] + new_particles): 888 | self.material[i] = new_material 889 | x = ti.Vector.zero(ti.f32, self.dim) 890 | for k in ti.static(range(self.dim)): 891 | x[k] = self.source_bound[0][k] + ti.random( 892 | ) * self.source_bound[1][k] 893 | self.seed_particle(i, x, new_material, color, 894 | self.source_velocity[None], None) 895 | 896 | def set_source_velocity(self, velocity): 897 | if velocity is not None: 898 | velocity = list(velocity) 899 | assert len(velocity) == self.dim 900 | self.source_velocity[None] = velocity 901 | else: 902 | for i in range(self.dim): 903 | self.source_velocity[None][i] = 0 904 | 905 | def add_cube(self, 906 | lower_corner, 907 | cube_size, 908 | material, 909 | color=0xFFFFFF, 910 | sample_density=None, 911 | velocity=None): 912 | if sample_density is None: 913 | sample_density = 2**self.dim 914 | vol = 1 915 | for i in range(self.dim): 916 | vol = vol * cube_size[i] 917 | num_new_particles = int(sample_density * vol / self.dx**self.dim + 1) 918 | assert self.n_particles[ 919 | None] + num_new_particles <= self.max_num_particles 920 | 921 | for i in range(self.dim): 922 | self.source_bound[0][i] = lower_corner[i] 923 | self.source_bound[1][i] = cube_size[i] 924 | 925 | self.set_source_velocity(velocity=velocity) 926 | 927 | self.seed(num_new_particles, material, color) 928 | self.n_particles[None] += num_new_particles 929 | 930 | def add_ngon( 931 | self, 932 | sides, 933 | center, 934 | radius, 935 | angle, 936 | material, 937 | color=0xFFFFFF, 938 | sample_density=None, 939 | velocity=None, 940 | ): 941 | if self.dim != 2: 942 | raise ValueError("Add Ngon only works for 2D simulations") 943 | 944 | if sample_density is None: 945 | sample_density = 2**self.dim 946 | 947 | num_particles = 0.5 * (radius * self.inv_dx)**2 * math.sin( 948 | 2 * math.pi / sides) * sides 949 | 950 | num_particles = int(math.ceil(num_particles * sample_density)) 951 | 952 | self.source_bound[0] = center 953 | self.source_bound[1] = [radius, radius] 954 | 955 | self.set_source_velocity(velocity=velocity) 956 | 957 | assert self.n_particles[None] + num_particles <= self.max_num_particles 958 | 959 | self.seed_polygon(num_particles, sides, angle, material, color) 960 | self.n_particles[None] += num_particles 961 | 962 | @ti.func 963 | def random_point_in_unit_polygon(self, sides, angle): 964 | point = ti.Vector.zero(ti.f32, 2) 965 | central_angle = 2 * math.pi / sides 966 | while True: 967 | point = ti.Vector([ti.random(), ti.random()]) * 2 - 1 968 | point_angle = ti.atan2(point.y, point.x) 969 | theta = (point_angle - 970 | angle) % central_angle # polygon angle is from +X axis 971 | phi = central_angle / 2 972 | dist = ti.sqrt((point**2).sum()) 973 | if dist < ti.cos(phi) / ti.cos(phi - theta): 974 | break 975 | return point 976 | 977 | @ti.kernel 978 | def seed_polygon(self, new_particles: ti.i32, sides: ti.i32, angle: ti.f32, 979 | new_material: ti.i32, color: ti.i32): 980 | for i in range(self.n_particles[None], 981 | self.n_particles[None] + new_particles): 982 | x = self.random_point_in_unit_polygon(sides, angle) 983 | x = self.source_bound[0] + x * self.source_bound[1] 984 | self.seed_particle(i, x, new_material, color, 985 | self.source_velocity[None], None) 986 | 987 | @ti.kernel 988 | def add_texture_2d( 989 | self, 990 | offset_x: ti.f32, 991 | offset_y: ti.f32, 992 | texture: ti.types.ndarray(), 993 | new_material: ti.i32, 994 | color: ti.i32, 995 | ): 996 | for i, j in ti.ndrange(texture.shape[0], texture.shape[1]): 997 | if texture[i, j] > 0.1: 998 | pid = ti.atomic_add(self.n_particles[None], 1) 999 | x = ti.Vector([offset_x + i * self.dx, offset_y + j * self.dx]) 1000 | self.seed_particle(pid, x, new_material, color, 1001 | self.source_velocity[None], None) 1002 | 1003 | @ti.func 1004 | def random_point_in_unit_sphere(self): 1005 | ret = ti.Vector.zero(ti.f32, n=self.dim) 1006 | while True: 1007 | for i in ti.static(range(self.dim)): 1008 | ret[i] = ti.random(ti.f32) * 2 - 1 1009 | if ret.norm_sqr() <= 1: 1010 | break 1011 | return ret 1012 | 1013 | @ti.kernel 1014 | def seed_ellipsoid(self, new_particles: ti.i32, new_material: ti.i32, 1015 | color: ti.i32): 1016 | 1017 | for i in range(self.n_particles[None], 1018 | self.n_particles[None] + new_particles): 1019 | x = self.source_bound[0] + self.random_point_in_unit_sphere( 1020 | ) * self.source_bound[1] 1021 | self.seed_particle(i, x, new_material, color, 1022 | self.source_velocity[None], None) 1023 | 1024 | def add_ellipsoid(self, 1025 | center, 1026 | radius, 1027 | material, 1028 | color=0xFFFFFF, 1029 | sample_density=None, 1030 | velocity=None): 1031 | if sample_density is None: 1032 | sample_density = 2**self.dim 1033 | 1034 | if isinstance(radius, numbers.Number): 1035 | radius = [ 1036 | radius, 1037 | ] * self.dim 1038 | 1039 | radius = list(radius) 1040 | 1041 | if self.dim == 2: 1042 | num_particles = math.pi 1043 | else: 1044 | num_particles = 4 / 3 * math.pi 1045 | 1046 | for i in range(self.dim): 1047 | num_particles *= radius[i] * self.inv_dx 1048 | 1049 | num_particles = int(math.ceil(num_particles * sample_density)) 1050 | 1051 | self.source_bound[0] = center 1052 | self.source_bound[1] = radius 1053 | 1054 | self.set_source_velocity(velocity=velocity) 1055 | 1056 | assert self.n_particles[None] + num_particles <= self.max_num_particles 1057 | 1058 | self.seed_ellipsoid(num_particles, material, color) 1059 | self.n_particles[None] += num_particles 1060 | 1061 | @ti.kernel 1062 | def seed_from_voxels( 1063 | self, 1064 | material: ti.i32, 1065 | color: ti.i32, 1066 | sample_density: ti.i32, 1067 | emmiter_id: ti.u16 1068 | ): 1069 | for i, j, k in self.voxelizer.voxels: 1070 | inside = 1 1071 | for d in ti.static(range(3)): 1072 | inside = inside and -self.grid_size // 2 + self.padding <= i and i < self.grid_size // 2 - self.padding 1073 | if inside and self.voxelizer.voxels[i, j, k] > 0: 1074 | s = sample_density / self.voxelizer_super_sample**self.dim 1075 | for l in range(sample_density + 1): 1076 | if ti.random() + l < s: 1077 | x = ti.Vector([ 1078 | ti.random() + i, 1079 | ti.random() + j, 1080 | ti.random() + k 1081 | ]) * (self.dx / self.voxelizer_super_sample 1082 | ) + self.source_bound[0] 1083 | p = ti.atomic_add(self.n_particles[None], 1) 1084 | self.seed_particle( 1085 | p, 1086 | x, 1087 | material, 1088 | color, 1089 | self.source_velocity[None], 1090 | emmiter_id 1091 | ) 1092 | 1093 | def add_mesh(self, 1094 | triangles, 1095 | material, 1096 | color=0xFFFFFF, 1097 | sample_density=None, 1098 | velocity=None, 1099 | translation=None, 1100 | emmiter_id=0 1101 | ): 1102 | assert self.dim == 3 1103 | if sample_density is None: 1104 | sample_density = 2**self.dim 1105 | 1106 | self.set_source_velocity(velocity=velocity) 1107 | 1108 | for i in range(self.dim): 1109 | if translation: 1110 | self.source_bound[0][i] = translation[i] 1111 | else: 1112 | self.source_bound[0][i] = 0 1113 | 1114 | self.voxelizer.voxelize(triangles) 1115 | t = time.time() 1116 | self.seed_from_voxels( 1117 | material, 1118 | color, 1119 | sample_density, 1120 | emmiter_id 1121 | ) 1122 | ti.sync() 1123 | # print('Voxelization time:', (time.time() - t) * 1000, 'ms') 1124 | 1125 | @ti.kernel 1126 | def seed_from_external_array_single_mat(self, num_particles: ti.i32, 1127 | pos: ti.types.ndarray(), new_material: ti.i32, 1128 | color: ti.i32, motion_override_flag: ti.types.ndarray()): 1129 | 1130 | for i in range(num_particles): 1131 | x = ti.Vector.zero(ti.f32, n=self.dim) 1132 | if ti.static(self.dim == 3): 1133 | x = ti.Vector([pos[i, 0], pos[i, 1], pos[i, 2]]) 1134 | else: 1135 | x = ti.Vector([pos[i, 0], pos[i, 1]]) 1136 | self.seed_particle(self.n_particles[None] + i, x, new_material, 1137 | color, self.source_velocity[None], None, motion_override_flag[i]) 1138 | 1139 | self.n_particles[None] += num_particles 1140 | 1141 | @ti.kernel 1142 | def seed_from_external_array_multiple_mat(self, num_particles: ti.i32, 1143 | pos: ti.types.ndarray(), new_material: ti.types.ndarray(), 1144 | color: ti.i32, motion_override_flag: ti.types.ndarray()): 1145 | 1146 | for i in range(num_particles): 1147 | x = ti.Vector.zero(ti.f32, n=self.dim) 1148 | if ti.static(self.dim == 3): 1149 | x = ti.Vector([pos[i, 0], pos[i, 1], pos[i, 2]]) 1150 | else: 1151 | x = ti.Vector([pos[i, 0], pos[i, 1]]) 1152 | self.seed_particle(self.n_particles[None] + i, x, new_material[i], 1153 | color, self.source_velocity[None], None, motion_override_flag[i]) 1154 | 1155 | self.n_particles[None] += num_particles 1156 | 1157 | def seed_from_external_array(self, num_particles, pos, material, color, motion_override_flag): 1158 | if isinstance(material, int): 1159 | self.seed_from_external_array_single_mat(num_particles, pos, material, color, motion_override_flag) 1160 | else: 1161 | self.seed_from_external_array_multiple_mat(num_particles, pos, material, color, motion_override_flag) 1162 | 1163 | def add_particles(self, 1164 | particles, 1165 | material, 1166 | color=0xFFFFFF, 1167 | velocity=None, 1168 | motion_override_flag_arr=None): 1169 | if motion_override_flag_arr is None: 1170 | motion_override_flag_arr = np.zeros(len(particles), dtype=np.int32) 1171 | self.set_source_velocity(velocity=velocity) 1172 | self.seed_from_external_array(len(particles), particles, material, 1173 | color, motion_override_flag_arr) 1174 | 1175 | @ti.kernel 1176 | def recover_from_external_array( 1177 | self, 1178 | num_particles: ti.i32, 1179 | pos: ti.types.ndarray(), 1180 | vel: ti.types.ndarray(), 1181 | material: ti.types.ndarray(), 1182 | color: ti.types.ndarray(), 1183 | ): 1184 | for i in range(num_particles): 1185 | x = ti.Vector.zero(ti.f32, n=self.dim) 1186 | v = ti.Vector.zero(ti.f32, n=self.dim) 1187 | if ti.static(self.dim == 3): 1188 | x = ti.Vector([pos[i, 0], pos[i, 1], pos[i, 2]]) 1189 | v = ti.Vector([vel[i, 0], vel[i, 1], vel[i, 2]]) 1190 | else: 1191 | x = ti.Vector([pos[i, 0], pos[i, 1]]) 1192 | v = ti.Vector([vel[i, 0], vel[i, 1]]) 1193 | self.seed_particle(self.n_particles[None] + i, x, material[i], 1194 | color[i], v, None) 1195 | self.n_particles[None] += num_particles 1196 | 1197 | def read_restart( 1198 | self, 1199 | num_particles, 1200 | pos, 1201 | vel, 1202 | material, 1203 | color, 1204 | ): 1205 | slice_size = 50000 1206 | num_slices = (num_particles + slice_size - 1) // slice_size 1207 | for s in range(num_slices): 1208 | begin = slice_size * s 1209 | end = min(slice_size * (s + 1), num_particles) 1210 | self.recover_from_external_array(end - begin, pos[begin:end], 1211 | vel[begin:end], 1212 | material[begin:end], 1213 | color[begin:end]) 1214 | 1215 | @ti.kernel 1216 | def copy_dynamic_nd(self, np_x: ti.types.ndarray(), input_x: ti.template()): 1217 | for i in self.x: 1218 | for j in ti.static(range(self.dim)): 1219 | np_x[i, j] = input_x[i][j] 1220 | 1221 | @ti.kernel 1222 | def copy_dynamic_nnd(self, np_x: ti.types.ndarray(), input_x: ti.template()): 1223 | for i in self.x: 1224 | for j in ti.static(range(self.dim)): 1225 | for k in ti.static(range(self.dim)): 1226 | np_x[i, j, k] = input_x[i][j, k] 1227 | 1228 | @ti.kernel 1229 | def copy_dynamic(self, np_x: ti.types.ndarray(), input_x: ti.template()): 1230 | for i in self.x: 1231 | np_x[i] = input_x[i] 1232 | 1233 | @ti.kernel 1234 | def copy_ranged(self, np_x: ti.types.ndarray(), input_x: ti.template(), 1235 | begin: ti.i32, end: ti.i32): 1236 | ti.no_activate(self.particle) 1237 | for i in range(begin, end): 1238 | np_x[i - begin] = input_x[i] 1239 | 1240 | @ti.kernel 1241 | def copy_ranged_nd(self, np_x: ti.types.ndarray(), input_x: ti.template(), 1242 | begin: ti.i32, end: ti.i32): 1243 | ti.no_activate(self.particle) 1244 | for i in range(begin, end): 1245 | for j in ti.static(range(self.dim)): 1246 | np_x[i - begin, j] = input_x[i, j] 1247 | 1248 | def particle_info(self): 1249 | np_x = np.ndarray((self.n_particles[None], self.dim), dtype=np.float32) 1250 | self.copy_dynamic_nd(np_x, self.x) 1251 | np_v = np.ndarray((self.n_particles[None], self.dim), dtype=np.float32) 1252 | self.copy_dynamic_nd(np_v, self.v) 1253 | np_material = np.ndarray((self.n_particles[None], ), dtype=np.int32) 1254 | self.copy_dynamic(np_material, self.material) 1255 | np_color = np.ndarray((self.n_particles[None], ), dtype=np.int32) 1256 | self.copy_dynamic(np_color, self.color) 1257 | np_rotation = np.ndarray((self.n_particles[None], self.dim, self.dim), dtype=np.float32) 1258 | self.copy_dynamic_nnd(np_rotation, self.particle_R) 1259 | particles_data = { 1260 | 'position': np_x, 1261 | 'velocity': np_v, 1262 | 'material': np_material, 1263 | 'color': np_color, 1264 | 'rotation': np_rotation 1265 | } 1266 | if self.use_emitter_id: 1267 | np_emitters = np.ndarray((self.n_particles[None], ), dtype=np.int32) 1268 | self.copy_dynamic(np_emitters, self.emitter_ids) 1269 | particles_data['emitter_ids'] = np_emitters 1270 | return particles_data 1271 | 1272 | @ti.kernel 1273 | def clear_particles(self): 1274 | self.n_particles[None] = 0 1275 | ti.deactivate(self.x.loop_range().parent().snode(), []) 1276 | 1277 | def write_particles(self, fn, slice_size=1000000): 1278 | from .particle_io import ParticleIO 1279 | ParticleIO.write_particles(self, fn, slice_size) 1280 | 1281 | def write_particles_ply(self, fn): 1282 | np_x = np.ndarray((self.n_particles[None], self.dim), dtype=np.float32) 1283 | self.copy_dynamic_nd(np_x, self.x) 1284 | np_color = np.ndarray((self.n_particles[None]), dtype=np.uint32) 1285 | self.copy_dynamic(np_color, self.color) 1286 | data = np.hstack([np_x, (np_color[:, None]).view(np.float32)]) 1287 | from mesh_io import write_point_cloud 1288 | write_point_cloud(fn, data) 1289 | -------------------------------------------------------------------------------- /feature_splatting/utils/mpm_engine/particle_io.py: -------------------------------------------------------------------------------- 1 | from engine.mesh_io import write_point_cloud 2 | import numpy as np 3 | import taichi as ti 4 | import time 5 | import gc 6 | 7 | 8 | class ParticleIO: 9 | v_bits = 8 10 | x_bits = 32 - v_bits 11 | 12 | @staticmethod 13 | def write_particles(solver, fn, slice_size=1000000): 14 | t = time.time() 15 | output_fn = fn 16 | 17 | n_particles = solver.n_particles[None] 18 | 19 | x_and_v = np.ndarray((n_particles, solver.dim), dtype=np.uint32) 20 | # Value ranges of x and v components, for quantization 21 | ranges = np.ndarray((2, solver.dim, 2), dtype=np.float32) 22 | 23 | # Fetch data slice after slice since we don't have the GPU memory to fetch them channel after channel... 24 | num_slices = (n_particles + slice_size - 1) // slice_size 25 | 26 | for d in range(solver.dim): 27 | np_x = np.ndarray((n_particles, ), dtype=np.float32) 28 | np_v = np.ndarray((n_particles, ), dtype=np.float32) 29 | 30 | np_x_slice = np.ndarray((slice_size, ), dtype=np.float32) 31 | np_v_slice = np.ndarray((slice_size, ), dtype=np.float32) 32 | 33 | for s in range(num_slices): 34 | begin = slice_size * s 35 | end = min(slice_size * (s + 1), n_particles) 36 | solver.copy_ranged(np_x_slice, solver.x.get_scalar_field(d), 37 | begin, end) 38 | solver.copy_ranged(np_v_slice, solver.v.get_scalar_field(d), 39 | begin, end) 40 | 41 | np_x[begin:end] = np_x_slice[:end - begin] 42 | np_v[begin:end] = np_v_slice[:end - begin] 43 | 44 | ranges[0, d] = [np.min(np_x), np.max(np_x)] 45 | ranges[1, d] = [np.min(np_v), np.max(np_v)] 46 | 47 | # Avoid too narrow ranges 48 | for c in range(2): 49 | ranges[c, d, 1] = max(ranges[c, d, 0] + 1e-5, ranges[c, d, 1]) 50 | np_x = ((np_x - ranges[0, d, 0]) * 51 | (1 / (ranges[0, d, 1] - ranges[0, d, 0])) * 52 | (2**ParticleIO.x_bits - 1) + 0.499).astype(np.uint32) 53 | np_v = ((np_v - ranges[1, d, 0]) * 54 | (1 / (ranges[1, d, 1] - ranges[1, d, 0])) * 55 | (2**ParticleIO.v_bits - 1) + 0.499).astype(np.uint32) 56 | x_and_v[:, d] = (np_x << ParticleIO.v_bits) + np_v 57 | del np_x, np_v 58 | 59 | color = np.ndarray((n_particles, 3), dtype=np.uint8) 60 | np_color = np.ndarray((n_particles, ), dtype=np.uint32) 61 | 62 | np_color_slice = np.ndarray((slice_size, ), dtype=np.float32) 63 | 64 | for s in range(num_slices): 65 | begin = slice_size * s 66 | end = min(slice_size * (s + 1), n_particles) 67 | 68 | solver.copy_ranged(np_color_slice, solver.color, begin, end) 69 | np_color[begin:end] = np_color_slice[:end - begin] 70 | 71 | for c in range(3): 72 | color[:, c] = (np_color >> (8 * (2 - c))) & 255 73 | 74 | np.savez(output_fn, ranges=ranges, x_and_v=x_and_v, color=color) 75 | 76 | print(f'Writing to disk: {time.time() - t:.3f} s') 77 | 78 | @staticmethod 79 | def read_particles_3d(fn): 80 | return ParticleIO.read_particles(fn, 3) 81 | 82 | @staticmethod 83 | def read_particles_2d(fn): 84 | return ParticleIO.read_particles(fn, 2) 85 | 86 | @staticmethod 87 | def read_particles(fn, dim): 88 | data = np.load(fn) 89 | ranges = data['ranges'] 90 | color = data['color'] 91 | x_and_v = data['x_and_v'] 92 | del data 93 | gc.collect() 94 | x = (x_and_v >> ParticleIO.v_bits).astype(np.float32) / ( 95 | (2**ParticleIO.x_bits - 1)) 96 | for c in range(dim): 97 | x[:, 98 | c] = x[:, c] * (ranges[0, c, 1] - ranges[0, c, 0]) + ranges[0, c, 99 | 0] 100 | v = (x_and_v & (2**ParticleIO.v_bits - 1)).astype( 101 | np.float32) / (2**ParticleIO.v_bits - 1) 102 | for c in range(dim): 103 | v[:, 104 | c] = v[:, c] * (ranges[1, c, 1] - ranges[1, c, 0]) + ranges[1, c, 105 | 0] 106 | return x, v, color 107 | 108 | @staticmethod 109 | def convert_particle_to_ply(fns): 110 | for fn in fns: 111 | print(f'Converting {fn}...') 112 | x, _, color = ParticleIO.read_particles_3d(fn) 113 | x = x.astype(np.float32) 114 | color = (color[:, 2].astype(np.uint32) << 16) + ( 115 | color[:, 1].astype(np.uint32) << 8) + color[:, 0] 116 | color = color[:, None] 117 | pos_color = np.hstack([x, color.view(np.float32)]) 118 | del x, color 119 | gc.collect() 120 | write_point_cloud(fn + ".ply", pos_color) 121 | 122 | 123 | if __name__ == '__main__': 124 | import sys 125 | ParticleIO.convert_particle_to_ply(sys.argv[1:]) 126 | -------------------------------------------------------------------------------- /feature_splatting/utils/mpm_engine/renderer.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import numpy as np 3 | import math 4 | import time 5 | from engine.renderer_utils import out_dir, ray_aabb_intersection, inf, eps, \ 6 | intersect_sphere, sphere_aabb_intersect_motion, inside_taichi 7 | 8 | from engine.particle_io import ParticleIO 9 | 10 | res = 1280, 720 11 | aspect_ratio = res[0] / res[1] 12 | 13 | max_ray_depth = 4 14 | use_directional_light = True 15 | 16 | dist_limit = 100 17 | # TODO: why doesn't it render normally when shutter_begin = -1? 18 | shutter_begin = -0.5 19 | 20 | exposure = 1.5 21 | light_direction = [1.2, 0.3, 0.7] 22 | light_direction_noise = 0.03 23 | light_color = [1.0, 1.0, 1.0] 24 | 25 | 26 | @ti.data_oriented 27 | class Renderer: 28 | def __init__(self, 29 | dx=1 / 1024, 30 | sphere_radius=0.3 / 1024, 31 | render_voxel=False, 32 | shutter_time=1e-3, 33 | taichi_logo=True, 34 | max_num_particles_million=128): 35 | self.vignette_strength = 0.9 36 | self.vignette_radius = 0.0 37 | self.vignette_center = [0.5, 0.5] 38 | self.taichi_logo = taichi_logo 39 | self.shutter_time = shutter_time # usually half the frame time 40 | self.enable_motion_blur = self.shutter_time != 0.0 41 | 42 | self.color_buffer = ti.Vector.field(3, dtype=ti.f32) 43 | self.bbox = ti.Vector.field(3, dtype=ti.f32, shape=2) 44 | self.voxel_grid_density = ti.field(dtype=ti.f32) 45 | self.voxel_has_particle = ti.field(dtype=ti.i32) 46 | self.fov = ti.field(dtype=ti.f32, shape=()) 47 | 48 | self.particle_x = ti.Vector.field(3, dtype=ti.f32) 49 | if self.enable_motion_blur: 50 | self.particle_v = ti.Vector.field(3, dtype=ti.f32) 51 | self.particle_color = ti.Vector.field(3, dtype=ti.u8) 52 | self.pid = ti.field(ti.i32) 53 | self.num_particles = ti.field(ti.i32, shape=()) 54 | 55 | self.render_voxel = render_voxel 56 | 57 | self.voxel_edges = 0.1 58 | 59 | self.particle_grid_res = 2048 60 | 61 | self.dx = dx 62 | self.inv_dx = 1 / self.dx 63 | 64 | self.camera_pos = ti.Vector.field(3, dtype=ti.f32, shape=()) 65 | self.look_at = ti.Vector.field(3, dtype=ti.f32, shape=()) 66 | self.up = ti.Vector.field(3, dtype=ti.f32, shape=()) 67 | 68 | self.floor_height = ti.field(dtype=ti.f32, shape=()) 69 | 70 | self.supporter = 2 71 | self.sphere_radius = sphere_radius 72 | self.particle_grid_offset = [ 73 | -self.particle_grid_res // 2 for _ in range(3) 74 | ] 75 | 76 | self.voxel_grid_res = self.particle_grid_res 77 | voxel_grid_offset = [-self.voxel_grid_res // 2 for _ in range(3)] 78 | self.max_num_particles_per_cell = 8192 * 1024 79 | self.max_num_particles = 1024 * 1024 * max_num_particles_million 80 | 81 | self.voxel_dx = self.dx 82 | self.voxel_inv_dx = 1 / self.voxel_dx 83 | 84 | assert self.sphere_radius * 2 < self.dx 85 | 86 | ti.root.dense(ti.ij, res).place(self.color_buffer) 87 | 88 | self.block_size = 8 89 | self.block_offset = [ 90 | o // self.block_size for o in self.particle_grid_offset 91 | ] 92 | self.particle_bucket = ti.root.pointer( 93 | ti.ijk, self.particle_grid_res // self.block_size) 94 | 95 | self.particle_bucket.dense(ti.ijk, self.block_size).dynamic( 96 | ti.l, self.max_num_particles_per_cell, 97 | chunk_size=32).place(self.pid, 98 | offset=self.particle_grid_offset + [0]) 99 | 100 | self.voxel_block_offset = [ 101 | o // self.block_size for o in voxel_grid_offset 102 | ] 103 | ti.root.pointer(ti.ijk, 104 | self.particle_grid_res // self.block_size).dense( 105 | ti.ijk, 106 | self.block_size).place(self.voxel_has_particle, 107 | offset=voxel_grid_offset) 108 | voxel_block = ti.root.pointer(ti.ijk, 109 | self.voxel_grid_res // self.block_size) 110 | 111 | voxel_block.dense(ti.ijk, 112 | self.block_size).place(self.voxel_grid_density, 113 | offset=voxel_grid_offset) 114 | 115 | particle = ti.root.dense(ti.l, self.max_num_particles) 116 | 117 | particle.place(self.particle_x) 118 | if self.enable_motion_blur: 119 | particle.place(self.particle_v) 120 | particle.place(self.particle_color) 121 | 122 | self.set_up(0, 1, 0) 123 | self.set_fov(0.23) 124 | 125 | @ti.func 126 | def inside_grid(self, ipos): 127 | return ipos.min() >= -self.voxel_grid_res // 2 and ipos.max( 128 | ) < self.voxel_grid_res // 2 129 | 130 | # The dda algorithm requires the voxel grid to have one surrounding layer of void region 131 | # to correctly render the outmost voxel faces 132 | @ti.func 133 | def inside_grid_loose(self, ipos): 134 | return ipos.min() >= -self.voxel_grid_res // 2 - 1 and ipos.max( 135 | ) <= self.voxel_grid_res // 2 136 | 137 | @ti.func 138 | def query_density(self, ipos): 139 | inside = self.inside_grid(ipos) 140 | ret = 0.0 141 | if inside: 142 | ret = self.voxel_grid_density[ipos] 143 | else: 144 | ret = 0.0 145 | return ret 146 | 147 | @ti.func 148 | def voxel_color(self, pos): 149 | p = pos * self.inv_dx 150 | 151 | p -= ti.floor(p) 152 | 153 | boundary = self.voxel_edges 154 | count = 0 155 | for i in ti.static(range(3)): 156 | if p[i] < boundary or p[i] > 1 - boundary: 157 | count += 1 158 | f = 0.0 159 | if count >= 2: 160 | f = 1.0 161 | return ti.Vector([0.9, 0.8, 1.0]) * (1.3 - 1.2 * f) 162 | 163 | @ti.func 164 | def sdf(self, o): 165 | dist = 0.0 166 | if ti.static(self.supporter == 0): 167 | o -= ti.Vector([0.5, 0.002, 0.5]) 168 | p = o 169 | h = 0.02 170 | ra = 0.29 171 | rb = 0.005 172 | d = (ti.Vector([p[0], p[2]]).norm() - 2.0 * ra + rb, abs(p[1]) - h) 173 | dist = min(max(d[0], d[1]), 0.0) + ti.Vector( 174 | [max(d[0], 0.0), max(d[1], 0)]).norm() - rb 175 | elif ti.static(self.supporter == 1): 176 | o -= ti.Vector([0.5, 0.002, 0.5]) 177 | dist = (o.abs() - ti.Vector([0.5, 0.02, 0.5])).max() 178 | else: 179 | dist = o[1] - self.floor_height[None] 180 | 181 | return dist 182 | 183 | @ti.func 184 | def ray_march(self, p, d): 185 | j = 0 186 | dist = 0.0 187 | limit = 200 188 | while j < limit and self.sdf(p + 189 | dist * d) > 1e-8 and dist < dist_limit: 190 | dist += self.sdf(p + dist * d) 191 | j += 1 192 | if dist > dist_limit: 193 | dist = inf 194 | return dist 195 | 196 | @ti.func 197 | def sdf_normal(self, p): 198 | d = 1e-3 199 | n = ti.Vector([0.0, 0.0, 0.0]) 200 | for i in ti.static(range(3)): 201 | inc = p 202 | dec = p 203 | inc[i] += d 204 | dec[i] -= d 205 | n[i] = (0.5 / d) * (self.sdf(inc) - self.sdf(dec)) 206 | return n.normalized() 207 | 208 | @ti.func 209 | def sdf_color(self, p): 210 | scale = 0.0 211 | if ti.static(self.taichi_logo): 212 | scale = 0.4 213 | if inside_taichi(ti.Vector([p[0], p[2]])): 214 | scale = 1 215 | else: 216 | scale = 1.0 217 | return ti.Vector([0.3, 0.5, 0.7]) * scale 218 | 219 | # Digital differential analyzer for the grid visualization (render_voxels=True) 220 | @ti.func 221 | def dda_voxel(self, eye_pos, d): 222 | for i in ti.static(range(3)): 223 | if abs(d[i]) < 1e-6: 224 | d[i] = 1e-6 225 | rinv = 1.0 / d 226 | rsign = ti.Vector([0, 0, 0]) 227 | for i in ti.static(range(3)): 228 | if d[i] > 0: 229 | rsign[i] = 1 230 | else: 231 | rsign[i] = -1 232 | 233 | bbox_min = self.bbox[0] 234 | bbox_max = self.bbox[1] 235 | inter, near, far = ray_aabb_intersection(bbox_min, bbox_max, eye_pos, 236 | d) 237 | hit_distance = inf 238 | normal = ti.Vector([0.0, 0.0, 0.0]) 239 | c = ti.Vector([0.0, 0.0, 0.0]) 240 | if inter: 241 | near = max(0, near) 242 | 243 | pos = eye_pos + d * (near + 5 * eps) 244 | 245 | o = self.voxel_inv_dx * pos 246 | ipos = int(ti.floor(o)) 247 | dis = (ipos - o + 0.5 + rsign * 0.5) * rinv 248 | running = 1 249 | i = 0 250 | hit_pos = ti.Vector([0.0, 0.0, 0.0]) 251 | while running: 252 | last_sample = int(self.query_density(ipos)) 253 | if not self.inside_particle_grid(ipos): 254 | running = 0 255 | 256 | if last_sample: 257 | mini = (ipos - o + ti.Vector([0.5, 0.5, 0.5]) - 258 | rsign * 0.5) * rinv 259 | hit_distance = mini.max() * self.voxel_dx + near 260 | hit_pos = eye_pos + hit_distance * d 261 | c = self.voxel_color(hit_pos) 262 | running = 0 263 | else: 264 | mm = ti.Vector([0, 0, 0]) 265 | if dis[0] <= dis[1] and dis[0] < dis[2]: 266 | mm[0] = 1 267 | elif dis[1] <= dis[0] and dis[1] <= dis[2]: 268 | mm[1] = 1 269 | else: 270 | mm[2] = 1 271 | dis += mm * rsign * rinv 272 | ipos += mm * rsign 273 | normal = -mm * rsign 274 | i += 1 275 | return hit_distance, normal, c 276 | 277 | @ti.func 278 | def inside_particle_grid(self, ipos): 279 | pos = ipos * self.dx 280 | return self.bbox[0][0] <= pos[0] and pos[0] < self.bbox[1][ 281 | 0] and self.bbox[0][1] <= pos[1] and pos[1] < self.bbox[1][ 282 | 1] and self.bbox[0][2] <= pos[2] and pos[2] < self.bbox[1][2] 283 | 284 | # DDA for the particle visualization (render_voxels=False) 285 | @ti.func 286 | def dda_particle(self, eye_pos, d, t): 287 | # bounding box 288 | bbox_min = self.bbox[0] 289 | bbox_max = self.bbox[1] 290 | 291 | hit_pos = ti.Vector([0.0, 0.0, 0.0]) 292 | normal = ti.Vector([0.0, 0.0, 0.0]) 293 | c = ti.Vector([0.0, 0.0, 0.0]) 294 | for i in ti.static(range(3)): 295 | if abs(d[i]) < 1e-6: 296 | d[i] = 1e-6 297 | 298 | inter, near, far = ray_aabb_intersection(bbox_min, bbox_max, eye_pos, 299 | d) 300 | near = max(0, near) 301 | 302 | closest_intersection = inf 303 | 304 | if inter: 305 | pos = eye_pos + d * (near + eps) 306 | 307 | rinv = 1.0 / d 308 | rsign = ti.Vector([0, 0, 0]) 309 | for i in ti.static(range(3)): 310 | if d[i] > 0: 311 | rsign[i] = 1 312 | else: 313 | rsign[i] = -1 314 | 315 | o = self.inv_dx * pos 316 | ipos = ti.floor(o).cast(int) 317 | dis = (ipos - o + 0.5 + rsign * 0.5) * rinv 318 | running = 1 319 | # DDA for voxels with at least one particle 320 | while running: 321 | inside = self.inside_particle_grid(ipos) 322 | 323 | if inside: 324 | # once we actually intersect with a voxel that contains at least one particle, loop over the particle list 325 | num_particles = self.voxel_has_particle[ipos] 326 | if num_particles != 0: 327 | num_particles = ti.length( 328 | self.pid.parent(), 329 | ipos - ti.Vector(self.particle_grid_offset)) 330 | for k in range(num_particles): 331 | p = self.pid[ipos, k] 332 | v = ti.Vector([0.0, 0.0, 0.0]) 333 | if ti.static(self.enable_motion_blur): 334 | v = self.particle_v[p] 335 | x = self.particle_x[p] + t * v 336 | color = ti.cast(self.particle_color[p], 337 | ti.u32) * (1 / 255.0) 338 | # ray-sphere intersection 339 | dist, poss = intersect_sphere(eye_pos, d, x, 340 | self.sphere_radius) 341 | hit_pos = poss 342 | if dist < closest_intersection and dist > 0: 343 | hit_pos = eye_pos + dist * d 344 | closest_intersection = dist 345 | normal = (hit_pos - x).normalized() 346 | c = color 347 | else: 348 | running = 0 349 | normal = [0, 0, 0] 350 | 351 | if closest_intersection < inf: 352 | running = 0 353 | else: 354 | # hits nothing. Continue ray marching 355 | mm = ti.Vector([0, 0, 0]) 356 | if dis[0] <= dis[1] and dis[0] <= dis[2]: 357 | mm[0] = 1 358 | elif dis[1] <= dis[0] and dis[1] <= dis[2]: 359 | mm[1] = 1 360 | else: 361 | mm[2] = 1 362 | dis += mm * rsign * rinv 363 | ipos += mm * rsign 364 | 365 | return closest_intersection, normal, c 366 | 367 | @ti.func 368 | def next_hit(self, pos, d, t): 369 | closest = inf 370 | normal = ti.Vector([0.0, 0.0, 0.0]) 371 | c = ti.Vector([0.0, 0.0, 0.0]) 372 | if ti.static(self.render_voxel): 373 | closest, normal, c = self.dda_voxel(pos, d) 374 | else: 375 | closest, normal, c = self.dda_particle(pos, d, t) 376 | 377 | if d[2] != 0: 378 | ray_closest = -(pos[2] + 5.5) / d[2] 379 | if ray_closest > 0 and ray_closest < closest: 380 | closest = ray_closest 381 | normal = ti.Vector([0.0, 0.0, 1.0]) 382 | c = ti.Vector([0.6, 0.7, 0.7]) 383 | 384 | ray_march_dist = self.ray_march(pos, d) 385 | if ray_march_dist < dist_limit and ray_march_dist < closest: 386 | closest = ray_march_dist 387 | normal = self.sdf_normal(pos + d * closest) 388 | c = self.sdf_color(pos + d * closest) 389 | 390 | return closest, normal, c 391 | 392 | @ti.kernel 393 | def set_camera_pos(self, x: ti.f32, y: ti.f32, z: ti.f32): 394 | self.camera_pos[None] = ti.Vector([x, y, z]) 395 | 396 | @ti.kernel 397 | def set_up(self, x: ti.f32, y: ti.f32, z: ti.f32): 398 | self.up[None] = ti.Vector([x, y, z]).normalized() 399 | 400 | @ti.kernel 401 | def look_at(self, x: ti.f32, y: ti.f32, z: ti.f32): 402 | self.look_at[None] = ti.Vector([x, y, z]) 403 | 404 | @ti.kernel 405 | def set_fov(self, fov: ti.f32): 406 | self.fov[None] = fov 407 | 408 | @ti.kernel 409 | def render(self): 410 | ti.loop_config(block_dim=256) 411 | for u, v in self.color_buffer: 412 | fov = self.fov[None] 413 | pos = self.camera_pos[None] 414 | d = (self.look_at[None] - self.camera_pos[None]).normalized() 415 | fu = (2 * fov * (u + ti.random(ti.f32)) / res[1] - 416 | fov * aspect_ratio - 1e-5) 417 | fv = 2 * fov * (v + ti.random(ti.f32)) / res[1] - fov - 1e-5 418 | du = d.cross(self.up[None]).normalized() 419 | dv = du.cross(d).normalized() 420 | d = (d + fu * du + fv * dv).normalized() 421 | t = (ti.random() + shutter_begin) * self.shutter_time 422 | 423 | contrib = ti.Vector([0.0, 0.0, 0.0]) 424 | throughput = ti.Vector([1.0, 1.0, 1.0]) 425 | 426 | depth = 0 427 | hit_sky = 1 428 | ray_depth = 0 429 | 430 | while depth < max_ray_depth: 431 | closest, normal, c = self.next_hit(pos, d, t) 432 | hit_pos = pos + closest * d 433 | depth += 1 434 | ray_depth = depth 435 | if normal.norm() != 0: 436 | d = out_dir(normal) 437 | pos = hit_pos + 1e-4 * d 438 | throughput *= c 439 | 440 | if ti.static(use_directional_light): 441 | dir_noise = ti.Vector([ 442 | ti.random() - 0.5, 443 | ti.random() - 0.5, 444 | ti.random() - 0.5 445 | ]) * light_direction_noise 446 | direct = (ti.Vector(light_direction) + 447 | dir_noise).normalized() 448 | dot = direct.dot(normal) 449 | if dot > 0: 450 | dist, _, _ = self.next_hit(pos, direct, t) 451 | if dist > dist_limit: 452 | contrib += throughput * ti.Vector( 453 | light_color) * dot 454 | else: # hit sky 455 | hit_sky = 1 456 | depth = max_ray_depth 457 | 458 | max_c = throughput.max() 459 | if ti.random() > max_c: 460 | depth = max_ray_depth 461 | throughput = [0, 0, 0] 462 | else: 463 | throughput /= max_c 464 | 465 | if hit_sky: 466 | if ray_depth != 1: 467 | # contrib *= max(d[1], 0.05) 468 | pass 469 | else: 470 | # directly hit sky 471 | pass 472 | else: 473 | throughput *= 0 474 | 475 | # contrib += throughput 476 | self.color_buffer[u, v] += contrib 477 | 478 | @ti.kernel 479 | def initialize_particle_grid(self): 480 | for p in range(self.num_particles[None]): 481 | v = ti.Vector([0.0, 0.0, 0.0]) 482 | if ti.static(self.enable_motion_blur): 483 | v = self.particle_v[p] 484 | x = self.particle_x[p] 485 | ipos = ti.floor(x * self.inv_dx).cast(ti.i32) 486 | 487 | offset_begin = shutter_begin * self.shutter_time * v 488 | offset_end = (shutter_begin + 1.0) * self.shutter_time * v 489 | offset_begin_grid = offset_begin 490 | offset_end_grid = offset_end 491 | 492 | for k in ti.static(range(3)): 493 | if offset_end_grid[k] < offset_begin_grid[k]: 494 | t = offset_end_grid[k] 495 | offset_end_grid[k] = offset_begin_grid[k] 496 | offset_begin_grid[k] = t 497 | 498 | offset_begin_grid = int(ti.floor( 499 | offset_begin_grid * self.inv_dx)) - 1 500 | offset_end_grid = int(ti.ceil(offset_end_grid * self.inv_dx)) + 2 501 | 502 | for i in range(offset_begin_grid[0], offset_end_grid[0]): 503 | for j in range(offset_begin_grid[1], offset_end_grid[1]): 504 | for k in range(offset_begin_grid[2], offset_end_grid[2]): 505 | offset = ti.Vector([i, j, k]) 506 | box_ipos = ipos + offset 507 | if self.inside_particle_grid(box_ipos): 508 | box_min = box_ipos * self.dx 509 | box_max = (box_ipos + 510 | ti.Vector([1, 1, 1])) * self.dx 511 | if sphere_aabb_intersect_motion( 512 | box_min, box_max, x + offset_begin, 513 | x + offset_end, self.sphere_radius): 514 | self.voxel_has_particle[box_ipos] = 1 515 | self.voxel_grid_density[box_ipos] = 1 516 | ti.append( 517 | self.pid.parent(), box_ipos - 518 | ti.Vector(self.particle_grid_offset), p) 519 | 520 | @ti.kernel 521 | def copy(self, img: ti.types.ndarray(), samples: ti.i32): 522 | for i, j in self.color_buffer: 523 | u = 1.0 * i / res[0] 524 | v = 1.0 * j / res[1] 525 | 526 | darken = 1.0 - self.vignette_strength * max((ti.sqrt( 527 | (u - self.vignette_center[0])**2 + 528 | (v - self.vignette_center[1])**2) - self.vignette_radius), 0) 529 | 530 | for c in ti.static(range(3)): 531 | img[i, j, c] = ti.sqrt(self.color_buffer[i, j][c] * darken * 532 | exposure / samples) 533 | 534 | @ti.kernel 535 | def initialize_particle(self, x: ti.types.ndarray(), v: ti.types.ndarray(), 536 | color: ti.types.ndarray(), begin: ti.i32, end: ti.i32): 537 | for i in range(begin, end): 538 | for c in ti.static(range(3)): 539 | self.particle_x[i][c] = x[i - begin, c] 540 | if ti.static(self.enable_motion_blur): 541 | self.particle_v[i][c] = v[i - begin, c] 542 | self.particle_color[i][c] = color[i - begin, c] 543 | 544 | @ti.kernel 545 | def total_non_empty_voxels(self) -> ti.i32: 546 | counter = 0 547 | 548 | for I in ti.grouped(self.voxel_has_particle): 549 | if self.voxel_has_particle[I]: 550 | counter += 1 551 | 552 | return counter 553 | 554 | @ti.kernel 555 | def total_inserted_particles(self) -> ti.i32: 556 | counter = 0 557 | 558 | for I in ti.grouped(self.voxel_has_particle): 559 | if self.voxel_has_particle[I]: 560 | num_particles = ti.length( 561 | self.pid.parent(), 562 | I - ti.Vector(self.particle_grid_offset)) 563 | counter += num_particles 564 | 565 | return counter 566 | 567 | def reset(self): 568 | self.particle_bucket.deactivate_all() 569 | self.voxel_grid_density.snode.parent(n=2).deactivate_all() 570 | self.voxel_has_particle.snode.parent(n=2).deactivate_all() 571 | self.color_buffer.fill(0) 572 | 573 | def initialize_particles_from_taichi_elements(self, particle_fn): 574 | self.reset() 575 | 576 | np_x, np_v, np_color = ParticleIO.read_particles_3d(particle_fn) 577 | num_part = len(np_x) 578 | 579 | assert num_part <= self.max_num_particles 580 | 581 | for i in range(3): 582 | # bbox values must be multiples of self.dx 583 | # bbox values are the min and max particle coordinates, with 3 self.dx margin 584 | self.bbox[0][i] = (math.floor(np_x[:, i].min() * self.inv_dx) - 585 | 3.0) * self.dx 586 | self.bbox[1][i] = (math.floor(np_x[:, i].max() * self.inv_dx) + 587 | 3.0) * self.dx 588 | print(f'Bounding box dim {i}: {self.bbox[0][i]} {self.bbox[1][i]}') 589 | 590 | # TODO: assert bounds 591 | 592 | self.num_particles[None] = num_part 593 | print('num_input_particles =', num_part) 594 | 595 | slice_size = 1000000 596 | num_slices = (num_part + slice_size - 1) // slice_size 597 | for i in range(num_slices): 598 | begin = slice_size * i 599 | end = min(num_part, begin + slice_size) 600 | self.initialize_particle(np_x[begin:end], np_v[begin:end], 601 | np_color[begin:end], begin, end) 602 | self.initialize_particle_grid() 603 | 604 | def render_frame(self, spp): 605 | last_t = 0 606 | for i in range(1, 1 + spp): 607 | self.render() 608 | 609 | interval = 20 610 | if i % interval == 0: 611 | if last_t != 0: 612 | ti.sync() 613 | print("time per spp = {:.2f} ms".format( 614 | (time.time() - last_t) * 1000 / interval)) 615 | last_t = time.time() 616 | 617 | img = np.zeros((res[0], res[1], 3), dtype=np.float32) 618 | self.copy(img, spp) 619 | return img 620 | -------------------------------------------------------------------------------- /feature_splatting/utils/mpm_engine/renderer_utils.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import math 3 | 4 | eps = 1e-4 5 | inf = 1e10 6 | 7 | 8 | @ti.func 9 | def out_dir(n): 10 | u = ti.Vector([1.0, 0.0, 0.0]) 11 | if ti.abs(n[1]) < 1 - 1e-3: 12 | u = n.cross(ti.Vector([0.0, 1.0, 0.0])).normalized() 13 | v = n.cross(u) 14 | phi = 2 * math.pi * ti.random(ti.f32) 15 | r = ti.random(ti.f32) 16 | ay = ti.sqrt(r) 17 | ax = ti.sqrt(1 - r) 18 | return ax * (ti.cos(phi) * u + ti.sin(phi) * v) + ay * n 19 | 20 | 21 | @ti.func 22 | def reflect(d, n): 23 | # Assuming |d| and |n| are normalized 24 | return d - 2.0 * d.dot(n) * n 25 | 26 | 27 | @ti.func 28 | def refract(d, n, ni_over_nt): 29 | # Assuming |d| and |n| are normalized 30 | has_r, rd = 0, d 31 | dt = d.dot(n) 32 | discr = 1.0 - ni_over_nt * ni_over_nt * (1.0 - dt * dt) 33 | if discr > 0.0: 34 | has_r = 1 35 | rd = (ni_over_nt * (d - n * dt) - n * ti.sqrt(discr)).normalized() 36 | else: 37 | rd *= 0.0 38 | return has_r, rd 39 | 40 | 41 | @ti.func 42 | def ray_aabb_intersection(box_min, box_max, o, d): 43 | intersect = 1 44 | 45 | near_int = -inf 46 | far_int = inf 47 | 48 | for i in ti.static(range(3)): 49 | if d[i] == 0: 50 | if o[i] < box_min[i] or o[i] > box_max[i]: 51 | intersect = 0 52 | else: 53 | i1 = (box_min[i] - o[i]) / d[i] 54 | i2 = (box_max[i] - o[i]) / d[i] 55 | 56 | new_far_int = ti.max(i1, i2) 57 | new_near_int = ti.min(i1, i2) 58 | 59 | far_int = ti.min(new_far_int, far_int) 60 | near_int = ti.max(new_near_int, near_int) 61 | 62 | if near_int > far_int: 63 | intersect = 0 64 | return intersect, near_int, far_int 65 | 66 | 67 | # (T + x d)(T + x d) = r * r 68 | # T*T + 2Td x + x^2 = r * r 69 | # x^2 + 2Td x + (T * T - r * r) = 0 70 | 71 | 72 | @ti.func 73 | def intersect_sphere(pos, d, center, radius): 74 | T = pos - center 75 | A = 1.0 76 | invA = 1 / A 77 | B = 2.0 * T.dot(d) 78 | hit_pos = ti.Vector([0.0, 0.0, 0.0]) 79 | ratio = 0.5 * invA 80 | ret1 = ratio * (-B) 81 | dist = ret1 82 | 83 | old_dist = dist 84 | new_pos = pos + d * dist 85 | T = new_pos - center 86 | A = 1.0 87 | B = 2.0 * T.dot(d) 88 | C = T.dot(T) - radius * radius 89 | delta = B * B - 4 * A * C 90 | if delta > 0: 91 | sdelta = ti.sqrt(delta) 92 | ratio = 0.5 * invA 93 | ret1 = ratio * (-B - sdelta) + old_dist 94 | if ret1 > 0: 95 | dist = ret1 96 | hit_pos = new_pos + ratio * (-B - sdelta) * d 97 | else: 98 | pass 99 | else: 100 | dist = inf 101 | 102 | return dist, hit_pos 103 | 104 | 105 | @ti.func 106 | def ray_plane_intersect(pos, d, pt_on_plane, norm): 107 | dist = inf 108 | hit_pos = ti.Vector([0.0, 0.0, 0.0]) 109 | denom = d.dot(norm) 110 | if abs(denom) > eps: 111 | dist = norm.dot(pt_on_plane - pos) / denom 112 | hit_pos = pos + d * dist 113 | return dist, hit_pos 114 | 115 | 116 | @ti.func 117 | def point_aabb_distance2(box_min, box_max, o): 118 | p = ti.Vector([0.0, 0.0, 0.0]) 119 | for i in ti.static(range(3)): 120 | p[i] = ti.max(ti.min(o[i], box_max[i]), box_min[i]) 121 | return (p - o).norm_sqr() 122 | 123 | 124 | @ti.func 125 | def sphere_aabb_intersect(box_min, box_max, o, radius): 126 | return point_aabb_distance2(box_min, box_max, o) < radius * radius 127 | 128 | 129 | @ti.func 130 | def sphere_aabb_intersect_motion(box_min, box_max, o1, o2, radius): 131 | lo = 0.0 132 | hi = 1.0 133 | while lo + 1e-5 < hi: 134 | m1 = 2 * lo / 3 + hi / 3 135 | m2 = lo / 3 + 2 * hi / 3 136 | d1 = point_aabb_distance2(box_min, box_max, (1 - m1) * o1 + m1 * o2) 137 | d2 = point_aabb_distance2(box_min, box_max, (1 - m2) * o1 + m2 * o2) 138 | if d2 > d1: 139 | hi = m2 140 | else: 141 | lo = m1 142 | 143 | return point_aabb_distance2(box_min, box_max, 144 | (1 - lo) * o1 + lo * o2) < radius * radius 145 | 146 | 147 | @ti.func 148 | def inside(p, c, r): 149 | return (p - c).norm_sqr() <= r * r 150 | 151 | 152 | @ti.func 153 | def inside_left(p, c, r): 154 | return inside(p, c, r) and p[0] < c[0] 155 | 156 | 157 | @ti.func 158 | def inside_right(p, c, r): 159 | return inside(p, c, r) and p[0] > c[0] 160 | 161 | 162 | def Vector2(x, y): 163 | return ti.Vector([x, y]) 164 | 165 | 166 | @ti.func 167 | def inside_taichi(p_): 168 | p = p_ 169 | ret = -1 170 | if not inside(p, Vector2(0.50, 0.50), 0.52): 171 | if ret == -1: 172 | ret = 0 173 | if not inside(p, Vector2(0.50, 0.50), 0.495): 174 | if ret == -1: 175 | ret = 1 176 | p = Vector2(0.5, 0.5) + (p - Vector2(0.5, 0.5)) 177 | if inside(p, Vector2(0.50, 0.25), 0.08): 178 | if ret == -1: 179 | ret = 1 180 | if inside(p, Vector2(0.50, 0.75), 0.08): 181 | if ret == -1: 182 | ret = 0 183 | if inside(p, Vector2(0.50, 0.25), 0.25): 184 | if ret == -1: 185 | ret = 0 186 | if inside(p, Vector2(0.50, 0.75), 0.25): 187 | if ret == -1: 188 | ret = 1 189 | if p[0] < 0.5: 190 | if ret == -1: 191 | ret = 1 192 | else: 193 | if ret == -1: 194 | ret = 0 195 | return ret 196 | -------------------------------------------------------------------------------- /feature_splatting/utils/mpm_engine/voxelizer.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import numpy as np 3 | 4 | 5 | @ti.func 6 | def cross2d(a, b): 7 | return a[0] * b[1] - a[1] * b[0] 8 | 9 | 10 | @ti.func 11 | def inside_ccw(p, a, b, c): 12 | return cross2d(a - p, b - p) >= 0 and cross2d( 13 | b - p, c - p) >= 0 and cross2d(c - p, a - p) >= 0 14 | 15 | 16 | @ti.data_oriented 17 | class Voxelizer: 18 | def __init__(self, res, dx, super_sample=2, precision=ti.f64, padding=3): 19 | assert len(res) == 3 20 | res = list(res) 21 | for i in range(len(res)): 22 | r = 1 23 | while r < res[i]: 24 | r = r * 2 25 | res[i] = r 26 | print(f'Voxelizer resolution {res}') 27 | # Super sample by 2x 28 | self.res = (res[0] * super_sample, res[1] * super_sample, 29 | res[2] * super_sample) 30 | self.dx = dx / super_sample 31 | self.inv_dx = 1 / self.dx 32 | self.voxels = ti.field(ti.i32) 33 | self.block = ti.root.pointer( 34 | ti.ijk, (self.res[0] // 8, self.res[1] // 8, self.res[2] // 8)) 35 | self.block.dense(ti.ijk, 8).place(self.voxels) 36 | 37 | assert precision in [ti.f32, ti.f64] 38 | self.precision = precision 39 | self.padding = padding 40 | 41 | @ti.func 42 | def fill(self, p, q, height, inc): 43 | for i in range(self.padding, height): 44 | self.voxels[p, q, i] += inc 45 | 46 | @ti.kernel 47 | def voxelize_triangles(self, num_triangles: ti.i32, 48 | triangles: ti.types.ndarray()): 49 | for i in range(num_triangles): 50 | jitter_scale = ti.cast(0, self.precision) 51 | if ti.static(self.precision == ti.f32): 52 | jitter_scale = 1e-4 53 | else: 54 | jitter_scale = 1e-8 55 | # We jitter the vertices to prevent voxel samples from lying precicely at triangle edges 56 | jitter = ti.Vector([ 57 | -0.057616723909439505, -0.25608986292614977, 58 | 0.06716309129743714 59 | ]) * jitter_scale 60 | a = ti.Vector([triangles[i, 0], triangles[i, 1], triangles[i, 2] 61 | ]) + jitter 62 | b = ti.Vector([triangles[i, 3], triangles[i, 4], triangles[i, 5] 63 | ]) + jitter 64 | c = ti.Vector([triangles[i, 6], triangles[i, 7], triangles[i, 8] 65 | ]) + jitter 66 | 67 | bound_min = ti.Vector.zero(self.precision, 3) 68 | bound_max = ti.Vector.zero(self.precision, 3) 69 | for k in ti.static(range(3)): 70 | bound_min[k] = min(a[k], b[k], c[k]) 71 | bound_max[k] = max(a[k], b[k], c[k]) 72 | 73 | p_min = int(ti.floor(bound_min[0] * self.inv_dx)) 74 | p_max = int(ti.floor(bound_max[0] * self.inv_dx)) + 1 75 | 76 | p_min = max(self.padding, p_min) 77 | p_max = min(self.res[0] - self.padding, p_max) 78 | 79 | q_min = int(ti.floor(bound_min[1] * self.inv_dx)) 80 | q_max = int(ti.floor(bound_max[1] * self.inv_dx)) + 1 81 | 82 | q_min = max(self.padding, q_min) 83 | q_max = min(self.res[1] - self.padding, q_max) 84 | 85 | normal = ((b - a).cross(c - a)).normalized() 86 | 87 | if abs(normal[2]) < 1e-10: 88 | continue 89 | 90 | a_proj = ti.Vector([a[0], a[1]]) 91 | b_proj = ti.Vector([b[0], b[1]]) 92 | c_proj = ti.Vector([c[0], c[1]]) 93 | 94 | for p in range(p_min, p_max): 95 | for q in range(q_min, q_max): 96 | pos2d = ti.Vector([(p + 0.5) * self.dx, 97 | (q + 0.5) * self.dx]) 98 | if inside_ccw(pos2d, a_proj, b_proj, c_proj) or inside_ccw( 99 | pos2d, a_proj, c_proj, b_proj): 100 | base_voxel = ti.Vector([pos2d[0], pos2d[1], 0]) 101 | height = int(-normal.dot(base_voxel - a) / normal[2] * 102 | self.inv_dx + 0.5) 103 | height = min(height, self.res[1] - self.padding) 104 | inc = 0 105 | if normal[2] > 0: 106 | inc = 1 107 | else: 108 | inc = -1 109 | self.fill(p, q, height, inc) 110 | 111 | def voxelize(self, triangles): 112 | assert isinstance(triangles, np.ndarray) 113 | triangles = triangles.astype(np.float64) 114 | assert triangles.dtype in [np.float32, np.float64] 115 | if self.precision is ti.f32: 116 | triangles = triangles.astype(np.float32) 117 | elif self.precision is ti.f64: 118 | triangles = triangles.astype(np.float64) 119 | else: 120 | assert False 121 | assert len(triangles.shape) == 2 122 | assert triangles.shape[1] == 9 123 | 124 | self.block.deactivate_all() 125 | num_triangles = len(triangles) 126 | self.voxelize_triangles(num_triangles, triangles) 127 | 128 | 129 | if __name__ == '__main__': 130 | n = 256 131 | vox = Voxelizer((n, n, n), 1.0 / n) 132 | # triangle = np.array([[0.1, 0.1, 0.1, 0.6, 0.2, 0.1, 0.5, 0.7, 133 | # 0.7]]).astype(np.float32) 134 | triangles = np.fromfile('triangles.npy', dtype=np.float32) 135 | triangles = np.reshape(triangles, (len(triangles) // 9, 9)) * 0.306 + 0.501 136 | offsets = [0.0, 0.0, 0.0] 137 | for i in range(9): 138 | triangles[:, i] += offsets[i % 3] 139 | print(triangles.shape) 140 | print(triangles.max()) 141 | print(triangles.min()) 142 | 143 | vox.voxelize(triangles) 144 | 145 | voxels = vox.voxels.to_numpy().astype(np.float32) 146 | 147 | import os 148 | os.makedirs('outputs', exist_ok=True) 149 | gui = ti.GUI('cross section', (n, n)) 150 | for i in range(n): 151 | gui.set_image(voxels[:, :, i]) 152 | gui.show(f'outputs/{i:04d}.png') 153 | -------------------------------------------------------------------------------- /feature_splatting/utils/segment_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation as R 3 | from .math_utils import point_to_plane_distance, vector_angle 4 | 5 | def cluster_instance(all_xyz_n3, selected_obj_idx=None, min_sample=20, eps=0.1): 6 | """ 7 | Cluster points into instances using DBSCAN. 8 | Return the indices of the most populated cluster. 9 | """ 10 | from sklearn.cluster import DBSCAN 11 | if selected_obj_idx is None: 12 | selected_obj_idx = np.ones(all_xyz_n3.shape[0], dtype=bool) 13 | dbscan = DBSCAN(eps=eps, min_samples=min_sample).fit(all_xyz_n3[selected_obj_idx]) 14 | clustered_labels = dbscan.labels_ 15 | 16 | # Find the most populated cluster 17 | label_idx_list, label_count_list = np.unique(clustered_labels, return_counts=True) 18 | # Filter out -1 19 | label_count_list = label_count_list[label_idx_list != -1] 20 | label_idx_list = label_idx_list[label_idx_list != -1] 21 | max_count_label = label_idx_list[np.argmax(label_count_list)] 22 | 23 | clustered_idx = np.zeros_like(selected_obj_idx, dtype=bool) 24 | # Double assignment to make sure indices go into the right place 25 | arr = clustered_idx[selected_obj_idx] 26 | arr[clustered_labels == max_count_label] = True 27 | clustered_idx[selected_obj_idx] = arr 28 | return clustered_idx 29 | 30 | def estimate_ground(ground_pts, distance_threshold=0.005, rotation_flip=False): 31 | import open3d as o3d 32 | point_cloud = ground_pts.copy() 33 | 34 | pcd = o3d.geometry.PointCloud() 35 | pcd.points = o3d.utility.Vector3dVector(point_cloud) 36 | 37 | plane_model, inliers = pcd.segment_plane(distance_threshold=distance_threshold, 38 | ransac_n=3, 39 | num_iterations=2000) 40 | # [a, b, c, d] = plane_model 41 | # print(f"Plane equation: {a:.2f}x + {b:.2f}y + {c:.2f}z + {d:.2f} = 0") 42 | 43 | origin_plane_distance = point_to_plane_distance((0, 0, 0), plane_model) 44 | 45 | # Calculate rotation angle between plane normal & z-axis 46 | plane_normal = tuple(plane_model[:3]) 47 | plane_normal = np.array(plane_normal) / np.linalg.norm(plane_normal) 48 | 49 | # Taichi uses y-axis as up-axis (OpenGL convention) 50 | if rotation_flip: 51 | # Sometimes the estimated plane normal is flipped 52 | y_axis = np.array((0, -1, 0)) 53 | else: 54 | y_axis = np.array((0, 1, 0)) # Taichi uses y-axis as up-axis 55 | 56 | rotation_angle = vector_angle(plane_normal, y_axis) 57 | 58 | # Calculate rotation axis 59 | rotation_axis = np.cross(plane_normal, y_axis) 60 | rotation_axis = rotation_axis / np.linalg.norm(rotation_axis) 61 | 62 | # Generate axis-angle representation 63 | axis_angle = tuple([x * rotation_angle for x in rotation_axis]) 64 | 65 | # Rotate point cloud 66 | rotation_object = R.from_rotvec(axis_angle) 67 | rotation_matrix = rotation_object.as_matrix() 68 | 69 | return (rotation_matrix, np.array((0, origin_plane_distance, 0)), inliers) 70 | 71 | def get_ground_bbox_min_max(all_xyz_n3, selected_obj_idx, ground_R, ground_T): 72 | """ 73 | Select points within a bounding box. 74 | """ 75 | particles = all_xyz_n3 @ ground_R.T 76 | particles += ground_T 77 | xyz_min = np.min(particles[selected_obj_idx], axis=0) 78 | xyz_max = np.max(particles[selected_obj_idx], axis=0) 79 | return xyz_min, xyz_max 80 | -------------------------------------------------------------------------------- /feature_splatting/utils/viewer_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import cached_property 3 | from typing import Dict, List, Literal, Optional, Tuple, Type, Union 4 | from torchtyping import TensorType 5 | from nerfstudio.utils.rich_utils import CONSOLE 6 | 7 | import torch 8 | 9 | class ViewerUtils: 10 | def __init__(self, text_encoding_func, softmax_temp: float = 0.05, canonical_words: str = 'object'): 11 | self.text_encoding_func = text_encoding_func 12 | self.text_embedding_dict = {} 13 | self.update_text_embedding('canonical', canonical_words) 14 | self.softmax_temp = softmax_temp 15 | self.pca_proj = None 16 | 17 | @torch.no_grad() 18 | def update_text_embedding(self, name_key: str, raw_text: str): 19 | """Compute CLIP embeddings based on queries and update state""" 20 | texts = [x.strip() for x in raw_text.split(",") if x.strip()] 21 | if not texts: 22 | self.text_embedding_dict[name_key] = (texts, None) 23 | else: 24 | # Embed text queries 25 | embed = self.text_encoding_func(texts) 26 | self.text_embedding_dict[name_key] = (texts, embed) 27 | 28 | def is_embed_valid(self, name_key: str) -> bool: 29 | return name_key in self.text_embedding_dict and self.text_embedding_dict[name_key][1] is not None 30 | 31 | def get_text_embed(self, name_key: str) -> Optional[torch.Tensor]: 32 | return self.text_embedding_dict[name_key][1] 33 | 34 | def get_embed_shape(self, name_key: str) -> Optional[Tuple[int]]: 35 | embed = self.get_text_embed(name_key) 36 | if embed is not None: 37 | return embed.shape 38 | return None 39 | 40 | def update_softmax_temp(self, temp: float): 41 | self.softmax_temp = temp 42 | 43 | def reset_pca_proj(self): 44 | self.pca_proj = None 45 | CONSOLE.print("Reset PCA projection") 46 | 47 | def apply_pca_colormap_return_proj( 48 | image: TensorType["bs":..., "d"], 49 | proj_V: Optional[TensorType] = None, 50 | low_rank_min: Optional[TensorType] = None, 51 | low_rank_max: Optional[TensorType] = None, 52 | niter: int = 5, 53 | ) -> TensorType["bs":..., "rgb":3]: 54 | """Convert a multichannel image to color using PCA. 55 | 56 | Args: 57 | image: Multichannel image. 58 | proj_V: Projection matrix to use. If None, use torch low rank PCA. 59 | 60 | Returns: 61 | Colored PCA image of the multichannel input image. 62 | """ 63 | image_flat = image.reshape(-1, image.shape[-1]) 64 | 65 | # Modified from https://github.com/pfnet-research/distilled-feature-fields/blob/master/train.py 66 | if proj_V is None: 67 | mean = image_flat.mean(0) 68 | with torch.no_grad(): 69 | U, S, V = torch.pca_lowrank(image_flat - mean, niter=niter) 70 | proj_V = V[:, :3] 71 | 72 | low_rank = image_flat @ proj_V 73 | if low_rank_min is None: 74 | low_rank_min = torch.quantile(low_rank, 0.01, dim=0) 75 | if low_rank_max is None: 76 | low_rank_max = torch.quantile(low_rank, 0.99, dim=0) 77 | 78 | low_rank = (low_rank - low_rank_min) / (low_rank_max - low_rank_min) 79 | low_rank = torch.clamp(low_rank, 0, 1) 80 | 81 | colored_image = low_rank.reshape(image.shape[:-1] + (3,)) 82 | return colored_image, proj_V, low_rank_min, low_rank_max 83 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "feature_splatting" 3 | version = "0.0.3" 4 | description = "Feature Splatting" 5 | readme = "README.md" 6 | requires-python = ">=3.8" 7 | license = { file = "LICENSE" } 8 | keywords = ["gaussian splatting", "feature fields", "physics simulation", "3D segmentation"] 9 | authors = [ 10 | { name = "Ri-Zhao Qiu", email = "riqiu@ucsd.edu" }, 11 | ] 12 | 13 | dependencies = [ 14 | "einops", 15 | "ftfy", 16 | "gdown", 17 | "matplotlib", 18 | "nerfstudio", 19 | "numpy", 20 | "gsplat>=1.0.0", 21 | "pillow", 22 | "regex", 23 | "torchtyping", 24 | "tqdm", 25 | "segment-anything @ git+https://github.com/facebookresearch/segment-anything.git", 26 | "MobileSAMV2 @ git+https://github.com/RogerQi/MobileSAMV2.git", 27 | "maskclip_onnx @ git+https://github.com/RogerQi/maskclip_onnx.git", 28 | "taichi", 29 | ] 30 | 31 | [project.optional-dependencies] 32 | dev = [ 33 | "black", 34 | "isort", 35 | ] 36 | 37 | [tool.black] 38 | line-length = 120 39 | 40 | [project.urls] 41 | "Homepage" = "https://feature-splatting.github.io" 42 | "Source" = "https://github.com/vuer-ai/feature-splatting" 43 | 44 | [project.entry-points."nerfstudio.method_configs"] 45 | feature-splatting = "feature_splatting.feature_splatting_config:feature_splatting_method" 46 | 47 | [tool.setuptools.packages] 48 | find = { include = ["feature_splatting", "feature_splatting.*"] } 49 | 50 | [build-system] 51 | requires = ["setuptools>=61.0"] 52 | build-backend = "setuptools.build_meta" 53 | --------------------------------------------------------------------------------