├── .gitignore ├── LICENSE ├── NTC.py ├── README.md ├── assets ├── setting.png └── teaser.png ├── main.py ├── renderer_cuda.py ├── renderer_ogl.py ├── requirements.txt ├── shaders ├── gau_frag.glsl └── gau_vert.glsl ├── util.py ├── util_3dgstream.py └── util_gau.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.png 3 | !teaser.png 4 | !setting.png 5 | *.npz 6 | *.ini 7 | *.dump 8 | _example.py 9 | text.py 10 | toy_cuda_gs.py 11 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Li Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NTC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | class NeuralTransformationCache(torch.nn.Module): 3 | def __init__(self, model, xyz_bound_min, xyz_bound_max): 4 | super(NeuralTransformationCache, self).__init__() 5 | self.model = model 6 | self.register_buffer('xyz_bound_min',xyz_bound_min) 7 | self.register_buffer('xyz_bound_max',xyz_bound_max) 8 | 9 | def dump(self, path): 10 | torch.save(self.state_dict(),path) 11 | 12 | def get_contracted_xyz(self, xyz): 13 | with torch.no_grad(): 14 | contracted_xyz=(xyz-self.xyz_bound_min)/(self.xyz_bound_max-self.xyz_bound_min) 15 | return contracted_xyz 16 | 17 | def forward(self, xyz:torch.Tensor): 18 | contracted_xyz=self.get_contracted_xyz(xyz) # Shape: [N, 3] 19 | 20 | mask = (contracted_xyz >= 0) & (contracted_xyz <= 1) 21 | mask = mask.all(dim=1) 22 | 23 | res_cache_inputs=torch.cat([contracted_xyz[mask]],dim=-1) 24 | resi=self.model(res_cache_inputs) 25 | 26 | masked_d_xyz=resi[:,:3] 27 | masked_d_rot=resi[:,3:7] 28 | # masked_d_opacity=resi[:,7:None] 29 | 30 | d_xyz = torch.full((xyz.shape[0], 3), 0.0, dtype=torch.half, device="cuda") 31 | d_rot = torch.full((xyz.shape[0], 4), 0.0, dtype=torch.half, device="cuda") 32 | d_rot[:, 0] = 1.0 33 | # d_opacity = self._origin_d_opacity.clone() 34 | 35 | d_xyz[mask] = masked_d_xyz 36 | d_rot[mask] = masked_d_rot 37 | 38 | return mask, d_xyz, d_rot 39 | 40 | 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tiny 3DGStreamViewer 2 | This is a fork of [GaussianSplattingViewer](https://github.com/limacv/GaussianSplattingViewer), as a simple viewer for our CVPR 2024 paper [3DGStream](https://sjojok.github.io/3dgstream/). 3 | 4 | Note this is not the renderer we used to evaluate the render performance of 3DGStream in our paper "3DGStream: On-the-Fly Training of 3D Gaussians for Efficient Streaming of Photo-Realistic Free-Viewpoint Videos", but is still efficient enough for real-time renderable FVVs 5 | 6 | ## Availble FVVs: 7 | 8 | [Flame Steak](https://drive.google.com/file/d/1AXDqSzSaT_uNu_DhKeSmZmrBAfuOhWYY/view?usp=drive_link) 9 | 10 | ## Usage 11 | Install the dependencies: 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | Install [Pytorch](https://pytorch.org/) w/ CUDA 16 | 17 | Install the [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization) following the guidance [here](https://github.com/graphdeco-inria/gaussian-splatting). 18 | 19 | Install the [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn). 20 | 21 | Install the following package: 22 | ``` 23 | pip install cuda-python 24 | ``` 25 | 26 | Launch the viewer: 27 | ``` 28 | python main.py 29 | ``` 30 | 31 | ### To view Free-Viewpoint Videos of 3DGStream 32 | 33 | 1. Unzip the zip file at anywhere you like 34 | 35 | ![image](https://github.com/SJoJoK/3DGStreamViewer/assets/50450335/011675a5-d8d6-410e-ab82-5572e71fe6bd) 36 | 37 | 2. Launch the viewer: 38 | 39 | ``` 40 | python main.py 41 | ``` 42 | 43 | 3. Click `load ply` and open the `init_3dgs.ply` 44 | 45 | ![image](https://github.com/SJoJoK/3DGStreamViewer/assets/50450335/c5879abe-7752-4229-ae09-d71992ab3114) 46 | 47 | 4. Move the camera to a proper position 48 | 49 | ![image](https://github.com/SJoJoK/3DGStreamViewer/assets/50450335/3e4c437a-ba1e-40f8-b022-3e88090b2a97) 50 | 51 | 5. Click `load FVV` and choose the directory where you unzip the FVV 52 | 53 | 6. Click `Step` to step into next frame, click `Play` or `Pause` to play or pause the FVV, and click `Reset` to get back to Frame 0 54 | 55 | Happy Hacking! 56 | 57 | ## TO-DO List (Not guaranteed) 58 | 59 | - [ ] View the results of Stage 1 and/or Stage 2 60 | - [ ] Support OpenGL-backend 61 | - [ ] Align the functionality with the origin repo. 62 | 63 | ## Code-paper discrepancies 64 | 65 | 1. We discarded the SH rotation, as stated in the paper, due to its costly and unnecessary. 66 | 2. The renderer we used to evaluate the render performance is the official [SIBR Viewer](https://gitlab.inria.fr/sibr/sibr_core), which has an highly-optimized OpenGL backend. While, we believe that an open-srouce viewer based on CUDARasterizer is more configurable and helpful for researchers. 67 | 68 | ## Contributing 69 | 70 | This project is a tiny viewer designed for simplicity and ease of use. We welcome contributions that aim to improve performance or extend functionality. If you have ideas or improvements, please feel free to submit a pull request. 71 | 72 | ## Acknowledgements 73 | 74 | We would like to express our gratitude to the original repository [GaussianSplattingViewer](https://github.com/limacv/GaussianSplattingViewer) for providing the foundation upon which this work is built. 75 | 76 | -------------------------------------------------------------------------------- /assets/setting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStreamViewer/bb6879ca6c198eedd3b9331c1ba2e068aac76a0d/assets/setting.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJoJoK/3DGStreamViewer/bb6879ca6c198eedd3b9331c1ba2e068aac76a0d/assets/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import glfw 2 | import OpenGL.GL as gl 3 | from imgui.integrations.glfw import GlfwRenderer 4 | import imgui 5 | import numpy as np 6 | import util 7 | import imageio 8 | import util_gau 9 | import util_3dgstream 10 | import time 11 | import tkinter as tk 12 | from tkinter import filedialog 13 | import os 14 | import sys 15 | import argparse 16 | from renderer_ogl import OpenGLRenderer, GaussianRenderBase 17 | 18 | 19 | # Add the directory containing main.py to the Python path 20 | dir_path = os.path.dirname(os.path.realpath(__file__)) 21 | sys.path.append(dir_path) 22 | 23 | # Change the current working directory to the script's directory 24 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 25 | 26 | 27 | g_camera = util.Camera(720, 1280) 28 | BACKEND_OGL=0 29 | BACKEND_CUDA=1 30 | g_renderer_list = [ 31 | None, # ogl 32 | ] 33 | g_renderer_idx = BACKEND_OGL 34 | g_renderer = g_renderer_list[g_renderer_idx] 35 | g_scale_modifier = 1. 36 | g_auto_sort = False 37 | g_show_control_win = True 38 | g_show_help_win = True 39 | g_show_camera_win = False 40 | g_render_mode_tables = ["Gaussian Ball", "Flat Ball", "Billboard", "Depth", "SH:0", "SH:0~1", "SH:0~2", "SH:0~3 (default)"] 41 | g_render_mode = 7 42 | g_FVV_path="" 43 | VIDEO_FPS = 30.0 44 | VIDEO_INTERVAL = 1.0 / VIDEO_FPS 45 | 46 | g_last_frame_time = 0.0 47 | g_timestep = 0 48 | g_paused = True 49 | g_reset = False 50 | g_total_frame = 300 51 | def impl_glfw_init(): 52 | window_name = "Tiny 3DGStream Viewer" 53 | 54 | if not glfw.init(): 55 | print("Could not initialize OpenGL context") 56 | exit(1) 57 | 58 | glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 4) 59 | glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3) 60 | glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) 61 | # glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, gl.GL_TRUE) 62 | 63 | # Create a windowed mode window and its OpenGL context 64 | global window 65 | window = glfw.create_window( 66 | g_camera.w, g_camera.h, window_name, None, None 67 | ) 68 | glfw.make_context_current(window) 69 | glfw.swap_interval(0) 70 | # glfw.set_input_mode(window, glfw.CURSOR, glfw.CURSOR_NORMAL); 71 | if not window: 72 | glfw.terminate() 73 | print("Could not initialize Window") 74 | exit(1) 75 | 76 | return window 77 | 78 | def cursor_pos_callback(window, xpos, ypos): 79 | if imgui.get_io().want_capture_mouse: 80 | g_camera.is_leftmouse_pressed = False 81 | g_camera.is_rightmouse_pressed = False 82 | g_camera.process_mouse(xpos, ypos) 83 | 84 | def mouse_button_callback(window, button, action, mod): 85 | if imgui.get_io().want_capture_mouse: 86 | return 87 | pressed = action == glfw.PRESS 88 | g_camera.is_leftmouse_pressed = (button == glfw.MOUSE_BUTTON_LEFT and pressed) 89 | g_camera.is_rightmouse_pressed = (button == glfw.MOUSE_BUTTON_RIGHT and pressed) 90 | 91 | def wheel_callback(window, dx, dy): 92 | g_camera.process_wheel(dx, dy) 93 | 94 | def key_callback(window, key, scancode, action, mods): 95 | if action == glfw.REPEAT or action == glfw.PRESS: 96 | if key == glfw.KEY_Q: 97 | g_camera.process_roll_key(1) 98 | elif key == glfw.KEY_E: 99 | g_camera.process_roll_key(-1) 100 | 101 | def update_camera_pose_lazy(): 102 | if g_camera.is_pose_dirty: 103 | g_renderer.update_camera_pose(g_camera) 104 | g_camera.is_pose_dirty = False 105 | 106 | def update_camera_intrin_lazy(): 107 | if g_camera.is_intrin_dirty: 108 | g_renderer.update_camera_intrin(g_camera) 109 | g_camera.is_intrin_dirty = False 110 | 111 | def update_activated_renderer_state(gaus: util_gau.GaussianData): 112 | g_renderer.update_gaussian_data(gaus) 113 | g_renderer.sort_and_update(g_camera) 114 | g_renderer.set_scale_modifier(g_scale_modifier) 115 | g_renderer.set_render_mod(g_render_mode - 3) 116 | g_renderer.update_camera_pose(g_camera) 117 | g_renderer.update_camera_intrin(g_camera) 118 | g_renderer.set_render_reso(g_camera.w, g_camera.h) 119 | 120 | def window_resize_callback(window, width, height): 121 | gl.glViewport(0, 0, width, height) 122 | g_camera.update_resolution(height, width) 123 | g_renderer.set_render_reso(width, height) 124 | 125 | def main(): 126 | global g_camera, g_renderer, g_renderer_list, g_renderer_idx, g_scale_modifier, g_auto_sort, \ 127 | g_show_control_win, g_show_help_win, g_show_camera_win, \ 128 | g_render_mode, g_render_mode_tables, \ 129 | g_FVV_path, g_paused, g_reset, g_timestep, g_last_frame_time, g_total_frame 130 | 131 | imgui.create_context() 132 | if args.hidpi: 133 | imgui.get_io().font_global_scale = 1.5 134 | window = impl_glfw_init() 135 | impl = GlfwRenderer(window) 136 | root = tk.Tk() # used for file dialog 137 | root.withdraw() 138 | 139 | glfw.set_cursor_pos_callback(window, cursor_pos_callback) 140 | glfw.set_mouse_button_callback(window, mouse_button_callback) 141 | glfw.set_scroll_callback(window, wheel_callback) 142 | glfw.set_key_callback(window, key_callback) 143 | 144 | glfw.set_window_size_callback(window, window_resize_callback) 145 | 146 | # init renderer 147 | g_renderer_list[BACKEND_OGL] = OpenGLRenderer(g_camera.w, g_camera.h) 148 | try: 149 | from renderer_cuda import CUDARenderer 150 | g_renderer_list += [CUDARenderer(g_camera.w, g_camera.h)] 151 | except ImportError: 152 | pass 153 | 154 | g_renderer_idx = BACKEND_CUDA 155 | g_renderer = g_renderer_list[g_renderer_idx] 156 | 157 | # gaussian data 158 | gaussians = util_gau.naive_gaussian() 159 | update_activated_renderer_state(gaussians) 160 | g_last_frame_time=time.time() 161 | 162 | #debug only 163 | # gaussians = util_gau.load_ply("F:\\3dgstream\\flame_steak\\init_3dgs.ply") 164 | # g_renderer.update_gaussian_data(gaussians) 165 | # g_renderer.sort_and_update(g_camera) 166 | # g_renderer.NTCs = util_3dgstream.load_NTCs("F:\\3dgstream\\flame_steak", g_renderer.gaussians, 300) 167 | # g_renderer.additional_3dgs = util_3dgstream.load_Additions("F:\\3dgstream\\flame_steak", 300) 168 | 169 | # settings 170 | while not glfw.window_should_close(window): 171 | glfw.poll_events() 172 | impl.process_inputs() 173 | imgui.new_frame() 174 | 175 | gl.glClearColor(0, 0, 0, 1.0) 176 | gl.glClear(gl.GL_COLOR_BUFFER_BIT) 177 | 178 | update_camera_pose_lazy() 179 | update_camera_intrin_lazy() 180 | current_time=time.time() 181 | if current_time - g_last_frame_time >= VIDEO_INTERVAL and not g_paused and g_timestep < g_total_frame-1: 182 | g_timestep+=1 183 | g_last_frame_time = current_time 184 | if g_reset: 185 | g_renderer.fvv_reset() 186 | g_reset = False 187 | g_last_frame_time = time.time() 188 | g_renderer.draw(g_timestep) 189 | 190 | # imgui ui 191 | if imgui.begin_main_menu_bar(): 192 | if imgui.begin_menu("Window", True): 193 | clicked, g_show_control_win = imgui.menu_item( 194 | "Show Control", None, g_show_control_win 195 | ) 196 | clicked, g_show_help_win = imgui.menu_item( 197 | "Show Help", None, g_show_help_win 198 | ) 199 | clicked, g_show_camera_win = imgui.menu_item( 200 | "Show Camera Control", None, g_show_camera_win 201 | ) 202 | imgui.end_menu() 203 | imgui.end_main_menu_bar() 204 | 205 | if g_show_control_win: 206 | if imgui.begin("Control", True): 207 | # rendering backend 208 | changed, g_renderer_idx = imgui.combo("backend", g_renderer_idx, ["ogl", "cuda"][:len(g_renderer_list)]) 209 | if changed: 210 | g_renderer = g_renderer_list[g_renderer_idx] 211 | update_activated_renderer_state(gaussians) 212 | 213 | imgui.text(f"# of Gaus = {len(gaussians)}") 214 | 215 | imgui.text(f"Render FPS = {imgui.get_io().framerate:.1f}") 216 | imgui.text(f"Video FPS = {VIDEO_FPS:.1f}") 217 | imgui.text(f"FVV Dir:{g_FVV_path}") 218 | imgui.text(f"Frame {g_timestep}") 219 | imgui.text(f"#Frames: ") 220 | imgui.same_line() 221 | total_frame_changed, g_total_frame = imgui.slider_int( 222 | "frames", g_total_frame, 1, 300 223 | ) 224 | if imgui.button("Pause"): 225 | g_paused = True 226 | g_last_frame_time=time.time() 227 | 228 | imgui.same_line() 229 | 230 | if imgui.button("Play"): 231 | g_paused = False 232 | g_last_frame_time=time.time() 233 | 234 | imgui.same_line() 235 | 236 | if imgui.button("Reset"): 237 | g_paused = True 238 | g_reset = True 239 | g_timestep=0 240 | g_last_frame_time=time.time() 241 | 242 | imgui.same_line() 243 | 244 | if imgui.button("Step"): 245 | g_timestep+=1 246 | g_last_frame_time=time.time() 247 | 248 | if imgui.button(label='load ply'): 249 | file_path = filedialog.askopenfilename(title="load ply", 250 | initialdir="C:\\Users", 251 | filetypes=[('ply file', '.ply')] 252 | ) 253 | if file_path: 254 | try: 255 | gaussians = util_gau.load_ply(file_path) 256 | g_renderer.update_gaussian_data(gaussians) 257 | g_renderer.sort_and_update(g_camera) 258 | except RuntimeError as e: 259 | pass 260 | 261 | imgui.same_line() 262 | 263 | if imgui.button(label='save ply'): 264 | file_path = filedialog.asksaveasfilename(title="save ply", 265 | initialdir="C:\\Users\\", 266 | defaultextension=".txt", 267 | filetypes=[('ply file', '.ply')] 268 | ) 269 | if file_path: 270 | try: 271 | util_3dgstream.save_gau_cuda(g_renderer.gaussians, file_path) 272 | except RuntimeError as e: 273 | pass 274 | 275 | imgui.same_line() 276 | 277 | if imgui.button(label='load FVV'): 278 | dir_path = filedialog.askdirectory(title="load FVV", 279 | initialdir="C:\\Users" 280 | ) 281 | if dir_path: 282 | try: 283 | g_FVV_path = dir_path 284 | g_renderer.NTCs = util_3dgstream.load_NTCs(g_FVV_path, g_renderer.gaussians, g_total_frame) 285 | g_renderer.additional_3dgs = util_3dgstream.load_Additions(g_FVV_path, g_total_frame) 286 | except RuntimeError as e: 287 | pass 288 | # camera fov 289 | changed, g_camera.fovy = imgui.slider_float( 290 | "fov", g_camera.fovy, 0.001, np.pi - 0.001, "fov = %.3f" 291 | ) 292 | g_camera.is_intrin_dirty = changed 293 | update_camera_intrin_lazy() 294 | 295 | # scale modifier 296 | changed, g_scale_modifier = imgui.slider_float( 297 | "", g_scale_modifier, 0.1, 10, "scale modifier = %.3f" 298 | ) 299 | imgui.same_line() 300 | if imgui.button(label="reset"): 301 | g_scale_modifier = 1. 302 | changed = True 303 | 304 | if changed: 305 | g_renderer.set_scale_modifier(g_scale_modifier) 306 | 307 | # render mode 308 | changed, g_render_mode = imgui.combo("shading", g_render_mode, g_render_mode_tables) 309 | if changed: 310 | g_renderer.set_render_mod(g_render_mode - 4) 311 | 312 | # sort button 313 | if imgui.button(label='sort Gaussians'): 314 | g_renderer.sort_and_update(g_camera) 315 | imgui.same_line() 316 | changed, g_auto_sort = imgui.checkbox( 317 | "auto sort", g_auto_sort, 318 | ) 319 | if g_auto_sort: 320 | g_renderer.sort_and_update(g_camera) 321 | 322 | if imgui.button(label='save image'): 323 | width, height = glfw.get_framebuffer_size(window) 324 | nrChannels = 3; 325 | stride = nrChannels * width; 326 | stride += (4 - stride % 4) if stride % 4 else 0 327 | gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 4) 328 | gl.glReadBuffer(gl.GL_FRONT) 329 | bufferdata = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) 330 | img = np.frombuffer(bufferdata, np.uint8, -1).reshape(height, width, 3) 331 | imageio.imwrite("save.png", img[::-1]) 332 | # save intermediate information 333 | # np.savez( 334 | # "save.npz", 335 | # gau_xyz=gaussians.xyz, 336 | # gau_s=gaussians.scale, 337 | # gau_rot=gaussians.rot, 338 | # gau_c=gaussians.sh, 339 | # gau_a=gaussians.opacity, 340 | # viewmat=g_camera.get_view_matrix(), 341 | # projmat=g_camera.get_project_matrix(), 342 | # hfovxyfocal=g_camera.get_htanfovxy_focal() 343 | # ) 344 | # Add buttons directly in the main menu bar for control actions 345 | 346 | imgui.end() 347 | 348 | if g_show_camera_win: 349 | if imgui.button(label='rot 180'): 350 | g_camera.flip_ground() 351 | 352 | changed, g_camera.target_dist = imgui.slider_float( 353 | "t", g_camera.target_dist, 1., 8., "target dist = %.3f" 354 | ) 355 | if changed: 356 | g_camera.update_target_distance() 357 | 358 | changed, g_camera.rot_sensitivity = imgui.slider_float( 359 | "r", g_camera.rot_sensitivity, 0.002, 0.1, "rotate speed = %.3f" 360 | ) 361 | imgui.same_line() 362 | if imgui.button(label="reset r"): 363 | g_camera.rot_sensitivity = 0.02 364 | 365 | changed, g_camera.trans_sensitivity = imgui.slider_float( 366 | "m", g_camera.trans_sensitivity, 0.001, 0.03, "move speed = %.3f" 367 | ) 368 | imgui.same_line() 369 | if imgui.button(label="reset m"): 370 | g_camera.trans_sensitivity = 0.01 371 | 372 | changed, g_camera.zoom_sensitivity = imgui.slider_float( 373 | "z", g_camera.zoom_sensitivity, 0.001, 0.05, "zoom speed = %.3f" 374 | ) 375 | imgui.same_line() 376 | if imgui.button(label="reset z"): 377 | g_camera.zoom_sensitivity = 0.01 378 | 379 | changed, g_camera.roll_sensitivity = imgui.slider_float( 380 | "ro", g_camera.roll_sensitivity, 0.003, 0.1, "roll speed = %.3f" 381 | ) 382 | imgui.same_line() 383 | if imgui.button(label="reset ro"): 384 | g_camera.roll_sensitivity = 0.03 385 | 386 | if g_show_help_win: 387 | imgui.begin("Help", True) 388 | imgui.text("Open Gaussian Splatting PLY file \n by click 'open ply' button") 389 | imgui.text("Use left click & move to rotate camera") 390 | imgui.text("Use right click & move to translate camera") 391 | imgui.text("Press Q/E to roll camera") 392 | imgui.text("Use scroll to zoom in/out") 393 | imgui.text("Use control panel to change setting") 394 | imgui.end() 395 | 396 | imgui.render() 397 | impl.render(imgui.get_draw_data()) 398 | glfw.swap_buffers(window) 399 | 400 | impl.shutdown() 401 | glfw.terminate() 402 | 403 | 404 | if __name__ == "__main__": 405 | global args 406 | parser = argparse.ArgumentParser(description="Tiny 3DGStream Viewer.") 407 | parser.add_argument("--hidpi", action="store_true", help="Enable HiDPI scaling for the interface.") 408 | args = parser.parse_args() 409 | 410 | main() 411 | -------------------------------------------------------------------------------- /renderer_cuda.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Part of the code (CUDA and OpenGL memory transfer) is derived from https://github.com/jbaron34/torchwindow/tree/master 3 | ''' 4 | from OpenGL import GL as gl 5 | import OpenGL.GL.shaders as shaders 6 | import util 7 | import util_gau 8 | import numpy as np 9 | import torch 10 | from renderer_ogl import GaussianRenderBase 11 | from dataclasses import dataclass 12 | from cuda import cudart as cu 13 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 14 | 15 | 16 | VERTEX_SHADER_SOURCE = """ 17 | #version 450 18 | 19 | smooth out vec4 fragColor; 20 | smooth out vec2 texcoords; 21 | 22 | vec4 positions[3] = vec4[3]( 23 | vec4(-1.0, 1.0, 0.0, 1.0), 24 | vec4(3.0, 1.0, 0.0, 1.0), 25 | vec4(-1.0, -3.0, 0.0, 1.0) 26 | ); 27 | 28 | vec2 texpos[3] = vec2[3]( 29 | vec2(0, 0), 30 | vec2(2, 0), 31 | vec2(0, 2) 32 | ); 33 | 34 | void main() { 35 | gl_Position = positions[gl_VertexID]; 36 | texcoords = texpos[gl_VertexID]; 37 | } 38 | """ 39 | 40 | FRAGMENT_SHADER_SOURCE = """ 41 | #version 330 42 | 43 | smooth in vec2 texcoords; 44 | 45 | out vec4 outputColour; 46 | 47 | uniform sampler2D texSampler; 48 | 49 | void main() 50 | { 51 | outputColour = texture(texSampler, texcoords); 52 | } 53 | """ 54 | 55 | def quaternion_multiply(a, b): 56 | a_norm=torch.nn.functional.normalize(a) 57 | b_norm=torch.nn.functional.normalize(b) 58 | w1, x1, y1, z1 = a_norm[:, 0], a_norm[:, 1], a_norm[:, 2], a_norm[:, 3] 59 | w2, x2, y2, z2 = b_norm[:, 0], b_norm[:, 1], b_norm[:, 2], b_norm[:, 3] 60 | 61 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 62 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 63 | y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 64 | z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 65 | 66 | return torch.stack([w, x, y, z], dim=1) 67 | 68 | @dataclass 69 | class GaussianDataCUDA: 70 | xyz: torch.Tensor 71 | rot: torch.Tensor 72 | scale: torch.Tensor 73 | opacity: torch.Tensor 74 | sh: torch.Tensor 75 | 76 | def __len__(self): 77 | return len(self.xyz) 78 | 79 | @property 80 | def sh_dim(self): 81 | return self.sh.shape[-2] 82 | 83 | @torch.no_grad() 84 | def get_xyz_bound(self, percentile=86.6): 85 | half_percentile = (100 - percentile) / 200 86 | return torch.quantile(self.xyz,half_percentile,dim=0), torch.quantile(self.xyz,1 - half_percentile,dim=0) 87 | 88 | def clone(self): 89 | return GaussianDataCUDA( 90 | xyz=self.xyz.clone(), 91 | rot=self.rot.clone(), 92 | scale=self.scale.clone(), 93 | opacity=self.opacity.clone(), 94 | sh=self.sh.clone(), 95 | ) 96 | 97 | @dataclass 98 | class GaussianRasterizationSettingsStorage: 99 | image_height: int 100 | image_width: int 101 | tanfovx : float 102 | tanfovy : float 103 | bg : torch.Tensor 104 | scale_modifier : float 105 | viewmatrix : torch.Tensor 106 | projmatrix : torch.Tensor 107 | sh_degree : int 108 | campos : torch.Tensor 109 | prefiltered : bool 110 | debug : bool 111 | 112 | 113 | def gaus_cuda_from_cpu(gau: util_gau) -> GaussianDataCUDA: 114 | gaus = GaussianDataCUDA( 115 | xyz = torch.tensor(gau.xyz).float().cuda().requires_grad_(False), 116 | rot = torch.tensor(gau.rot).float().cuda().requires_grad_(False), 117 | scale = torch.tensor(gau.scale).float().cuda().requires_grad_(False), 118 | opacity = torch.tensor(gau.opacity).float().cuda().requires_grad_(False), 119 | sh = torch.tensor(gau.sh).float().cuda().requires_grad_(False) 120 | ) 121 | gaus.sh = gaus.sh.reshape(len(gaus), -1, 3).contiguous() 122 | return gaus 123 | 124 | 125 | class CUDARenderer(GaussianRenderBase): 126 | def __init__(self, w, h): 127 | super().__init__() 128 | self.raster_settings = { 129 | "image_height": int(h), 130 | "image_width": int(w), 131 | "tanfovx": 1, 132 | "tanfovy": 1, 133 | "bg": torch.Tensor([0., 0., 0]).float().cuda(), 134 | "scale_modifier": 1., 135 | "viewmatrix": None, 136 | "projmatrix": None, 137 | "sh_degree": 1, # ? 138 | "campos": None, 139 | "prefiltered": False, 140 | "debug": False 141 | } 142 | gl.glViewport(0, 0, w, h) 143 | self.program = util.compile_shaders(VERTEX_SHADER_SOURCE, FRAGMENT_SHADER_SOURCE) 144 | # setup cuda 145 | err, *_ = cu.cudaGLGetDevices(1, cu.cudaGLDeviceList.cudaGLDeviceListAll) 146 | if err == cu.cudaError_t.cudaErrorUnknown: 147 | raise RuntimeError( 148 | "OpenGL context may be running on integrated graphics" 149 | ) 150 | 151 | self.vao = gl.glGenVertexArrays(1) 152 | self.tex = None 153 | self.NTC = None 154 | # the index of NTCs and additional_3dgs is the index of the current un-processed frame. 155 | self.NTCs = [] 156 | self.additional_3dgs = [] 157 | self.current_timestep=0 158 | self.set_gl_texture(h, w) 159 | 160 | gl.glDisable(gl.GL_CULL_FACE) 161 | gl.glEnable(gl.GL_BLEND) 162 | gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) 163 | 164 | def update_gaussian_data(self, gaus: util_gau.GaussianData): 165 | self.gaussians = gaus_cuda_from_cpu(gaus) 166 | self.init_gaussians = GaussianDataCUDA( 167 | xyz = self.gaussians.xyz.clone(), 168 | rot = self.gaussians.rot.clone(), 169 | scale = self.gaussians.scale.clone(), 170 | opacity = self.gaussians.opacity.clone(), 171 | sh = self.gaussians.sh.clone() 172 | ) 173 | self.raster_settings["sh_degree"] = int(np.round(np.sqrt(self.gaussians.sh_dim))) - 1 174 | 175 | def sort_and_update(self, camera: util.Camera): 176 | pass 177 | 178 | def set_scale_modifier(self, modifier): 179 | self.raster_settings["scale_modifier"] = float(modifier) 180 | 181 | def set_render_mod(self, mod: int): 182 | pass 183 | 184 | def set_gl_texture(self, h, w): 185 | self.tex = gl.glGenTextures(1) 186 | gl.glBindTexture(gl.GL_TEXTURE_2D, self.tex) 187 | gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_REPEAT) 188 | gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_REPEAT) 189 | gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) 190 | gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) 191 | gl.glTexImage2D( 192 | gl.GL_TEXTURE_2D, 193 | 0, 194 | gl.GL_RGBA32F, 195 | w, 196 | h, 197 | 0, 198 | gl.GL_RGBA, 199 | gl.GL_FLOAT, 200 | None, 201 | ) 202 | gl.glBindTexture(gl.GL_TEXTURE_2D, 0) 203 | err, self.cuda_image = cu.cudaGraphicsGLRegisterImage( 204 | self.tex, 205 | gl.GL_TEXTURE_2D, 206 | cu.cudaGraphicsRegisterFlags.cudaGraphicsRegisterFlagsWriteDiscard, 207 | ) 208 | if err != cu.cudaError_t.cudaSuccess: 209 | raise RuntimeError("Unable to register opengl texture") 210 | 211 | def set_render_reso(self, w, h): 212 | self.raster_settings["image_height"] = int(h) 213 | self.raster_settings["image_width"] = int(w) 214 | gl.glViewport(0, 0, w, h) 215 | self.set_gl_texture(h, w) 216 | 217 | @torch.no_grad() 218 | def query_NTC(self, xyz, timestep): 219 | mask, d_xyz, d_rot = self.NTCs[timestep](xyz) 220 | self.gaussians.xyz += d_xyz 221 | self.gaussians.rot = quaternion_multiply(self.gaussians.rot, d_rot) 222 | 223 | @torch.no_grad() 224 | def cat_additions(self, timestep): 225 | additions=self.additional_3dgs[timestep] 226 | s2_gaussians=GaussianDataCUDA( 227 | xyz=torch.cat([additions.xyz, self.gaussians.xyz], dim=0), 228 | rot=torch.cat([additions.rot, self.gaussians.rot], dim=0), 229 | scale=torch.cat([additions.scale, self.gaussians.scale], dim=0), 230 | opacity=torch.cat([additions.opacity, self.gaussians.opacity], dim=0), 231 | sh=torch.cat([additions.sh, self.gaussians.sh], dim=0) 232 | ) 233 | return s2_gaussians 234 | 235 | def fvv_reset(self): 236 | self.gaussians = self.init_gaussians.clone() 237 | self.current_timestep=0 238 | 239 | def update_camera_pose(self, camera: util.Camera): 240 | view_matrix = camera.get_view_matrix() 241 | view_matrix[[0, 2], :] = -view_matrix[[0, 2], :] 242 | proj = camera.get_project_matrix() @ view_matrix 243 | self.raster_settings["viewmatrix"] = torch.tensor(view_matrix.T).float().cuda() 244 | self.raster_settings["campos"] = torch.tensor(camera.position).float().cuda() 245 | self.raster_settings["projmatrix"] = torch.tensor(proj.T).float().cuda() 246 | 247 | def update_camera_intrin(self, camera: util.Camera): 248 | view_matrix = camera.get_view_matrix() 249 | view_matrix[[0, 2], :] = -view_matrix[[0, 2], :] 250 | proj = camera.get_project_matrix() @ view_matrix 251 | self.raster_settings["projmatrix"] = torch.tensor(proj.T).float().cuda() 252 | hfovx, hfovy, focal = camera.get_htanfovxy_focal() 253 | self.raster_settings["tanfovx"] = hfovx 254 | self.raster_settings["tanfovy"] = hfovy 255 | 256 | def draw(self, timestep: int = 0): 257 | raster_settings = GaussianRasterizationSettings(**self.raster_settings) 258 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 259 | # means2D = torch.zeros_like(self.gaussians.xyz, dtype=self.gaussians.xyz.dtype, requires_grad=False, device="cuda") 260 | rendered_gaussians = self.gaussians 261 | with torch.no_grad(): 262 | while(timestep-self.current_timestep>0): 263 | self.query_NTC(self.gaussians.xyz, self.current_timestep) 264 | self.current_timestep+=1 265 | if self.current_timestep!=0: 266 | rendered_gaussians=self.cat_additions(self.current_timestep-1) 267 | img, radii = rasterizer( 268 | means3D = rendered_gaussians.xyz, 269 | means2D = None, 270 | shs = rendered_gaussians.sh, 271 | colors_precomp = None, 272 | opacities = rendered_gaussians.opacity, 273 | scales = rendered_gaussians.scale, 274 | rotations = rendered_gaussians.rot, 275 | cov3D_precomp = None 276 | ) 277 | img = img.permute(1, 2, 0) 278 | img = torch.concat([img, torch.ones_like(img[..., :1])], dim=-1) 279 | img = img.contiguous() 280 | height, width = img.shape[:2] 281 | # transfer 282 | (err,) = cu.cudaGraphicsMapResources(1, self.cuda_image, cu.cudaStreamLegacy) 283 | if err != cu.cudaError_t.cudaSuccess: 284 | raise RuntimeError("Unable to map graphics resource") 285 | err, array = cu.cudaGraphicsSubResourceGetMappedArray(self.cuda_image, 0, 0) 286 | if err != cu.cudaError_t.cudaSuccess: 287 | raise RuntimeError("Unable to get mapped array") 288 | 289 | (err,) = cu.cudaMemcpy2DToArrayAsync( 290 | array, 291 | 0, 292 | 0, 293 | img.data_ptr(), 294 | 4 * 4 * width, 295 | 4 * 4 * width, 296 | height, 297 | cu.cudaMemcpyKind.cudaMemcpyDeviceToDevice, 298 | cu.cudaStreamLegacy, 299 | ) 300 | if err != cu.cudaError_t.cudaSuccess: 301 | raise RuntimeError("Unable to copy from tensor to texture") 302 | (err,) = cu.cudaGraphicsUnmapResources(1, self.cuda_image, cu.cudaStreamLegacy) 303 | if err != cu.cudaError_t.cudaSuccess: 304 | raise RuntimeError("Unable to unmap graphics resource") 305 | 306 | gl.glUseProgram(self.program) 307 | gl.glBindTexture(gl.GL_TEXTURE_2D, self.tex) 308 | gl.glBindVertexArray(self.vao) 309 | gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) 310 | -------------------------------------------------------------------------------- /renderer_ogl.py: -------------------------------------------------------------------------------- 1 | from OpenGL import GL as gl 2 | import util 3 | import util_gau 4 | import numpy as np 5 | 6 | _sort_buffer_xyz = None 7 | _sort_buffer_gausid = None # used to tell whether gaussian is reloaded 8 | 9 | def _sort_gaussian_cpu(gaus, view_mat): 10 | xyz = np.asarray(gaus.xyz) 11 | view_mat = np.asarray(view_mat) 12 | 13 | xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None] 14 | depth = xyz_view[:, 2, 0] 15 | 16 | index = np.argsort(depth) 17 | index = index.astype(np.int32).reshape(-1, 1) 18 | return index 19 | 20 | 21 | def _sort_gaussian_cupy(gaus, view_mat): 22 | import cupy as cp 23 | global _sort_buffer_gausid, _sort_buffer_xyz 24 | if _sort_buffer_gausid != id(gaus): 25 | _sort_buffer_xyz = cp.asarray(gaus.xyz) 26 | _sort_buffer_gausid = id(gaus) 27 | 28 | xyz = _sort_buffer_xyz 29 | view_mat = cp.asarray(view_mat) 30 | 31 | xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None] 32 | depth = xyz_view[:, 2, 0] 33 | 34 | index = cp.argsort(depth) 35 | index = index.astype(cp.int32).reshape(-1, 1) 36 | 37 | index = cp.asnumpy(index) # convert to numpy 38 | return index 39 | 40 | 41 | def _sort_gaussian_torch(gaus, view_mat): 42 | global _sort_buffer_gausid, _sort_buffer_xyz 43 | if _sort_buffer_gausid != id(gaus): 44 | _sort_buffer_xyz = torch.tensor(gaus.xyz).cuda() 45 | _sort_buffer_gausid = id(gaus) 46 | 47 | xyz = _sort_buffer_xyz 48 | view_mat = torch.tensor(view_mat).cuda() 49 | xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None] 50 | depth = xyz_view[:, 2, 0] 51 | index = torch.argsort(depth) 52 | index = index.type(torch.int32).reshape(-1, 1).cpu().numpy() 53 | return index 54 | 55 | 56 | # Decide which sort to use 57 | _sort_gaussian = None 58 | try: 59 | import torch 60 | if not torch.cuda.is_available(): 61 | raise ImportError 62 | print("Detect torch cuda installed, will use torch as sorting backend") 63 | _sort_gaussian = _sort_gaussian_torch 64 | except ImportError: 65 | try: 66 | import cupy as cp 67 | print("Detect cupy installed, will use cupy as sorting backend") 68 | _sort_gaussian = _sort_gaussian_cupy 69 | except ImportError: 70 | _sort_gaussian = _sort_gaussian_cpu 71 | 72 | 73 | class GaussianRenderBase: 74 | def __init__(self): 75 | self.gaussians = None 76 | 77 | def update_gaussian_data(self, gaus: util_gau.GaussianData): 78 | raise NotImplementedError() 79 | 80 | def sort_and_update(self): 81 | raise NotImplementedError() 82 | 83 | def set_scale_modifier(self, modifier: float): 84 | raise NotImplementedError() 85 | 86 | def set_render_mod(self, mod: int): 87 | raise NotImplementedError() 88 | 89 | def update_camera_pose(self, camera: util.Camera): 90 | raise NotImplementedError() 91 | 92 | def update_camera_intrin(self, camera: util.Camera): 93 | raise NotImplementedError() 94 | 95 | def draw(self): 96 | raise NotImplementedError() 97 | 98 | def set_render_reso(self, w, h): 99 | raise NotImplementedError() 100 | 101 | 102 | class OpenGLRenderer(GaussianRenderBase): 103 | def __init__(self, w, h): 104 | super().__init__() 105 | gl.glViewport(0, 0, w, h) 106 | self.program = util.load_shaders('shaders/gau_vert.glsl', 'shaders/gau_frag.glsl') 107 | 108 | # Vertex data for a quad 109 | self.quad_v = np.array([ 110 | -1, 1, 111 | 1, 1, 112 | 1, -1, 113 | -1, -1 114 | ], dtype=np.float32).reshape(4, 2) 115 | self.quad_f = np.array([ 116 | 0, 1, 2, 117 | 0, 2, 3 118 | ], dtype=np.uint32).reshape(2, 3) 119 | 120 | # load quad geometry 121 | vao, buffer_id = util.set_attributes(self.program, ["position"], [self.quad_v]) 122 | util.set_faces_tovao(vao, self.quad_f) 123 | self.vao = vao 124 | self.gau_bufferid = None 125 | self.index_bufferid = None 126 | # opengl settings 127 | gl.glDisable(gl.GL_CULL_FACE) 128 | gl.glEnable(gl.GL_BLEND) 129 | gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) 130 | 131 | def update_gaussian_data(self, gaus: util_gau.GaussianData): 132 | self.gaussians = gaus 133 | # load gaussian geometry 134 | gaussian_data = gaus.flat() 135 | self.gau_bufferid = util.set_storage_buffer_data(self.program, "gaussian_data", gaussian_data, 136 | bind_idx=0, 137 | buffer_id=self.gau_bufferid) 138 | util.set_uniform_1int(self.program, gaus.sh_dim, "sh_dim") 139 | 140 | def sort_and_update(self, camera: util.Camera): 141 | index = _sort_gaussian(self.gaussians, camera.get_view_matrix()) 142 | self.index_bufferid = util.set_storage_buffer_data(self.program, "gi", index, 143 | bind_idx=1, 144 | buffer_id=self.index_bufferid) 145 | return 146 | 147 | def set_scale_modifier(self, modifier): 148 | util.set_uniform_1f(self.program, modifier, "scale_modifier") 149 | 150 | def set_render_mod(self, mod: int): 151 | util.set_uniform_1int(self.program, mod, "render_mod") 152 | 153 | def set_render_reso(self, w, h): 154 | gl.glViewport(0, 0, w, h) 155 | 156 | def update_camera_pose(self, camera: util.Camera): 157 | view_mat = camera.get_view_matrix() 158 | util.set_uniform_mat4(self.program, view_mat, "view_matrix") 159 | util.set_uniform_v3(self.program, camera.position, "cam_pos") 160 | 161 | def update_camera_intrin(self, camera: util.Camera): 162 | proj_mat = camera.get_project_matrix() 163 | util.set_uniform_mat4(self.program, proj_mat, "projection_matrix") 164 | util.set_uniform_v3(self.program, camera.get_htanfovxy_focal(), "hfovxy_focal") 165 | 166 | def draw(self, timestep: int = 0): 167 | # run opengl rasterizer to render FVV is implemented. 168 | gl.glUseProgram(self.program) 169 | gl.glBindVertexArray(self.vao) 170 | num_gau = len(self.gaussians) 171 | gl.glDrawElementsInstanced(gl.GL_TRIANGLES, len(self.quad_f.reshape(-1)), gl.GL_UNSIGNED_INT, None, num_gau) 172 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | glfw 2 | PyGLM 3 | imgui 4 | PyOpenGL 5 | numpy 6 | imageio 7 | plyfile 8 | -------------------------------------------------------------------------------- /shaders/gau_frag.glsl: -------------------------------------------------------------------------------- 1 | #version 430 core 2 | 3 | in vec3 color; 4 | in float alpha; 5 | in vec3 conic; 6 | in vec2 coordxy; // local coordinate in quad, unit in pixel 7 | 8 | uniform int render_mod; // > 0 render 0-ith SH dim, -1 depth, -2 bill board, -3 flat ball, -4 gaussian ball 9 | 10 | out vec4 FragColor; 11 | 12 | void main() 13 | { 14 | if (render_mod == -2) 15 | { 16 | FragColor = vec4(color, 1.f); 17 | return; 18 | } 19 | 20 | float power = -0.5f * (conic.x * coordxy.x * coordxy.x + conic.z * coordxy.y * coordxy.y) - conic.y * coordxy.x * coordxy.y; 21 | if (power > 0.f) 22 | discard; 23 | float opacity = min(0.99f, alpha * exp(power)); 24 | if (opacity < 1.f / 255.f) 25 | discard; 26 | FragColor = vec4(color, opacity); 27 | 28 | // handling special shading effect 29 | if (render_mod == -3) 30 | FragColor.a = FragColor.a > 0.22 ? 1 : 0; 31 | else if (render_mod == -4) 32 | { 33 | FragColor.a = FragColor.a > 0.22 ? 1 : 0; 34 | FragColor.rgb = FragColor.rgb * exp(power); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /shaders/gau_vert.glsl: -------------------------------------------------------------------------------- 1 | #version 430 core 2 | 3 | #define SH_C0 0.28209479177387814f 4 | #define SH_C1 0.4886025119029199f 5 | 6 | #define SH_C2_0 1.0925484305920792f 7 | #define SH_C2_1 -1.0925484305920792f 8 | #define SH_C2_2 0.31539156525252005f 9 | #define SH_C2_3 -1.0925484305920792f 10 | #define SH_C2_4 0.5462742152960396f 11 | 12 | #define SH_C3_0 -0.5900435899266435f 13 | #define SH_C3_1 2.890611442640554f 14 | #define SH_C3_2 -0.4570457994644658f 15 | #define SH_C3_3 0.3731763325901154f 16 | #define SH_C3_4 -0.4570457994644658f 17 | #define SH_C3_5 1.445305721320277f 18 | #define SH_C3_6 -0.5900435899266435f 19 | 20 | layout(location = 0) in vec2 position; 21 | // layout(location = 1) in vec3 g_pos; 22 | // layout(location = 2) in vec4 g_rot; 23 | // layout(location = 3) in vec3 g_scale; 24 | // layout(location = 4) in vec3 g_dc_color; 25 | // layout(location = 5) in float g_opacity; 26 | 27 | 28 | #define POS_IDX 0 29 | #define ROT_IDX 3 30 | #define SCALE_IDX 7 31 | #define OPACITY_IDX 10 32 | #define SH_IDX 11 33 | 34 | layout (std430, binding=0) buffer gaussian_data { 35 | float g_data[]; 36 | // compact version of following data 37 | // vec3 g_pos[]; 38 | // vec4 g_rot[]; 39 | // vec3 g_scale[]; 40 | // float g_opacity[]; 41 | // vec3 g_sh[]; 42 | }; 43 | layout (std430, binding=1) buffer gaussian_order { 44 | int gi[]; 45 | }; 46 | 47 | uniform mat4 view_matrix; 48 | uniform mat4 projection_matrix; 49 | uniform vec3 hfovxy_focal; 50 | uniform vec3 cam_pos; 51 | uniform int sh_dim; 52 | uniform float scale_modifier; 53 | uniform int render_mod; // > 0 render 0-ith SH dim, -1 depth, -2 bill board, -3 gaussian 54 | 55 | out vec3 color; 56 | out float alpha; 57 | out vec3 conic; 58 | out vec2 coordxy; // local coordinate in quad, unit in pixel 59 | 60 | mat3 computeCov3D(vec3 scale, vec4 q) // should be correct 61 | { 62 | mat3 S = mat3(0.f); 63 | S[0][0] = scale.x; 64 | S[1][1] = scale.y; 65 | S[2][2] = scale.z; 66 | float r = q.x; 67 | float x = q.y; 68 | float y = q.z; 69 | float z = q.w; 70 | 71 | mat3 R = mat3( 72 | 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y), 73 | 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x), 74 | 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y) 75 | ); 76 | 77 | mat3 M = S * R; 78 | mat3 Sigma = transpose(M) * M; 79 | return Sigma; 80 | } 81 | 82 | vec3 computeCov2D(vec4 mean_view, float focal_x, float focal_y, float tan_fovx, float tan_fovy, mat3 cov3D, mat4 viewmatrix) 83 | { 84 | vec4 t = mean_view; 85 | // why need this? Try remove this later 86 | float limx = 1.3f * tan_fovx; 87 | float limy = 1.3f * tan_fovy; 88 | float txtz = t.x / t.z; 89 | float tytz = t.y / t.z; 90 | t.x = min(limx, max(-limx, txtz)) * t.z; 91 | t.y = min(limy, max(-limy, tytz)) * t.z; 92 | 93 | mat3 J = mat3( 94 | focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z), 95 | 0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z), 96 | 0, 0, 0 97 | ); 98 | mat3 W = transpose(mat3(viewmatrix)); 99 | mat3 T = W * J; 100 | 101 | mat3 cov = transpose(T) * transpose(cov3D) * T; 102 | // Apply low-pass filter: every Gaussian should be at least 103 | // one pixel wide/high. Discard 3rd row and column. 104 | cov[0][0] += 0.3f; 105 | cov[1][1] += 0.3f; 106 | return vec3(cov[0][0], cov[0][1], cov[1][1]); 107 | } 108 | 109 | vec3 get_vec3(int offset) 110 | { 111 | return vec3(g_data[offset], g_data[offset + 1], g_data[offset + 2]); 112 | } 113 | vec4 get_vec4(int offset) 114 | { 115 | return vec4(g_data[offset], g_data[offset + 1], g_data[offset + 2], g_data[offset + 3]); 116 | } 117 | 118 | void main() 119 | { 120 | int boxid = gi[gl_InstanceID]; 121 | int total_dim = 3 + 4 + 3 + 1 + sh_dim; 122 | int start = boxid * total_dim; 123 | vec4 g_pos = vec4(get_vec3(start + POS_IDX), 1.f); 124 | vec4 g_pos_view = view_matrix * g_pos; 125 | vec4 g_pos_screen = projection_matrix * g_pos_view; 126 | g_pos_screen.xyz = g_pos_screen.xyz / g_pos_screen.w; 127 | g_pos_screen.w = 1.f; 128 | // early culling 129 | if (any(greaterThan(abs(g_pos_screen.xyz), vec3(1.3)))) 130 | { 131 | gl_Position = vec4(-100, -100, -100, 1); 132 | return; 133 | } 134 | vec4 g_rot = get_vec4(start + ROT_IDX); 135 | vec3 g_scale = get_vec3(start + SCALE_IDX); 136 | float g_opacity = g_data[start + OPACITY_IDX]; 137 | 138 | mat3 cov3d = computeCov3D(g_scale * scale_modifier, g_rot); 139 | vec2 wh = 2 * hfovxy_focal.xy * hfovxy_focal.z; 140 | vec3 cov2d = computeCov2D(g_pos_view, 141 | hfovxy_focal.z, 142 | hfovxy_focal.z, 143 | hfovxy_focal.x, 144 | hfovxy_focal.y, 145 | cov3d, 146 | view_matrix); 147 | 148 | // Invert covariance (EWA algorithm) 149 | float det = (cov2d.x * cov2d.z - cov2d.y * cov2d.y); 150 | if (det == 0.0f) 151 | gl_Position = vec4(0.f, 0.f, 0.f, 0.f); 152 | 153 | float det_inv = 1.f / det; 154 | conic = vec3(cov2d.z * det_inv, -cov2d.y * det_inv, cov2d.x * det_inv); 155 | 156 | vec2 quadwh_scr = vec2(3.f * sqrt(cov2d.x), 3.f * sqrt(cov2d.z)); // screen space half quad height and width 157 | vec2 quadwh_ndc = quadwh_scr / wh * 2; // in ndc space 158 | g_pos_screen.xy = g_pos_screen.xy + position * quadwh_ndc; 159 | coordxy = position * quadwh_scr; 160 | gl_Position = g_pos_screen; 161 | 162 | alpha = g_opacity; 163 | 164 | if (render_mod == -1) 165 | { 166 | float depth = -g_pos_view.z; 167 | depth = depth < 0.05 ? 1 : depth; 168 | depth = 1 / depth; 169 | color = vec3(depth, depth, depth); 170 | return; 171 | } 172 | 173 | // Covert SH to color 174 | int sh_start = start + SH_IDX; 175 | vec3 dir = g_pos.xyz - cam_pos; 176 | dir = normalize(dir); 177 | color = SH_C0 * get_vec3(sh_start); 178 | 179 | if (sh_dim > 3 && render_mod >= 1) // 1 * 3 180 | { 181 | float x = dir.x; 182 | float y = dir.y; 183 | float z = dir.z; 184 | color = color - SH_C1 * y * get_vec3(sh_start + 1 * 3) + SH_C1 * z * get_vec3(sh_start + 2 * 3) - SH_C1 * x * get_vec3(sh_start + 3 * 3); 185 | 186 | if (sh_dim > 12 && render_mod >= 2) // (1 + 3) * 3 187 | { 188 | float xx = x * x, yy = y * y, zz = z * z; 189 | float xy = x * y, yz = y * z, xz = x * z; 190 | color = color + 191 | SH_C2_0 * xy * get_vec3(sh_start + 4 * 3) + 192 | SH_C2_1 * yz * get_vec3(sh_start + 5 * 3) + 193 | SH_C2_2 * (2.0f * zz - xx - yy) * get_vec3(sh_start + 6 * 3) + 194 | SH_C2_3 * xz * get_vec3(sh_start + 7 * 3) + 195 | SH_C2_4 * (xx - yy) * get_vec3(sh_start + 8 * 3); 196 | 197 | if (sh_dim > 27 && render_mod >= 3) // (1 + 3 + 5) * 3 198 | { 199 | color = color + 200 | SH_C3_0 * y * (3.0f * xx - yy) * get_vec3(sh_start + 9 * 3) + 201 | SH_C3_1 * xy * z * get_vec3(sh_start + 10 * 3) + 202 | SH_C3_2 * y * (4.0f * zz - xx - yy) * get_vec3(sh_start + 11 * 3) + 203 | SH_C3_3 * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * get_vec3(sh_start + 12 * 3) + 204 | SH_C3_4 * x * (4.0f * zz - xx - yy) * get_vec3(sh_start + 13 * 3) + 205 | SH_C3_5 * z * (xx - yy) * get_vec3(sh_start + 14 * 3) + 206 | SH_C3_6 * x * (xx - 3.0f * yy) * get_vec3(sh_start + 15 * 3); 207 | } 208 | } 209 | } 210 | color += 0.5f; 211 | } 212 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from OpenGL.GL import * 2 | import OpenGL.GL.shaders as shaders 3 | import numpy as np 4 | import glm 5 | import ctypes 6 | 7 | class Camera: 8 | def __init__(self, h, w): 9 | self.znear = 0.01 10 | self.zfar = 100 11 | self.h = h 12 | self.w = w 13 | self.fovy = np.pi / 2 14 | self.position = np.array([0.0, 0.0, 3.0]).astype(np.float32) 15 | self.target = np.array([0.0, 0.0, 0.0]).astype(np.float32) 16 | self.up = np.array([0.0, -1.0, 0.0]).astype(np.float32) 17 | self.yaw = -np.pi / 2 18 | self.pitch = 0 19 | 20 | self.is_pose_dirty = True 21 | self.is_intrin_dirty = True 22 | 23 | self.last_x = 640 24 | self.last_y = 360 25 | self.first_mouse = True 26 | 27 | self.is_leftmouse_pressed = False 28 | self.is_rightmouse_pressed = False 29 | 30 | self.rot_sensitivity = 0.02 31 | self.trans_sensitivity = 0.01 32 | self.zoom_sensitivity = 0.08 33 | self.roll_sensitivity = 0.03 34 | self.target_dist = 3. 35 | 36 | def _global_rot_mat(self): 37 | x = np.array([1, 0, 0]) 38 | z = np.cross(x, self.up) 39 | z = z / np.linalg.norm(z) 40 | x = np.cross(self.up, z) 41 | return np.stack([x, self.up, z], axis=-1) 42 | 43 | def get_view_matrix(self): 44 | return np.array(glm.lookAt(self.position, self.target, self.up)) 45 | 46 | def get_project_matrix(self): 47 | # htanx, htany, focal = self.get_htanfovxy_focal() 48 | # f_n = self.zfar - self.znear 49 | # proj_mat = np.array([ 50 | # 1 / htanx, 0, 0, 0, 51 | # 0, 1 / htany, 0, 0, 52 | # 0, 0, self.zfar / f_n, - 2 * self.zfar * self.znear / f_n, 53 | # 0, 0, 1, 0 54 | # ]) 55 | project_mat = glm.perspective( 56 | self.fovy, 57 | self.w / self.h, 58 | self.znear, 59 | self.zfar 60 | ) 61 | return np.array(project_mat).astype(np.float32) 62 | 63 | def get_htanfovxy_focal(self): 64 | htany = np.tan(self.fovy / 2) 65 | htanx = htany / self.h * self.w 66 | focal = self.h / (2 * htany) 67 | return [htanx, htany, focal] 68 | 69 | def get_focal(self): 70 | return self.h / (2 * np.tan(self.fovy / 2)) 71 | 72 | def process_mouse(self, xpos, ypos): 73 | if self.first_mouse: 74 | self.last_x = xpos 75 | self.last_y = ypos 76 | self.first_mouse = False 77 | 78 | xoffset = xpos - self.last_x 79 | yoffset = self.last_y - ypos 80 | self.last_x = xpos 81 | self.last_y = ypos 82 | 83 | if self.is_leftmouse_pressed: 84 | self.yaw += xoffset * self.rot_sensitivity 85 | self.pitch += yoffset * self.rot_sensitivity 86 | 87 | self.pitch = np.clip(self.pitch, -np.pi / 2, np.pi / 2) 88 | 89 | front = np.array([np.cos(self.yaw) * np.cos(self.pitch), 90 | np.sin(self.pitch), np.sin(self.yaw) * 91 | np.cos(self.pitch)]) 92 | front = self._global_rot_mat() @ front.reshape(3, 1) 93 | front = front[:, 0] 94 | self.position[:] = - front * np.linalg.norm(self.position - self.target) + self.target 95 | 96 | self.is_pose_dirty = True 97 | 98 | if self.is_rightmouse_pressed: 99 | front = self.target - self.position 100 | front = front / np.linalg.norm(front) 101 | right = np.cross(self.up, front) 102 | self.position += right * xoffset * self.trans_sensitivity 103 | self.target += right * xoffset * self.trans_sensitivity 104 | cam_up = np.cross(right, front) 105 | self.position += cam_up * yoffset * self.trans_sensitivity 106 | self.target += cam_up * yoffset * self.trans_sensitivity 107 | 108 | self.is_pose_dirty = True 109 | 110 | def process_wheel(self, dx, dy): 111 | front = self.target - self.position 112 | front = front / np.linalg.norm(front) 113 | self.position += front * dy * self.zoom_sensitivity 114 | self.target += front * dy * self.zoom_sensitivity 115 | self.is_pose_dirty = True 116 | 117 | def process_roll_key(self, d): 118 | front = self.target - self.position 119 | right = np.cross(front, self.up) 120 | new_up = self.up + right * (d * self.roll_sensitivity / np.linalg.norm(right)) 121 | self.up = new_up / np.linalg.norm(new_up) 122 | self.is_pose_dirty = True 123 | 124 | def flip_ground(self): 125 | self.up = -self.up 126 | self.is_pose_dirty = True 127 | 128 | def update_target_distance(self): 129 | _dir = self.target - self.position 130 | _dir = _dir / np.linalg.norm(_dir) 131 | self.target = self.position + _dir * self.target_dist 132 | 133 | def update_resolution(self, height, width): 134 | self.h = max(height, 1) 135 | self.w = max(width, 1) 136 | self.is_intrin_dirty = True 137 | 138 | 139 | def load_shaders(vs, fs): 140 | vertex_shader = open(vs, 'r').read() 141 | fragment_shader = open(fs, 'r').read() 142 | 143 | active_shader = shaders.compileProgram( 144 | shaders.compileShader(vertex_shader, GL_VERTEX_SHADER), 145 | shaders.compileShader(fragment_shader, GL_FRAGMENT_SHADER), 146 | ) 147 | return active_shader 148 | 149 | 150 | def compile_shaders(vertex_shader, fragment_shader): 151 | active_shader = shaders.compileProgram( 152 | shaders.compileShader(vertex_shader, GL_VERTEX_SHADER), 153 | shaders.compileShader(fragment_shader, GL_FRAGMENT_SHADER), 154 | ) 155 | return active_shader 156 | 157 | 158 | def set_attributes(program, keys, values, vao=None, buffer_ids=None): 159 | glUseProgram(program) 160 | if vao is None: 161 | vao = glGenVertexArrays(1) 162 | glBindVertexArray(vao) 163 | 164 | if buffer_ids is None: 165 | buffer_ids = [None] * len(keys) 166 | for i, (key, value, b) in enumerate(zip(keys, values, buffer_ids)): 167 | if b is None: 168 | b = glGenBuffers(1) 169 | buffer_ids[i] = b 170 | glBindBuffer(GL_ARRAY_BUFFER, b) 171 | glBufferData(GL_ARRAY_BUFFER, value.nbytes, value.reshape(-1), GL_STATIC_DRAW) 172 | length = value.shape[-1] 173 | pos = glGetAttribLocation(program, key) 174 | glVertexAttribPointer(pos, length, GL_FLOAT, False, 0, None) 175 | glEnableVertexAttribArray(pos) 176 | 177 | glBindBuffer(GL_ARRAY_BUFFER,0) 178 | return vao, buffer_ids 179 | 180 | def set_attribute(program, key, value, vao=None, buffer_id=None): 181 | glUseProgram(program) 182 | if vao is None: 183 | vao = glGenVertexArrays(1) 184 | glBindVertexArray(vao) 185 | 186 | if buffer_id is None: 187 | buffer_id = glGenBuffers(1) 188 | glBindBuffer(GL_ARRAY_BUFFER, buffer_id) 189 | glBufferData(GL_ARRAY_BUFFER, value.nbytes, value.reshape(-1), GL_STATIC_DRAW) 190 | length = value.shape[-1] 191 | pos = glGetAttribLocation(program, key) 192 | glVertexAttribPointer(pos, length, GL_FLOAT, False, 0, None) 193 | glEnableVertexAttribArray(pos) 194 | glBindBuffer(GL_ARRAY_BUFFER,0) 195 | return vao, buffer_id 196 | 197 | def set_attribute_instanced(program, key, value, instance_stride=1, vao=None, buffer_id=None): 198 | glUseProgram(program) 199 | if vao is None: 200 | vao = glGenVertexArrays(1) 201 | glBindVertexArray(vao) 202 | 203 | if buffer_id is None: 204 | buffer_id = glGenBuffers(1) 205 | glBindBuffer(GL_ARRAY_BUFFER, buffer_id) 206 | glBufferData(GL_ARRAY_BUFFER, value.nbytes, value.reshape(-1), GL_STATIC_DRAW) 207 | length = value.shape[-1] 208 | pos = glGetAttribLocation(program, key) 209 | glVertexAttribPointer(pos, length, GL_FLOAT, False, 0, None) 210 | glEnableVertexAttribArray(pos) 211 | glVertexAttribDivisor(pos, instance_stride) 212 | glBindBuffer(GL_ARRAY_BUFFER,0) 213 | return vao, buffer_id 214 | 215 | def set_storage_buffer_data(program, key, value: np.ndarray, bind_idx, vao=None, buffer_id=None): 216 | glUseProgram(program) 217 | # if vao is None: # TODO: if this is really unnecessary? 218 | # vao = glGenVertexArrays(1) 219 | if vao is not None: 220 | glBindVertexArray(vao) 221 | 222 | if buffer_id is None: 223 | buffer_id = glGenBuffers(1) 224 | glBindBuffer(GL_SHADER_STORAGE_BUFFER, buffer_id) 225 | glBufferData(GL_SHADER_STORAGE_BUFFER, value.nbytes, value.reshape(-1), GL_STATIC_DRAW) 226 | # pos = glGetProgramResourceIndex(program, GL_SHADER_STORAGE_BLOCK, key) # TODO: ??? 227 | glBindBufferBase(GL_SHADER_STORAGE_BUFFER, bind_idx, buffer_id) 228 | # glShaderStorageBlockBinding(program, pos, pos) # TODO: ??? 229 | glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0) 230 | return buffer_id 231 | 232 | def set_faces_tovao(vao, faces: np.ndarray): 233 | # faces 234 | glBindVertexArray(vao) 235 | element_buffer = glGenBuffers(1) 236 | glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, element_buffer) 237 | glBufferData(GL_ELEMENT_ARRAY_BUFFER, faces.nbytes, faces, GL_STATIC_DRAW) 238 | return element_buffer 239 | 240 | def set_gl_bindings(vertices, faces): 241 | # vertices 242 | vao = glGenVertexArrays(1) 243 | glBindVertexArray(vao) 244 | # vertex_buffer = glGenVertexArrays(1) 245 | vertex_buffer = glGenBuffers(1) 246 | glBindBuffer(GL_ARRAY_BUFFER, vertex_buffer) 247 | glBufferData(GL_ARRAY_BUFFER, vertices.nbytes, vertices, GL_STATIC_DRAW) 248 | glVertexAttribPointer(0, 4, GL_FLOAT, False, 0, None) 249 | glEnableVertexAttribArray(0) 250 | 251 | # faces 252 | element_buffer = glGenBuffers(1) 253 | glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, element_buffer) 254 | glBufferData(GL_ELEMENT_ARRAY_BUFFER, faces.nbytes, faces, GL_STATIC_DRAW) 255 | # glVertexAttribPointer(1, 3, GL_FLOAT, False, 36, ctypes.c_void_p(12)) 256 | # glEnableVertexAttribArray(1) 257 | # glVertexAttribPointer(2, 3, GL_FLOAT, False, 36, ctypes.c_void_p(12)) 258 | # glEnableVertexAttribArray(2) 259 | 260 | def set_uniform_mat4(shader, content, name): 261 | glUseProgram(shader) 262 | if isinstance(content, glm.mat4): 263 | content = np.array(content).astype(np.float32) 264 | else: 265 | content = content.T 266 | glUniformMatrix4fv( 267 | glGetUniformLocation(shader, name), 268 | 1, 269 | GL_FALSE, 270 | content.astype(np.float32) 271 | ) 272 | 273 | def set_uniform_1f(shader, content, name): 274 | glUseProgram(shader) 275 | glUniform1f( 276 | glGetUniformLocation(shader, name), 277 | content, 278 | ) 279 | 280 | def set_uniform_1int(shader, content, name): 281 | glUseProgram(shader) 282 | glUniform1i( 283 | glGetUniformLocation(shader, name), 284 | content 285 | ) 286 | 287 | def set_uniform_v3f(shader, contents, name): 288 | glUseProgram(shader) 289 | glUniform3fv( 290 | glGetUniformLocation(shader, name), 291 | len(contents), 292 | contents 293 | ) 294 | 295 | def set_uniform_v3(shader, contents, name): 296 | glUseProgram(shader) 297 | glUniform3f( 298 | glGetUniformLocation(shader, name), 299 | contents[0], contents[1], contents[2] 300 | ) 301 | 302 | def set_uniform_v1f(shader, contents, name): 303 | glUseProgram(shader) 304 | glUniform1fv( 305 | glGetUniformLocation(shader, name), 306 | len(contents), 307 | contents 308 | ) 309 | 310 | def set_uniform_v2(shader, contents, name): 311 | glUseProgram(shader) 312 | glUniform2f( 313 | glGetUniformLocation(shader, name), 314 | contents[0], contents[1] 315 | ) 316 | 317 | def set_texture2d(img, texid=None): 318 | h, w, c = img.shape 319 | assert img.dtype == np.uint8 320 | if texid is None: 321 | texid = glGenTextures(1) 322 | glBindTexture(GL_TEXTURE_2D, texid) 323 | glTexImage2D( 324 | GL_TEXTURE_2D, 0, GL_RGB, w, h, 0, 325 | GL_RGB, GL_UNSIGNED_BYTE, img 326 | ) 327 | glActiveTexture(GL_TEXTURE0) # can be removed 328 | # glGenerateMipmap(GL_TEXTURE_2D) 329 | glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR) 330 | glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR) 331 | glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_BORDER) 332 | glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_BORDER) 333 | return texid 334 | 335 | def update_texture2d(img, texid, offset): 336 | x1, y1 = offset 337 | h, w = img.shape[:2] 338 | glBindTexture(GL_TEXTURE_2D, texid) 339 | glTexSubImage2D( 340 | GL_TEXTURE_2D, 0, x1, y1, w, h, 341 | GL_RGB, GL_UNSIGNED_BYTE, img 342 | ) 343 | 344 | 345 | -------------------------------------------------------------------------------- /util_3dgstream.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tinycudann as tcnn 3 | import torch 4 | import json 5 | import os 6 | from plyfile import PlyData, PlyElement 7 | from NTC import NeuralTransformationCache 8 | from renderer_cuda import GaussianDataCUDA, gaus_cuda_from_cpu 9 | from util_gau import load_ply 10 | @torch.no_grad() 11 | def inverse_sigmoid(x): 12 | return torch.log(x/(1-x)) 13 | 14 | def construct_list_of_attributes(gau_cuda:GaussianDataCUDA): 15 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 16 | # All channels except the 3 DC 17 | for i in range(1*3): 18 | l.append('f_dc_{}'.format(i)) 19 | #TODO: SH > 1 20 | for i in range((gau_cuda.sh_dim-1)*3): 21 | l.append('f_rest_{}'.format(i)) 22 | l.append('opacity') 23 | for i in range(gau_cuda.scale.shape[1]): 24 | l.append('scale_{}'.format(i)) 25 | for i in range(gau_cuda.rot.shape[1]): 26 | l.append('rot_{}'.format(i)) 27 | return l 28 | 29 | def load_NTCs(FVV_path, gau_cuda:GaussianDataCUDA, total_frames:int = 150): 30 | NTC_paths=[os.path.join(FVV_path, 'NTCs', f'NTC_{frame_id:06}.pth') for frame_id in range(0, total_frames-1)] 31 | config_path=os.path.join(FVV_path, 'NTCs', 'config.json') 32 | xyz_bound = gau_cuda.get_xyz_bound() 33 | with open(config_path) as f: 34 | NTC_conf = json.load(f) 35 | models=[tcnn.NetworkWithInputEncoding(n_input_dims=3, n_output_dims=8, encoding_config=NTC_conf["encoding"], network_config=NTC_conf["network"]).to(torch.device("cuda")) for path in NTC_paths] 36 | NTCs=[NeuralTransformationCache(model,xyz_bound[0],xyz_bound[1]) for model in models] 37 | for frame_id, ntc in enumerate(NTCs): 38 | ntc.load_state_dict(torch.load(NTC_paths[frame_id])) 39 | return NTCs 40 | 41 | def load_Additions(FVV_path, total_frames:int = 150): 42 | addition_paths=[os.path.join(FVV_path, 'additional_3dgs', f'additions_{frame_id:06}.ply') for frame_id in range(0, total_frames-1)] 43 | additions_gaus=[load_ply(path) for path in addition_paths] 44 | additions_gaus_cuda=[gaus_cuda_from_cpu(gaus) for gaus in additions_gaus] 45 | return additions_gaus_cuda 46 | 47 | def get_per_frame_3dgs(FVV_path, gau_cuda:GaussianDataCUDA, total_frames:int = 150): 48 | raise NotImplementedError("This function is not implemented yet") 49 | 50 | def save_gau_cuda(gau_cuda:GaussianDataCUDA, path:str): 51 | xyz = gau_cuda.xyz.cpu().numpy() 52 | rotation = gau_cuda.rot.cpu().numpy() 53 | normals = np.zeros_like(xyz) 54 | f_dc = gau_cuda.sh[:,0:1,:].transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 55 | f_rest = gau_cuda.sh[:,1:,:].transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 56 | opacities = inverse_sigmoid(gau_cuda.opacity).cpu().numpy() 57 | scale = torch.log(gau_cuda.scale).cpu().numpy() 58 | dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(gau_cuda)] 59 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 60 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 61 | elements[:] = list(map(tuple, attributes)) 62 | el = PlyElement.describe(elements, 'vertex') 63 | PlyData([el]).write(path) -------------------------------------------------------------------------------- /util_gau.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from plyfile import PlyData 3 | from dataclasses import dataclass 4 | 5 | @dataclass 6 | class GaussianData: 7 | xyz: np.ndarray 8 | rot: np.ndarray 9 | scale: np.ndarray 10 | opacity: np.ndarray 11 | sh: np.ndarray 12 | def flat(self) -> np.ndarray: 13 | ret = np.concatenate([self.xyz, self.rot, self.scale, self.opacity, self.sh], axis=-1) 14 | return np.ascontiguousarray(ret) 15 | 16 | def __len__(self): 17 | return len(self.xyz) 18 | 19 | @property 20 | def sh_dim(self): 21 | return self.sh.shape[-1] 22 | 23 | 24 | def naive_gaussian(): 25 | gau_xyz = np.array([ 26 | 0, 0, 0, 27 | 1, 0, 0, 28 | 0, 1, 0, 29 | 0, 0, 1, 30 | ]).astype(np.float32).reshape(-1, 3) 31 | gau_rot = np.array([ 32 | 1, 0, 0, 0, 33 | 1, 0, 0, 0, 34 | 1, 0, 0, 0, 35 | 1, 0, 0, 0 36 | ]).astype(np.float32).reshape(-1, 4) 37 | gau_s = np.array([ 38 | 0.03, 0.03, 0.03, 39 | 0.2, 0.03, 0.03, 40 | 0.03, 0.2, 0.03, 41 | 0.03, 0.03, 0.2 42 | ]).astype(np.float32).reshape(-1, 3) 43 | gau_c = np.array([ 44 | 1, 0, 1, 45 | 1, 0, 0, 46 | 0, 1, 0, 47 | 0, 0, 1, 48 | ]).astype(np.float32).reshape(-1, 3) 49 | gau_c = (gau_c - 0.5) / 0.28209 50 | gau_a = np.array([ 51 | 1, 1, 1, 1 52 | ]).astype(np.float32).reshape(-1, 1) 53 | return GaussianData( 54 | gau_xyz, 55 | gau_rot, 56 | gau_s, 57 | gau_a, 58 | gau_c 59 | ) 60 | 61 | 62 | def load_ply(path): 63 | max_sh_degree = 1 64 | plydata = PlyData.read(path) 65 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 66 | np.asarray(plydata.elements[0]["y"]), 67 | np.asarray(plydata.elements[0]["z"])), axis=1) 68 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 69 | 70 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 71 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 72 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 73 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 74 | 75 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 76 | extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) 77 | # assert len(extra_f_names)==3 * (max_sh_degree + 1) ** 2 - 3 78 | max_sh_degree = int(np.sqrt((len(extra_f_names)+3)//3)) - 1 79 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 80 | for idx, attr_name in enumerate(extra_f_names): 81 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 82 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 83 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)) 84 | features_extra = np.transpose(features_extra, [0, 2, 1]) 85 | 86 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 87 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) 88 | scales = np.zeros((xyz.shape[0], len(scale_names))) 89 | for idx, attr_name in enumerate(scale_names): 90 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 91 | 92 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 93 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) 94 | rots = np.zeros((xyz.shape[0], len(rot_names))) 95 | for idx, attr_name in enumerate(rot_names): 96 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 97 | 98 | # pass activate function 99 | xyz = xyz.astype(np.float32) 100 | rots = rots / np.linalg.norm(rots, axis=-1, keepdims=True) 101 | rots = rots.astype(np.float32) 102 | scales = np.exp(scales) 103 | scales = scales.astype(np.float32) 104 | opacities = 1/(1 + np.exp(- opacities)) # sigmoid 105 | opacities = opacities.astype(np.float32) 106 | shs = np.concatenate([features_dc.reshape(-1, 3), 107 | features_extra.reshape(len(features_dc), -1)], axis=-1).astype(np.float32) 108 | shs = shs.astype(np.float32) 109 | return GaussianData(xyz, rots, scales, opacities, shs) 110 | 111 | if __name__ == "__main__": 112 | gs = load_ply("C:\\Users\\MSI_NB\\Downloads\\viewers\\models\\train\\point_cloud\\iteration_7000\\point_cloud.ply") 113 | a = gs.flat() 114 | print(a.shape) 115 | --------------------------------------------------------------------------------