├── src ├── polaris │ ├── __init__.py │ ├── py.typed │ ├── splat_renderer │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── point_utils.py │ │ │ ├── graphics_utils.py │ │ │ ├── sh_utils.py │ │ │ └── general_utils.py │ │ ├── __init__.py │ │ ├── scene │ │ │ └── cameras.py │ │ ├── gaussian_renderer.py │ │ └── splat_renderer.py │ ├── policy │ │ ├── __init__.py │ │ ├── abstract_client.py │ │ └── droid_jointpos_client.py │ ├── environments │ │ ├── rubrics │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── checkers.py │ │ ├── robot_cfg.py │ │ ├── __init__.py │ │ ├── manager_based_rl_splat_environment.py │ │ └── droid_cfg.py │ ├── config.py │ ├── utils.py │ └── hf_upload.py ├── simple-knn │ ├── spatial.h │ ├── ext.cpp │ ├── simple_knn.h │ ├── spatial.cu │ ├── setup.py │ ├── LICENSE.md │ ├── simple_knn │ │ └── __init__.py │ └── simple_knn.cu └── diff-surfel-rasterization │ ├── diff_surfel_rasterization │ ├── csrc │ │ ├── cuda_rasterizer │ │ │ ├── config.h │ │ │ ├── rasterizer_impl.h │ │ │ ├── backward.h │ │ │ ├── forward.h │ │ │ ├── rasterizer.h │ │ │ └── auxiliary.h │ │ ├── ext.cpp │ │ ├── rasterize_points.h │ │ └── rasterize_points.cu │ └── __init__.py │ ├── CMakeLists.txt │ ├── setup.py │ └── LICENSE.md ├── .python-version ├── docs ├── images │ ├── stack.png │ ├── foodbus.png │ ├── latte-cup.png │ ├── panclean.png │ ├── Teaser Figure.png │ ├── organize-tools.png │ └── tape-into-container.png ├── checkpoints_and_envs.md ├── custom_policies.md └── custom_environments.md ├── .gitmodules ├── scripts ├── upload_env_to_hf.py └── eval.py ├── LICENSE ├── pyproject.toml ├── experiments └── example.py ├── .gitignore └── README.md /src/polaris/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/polaris/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /src/polaris/splat_renderer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/polaris/splat_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | from .splat_renderer import SplatRenderer 2 | -------------------------------------------------------------------------------- /docs/images/stack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arhanjain/polaris/HEAD/docs/images/stack.png -------------------------------------------------------------------------------- /docs/images/foodbus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arhanjain/polaris/HEAD/docs/images/foodbus.png -------------------------------------------------------------------------------- /docs/images/latte-cup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arhanjain/polaris/HEAD/docs/images/latte-cup.png -------------------------------------------------------------------------------- /docs/images/panclean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arhanjain/polaris/HEAD/docs/images/panclean.png -------------------------------------------------------------------------------- /docs/images/Teaser Figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arhanjain/polaris/HEAD/docs/images/Teaser Figure.png -------------------------------------------------------------------------------- /docs/images/organize-tools.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arhanjain/polaris/HEAD/docs/images/organize-tools.png -------------------------------------------------------------------------------- /docs/images/tape-into-container.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arhanjain/polaris/HEAD/docs/images/tape-into-container.png -------------------------------------------------------------------------------- /src/polaris/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from polaris.config import PolicyArgs 2 | from .abstract_client import FakeClient, InferenceClient 3 | 4 | import polaris.policy.droid_jointpos_client 5 | 6 | __all__ = ["PolicyArgs", "FakeClient", "InferenceClient"] 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/diff-surfel-rasterization/third_party/glm"] 2 | path = src/diff-surfel-rasterization/third_party/glm 3 | url = git@github.com:g-truc/glm.git 4 | [submodule "third_party/openpi"] 5 | path = third_party/openpi 6 | url = git@github.com:Physical-Intelligence/openpi.git 7 | -------------------------------------------------------------------------------- /src/polaris/environments/rubrics/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Task evaluation rubrics. 3 | 4 | Rubrics compute success/progress by inspecting simulation state. 5 | """ 6 | 7 | from .base import Rubric, RubricResult 8 | # from .object_in_zone import ObjectInZoneRubric 9 | # from .stacking import StackingRubric 10 | 11 | __all__ = [ 12 | "Rubric", 13 | "RubricResult", 14 | ] 15 | -------------------------------------------------------------------------------- /src/simple-knn/spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); -------------------------------------------------------------------------------- /src/simple-knn/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /src/simple-knn/simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/cuda_rasterizer/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 3 // Default 3, RGB 16 | #define BLOCK_X 16 17 | #define BLOCK_Y 16 18 | 19 | #endif -------------------------------------------------------------------------------- /scripts/upload_env_to_hf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Upload a PolaRiS environment folder to the Hugging Face dataset repository 4 | `PolaRiS-Evals/PolaRiS-Hub` after performing local validation. 5 | Validation tries to catch the most common mistakes (missing assets, malformed 6 | `initial_conditions.json`, unreadable USD), but it cannot guarantee runtime 7 | success inside Isaac Sim. Use this as a fast client-side gate before pushing. 8 | # Example commands: 9 | # uv run scripts/upload_env_to_hf.py ./PolaRiS-Hub/food_bussing --dry-run 10 | """ 11 | 12 | from polaris.hf_upload import main # type: ignore 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "rasterize_points.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); 17 | m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); 18 | m.def("mark_visible", &markVisible); 19 | } -------------------------------------------------------------------------------- /src/simple-knn/spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | torch::Tensor 16 | distCUDA2(const torch::Tensor& points) 17 | { 18 | const int P = points.size(0); 19 | 20 | auto float_opts = points.options().dtype(torch::kFloat32); 21 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 22 | 23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); 24 | 25 | return means; 26 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Arhan Jain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | cmake_minimum_required(VERSION 3.20) 13 | 14 | project(DiffRast LANGUAGES CUDA CXX) 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CXX_EXTENSIONS OFF) 18 | set(CMAKE_CUDA_STANDARD 17) 19 | 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 21 | 22 | add_library(CudaRasterizer 23 | cuda_rasterizer/backward.h 24 | cuda_rasterizer/backward.cu 25 | cuda_rasterizer/forward.h 26 | cuda_rasterizer/forward.cu 27 | cuda_rasterizer/auxiliary.h 28 | cuda_rasterizer/rasterizer_impl.cu 29 | cuda_rasterizer/rasterizer_impl.h 30 | cuda_rasterizer/rasterizer.h 31 | ) 32 | 33 | set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86") 34 | 35 | target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) 36 | target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 37 | -------------------------------------------------------------------------------- /src/polaris/splat_renderer/utils/point_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def depths_to_points(view, depthmap): 5 | c2w = (view.world_view_transform.T).inverse() 6 | W, H = view.image_width, view.image_height 7 | ndc2pix = ( 8 | torch.tensor([[W / 2, 0, 0, (W) / 2], [0, H / 2, 0, (H) / 2], [0, 0, 0, 1]]) 9 | .float() 10 | .cuda() 11 | .T 12 | ) 13 | projection_matrix = c2w.T @ view.full_proj_transform 14 | intrins = (projection_matrix @ ndc2pix)[:3, :3].T 15 | 16 | grid_x, grid_y = torch.meshgrid( 17 | torch.arange(W, device="cuda").float(), 18 | torch.arange(H, device="cuda").float(), 19 | indexing="xy", 20 | ) 21 | points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape( 22 | -1, 3 23 | ) 24 | rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T 25 | rays_o = c2w[:3, 3] 26 | points = depthmap.reshape(-1, 1) * rays_d + rays_o 27 | return points 28 | 29 | 30 | def depth_to_normal(view, depth): 31 | """ 32 | view: view camera 33 | depth: depthmap 34 | """ 35 | points = depths_to_points(view, depth).reshape(*depth.shape[1:], 3) 36 | output = torch.zeros_like(points) 37 | dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) 38 | dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) 39 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) 40 | output[1:-1, 1:-1, :] = normal_map 41 | return output 42 | -------------------------------------------------------------------------------- /docs/checkpoints_and_envs.md: -------------------------------------------------------------------------------- 1 | # Polaris Checkpoints 2 | All checkpoints for PolaRiS were based on DROID base policies. Checkpoints were produced by cotraining at a weightage of 10% random simulated data and 90% DROID data for 1k steps. 3 | 4 | | Policy Name | Checkpoints Path | 5 | | :--- | :--- | 6 | | **π0.5 Polaris** | `gs://openpi-assets/checkpoints/polaris/pi05_droid_jointpos_polaris` | 7 | | **π0 Fast Polaris** | `gs://openpi-assets/checkpoints/polaris/pi0_fast_droid_jointpos_polaris` | 8 | | **π0 Polaris** | `gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_polaris` | 9 | | **π0 Polaris (100k)** | `gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_100k_polaris` | 10 | | **PaliGemma Polaris** | `gs://openpi-assets/checkpoints/polaris/paligemma_binning_droid_jointpos_polaris` | 11 | 12 | ## Base DROID Joint Position Checkpoints 13 | | Policy Name | Checkpoints Path | 14 | | :--- | :--- | 15 | | **π0.5 Base** | `gs://openpi-assets/checkpoints/pi05_droid_jointpos` | 16 | | **π0 Fast Base** | `gs://openpi-assets/checkpoints/pi0_fast_droid_jointpos` | 17 | | **π0 Base** | `gs://openpi-assets/checkpoints/pi0_droid_jointpos` | 18 | | **π0 Base (100k)** | `gs://openpi-assets/checkpoints/pi0_droid_jointpos_100k` | 19 | | **PaliGemma Base** | `gs://openpi-assets/checkpoints/paligemma_binning_droid_jointpos` | 20 | 21 | 22 | # Environments 23 | | Environment Name | Prompt | 24 | | :--- | :--- | 25 | | DROID-BlockStackKitchen | Place and stack the blocks on top of the green tray | 26 | | DROID-FoodBussing | Put all the foods in the bowl | 27 | | DROID-PanClean | Use the yellow sponge to scrub the blue handle frying pan | 28 | | DROID-MoveLatteCup | put the latte art cup on top of the cutting board | 29 | | DROID-OrganizeTools | put the scissor into the large container | 30 | | DROID-TapeIntoContainer | put the tape into the container | 31 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/cuda_rasterizer/rasterizer_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "rasterizer.h" 17 | #include 18 | #include 19 | 20 | namespace CudaRasterizer 21 | { 22 | template 23 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 24 | { 25 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 26 | ptr = reinterpret_cast(offset); 27 | chunk = reinterpret_cast(ptr + count); 28 | } 29 | 30 | struct GeometryState 31 | { 32 | size_t scan_size; 33 | float* depths; 34 | char* scanning_space; 35 | bool* clamped; 36 | int* internal_radii; 37 | float2* means2D; 38 | float* transMat; 39 | float4* normal_opacity; 40 | float* rgb; 41 | uint32_t* point_offsets; 42 | uint32_t* tiles_touched; 43 | 44 | static GeometryState fromChunk(char*& chunk, size_t P); 45 | }; 46 | 47 | struct ImageState 48 | { 49 | uint2* ranges; 50 | uint32_t* n_contrib; 51 | float* accum_alpha; 52 | 53 | static ImageState fromChunk(char*& chunk, size_t N); 54 | }; 55 | 56 | struct BinningState 57 | { 58 | size_t sorting_size; 59 | uint64_t* point_list_keys_unsorted; 60 | uint64_t* point_list_keys; 61 | uint32_t* point_list_unsorted; 62 | uint32_t* point_list; 63 | char* list_sorting_space; 64 | 65 | static BinningState fromChunk(char*& chunk, size_t P); 66 | }; 67 | 68 | template 69 | size_t required(size_t P) 70 | { 71 | char* size = nullptr; 72 | T::fromChunk(size, P); 73 | return ((size_t)size) + 128; 74 | } 75 | }; 76 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/cuda_rasterizer/backward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace BACKWARD 22 | { 23 | void render( 24 | const dim3 grid, dim3 block, 25 | const uint2* ranges, 26 | const uint32_t* point_list, 27 | int W, int H, 28 | float focal_x, float focal_y, 29 | const float* bg_color, 30 | const float2* means2D, 31 | const float4* normal_opacity, 32 | const float* transMats, 33 | const float* colors, 34 | const float* depths, 35 | const float* final_Ts, 36 | const uint32_t* n_contrib, 37 | const float* dL_dpixels, 38 | const float* dL_depths, 39 | float * dL_dtransMat, 40 | float3* dL_dmean2D, 41 | float* dL_dnormal3D, 42 | float* dL_dopacity, 43 | float* dL_dcolors, 44 | float near_n , 45 | float far_n); 46 | 47 | void preprocess( 48 | int P, int D, int M, 49 | const float3* means, 50 | const int* radii, 51 | const float* shs, 52 | const bool* clamped, 53 | const glm::vec2* scales, 54 | const glm::vec4* rotations, 55 | const float scale_modifier, 56 | const float* transMats, 57 | const float* view, 58 | const float* proj, 59 | const float focal_x, const float focal_y, 60 | const float tan_fovx, const float tan_fovy, 61 | const glm::vec3* campos, 62 | float3* dL_dmean2D, 63 | const float* dL_dnormal3D, 64 | float* dL_dtransMat, 65 | float* dL_dcolor, 66 | float* dL_dsh, 67 | glm::vec3* dL_dmeans, 68 | glm::vec2* dL_dscale, 69 | glm::vec4* dL_drot); 70 | } 71 | 72 | #endif 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "polaris" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | authors = [ 7 | { name = "Arhan Jain", email = "arhan.j04@gmail.com" } 8 | ] 9 | requires-python = ">=3.11" 10 | dependencies = [ 11 | "isaaclab[all,isaacsim]==2.3.0", 12 | "mediapy>=1.2.4", 13 | "opencv-python>=4.11.0.86", 14 | "plyfile>=1.1.3", 15 | "tyro>=0.9.17", 16 | "diff-surfel-rasterization", 17 | "simple-knn", 18 | "openpi-client", 19 | #"torch>=2.9.0", # Change here for different CUDA version 20 | #"torchvision>=0.24.0", # Change here for different CUDA version 21 | ] 22 | 23 | [project.scripts] 24 | polaris = "polaris.cli:main" 25 | 26 | [tool.uv] 27 | override-dependencies = [ 28 | "pywin32==306; sys_platform == 'win32'", 29 | "torch>=2.9.0", # Change here for different CUDA version 30 | "torchvision>=0.24.0", # Change here for different CUDA version 31 | ] 32 | 33 | [build-system] 34 | requires = ["hatchling"] 35 | build-backend = "hatchling.build" 36 | 37 | [dependency-groups] 38 | dev = [ 39 | "ruff>=0.14.9", 40 | ] 41 | 42 | [tool.uv.sources] 43 | isaacsim = { index = "nvidia" } 44 | diff-surfel-rasterization = { path = "src/diff-surfel-rasterization", editable=true} 45 | simple-knn = { path = "src/simple-knn", editable=true} 46 | openpi-client = { path = "third_party/openpi/packages/openpi-client" } 47 | torch = { index = "torch" } 48 | torchvision = { index = "torch" } 49 | 50 | 51 | [[tool.uv.index]] 52 | name = "nvidia" 53 | url = "https://pypi.nvidia.com/" 54 | explicit = true 55 | 56 | [[tool.uv.index]] 57 | name = "torch" 58 | url = "https://download.pytorch.org/whl/cu130" # Change here for different CUDA version 59 | explicit = true 60 | 61 | 62 | 63 | [tool.basedpyright] 64 | extraPaths = [ 65 | 66 | "./.venv/lib/python3.11/site-packages/isaaclab/source/isaaclab/", 67 | "./.venv/lib/python3.11/site-packages/isaaclab/source/isaaclab_assets/", 68 | "./.venv/lib/python3.11/site-packages/isaaclab/source/isaaclab_tasks/", 69 | ] 70 | typeCheckingMode = "standard" 71 | -------------------------------------------------------------------------------- /src/polaris/splat_renderer/utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from typing import NamedTuple 5 | 6 | 7 | class BasicPointCloud(NamedTuple): 8 | points: np.array 9 | colors: np.array 10 | normals: np.array 11 | 12 | 13 | def geom_transform_points(points, transf_matrix): 14 | P, _ = points.shape 15 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 16 | points_hom = torch.cat([points, ones], dim=1) 17 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 18 | 19 | denom = points_out[..., 3:] + 0.0000001 20 | return (points_out[..., :3] / denom).squeeze(dim=0) 21 | 22 | 23 | def getWorld2View(R, t): 24 | Rt = np.zeros((4, 4)) 25 | Rt[:3, :3] = R.transpose() 26 | Rt[:3, 3] = t 27 | Rt[3, 3] = 1.0 28 | return np.float32(Rt) 29 | 30 | 31 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | 37 | C2W = np.linalg.inv(Rt) 38 | cam_center = C2W[:3, 3] 39 | cam_center = (cam_center + translate) * scale 40 | C2W[:3, 3] = cam_center 41 | Rt = np.linalg.inv(C2W) 42 | return np.float32(Rt) 43 | 44 | 45 | def getProjectionMatrix(znear, zfar, fovX, fovY): 46 | tanHalfFovY = math.tan((fovY / 2)) 47 | tanHalfFovX = math.tan((fovX / 2)) 48 | 49 | top = tanHalfFovY * znear 50 | bottom = -top 51 | right = tanHalfFovX * znear 52 | left = -right 53 | 54 | P = torch.zeros(4, 4) 55 | 56 | z_sign = 1.0 57 | 58 | P[0, 0] = 2.0 * znear / (right - left) 59 | P[1, 1] = 2.0 * znear / (top - bottom) 60 | P[0, 2] = (right + left) / (right - left) 61 | P[1, 2] = (top + bottom) / (top - bottom) 62 | P[3, 2] = z_sign 63 | P[2, 2] = z_sign * zfar / (zfar - znear) 64 | P[2, 3] = -(zfar * znear) / (zfar - znear) 65 | return P 66 | 67 | 68 | def fov2focal(fov, pixels): 69 | return pixels / (2 * math.tan(fov / 2)) 70 | 71 | 72 | def focal2fov(focal, pixels): 73 | return 2 * math.atan(pixels / (2 * focal)) 74 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/cuda_rasterizer/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess(int P, int D, int M, 25 | const float* orig_points, 26 | const glm::vec2* scales, 27 | const float scale_modifier, 28 | const glm::vec4* rotations, 29 | const float* opacities, 30 | const float* shs, 31 | bool* clamped, 32 | const float* transMat_precomp, 33 | const float* colors_precomp, 34 | const float* viewmatrix, 35 | const float* projmatrix, 36 | const glm::vec3* cam_pos, 37 | const int W, int H, 38 | const float focal_x, float focal_y, 39 | const float tan_fovx, float tan_fovy, 40 | int* radii, 41 | float2* points_xy_image, 42 | float* depths, 43 | // float* isovals, 44 | // float3* normals, 45 | float* transMats, 46 | float* colors, 47 | float4* normal_opacity, 48 | const dim3 grid, 49 | uint32_t* tiles_touched, 50 | bool prefiltered, 51 | float near_n, 52 | float far_n); 53 | 54 | // Main rasterization method. 55 | void render( 56 | const dim3 grid, dim3 block, 57 | const uint2* ranges, 58 | const uint32_t* point_list, 59 | int W, int H, 60 | float focal_x, float focal_y, 61 | const float2* points_xy_image, 62 | const float* features, 63 | const float* transMats, 64 | const float* depths, 65 | const float4* normal_opacity, 66 | float* final_T, 67 | uint32_t* n_contrib, 68 | const float* bg_color, 69 | float* out_color, 70 | float* out_others, 71 | float near_n , 72 | float far_n); 73 | } 74 | 75 | 76 | #endif 77 | -------------------------------------------------------------------------------- /src/polaris/environments/robot_cfg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import isaaclab.sim as sim_utils 4 | from isaaclab.actuators import ImplicitActuatorCfg 5 | from isaaclab.assets import ArticulationCfg 6 | 7 | from polaris.utils import DATA_PATH 8 | 9 | NVIDIA_DROID = ArticulationCfg( 10 | prim_path="{ENV_REGEX_NS}/robot", 11 | spawn=sim_utils.UsdFileCfg( 12 | usd_path=str(DATA_PATH / "nvidia_droid/noninstanceable.usd"), 13 | activate_contact_sensors=True, 14 | rigid_props=sim_utils.RigidBodyPropertiesCfg( 15 | disable_gravity=True, 16 | max_depenetration_velocity=5.0, 17 | ), 18 | articulation_props=sim_utils.ArticulationRootPropertiesCfg( 19 | enabled_self_collisions=False, 20 | solver_position_iteration_count=64, 21 | solver_velocity_iteration_count=0, 22 | ), 23 | ), 24 | init_state=ArticulationCfg.InitialStateCfg( 25 | pos=(0, 0, 0), 26 | rot=(1, 0, 0, 0), 27 | joint_pos={ 28 | "panda_joint1": 0.0, 29 | "panda_joint2": -1 / 5 * np.pi, 30 | "panda_joint3": 0.0, 31 | "panda_joint4": -4 / 5 * np.pi, 32 | "panda_joint5": 0.0, 33 | "panda_joint6": 3 / 5 * np.pi, 34 | "panda_joint7": 0, 35 | "finger_joint": 0.0, 36 | "right_outer.*": 0.0, 37 | "left_inner.*": 0.0, 38 | "right_inner.*": 0.0, 39 | }, 40 | ), 41 | soft_joint_pos_limit_factor=1, 42 | actuators={ 43 | "panda_shoulder": ImplicitActuatorCfg( 44 | joint_names_expr=["panda_joint[1-4]"], 45 | effort_limit=87.0, 46 | velocity_limit=2.175, 47 | stiffness=400.0, 48 | damping=80.0, 49 | ), 50 | "panda_forearm": ImplicitActuatorCfg( 51 | joint_names_expr=["panda_joint[5-7]"], 52 | effort_limit=12.0, 53 | velocity_limit=2.61, 54 | stiffness=400.0, 55 | damping=80.0, 56 | ), 57 | "gripper": ImplicitActuatorCfg( 58 | joint_names_expr=["finger_joint"], 59 | stiffness=None, 60 | damping=None, 61 | effort_limit=200.0, 62 | velocity_limit=5.0, # 2.175, 63 | ), 64 | }, 65 | ) 66 | -------------------------------------------------------------------------------- /src/simple-knn/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | 7 | from setuptools import setup, find_packages 8 | from pathlib import Path 9 | 10 | # Read README if available 11 | this_directory = Path(__file__).parent 12 | try: 13 | long_description = (this_directory / "README.md").read_text() 14 | except FileNotFoundError: 15 | long_description = "Simple KNN for 3D Gaussian Splatting" 16 | 17 | setup( 18 | name="simple_knn", 19 | version="0.1.0", 20 | author="Bernhard Kerbl", 21 | description="Simple KNN CUDA implementation for 3D Gaussian Splatting", 22 | long_description=long_description, 23 | long_description_content_type="text/markdown", 24 | url="https://gitlab.inria.fr/bkerbl/simple-knn", 25 | # Find the simple_knn package 26 | packages=find_packages(), 27 | # Include all source files (.cu, .cpp, .h) for JIT compilation 28 | package_data={ 29 | "": ["*.cu", "*.cpp", "*.h", "*.cuh"], 30 | }, 31 | # Also include source files at repo root 32 | include_package_data=True, 33 | # Dependencies 34 | python_requires=">=3.7", 35 | install_requires=[ 36 | "torch>=1.13.0", 37 | "ninja", # For faster JIT compilation 38 | ], 39 | # Classifiers 40 | classifiers=[ 41 | "Development Status :: 4 - Beta", 42 | "Intended Audience :: Science/Research", 43 | "Programming Language :: Python :: 3", 44 | "Programming Language :: C++", 45 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 46 | ], 47 | ) 48 | 49 | # ============================================================================== 50 | # NOTES: 51 | # ============================================================================== 52 | # 53 | # This setup.py does NOT use CUDAExtension to pre-compile the extension. 54 | # Instead, it ships source files which are JIT compiled on first import. 55 | # 56 | # BENEFITS: 57 | # ✓ pip install is instant (no compilation) 58 | # ✓ Works on any CUDA version (compiles for user's hardware) 59 | # ✓ No need to build wheels for every Python/CUDA combination 60 | # ✓ Smaller package size (source only, no binaries) 61 | # 62 | # ============================================================================== 63 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/rasterize_points.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | std::tuple 19 | RasterizeGaussiansCUDA( 20 | const torch::Tensor& background, 21 | const torch::Tensor& means3D, 22 | const torch::Tensor& colors, 23 | const torch::Tensor& opacity, 24 | const torch::Tensor& scales, 25 | const torch::Tensor& rotations, 26 | const float scale_modifier, 27 | const torch::Tensor& transMat_precomp, 28 | const torch::Tensor& viewmatrix, 29 | const torch::Tensor& projmatrix, 30 | const float tan_fovx, 31 | const float tan_fovy, 32 | const int image_height, 33 | const int image_width, 34 | const torch::Tensor& sh, 35 | const int degree, 36 | const torch::Tensor& campos, 37 | const bool prefiltered, 38 | const bool debug, 39 | float near_n , 40 | float far_n); 41 | 42 | std::tuple 43 | RasterizeGaussiansBackwardCUDA( 44 | const torch::Tensor& background, 45 | const torch::Tensor& means3D, 46 | const torch::Tensor& radii, 47 | const torch::Tensor& colors, 48 | const torch::Tensor& scales, 49 | const torch::Tensor& rotations, 50 | const float scale_modifier, 51 | const torch::Tensor& transMat_precomp, 52 | const torch::Tensor& viewmatrix, 53 | const torch::Tensor& projmatrix, 54 | const float tan_fovx, 55 | const float tan_fovy, 56 | const torch::Tensor& dL_dout_color, 57 | const torch::Tensor& dL_dout_others, 58 | const torch::Tensor& sh, 59 | const int degree, 60 | const torch::Tensor& campos, 61 | const torch::Tensor& geomBuffer, 62 | const int R, 63 | const torch::Tensor& binningBuffer, 64 | const torch::Tensor& imageBuffer, 65 | const bool debug, 66 | float near_n , 67 | float far_n); 68 | 69 | torch::Tensor markVisible( 70 | torch::Tensor& means3D, 71 | torch::Tensor& viewmatrix, 72 | torch::Tensor& projmatrix, 73 | float near_n, 74 | float far_n); 75 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/cuda_rasterizer/rasterizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace CudaRasterizer 19 | { 20 | class Rasterizer 21 | { 22 | public: 23 | 24 | static void markVisible( 25 | int P, 26 | float* means3D, 27 | float* viewmatrix, 28 | float* projmatrix, 29 | bool* present , 30 | float near_n , 31 | float far_n); 32 | 33 | static int forward( 34 | std::function geometryBuffer, 35 | std::function binningBuffer, 36 | std::function imageBuffer, 37 | const int P, int D, int M, 38 | const float* background, 39 | const int width, int height, 40 | const float* means3D, 41 | const float* shs, 42 | const float* colors_precomp, 43 | const float* opacities, 44 | const float* scales, 45 | const float scale_modifier, 46 | const float* rotations, 47 | const float* transMat_precomp, 48 | const float* viewmatrix, 49 | const float* projmatrix, 50 | const float* cam_pos, 51 | const float tan_fovx, float tan_fovy, 52 | const bool prefiltered, 53 | float* out_color, 54 | float* out_others, 55 | int* radii = nullptr, 56 | bool debug = false , 57 | float near_n = 0.2, 58 | float far_n = 100.0); 59 | 60 | static void backward( 61 | const int P, int D, int M, int R, 62 | const float* background, 63 | const int width, int height, 64 | const float* means3D, 65 | const float* shs, 66 | const float* colors_precomp, 67 | const float* scales, 68 | const float scale_modifier, 69 | const float* rotations, 70 | const float* transMat_precomp, 71 | const float* viewmatrix, 72 | const float* projmatrix, 73 | const float* campos, 74 | const float tan_fovx, float tan_fovy, 75 | const int* radii, 76 | char* geom_buffer, 77 | char* binning_buffer, 78 | char* image_buffer, 79 | const float* dL_dpix, 80 | const float* dL_depths, 81 | float* dL_dmean2D, 82 | float* dL_dnormal, 83 | float* dL_dopacity, 84 | float* dL_dcolor, 85 | float* dL_dmean3D, 86 | float* dL_dtransMat, 87 | float* dL_dsh, 88 | float* dL_dscale, 89 | float* dL_drot, 90 | bool debug, 91 | float near_n , 92 | float far_n); 93 | }; 94 | }; 95 | 96 | #endif 97 | -------------------------------------------------------------------------------- /src/polaris/policy/abstract_client.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Callable 3 | import numpy as np 4 | 5 | from polaris.config import PolicyArgs 6 | 7 | 8 | class InferenceClient(ABC): 9 | REGISTERED_CLIENTS = {} 10 | 11 | # def __init_subclass__(cls, client_name: str, *args, **kwargs) -> None: 12 | # super().__init_subclass__(*args, **kwargs) 13 | # InferenceClient.REGISTERED_CLIENTS[client_name] = cls 14 | 15 | @staticmethod 16 | def register(client_name: str) -> Callable[[type], type]: 17 | def decorator(cls: type): 18 | InferenceClient.REGISTERED_CLIENTS[client_name] = cls 19 | return cls 20 | 21 | return decorator 22 | 23 | @staticmethod 24 | def get_client(policy_args: PolicyArgs) -> "InferenceClient": 25 | if policy_args.client not in InferenceClient.REGISTERED_CLIENTS: 26 | raise ValueError( 27 | f"Client {policy_args.client} not found. Available clients: {list(InferenceClient.REGISTERED_CLIENTS.keys())}" 28 | ) 29 | return InferenceClient.REGISTERED_CLIENTS[policy_args.client](policy_args) 30 | 31 | @abstractmethod 32 | def __init__(self, args) -> None: 33 | """ 34 | Initializes the client. 35 | """ 36 | pass 37 | 38 | @property 39 | def rerender(self) -> bool: 40 | """ 41 | Policy requests a rerender of the visualization. Optimization for less splat rendering 42 | for chunked policies. Can default to always True if optimization is not desired. 43 | """ 44 | return True 45 | 46 | @abstractmethod 47 | def infer( 48 | self, obs, instruction, return_viz: bool = False 49 | ) -> tuple[np.ndarray, np.ndarray | None]: 50 | """ 51 | Does inference on observation and returns action and visualization. If visualization is not needed, return None. 52 | """ 53 | 54 | pass 55 | 56 | @abstractmethod 57 | def reset(self): 58 | """ 59 | Resets the client to start a new episode. Useful if policy is stateful. 60 | """ 61 | pass 62 | 63 | 64 | class FakeClient(InferenceClient): 65 | """ 66 | Fake client that returns a dummy action and visualization. 67 | """ 68 | 69 | def __init__(self, *args, **kwargs) -> None: 70 | return 71 | 72 | def infer( 73 | self, obs, instruction, return_viz: bool = False 74 | ) -> tuple[np.ndarray, np.ndarray | None]: 75 | import cv2 76 | 77 | external = obs["splat"]["external_cam"] 78 | wrist = obs["splat"]["wrist_cam"] 79 | external = cv2.resize(external, (224, 224)) 80 | wrist = cv2.resize(wrist, (224, 224)) 81 | both = np.concatenate([external, wrist], axis=1) 82 | return np.zeros((8,)), both 83 | 84 | def reset(self, *args, **kwargs): 85 | return 86 | -------------------------------------------------------------------------------- /src/polaris/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lightweight config dataclasses for evaluation. 3 | No heavy dependencies - safe to import anywhere. 4 | """ 5 | 6 | from dataclasses import dataclass 7 | 8 | 9 | @dataclass 10 | class PolicyServer: 11 | """ 12 | Configuration for a policy server to co-launch. 13 | 14 | Use {port} placeholder in command - it will be replaced with an auto-assigned free port. 15 | Jobs using this server will automatically have their policy.port updated. 16 | 17 | Example: 18 | PolicyServer( 19 | name="pi0", 20 | command="CUDA_VISIBLE_DEVICES=0 python serve_policy.py --port {port}", 21 | ) 22 | """ 23 | 24 | name: str # Friendly name for logging (also used to match jobs to servers) 25 | command: str # Shell command with {port} placeholder 26 | ready_message: str = ( 27 | "Application startup complete" # Message indicating server is ready 28 | ) 29 | 30 | # Runtime-assigned (don't set manually) 31 | _assigned_port: int | None = None 32 | 33 | 34 | @dataclass 35 | class PolicyArgs: 36 | """Policy configuration.""" 37 | 38 | # name: str # Policy name (pi05_droid_jointpos, pi0_fast_droid_jointpos, etc.) 39 | client: str = "DroidJointPos" # Client name (DroidJointPos, Fake, etc.) 40 | host: str = "0.0.0.0" 41 | port: int = 8000 42 | open_loop_horizon: int | None = 8 43 | 44 | 45 | @dataclass 46 | class EvalArgs: 47 | """Evaluation configuration.""" 48 | 49 | policy: PolicyArgs # Policy arguments 50 | environment: str # Which IsaacLab environment to use 51 | run_folder: str # Path to run folder 52 | headless: bool = True # Whether to run in headless mode 53 | initial_conditions_file: str | None = None # Path to initial conditions file 54 | instruction: str | None = None # Override language instruction 55 | rollouts: int | None = None # Number of rollouts to evaluate 56 | 57 | 58 | @dataclass 59 | class JobCfg: 60 | """A single evaluation job in a batch.""" 61 | 62 | eval_args: EvalArgs 63 | server: PolicyServer | None = None # Server to co-launch for this job 64 | 65 | 66 | @dataclass 67 | class BatchConfig: 68 | """Batch evaluation configuration.""" 69 | 70 | jobs: list[JobCfg] 71 | 72 | # @staticmethod # let users do this on their own if they want 73 | # def sweep(**kwargs: list[Any]) -> list[dict[str, Any]]: 74 | # """ 75 | # Helper to generate grid of configs from lists of values. 76 | 77 | # Example: 78 | # BatchConfig.sweep( 79 | # usd=["env1.usd", "env2.usd"], 80 | # policy=["pi0", "pi05"], 81 | # ) 82 | # # Returns 4 dicts: all combinations 83 | # """ 84 | # keys = list(kwargs.keys()) 85 | # values = [v if isinstance(v, list) else [v] for v in kwargs.values()] 86 | # return [dict(zip(keys, combo)) for combo in itertools.product(*values)] 87 | -------------------------------------------------------------------------------- /docs/custom_policies.md: -------------------------------------------------------------------------------- 1 | # Cotraining 2 | Pull the RLDS sim co-training dataset 3 | ```bash 4 | uvx hf download owhan/PolaRiS-datasets --repo-type=dataset --local-dir [path/to/rlds/datasets] 5 | ``` 6 | 7 | To run cotraining on an off-the-shelf policy, use the PolaRiS training configs in [openpi](https://github.com/Physical-Intelligence/openpi). Before running make sure to update the `rlds_data_dir` for each config. Example run below. 8 | ```bash 9 | cd third_party/openpi 10 | uv run --group rlds scripts/train.py pi05_droid_jointpos_polaris --exp-name=polaris-pi05-droid --overwrite 11 | ``` 12 | 13 | 14 | # Evaluating Custom Policies 15 | 16 | PolaRiS provides a simple interface for evaluating custom policies. For simplicity, we employ a server-client setup where the policy is hosted in a different process from the evaluation process. This can be helpful especialy when policies may be require lots of resources or conflicting dependencies. 17 | 18 | We interface with policies through [openpi's WebsockeClientPolicy](https://github.com/Physical-Intelligence/openpi/blob/main/packages/openpi-client/src/openpi_client/websocket_client_policy.py). You may host the policy server however you want. To define a client you need to implement the [InferenceClient](src/polaris/policy/abstract_client.py) abstract class. See [DroidJointPosClient](src/polaris/policy/droid_jointpos_client.py) for a working example. 19 | 20 | Minimal Example: 21 | ```py 22 | @InferenceClient.register(client_name="CustomPolicy") 23 | class CustomPolicy(InferenceClient): 24 | def __init__(self, args: PolicyArgs): 25 | # inititalize any necessary state (obs history, action chunks, etc.) 26 | self.client = websocket_client_policy.WebsocketClientPolicy( 27 | host=args.host, port=args.port 28 | ) 29 | 30 | @property 31 | def rerender(self) -> bool: 32 | """ 33 | Policy requests a rerender of the visualization. Optimization for less splat rendering 34 | for chunked policies. Can default to always True if optimization is not desired. 35 | """ 36 | return True 37 | 38 | def infer(self, obs, instruction, return_viz: bool = False) -> tuple[np.ndarray, np.ndarray | None]: 39 | """ 40 | Does inference on observation and returns action and visualization. If visualization is not needed, return None. 41 | """ 42 | request_data = { 43 | "external_image": obs["splat"]["external_cam"], 44 | "wrist_image": obs["splat"]["wrist_cam"], 45 | "instruction": instruction, 46 | } 47 | server_response= self.client.infer(request_data) 48 | return server_response["action"], None 49 | 50 | def reset(self): 51 | """ 52 | Resets the client to start a new episode. Useful if policy is stateful. 53 | """ 54 | pass 55 | ``` 56 | 57 | To run the policy, specify the client name and port: 58 | ```bash 59 | uv run scripts/eval.py --environment DROID-FoodBussing --policy.client CustomPolicy --policy.port 8000 --run-folder runs/test 60 | ``` 61 | -------------------------------------------------------------------------------- /src/polaris/environments/rubrics/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for task success/progress rubrics. 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import Callable 7 | from isaaclab.envs import ManagerBasedRLEnv 8 | 9 | 10 | @dataclass 11 | class RubricResult: 12 | """Result from evaluating a rubric.""" 13 | 14 | success: bool # Binary task success 15 | progress: float # Progress score 0.0 - 1.0 16 | metrics: dict[str, float] # Additional metrics for logging 17 | 18 | 19 | class Rubric: 20 | """ 21 | Rubrics compute success/progress by inspecting simulation state. 22 | They're called after each step and on reset to populate info dict. 23 | """ 24 | 25 | def __init__(self, criteria: list[Callable | tuple[Callable, list[int]]], **kwargs): 26 | """ 27 | Initialize the rubric with access to the environment. 28 | 29 | Args: 30 | env: The ManagerBasedRLEnv instance 31 | **kwargs: Task-specific configuration 32 | """ 33 | self.config = kwargs 34 | self.criteria = criteria 35 | self.criteria_reached = [False] * len(criteria) 36 | 37 | def evaluate(self, env: ManagerBasedRLEnv) -> RubricResult: 38 | """ 39 | Evaluate current simulation state and return result. 40 | 41 | Supports criteria with optional dependencies. 42 | Criteria can be: 43 | - callable (no dependency, can be achieved in any order) 44 | - (callable, [dep_indices]) (only counts if all deps by index are met) 45 | This allows for some to require others, but leaves most unconstrained. 46 | 47 | Tracks the max-ever reached state for each criterion using self.criteria_reached. 48 | """ 49 | metrics = {} 50 | num_criteria = len(self.criteria) 51 | 52 | criteria_met_now = [] 53 | for idx, c in enumerate(self.criteria): 54 | # Check if c is (callable, [deps]), else treat as callable only 55 | if isinstance(c, tuple): 56 | fn, deps = c 57 | # Only evaluate if all deps ever reached 58 | deps_met = all(self.criteria_reached[d] for d in deps) 59 | result = fn(env) if deps_met else False 60 | else: 61 | fn = c 62 | result = fn(env) 63 | # Update max-ever reached for this criterion 64 | self.criteria_reached[idx] = self.criteria_reached[idx] or bool(result) 65 | criteria_met_now.append(bool(result)) 66 | 67 | num_reached_ever = sum(self.criteria_reached) 68 | progress = num_reached_ever / num_criteria if num_criteria > 0 else 0.0 69 | metrics["criteria_ever_reached"] = num_reached_ever 70 | metrics["criteria_total"] = num_criteria 71 | 72 | success = num_reached_ever == num_criteria 73 | return RubricResult(success=success, progress=progress, metrics=metrics) 74 | 75 | def reset(self): 76 | """Called when environment resets. Override for stateful rubrics.""" 77 | self.criteria_reached = [False] * len(self.criteria) 78 | -------------------------------------------------------------------------------- /experiments/example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example experiment config that co-launches policy servers with each job. 3 | 4 | Usage: 5 | python scripts/batch_eval.py --config experiments/with_servers.py 6 | python scripts/batch_eval.py --config experiments/with_servers.py --dry-run 7 | """ 8 | 9 | from polaris.config import EvalArgs, PolicyArgs, BatchConfig, PolicyServer, JobCfg 10 | 11 | 12 | ''' 13 | Define reusable servers. Servers MUST accept a port argument in the command to 14 | avoid policy servers conflicting. `{port}` is auto-replaced with a free port 15 | determined at runtime. 16 | ''' 17 | PI0_FAST_SERVER = PolicyServer( 18 | name="pi0_fast", 19 | command=" ".join([ 20 | "XLA_PYTHON_CLIENT_MEM_FRACTION=0.35", 21 | "~/projects/PolaRiS/third_party/openpi/.venv/bin/python", 22 | "~/projects/PolaRiS/third_party/openpi/scripts/serve_policy.py", 23 | "--port {port}", 24 | "policy:checkpoint --policy.config pi0_fast_droid_jointpos", 25 | "--policy.dir gs://openpi-assets/checkpoints/polaris/pi0_fast_droid_jointpos_polaris", 26 | ]), 27 | ready_message="server listening on", 28 | ) 29 | 30 | PI05_SERVER = PolicyServer( 31 | name="pi05", 32 | command=" ".join([ 33 | "XLA_PYTHON_CLIENT_MEM_FRACTION=0.35", 34 | "~/projects/PolaRiS/third_party/openpi/.venv/bin/python", 35 | "~/projects/PolaRiS/third_party/openpi/scripts/serve_policy.py", 36 | "--port {port}", 37 | "policy:checkpoint --policy.config pi05_droid_jointpos", 38 | "--policy.dir gs://openpi-assets/checkpoints/polaris/pi05_droid_jointpos_polaris", 39 | ]), 40 | ready_message="server listening on", 41 | ) 42 | 43 | # Each job references its server. The server will be launched on the same GPU as the job. 44 | config = BatchConfig( 45 | jobs=[ 46 | # pi0 jobs 47 | JobCfg( 48 | server=PI05_SERVER, 49 | eval_args=EvalArgs( 50 | environment="DROID-MoveLatteCup", 51 | policy=PolicyArgs( 52 | name="pi05_droid_jointpos_cotrained", 53 | client="DroidJointPos", 54 | open_loop_horizon=8, 55 | ), 56 | ), 57 | ), 58 | JobCfg( 59 | server=PI05_SERVER, 60 | eval_args=EvalArgs( 61 | environment="DROID-OrganizeTools", 62 | policy=PolicyArgs( 63 | name="pi05_droid_jointpos_cotrained", 64 | client="DroidJointPos", 65 | open_loop_horizon=8, 66 | ), 67 | ), 68 | ), 69 | JobCfg( 70 | server=PI05_SERVER, 71 | eval_args=EvalArgs( 72 | environment="DROID-TapeIntoContainer", 73 | policy=PolicyArgs( 74 | name="pi05_droid_jointpos_cotrained", 75 | client="DroidJointPos", 76 | open_loop_horizon=8, 77 | ), 78 | ), 79 | ), 80 | 81 | JobCfg( 82 | server=PI05_SERVER, 83 | eval_args=EvalArgs( 84 | environment="DROID-FoodBussing", 85 | policy=PolicyArgs( 86 | name="pi05_droid_jointpos_cotrained", 87 | client="DroidJointPos", 88 | open_loop_horizon=8, 89 | ), 90 | ), 91 | ), 92 | 93 | ], 94 | ) 95 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup, find_packages 13 | from pathlib import Path 14 | 15 | # Read README for long description 16 | this_directory = Path(__file__).parent 17 | try: 18 | long_description = (this_directory / "README.md").read_text() 19 | except FileNotFoundError: 20 | long_description = "Differentiable rasterizer for 2D Gaussian Splatting" 21 | 22 | setup( 23 | name="diff_surfel_rasterization", 24 | version="0.1.0", 25 | author="Binbin Huang", 26 | description="Differentiable rasterizer for 2D Gaussian Splatting", 27 | long_description=long_description, 28 | long_description_content_type="text/markdown", 29 | url="https://github.com/Tordjx/diff-surfel-rasterization", 30 | # Find the diff_surfel_rasterization package 31 | packages=find_packages(), 32 | # Include all source files (.cu, .cpp, .h) for JIT compilation 33 | package_data={ 34 | "": [ 35 | "*.cu", 36 | "*.cpp", 37 | "*.h", 38 | "*.cuh", 39 | "cuda_rasterizer/*.cu", 40 | "cuda_rasterizer/*.h", 41 | "cuda_rasterizer/*.cuh", 42 | ], 43 | }, 44 | # Also include source files at repo root level 45 | # These will be accessible from Path(__file__).parent.parent in __init__.py 46 | data_files=[ 47 | ( 48 | ".", 49 | [ 50 | "ext.cpp", 51 | # Add any root-level .cu files here 52 | ], 53 | ), 54 | ( 55 | "cuda_rasterizer", 56 | [ 57 | # List cuda_rasterizer files if they're at repo root 58 | ], 59 | ), 60 | ], 61 | include_package_data=True, 62 | # Dependencies 63 | python_requires=">=3.7", 64 | install_requires=[ 65 | "torch>=1.13.0", 66 | "ninja", # For faster JIT compilation 67 | ], 68 | # Classifiers 69 | classifiers=[ 70 | "Development Status :: 4 - Beta", 71 | "Intended Audience :: Science/Research", 72 | "Programming Language :: Python :: 3", 73 | "Programming Language :: C++", 74 | "Programming Language :: Python :: 3.7", 75 | "Programming Language :: Python :: 3.8", 76 | "Programming Language :: Python :: 3.9", 77 | "Programming Language :: Python :: 3.10", 78 | "Programming Language :: Python :: 3.11", 79 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 80 | ], 81 | ) 82 | 83 | # ============================================================================== 84 | # NOTES: 85 | # ============================================================================== 86 | # 87 | # This setup.py does NOT use CUDAExtension to pre-compile the extension. 88 | # Instead, it ships source files which are JIT compiled on first import. 89 | # 90 | # OLD APPROACH (removed): 91 | # from torch.utils.cpp_extension import CUDAExtension, BuildExtension 92 | # ext_modules=[ 93 | # CUDAExtension( 94 | # name='diff_surfel_rasterization._C', 95 | # sources=[...], 96 | # extra_compile_args={...} 97 | # ) 98 | # ], 99 | # cmdclass={'build_ext': BuildExtension} 100 | # 101 | # NEW APPROACH: 102 | # - Ship source files via package_data 103 | # - JIT compile in __init__.py using torch.utils.cpp_extension.load() 104 | # - First import takes 2-5 minutes, subsequent imports are instant 105 | # 106 | # BENEFITS: 107 | # ✓ pip install is instant (no compilation) 108 | # ✓ Works on any CUDA version (compiles for user's hardware) 109 | # ✓ No need to build wheels for every Python/CUDA combination 110 | # ✓ Smaller package size (source only, no binaries) 111 | # 112 | # ============================================================================== 113 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | import mediapy 3 | 4 | # import wandb 5 | import tqdm 6 | import gymnasium as gym 7 | import torch 8 | import argparse 9 | import pandas as pd 10 | 11 | 12 | from pathlib import Path 13 | from isaaclab.app import AppLauncher 14 | 15 | from polaris.config import EvalArgs 16 | 17 | 18 | def main(eval_args: EvalArgs): 19 | # This must be done before importing anything from IsaacLab 20 | # Inside main function to avoid launching IsaacLab in global scope 21 | # >>>> Isaac Sim App Launcher <<<< 22 | parser = argparse.ArgumentParser() 23 | args_cli, _ = parser.parse_known_args() 24 | args_cli.enable_cameras = True 25 | args_cli.headless = eval_args.headless 26 | app_launcher = AppLauncher(args_cli) 27 | simulation_app = app_launcher.app 28 | # >>>> Isaac Sim App Launcher <<<< 29 | 30 | from isaaclab_tasks.utils import parse_env_cfg # noqa: E402 31 | from polaris.environments.manager_based_rl_splat_environment import ( 32 | ManagerBasedRLSplatEnv, 33 | ) 34 | from polaris.utils import load_eval_initial_conditions 35 | from polaris.policy import InferenceClient 36 | # from real2simeval.autoscoring import TASK_TO_SUCCESS_CHECKER 37 | 38 | env_cfg = parse_env_cfg( 39 | eval_args.environment, 40 | device="cuda", 41 | num_envs=1, 42 | use_fabric=True, 43 | ) 44 | env: MangerBasedRLSplatEnv = gym.make(eval_args.environment, cfg=env_cfg) # type: ignore 45 | 46 | language_instruction, initial_conditions = load_eval_initial_conditions( 47 | usd=env.usd_file, 48 | initial_conditions_file=eval_args.initial_conditions_file, 49 | rollouts=eval_args.rollouts, 50 | ) 51 | rollouts = len(initial_conditions) 52 | # Resume CSV logging 53 | run_folder = Path(eval_args.run_folder) 54 | run_folder.mkdir(parents=True, exist_ok=True) 55 | csv_path = run_folder / "eval_results.csv" 56 | if csv_path.exists(): 57 | episode_df = pd.read_csv(csv_path) 58 | else: 59 | episode_df = pd.DataFrame( 60 | { 61 | "episode": pd.Series(dtype="int"), 62 | "episode_length": pd.Series(dtype="int"), 63 | "success": pd.Series(dtype="bool"), 64 | "progress": pd.Series(dtype="float"), 65 | } 66 | ) 67 | episode = len(episode_df) 68 | if episode >= rollouts: 69 | print("All rollouts have been evaluated. Exiting.") 70 | env.close() 71 | simulation_app.close() 72 | return 73 | 74 | policy_client: InferenceClient = InferenceClient.get_client(eval_args.policy) 75 | 76 | video = [] 77 | horizon = env.max_episode_length 78 | bar = tqdm.tqdm(range(horizon)) 79 | obs, info = env.reset( 80 | object_positions=initial_conditions[episode % len(initial_conditions)] 81 | ) 82 | policy_client.reset() 83 | print(f" >>> Starting eval job from episode {episode + 1} of {rollouts} <<< ") 84 | while True: 85 | action, viz = policy_client.infer(obs, language_instruction) 86 | if viz is not None: 87 | video.append(viz) 88 | obs, rew, term, trunc, info = env.step( 89 | torch.tensor(action).reshape(1, -1), expensive=policy_client.rerender 90 | ) 91 | 92 | bar.update(1) 93 | if term[0] or trunc[0] or bar.n >= horizon: 94 | policy_client.reset() 95 | 96 | # Save video and metadata 97 | filename = run_folder / f"episode_{episode}.mp4" 98 | mediapy.write_video(filename, video, fps=15) 99 | 100 | # Log episode results to CSV 101 | episode_data = { 102 | "episode": episode, 103 | "episode_length": bar.n, 104 | "success": info["rubric"]["success"], 105 | "progress": info["rubric"]["progress"], 106 | } 107 | episode_df = pd.concat( 108 | [episode_df, pd.DataFrame([episode_data])], ignore_index=True 109 | ) 110 | episode_df.to_csv(csv_path, index=False) 111 | 112 | bar.close() 113 | print(f"Episode {episode} finished. Episode length: {bar.n}") 114 | bar = tqdm.tqdm(range(horizon)) 115 | obs, info = env.reset( 116 | object_positions=initial_conditions[episode % len(initial_conditions)] 117 | ) 118 | 119 | episode += 1 120 | video = [] 121 | if episode >= rollouts: 122 | break 123 | 124 | env.close() 125 | simulation_app.close() 126 | 127 | 128 | if __name__ == "__main__": 129 | args: EvalArgs = tyro.cli(EvalArgs) 130 | main(args) 131 | -------------------------------------------------------------------------------- /src/polaris/policy/droid_jointpos_client.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from openpi_client import websocket_client_policy, image_tools 3 | from polaris.policy.abstract_client import InferenceClient, PolicyArgs 4 | 5 | 6 | # Joint Position Client for DROID 7 | @InferenceClient.register(client_name="DroidJointPos") 8 | class DroidJointPosClient(InferenceClient): 9 | def __init__(self, args: PolicyArgs) -> None: 10 | self.args = args 11 | if args.open_loop_horizon is None: 12 | raise ValueError("open_loop_horizon must be set for DroidJointPosClient") 13 | 14 | self.client = websocket_client_policy.WebsocketClientPolicy( 15 | host=args.host, port=args.port 16 | ) 17 | self.actions_from_chunk_completed = 0 18 | self.pred_action_chunk = None 19 | self.open_loop_horizon = args.open_loop_horizon 20 | 21 | @property 22 | def rerender(self) -> bool: 23 | return ( 24 | self.actions_from_chunk_completed == 0 25 | or self.actions_from_chunk_completed >= self.open_loop_horizon 26 | ) 27 | 28 | def visualize(self, request: dict): 29 | """ 30 | Return the camera views how the model sees it 31 | """ 32 | curr_obs = self._extract_observation(request) 33 | base_img = image_tools.resize_with_pad(curr_obs["right_image"], 224, 224) 34 | wrist_img = image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224) 35 | combined = np.concatenate([base_img, wrist_img], axis=1) 36 | return combined 37 | 38 | def reset(self): 39 | self.actions_from_chunk_completed = 0 40 | self.pred_action_chunk = None 41 | 42 | def infer( 43 | self, obs: dict, instruction: str, return_viz: bool = False 44 | ) -> tuple[np.ndarray, np.ndarray | None]: 45 | """ 46 | Infer the next action from the policy in a server-client setup 47 | """ 48 | both = None 49 | ret = {} 50 | if ( 51 | self.actions_from_chunk_completed == 0 52 | or self.actions_from_chunk_completed >= self.open_loop_horizon 53 | ): 54 | curr_obs = self._extract_observation(obs) 55 | 56 | self.actions_from_chunk_completed = 0 57 | exterior_image = image_tools.resize_with_pad( 58 | curr_obs["right_image"], 224, 224 59 | ) 60 | wrist_image = image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224) 61 | request_data = { 62 | "observation/exterior_image_1_left": exterior_image, 63 | "observation/wrist_image_left": wrist_image, 64 | "observation/joint_position": curr_obs["joint_position"], 65 | "observation/gripper_position": curr_obs["gripper_position"], 66 | "prompt": instruction, 67 | } 68 | server_response = self.client.infer(request_data) 69 | self.pred_action_chunk = server_response["actions"] 70 | both = np.concatenate([exterior_image, wrist_image], axis=1) 71 | 72 | if return_viz and both is None: 73 | curr_obs = self._extract_observation(obs) 74 | both = np.concatenate( 75 | [ 76 | image_tools.resize_with_pad(curr_obs["right_image"], 224, 224), 77 | image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224), 78 | ], 79 | axis=1, 80 | ) 81 | 82 | if self.pred_action_chunk is None: 83 | raise ValueError("No action chunk predicted") 84 | 85 | action = self.pred_action_chunk[self.actions_from_chunk_completed] 86 | self.actions_from_chunk_completed += 1 87 | 88 | # binarize gripper action 89 | if action[-1].item() > 0.5: 90 | action = np.concatenate([action[:-1], np.ones((1,))]) 91 | else: 92 | action = np.concatenate([action[:-1], np.zeros((1,))]) 93 | 94 | return action, both 95 | 96 | def _extract_observation(self, obs_dict): 97 | # Assign images 98 | right_image = obs_dict["splat"]["external_cam"] 99 | wrist_image = obs_dict["splat"]["wrist_cam"] 100 | 101 | # Capture proprioceptive state 102 | robot_state = obs_dict["policy"] 103 | joint_position = robot_state["arm_joint_pos"].clone().detach().cpu().numpy()[0] 104 | gripper_position = robot_state["gripper_pos"].clone().detach().cpu().numpy()[0] 105 | 106 | return { 107 | "right_image": right_image, 108 | "wrist_image": wrist_image, 109 | "joint_position": joint_position, 110 | "gripper_position": gripper_position, 111 | } 112 | -------------------------------------------------------------------------------- /src/simple-knn/LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | 85 | ## 6. Files subject to permissive licenses 86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. 87 | 88 | Title: pytorch-ssim\ 89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ 90 | Copyright Evan Su, 2017\ 91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) -------------------------------------------------------------------------------- /src/polaris/splat_renderer/utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | 25 | C0 = 0.28209479177387814 26 | C1 = 0.4886025119029199 27 | C2 = [ 28 | 1.0925484305920792, 29 | -1.0925484305920792, 30 | 0.31539156525252005, 31 | -1.0925484305920792, 32 | 0.5462742152960396, 33 | ] 34 | C3 = [ 35 | -0.5900435899266435, 36 | 2.890611442640554, 37 | -0.4570457994644658, 38 | 0.3731763325901154, 39 | -0.4570457994644658, 40 | 1.445305721320277, 41 | -0.5900435899266435, 42 | ] 43 | C4 = [ 44 | 2.5033429417967046, 45 | -1.7701307697799304, 46 | 0.9461746957575601, 47 | -0.6690465435572892, 48 | 0.10578554691520431, 49 | -0.6690465435572892, 50 | 0.47308734787878004, 51 | -1.7701307697799304, 52 | 0.6258357354491761, 53 | ] 54 | 55 | 56 | def eval_sh(deg, sh, dirs): 57 | """ 58 | Evaluate spherical harmonics at unit directions 59 | using hardcoded SH polynomials. 60 | Works with torch/np/jnp. 61 | ... Can be 0 or more batch dimensions. 62 | Args: 63 | deg: int SH deg. Currently, 0-3 supported 64 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 65 | dirs: jnp.ndarray unit directions [..., 3] 66 | Returns: 67 | [..., C] 68 | """ 69 | assert deg <= 4 and deg >= 0 70 | coeff = (deg + 1) ** 2 71 | assert sh.shape[-1] >= coeff 72 | 73 | result = C0 * sh[..., 0] 74 | if deg > 0: 75 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 76 | result = ( 77 | result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] 78 | ) 79 | 80 | if deg > 1: 81 | xx, yy, zz = x * x, y * y, z * z 82 | xy, yz, xz = x * y, y * z, x * z 83 | result = ( 84 | result 85 | + C2[0] * xy * sh[..., 4] 86 | + C2[1] * yz * sh[..., 5] 87 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] 88 | + C2[3] * xz * sh[..., 7] 89 | + C2[4] * (xx - yy) * sh[..., 8] 90 | ) 91 | 92 | if deg > 2: 93 | result = ( 94 | result 95 | + C3[0] * y * (3 * xx - yy) * sh[..., 9] 96 | + C3[1] * xy * z * sh[..., 10] 97 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] 98 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] 99 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] 100 | + C3[5] * z * (xx - yy) * sh[..., 14] 101 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15] 102 | ) 103 | 104 | if deg > 3: 105 | result = ( 106 | result 107 | + C4[0] * xy * (xx - yy) * sh[..., 16] 108 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17] 109 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18] 110 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19] 111 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] 112 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21] 113 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] 114 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] 115 | + C4[8] 116 | * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 117 | * sh[..., 24] 118 | ) 119 | return result 120 | 121 | 122 | def RGB2SH(rgb): 123 | return (rgb - 0.5) / C0 124 | 125 | 126 | def SH2RGB(sh): 127 | return sh * C0 + 0.5 128 | -------------------------------------------------------------------------------- /src/polaris/environments/__init__.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from polaris.environments.manager_based_rl_splat_environment import ( 3 | ManagerBasedRLSplatEnv, 4 | ) 5 | from polaris.environments.droid_cfg import EnvCfg as DroidCfg 6 | from isaaclab.envs import ManagerBasedRLEnv 7 | 8 | # Import rubric system 9 | from polaris.environments.rubrics import Rubric 10 | from polaris.utils import DATA_PATH 11 | import polaris.environments.rubrics.checkers as checkers 12 | 13 | 14 | # ============================================================================= 15 | # Environment Registration 16 | # ============================================================================= 17 | 18 | gym.register( 19 | id='DROID-BlockStackKitchen', 20 | entry_point=ManagerBasedRLSplatEnv, 21 | kwargs={ 22 | "env_cfg_entry_point": DroidCfg, 23 | "usd_file": str(DATA_PATH / "block_stack_kitchen/scene.usda"), 24 | "rubric": Rubric( 25 | criteria=[ 26 | checkers.reach("green_cube", threshold=0.2), 27 | checkers.reach("wood_cube", threshold=0.2), 28 | (checkers.lift("green_cube", default_height=0.06, threshold=0.03), [0]), 29 | (checkers.lift("wood_cube", default_height=0.06, threshold=0.03), [1]), 30 | (checkers.is_within_xy("green_cube", "tray", 0.8), [2]), 31 | (checkers.is_within_xy("wood_cube", "tray", 0.8), [3]), 32 | (checkers.is_within_xy("green_cube", "wood_cube", 0.5), [4, 5]), 33 | ] 34 | ), 35 | }, 36 | disable_env_checker=True, 37 | order_enforce=False, 38 | ) 39 | 40 | 41 | gym.register( 42 | id="DROID-FoodBussing", 43 | entry_point=ManagerBasedRLSplatEnv, 44 | disable_env_checker=True, 45 | order_enforce=False, 46 | kwargs={ 47 | "env_cfg_entry_point": DroidCfg, 48 | "usd_file": str(DATA_PATH / "food_bussing/scene.usda"), 49 | "rubric": Rubric( 50 | criteria=[ 51 | checkers.reach("ice_cream_", threshold=0.2), 52 | checkers.reach("grapes", threshold=0.2), 53 | (checkers.lift("ice_cream_", threshold=0.06), [0]), 54 | (checkers.lift("grapes", threshold=0.06), [1]), 55 | ( 56 | checkers.is_within_xy("ice_cream_", "bowl", percent_threshold=0.8), 57 | [2], 58 | ), 59 | (checkers.is_within_xy("grapes", "bowl", percent_threshold=0.8), [3]), 60 | ] 61 | ), 62 | }, 63 | ) 64 | 65 | gym.register( 66 | id="DROID-PanClean", 67 | entry_point=ManagerBasedRLSplatEnv, 68 | disable_env_checker=True, 69 | order_enforce=False, 70 | kwargs={ 71 | "env_cfg_entry_point": DroidCfg, 72 | "usd_file": str(DATA_PATH / "pan_clean/scene.usda"), 73 | "rubric": Rubric( 74 | criteria=[ 75 | checkers.reach("sponge", threshold=0.2), 76 | (checkers.lift("sponge", threshold=0.09, default_height=0.0), [0]), 77 | (checkers.is_within_xy("sponge", "pan", percent_threshold=0.8), [1]), 78 | ] 79 | ), 80 | }, 81 | ) 82 | 83 | 84 | gym.register( 85 | id="DROID-MoveLatteCup", 86 | entry_point=ManagerBasedRLSplatEnv, 87 | disable_env_checker=True, 88 | order_enforce=False, 89 | kwargs={ 90 | "env_cfg_entry_point": DroidCfg, 91 | "usd_file": str(DATA_PATH / "move_latte_cup/scene.usda"), 92 | "rubric": Rubric( 93 | criteria=[ 94 | checkers.reach("latteartcup_eval", threshold=0.2), 95 | (checkers.lift("latteartcup_eval", threshold=0.04), [0]), 96 | (checkers.is_within_xy("latteartcup_eval", "cuttingboard_eval", percent_threshold=0.8), [1]), 97 | ] 98 | ), 99 | }, 100 | ) 101 | 102 | gym.register( 103 | id="DROID-OrganizeTools", 104 | entry_point=ManagerBasedRLSplatEnv, 105 | disable_env_checker=True, 106 | order_enforce=False, 107 | kwargs={ 108 | "env_cfg_entry_point": DroidCfg, 109 | "usd_file": str(DATA_PATH / "organize_tools/scene.usda"), 110 | "rubric": Rubric( 111 | criteria=[ 112 | checkers.reach("scissor", threshold=0.2), 113 | (checkers.lift("scissor", threshold=0.04), [0]), 114 | (checkers.is_within_xy("scissor", "container_01", percent_threshold=0.8), [1]), 115 | ] 116 | ), 117 | }, 118 | ) 119 | 120 | gym.register( 121 | id="DROID-TapeIntoContainer", 122 | entry_point=ManagerBasedRLSplatEnv, 123 | disable_env_checker=True, 124 | order_enforce=False, 125 | kwargs={ 126 | "env_cfg_entry_point": DroidCfg, 127 | "usd_file": str(DATA_PATH / "tape_into_container/scene.usda"), 128 | "rubric": Rubric( 129 | criteria=[ 130 | checkers.reach("tape_00", threshold=0.2), 131 | (checkers.lift("tape_00", threshold=0.04), [0]), 132 | (checkers.is_within_xy("tape_00", "container_02", percent_threshold=0.8), [1]), 133 | ] 134 | ), 135 | }, 136 | ) 137 | -------------------------------------------------------------------------------- /docs/custom_environments.md: -------------------------------------------------------------------------------- 1 | # Creating Custom Environments 2 | 3 | The environments we provide were scanned using ZED cameras, but the reconstruction pipeline is camera agnostic. 4 | 5 | Capture a dense view video of a scene without motion blur, and run it through [COLMAP](https://colmap.github.io/install.html) 6 | 7 | Once you have your COLMAP dataset, follow the instructions in [2DGS](https://github.com/hbb1/2d-gaussian-splatting) to obtain a splat and corresponding extracted mesh. 8 | 9 | Turn the `fuse_post.ply` mesh into a USD, and create an asset directory that follows this structure. 10 | ``` 11 | new_asset/ 12 | ├── mesh.usd 13 | ├── splat.ply 14 | ├── textures/ (optional, if USD requires textures) 15 | └── config.yaml (optional, USD parameter configuratoin) 16 | ``` 17 | 18 | Using the [online scene composition GUI](https://polaris-evals.github.io/compose-environments/), create a USD stage that composes the objects in the scene. Export and unzip the USD with the command below. 19 | ``` 20 | unzip scene.zip -d PolaRiS-Hub/new_env/ 21 | ``` 22 | 23 | You should now have a directory that looks something like this: 24 | ``` 25 | PolaRiS-Hub/ 26 | └── new_env/ 27 | ├── assets/ 28 | │ ├── object_1/ 29 | │ │ └── mesh.usd 30 | │ │ └── textures/ 31 | │ ├── object_2/ 32 | │ │ └── mesh.usd 33 | │ │ └── textures/ 34 | │ └── scene_splat/ 35 | │ ├── config.yaml 36 | │ └── splat.ply 37 | ├── scene.usda # Main USD stage file 38 | └── initial_conditions.json (defined via GUI) 39 | ``` 40 | 41 | Add the new environment to the [environments file](../src/polaris/environments/__init__.py), following the same pattern as the default 6 environments. You can also see how to define a rubric to score rollouts with just a few lines of code. Now you can use this environment by changing the `--environment` flag in the eval script. 42 | 43 | After testing the environment, please consider submitting a PR to upload it to the [PolaRiS-Hub](https://huggingface.co/datasets/owhan/PolaRiS-Hub)! See below for instructions. 44 | 45 | ## Uploading Environments to HuggingFace 46 | 47 | Share your custom environments with the community by uploading them to the [PolaRiS-Hub](https://huggingface.co/datasets/owhan/PolaRiS-Hub) dataset. **All uploads are automatically submitted as pull requests** (not direct commits) for review and quality control. 48 | 49 | ### Environment Structure 50 | 51 | Your environment folder should look something like this: 52 | ``` 53 | PolaRiS-Hub/ 54 | └── new_env/ 55 | ├── assets/ 56 | │ ├── object_1/ 57 | │ │ └── mesh.usd 58 | │ │ └── textures/ 59 | │ ├── object_2/ 60 | │ │ └── mesh.usd 61 | │ │ └── textures/ 62 | │ └── scene_splat/ 63 | │ ├── config.yaml 64 | │ └── splat.ply 65 | ├── scene.usda # Main USD stage file 66 | └── initial_conditions.json (defined via GUI) 67 | ``` 68 | 69 | ### Upload Commands 70 | 71 | ```bash 72 | uv run scripts/upload_env_to_hf.py ./PolaRiS-Hub/new_env --pr-title "Add new_env" --pr-description "Description of the environment" 73 | ``` 74 | 75 | ### CLI Options 76 | 77 | | Flag | Description | 78 | |------|-------------| 79 | | `--dry-run` | Validate only, don't upload | 80 | | `--pr-title` | Title for the pull request | 81 | | `--pr-description` | Description/body for the PR | 82 | | `--repo-id` | Target HF dataset (default: `owhan/PolaRiS-Hub`) | 83 | | `--branch` | Target branch (default: `main`) | 84 | | `--token` | HF token (or use `HF_TOKEN` env var) | 85 | | `--strict` | Treat validation warnings as errors | 86 | | `--require-pxr` | Fail if USD files can't be opened (requires pxr) | 87 | | `--skip-validation` | Skip validation (not recommended) | 88 | 89 | ### How PRs Work for HuggingFace Datasets 90 | 91 | When you run `polaris upload`, the tool automatically: 92 | 93 | 1. Validates your environment structure locally 94 | 2. Creates a pull request (not a direct commit) to the target dataset 95 | 3. Returns the PR URL or instructions to view it 96 | 97 | **Viewing Your PR:** 98 | 99 | - After upload, the CLI will print the PR URL (e.g., `https://huggingface.co/datasets/owhan/PolaRiS-Hub/discussions/`) 100 | - You can also view all PRs at: `https://huggingface.co/datasets/owhan/PolaRiS-Hub/discussions` 101 | - PRs must be reviewed and merged by dataset maintainers before your environment appears in the dataset 102 | 103 | **Merging Your PR:** 104 | 105 | - Navigate to the PR URL in your browser 106 | - Review the changes in the "Files" tab 107 | - Click "Publish" when ready to merge (requires write access to the dataset) 108 | 109 | ### Managing Your PR Locally 110 | 111 | After creating a PR, you can check it out locally to make changes: 112 | 113 | ```bash 114 | # Clone the dataset repo 115 | git clone https://huggingface.co/datasets/owhan/PolaRiS-Hub 116 | cd PolaRiS-Hub 117 | 118 | # Fetch and checkout PR (replace with your PR number from the upload output) 119 | git fetch origin refs/pr/:pr/ 120 | git checkout pr/ 121 | 122 | # Make edits, then push back 123 | git add . 124 | git commit -m "Update environment" 125 | git push origin pr/:refs/pr/ 126 | ``` 127 | -------------------------------------------------------------------------------- /src/polaris/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from datetime import datetime 5 | from pathlib import Path 6 | 7 | from isaaclab.envs import ManagerBasedRLEnvCfg 8 | from isaaclab_tasks.utils import load_cfg_from_registry 9 | 10 | DATA_PATH = ( 11 | Path("./PolaRiS-Hub").resolve() 12 | if "POLARIS_DATA_PATH" not in os.environ 13 | else Path(os.environ["POLARIS_DATA_PATH"]).resolve() 14 | ) 15 | 16 | 17 | def load_eval_initial_conditions( 18 | usd: str, initial_conditions_file: str | None = None, rollouts: int | None = None 19 | ) -> tuple[str, dict]: 20 | """ 21 | If initial_conditions_file is provided, load the initial conditions from the file. 22 | Otherwise, load the initial conditions from the USD file. If neither exist, raise an error. 23 | """ 24 | if initial_conditions_file is None: 25 | initial_conditions_file_path = Path(usd).parent / "initial_conditions.json" 26 | else: 27 | initial_conditions_file_path = Path(initial_conditions_file) 28 | 29 | if not initial_conditions_file_path.exists(): 30 | raise FileNotFoundError( 31 | "Either USD directory must have an initial_conditions.json file, or a custom initial_conditions_file must be provided." 32 | ) 33 | with open(initial_conditions_file_path, "r") as f: 34 | initial_conditions = json.load(f) 35 | 36 | # will have initial conditions and language instruction 37 | if "instruction" not in initial_conditions or "poses" not in initial_conditions: 38 | raise ValueError( 39 | "Initial conditions ill formated. Must contain 'instruction' and 'poses' keys." 40 | ) 41 | instruction = initial_conditions["instruction"] 42 | initial_conditions = ( 43 | initial_conditions["poses"] 44 | if rollouts is None 45 | else initial_conditions["poses"][:rollouts] 46 | ) 47 | return instruction, initial_conditions 48 | 49 | 50 | def run_folder_path(run_folder: str | None, usd: str, policy: str) -> Path: 51 | """ 52 | If run_folder is not provided, create a new run folder in the runs directory with the current date and time. 53 | Otherwise, use the provided run folder. 54 | """ 55 | if not run_folder: 56 | run_folder_path = f"runs/{datetime.now().strftime('%Y-%m-%d')}/{datetime.now().strftime('%I:%M:%S %p')}" 57 | else: 58 | run_folder_path = run_folder 59 | 60 | run_folder_path = Path(run_folder_path) / Path(usd).stem / policy 61 | print(f" >>> Saving to {run_folder_path} <<< ") 62 | run_folder_path.mkdir(parents=True, exist_ok=True) 63 | return run_folder_path 64 | 65 | 66 | def parse_env_cfg( 67 | task_name: str, 68 | usd_file: str, 69 | device: str = "cuda:0", 70 | num_envs: int | None = None, 71 | use_fabric: bool | None = None, 72 | ) -> ManagerBasedRLEnvCfg: 73 | """ 74 | Parse configuration for an environment and override based on inputs. 75 | Adapted from isaaclab_tasks.utils.parse_env_cfg. 76 | 77 | New Parameters 78 | -------------- 79 | usd_file: str 80 | Path to USD file we want to use 81 | """ 82 | # load the default configuration 83 | cfg = load_cfg_from_registry(task_name.split(":")[-1], "env_cfg_entry_point") 84 | 85 | # check that it is not a dict 86 | if isinstance(cfg, dict): 87 | raise RuntimeError( 88 | f"Configuration for the task: '{task_name}' is not a class. Please provide a class." 89 | ) 90 | 91 | cfg.dynamic_setup(usd_file) 92 | 93 | # simulation device 94 | cfg.sim.device = device 95 | # disable fabric to read/write through USD 96 | if use_fabric is not None: 97 | cfg.sim.use_fabric = use_fabric 98 | # number of environments 99 | if num_envs is not None: 100 | cfg.scene.num_envs = num_envs 101 | 102 | return cfg 103 | 104 | 105 | def rotate_vector_by_quaternion(q, v): 106 | """Rotate vectors by quaternions using the fast Hamilton product. 107 | 108 | Args: 109 | q: (4) tensor of quaternions in [w,x,y,z] format 110 | v: (..., 3) tensor of vectors to rotate 111 | Returns: 112 | (..., 3) tensor of rotated vectors 113 | """ 114 | q = q.repeat(v.shape[:-1] + (1,)) 115 | # Extract quaternion components 116 | qw = q[..., 0] 117 | qv = q[..., 1:] 118 | 119 | # uv = 2 * cross(qv, v) 120 | uv = 2 * torch.cross(qv, v, dim=-1) 121 | 122 | # return v + qw * uv + cross(qv, uv) 123 | return v + qw[..., None] * uv + torch.cross(qv, uv, dim=-1) 124 | 125 | 126 | def multiply_quaternions(q1, q2): 127 | """Fast quaternion multiplication using PyTorch. 128 | Assumes quaternions are in [w,x,y,z] format. 129 | 130 | Args: 131 | q1: (N, 4) tensor of quaternions 132 | q2: (N, 4) tensor of quaternions 133 | Returns: 134 | (N, 4) tensor of resulting quaternions 135 | """ 136 | w1, x1, y1, z1 = q1.unbind(-1) 137 | w2, x2, y2, z2 = q2.unbind(-1) 138 | 139 | # Compute components directly 140 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 141 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 142 | y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 143 | z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 144 | 145 | return torch.stack((w, x, y, z), dim=-1) 146 | -------------------------------------------------------------------------------- /src/polaris/splat_renderer/scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from polaris.splat_renderer.utils.graphics_utils import ( 16 | getWorld2View2, 17 | getProjectionMatrix, 18 | ) 19 | 20 | 21 | class Camera(nn.Module): 22 | def __init__( 23 | self, 24 | colmap_id, 25 | R, 26 | T, 27 | FoVx, 28 | FoVy, 29 | image, 30 | gt_alpha_mask, 31 | image_name, 32 | uid, 33 | trans=np.array([0.0, 0.0, 0.0]), 34 | scale=1.0, 35 | data_device="cuda", 36 | ): 37 | super(Camera, self).__init__() 38 | 39 | self.uid = uid 40 | self.colmap_id = colmap_id 41 | self.R = R 42 | self.T = T 43 | self.FoVx = FoVx 44 | self.FoVy = FoVy 45 | self.image_name = image_name 46 | 47 | try: 48 | self.data_device = torch.device(data_device) 49 | except Exception as e: 50 | print(e) 51 | print( 52 | f"[Warning] Custom device {data_device} failed, fallback to default cuda device" 53 | ) 54 | self.data_device = torch.device("cuda") 55 | 56 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 57 | self.image_width = self.original_image.shape[2] 58 | self.image_height = self.original_image.shape[1] 59 | 60 | if gt_alpha_mask is not None: 61 | # self.original_image *= gt_alpha_mask.to(self.data_device) 62 | self.gt_alpha_mask = gt_alpha_mask.to(self.data_device) 63 | else: 64 | self.original_image *= torch.ones( 65 | (1, self.image_height, self.image_width), device=self.data_device 66 | ) 67 | self.gt_alpha_mask = None 68 | 69 | self.zfar = 100 70 | self.znear = 0.05 71 | # self.znear = 10 72 | 73 | self.trans = trans 74 | self.scale = scale 75 | 76 | self.world_view_transform = ( 77 | torch.tensor(getWorld2View2(R, T, trans, scale)) 78 | .transpose(0, 1) 79 | .to(self.data_device) 80 | ) 81 | self.projection_matrix = ( 82 | getProjectionMatrix( 83 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 84 | ) 85 | .transpose(0, 1) 86 | .to(self.data_device) 87 | ) 88 | self.full_proj_transform = ( 89 | self.world_view_transform.unsqueeze(0).bmm( 90 | self.projection_matrix.unsqueeze(0) 91 | ) 92 | ).squeeze(0) 93 | self.camera_center = self.world_view_transform.inverse()[3, :3] 94 | 95 | def set_extrinsics(self, R, T): 96 | self.R = R 97 | self.T = T 98 | center = np.zeros(3) 99 | self.world_view_transform = ( 100 | torch.tensor(getWorld2View2(R, center, T, self.scale)) 101 | .transpose(0, 1) 102 | .to(self.data_device) 103 | ) 104 | # self.world_view_transform = torch.tensor(getWorld2View2(R, T, self.trans, self.scale)).transpose(0, 1).to(self.data_device) 105 | self.full_proj_transform = ( 106 | self.world_view_transform.unsqueeze(0).bmm( 107 | self.projection_matrix.unsqueeze(0) 108 | ) 109 | ).squeeze(0) 110 | self.camera_center = self.world_view_transform.inverse()[3, :3] 111 | 112 | # print(self.full_proj_transform) 113 | 114 | def set_extrinsics2(self, R, T): 115 | self.R = R 116 | self.T = T 117 | center = np.zeros(3) 118 | # self.world_view_transform = torch.tensor(getWorld2View2(R, center, T, self.scale)).transpose(0, 1).to(self.data_device) 119 | self.world_view_transform = ( 120 | torch.tensor(getWorld2View2(R, T, self.trans, self.scale)) 121 | .transpose(0, 1) 122 | .to(self.data_device) 123 | ) 124 | self.full_proj_transform = ( 125 | self.world_view_transform.unsqueeze(0).bmm( 126 | self.projection_matrix.unsqueeze(0) 127 | ) 128 | ).squeeze(0) 129 | self.camera_center = self.world_view_transform.inverse()[3, :3] 130 | 131 | # print(self.full_proj_transform) 132 | 133 | 134 | class MiniCam: 135 | def __init__( 136 | self, 137 | width, 138 | height, 139 | fovy, 140 | fovx, 141 | znear, 142 | zfar, 143 | world_view_transform, 144 | full_proj_transform, 145 | ): 146 | self.image_width = width 147 | self.image_height = height 148 | self.FoVy = fovy 149 | self.FoVx = fovx 150 | self.znear = znear 151 | self.zfar = zfar 152 | self.world_view_transform = world_view_transform 153 | self.full_proj_transform = full_proj_transform 154 | view_inv = torch.inverse(self.world_view_transform) 155 | self.camera_center = view_inv[3][:3] 156 | -------------------------------------------------------------------------------- /src/simple-knn/simple_knn/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | 7 | import torch 8 | 9 | # ============================================================================== 10 | # JIT COMPILATION SUPPORT 11 | # ============================================================================== 12 | # Try to import pre-compiled _C module, if not available, compile via JIT 13 | # ============================================================================== 14 | 15 | try: 16 | # Try to import pre-compiled extension (from setup.py build) 17 | from . import _simple_knn as _C 18 | except ImportError: 19 | # If not available, compile via JIT on first import 20 | import os 21 | from pathlib import Path 22 | 23 | def _load_extension_jit(): 24 | """JIT compile the CUDA extension if pre-built version not available.""" 25 | from torch.utils.cpp_extension import load 26 | 27 | # Get source directory 28 | # simple-knn typically has structure: 29 | # simple-knn/ 30 | # simple_knn/ 31 | # __init__.py (this file) 32 | # spatial.cu 33 | # simple_knn.cu 34 | # ext.cpp 35 | 36 | _pkg_path = Path(__file__).parent 37 | _src_path = _pkg_path.parent # Go up to repo root where .cu files are 38 | 39 | # Find all source files 40 | sources = [] 41 | 42 | # Common simple-knn source files 43 | potential_files = [ 44 | _src_path / "ext.cpp", 45 | _src_path / "spatial.cu", 46 | _src_path / "simple_knn.cu", 47 | ] 48 | 49 | for f in potential_files: 50 | if f.exists(): 51 | sources.append(str(f)) 52 | 53 | # Also search recursively for any .cu/.cpp files we might have missed 54 | for ext in ["*.cu", "*.cpp"]: 55 | for p in _src_path.rglob(ext): 56 | p_str = str(p) 57 | if p_str not in sources and "test" not in p_str.lower(): 58 | sources.append(p_str) 59 | 60 | if not sources: 61 | raise FileNotFoundError( 62 | f"No source files found in {_src_path}. " 63 | "Make sure simple-knn is properly installed.\n" 64 | f"Package path: {_pkg_path}\n" 65 | f"Source path: {_src_path}" 66 | ) 67 | 68 | # Compilation settings 69 | extra_cuda_cflags = [ 70 | "-O3", 71 | "--use_fast_math", 72 | "-std=c++17", 73 | "--expt-relaxed-constexpr", 74 | ] 75 | 76 | extra_cflags = ["-O3", "-std=c++17"] 77 | 78 | # Include directories 79 | include_dirs = [str(_src_path)] 80 | 81 | # Build directory 82 | cuda_ver = ( 83 | torch.version.cuda.replace(".", "_") if torch.cuda.is_available() else "cpu" 84 | ) 85 | build_dir = os.path.join( 86 | os.path.expanduser("~"), 87 | ".cache", 88 | "torch_extensions", 89 | f"simple_knn_cu{cuda_ver}", 90 | ) 91 | 92 | # Create build directory if it doesn't exist 93 | os.makedirs(build_dir, exist_ok=True) 94 | 95 | is_first_build = not os.path.exists(os.path.join(build_dir, "build.ninja")) 96 | 97 | if is_first_build: 98 | print("\n" + "=" * 70) 99 | print("Compiling simple-knn (first time only)...") 100 | print("This will take 1-2 minutes.") 101 | print("=" * 70 + "\n") 102 | 103 | try: 104 | extension = load( 105 | name="simple_knn_cuda", 106 | sources=sources, 107 | extra_cflags=extra_cflags, 108 | extra_cuda_cflags=extra_cuda_cflags, 109 | extra_include_paths=include_dirs, 110 | build_directory=build_dir, 111 | verbose=is_first_build, 112 | with_cuda=True, 113 | ) 114 | 115 | if is_first_build: 116 | print("\n✓ Compilation successful! Cached for future use.\n") 117 | 118 | return extension 119 | 120 | except Exception as e: 121 | print("\n" + "=" * 70) 122 | print("ERROR: Failed to compile simple-knn") 123 | print("=" * 70) 124 | print(f"\n{e}\n") 125 | print("Requirements:") 126 | print(" - CUDA toolkit installed") 127 | print(" - Compatible C++ compiler (gcc 7-12)") 128 | print(" - PyTorch with CUDA support") 129 | print("=" * 70 + "\n") 130 | raise 131 | 132 | # Load via JIT 133 | if not torch.cuda.is_available(): 134 | raise RuntimeError( 135 | "CUDA not available. simple-knn requires CUDA.\n" 136 | f"PyTorch version: {torch.__version__}" 137 | ) 138 | 139 | _C = _load_extension_jit() 140 | 141 | # ============================================================================== 142 | # Export the distCUDA2 function (main API) 143 | # ============================================================================== 144 | 145 | 146 | def distCUDA2(points): 147 | """ 148 | Compute KNN distances for points using CUDA. 149 | 150 | Args: 151 | points: Tensor of shape (N, 3) containing 3D points 152 | 153 | Returns: 154 | Tensor of shape (N,) containing squared distances to nearest neighbors 155 | """ 156 | return _C.distCUDA2(points) 157 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | data/ 3 | runs/ 4 | PolaRiS-Hub/ 5 | PolaRiS-datasets/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[codz] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # UV 104 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | #uv.lock 108 | 109 | # poetry 110 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 111 | # This is especially recommended for binary packages to ensure reproducibility, and is more 112 | # commonly ignored for libraries. 113 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 114 | #poetry.lock 115 | #poetry.toml 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 120 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 121 | #pdm.lock 122 | #pdm.toml 123 | .pdm-python 124 | .pdm-build/ 125 | 126 | # pixi 127 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 128 | #pixi.lock 129 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 130 | # in the .venv directory. It is recommended not to include this directory in version control. 131 | .pixi 132 | 133 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 134 | __pypackages__/ 135 | 136 | # Celery stuff 137 | celerybeat-schedule 138 | celerybeat.pid 139 | 140 | # SageMath parsed files 141 | *.sage.py 142 | 143 | # Environments 144 | .env 145 | .envrc 146 | .venv 147 | env/ 148 | venv/ 149 | ENV/ 150 | env.bak/ 151 | venv.bak/ 152 | 153 | # Spyder project settings 154 | .spyderproject 155 | .spyproject 156 | 157 | # Rope project settings 158 | .ropeproject 159 | 160 | # mkdocs documentation 161 | /site 162 | 163 | # mypy 164 | .mypy_cache/ 165 | .dmypy.json 166 | dmypy.json 167 | 168 | # Pyre type checker 169 | .pyre/ 170 | 171 | # pytype static type analyzer 172 | .pytype/ 173 | 174 | # Cython debug symbols 175 | cython_debug/ 176 | 177 | # PyCharm 178 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 179 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 180 | # and can be added to the global gitignore or merged into this file. For a more nuclear 181 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 182 | #.idea/ 183 | 184 | # Abstra 185 | # Abstra is an AI-powered process automation framework. 186 | # Ignore directories containing user credentials, local state, and settings. 187 | # Learn more at https://abstra.io/docs 188 | .abstra/ 189 | 190 | # Visual Studio Code 191 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 192 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 193 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 194 | # you could uncomment the following to ignore the entire vscode folder 195 | # .vscode/ 196 | 197 | # Ruff stuff: 198 | .ruff_cache/ 199 | 200 | # PyPI configuration file 201 | .pypirc 202 | 203 | # Cursor 204 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 205 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 206 | # refer to https://docs.cursor.com/context/ignore-files 207 | .cursorignore 208 | .cursorindexingignore 209 | 210 | # Marimo 211 | marimo/_static/ 212 | marimo/_lsp/ 213 | __marimo__/ 214 | -------------------------------------------------------------------------------- /src/polaris/splat_renderer/gaussian_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from diff_surfel_rasterization import GaussianRasterizationSettings, GaussianRasterizer 4 | from polaris.splat_renderer.scene.gaussian_model import GaussianModel 5 | import polaris.splat_renderer.utils.sh_utils as sh_utils 6 | import polaris.splat_renderer.utils.point_utils as point_utils 7 | 8 | 9 | def render( 10 | viewpoint_camera, 11 | pc: GaussianModel, 12 | pipe, 13 | bg_color: torch.Tensor, 14 | scaling_modifier=1.0, 15 | override_color=None, 16 | ): 17 | """ 18 | Render the scene. 19 | 20 | Background tensor (bg_color) must be on GPU! 21 | """ 22 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 23 | screenspace_points = ( 24 | torch.zeros_like( 25 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" 26 | ) 27 | + 0 28 | ) 29 | try: 30 | screenspace_points.retain_grad() 31 | except: 32 | pass 33 | 34 | # Set up rasterization configuration 35 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 36 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 37 | 38 | raster_settings = GaussianRasterizationSettings( 39 | image_height=int(viewpoint_camera.image_height), 40 | image_width=int(viewpoint_camera.image_width), 41 | tanfovx=tanfovx, 42 | tanfovy=tanfovy, 43 | bg=bg_color, 44 | scale_modifier=scaling_modifier, 45 | viewmatrix=viewpoint_camera.world_view_transform, 46 | projmatrix=viewpoint_camera.full_proj_transform, 47 | sh_degree=pc.active_sh_degree, 48 | campos=viewpoint_camera.camera_center, 49 | prefiltered=False, 50 | debug=False, 51 | near_n=viewpoint_camera.znear, 52 | far_n=viewpoint_camera.zfar, 53 | # pipe.debug 54 | ) 55 | 56 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 57 | 58 | means3D = pc.get_xyz 59 | means2D = screenspace_points 60 | opacity = pc.get_opacity 61 | 62 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 63 | # scaling / rotation by the rasterizer. 64 | scales = None 65 | rotations = None 66 | cov3D_precomp = None 67 | if pipe.compute_cov3D_python: 68 | # currently don't support normal consistency loss if use precomputed covariance 69 | splat2world = pc.get_covariance(scaling_modifier) 70 | W, H = viewpoint_camera.image_width, viewpoint_camera.image_height 71 | near, far = viewpoint_camera.znear, viewpoint_camera.zfar 72 | 73 | ndc2pix = ( 74 | torch.tensor( 75 | [ 76 | [W / 2, 0, 0, (W - 1) / 2], 77 | [0, H / 2, 0, (H - 1) / 2], 78 | [0, 0, far - near, near], 79 | [0, 0, 0, 1], 80 | ] 81 | ) 82 | .float() 83 | .cuda() 84 | .T 85 | ) 86 | world2pix = viewpoint_camera.full_proj_transform @ ndc2pix 87 | cov3D_precomp = ( 88 | (splat2world[:, [0, 1, 3]] @ world2pix[:, [0, 1, 3]]) 89 | .permute(0, 2, 1) 90 | .reshape(-1, 9) 91 | ) # column major 92 | else: 93 | scales = pc.get_scaling 94 | rotations = pc.get_rotation 95 | 96 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 97 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 98 | pipe.convert_SHs_python = False 99 | shs = None 100 | colors_precomp = None 101 | if override_color is None: 102 | if pipe.convert_SHs_python: 103 | shs_view = pc.get_features.transpose(1, 2).view( 104 | -1, 3, (pc.max_sh_degree + 1) ** 2 105 | ) 106 | dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat( 107 | pc.get_features.shape[0], 1 108 | ) 109 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 110 | sh2rgb = sh_utils.eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 111 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 112 | else: 113 | shs = pc.get_features 114 | else: 115 | colors_precomp = override_color 116 | 117 | # breakpoint() 118 | 119 | rendered_image, radii, allmap = rasterizer( 120 | means3D=means3D, 121 | means2D=means2D, 122 | shs=shs, 123 | colors_precomp=colors_precomp, 124 | opacities=opacity, 125 | scales=scales, 126 | rotations=rotations, 127 | cov3D_precomp=cov3D_precomp, 128 | ) 129 | 130 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 131 | # They will be excluded from value updates used in the splitting criteria. 132 | rets = { 133 | "render": rendered_image, 134 | "viewspace_points": means2D, 135 | "visibility_filter": radii > 0, 136 | "radii": radii, 137 | } 138 | 139 | # additional regularizations 140 | render_alpha = allmap[1:2] 141 | 142 | # get normal map 143 | # transform normal from view space to world space 144 | render_normal = allmap[2:5] 145 | render_normal = ( 146 | render_normal.permute(1, 2, 0) 147 | @ (viewpoint_camera.world_view_transform[:3, :3].T) 148 | ).permute(2, 0, 1) 149 | 150 | # get median depth map 151 | render_depth_median = allmap[5:6] 152 | render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) 153 | 154 | # get expected depth map 155 | render_depth_expected = allmap[0:1] 156 | render_depth_expected = render_depth_expected / render_alpha 157 | render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) 158 | 159 | # get depth distortion map 160 | render_dist = allmap[6:7] 161 | 162 | # psedo surface attributes 163 | # surf depth is either median or expected by setting depth_ratio to 1 or 0 164 | # for bounded scene, use median depth, i.e., depth_ratio = 1; 165 | # for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing. 166 | surf_depth = ( 167 | render_depth_expected * (1 - pipe.depth_ratio) 168 | + (pipe.depth_ratio) * render_depth_median 169 | ) 170 | 171 | # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. 172 | surf_normal = point_utils.depth_to_normal(viewpoint_camera, surf_depth) 173 | surf_normal = surf_normal.permute(2, 0, 1) 174 | # remember to multiply with accum_alpha since render_normal is unnormalized. 175 | surf_normal = surf_normal * (render_alpha).detach() 176 | 177 | rets.update( 178 | { 179 | "rend_alpha": render_alpha, 180 | "rend_normal": render_normal, 181 | "rend_dist": render_dist, 182 | "surf_depth": surf_depth, 183 | "surf_normal": surf_normal, 184 | } 185 | ) 186 | 187 | return rets 188 | -------------------------------------------------------------------------------- /src/polaris/splat_renderer/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x / (1 - x)) 21 | 22 | 23 | def PILtoTorch(pil_image, resolution): 24 | resized_image_PIL = pil_image.resize(resolution) 25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 26 | if len(resized_image.shape) == 3: 27 | return resized_image.permute(2, 0, 1) 28 | else: 29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 30 | 31 | 32 | def get_expon_lr_func( 33 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 34 | ): 35 | """ 36 | Copied from Plenoxels 37 | 38 | Continuous learning rate decay function. Adapted from JaxNeRF 39 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 40 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 41 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 42 | function of lr_delay_mult, such that the initial learning rate is 43 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 44 | to the normal learning rate when steps>lr_delay_steps. 45 | :param conf: config subtree 'lr' or similar 46 | :param max_steps: int, the number of steps during optimization. 47 | :return HoF which takes step as input 48 | """ 49 | 50 | def helper(step): 51 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 52 | # Disable this parameter 53 | return 0.0 54 | if lr_delay_steps > 0: 55 | # A kind of reverse cosine decay. 56 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 57 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 58 | ) 59 | else: 60 | delay_rate = 1.0 61 | t = np.clip(step / max_steps, 0, 1) 62 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 63 | return delay_rate * log_lerp 64 | 65 | return helper 66 | 67 | 68 | def strip_lowerdiag(L): 69 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 70 | 71 | uncertainty[:, 0] = L[:, 0, 0] 72 | uncertainty[:, 1] = L[:, 0, 1] 73 | uncertainty[:, 2] = L[:, 0, 2] 74 | uncertainty[:, 3] = L[:, 1, 1] 75 | uncertainty[:, 4] = L[:, 1, 2] 76 | uncertainty[:, 5] = L[:, 2, 2] 77 | return uncertainty 78 | 79 | 80 | def strip_symmetric(sym): 81 | return strip_lowerdiag(sym) 82 | 83 | 84 | def build_rotation(r): 85 | norm = torch.sqrt( 86 | r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] 87 | ) 88 | 89 | q = r / norm[:, None] 90 | 91 | R = torch.zeros((q.size(0), 3, 3), device="cuda") 92 | 93 | r = q[:, 0] 94 | x = q[:, 1] 95 | y = q[:, 2] 96 | z = q[:, 3] 97 | 98 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 99 | R[:, 0, 1] = 2 * (x * y - r * z) 100 | R[:, 0, 2] = 2 * (x * z + r * y) 101 | R[:, 1, 0] = 2 * (x * y + r * z) 102 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 103 | R[:, 1, 2] = 2 * (y * z - r * x) 104 | R[:, 2, 0] = 2 * (x * z - r * y) 105 | R[:, 2, 1] = 2 * (y * z + r * x) 106 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 107 | return R 108 | 109 | 110 | def build_scaling_rotation(s, r): 111 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 112 | R = build_rotation(r) 113 | 114 | L[:, 0, 0] = s[:, 0] 115 | L[:, 1, 1] = s[:, 1] 116 | L[:, 2, 2] = s[:, 2] 117 | 118 | L = R @ L 119 | return L 120 | 121 | 122 | def safe_state(silent): 123 | old_f = sys.stdout 124 | 125 | class F: 126 | def __init__(self, silent): 127 | self.silent = silent 128 | 129 | def write(self, x): 130 | if not self.silent: 131 | if x.endswith("\n"): 132 | old_f.write( 133 | x.replace( 134 | "\n", 135 | " [{}]\n".format( 136 | str(datetime.now().strftime("%d/%m %H:%M:%S")) 137 | ), 138 | ) 139 | ) 140 | else: 141 | old_f.write(x) 142 | 143 | def flush(self): 144 | old_f.flush() 145 | 146 | sys.stdout = F(silent) 147 | 148 | random.seed(0) 149 | np.random.seed(0) 150 | torch.manual_seed(0) 151 | torch.cuda.set_device(torch.device("cuda:0")) 152 | 153 | 154 | def create_rotation_matrix_from_direction_vector_batch(direction_vectors): 155 | # Normalize the batch of direction vectors 156 | direction_vectors = direction_vectors / torch.norm( 157 | direction_vectors, dim=-1, keepdim=True 158 | ) 159 | # Create a batch of arbitrary vectors that are not collinear with the direction vectors 160 | v1 = ( 161 | torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32) 162 | .to(direction_vectors.device) 163 | .expand(direction_vectors.shape[0], -1) 164 | .clone() 165 | ) 166 | is_collinear = torch.all(torch.abs(direction_vectors - v1) < 1e-5, dim=-1) 167 | v1[is_collinear] = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32).to( 168 | direction_vectors.device 169 | ) 170 | 171 | # Calculate the first orthogonal vectors 172 | v1 = torch.cross(direction_vectors, v1) 173 | v1 = v1 / (torch.norm(v1, dim=-1, keepdim=True)) 174 | # Calculate the second orthogonal vectors by taking the cross product 175 | v2 = torch.cross(direction_vectors, v1) 176 | v2 = v2 / (torch.norm(v2, dim=-1, keepdim=True)) 177 | # Create the batch of rotation matrices with the direction vectors as the last columns 178 | rotation_matrices = torch.stack((v1, v2, direction_vectors), dim=-1) 179 | return rotation_matrices 180 | 181 | 182 | # from kornia.geometry import conversions 183 | # def normal_to_rotation(normals): 184 | # rotations = create_rotation_matrix_from_direction_vector_batch(normals) 185 | # rotations = conversions.rotation_matrix_to_quaternion(rotations,eps=1e-5, order=conversions.QuaternionCoeffOrder.WXYZ) 186 | # return rotations 187 | 188 | 189 | def colormap(img, cmap="jet"): 190 | import matplotlib.pyplot as plt 191 | 192 | W, H = img.shape[:2] 193 | dpi = 300 194 | fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) 195 | im = ax.imshow(img, cmap=cmap) 196 | ax.set_axis_off() 197 | fig.colorbar(im, ax=ax) 198 | fig.tight_layout() 199 | fig.canvas.draw() 200 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 201 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 202 | img = torch.from_numpy(data / 255.0).float().permute(2, 0, 1) 203 | plt.close() 204 | return img 205 | -------------------------------------------------------------------------------- /src/simple-knn/simple_knn.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #define BOX_SIZE 1024 13 | 14 | #include "cuda_runtime.h" 15 | #include "device_launch_parameters.h" 16 | #include "simple_knn.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #define __CUDACC__ 25 | #include 26 | #include 27 | 28 | namespace cg = cooperative_groups; 29 | 30 | struct CustomMin 31 | { 32 | __device__ __forceinline__ 33 | float3 operator()(const float3& a, const float3& b) const { 34 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; 35 | } 36 | }; 37 | 38 | struct CustomMax 39 | { 40 | __device__ __forceinline__ 41 | float3 operator()(const float3& a, const float3& b) const { 42 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; 43 | } 44 | }; 45 | 46 | __host__ __device__ uint32_t prepMorton(uint32_t x) 47 | { 48 | x = (x | (x << 16)) & 0x030000FF; 49 | x = (x | (x << 8)) & 0x0300F00F; 50 | x = (x | (x << 4)) & 0x030C30C3; 51 | x = (x | (x << 2)) & 0x09249249; 52 | return x; 53 | } 54 | 55 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) 56 | { 57 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); 58 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); 59 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); 60 | 61 | return x | (y << 1) | (z << 2); 62 | } 63 | 64 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) 65 | { 66 | auto idx = cg::this_grid().thread_rank(); 67 | if (idx >= P) 68 | return; 69 | 70 | codes[idx] = coord2Morton(points[idx], minn, maxx); 71 | } 72 | 73 | struct MinMax 74 | { 75 | float3 minn; 76 | float3 maxx; 77 | }; 78 | 79 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) 80 | { 81 | auto idx = cg::this_grid().thread_rank(); 82 | 83 | MinMax me; 84 | if (idx < P) 85 | { 86 | me.minn = points[indices[idx]]; 87 | me.maxx = points[indices[idx]]; 88 | } 89 | else 90 | { 91 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; 92 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; 93 | } 94 | 95 | __shared__ MinMax redResult[BOX_SIZE]; 96 | 97 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2) 98 | { 99 | if (threadIdx.x < 2 * off) 100 | redResult[threadIdx.x] = me; 101 | __syncthreads(); 102 | 103 | if (threadIdx.x < off) 104 | { 105 | MinMax other = redResult[threadIdx.x + off]; 106 | me.minn.x = min(me.minn.x, other.minn.x); 107 | me.minn.y = min(me.minn.y, other.minn.y); 108 | me.minn.z = min(me.minn.z, other.minn.z); 109 | me.maxx.x = max(me.maxx.x, other.maxx.x); 110 | me.maxx.y = max(me.maxx.y, other.maxx.y); 111 | me.maxx.z = max(me.maxx.z, other.maxx.z); 112 | } 113 | __syncthreads(); 114 | } 115 | 116 | if (threadIdx.x == 0) 117 | boxes[blockIdx.x] = me; 118 | } 119 | 120 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) 121 | { 122 | float3 diff = { 0, 0, 0 }; 123 | if (p.x < box.minn.x || p.x > box.maxx.x) 124 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); 125 | if (p.y < box.minn.y || p.y > box.maxx.y) 126 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); 127 | if (p.z < box.minn.z || p.z > box.maxx.z) 128 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); 129 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; 130 | } 131 | 132 | template 133 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn) 134 | { 135 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; 136 | float dist = d.x * d.x + d.y * d.y + d.z * d.z; 137 | for (int j = 0; j < K; j++) 138 | { 139 | if (knn[j] > dist) 140 | { 141 | float t = knn[j]; 142 | knn[j] = dist; 143 | dist = t; 144 | } 145 | } 146 | } 147 | 148 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) 149 | { 150 | int idx = cg::this_grid().thread_rank(); 151 | if (idx >= P) 152 | return; 153 | 154 | float3 point = points[indices[idx]]; 155 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; 156 | 157 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) 158 | { 159 | if (i == idx) 160 | continue; 161 | updateKBest<3>(point, points[indices[i]], best); 162 | } 163 | 164 | float reject = best[2]; 165 | best[0] = FLT_MAX; 166 | best[1] = FLT_MAX; 167 | best[2] = FLT_MAX; 168 | 169 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) 170 | { 171 | MinMax box = boxes[b]; 172 | float dist = distBoxPoint(box, point); 173 | if (dist > reject || dist > best[2]) 174 | continue; 175 | 176 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) 177 | { 178 | if (i == idx) 179 | continue; 180 | updateKBest<3>(point, points[indices[i]], best); 181 | } 182 | } 183 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; 184 | } 185 | 186 | void SimpleKNN::knn(int P, float3* points, float* meanDists) 187 | { 188 | float3* result; 189 | cudaMalloc(&result, sizeof(float3)); 190 | size_t temp_storage_bytes; 191 | 192 | float3 init = { 0, 0, 0 }, minn, maxx; 193 | 194 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); 195 | thrust::device_vector temp_storage(temp_storage_bytes); 196 | 197 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); 198 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); 199 | 200 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); 201 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); 202 | 203 | thrust::device_vector morton(P); 204 | thrust::device_vector morton_sorted(P); 205 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); 206 | 207 | thrust::device_vector indices(P); 208 | thrust::sequence(indices.begin(), indices.end()); 209 | thrust::device_vector indices_sorted(P); 210 | 211 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 212 | temp_storage.resize(temp_storage_bytes); 213 | 214 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 215 | 216 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; 217 | thrust::device_vector boxes(num_boxes); 218 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); 219 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); 220 | 221 | cudaFree(result); 222 | } -------------------------------------------------------------------------------- /src/polaris/environments/rubrics/checkers.py: -------------------------------------------------------------------------------- 1 | from pxr import Usd, UsdGeom 2 | from pxr import Gf 3 | from omni.usd import get_context 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def reach(obj_name, threshold=0.05): 9 | """ 10 | Returns a checker function that expects (env). 11 | Example: is_near("bowl", threshold=0.04)(env) 12 | """ 13 | 14 | def checker(env): 15 | obj_pos = env.scene[obj_name].data.root_pos_w[0] 16 | ee_pos = env.scene["ee_frame"].data.target_pos_w[0] 17 | dist = torch.norm(obj_pos - ee_pos) 18 | return dist < threshold 19 | 20 | return checker 21 | 22 | 23 | def lift(obj_name, threshold=0.05, default_height=None): 24 | def checker(env): 25 | nonlocal default_height 26 | object_pos = env.scene[obj_name].data.root_pos_w[0] 27 | if default_height is None: 28 | default_height = env.scene[obj_name].data.default_root_state[0, 2] 29 | 30 | return (object_pos[2] - default_height).item() > threshold 31 | 32 | return checker 33 | 34 | 35 | def is_within_xy(object1, object2, percent_threshold=0.5, open_finger_threshold=0.1): 36 | """ 37 | Check if object1 is inside object2. 38 | """ 39 | 40 | def checker(env): 41 | # ee should be open 42 | stage = get_context().get_stage() 43 | finger_joint = env.scene["robot"].data.joint_pos[0][ 44 | env.scene["robot"].data.joint_names.index("finger_joint") 45 | ] 46 | if finger_joint >= open_finger_threshold: 47 | return False 48 | 49 | obj1_prim = stage.GetPrimAtPath(f"/World/envs/env_0/scene/{object1}") 50 | obj2_prim = stage.GetPrimAtPath(f"/World/envs/env_0/scene/{object2}") 51 | obj1_pos = env.scene[object1].data.root_pos_w[0] 52 | obj2_pos = env.scene[object2].data.root_pos_w[0] 53 | obj1_quat = env.scene[object1].data.root_quat_w[0] 54 | obj2_quat = env.scene[object2].data.root_quat_w[0] 55 | 56 | obj1_corners, obj1_centroid = get_bbox(obj1_prim, pos=obj1_pos, quat=obj1_quat) 57 | obj2_corners, obj2_centroid = get_bbox(obj2_prim, pos=obj2_pos, quat=obj2_quat) 58 | obj1_corners = np.array(obj1_corners) # [8, 3] 59 | obj2_corners = np.array(obj2_corners) # [8, 3] 60 | 61 | # compute intersection of xy planes 62 | obj1_xy_corners = obj1_corners[:, :2] # [8, 2] 63 | obj2_xy_corners = obj2_corners[:, :2] # [8, 2] 64 | obj1_min_xy = np.min(obj1_xy_corners, axis=0) 65 | obj1_max_xy = np.max(obj1_xy_corners, axis=0) 66 | obj2_min_xy = np.min(obj2_xy_corners, axis=0) 67 | obj2_max_xy = np.max(obj2_xy_corners, axis=0) 68 | 69 | # Overlap rectangle boundaries 70 | overlap_min_x = max(obj1_min_xy[0], obj2_min_xy[0]) 71 | overlap_max_x = min(obj1_max_xy[0], obj2_max_xy[0]) 72 | overlap_min_y = max(obj1_min_xy[1], obj2_min_xy[1]) 73 | overlap_max_y = min(obj1_max_xy[1], obj2_max_xy[1]) 74 | 75 | # Check if there's any actual overlap 76 | if overlap_min_x >= overlap_max_x or overlap_min_y >= overlap_max_y: 77 | return False 78 | 79 | # Areas 80 | obj1_area = (obj1_max_xy[0] - obj1_min_xy[0]) * ( 81 | obj1_max_xy[1] - obj1_min_xy[1] 82 | ) 83 | overlap_area = (overlap_max_x - overlap_min_x) * (overlap_max_y - overlap_min_y) 84 | 85 | # Percentage of object1 area that is inside object2 86 | overlap_ratio = overlap_area / obj1_area 87 | # print(f"{object1} is inside {object2} {overlap_ratio}") 88 | 89 | return overlap_ratio >= percent_threshold 90 | 91 | return checker 92 | 93 | 94 | def get_scale(prim: Usd.Prim) -> Gf.Vec3d: 95 | """ 96 | Get the scale parameter applied to a Usd.Prim. 97 | 98 | This function tries multiple approaches to get the scale: 99 | 1. Directly from the 'xformOp:scale' attribute if it exists 100 | 2. From the transform matrix using ExtractScale() method 101 | 3. Returns (1, 1, 1) as default if no scale is found 102 | 103 | Args: 104 | prim: The Usd.Prim to get scale from 105 | 106 | Returns: 107 | Gf.Vec3d: The scale vector (x, y, z) 108 | """ 109 | # First try to get scale directly from xformOp:scale attribute 110 | # scale_attr = get_attribute(prim, "xformOp:scale") 111 | scale_attr = prim.GetAttribute("xformOp:scale") 112 | if scale_attr and scale_attr.IsValid(): 113 | scale_value = scale_attr.Get() 114 | if scale_value is not None: 115 | # Convert to Gf.Vec3d if it's not already 116 | if isinstance(scale_value, (list, tuple)): 117 | return Gf.Vec3d(*scale_value) 118 | elif hasattr(scale_value, "__len__") and len(scale_value) == 3: 119 | return Gf.Vec3d(scale_value[0], scale_value[1], scale_value[2]) 120 | else: 121 | return Gf.Vec3d(scale_value, scale_value, scale_value) 122 | 123 | # Default scale if nothing else works 124 | return Gf.Vec3d(1.0, 1.0, 1.0) 125 | 126 | 127 | def get_bbox(body_prim: Usd.Prim, pos=None, quat=None, scalar_first=False): 128 | pos = pos.cpu().numpy().astype(np.float64) 129 | quat = quat.cpu().numpy().astype(np.float64) 130 | ## TODO: add options: zero_centered, current 131 | time_code = Usd.TimeCode.Default() 132 | bbox_cache = UsdGeom.BBoxCache( 133 | time_code, includedPurposes=[UsdGeom.Tokens.default_] 134 | ) 135 | bbox_cache.Clear() 136 | 137 | prim_bbox = bbox_cache.ComputeLocalBound(body_prim) 138 | 139 | # Get corners and centroid # corners = prim_bbox.GetCorners() # List of 8 GfVec3d points [1][2] 140 | range3d = prim_bbox.GetRange() 141 | matrix = prim_bbox.GetMatrix() 142 | 143 | corners = [matrix.Transform(range3d.GetCorner(i)) for i in range(8)] 144 | centroid = prim_bbox.ComputeCentroid() # GfVec3d [1][2] 145 | 146 | # Transform to origin using the inverse of the prim's world transform 147 | xform_cache = UsdGeom.XformCache(time_code) 148 | world_xform = xform_cache.GetLocalToWorldTransform(body_prim) # Gf.Matrix4d[2][3] 149 | transform = world_xform.GetInverse() 150 | 151 | transformed_corners = [transform.Transform(corner) for corner in corners] 152 | transformed_centroid = transform.Transform(centroid) 153 | 154 | scale = get_scale(body_prim) 155 | if scale is not Gf.Vec3d(1.0, 1.0, 1.0): 156 | # Scale corners directly 157 | scaled_corners = [] 158 | for corner in transformed_corners: 159 | # Scale each corner directly 160 | scaled_corner = Gf.Vec3d( 161 | corner[0] * scale[0], corner[1] * scale[1], corner[2] * scale[2] 162 | ) 163 | scaled_corners.append(scaled_corner) 164 | 165 | transformed_corners = scaled_corners 166 | 167 | # User-supplied transform, if any 168 | if quat is not None and pos is not None: 169 | if isinstance(pos, np.ndarray) or isinstance(pos, list): 170 | # pos = np_to_gf_vec3d(pos) 171 | pos = Gf.Vec3d(pos[0], pos[1], pos[2]) 172 | if isinstance(quat, np.ndarray) or isinstance(quat, list): 173 | # quat = np_to_gf_quatf(quat, scalar_first) 174 | quat = Gf.Quatd(quat[0], quat[1], quat[2], quat[3]) 175 | 176 | additional_transform = Gf.Matrix4d().SetRotateOnly(quat) 177 | additional_transform.SetTranslateOnly(pos) 178 | else: 179 | # No additional transform 180 | return transformed_corners, transformed_centroid 181 | 182 | # Apply additional transform 183 | new_corners = [ 184 | additional_transform.Transform(corner) for corner in transformed_corners 185 | ] 186 | new_centroid = additional_transform.Transform(transformed_centroid) 187 | 188 | return new_corners, new_centroid 189 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | --- 4 | 5 | # PolaRiS 6 | 7 | **[🌐 Website](https://polaris-evals.github.io/)** • **[📄 Paper](https://arxiv.org/abs/2512.16881)** • **[🤗 PolaRiS Hub](https://huggingface.co/datasets/owhan/PolaRiS-Hub)** 8 | 9 | 10 | PolaRiS is a evaluation framework for generalist policies. It provides tooling for reconstructing environments, evaluating models, and running experiments with minimal setup. 11 | 12 | ## Installation 13 | 14 | ### Clone the repository (recursively) 15 | 16 | ```bash 17 | git clone --recursive git@github.com:arhanjain/polaris.git 18 | cd PolaRiS 19 | ``` 20 | 21 | If you cloned without `--recursive`: 22 | 23 | ```bash 24 | git submodule update --init --recursive 25 | ``` 26 | 27 | ### Setup environment with uv 28 | If you don't have UV installed, see [installation instructions](https://docs.astral.sh/uv/getting-started/installation/) 29 | 30 | 31 | By default we support CUDA 13. If you have an older version of CUDA installed, please downgrade the torch and torchvision version and index to be compatible in the [pyproject.toml](pyproject.toml). 32 | ```bash 33 | uv sync 34 | ``` 35 | 36 | ## Getting Started 37 | First, download the PolaRiS environments (<2GB) 38 | ```bash 39 | uvx hf download owhan/PolaRiS-Hub --repo-type=dataset --local-dir ./PolaRiS-Hub 40 | ``` 41 | 42 | ### Minimal Code Example 43 | Next let's test a simple random action policy in a PolaRiS environment. 44 | ```python 45 | import torch 46 | import argparse 47 | import gymnasium as gym 48 | from isaaclab.app import AppLauncher 49 | # This must be done before importing anything with dependency on Isaaclab 50 | # >>>> Isaac Sim App Launcher <<<< 51 | parser = argparse.ArgumentParser() 52 | args_cli, _ = parser.parse_known_args() 53 | args_cli.enable_cameras = True 54 | args_cli.headless = True 55 | app_launcher = AppLauncher(args_cli) 56 | simulation_app = app_launcher.app 57 | # >>>> Isaac Sim App Launcher <<<< 58 | 59 | import polaris.environments 60 | from isaaclab_tasks.utils import parse_env_cfg # noqa: E402 61 | from polaris.environments.manager_based_rl_splat_environment import MangerBasedRLSplatEnv 62 | from polaris.utils import load_eval_initial_conditions 63 | 64 | env_cfg = parse_env_cfg( 65 | "DROID-FoodBussing", 66 | device="cuda", 67 | num_envs=1, 68 | use_fabric=True, 69 | ) 70 | env: MangerBasedRLSplatEnv = gym.make("DROID-FoodBussing", cfg=env_cfg) # type: ignore 71 | language_instruction, initial_conditions = load_eval_initial_conditions(env.usd_file) 72 | obs, info = env.reset(object_positions = initial_conditions[0]) 73 | 74 | while True: 75 | action = torch.tensor(env.action_space.sample()) 76 | obs, rew, term, trunc, info = env.step(action, expensive=True) 77 | 78 | if term[0] or trunc[0]: 79 | break 80 | 81 | print(f"Episode Finished. Success: {info['rubric']['success']}, Progress: {info['rubric']['progress']}") 82 | ``` 83 | 84 | ### Run a π0.5 Policy in PolaRiS 85 | *Note: First run may take longer due to JIT compilation of the splat rasterization kernels. Ensure you have NVIDIA Drivers and CUDA Toolkit (nvcc) properly configured.* 86 | 87 | Both the policy server and evaluation process should fit onto a single GPU (tested on RTX 3090, 24 GB). 88 | ```bash 89 | # Starting from the root of this repo. This will setup openpi and host a pi05 policy. 90 | cd third_party/openpi 91 | GIT_LFS_SKIP_SMUDGE=1 uv sync 92 | GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . 93 | XLA_PYTHON_CLIENT_MEM_FRACTION=0.35 uv run scripts/serve_policy.py --port 8000 policy:checkpoint --policy.config pi05_droid_jointpos_polaris --policy.dir gs://openpi-assets/checkpoints/polaris/pi05_droid_jointpos_polaris 94 | 95 | # In a separate process, start evaluation process 96 | sudo apt install ffmpeg # for saving videos 97 | uv run scripts/eval.py --environment DROID-FoodBussing --policy.port 8000 --run-folder runs/test 98 | ``` 99 | Results include rollout videos, and a CSV summarizing success and normalized progress of each episode. 100 | 101 | ### Off-the-shelf Evaluation Environments 102 | | Environment Name | Prompt | Image | 103 | | :--- | :--- | :--- | 104 | | DROID-BlockStackKitchen | Place and stack the blocks on top of the green tray | | 105 | | DROID-FoodBussing | Put all the foods in the bowl | | 106 | | DROID-PanClean | Use the yellow sponge to scrub the blue handle frying pan | | 107 | | DROID-MoveLatteCup | put the latte art cup on top of the cutting board | | 108 | | DROID-OrganizeTools | put the scissor into the large container | | 109 | | DROID-TapeIntoContainer | put the tape into the container | | 110 | 111 | ### PolaRiS-Ready Policies 112 | 113 | All checkpoints for PolaRiS were based on DROID base policies. Checkpoints were produced by cotraining at a weightage of 10% random simulated data and 90% DROID data for 1k steps. 114 | 115 | | Policy Name | Checkpoints Path | 116 | | :--- | :--- | 117 | | **π0.5 Polaris** | `gs://openpi-assets/checkpoints/polaris/pi05_droid_jointpos_polaris` | 118 | | **π0 Fast Polaris** | `gs://openpi-assets/checkpoints/polaris/pi0_fast_droid_jointpos_polaris` | 119 | | **π0 Polaris** | `gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_polaris` | 120 | | **π0 Polaris (100k)** | `gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_100k_polaris` | 121 | | **PaliGemma Polaris** | `gs://openpi-assets/checkpoints/polaris/paligemma_binning_droid_jointpos_polaris` 122 | 123 | For the full list of all checkpoints, base policies, and environments we provide for evaluation, see [checkpoints_and_envs.md](docs/checkpoints_and_envs.md) 124 | 125 | ## Cotraining and Evaluating Your Policies In PolaRiS 126 |
    127 |
  1. Download DROID simulated cotraining dataset
  2. 128 |
  3. Cotrain a policy 129 |
      130 |
    1. Using OpenPI 131 | 136 |
    2. 137 |
    3. Training a custom policy
    4. 138 |
        139 |
      • We recommend co-finetuning your policy with the provided sim dataset at 10% weightage
      • 140 |
      • May need to define a custom policy client if your policy is not compatible with the provided DROID JointPosition client
      • 141 |
      142 |
    144 |
145 | 146 | See custom_policies.md for more details 147 | 148 | 149 | ## Creating Custom Evaluation Environments 150 | Time Estimate: 20 Minutes Human Time + 40 Minutes Offline Training 151 | 1. Take a video 152 | 2. Extract splat and mesh (we use 2DGS, but any method that produces both can work) 153 | 3. Compose environment USD using our provided Web GUI 154 | 4. Run evaluation :) 155 | 5. Contribute to the community pool of evaluation environments! 156 | 157 | For detailed instructions, see [docs/custom_environments.md](docs/custom_environments.md) 158 | 159 | ## Issues 160 | This codebase has been tested on CUDA 13 and CUDA 12 with NVIDIA 5090 and 3090 GPUs. Please raise an issue if you run into any issues. 161 | 162 | ## Citation 163 | If you find this repository useful, please consider citing it as: 164 | 165 | ```bibtex 166 | @misc{jain2025polarisscalablerealtosimevaluations, 167 | title={PolaRiS: Scalable Real-to-Sim Evaluations for Generalist Robot Policies}, 168 | author={Arhan Jain and Mingtong Zhang and Kanav Arora and William Chen and Marcel Torne and Muhammad Zubair Irshad and Sergey Zakharov and Yue Wang and Sergey Levine and Chelsea Finn and Wei-Chiu Ma and Dhruv Shah and Abhishek Gupta and Karl Pertsch}, 169 | year={2025}, 170 | eprint={2512.16881}, 171 | archivePrefix={arXiv}, 172 | primaryClass={cs.RO}, 173 | url={https://arxiv.org/abs/2512.16881}, 174 | } 175 | ``` 176 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/rasterize_points.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include "cuda_rasterizer/config.h" 22 | #include "cuda_rasterizer/rasterizer.h" 23 | #include 24 | #include 25 | #include 26 | 27 | #define CHECK_INPUT(x) \ 28 | AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 29 | // AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 30 | 31 | std::function resizeFunctional(torch::Tensor& t) { 32 | auto lambda = [&t](size_t N) { 33 | t.resize_({(long long)N}); 34 | return reinterpret_cast(t.contiguous().data_ptr()); 35 | }; 36 | return lambda; 37 | } 38 | 39 | std::tuple 40 | RasterizeGaussiansCUDA( 41 | const torch::Tensor& background, 42 | const torch::Tensor& means3D, 43 | const torch::Tensor& colors, 44 | const torch::Tensor& opacity, 45 | const torch::Tensor& scales, 46 | const torch::Tensor& rotations, 47 | const float scale_modifier, 48 | const torch::Tensor& transMat_precomp, 49 | const torch::Tensor& viewmatrix, 50 | const torch::Tensor& projmatrix, 51 | const float tan_fovx, 52 | const float tan_fovy, 53 | const int image_height, 54 | const int image_width, 55 | const torch::Tensor& sh, 56 | const int degree, 57 | const torch::Tensor& campos, 58 | const bool prefiltered, 59 | const bool debug, 60 | float near_n , 61 | float far_n) 62 | { 63 | if (means3D.ndimension() != 2 || means3D.size(1) != 3) { 64 | AT_ERROR("means3D must have dimensions (num_points, 3)"); 65 | } 66 | 67 | 68 | const int P = means3D.size(0); 69 | const int H = image_height; 70 | const int W = image_width; 71 | 72 | CHECK_INPUT(background); 73 | CHECK_INPUT(means3D); 74 | CHECK_INPUT(colors); 75 | CHECK_INPUT(opacity); 76 | CHECK_INPUT(scales); 77 | CHECK_INPUT(rotations); 78 | CHECK_INPUT(transMat_precomp); 79 | CHECK_INPUT(viewmatrix); 80 | CHECK_INPUT(projmatrix); 81 | CHECK_INPUT(sh); 82 | CHECK_INPUT(campos); 83 | 84 | auto int_opts = means3D.options().dtype(torch::kInt32); 85 | auto float_opts = means3D.options().dtype(torch::kFloat32); 86 | 87 | torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); 88 | torch::Tensor out_others = torch::full({3+3+1, H, W}, 0.0, float_opts); 89 | torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 90 | 91 | torch::Device device(torch::kCUDA); 92 | torch::TensorOptions options(torch::kByte); 93 | torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); 94 | torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); 95 | torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); 96 | std::function geomFunc = resizeFunctional(geomBuffer); 97 | std::function binningFunc = resizeFunctional(binningBuffer); 98 | std::function imgFunc = resizeFunctional(imgBuffer); 99 | 100 | int rendered = 0; 101 | if(P != 0) 102 | { 103 | int M = 0; 104 | if(sh.size(0) != 0) 105 | { 106 | M = sh.size(1); 107 | } 108 | 109 | rendered = CudaRasterizer::Rasterizer::forward( 110 | geomFunc, 111 | binningFunc, 112 | imgFunc, 113 | P, degree, M, 114 | background.contiguous().data(), 115 | W, H, 116 | means3D.contiguous().data(), 117 | sh.contiguous().data_ptr(), 118 | colors.contiguous().data(), 119 | opacity.contiguous().data(), 120 | scales.contiguous().data_ptr(), 121 | scale_modifier, 122 | rotations.contiguous().data_ptr(), 123 | transMat_precomp.contiguous().data(), 124 | viewmatrix.contiguous().data(), 125 | projmatrix.contiguous().data(), 126 | campos.contiguous().data(), 127 | tan_fovx, 128 | tan_fovy, 129 | prefiltered, 130 | out_color.contiguous().data(), 131 | out_others.contiguous().data(), 132 | radii.contiguous().data(), 133 | debug, 134 | near_n , 135 | far_n); 136 | } 137 | return std::make_tuple(rendered, out_color, out_others, radii, geomBuffer, binningBuffer, imgBuffer); 138 | } 139 | 140 | std::tuple 141 | RasterizeGaussiansBackwardCUDA( 142 | const torch::Tensor& background, 143 | const torch::Tensor& means3D, 144 | const torch::Tensor& radii, 145 | const torch::Tensor& colors, 146 | const torch::Tensor& scales, 147 | const torch::Tensor& rotations, 148 | const float scale_modifier, 149 | const torch::Tensor& transMat_precomp, 150 | const torch::Tensor& viewmatrix, 151 | const torch::Tensor& projmatrix, 152 | const float tan_fovx, 153 | const float tan_fovy, 154 | const torch::Tensor& dL_dout_color, 155 | const torch::Tensor& dL_dout_others, 156 | const torch::Tensor& sh, 157 | const int degree, 158 | const torch::Tensor& campos, 159 | const torch::Tensor& geomBuffer, 160 | const int R, 161 | const torch::Tensor& binningBuffer, 162 | const torch::Tensor& imageBuffer, 163 | const bool debug, 164 | float near_n , 165 | float far_n) 166 | { 167 | 168 | CHECK_INPUT(background); 169 | CHECK_INPUT(means3D); 170 | CHECK_INPUT(radii); 171 | CHECK_INPUT(colors); 172 | CHECK_INPUT(scales); 173 | CHECK_INPUT(rotations); 174 | CHECK_INPUT(transMat_precomp); 175 | CHECK_INPUT(viewmatrix); 176 | CHECK_INPUT(projmatrix); 177 | CHECK_INPUT(sh); 178 | CHECK_INPUT(campos); 179 | CHECK_INPUT(binningBuffer); 180 | CHECK_INPUT(imageBuffer); 181 | CHECK_INPUT(geomBuffer); 182 | 183 | const int P = means3D.size(0); 184 | const int H = dL_dout_color.size(1); 185 | const int W = dL_dout_color.size(2); 186 | 187 | int M = 0; 188 | if(sh.size(0) != 0) 189 | { 190 | M = sh.size(1); 191 | } 192 | 193 | torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); 194 | torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); 195 | torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); 196 | torch::Tensor dL_dnormal = torch::zeros({P, 3}, means3D.options()); 197 | torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); 198 | torch::Tensor dL_dtransMat = torch::zeros({P, 9}, means3D.options()); 199 | torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); 200 | torch::Tensor dL_dscales = torch::zeros({P, 2}, means3D.options()); 201 | torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); 202 | 203 | if(P != 0) 204 | { 205 | CudaRasterizer::Rasterizer::backward(P, degree, M, R, 206 | background.contiguous().data(), 207 | W, H, 208 | means3D.contiguous().data(), 209 | sh.contiguous().data(), 210 | colors.contiguous().data(), 211 | scales.data_ptr(), 212 | scale_modifier, 213 | rotations.data_ptr(), 214 | transMat_precomp.contiguous().data(), 215 | viewmatrix.contiguous().data(), 216 | projmatrix.contiguous().data(), 217 | campos.contiguous().data(), 218 | tan_fovx, 219 | tan_fovy, 220 | radii.contiguous().data(), 221 | reinterpret_cast(geomBuffer.contiguous().data_ptr()), 222 | reinterpret_cast(binningBuffer.contiguous().data_ptr()), 223 | reinterpret_cast(imageBuffer.contiguous().data_ptr()), 224 | dL_dout_color.contiguous().data(), 225 | dL_dout_others.contiguous().data(), 226 | dL_dmeans2D.contiguous().data(), 227 | dL_dnormal.contiguous().data(), 228 | dL_dopacity.contiguous().data(), 229 | dL_dcolors.contiguous().data(), 230 | dL_dmeans3D.contiguous().data(), 231 | dL_dtransMat.contiguous().data(), 232 | dL_dsh.contiguous().data(), 233 | dL_dscales.contiguous().data(), 234 | dL_drotations.contiguous().data(), 235 | debug, 236 | near_n , 237 | far_n); 238 | } 239 | 240 | return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dtransMat, dL_dsh, dL_dscales, dL_drotations); 241 | } 242 | 243 | torch::Tensor markVisible( 244 | torch::Tensor& means3D, 245 | torch::Tensor& viewmatrix, 246 | torch::Tensor& projmatrix, 247 | float near_n , 248 | float far_n) 249 | { 250 | const int P = means3D.size(0); 251 | 252 | torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); 253 | 254 | if(P != 0) 255 | { 256 | CudaRasterizer::Rasterizer::markVisible(P, 257 | means3D.contiguous().data(), 258 | viewmatrix.contiguous().data(), 259 | projmatrix.contiguous().data(), 260 | present.contiguous().data(), 261 | near_n , 262 | far_n); 263 | } 264 | 265 | return present; 266 | } 267 | -------------------------------------------------------------------------------- /src/polaris/splat_renderer/splat_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import polaris.utils as utils 4 | from polaris.splat_renderer.gaussian_renderer import GaussianModel, render 5 | from polaris.splat_renderer.scene.cameras import Camera 6 | 7 | 8 | class DummyPipe: 9 | convert_SHs_python = False 10 | compute_cov3D_python = False 11 | depth_ratio = 0.0 12 | debug = False 13 | 14 | 15 | class SplatRenderer: 16 | def __init__(self, splats, bg_color=[0.5, 0.5, 0.5], device=0): 17 | # self.bg_color = bg_color 18 | self.device = device 19 | self.bg_color = torch.tensor(bg_color).to(self.device).float() 20 | self.pcds = splats 21 | self.big_model = GaussianModel(3) 22 | self.original_big_model = GaussianModel(3) 23 | self.splat_mapping = {} 24 | 25 | self.init_models() 26 | print("Finished loading models!") 27 | 28 | self.pipe = DummyPipe() 29 | # self.cameras = self.init_cams(fovx=fovx, fovy=fovy, res=res) 30 | 31 | def render_raw(self, extrinsics_dict): 32 | images = {} 33 | for name in self.cameras: 34 | if name in extrinsics_dict: 35 | cam_t = extrinsics_dict[name]["pos"] 36 | cam_r = extrinsics_dict[name]["rot"] 37 | 38 | self.cameras[name].set_extrinsics(cam_r, cam_t) 39 | 40 | render_pkg = render( 41 | self.cameras[name], self.big_model, self.pipe, self.bg_color 42 | ) 43 | image = render_pkg["render"] 44 | images[name] = image.permute(1, 2, 0).clone() 45 | return images 46 | 47 | def render(self, extrinsics_dict): 48 | """ 49 | extrinsics_dict: dict 50 | { 51 | "name": {"pos": torch.Tensor, "rot": torch.Tensor} 52 | ... 53 | } 54 | """ 55 | 56 | # permute axis to match coordinate frame 57 | p_mat = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) 58 | 59 | images = {} 60 | # for name, extrinsics in extrinsics_dict.items(): 61 | for name in self.cameras: 62 | if name in extrinsics_dict: 63 | cam_t = extrinsics_dict[name]["pos"] 64 | cam_r = extrinsics_dict[name]["rot"] 65 | 66 | cam_r = cam_r @ p_mat 67 | 68 | self.cameras[name].set_extrinsics(cam_r, cam_t) 69 | 70 | render_pkg = render( 71 | self.cameras[name], self.big_model, self.pipe, self.bg_color 72 | ) 73 | image = render_pkg["render"] 74 | images[name] = image.permute(1, 2, 0).clone() 75 | return images 76 | 77 | def init_cameras(self, cam_dict): 78 | """ 79 | cam_dict: dict 80 | { 81 | "name": {"fovx": float, "fovy": float, "res": (height, width)} 82 | ... 83 | } 84 | """ 85 | self.cameras = {} 86 | for name, cam_params in cam_dict.items(): 87 | self.cameras[name] = Camera( 88 | colmap_id=0, 89 | R=np.eye(3), 90 | T=np.array([0.0, 0.0, 0.0]), 91 | FoVy=cam_params["fovy"], 92 | FoVx=cam_params["fovx"], 93 | image=torch.zeros(3, cam_params["res"][0], cam_params["res"][1]), 94 | gt_alpha_mask=None, 95 | image_name="test", 96 | uid=123, 97 | data_device=self.device, 98 | ) 99 | 100 | def init_models(self): 101 | self.big_model._xyz = self.big_model._xyz.to(self.device) 102 | self.big_model._rotation = self.big_model._rotation.to(self.device) 103 | self.big_model._opacity = self.big_model._opacity.to(self.device) 104 | self.big_model._features_rest = self.big_model._features_rest.to(self.device) 105 | self.big_model._features_dc = self.big_model._features_dc.to(self.device) 106 | self.big_model._scaling = self.big_model._scaling.to(self.device) 107 | for name, pcd_path in self.pcds.items(): 108 | model = GaussianModel(3) 109 | model.load_ply(pcd_path) 110 | 111 | # get mappings 112 | cur_len = self.big_model._xyz.shape 113 | self.splat_mapping[name] = (cur_len[0], cur_len[0] + model._xyz.shape[0]) 114 | 115 | self.big_model._xyz = torch.cat( 116 | [self.big_model._xyz, model._xyz], dim=0 117 | ).requires_grad_() 118 | self.big_model._rotation = torch.cat( 119 | [self.big_model._rotation, model._rotation], dim=0 120 | ).requires_grad_() 121 | self.big_model._opacity = torch.cat( 122 | [self.big_model._opacity, model._opacity], dim=0 123 | ).requires_grad_() 124 | self.big_model._features_rest = torch.cat( 125 | [self.big_model._features_rest, model._features_rest], dim=0 126 | ).requires_grad_() 127 | self.big_model._features_dc = torch.cat( 128 | [self.big_model._features_dc, model._features_dc], dim=0 129 | ).requires_grad_() 130 | self.big_model._scaling = torch.cat( 131 | [self.big_model._scaling, model._scaling], dim=0 132 | ).requires_grad_() 133 | 134 | self.original_big_model._xyz = self.big_model._xyz.clone() 135 | self.original_big_model._rotation = self.big_model._rotation.clone() 136 | self.original_big_model._opacity = self.big_model._opacity.clone() 137 | self.original_big_model._features_rest = self.big_model._features_rest.clone() 138 | self.original_big_model._features_dc = self.big_model._features_dc.clone() 139 | self.original_big_model._scaling = self.big_model._scaling.clone() 140 | 141 | def add_splats(self, splats): 142 | for name, pcd_path in splats.items(): 143 | model = GaussianModel(3) 144 | model.load_ply(pcd_path) 145 | 146 | cur_len = self.big_model._xyz.shape 147 | self.splat_mapping[name] = (cur_len[0], cur_len[0] + model._xyz.shape[0]) 148 | 149 | self.big_model._xyz = torch.cat( 150 | [self.big_model._xyz, model._xyz], dim=0 151 | ).requires_grad_() 152 | self.big_model._rotation = torch.cat( 153 | [self.big_model._rotation, model._rotation], dim=0 154 | ).requires_grad_() 155 | self.big_model._opacity = torch.cat( 156 | [self.big_model._opacity, model._opacity], dim=0 157 | ).requires_grad_() 158 | self.big_model._features_rest = torch.cat( 159 | [self.big_model._features_rest, model._features_rest], dim=0 160 | ).requires_grad_() 161 | self.big_model._features_dc = torch.cat( 162 | [self.big_model._features_dc, model._features_dc], dim=0 163 | ).requires_grad_() 164 | self.big_model._scaling = torch.cat( 165 | [self.big_model._scaling, model._scaling], dim=0 166 | ).requires_grad_() 167 | 168 | self.original_big_model._xyz = self.big_model._xyz.clone() 169 | self.original_big_model._rotation = self.big_model._rotation.clone() 170 | self.original_big_model._opacity = self.big_model._opacity.clone() 171 | self.original_big_model._features_rest = self.big_model._features_rest.clone() 172 | self.original_big_model._features_dc = self.big_model._features_dc.clone() 173 | self.original_big_model._scaling = self.big_model._scaling.clone() 174 | 175 | def transform_many(self, all_transforms): 176 | """ 177 | all_transforms: dict 178 | { 179 | "name": (pos (torch.Tensor), rot (torch.Tensor)) 180 | ... 181 | } 182 | 183 | """ 184 | with torch.no_grad(): 185 | indices = [] 186 | properties = [] 187 | for name, transform in all_transforms.items(): 188 | translate = transform[0].to(self.device) 189 | rotate = transform[1].to(self.device) 190 | 191 | start = self.splat_mapping[name][0] 192 | end = self.splat_mapping[name][1] 193 | 194 | new_xyz = ( 195 | utils.rotate_vector_by_quaternion( 196 | rotate, self.original_big_model._xyz[start:end] 197 | ) 198 | + translate 199 | ) 200 | new_rotation = utils.multiply_quaternions( 201 | rotate, self.original_big_model._rotation[start:end] 202 | ) 203 | 204 | new_features_rest = self.original_big_model._features_rest[start:end] 205 | 206 | indices.append(torch.arange(start, end)) 207 | properties.append( 208 | { 209 | "xyz": new_xyz, 210 | "rotation": new_rotation, 211 | "features_rest": new_features_rest, 212 | } 213 | ) 214 | 215 | indices = torch.cat(indices) 216 | xyzs = torch.cat([prop["xyz"] for prop in properties]) 217 | rotations = torch.cat([prop["rotation"] for prop in properties]) 218 | features_rests = torch.cat([prop["features_rest"] for prop in properties]) 219 | 220 | self.big_model._xyz[indices] = xyzs 221 | self.big_model._rotation[indices] = rotations 222 | self.big_model._features_rest[indices] = features_rests 223 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/csrc/cuda_rasterizer/auxiliary.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 13 | #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 14 | 15 | #include "config.h" 16 | #include "stdio.h" 17 | 18 | #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) 19 | #define NUM_WARPS (BLOCK_SIZE/32) 20 | 21 | #define TIGHTBBOX 0 22 | #define RENDER_AXUTILITY 1 23 | #define DEPTH_OFFSET 0 24 | #define ALPHA_OFFSET 1 25 | #define NORMAL_OFFSET 2 26 | #define MIDDEPTH_OFFSET 5 27 | #define DISTORTION_OFFSET 6 28 | // #define MEDIAN_WEIGHT_OFFSET 7 29 | 30 | // distortion helper macros 31 | #define BACKFACE_CULL 1 32 | #define DUAL_VISIABLE 1 33 | // #define NEAR_PLANE 0.2 34 | // #define FAR_PLANE 100.0 35 | #define DETACH_WEIGHT 0 36 | 37 | //__device__ const float near_n = 0.2; 38 | //__device__ const float far_n = 100.0; 39 | __device__ const float FilterInvSquare = 2.0f; 40 | 41 | // Spherical harmonics coefficients 42 | __device__ const float SH_C0 = 0.28209479177387814f; 43 | __device__ const float SH_C1 = 0.4886025119029199f; 44 | __device__ const float SH_C2[] = { 45 | 1.0925484305920792f, 46 | -1.0925484305920792f, 47 | 0.31539156525252005f, 48 | -1.0925484305920792f, 49 | 0.5462742152960396f 50 | }; 51 | __device__ const float SH_C3[] = { 52 | -0.5900435899266435f, 53 | 2.890611442640554f, 54 | -0.4570457994644658f, 55 | 0.3731763325901154f, 56 | -0.4570457994644658f, 57 | 1.445305721320277f, 58 | -0.5900435899266435f 59 | }; 60 | 61 | __forceinline__ __device__ float ndc2Pix(float v, int S) 62 | { 63 | return ((v + 1.0) * S - 1.0) * 0.5; 64 | } 65 | 66 | __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid) 67 | { 68 | rect_min = { 69 | min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))), 70 | min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y))) 71 | }; 72 | rect_max = { 73 | min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))), 74 | min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y))) 75 | }; 76 | } 77 | 78 | __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) 79 | { 80 | float3 transformed = { 81 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 82 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 83 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 84 | }; 85 | return transformed; 86 | } 87 | 88 | __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) 89 | { 90 | float4 transformed = { 91 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 92 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 93 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 94 | matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] 95 | }; 96 | return transformed; 97 | } 98 | 99 | __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) 100 | { 101 | float3 transformed = { 102 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, 103 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, 104 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, 105 | }; 106 | return transformed; 107 | } 108 | 109 | __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) 110 | { 111 | float3 transformed = { 112 | matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, 113 | matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, 114 | matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z, 115 | }; 116 | return transformed; 117 | } 118 | 119 | __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) 120 | { 121 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 122 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 123 | float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 124 | return dnormvdz; 125 | } 126 | 127 | __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) 128 | { 129 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 130 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 131 | 132 | float3 dnormvdv; 133 | dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; 134 | dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32; 135 | dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 136 | return dnormvdv; 137 | } 138 | 139 | __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) 140 | { 141 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; 142 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 143 | 144 | float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; 145 | float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; 146 | float4 dnormvdv; 147 | dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32; 148 | dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32; 149 | dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32; 150 | dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32; 151 | return dnormvdv; 152 | } 153 | 154 | __forceinline__ __device__ float3 cross(float3 a, float3 b){return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);} 155 | 156 | __forceinline__ __device__ float3 operator*(float3 a, float3 b){return make_float3(a.x * b.x, a.y * b.y, a.z*b.z);} 157 | 158 | __forceinline__ __device__ float2 operator*(float2 a, float2 b){return make_float2(a.x * b.x, a.y * b.y);} 159 | 160 | __forceinline__ __device__ float3 operator*(float f, float3 a){return make_float3(f * a.x, f * a.y, f * a.z);} 161 | 162 | __forceinline__ __device__ float2 operator*(float f, float2 a){return make_float2(f * a.x, f * a.y);} 163 | 164 | __forceinline__ __device__ float3 operator-(float3 a, float3 b){return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);} 165 | 166 | __forceinline__ __device__ float2 operator-(float2 a, float2 b){return make_float2(a.x - b.x, a.y - b.y);} 167 | 168 | __forceinline__ __device__ float sumf3(float3 a){return a.x + a.y + a.z;} 169 | 170 | __forceinline__ __device__ float sumf2(float2 a){return a.x + a.y;} 171 | 172 | __forceinline__ __device__ float3 sqrtf3(float3 a){return make_float3(sqrtf(a.x), sqrtf(a.y), sqrtf(a.z));} 173 | 174 | __forceinline__ __device__ float2 sqrtf2(float2 a){return make_float2(sqrtf(a.x), sqrtf(a.y));} 175 | 176 | __forceinline__ __device__ float3 minf3(float f, float3 a){return make_float3(min(f, a.x), min(f, a.y), min(f, a.z));} 177 | 178 | __forceinline__ __device__ float2 minf2(float f, float2 a){return make_float2(min(f, a.x), min(f, a.y));} 179 | 180 | __forceinline__ __device__ float3 maxf3(float f, float3 a){return make_float3(max(f, a.x), max(f, a.y), max(f, a.z));} 181 | 182 | __forceinline__ __device__ float2 maxf2(float f, float2 a){return make_float2(max(f, a.x), max(f, a.y));} 183 | 184 | __forceinline__ __device__ bool in_frustum(int idx, 185 | const float* orig_points, 186 | const float* viewmatrix, 187 | const float* projmatrix, 188 | bool prefiltered, 189 | float3& p_view, 190 | float near_n , 191 | float far_n) 192 | { 193 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 194 | 195 | // Bring points to screen space 196 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 197 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 198 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 199 | p_view = transformPoint4x3(p_orig, viewmatrix); 200 | 201 | if ((p_view.z <= near_n) || (p_view.z > far_n) || (p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)) 202 | { 203 | if (prefiltered) 204 | { 205 | printf("Point is filtered although prefiltered is set. This shouldn't happen!"); 206 | __trap(); 207 | } 208 | return false; 209 | } 210 | return true; 211 | } 212 | 213 | // adopt from gsplat: https://github.com/nerfstudio-project/gsplat/blob/main/gsplat/cuda/csrc/forward.cu 214 | inline __device__ glm::mat3 quat_to_rotmat(const glm::vec4 quat) { 215 | // quat to rotation matrix 216 | float s = rsqrtf( 217 | quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z 218 | ); 219 | float w = quat.x * s; 220 | float x = quat.y * s; 221 | float y = quat.z * s; 222 | float z = quat.w * s; 223 | 224 | // glm matrices are column-major 225 | return glm::mat3( 226 | 1.f - 2.f * (y * y + z * z), 227 | 2.f * (x * y + w * z), 228 | 2.f * (x * z - w * y), 229 | 2.f * (x * y - w * z), 230 | 1.f - 2.f * (x * x + z * z), 231 | 2.f * (y * z + w * x), 232 | 2.f * (x * z + w * y), 233 | 2.f * (y * z - w * x), 234 | 1.f - 2.f * (x * x + y * y) 235 | ); 236 | } 237 | 238 | 239 | inline __device__ glm::vec4 240 | quat_to_rotmat_vjp(const glm::vec4 quat, const glm::mat3 v_R) { 241 | float s = rsqrtf( 242 | quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z 243 | ); 244 | float w = quat.x * s; 245 | float x = quat.y * s; 246 | float y = quat.z * s; 247 | float z = quat.w * s; 248 | 249 | glm::vec4 v_quat; 250 | // v_R is COLUMN MAJOR 251 | // w element stored in x field 252 | v_quat.x = 253 | 2.f * ( 254 | // v_quat.w = 2.f * ( 255 | x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + 256 | z * (v_R[0][1] - v_R[1][0]) 257 | ); 258 | // x element in y field 259 | v_quat.y = 260 | 2.f * 261 | ( 262 | // v_quat.x = 2.f * ( 263 | -2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + 264 | z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1]) 265 | ); 266 | // y element in z field 267 | v_quat.z = 268 | 2.f * 269 | ( 270 | // v_quat.y = 2.f * ( 271 | x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + 272 | z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2]) 273 | ); 274 | // z element in w field 275 | v_quat.w = 276 | 2.f * 277 | ( 278 | // v_quat.z = 2.f * ( 279 | x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - 280 | 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0]) 281 | ); 282 | return v_quat; 283 | } 284 | 285 | 286 | inline __device__ glm::mat3 287 | scale_to_mat(const glm::vec2 scale, const float glob_scale) { 288 | glm::mat3 S = glm::mat3(1.f); 289 | S[0][0] = glob_scale * scale.x; 290 | S[1][1] = glob_scale * scale.y; 291 | // S[2][2] = glob_scale * scale.z; 292 | return S; 293 | } 294 | 295 | 296 | 297 | #define CHECK_CUDA(A, debug) \ 298 | A; if(debug) { \ 299 | auto ret = cudaDeviceSynchronize(); \ 300 | if (ret != cudaSuccess) { \ 301 | std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ 302 | throw std::runtime_error(cudaGetErrorString(ret)); \ 303 | } \ 304 | } 305 | 306 | #endif 307 | -------------------------------------------------------------------------------- /src/polaris/environments/manager_based_rl_splat_environment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from pathlib import Path 4 | import numpy as np 5 | 6 | from isaaclab.sensors.camera.camera import Camera 7 | import isaaclab.utils.math as math 8 | from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg 9 | from isaacsim.core.prims import GeometryPrim 10 | from isaacsim.core.utils.stage import get_current_stage 11 | from pxr import Semantics 12 | 13 | from polaris.splat_renderer import SplatRenderer 14 | from polaris.environments.rubrics import Rubric 15 | 16 | 17 | class ManagerBasedRLSplatEnv(ManagerBasedRLEnv): 18 | rubric: Rubric | None = None 19 | _task_name: str | None = None 20 | 21 | def __init__( 22 | self, 23 | cfg: ManagerBasedRLEnvCfg, 24 | *args, 25 | rubric: Rubric | None = None, 26 | usd_file: str | None = None, 27 | **kwargs, 28 | ): 29 | # do dynamic setup here maybe 30 | if usd_file is not None: 31 | self.usd_file = usd_file 32 | cfg.dynamic_setup(usd_file) 33 | 34 | super().__init__(cfg=cfg, *args, **kwargs) 35 | self.setup_splat_world_and_robot_views() 36 | self.setup_splat_robot() 37 | self.rubric = rubric 38 | 39 | def _evaluate_rubric(self) -> dict: 40 | """Evaluate rubric and return results for info dict.""" 41 | if self.rubric is None: 42 | return { 43 | "rubric": { 44 | "success": False, 45 | "progress": -1.0, 46 | "metrics": {}, 47 | } 48 | } 49 | 50 | result = self.rubric.evaluate(self) 51 | return { 52 | "rubric": { 53 | "success": result.success, 54 | "progress": result.progress, 55 | "metrics": result.metrics, 56 | } 57 | } 58 | 59 | def reset(self, object_positions: dict = {}, expensive=True, *args, **kwargs): 60 | """ 61 | Reset the environment 62 | 63 | Parameters 64 | ---------- 65 | object_positions : dict 66 | A dictionary mapping object names to their desired poses (position and orientation). 67 | expensive : bool 68 | Whether to perform expensive (splat) rendering operations. 69 | """ 70 | obs, info = super().reset(*args, **kwargs) 71 | 72 | # Reset rubric state 73 | if self.rubric: 74 | self.rubric.reset() 75 | 76 | # Following predefined initial conditions 77 | for obj, pose in object_positions.items(): 78 | print(f"Setting initial condition for {obj} to {pose}") 79 | pose = torch.tensor(pose)[None] 80 | self.scene[obj].write_root_pose_to_sim(pose) 81 | self.sim.render() 82 | self.scene.update(0) 83 | obs = ( 84 | self.observation_manager.compute() 85 | ) # update observation after setting ICs if needed 86 | obs["splat"] = self.custom_render(expensive, transform_static=True) 87 | 88 | # Evaluate rubric and add to info 89 | info.update(self._evaluate_rubric()) 90 | 91 | return obs, info 92 | 93 | def step(self, action, expensive=True): 94 | """ 95 | Steps the environment 96 | 97 | Parameters 98 | ---------- 99 | action: torch.Tensor 100 | The action to take in the environment. 101 | expensive : bool 102 | Whether to perform expensive (splat) rendering operations. 103 | """ 104 | obs, rew, done, trunc, info = super().step(action) 105 | obs["splat"] = self.custom_render(expensive) 106 | # obs["splat"] = {cam: self.get_robot_from_sim()[cam]["rgb"] for cam in self.get_robot_from_sim()} 107 | 108 | # Evaluate rubric and add to info 109 | info.update(self._evaluate_rubric()) 110 | 111 | return obs, rew, done, trunc, info 112 | 113 | def custom_render(self, expensive: bool, transform_static: bool = False): 114 | """ 115 | Render the environment 116 | """ 117 | if expensive: 118 | self.transform_sim_to_splat(transform_static=transform_static) 119 | rgb = self.render_splat() 120 | mask_and_rgb = self.get_robot_from_sim() 121 | for cam in mask_and_rgb: 122 | og_img = ( 123 | rgb[cam] if cam in rgb else np.zeros_like(mask_and_rgb[cam]["rgb"]) 124 | ) 125 | mask = mask_and_rgb[cam]["mask"] 126 | sim_img = mask_and_rgb[cam]["rgb"] 127 | new_img = np.where(mask, sim_img, og_img) 128 | rgb[cam] = new_img 129 | else: 130 | rgb = {} 131 | for cam in self.scene.sensors: 132 | if isinstance(self.scene.sensors[cam], Camera): 133 | rgb[cam] = ( 134 | self.scene[cam].data.output["rgb"][0].detach().cpu().numpy() 135 | ) 136 | return rgb 137 | 138 | def setup_splat_world_and_robot_views(self): 139 | splats = {} 140 | self.views = {} 141 | stage = get_current_stage() 142 | 143 | # Allocate splats for all rigid objects in the scene and raytrace semantic tags 144 | for name in self.scene.rigid_objects: 145 | path = Path(self.usd_file).parent / "assets" / name / "splat.ply" 146 | if path.exists(): 147 | splats[name] = path 148 | else: 149 | # apply semantic tags 150 | prim = stage.GetPrimAtPath(f"/World/envs/env_0/scene/{name}") 151 | semantic_type = "class" 152 | semantic_value = "raytraced" 153 | instance_name = f"{semantic_type}_{semantic_value}" 154 | sem = Semantics.SemanticsAPI.Apply(prim, instance_name) 155 | sem.CreateSemanticTypeAttr() 156 | sem.CreateSemanticDataAttr() 157 | sem.GetSemanticTypeAttr().Set(semantic_type) 158 | sem.GetSemanticDataAttr().Set(semantic_value) 159 | 160 | # Setup splat cameras with intrinsics and resolution from sim cameras 161 | camera_cfg = {} 162 | for name in self.scene.sensors: 163 | if not isinstance(self.scene.sensors[name], Camera): 164 | continue 165 | resolution = self.scene.sensors[name].image_shape 166 | h_aperture = ( 167 | self.scene[name]._sensor_prims[0].GetHorizontalApertureAttr().Get() 168 | ) 169 | v_aperture = ( 170 | self.scene[name]._sensor_prims[0].GetVerticalApertureAttr().Get() 171 | ) 172 | f = self.scene[name]._sensor_prims[0].GetFocalLengthAttr().Get() 173 | fovx = 2 * np.arctan(h_aperture / (2 * f)) 174 | fovy = 2 * np.arctan(v_aperture / (2 * f)) 175 | camera_cfg[name] = { 176 | "res": resolution, 177 | "fovx": fovx, 178 | "fovy": fovy, 179 | } 180 | self.splat_renderer = SplatRenderer(splats=splats, device=self.device) 181 | self.splat_renderer.init_cameras(camera_cfg) 182 | 183 | def setup_splat_robot(self): 184 | # Allocate robot splats and views on robot links to track 185 | more_splats = {} 186 | robot_asset_path = Path(self.cfg.scene.robot.spawn.usd_path).parent 187 | for ply in sorted(list(robot_asset_path.glob("SEGMENTED/*.ply"))): 188 | more_splats[ply.stem] = ply 189 | sim_path = ply.stem.replace("-", "/") 190 | view = GeometryPrim( 191 | prim_paths_expr=f"/World/envs/env_0/robot/{sim_path}", 192 | reset_xform_properties=False, 193 | ) 194 | print(f"/World/envs/env_0/robot/{sim_path}") 195 | self.views[ply.stem] = view 196 | self.splat_renderer.add_splats(more_splats) 197 | 198 | def get_robot_from_sim(self): 199 | # TODO: comment this. does this get only robot? objects too? 200 | ret = {} 201 | for cam in self.scene.sensors: 202 | if not isinstance(self.scene.sensors[cam], Camera): 203 | continue 204 | base_cam = self.scene[cam] 205 | mask = ( 206 | base_cam.data.output["semantic_segmentation"][0].detach().cpu().numpy() 207 | ) 208 | img = base_cam.data.output["rgb"][0].detach().cpu().numpy() 209 | mask = np.where(mask >= 2, 1, 0) 210 | 211 | ret[cam] = {"rgb": img, "mask": mask} 212 | 213 | return ret 214 | 215 | def transform_sim_to_splat(self, transform_static=False): 216 | """ 217 | Update splat renderer transforms from simulation 218 | 219 | Parameters 220 | ---------- 221 | transform_static : bool 222 | Whether to also transform static objects (like environment). 223 | """ 224 | all_transforms = {} 225 | 226 | # rigid bodies 227 | for name in self.scene.rigid_objects: 228 | path = Path(self.usd_file).parent / "assets" / name / "splat.ply" 229 | if ( 230 | "static" not in name or transform_static 231 | ) and path.exists(): # splat exists 232 | pos = self.scene[name].data.root_state_w[0, :3] 233 | quat = self.scene[name].data.root_state_w[0, 3:7] 234 | all_transforms[name] = (pos, quat) 235 | 236 | # robot - this will only fire if setup_splat_robot has been called otherwise views will be empty 237 | for v_name in self.views: 238 | view = self.views[v_name] 239 | pos, quat = view.get_world_poses(usd=False) 240 | pos, quat = pos.squeeze(), quat.squeeze() 241 | all_transforms[v_name] = (pos, quat) 242 | 243 | if len(all_transforms) > 0: # only transform if there is something to transform 244 | self.splat_renderer.transform_many(all_transforms) 245 | 246 | # set all cameras so that static cameras are set 247 | if transform_static: 248 | cam_extrinsics_dict = {} 249 | for name in self.splat_renderer.cameras: 250 | pos = self.scene[name].data.pos_w[0].detach().cpu().numpy() 251 | quat = self.scene[name].data.quat_w_world[0] 252 | 253 | rot = math.matrix_from_quat(quat).detach().cpu().numpy() 254 | cam_extrinsics_dict[name] = {"pos": pos, "rot": rot} 255 | 256 | if len(self.splat_renderer.pcds) > 0: 257 | self.splat_renderer.render(cam_extrinsics_dict) 258 | 259 | def render_splat(self): 260 | # get camera extrinsics 261 | cam_extrinsics_dict = {} 262 | for name in self.splat_renderer.cameras: 263 | if "wrist" in name: 264 | pos = self.scene[name].data.pos_w[0].detach().cpu().numpy() 265 | quat = self.scene[name].data.quat_w_world[0] 266 | 267 | rot = math.matrix_from_quat(quat).detach().cpu().numpy() 268 | cam_extrinsics_dict[name] = {"pos": pos, "rot": rot} 269 | 270 | # perform splat rendering 271 | if len(self.splat_renderer.pcds) > 0: 272 | rgb = self.splat_renderer.render(cam_extrinsics_dict) 273 | else: 274 | rgb = { 275 | name: torch.zeros( 276 | ( 277 | self.splat_renderer.cameras[name].image_height, 278 | self.splat_renderer.cameras[name].image_width, 279 | 3, 280 | ) 281 | ) 282 | for name in cam_extrinsics_dict 283 | } 284 | 285 | # process output 286 | for k, v in rgb.items(): 287 | rgb[k] = v.detach().cpu().numpy() 288 | rgb[k] = np.clip(rgb[k], 0, 1) 289 | rgb[k] = (rgb[k] * 255).astype(np.uint8) 290 | 291 | # TODO: why is there a resize? 292 | rgb[k] = cv2.resize(rgb[k], (rgb[k].shape[1] // 2, rgb[k].shape[0] // 2)) 293 | rgb[k] = cv2.resize(rgb[k], (rgb[k].shape[1] * 2, rgb[k].shape[0] * 2)) 294 | 295 | return rgb 296 | -------------------------------------------------------------------------------- /src/polaris/hf_upload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities to validate and upload a PolaRiS environment folder to Hugging Face. 3 | Default target dataset: `PolaRiS-Evals/PolaRiS-Hub`. 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | import json 9 | import sys 10 | from dataclasses import dataclass 11 | from pathlib import Path 12 | from typing import Iterable, List, Tuple 13 | 14 | import re 15 | import tyro 16 | 17 | from huggingface_hub import CommitOperationAdd, HfApi # type: ignore 18 | from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError # type: ignore 19 | 20 | ALLOWED_MESH_SUFFIXES = {".usdz", ".usd", ".glb", ".ply"} 21 | 22 | 23 | def _is_numeric_sequence(value: Iterable[object], expected_len: int = 7) -> bool: 24 | try: 25 | items = list(value) 26 | except TypeError: 27 | return False 28 | if len(items) != expected_len: 29 | return False 30 | return all(isinstance(v, (int, float)) for v in items) 31 | 32 | 33 | def _validate_assets(assets_dir: Path) -> Tuple[List[str], List[str], List[str]]: 34 | errors: List[str] = [] 35 | warnings: List[str] = [] 36 | asset_names: List[str] = [] 37 | 38 | if not assets_dir.exists(): 39 | errors.append(f"Missing assets directory: {assets_dir}") 40 | return errors, warnings, asset_names 41 | if not assets_dir.is_dir(): 42 | errors.append(f"`assets` is not a directory: {assets_dir}") 43 | return errors, warnings, asset_names 44 | 45 | for asset_dir in sorted(p for p in assets_dir.iterdir() if p.is_dir()): 46 | asset_names.append(asset_dir.name) 47 | mesh_candidates = [ 48 | p 49 | for p in asset_dir.rglob("*") 50 | if p.is_file() and p.suffix.lower() in ALLOWED_MESH_SUFFIXES 51 | ] 52 | if not mesh_candidates: 53 | errors.append( 54 | f"Asset `{asset_dir.name}` is missing any mesh file " 55 | f"(expected one of {sorted(ALLOWED_MESH_SUFFIXES)})" 56 | ) 57 | if not asset_names: 58 | errors.append(f"No asset subfolders found in {assets_dir}") 59 | return errors, warnings, asset_names 60 | 61 | 62 | def _objects_match_assets(obj_name: str, asset_names: Iterable[str]) -> bool: 63 | normalized = obj_name.lower().rstrip("0123456789_") 64 | for asset in asset_names: 65 | asset_norm = asset.lower().rstrip("0123456789_") 66 | if ( 67 | normalized.startswith(asset_norm) 68 | or asset_norm.startswith(normalized) 69 | or normalized in asset_norm 70 | or asset_norm in normalized 71 | ): 72 | return True 73 | return False 74 | 75 | 76 | def _validate_initial_conditions( 77 | ic_path: Path, asset_names: Iterable[str] 78 | ) -> Tuple[List[str], List[str]]: 79 | errors: List[str] = [] 80 | warnings: List[str] = [] 81 | 82 | if not ic_path.exists(): 83 | errors.append(f"Missing initial_conditions.json at {ic_path}") 84 | return errors, warnings 85 | try: 86 | with ic_path.open("r") as f: 87 | data = json.load(f) 88 | except Exception as exc: # noqa: BLE001 - surfaced to user 89 | errors.append(f"Failed to parse {ic_path}: {exc}") 90 | return errors, warnings 91 | 92 | if not isinstance(data, dict): 93 | errors.append("initial_conditions.json must be a JSON object") 94 | return errors, warnings 95 | 96 | instruction = data.get("instruction") 97 | if not isinstance(instruction, str) or not instruction.strip(): 98 | errors.append("`instruction` must be a non-empty string") 99 | 100 | poses = data.get("poses") 101 | if not isinstance(poses, list) or not poses: 102 | errors.append("`poses` must be a non-empty list") 103 | return errors, warnings 104 | 105 | for idx, pose in enumerate(poses): 106 | if not isinstance(pose, dict): 107 | errors.append(f"Pose {idx} is not an object") 108 | continue 109 | for obj_name, obj_pose in pose.items(): 110 | if not _is_numeric_sequence(obj_pose, expected_len=7): 111 | errors.append( 112 | f"Pose {idx} for `{obj_name}` is not a 7-element numeric sequence" 113 | ) 114 | elif not _objects_match_assets(obj_name, asset_names): 115 | warnings.append( 116 | f"Pose {idx} references `{obj_name}` which does not obviously map to an asset " 117 | f"({', '.join(asset_names)})" 118 | ) 119 | return errors, warnings 120 | 121 | 122 | def _validate_usd_files( 123 | env_dir: Path, require_pxr: bool = False 124 | ) -> Tuple[List[str], List[str]]: 125 | errors: List[str] = [] 126 | warnings: List[str] = [] 127 | 128 | usd_files = list(env_dir.glob("*.usda")) 129 | if not usd_files: 130 | errors.append(f"No stage .usda file found in {env_dir}") 131 | return errors, warnings 132 | 133 | try: 134 | from pxr import Usd # type: ignore 135 | except Exception as exc: # noqa: BLE001 - library is optional 136 | if require_pxr: 137 | errors.append(f"pxr.Usd not available; cannot open USD files ({exc})") 138 | # when not required, stay quiet to keep dry-runs clean 139 | return errors, warnings 140 | 141 | for usd_file in usd_files: 142 | stage = Usd.Stage.Open(str(usd_file)) 143 | if stage is None: 144 | errors.append(f"Failed to open USD stage: {usd_file}") 145 | continue 146 | if stage.GetDefaultPrim() is None: 147 | warnings.append(f"USD stage has no default prim set: {usd_file}") 148 | return errors, warnings 149 | 150 | 151 | def validate_environment( 152 | env_dir: Path, require_pxr: bool = False 153 | ) -> Tuple[List[str], List[str]]: 154 | errors: List[str] = [] 155 | warnings: List[str] = [] 156 | 157 | if not env_dir.exists(): 158 | return [f"Environment path does not exist: {env_dir}"], warnings 159 | if not env_dir.is_dir(): 160 | return [f"Environment path is not a directory: {env_dir}"], warnings 161 | 162 | assets_errors, assets_warnings, asset_names = _validate_assets(env_dir / "assets") 163 | errors.extend(assets_errors) 164 | warnings.extend(assets_warnings) 165 | 166 | ic_errors, ic_warnings = _validate_initial_conditions( 167 | env_dir / "initial_conditions.json", asset_names 168 | ) 169 | errors.extend(ic_errors) 170 | warnings.extend(ic_warnings) 171 | 172 | usd_errors, usd_warnings = _validate_usd_files(env_dir, require_pxr=require_pxr) 173 | errors.extend(usd_errors) 174 | warnings.extend(usd_warnings) 175 | 176 | return errors, warnings 177 | 178 | 179 | def upload_environment( 180 | env_dir: Path, 181 | repo_id: str, 182 | token: str | None, 183 | branch: str, 184 | pr_branch: str | None, 185 | commit_message: str | None, 186 | pr_title: str | None, 187 | pr_description: str | None, 188 | ) -> None: 189 | env_name = env_dir.name 190 | api = HfApi(token=token) 191 | commit_message = commit_message or f"Add environment `{env_name}`" 192 | if pr_title: 193 | commit_message = pr_title 194 | if pr_description: 195 | commit_message = f"{commit_message}\n\n{pr_description}" 196 | 197 | operations = [] 198 | for file in env_dir.rglob("*"): 199 | if not file.is_file(): 200 | continue 201 | rel_path = file.relative_to(env_dir).as_posix() 202 | path_in_repo = f"{env_name}/{rel_path}" 203 | operations.append( 204 | CommitOperationAdd( 205 | path_in_repo=path_in_repo, 206 | path_or_fileobj=str(file), 207 | ) 208 | ) 209 | pr_title = pr_title or f"Add environment `{env_name}`" 210 | pr_description = pr_description or "" 211 | revision = pr_branch or branch 212 | try: 213 | commit_info = api.create_commit( # type: ignore[arg-type] 214 | repo_id=repo_id, 215 | repo_type="dataset", 216 | operations=operations, 217 | revision=revision, 218 | commit_message=commit_message, 219 | create_pr=True, 220 | ) 221 | except RepositoryNotFoundError as exc: 222 | raise SystemExit( 223 | f"Repository `{repo_id}` not found or unauthorized. " 224 | "Ensure the dataset exists and your HF token has write access." 225 | ) from exc 226 | except HfHubHTTPError as exc: 227 | raise SystemExit(f"Hugging Face API error while creating PR: {exc}") from exc 228 | pr_url = getattr(commit_info, "pr_url", None) 229 | pr_num = getattr(commit_info, "pr_num", None) 230 | 231 | # Try to extract PR number from URL if not directly available 232 | if pr_url and not pr_num: 233 | match = re.search(r"/(?:pull|pulls|discussions)/(\d+)", pr_url) 234 | if match: 235 | pr_num = match.group(1) 236 | 237 | if pr_url: 238 | print(f"Pull request opened: {pr_url}") 239 | elif pr_num: 240 | pr_url = f"https://huggingface.co/datasets/{repo_id}/discussions/{pr_num}" 241 | print(f"Pull request opened: {pr_url}") 242 | else: 243 | discussions_page = f"https://huggingface.co/datasets/{repo_id}/discussions" 244 | print( 245 | f"Pull request created (URL not returned by API). Check: {discussions_page}" 246 | ) 247 | 248 | if pr_num: 249 | repo_name = repo_id.split("/")[-1] 250 | print("\nTo check out and update this PR locally:") 251 | print(f" git clone https://huggingface.co/datasets/{repo_id}") 252 | print(f" cd {repo_name} && git fetch origin refs/pr/{pr_num}:pr/{pr_num}") 253 | print(f" git checkout pr/{pr_num}") 254 | print(" # make edits, then:") 255 | print(f" git push origin pr/{pr_num}:refs/pr/{pr_num}") 256 | print(f"PR source revision: {revision} -> target: {branch}") 257 | 258 | 259 | @dataclass 260 | class Args: 261 | """Validate and upload a PolaRiS environment to Hugging Face.""" 262 | 263 | env_dir: Path 264 | """Path to the environment folder (e.g., ~/polaris/PolaRiS-Hub/food_bussing)""" 265 | 266 | repo_id: str = "owhan/PolaRiS-Hub" 267 | """Target Hugging Face dataset repository""" 268 | 269 | branch: str = "main" 270 | """Target branch on the dataset repository""" 271 | 272 | pr_branch: str | None = None 273 | """Optional source branch/ref for the PR (e.g., refs/pr/104); defaults to --branch""" 274 | 275 | token: str | None = None 276 | """Hugging Face token (defaults to HF_TOKEN env var if omitted)""" 277 | 278 | skip_validation: bool = False 279 | """Upload without running local validation (not recommended)""" 280 | 281 | strict: bool = False 282 | """Treat validation warnings as errors""" 283 | 284 | require_pxr: bool = False 285 | """Fail validation if pxr (USD) is unavailable for stage open checks""" 286 | 287 | dry_run: bool = False 288 | """Only run validation; do not upload""" 289 | 290 | commit_message: str | None = None 291 | """Optional commit message for the upload""" 292 | 293 | pr_title: str | None = None 294 | """Pull request title""" 295 | 296 | pr_description: str | None = None 297 | """Pull request description/body""" 298 | 299 | 300 | def main(args: Args | None = None) -> None: 301 | if args is None: 302 | args = tyro.cli(Args) 303 | 304 | env_dir: Path = args.env_dir.resolve() 305 | 306 | if not args.skip_validation: 307 | errors, warnings = validate_environment(env_dir, require_pxr=args.require_pxr) 308 | for warn in warnings: 309 | print(f"[WARN] {warn}") 310 | if errors: 311 | for err in errors: 312 | print(f"[ERROR] {err}") 313 | sys.exit(1) 314 | if args.strict and warnings: 315 | print("[ERROR] Warnings treated as errors because --strict is set") 316 | sys.exit(1) 317 | else: 318 | print("Skipping validation as requested.") 319 | 320 | if args.dry_run: 321 | print("Dry run complete; nothing uploaded.") 322 | return 323 | 324 | upload_environment( 325 | env_dir=env_dir, 326 | repo_id=args.repo_id, 327 | token=args.token, 328 | branch=args.branch, 329 | pr_branch=args.pr_branch, 330 | commit_message=args.commit_message, 331 | pr_title=args.pr_title, 332 | pr_description=args.pr_description, 333 | ) 334 | print( 335 | f"Prepared PR for `{env_dir.name}` to {args.repo_id} (target branch: {args.branch})." 336 | ) 337 | 338 | 339 | if __name__ == "__main__": 340 | main() 341 | -------------------------------------------------------------------------------- /src/polaris/environments/droid_cfg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from isaaclab.envs.mdp.actions.actions_cfg import BinaryJointPositionActionCfg 4 | from isaaclab.envs.mdp.actions.binary_joint_actions import BinaryJointPositionAction 5 | import isaaclab.sim as sim_utils 6 | import isaaclab.utils.math as math 7 | import isaaclab.envs.mdp as mdp 8 | import numpy as np 9 | from typing import Sequence 10 | 11 | from polaris.environments.robot_cfg import NVIDIA_DROID 12 | 13 | from pxr import Usd, UsdGeom, UsdPhysics 14 | from isaaclab.utils import configclass, noise 15 | from isaaclab.assets import AssetBaseCfg, RigidObjectCfg 16 | from isaaclab.managers import SceneEntityCfg 17 | from isaaclab.scene import InteractiveSceneCfg 18 | from isaaclab.managers import ObservationGroupCfg as ObsGroup 19 | from isaaclab.managers import TerminationTermCfg as DoneTerm 20 | from isaaclab.managers import EventTermCfg as EventTerm 21 | from isaaclab.managers import ObservationTermCfg as ObsTerm 22 | from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg 23 | from isaaclab.sensors import CameraCfg, Camera 24 | from isaaclab.sensors.frame_transformer.frame_transformer_cfg import ( 25 | FrameTransformerCfg, 26 | OffsetCfg, 27 | ) 28 | from isaaclab.markers.config import FRAME_MARKER_CFG 29 | 30 | 31 | # Patch to fix updating camera poses, since it's broken in IsaacLab 2.3 32 | class FixedCamera(Camera): 33 | def _update_poses(self, env_ids: Sequence[int]): 34 | """Computes the pose of the camera in the world frame with ROS convention. 35 | 36 | This methods uses the ROS convention to resolve the input pose. In this convention, 37 | we assume that the camera front-axis is +Z-axis and up-axis is -Y-axis. 38 | 39 | Returns: 40 | A tuple of the position (in meters) and quaternion (w, x, y, z). 41 | """ 42 | # check camera prim exists 43 | if len(self._sensor_prims) == 0: 44 | raise RuntimeError("Camera prim is None. Please call 'sim.play()' first.") 45 | 46 | # get the poses from the view 47 | env_ids = env_ids.to(torch.int32) 48 | poses, quat = self._view.get_world_poses(env_ids, usd=False) 49 | self._data.pos_w[env_ids] = poses 50 | self._data.quat_w_world[env_ids] = ( 51 | math.convert_camera_frame_orientation_convention( 52 | quat, origin="opengl", target="world" 53 | ) 54 | ) 55 | 56 | 57 | ### SceneCfg ### 58 | @configclass 59 | class SceneCfg(InteractiveSceneCfg): 60 | """Configuration for a cart-pole scene.""" 61 | 62 | robot = NVIDIA_DROID 63 | 64 | wrist_cam = CameraCfg( 65 | class_type=FixedCamera, 66 | prim_path="{ENV_REGEX_NS}/robot/Gripper/Robotiq_2F_85/base_link/wrist_cam", 67 | height=720, 68 | width=1280, 69 | data_types=["rgb", "semantic_segmentation"], 70 | colorize_semantic_segmentation=False, 71 | update_latest_camera_pose=True, 72 | spawn=sim_utils.PinholeCameraCfg( 73 | focal_length=2.8, 74 | focus_distance=28.0, 75 | horizontal_aperture=5.376, 76 | vertical_aperture=3.024, 77 | ), 78 | offset=CameraCfg.OffsetCfg( 79 | pos=(0.011, -0.031, -0.074), 80 | rot=(-0.420, 0.570, 0.576, -0.409), 81 | convention="opengl", 82 | ), 83 | ) 84 | 85 | sphere_light = AssetBaseCfg( 86 | prim_path="/World/biglight", 87 | spawn=sim_utils.DomeLightCfg(intensity=1000), 88 | ) 89 | 90 | def __post_init__( 91 | self, 92 | ): 93 | marker_cfg = FRAME_MARKER_CFG.copy() 94 | marker_cfg.markers["frame"].scale = (0.1, 0.1, 0.1) 95 | marker_cfg.prim_path = "/Visuals/FrameTransformer" 96 | self.ee_frame = FrameTransformerCfg( 97 | prim_path="{ENV_REGEX_NS}/robot/panda_link0", 98 | debug_vis=False, 99 | visualizer_cfg=marker_cfg, 100 | target_frames=[ 101 | FrameTransformerCfg.FrameCfg( 102 | prim_path="{ENV_REGEX_NS}/robot/Gripper/Robotiq_2F_85/base_link", 103 | name="end_effector", 104 | offset=OffsetCfg( 105 | pos=[0.0, 0.0, 0.0], 106 | ), 107 | ), 108 | ], 109 | ) 110 | 111 | def dynamic_setup(self, environment_path, robot_splat=True, nightmare="", **kwargs): 112 | environment_path_ = Path(environment_path) 113 | environment_path = str(environment_path_.resolve()) 114 | 115 | scene = AssetBaseCfg( 116 | prim_path="{ENV_REGEX_NS}/scene", 117 | spawn=sim_utils.UsdFileCfg( 118 | usd_path=environment_path, 119 | activate_contact_sensors=False, 120 | ), 121 | ) 122 | self.scene = scene 123 | if not robot_splat: 124 | self.robot.spawn.semantic_tags = [("class", "raytraced")] 125 | stage = Usd.Stage.Open(environment_path) 126 | scene_prim = stage.GetPrimAtPath("/World") 127 | children = scene_prim.GetChildren() 128 | 129 | for child in children: 130 | name = child.GetName() 131 | print(name) 132 | 133 | # if its a camera, use the camera pose 134 | if child.IsA(UsdGeom.Camera): 135 | pos = child.GetAttribute("xformOp:translate").Get() 136 | rot = child.GetAttribute("xformOp:orient").Get() 137 | rot = ( 138 | rot.GetReal(), 139 | rot.GetImaginary()[0], 140 | rot.GetImaginary()[1], 141 | rot.GetImaginary()[2], 142 | ) 143 | asset = CameraCfg( 144 | prim_path=f"{{ENV_REGEX_NS}}/scene/{name}", 145 | height=720, 146 | width=1280, 147 | data_types=["rgb", "semantic_segmentation"], 148 | colorize_semantic_segmentation=False, 149 | spawn=None, 150 | offset=CameraCfg.OffsetCfg(pos=pos, rot=rot, convention="opengl"), 151 | ) 152 | setattr(self, name, asset) 153 | elif UsdPhysics.RigidBodyAPI(child): 154 | pos = child.GetAttribute("xformOp:translate").Get() 155 | rot = child.GetAttribute("xformOp:orient").Get() 156 | rot = ( 157 | rot.GetReal(), 158 | rot.GetImaginary()[0], 159 | rot.GetImaginary()[1], 160 | rot.GetImaginary()[2], 161 | ) 162 | asset = RigidObjectCfg( 163 | prim_path=f"{{ENV_REGEX_NS}}/scene/{name}", 164 | spawn=None, 165 | init_state=RigidObjectCfg.InitialStateCfg( 166 | pos=pos, 167 | rot=rot, 168 | ), 169 | ) 170 | setattr(self, name, asset) 171 | 172 | if not hasattr(self, "external_cam"): 173 | self.external_cam = CameraCfg( 174 | prim_path="{ENV_REGEX_NS}/scene/external_cam", 175 | height=720, 176 | width=1280, 177 | data_types=["rgb", "semantic_segmentation"], 178 | colorize_semantic_segmentation=False, 179 | spawn=sim_utils.PinholeCameraCfg( 180 | focal_length=1.0476, 181 | horizontal_aperture=2.5452, 182 | vertical_aperture=1.4721, 183 | ), 184 | offset=CameraCfg.OffsetCfg( 185 | pos=(-0.01, -0.33, 0.48), 186 | rot=(0.76, 0.43, -0.24, -0.42), 187 | convention="opengl", 188 | ), 189 | ) 190 | 191 | 192 | ### SceneCfg ### 193 | 194 | 195 | ### ActionCfg ### 196 | class BinaryJointPositionZeroToOneAction(BinaryJointPositionAction): 197 | # override 198 | def process_actions(self, actions: torch.Tensor): 199 | # store the raw actions 200 | self._raw_actions[:] = actions 201 | # compute the binary mask 202 | if actions.dtype == torch.bool: 203 | # true: close, false: open 204 | binary_mask = actions == 0 205 | else: 206 | # true: close, false: open 207 | binary_mask = actions > 0.5 208 | # compute the command 209 | self._processed_actions = torch.where( 210 | binary_mask, self._close_command, self._open_command 211 | ) 212 | if self.cfg.clip is not None: 213 | self._processed_actions = torch.clamp( 214 | self._processed_actions, 215 | min=self._clip[:, :, 0], 216 | max=self._clip[:, :, 1], 217 | ) 218 | 219 | 220 | @configclass 221 | class BinaryJointPositionZeroToOneActionCfg(BinaryJointPositionActionCfg): 222 | """Configuration for the binary joint position action term. 223 | 224 | See :class:`BinaryJointPositionAction` for more details. 225 | """ 226 | 227 | class_type = BinaryJointPositionZeroToOneAction 228 | 229 | 230 | @configclass 231 | class ActionCfg: 232 | arm = mdp.JointPositionActionCfg( 233 | asset_name="robot", 234 | joint_names=["panda_joint.*"], 235 | preserve_order=True, 236 | use_default_offset=False, 237 | ) 238 | 239 | finger_joint = BinaryJointPositionZeroToOneActionCfg( 240 | asset_name="robot", 241 | joint_names=["finger_joint"], 242 | open_command_expr={"finger_joint": 0.0}, 243 | close_command_expr={"finger_joint": np.pi / 4}, 244 | ) 245 | 246 | 247 | ### ActionCfg ### 248 | 249 | 250 | ### ObsCfg ### 251 | def arm_joint_pos( 252 | env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") 253 | ): 254 | robot = env.scene[asset_cfg.name] 255 | joint_names = [ 256 | "panda_joint1", 257 | "panda_joint2", 258 | "panda_joint3", 259 | "panda_joint4", 260 | "panda_joint5", 261 | "panda_joint6", 262 | "panda_joint7", 263 | ] 264 | # get joint inidices 265 | joint_indices = [ 266 | i for i, name in enumerate(robot.data.joint_names) if name in joint_names 267 | ] 268 | joint_pos = robot.data.joint_pos[:, joint_indices] 269 | return joint_pos 270 | 271 | 272 | def gripper_pos( 273 | env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") 274 | ): 275 | robot = env.scene[asset_cfg.name] 276 | joint_names = ["finger_joint"] 277 | joint_indices = [ 278 | i for i, name in enumerate(robot.data.joint_names) if name in joint_names 279 | ] 280 | joint_pos = robot.data.joint_pos[:, joint_indices] 281 | 282 | # rescale 283 | joint_pos = joint_pos / (np.pi / 4) 284 | 285 | return joint_pos 286 | 287 | 288 | @configclass 289 | class ObservationCfg: 290 | @configclass 291 | class PolicyCfg(ObsGroup): 292 | """Observations for policy.""" 293 | 294 | arm_joint_pos = ObsTerm(func=arm_joint_pos) 295 | gripper_pos = ObsTerm( 296 | func=gripper_pos, noise=noise.GaussianNoiseCfg(std=0.05), clip=(0, 1) 297 | ) 298 | 299 | def __post_init__(self) -> None: 300 | self.enable_corruption = False 301 | self.concatenate_terms = False 302 | 303 | policy: PolicyCfg = PolicyCfg() 304 | 305 | 306 | ### ObsCfg ### 307 | 308 | 309 | @configclass 310 | class EventCfg: 311 | """Configuration for events.""" 312 | 313 | reset_all = EventTerm(func=mdp.reset_scene_to_default, mode="reset") 314 | 315 | 316 | @configclass 317 | class CommandsCfg: 318 | """Command terms for the MDP.""" 319 | 320 | 321 | @configclass 322 | class RewardsCfg: 323 | """Reward terms for the MDP.""" 324 | 325 | 326 | @configclass 327 | class TerminationsCfg: 328 | """Termination terms for the MDP.""" 329 | 330 | time_out = DoneTerm(func=mdp.time_out, time_out=True) 331 | 332 | 333 | @configclass 334 | class CurriculumCfg: 335 | """Curriculum configuration.""" 336 | 337 | 338 | @configclass 339 | class EnvCfg(ManagerBasedRLEnvCfg): 340 | scene = SceneCfg(num_envs=1, env_spacing=7.0) 341 | 342 | observations = ObservationCfg() 343 | actions = ActionCfg() 344 | rewards = RewardsCfg() 345 | 346 | terminations = TerminationsCfg() 347 | commands = CommandsCfg() 348 | events = EventCfg() 349 | curriculum = CurriculumCfg() 350 | 351 | def __post_init__(self): 352 | self.episode_length_s = 30 353 | 354 | self.viewer.eye = (4.5, 0.0, 6.0) 355 | self.viewer.lookat = (0.0, 0.0, 0.0) 356 | 357 | self.decimation = 4 * 2 358 | self.sim.dt = 1 / (60 * 2) 359 | self.sim.render_interval = 4 * 2 360 | 361 | self.rerender_on_reset = True 362 | 363 | def dynamic_setup(self, *args): 364 | self.scene.dynamic_setup(*args) 365 | 366 | 367 | #### END DROID #### 368 | -------------------------------------------------------------------------------- /src/diff-surfel-rasterization/diff_surfel_rasterization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from typing import NamedTuple 13 | import torch.nn as nn 14 | import torch 15 | 16 | # ============================================================================== 17 | # JIT COMPILATION SUPPORT 18 | # ============================================================================== 19 | # Try to import pre-compiled _C module, if not available, compile via JIT 20 | # ============================================================================== 21 | 22 | try: 23 | # Try to import pre-compiled extension (from setup.py build) 24 | from . import _C 25 | except ImportError: 26 | # If not available, compile via JIT on first import 27 | import os 28 | import warnings 29 | from pathlib import Path 30 | 31 | def _load_extension_jit(): 32 | """JIT compile the CUDA extension if pre-built version not available.""" 33 | from torch.utils.cpp_extension import load 34 | 35 | # Get source directory (parent of this __init__.py) 36 | # _src_path = Path(__file__).parent.parent 37 | _src_path = Path(__file__).parent / "csrc" 38 | 39 | # Find all source files 40 | sources = [] 41 | for pattern in ["ext.cpp", "cuda_rasterizer/*.cu", "*.cu"]: 42 | sources.extend([str(p) for p in _src_path.glob(pattern)]) 43 | 44 | if not sources: 45 | raise FileNotFoundError( 46 | f"No source files found in {_src_path}. " 47 | "Make sure diff-surfel-rasterization is properly installed." 48 | ) 49 | 50 | # Compilation settings 51 | extra_cuda_cflags = [ 52 | "-O3", 53 | "--use_fast_math", 54 | "-std=c++17", 55 | "--expt-relaxed-constexpr", 56 | "-U__CUDA_NO_HALF_OPERATORS__", 57 | "-U__CUDA_NO_HALF_CONVERSIONS__", 58 | "-U__CUDA_NO_HALF2_OPERATORS__", 59 | ] 60 | 61 | extra_cflags = ["-O3", "-std=c++17"] 62 | 63 | # Include directories 64 | include_dirs = [ 65 | str(_src_path), 66 | str(_src_path / "cuda_rasterizer"), 67 | str(_src_path.parent.parent / "third_party" / "glm"), 68 | ] 69 | 70 | # Build directory 71 | cuda_ver = ( 72 | torch.version.cuda.replace(".", "_") if torch.cuda.is_available() else "cpu" 73 | ) 74 | build_dir = os.path.join( 75 | os.path.expanduser("~"), 76 | ".cache", 77 | "torch_extensions", 78 | f"diff_surfel_rasterization_cu{cuda_ver}", 79 | ) 80 | 81 | # Create build directory if it doesn't exist 82 | os.makedirs(build_dir, exist_ok=True) 83 | 84 | is_first_build = not os.path.exists(os.path.join(build_dir, "build.ninja")) 85 | if is_first_build: 86 | print("\n" + "=" * 70) 87 | print("Compiling diff-surfel-rasterization (first time only)...") 88 | print("This will take 2-5 minutes.") 89 | print("=" * 70 + "\n") 90 | 91 | try: 92 | extension = load( 93 | name="diff_surfel_rasterization_cuda", 94 | sources=sources, 95 | extra_cflags=extra_cflags, 96 | extra_cuda_cflags=extra_cuda_cflags, 97 | extra_include_paths=include_dirs, 98 | build_directory=build_dir, 99 | verbose=is_first_build, 100 | with_cuda=True, 101 | ) 102 | 103 | if is_first_build: 104 | print("\n✓ Compilation successful! Cached for future use.\n") 105 | 106 | return extension 107 | 108 | except Exception as e: 109 | print("\n" + "=" * 70) 110 | print("ERROR: Failed to compile diff-surfel-rasterization") 111 | print("=" * 70) 112 | print(f"\n{e}\n") 113 | print("Requirements:") 114 | print(" - CUDA toolkit installed") 115 | print(" - Compatible C++ compiler (gcc 7-12)") 116 | print(" - PyTorch with CUDA support") 117 | print("=" * 70 + "\n") 118 | raise 119 | 120 | # Load via JIT 121 | if not torch.cuda.is_available(): 122 | raise RuntimeError( 123 | "CUDA not available. diff-surfel-rasterization requires CUDA.\n" 124 | f"PyTorch version: {torch.__version__}" 125 | ) 126 | 127 | _C = _load_extension_jit() 128 | 129 | # ============================================================================== 130 | # REST OF ORIGINAL CODE (unchanged) 131 | # ============================================================================== 132 | 133 | 134 | def cpu_deep_copy_tuple(input_tuple): 135 | copied_tensors = [ 136 | item.cpu().clone() if isinstance(item, torch.Tensor) else item 137 | for item in input_tuple 138 | ] 139 | return tuple(copied_tensors) 140 | 141 | 142 | def rasterize_gaussians( 143 | means3D, 144 | means2D, 145 | sh, 146 | colors_precomp, 147 | opacities, 148 | scales, 149 | rotations, 150 | cov3Ds_precomp, 151 | raster_settings, 152 | ): 153 | return _RasterizeGaussians.apply( 154 | means3D, 155 | means2D, 156 | sh, 157 | colors_precomp, 158 | opacities, 159 | scales, 160 | rotations, 161 | cov3Ds_precomp, 162 | raster_settings, 163 | ) 164 | 165 | 166 | class _RasterizeGaussians(torch.autograd.Function): 167 | @staticmethod 168 | def forward( 169 | ctx, 170 | means3D, 171 | means2D, 172 | sh, 173 | colors_precomp, 174 | opacities, 175 | scales, 176 | rotations, 177 | cov3Ds_precomp, 178 | raster_settings, 179 | ): 180 | # Restructure arguments the way that the C++ lib expects them 181 | args = ( 182 | raster_settings.bg, 183 | means3D, 184 | colors_precomp, 185 | opacities, 186 | scales, 187 | rotations, 188 | raster_settings.scale_modifier, 189 | cov3Ds_precomp, 190 | raster_settings.viewmatrix, 191 | raster_settings.projmatrix, 192 | raster_settings.tanfovx, 193 | raster_settings.tanfovy, 194 | raster_settings.image_height, 195 | raster_settings.image_width, 196 | sh, 197 | raster_settings.sh_degree, 198 | raster_settings.campos, 199 | raster_settings.prefiltered, 200 | raster_settings.debug, 201 | raster_settings.near_n, 202 | raster_settings.far_n, 203 | ) 204 | 205 | # Invoke C++/CUDA rasterizer 206 | if raster_settings.debug: 207 | cpu_args = cpu_deep_copy_tuple( 208 | args 209 | ) # Copy them before they can be corrupted 210 | try: 211 | ( 212 | num_rendered, 213 | color, 214 | depth, 215 | radii, 216 | geomBuffer, 217 | binningBuffer, 218 | imgBuffer, 219 | ) = _C.rasterize_gaussians(*args) 220 | except Exception as ex: 221 | torch.save(cpu_args, "snapshot_fw.dump") 222 | print( 223 | "\nAn error occured in forward. Please forward snapshot_fw.dump for debugging." 224 | ) 225 | raise ex 226 | else: 227 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = ( 228 | _C.rasterize_gaussians(*args) 229 | ) 230 | 231 | # Keep relevant tensors for backward 232 | ctx.raster_settings = raster_settings 233 | ctx.num_rendered = num_rendered 234 | ctx.save_for_backward( 235 | colors_precomp, 236 | means3D, 237 | scales, 238 | rotations, 239 | cov3Ds_precomp, 240 | radii, 241 | sh, 242 | geomBuffer, 243 | binningBuffer, 244 | imgBuffer, 245 | ) 246 | return color, radii, depth 247 | 248 | @staticmethod 249 | def backward(ctx, grad_out_color, grad_radii, grad_depth): 250 | # Restore necessary values from context 251 | num_rendered = ctx.num_rendered 252 | raster_settings = ctx.raster_settings 253 | ( 254 | colors_precomp, 255 | means3D, 256 | scales, 257 | rotations, 258 | cov3Ds_precomp, 259 | radii, 260 | sh, 261 | geomBuffer, 262 | binningBuffer, 263 | imgBuffer, 264 | ) = ctx.saved_tensors 265 | 266 | # Restructure args as C++ method expects them 267 | args = ( 268 | raster_settings.bg, 269 | means3D, 270 | radii, 271 | colors_precomp, 272 | scales, 273 | rotations, 274 | raster_settings.scale_modifier, 275 | cov3Ds_precomp, 276 | raster_settings.viewmatrix, 277 | raster_settings.projmatrix, 278 | raster_settings.tanfovx, 279 | raster_settings.tanfovy, 280 | grad_out_color, 281 | grad_depth, 282 | sh, 283 | raster_settings.sh_degree, 284 | raster_settings.campos, 285 | geomBuffer, 286 | num_rendered, 287 | binningBuffer, 288 | imgBuffer, 289 | raster_settings.debug, 290 | raster_settings.near_n, 291 | raster_settings.far_n, 292 | ) 293 | 294 | # Compute gradients for relevant tensors by invoking backward method 295 | if raster_settings.debug: 296 | cpu_args = cpu_deep_copy_tuple( 297 | args 298 | ) # Copy them before they can be corrupted 299 | try: 300 | ( 301 | grad_means2D, 302 | grad_colors_precomp, 303 | grad_opacities, 304 | grad_means3D, 305 | grad_cov3Ds_precomp, 306 | grad_sh, 307 | grad_scales, 308 | grad_rotations, 309 | ) = _C.rasterize_gaussians_backward(*args) 310 | except Exception as ex: 311 | torch.save(cpu_args, "snapshot_bw.dump") 312 | print( 313 | "\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n" 314 | ) 315 | raise ex 316 | else: 317 | ( 318 | grad_means2D, 319 | grad_colors_precomp, 320 | grad_opacities, 321 | grad_means3D, 322 | grad_cov3Ds_precomp, 323 | grad_sh, 324 | grad_scales, 325 | grad_rotations, 326 | ) = _C.rasterize_gaussians_backward(*args) 327 | 328 | grads = ( 329 | grad_means3D, 330 | grad_means2D, 331 | grad_sh, 332 | grad_colors_precomp, 333 | grad_opacities, 334 | grad_scales, 335 | grad_rotations, 336 | grad_cov3Ds_precomp, 337 | None, 338 | ) 339 | 340 | return grads 341 | 342 | 343 | class GaussianRasterizationSettings(NamedTuple): 344 | image_height: int 345 | image_width: int 346 | tanfovx: float 347 | tanfovy: float 348 | bg: torch.Tensor 349 | scale_modifier: float 350 | viewmatrix: torch.Tensor 351 | projmatrix: torch.Tensor 352 | sh_degree: int 353 | campos: torch.Tensor 354 | prefiltered: bool 355 | debug: bool 356 | near_n: float 357 | far_n: float 358 | 359 | 360 | class GaussianRasterizer(nn.Module): 361 | def __init__(self, raster_settings): 362 | super().__init__() 363 | self.raster_settings = raster_settings 364 | 365 | def markVisible(self, positions): 366 | # Mark visible points (based on frustum culling for camera) with a boolean 367 | with torch.no_grad(): 368 | raster_settings = self.raster_settings 369 | visible = _C.mark_visible( 370 | positions, 371 | raster_settings.viewmatrix, 372 | raster_settings.projmatrix, 373 | raster_settings.near_n, 374 | raster_settings.far_n, 375 | ) 376 | 377 | return visible 378 | 379 | def forward( 380 | self, 381 | means3D, 382 | means2D, 383 | opacities, 384 | shs=None, 385 | colors_precomp=None, 386 | scales=None, 387 | rotations=None, 388 | cov3D_precomp=None, 389 | ): 390 | raster_settings = self.raster_settings 391 | 392 | if (shs is None and colors_precomp is None) or ( 393 | shs is not None and colors_precomp is not None 394 | ): 395 | raise Exception( 396 | "Please provide excatly one of either SHs or precomputed colors!" 397 | ) 398 | 399 | if ((scales is None or rotations is None) and cov3D_precomp is None) or ( 400 | (scales is not None or rotations is not None) and cov3D_precomp is not None 401 | ): 402 | raise Exception( 403 | "Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!" 404 | ) 405 | 406 | if shs is None: 407 | shs = torch.Tensor([]).cuda() 408 | if colors_precomp is None: 409 | colors_precomp = torch.Tensor([]).cuda() 410 | 411 | if scales is None: 412 | scales = torch.Tensor([]).cuda() 413 | if rotations is None: 414 | rotations = torch.Tensor([]).cuda() 415 | if cov3D_precomp is None: 416 | cov3D_precomp = torch.Tensor([]).cuda() 417 | 418 | # Invoke C++/CUDA rasterization routine 419 | return rasterize_gaussians( 420 | means3D, 421 | means2D, 422 | shs, 423 | colors_precomp, 424 | opacities, 425 | scales, 426 | rotations, 427 | cov3D_precomp, 428 | raster_settings, 429 | ) 430 | --------------------------------------------------------------------------------