├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── DATASET.md ├── LICENSE ├── README.md ├── README_OLD.md ├── SECURITY.md ├── SUPPORT.md ├── app.py ├── assets ├── example_image │ ├── T.png │ ├── typical_building_building.png │ ├── typical_building_castle.png │ ├── typical_building_colorful_cottage.png │ ├── typical_building_maya_pyramid.png │ ├── typical_building_mushroom.png │ ├── typical_building_space_station.png │ ├── typical_creature_dragon.png │ ├── typical_creature_elephant.png │ ├── typical_creature_furry.png │ ├── typical_creature_quadruped.png │ ├── typical_creature_robot_crab.png │ ├── typical_creature_robot_dinosour.png │ ├── typical_creature_rock_monster.png │ ├── typical_humanoid_block_robot.png │ ├── typical_humanoid_dragonborn.png │ ├── typical_humanoid_dwarf.png │ ├── typical_humanoid_goblin.png │ ├── typical_humanoid_mech.png │ ├── typical_misc_crate.png │ ├── typical_misc_fireplace.png │ ├── typical_misc_gate.png │ ├── typical_misc_lantern.png │ ├── typical_misc_magicbook.png │ ├── typical_misc_mailbox.png │ ├── typical_misc_monster_chest.png │ ├── typical_misc_paper_machine.png │ ├── typical_misc_phonograph.png │ ├── typical_misc_portal2.png │ ├── typical_misc_storage_chest.png │ ├── typical_misc_telephone.png │ ├── typical_misc_television.png │ ├── typical_misc_workbench.png │ ├── typical_vehicle_biplane.png │ ├── typical_vehicle_bulldozer.png │ ├── typical_vehicle_cart.png │ ├── typical_vehicle_excavator.png │ ├── typical_vehicle_helicopter.png │ ├── typical_vehicle_locomotive.png │ ├── typical_vehicle_pirate_ship.png │ └── weatherworn_misc_paper_machine3.png ├── example_multi_image │ ├── character_1.png │ ├── character_2.png │ ├── character_3.png │ ├── mushroom_1.png │ ├── mushroom_2.png │ ├── mushroom_3.png │ ├── orangeguy_1.png │ ├── orangeguy_2.png │ ├── orangeguy_3.png │ ├── popmart_1.png │ ├── popmart_2.png │ ├── popmart_3.png │ ├── rabbit_1.png │ ├── rabbit_2.png │ ├── rabbit_3.png │ ├── tiger_1.png │ ├── tiger_2.png │ ├── tiger_3.png │ ├── yoimiya_1.png │ ├── yoimiya_2.png │ └── yoimiya_3.png ├── logo.webp └── teaser.png ├── dataset_toolkits ├── blender_script │ ├── io_scene_usdz.zip │ └── render.py ├── build_metadata.py ├── datasets │ ├── 3D-FUTURE.py │ ├── ABO.py │ ├── HSSD.py │ ├── ObjaverseXL.py │ └── Toys4k.py ├── download.py ├── encode_latent.py ├── encode_ss_latent.py ├── extract_feature.py ├── render.py ├── render_cond.py ├── setup.sh ├── stat_latent.py ├── utils.py └── voxelize.py ├── example.py ├── example_multi_image.py ├── extensions └── vox2seq │ ├── benchmark.py │ ├── setup.py │ ├── src │ ├── api.cu │ ├── api.h │ ├── ext.cpp │ ├── hilbert.cu │ ├── hilbert.h │ ├── z_order.cu │ └── z_order.h │ ├── test.py │ └── vox2seq │ ├── __init__.py │ └── pytorch │ ├── __init__.py │ ├── default.py │ ├── hilbert.py │ └── z_order.py ├── setup.sh └── trellis ├── __init__.py ├── models ├── __init__.py ├── sparse_structure_flow.py ├── sparse_structure_vae.py ├── structured_latent_flow.py └── structured_latent_vae │ ├── __init__.py │ ├── base.py │ ├── decoder_gs.py │ ├── decoder_mesh.py │ ├── decoder_rf.py │ └── encoder.py ├── modules ├── attention │ ├── __init__.py │ ├── full_attn.py │ └── modules.py ├── norm.py ├── sparse │ ├── __init__.py │ ├── attention │ │ ├── __init__.py │ │ ├── full_attn.py │ │ ├── modules.py │ │ ├── serialized_attn.py │ │ └── windowed_attn.py │ ├── basic.py │ ├── conv │ │ ├── __init__.py │ │ ├── conv_spconv.py │ │ └── conv_torchsparse.py │ ├── linear.py │ ├── nonlinearity.py │ ├── norm.py │ ├── spatial.py │ └── transformer │ │ ├── __init__.py │ │ ├── blocks.py │ │ └── modulated.py ├── spatial.py ├── transformer │ ├── __init__.py │ ├── blocks.py │ └── modulated.py └── utils.py ├── pipelines ├── __init__.py ├── base.py ├── samplers │ ├── __init__.py │ ├── base.py │ ├── classifier_free_guidance_mixin.py │ ├── flow_euler.py │ └── guidance_interval_mixin.py └── trellis_image_to_3d.py ├── renderers ├── __init__.py ├── gsplat_renderer.py ├── mesh_renderer.py ├── octree_renderer.py └── sh_utils.py ├── representations ├── __init__.py ├── gaussian │ ├── __init__.py │ ├── gaussian_model.py │ └── general_utils.py ├── mesh │ ├── __init__.py │ ├── cube2mesh.py │ ├── flexicubes │ │ ├── DCO.txt │ │ ├── LICENSE.txt │ │ ├── README.md │ │ ├── examples │ │ │ ├── download_data.py │ │ │ ├── extraction.ipynb │ │ │ ├── loss.py │ │ │ ├── optimization.ipynb │ │ │ ├── optimize.py │ │ │ ├── render.py │ │ │ └── util.py │ │ ├── flexicubes.py │ │ ├── images │ │ │ ├── ablate_L_dev.jpg │ │ │ ├── block_final.png │ │ │ ├── block_init.png │ │ │ └── teaser_top.png │ │ └── tables.py │ └── utils_cube.py ├── octree │ ├── __init__.py │ └── octree_dfs.py └── radiance_field │ ├── __init__.py │ └── strivec.py └── utils ├── __init__.py ├── bake_texture.py ├── general_utils.py ├── postprocessing_utils.py ├── random_utils.py └── render_utils.py /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/.gitmodules -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## WARNING 2 | This is a work in progress repo under active development. Texturing is not optimal currently. Work in progress. Expect fix in the next commit. Sorry and Thank you. 3 | 4 | ## Introduction 5 | The repo makes 2 major modifications to the original TRELLIS library: 6 | - implements gsplat for rendering 7 | - replaces nvdiffrast with pytorch3d for texturing. 8 | 9 | ## Initial Setup 10 | It is suggested to use Cuda 11.8 due to dependency issues. 11 | - Download installer: 12 | ```wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run``` 13 | - Make it executable 14 | ```chmod +x cuda_11.8.0_520.61.05_linux.run``` 15 | - Install: 16 | ```sudo ./cuda_11.8.0_520.61.05_linux.run``` 17 | 18 | ## Conda 19 | - Create a new conda environment: 20 | ```conda create --name trellis_refactored``` 21 | - Activate: 22 | ```conda activate trellis_refactored``` 23 | - Need to default to cuda 11.8: 24 | ```mkdir -p $CONDA_PREFIX/etc/conda/activate.d```\ 25 | ```echo 'export CUDA_HOME=/usr/local/cuda-11.8' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh```\ 26 | ```echo 'export PATH=$CUDA_HOME/bin:$PATH' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh```\ 27 | ```echo 'export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh``` 28 | - Deactivate and reactivate to reset 29 | - ```nvcc --version``` to check if cuda 11.8 is currently active 30 | 31 | ## Add conda forge channel 32 | ```conda config --add channels conda-forge```\ 33 | ```conda config --set channel_priority flexible``` 34 | 35 | ## Pytorch 36 | Pytorch 2.4.0 is recommended to be used with cuda 11.8. Install with: 37 | ```conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=11.8 -c pytorch -c nvidia``` 38 | 39 | ## Run the setup 40 | ```. ./setup.sh --basic --xformers --flash-attn --diffoctreerast --spconv --mipgaussian --gsplat --demo``` 41 | 42 | ## Pytorch 3d installation with Conda 43 | ```conda install pytorch3d -c pytorch3d``` 44 | 45 | ## BUG ALERTS !! 46 | 1. ```python app.py``` gives gradio error: argument of type 'bool' is not iterable. 47 | Happens due to pydantic version mismatch. 48 | Unistall pydantic with: 49 | ```pip uninstall pydantic``` 50 | Install version 2.10.6 51 | ```pip install pydantic==2.10.6``` 52 | 53 | 2. Torchvision fails after installing pytorch3d. 54 | - Remove torchvision completely: 55 | ```conda remove torchvision``` 56 | - Reinstall with full dependencies: 57 | ```conda install pytorch==2.4.0 torchvision==0.19.0 pytorch-cuda=11.8 -c pytorch -c nvidia``` 58 | - After installing tqdm install pytorch3d again: 59 | ```conda install pytorch3d -c pytorch3d``` 60 | 61 | ## Run the gradio implementation and check in browser 62 | ```python app.py``` -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /assets/example_image/T.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/T.png -------------------------------------------------------------------------------- /assets/example_image/typical_building_building.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_building_building.png -------------------------------------------------------------------------------- /assets/example_image/typical_building_castle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_building_castle.png -------------------------------------------------------------------------------- /assets/example_image/typical_building_colorful_cottage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_building_colorful_cottage.png -------------------------------------------------------------------------------- /assets/example_image/typical_building_maya_pyramid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_building_maya_pyramid.png -------------------------------------------------------------------------------- /assets/example_image/typical_building_mushroom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_building_mushroom.png -------------------------------------------------------------------------------- /assets/example_image/typical_building_space_station.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_building_space_station.png -------------------------------------------------------------------------------- /assets/example_image/typical_creature_dragon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_creature_dragon.png -------------------------------------------------------------------------------- /assets/example_image/typical_creature_elephant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_creature_elephant.png -------------------------------------------------------------------------------- /assets/example_image/typical_creature_furry.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_creature_furry.png -------------------------------------------------------------------------------- /assets/example_image/typical_creature_quadruped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_creature_quadruped.png -------------------------------------------------------------------------------- /assets/example_image/typical_creature_robot_crab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_creature_robot_crab.png -------------------------------------------------------------------------------- /assets/example_image/typical_creature_robot_dinosour.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_creature_robot_dinosour.png -------------------------------------------------------------------------------- /assets/example_image/typical_creature_rock_monster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_creature_rock_monster.png -------------------------------------------------------------------------------- /assets/example_image/typical_humanoid_block_robot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_humanoid_block_robot.png -------------------------------------------------------------------------------- /assets/example_image/typical_humanoid_dragonborn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_humanoid_dragonborn.png -------------------------------------------------------------------------------- /assets/example_image/typical_humanoid_dwarf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_humanoid_dwarf.png -------------------------------------------------------------------------------- /assets/example_image/typical_humanoid_goblin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_humanoid_goblin.png -------------------------------------------------------------------------------- /assets/example_image/typical_humanoid_mech.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_humanoid_mech.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_crate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_crate.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_fireplace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_fireplace.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_gate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_gate.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_lantern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_lantern.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_magicbook.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_magicbook.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_mailbox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_mailbox.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_monster_chest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_monster_chest.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_paper_machine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_paper_machine.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_phonograph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_phonograph.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_portal2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_portal2.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_storage_chest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_storage_chest.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_telephone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_telephone.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_television.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_television.png -------------------------------------------------------------------------------- /assets/example_image/typical_misc_workbench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_misc_workbench.png -------------------------------------------------------------------------------- /assets/example_image/typical_vehicle_biplane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_vehicle_biplane.png -------------------------------------------------------------------------------- /assets/example_image/typical_vehicle_bulldozer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_vehicle_bulldozer.png -------------------------------------------------------------------------------- /assets/example_image/typical_vehicle_cart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_vehicle_cart.png -------------------------------------------------------------------------------- /assets/example_image/typical_vehicle_excavator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_vehicle_excavator.png -------------------------------------------------------------------------------- /assets/example_image/typical_vehicle_helicopter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_vehicle_helicopter.png -------------------------------------------------------------------------------- /assets/example_image/typical_vehicle_locomotive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_vehicle_locomotive.png -------------------------------------------------------------------------------- /assets/example_image/typical_vehicle_pirate_ship.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/typical_vehicle_pirate_ship.png -------------------------------------------------------------------------------- /assets/example_image/weatherworn_misc_paper_machine3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_image/weatherworn_misc_paper_machine3.png -------------------------------------------------------------------------------- /assets/example_multi_image/character_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/character_1.png -------------------------------------------------------------------------------- /assets/example_multi_image/character_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/character_2.png -------------------------------------------------------------------------------- /assets/example_multi_image/character_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/character_3.png -------------------------------------------------------------------------------- /assets/example_multi_image/mushroom_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/mushroom_1.png -------------------------------------------------------------------------------- /assets/example_multi_image/mushroom_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/mushroom_2.png -------------------------------------------------------------------------------- /assets/example_multi_image/mushroom_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/mushroom_3.png -------------------------------------------------------------------------------- /assets/example_multi_image/orangeguy_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/orangeguy_1.png -------------------------------------------------------------------------------- /assets/example_multi_image/orangeguy_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/orangeguy_2.png -------------------------------------------------------------------------------- /assets/example_multi_image/orangeguy_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/orangeguy_3.png -------------------------------------------------------------------------------- /assets/example_multi_image/popmart_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/popmart_1.png -------------------------------------------------------------------------------- /assets/example_multi_image/popmart_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/popmart_2.png -------------------------------------------------------------------------------- /assets/example_multi_image/popmart_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/popmart_3.png -------------------------------------------------------------------------------- /assets/example_multi_image/rabbit_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/rabbit_1.png -------------------------------------------------------------------------------- /assets/example_multi_image/rabbit_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/rabbit_2.png -------------------------------------------------------------------------------- /assets/example_multi_image/rabbit_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/rabbit_3.png -------------------------------------------------------------------------------- /assets/example_multi_image/tiger_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/tiger_1.png -------------------------------------------------------------------------------- /assets/example_multi_image/tiger_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/tiger_2.png -------------------------------------------------------------------------------- /assets/example_multi_image/tiger_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/tiger_3.png -------------------------------------------------------------------------------- /assets/example_multi_image/yoimiya_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/yoimiya_1.png -------------------------------------------------------------------------------- /assets/example_multi_image/yoimiya_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/yoimiya_2.png -------------------------------------------------------------------------------- /assets/example_multi_image/yoimiya_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/example_multi_image/yoimiya_3.png -------------------------------------------------------------------------------- /assets/logo.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/logo.webp -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/assets/teaser.png -------------------------------------------------------------------------------- /dataset_toolkits/blender_script/io_scene_usdz.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/dataset_toolkits/blender_script/io_scene_usdz.zip -------------------------------------------------------------------------------- /dataset_toolkits/datasets/3D-FUTURE.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import zipfile 5 | from concurrent.futures import ThreadPoolExecutor 6 | from tqdm import tqdm 7 | import pandas as pd 8 | from utils import get_file_hash 9 | 10 | 11 | def add_args(parser: argparse.ArgumentParser): 12 | pass 13 | 14 | 15 | def get_metadata(**kwargs): 16 | metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv") 17 | return metadata 18 | 19 | 20 | def download(metadata, output_dir, **kwargs): 21 | os.makedirs(output_dir, exist_ok=True) 22 | 23 | if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')): 24 | print("\033[93m") 25 | print("3D-FUTURE have to be downloaded manually") 26 | print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory") 27 | print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information") 28 | print("\033[0m") 29 | raise FileNotFoundError("3D-FUTURE-model.zip not found") 30 | 31 | downloaded = {} 32 | metadata = metadata.set_index("file_identifier") 33 | with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref: 34 | all_names = zip_ref.namelist() 35 | instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)] 36 | instances = list(filter(lambda x: x in metadata.index, instances)) 37 | 38 | with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ 39 | tqdm(total=len(instances), desc="Extracting") as pbar: 40 | def worker(instance: str) -> str: 41 | try: 42 | instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names)) 43 | zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files) 44 | sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg")) 45 | pbar.update() 46 | return sha256 47 | except Exception as e: 48 | pbar.update() 49 | print(f"Error extracting for {instance}: {e}") 50 | return None 51 | 52 | sha256s = executor.map(worker, instances) 53 | executor.shutdown(wait=True) 54 | 55 | for k, sha256 in zip(instances, sha256s): 56 | if sha256 is not None: 57 | if sha256 == metadata.loc[k, "sha256"]: 58 | downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj") 59 | else: 60 | print(f"Error downloading {k}: sha256s do not match") 61 | 62 | return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) 63 | 64 | 65 | def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: 66 | import os 67 | from concurrent.futures import ThreadPoolExecutor 68 | from tqdm import tqdm 69 | 70 | # load metadata 71 | metadata = metadata.to_dict('records') 72 | 73 | # processing objects 74 | records = [] 75 | max_workers = max_workers or os.cpu_count() 76 | try: 77 | with ThreadPoolExecutor(max_workers=max_workers) as executor, \ 78 | tqdm(total=len(metadata), desc=desc) as pbar: 79 | def worker(metadatum): 80 | try: 81 | local_path = metadatum['local_path'] 82 | sha256 = metadatum['sha256'] 83 | file = os.path.join(output_dir, local_path) 84 | record = func(file, sha256) 85 | if record is not None: 86 | records.append(record) 87 | pbar.update() 88 | except Exception as e: 89 | print(f"Error processing object {sha256}: {e}") 90 | pbar.update() 91 | 92 | executor.map(worker, metadata) 93 | executor.shutdown(wait=True) 94 | except: 95 | print("Error happened during processing.") 96 | 97 | return pd.DataFrame.from_records(records) 98 | -------------------------------------------------------------------------------- /dataset_toolkits/datasets/ABO.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import tarfile 5 | from concurrent.futures import ThreadPoolExecutor 6 | from tqdm import tqdm 7 | import pandas as pd 8 | from utils import get_file_hash 9 | 10 | 11 | def add_args(parser: argparse.ArgumentParser): 12 | pass 13 | 14 | 15 | def get_metadata(**kwargs): 16 | metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv") 17 | return metadata 18 | 19 | 20 | def download(metadata, output_dir, **kwargs): 21 | os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) 22 | 23 | if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')): 24 | try: 25 | os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) 26 | os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar") 27 | except: 28 | print("\033[93m") 29 | print("Error downloading ABO dataset. Please check your internet connection and try again.") 30 | print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory") 31 | print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information") 32 | print("\033[0m") 33 | raise FileNotFoundError("Error downloading ABO dataset") 34 | 35 | downloaded = {} 36 | metadata = metadata.set_index("file_identifier") 37 | with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar: 38 | with ThreadPoolExecutor(max_workers=1) as executor, \ 39 | tqdm(total=len(metadata), desc="Extracting") as pbar: 40 | def worker(instance: str) -> str: 41 | try: 42 | tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw')) 43 | sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance)) 44 | pbar.update() 45 | return sha256 46 | except Exception as e: 47 | pbar.update() 48 | print(f"Error extracting for {instance}: {e}") 49 | return None 50 | 51 | sha256s = executor.map(worker, metadata.index) 52 | executor.shutdown(wait=True) 53 | 54 | for k, sha256 in zip(metadata.index, sha256s): 55 | if sha256 is not None: 56 | if sha256 == metadata.loc[k, "sha256"]: 57 | downloaded[sha256] = os.path.join('raw/3dmodels/original', k) 58 | else: 59 | print(f"Error downloading {k}: sha256s do not match") 60 | 61 | return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) 62 | 63 | 64 | def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: 65 | import os 66 | from concurrent.futures import ThreadPoolExecutor 67 | from tqdm import tqdm 68 | 69 | # load metadata 70 | metadata = metadata.to_dict('records') 71 | 72 | # processing objects 73 | records = [] 74 | max_workers = max_workers or os.cpu_count() 75 | try: 76 | with ThreadPoolExecutor(max_workers=max_workers) as executor, \ 77 | tqdm(total=len(metadata), desc=desc) as pbar: 78 | def worker(metadatum): 79 | try: 80 | local_path = metadatum['local_path'] 81 | sha256 = metadatum['sha256'] 82 | file = os.path.join(output_dir, local_path) 83 | record = func(file, sha256) 84 | if record is not None: 85 | records.append(record) 86 | pbar.update() 87 | except Exception as e: 88 | print(f"Error processing object {sha256}: {e}") 89 | pbar.update() 90 | 91 | executor.map(worker, metadata) 92 | executor.shutdown(wait=True) 93 | except: 94 | print("Error happened during processing.") 95 | 96 | return pd.DataFrame.from_records(records) 97 | -------------------------------------------------------------------------------- /dataset_toolkits/datasets/HSSD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import tarfile 5 | from concurrent.futures import ThreadPoolExecutor 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import huggingface_hub 9 | from utils import get_file_hash 10 | 11 | 12 | def add_args(parser: argparse.ArgumentParser): 13 | pass 14 | 15 | 16 | def get_metadata(**kwargs): 17 | metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv") 18 | return metadata 19 | 20 | 21 | def download(metadata, output_dir, **kwargs): 22 | os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) 23 | 24 | # check login 25 | try: 26 | huggingface_hub.whoami() 27 | except: 28 | print("\033[93m") 29 | print("Haven't logged in to the Hugging Face Hub.") 30 | print("Visit https://huggingface.co/settings/tokens to get a token.") 31 | print("\033[0m") 32 | huggingface_hub.login() 33 | 34 | try: 35 | huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset") 36 | except: 37 | print("\033[93m") 38 | print("Error downloading HSSD dataset.") 39 | print("Check if you have access to the HSSD dataset.") 40 | print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information") 41 | print("\033[0m") 42 | 43 | downloaded = {} 44 | metadata = metadata.set_index("file_identifier") 45 | with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ 46 | tqdm(total=len(metadata), desc="Downloading") as pbar: 47 | def worker(instance: str) -> str: 48 | try: 49 | huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw')) 50 | sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance)) 51 | pbar.update() 52 | return sha256 53 | except Exception as e: 54 | pbar.update() 55 | print(f"Error extracting for {instance}: {e}") 56 | return None 57 | 58 | sha256s = executor.map(worker, metadata.index) 59 | executor.shutdown(wait=True) 60 | 61 | for k, sha256 in zip(metadata.index, sha256s): 62 | if sha256 is not None: 63 | if sha256 == metadata.loc[k, "sha256"]: 64 | downloaded[sha256] = os.path.join('raw', k) 65 | else: 66 | print(f"Error downloading {k}: sha256s do not match") 67 | 68 | return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) 69 | 70 | 71 | def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: 72 | import os 73 | from concurrent.futures import ThreadPoolExecutor 74 | from tqdm import tqdm 75 | 76 | # load metadata 77 | metadata = metadata.to_dict('records') 78 | 79 | # processing objects 80 | records = [] 81 | max_workers = max_workers or os.cpu_count() 82 | try: 83 | with ThreadPoolExecutor(max_workers=max_workers) as executor, \ 84 | tqdm(total=len(metadata), desc=desc) as pbar: 85 | def worker(metadatum): 86 | try: 87 | local_path = metadatum['local_path'] 88 | sha256 = metadatum['sha256'] 89 | file = os.path.join(output_dir, local_path) 90 | record = func(file, sha256) 91 | if record is not None: 92 | records.append(record) 93 | pbar.update() 94 | except Exception as e: 95 | print(f"Error processing object {sha256}: {e}") 96 | pbar.update() 97 | 98 | executor.map(worker, metadata) 99 | executor.shutdown(wait=True) 100 | except: 101 | print("Error happened during processing.") 102 | 103 | return pd.DataFrame.from_records(records) 104 | -------------------------------------------------------------------------------- /dataset_toolkits/datasets/ObjaverseXL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from concurrent.futures import ThreadPoolExecutor 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import objaverse.xl as oxl 7 | from utils import get_file_hash 8 | 9 | 10 | def add_args(parser: argparse.ArgumentParser): 11 | parser.add_argument('--source', type=str, default='sketchfab', 12 | help='Data source to download annotations from (github, sketchfab)') 13 | 14 | 15 | def get_metadata(source, **kwargs): 16 | if source == 'sketchfab': 17 | metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv") 18 | elif source == 'github': 19 | metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv") 20 | else: 21 | raise ValueError(f"Invalid source: {source}") 22 | return metadata 23 | 24 | 25 | def download(metadata, output_dir, **kwargs): 26 | os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) 27 | 28 | # download annotations 29 | annotations = oxl.get_annotations() 30 | annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)] 31 | 32 | # download and render objects 33 | file_paths = oxl.download_objects( 34 | annotations, 35 | download_dir=os.path.join(output_dir, "raw"), 36 | save_repo_format="zip", 37 | ) 38 | 39 | downloaded = {} 40 | metadata = metadata.set_index("file_identifier") 41 | for k, v in file_paths.items(): 42 | sha256 = metadata.loc[k, "sha256"] 43 | downloaded[sha256] = os.path.relpath(v, output_dir) 44 | 45 | return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) 46 | 47 | 48 | def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: 49 | import os 50 | from concurrent.futures import ThreadPoolExecutor 51 | from tqdm import tqdm 52 | import tempfile 53 | import zipfile 54 | 55 | # load metadata 56 | metadata = metadata.to_dict('records') 57 | 58 | # processing objects 59 | records = [] 60 | max_workers = max_workers or os.cpu_count() 61 | try: 62 | with ThreadPoolExecutor(max_workers=max_workers) as executor, \ 63 | tqdm(total=len(metadata), desc=desc) as pbar: 64 | def worker(metadatum): 65 | try: 66 | local_path = metadatum['local_path'] 67 | sha256 = metadatum['sha256'] 68 | if local_path.startswith('raw/github/repos/'): 69 | path_parts = local_path.split('/') 70 | file_name = os.path.join(*path_parts[5:]) 71 | zip_file = os.path.join(output_dir, *path_parts[:5]) 72 | with tempfile.TemporaryDirectory() as tmp_dir: 73 | with zipfile.ZipFile(zip_file, 'r') as zip_ref: 74 | zip_ref.extractall(tmp_dir) 75 | file = os.path.join(tmp_dir, file_name) 76 | record = func(file, sha256) 77 | else: 78 | file = os.path.join(output_dir, local_path) 79 | record = func(file, sha256) 80 | if record is not None: 81 | records.append(record) 82 | pbar.update() 83 | except Exception as e: 84 | print(f"Error processing object {sha256}: {e}") 85 | pbar.update() 86 | 87 | executor.map(worker, metadata) 88 | executor.shutdown(wait=True) 89 | except: 90 | print("Error happened during processing.") 91 | 92 | return pd.DataFrame.from_records(records) 93 | -------------------------------------------------------------------------------- /dataset_toolkits/datasets/Toys4k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import zipfile 5 | from concurrent.futures import ThreadPoolExecutor 6 | from tqdm import tqdm 7 | import pandas as pd 8 | from utils import get_file_hash 9 | 10 | 11 | def add_args(parser: argparse.ArgumentParser): 12 | pass 13 | 14 | 15 | def get_metadata(**kwargs): 16 | metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/Toys4k.csv") 17 | return metadata 18 | 19 | 20 | def download(metadata, output_dir, **kwargs): 21 | os.makedirs(output_dir, exist_ok=True) 22 | 23 | if not os.path.exists(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')): 24 | print("\033[93m") 25 | print("Toys4k have to be downloaded manually") 26 | print(f"Please download the toys4k_blend_files.zip file and place it in the {output_dir}/raw directory") 27 | print("Visit https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k for more information") 28 | print("\033[0m") 29 | raise FileNotFoundError("toys4k_blend_files.zip not found") 30 | 31 | downloaded = {} 32 | metadata = metadata.set_index("file_identifier") 33 | with zipfile.ZipFile(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')) as zip_ref: 34 | with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ 35 | tqdm(total=len(metadata), desc="Extracting") as pbar: 36 | def worker(instance: str) -> str: 37 | try: 38 | zip_ref.extract(os.path.join('toys4k_blend_files', instance), os.path.join(output_dir, 'raw')) 39 | sha256 = get_file_hash(os.path.join(output_dir, 'raw/toys4k_blend_files', instance)) 40 | pbar.update() 41 | return sha256 42 | except Exception as e: 43 | pbar.update() 44 | print(f"Error extracting for {instance}: {e}") 45 | return None 46 | 47 | sha256s = executor.map(worker, metadata.index) 48 | executor.shutdown(wait=True) 49 | 50 | for k, sha256 in zip(metadata.index, sha256s): 51 | if sha256 is not None: 52 | if sha256 == metadata.loc[k, "sha256"]: 53 | downloaded[sha256] = os.path.join("raw/toys4k_blend_files", k) 54 | else: 55 | print(f"Error downloading {k}: sha256s do not match") 56 | 57 | return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) 58 | 59 | 60 | def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: 61 | import os 62 | from concurrent.futures import ThreadPoolExecutor 63 | from tqdm import tqdm 64 | 65 | # load metadata 66 | metadata = metadata.to_dict('records') 67 | 68 | # processing objects 69 | records = [] 70 | max_workers = max_workers or os.cpu_count() 71 | try: 72 | with ThreadPoolExecutor(max_workers=max_workers) as executor, \ 73 | tqdm(total=len(metadata), desc=desc) as pbar: 74 | def worker(metadatum): 75 | try: 76 | local_path = metadatum['local_path'] 77 | sha256 = metadatum['sha256'] 78 | file = os.path.join(output_dir, local_path) 79 | record = func(file, sha256) 80 | if record is not None: 81 | records.append(record) 82 | pbar.update() 83 | except Exception as e: 84 | print(f"Error processing object {sha256}: {e}") 85 | pbar.update() 86 | 87 | executor.map(worker, metadata) 88 | executor.shutdown(wait=True) 89 | except: 90 | print("Error happened during processing.") 91 | 92 | return pd.DataFrame.from_records(records) 93 | -------------------------------------------------------------------------------- /dataset_toolkits/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import sys 4 | import importlib 5 | import argparse 6 | import pandas as pd 7 | from easydict import EasyDict as edict 8 | 9 | if __name__ == '__main__': 10 | dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--output_dir', type=str, required=True, 14 | help='Directory to save the metadata') 15 | parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, 16 | help='Filter objects with aesthetic score lower than this value') 17 | parser.add_argument('--instances', type=str, default=None, 18 | help='Instances to process') 19 | dataset_utils.add_args(parser) 20 | parser.add_argument('--rank', type=int, default=0) 21 | parser.add_argument('--world_size', type=int, default=1) 22 | opt = parser.parse_args(sys.argv[2:]) 23 | opt = edict(vars(opt)) 24 | 25 | os.makedirs(opt.output_dir, exist_ok=True) 26 | 27 | # get file list 28 | if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): 29 | raise ValueError('metadata.csv not found') 30 | metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) 31 | if opt.instances is None: 32 | if opt.filter_low_aesthetic_score is not None: 33 | metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] 34 | if 'local_path' in metadata.columns: 35 | metadata = metadata[metadata['local_path'].isna()] 36 | else: 37 | if os.path.exists(opt.instances): 38 | with open(opt.instances, 'r') as f: 39 | instances = f.read().splitlines() 40 | else: 41 | instances = opt.instances.split(',') 42 | metadata = metadata[metadata['sha256'].isin(instances)] 43 | 44 | start = len(metadata) * opt.rank // opt.world_size 45 | end = len(metadata) * (opt.rank + 1) // opt.world_size 46 | metadata = metadata[start:end] 47 | 48 | print(f'Processing {len(metadata)} objects...') 49 | 50 | # process objects 51 | downloaded = dataset_utils.download(metadata, **opt) 52 | downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False) 53 | -------------------------------------------------------------------------------- /dataset_toolkits/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import sys 5 | import importlib 6 | import argparse 7 | import pandas as pd 8 | from easydict import EasyDict as edict 9 | from functools import partial 10 | from subprocess import DEVNULL, call 11 | import numpy as np 12 | from utils import sphere_hammersley_sequence 13 | 14 | 15 | BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' 16 | BLENDER_INSTALLATION_PATH = '/tmp' 17 | BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' 18 | 19 | def _install_blender(): 20 | if not os.path.exists(BLENDER_PATH): 21 | os.system('sudo apt-get update') 22 | os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') 23 | os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') 24 | os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') 25 | 26 | 27 | def _render(file_path, sha256, output_dir, num_views): 28 | output_folder = os.path.join(output_dir, 'renders', sha256) 29 | 30 | # Build camera {yaw, pitch, radius, fov} 31 | yaws = [] 32 | pitchs = [] 33 | offset = (np.random.rand(), np.random.rand()) 34 | for i in range(num_views): 35 | y, p = sphere_hammersley_sequence(i, num_views, offset) 36 | yaws.append(y) 37 | pitchs.append(p) 38 | radius = [2] * num_views 39 | fov = [40 / 180 * np.pi] * num_views 40 | views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] 41 | 42 | args = [ 43 | BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), 44 | '--', 45 | '--views', json.dumps(views), 46 | '--object', os.path.expanduser(file_path), 47 | '--resolution', '512', 48 | '--output_folder', output_folder, 49 | '--engine', 'CYCLES', 50 | '--save_mesh', 51 | ] 52 | if file_path.endswith('.blend'): 53 | args.insert(1, file_path) 54 | 55 | call(args, stdout=DEVNULL, stderr=DEVNULL) 56 | 57 | if os.path.exists(os.path.join(output_folder, 'transforms.json')): 58 | return {'sha256': sha256, 'rendered': True} 59 | 60 | 61 | if __name__ == '__main__': 62 | dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--output_dir', type=str, required=True, 66 | help='Directory to save the metadata') 67 | parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, 68 | help='Filter objects with aesthetic score lower than this value') 69 | parser.add_argument('--instances', type=str, default=None, 70 | help='Instances to process') 71 | parser.add_argument('--num_views', type=int, default=150, 72 | help='Number of views to render') 73 | dataset_utils.add_args(parser) 74 | parser.add_argument('--rank', type=int, default=0) 75 | parser.add_argument('--world_size', type=int, default=1) 76 | parser.add_argument('--max_workers', type=int, default=8) 77 | opt = parser.parse_args(sys.argv[2:]) 78 | opt = edict(vars(opt)) 79 | 80 | os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True) 81 | 82 | # install blender 83 | print('Checking blender...', flush=True) 84 | _install_blender() 85 | 86 | # get file list 87 | if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): 88 | raise ValueError('metadata.csv not found') 89 | metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) 90 | if opt.instances is None: 91 | metadata = metadata[metadata['local_path'].notna()] 92 | if opt.filter_low_aesthetic_score is not None: 93 | metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] 94 | if 'rendered' in metadata.columns: 95 | metadata = metadata[metadata['rendered'] == False] 96 | else: 97 | if os.path.exists(opt.instances): 98 | with open(opt.instances, 'r') as f: 99 | instances = f.read().splitlines() 100 | else: 101 | instances = opt.instances.split(',') 102 | metadata = metadata[metadata['sha256'].isin(instances)] 103 | 104 | start = len(metadata) * opt.rank // opt.world_size 105 | end = len(metadata) * (opt.rank + 1) // opt.world_size 106 | metadata = metadata[start:end] 107 | records = [] 108 | 109 | # filter out objects that are already processed 110 | for sha256 in copy.copy(metadata['sha256'].values): 111 | if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')): 112 | records.append({'sha256': sha256, 'rendered': True}) 113 | metadata = metadata[metadata['sha256'] != sha256] 114 | 115 | print(f'Processing {len(metadata)} objects...') 116 | 117 | # process objects 118 | func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views) 119 | rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') 120 | rendered = pd.concat([rendered, pd.DataFrame.from_records(records)]) 121 | rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False) 122 | -------------------------------------------------------------------------------- /dataset_toolkits/render_cond.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import sys 5 | import importlib 6 | import argparse 7 | import pandas as pd 8 | from easydict import EasyDict as edict 9 | from functools import partial 10 | from subprocess import DEVNULL, call 11 | import numpy as np 12 | from utils import sphere_hammersley_sequence 13 | 14 | 15 | BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' 16 | BLENDER_INSTALLATION_PATH = '/tmp' 17 | BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' 18 | 19 | def _install_blender(): 20 | if not os.path.exists(BLENDER_PATH): 21 | os.system('sudo apt-get update') 22 | os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') 23 | os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') 24 | os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') 25 | 26 | 27 | def _render_cond(file_path, sha256, output_dir, num_views): 28 | output_folder = os.path.join(output_dir, 'renders_cond', sha256) 29 | 30 | # Build camera {yaw, pitch, radius, fov} 31 | yaws = [] 32 | pitchs = [] 33 | offset = (np.random.rand(), np.random.rand()) 34 | for i in range(num_views): 35 | y, p = sphere_hammersley_sequence(i, num_views, offset) 36 | yaws.append(y) 37 | pitchs.append(p) 38 | fov_min, fov_max = 10, 70 39 | radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi) 40 | radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi) 41 | k_min = 1 / radius_max**2 42 | k_max = 1 / radius_min**2 43 | ks = np.random.uniform(k_min, k_max, (1000000,)) 44 | radius = [1 / np.sqrt(k) for k in ks] 45 | fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius] 46 | views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] 47 | 48 | args = [ 49 | BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), 50 | '--', 51 | '--views', json.dumps(views), 52 | '--object', os.path.expanduser(file_path), 53 | '--output_folder', os.path.expanduser(output_folder), 54 | '--resolution', '1024', 55 | ] 56 | if file_path.endswith('.blend'): 57 | args.insert(1, file_path) 58 | 59 | call(args, stdout=DEVNULL) 60 | 61 | if os.path.exists(os.path.join(output_folder, 'transforms.json')): 62 | return {'sha256': sha256, 'cond_rendered': True} 63 | 64 | 65 | if __name__ == '__main__': 66 | dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--output_dir', type=str, required=True, 70 | help='Directory to save the metadata') 71 | parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, 72 | help='Filter objects with aesthetic score lower than this value') 73 | parser.add_argument('--instances', type=str, default=None, 74 | help='Instances to process') 75 | parser.add_argument('--num_views', type=int, default=24, 76 | help='Number of views to render') 77 | dataset_utils.add_args(parser) 78 | parser.add_argument('--rank', type=int, default=0) 79 | parser.add_argument('--world_size', type=int, default=1) 80 | parser.add_argument('--max_workers', type=int, default=8) 81 | opt = parser.parse_args(sys.argv[2:]) 82 | opt = edict(vars(opt)) 83 | 84 | os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True) 85 | 86 | # install blender 87 | print('Checking blender...', flush=True) 88 | _install_blender() 89 | 90 | # get file list 91 | if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): 92 | raise ValueError('metadata.csv not found') 93 | metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) 94 | if opt.instances is None: 95 | metadata = metadata[metadata['local_path'].notna()] 96 | if opt.filter_low_aesthetic_score is not None: 97 | metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] 98 | if 'cond_rendered' in metadata.columns: 99 | metadata = metadata[metadata['cond_rendered'] == False] 100 | else: 101 | if os.path.exists(opt.instances): 102 | with open(opt.instances, 'r') as f: 103 | instances = f.read().splitlines() 104 | else: 105 | instances = opt.instances.split(',') 106 | metadata = metadata[metadata['sha256'].isin(instances)] 107 | 108 | start = len(metadata) * opt.rank // opt.world_size 109 | end = len(metadata) * (opt.rank + 1) // opt.world_size 110 | metadata = metadata[start:end] 111 | records = [] 112 | 113 | # filter out objects that are already processed 114 | for sha256 in copy.copy(metadata['sha256'].values): 115 | if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')): 116 | records.append({'sha256': sha256, 'cond_rendered': True}) 117 | metadata = metadata[metadata['sha256'] != sha256] 118 | 119 | print(f'Processing {len(metadata)} objects...') 120 | 121 | # process objects 122 | func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views) 123 | cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') 124 | cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)]) 125 | cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False) 126 | -------------------------------------------------------------------------------- /dataset_toolkits/setup.sh: -------------------------------------------------------------------------------- 1 | pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub 2 | -------------------------------------------------------------------------------- /dataset_toolkits/stat_latent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from easydict import EasyDict as edict 8 | from concurrent.futures import ThreadPoolExecutor 9 | 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--output_dir', type=str, required=True, 14 | help='Directory to save the metadata') 15 | parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, 16 | help='Filter objects with aesthetic score lower than this value') 17 | parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16', 18 | help='Latent model to use') 19 | parser.add_argument('--num_samples', type=int, default=50000, 20 | help='Number of samples to use for calculating stats') 21 | opt = parser.parse_args() 22 | opt = edict(vars(opt)) 23 | 24 | # get file list 25 | if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): 26 | metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) 27 | else: 28 | raise ValueError('metadata.csv not found') 29 | if opt.filter_low_aesthetic_score is not None: 30 | metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] 31 | metadata = metadata[metadata[f'latent_{opt.model}'] == True] 32 | sha256s = metadata['sha256'].values 33 | sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False) 34 | 35 | # stats 36 | means = [] 37 | mean2s = [] 38 | with ThreadPoolExecutor(max_workers=16) as executor, \ 39 | tqdm(total=len(sha256s), desc="Extracting features") as pbar: 40 | def worker(sha256): 41 | try: 42 | feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz')) 43 | feats = feats['feats'] 44 | means.append(feats.mean(axis=0)) 45 | mean2s.append((feats ** 2).mean(axis=0)) 46 | pbar.update() 47 | except Exception as e: 48 | print(f"Error extracting features for {sha256}: {e}") 49 | pbar.update() 50 | 51 | executor.map(worker, sha256s) 52 | executor.shutdown(wait=True) 53 | 54 | mean = np.array(means).mean(axis=0) 55 | mean2 = np.array(mean2s).mean(axis=0) 56 | std = np.sqrt(mean2 - mean ** 2) 57 | 58 | print('mean:', mean) 59 | print('std:', std) 60 | 61 | with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f: 62 | json.dump({ 63 | 'mean': mean.tolist(), 64 | 'std': std.tolist(), 65 | }, f, indent=4) 66 | -------------------------------------------------------------------------------- /dataset_toolkits/utils.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import hashlib 3 | import numpy as np 4 | 5 | 6 | def get_file_hash(file: str) -> str: 7 | sha256 = hashlib.sha256() 8 | # Read the file from the path 9 | with open(file, "rb") as f: 10 | # Update the hash with the file content 11 | for byte_block in iter(lambda: f.read(4096), b""): 12 | sha256.update(byte_block) 13 | return sha256.hexdigest() 14 | 15 | # ===============LOW DISCREPANCY SEQUENCES================ 16 | 17 | PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] 18 | 19 | def radical_inverse(base, n): 20 | val = 0 21 | inv_base = 1.0 / base 22 | inv_base_n = inv_base 23 | while n > 0: 24 | digit = n % base 25 | val += digit * inv_base_n 26 | n //= base 27 | inv_base_n *= inv_base 28 | return val 29 | 30 | def halton_sequence(dim, n): 31 | return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] 32 | 33 | def hammersley_sequence(dim, n, num_samples): 34 | return [n / num_samples] + halton_sequence(dim - 1, n) 35 | 36 | def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)): 37 | u, v = hammersley_sequence(2, n, num_samples) 38 | u += offset[0] / num_samples 39 | v += offset[1] 40 | u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 41 | theta = np.arccos(1 - 2 * u) - np.pi / 2 42 | phi = v * 2 * np.pi 43 | return [phi, theta] 44 | -------------------------------------------------------------------------------- /dataset_toolkits/voxelize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import sys 4 | import importlib 5 | import argparse 6 | import pandas as pd 7 | from easydict import EasyDict as edict 8 | from functools import partial 9 | import numpy as np 10 | import open3d as o3d 11 | import utils3d 12 | 13 | 14 | def _voxelize(file, sha256, output_dir): 15 | mesh = o3d.io.read_triangle_mesh(os.path.join(output_dir, 'renders', sha256, 'mesh.ply')) 16 | # clamp vertices to the range [-0.5, 0.5] 17 | vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6) 18 | mesh.vertices = o3d.utility.Vector3dVector(vertices) 19 | voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) 20 | vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) 21 | assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds" 22 | vertices = (vertices + 0.5) / 64 - 0.5 23 | utils3d.io.write_ply(os.path.join(output_dir, 'voxels', f'{sha256}.ply'), vertices) 24 | return {'sha256': sha256, 'voxelized': True, 'num_voxels': len(vertices)} 25 | 26 | 27 | if __name__ == '__main__': 28 | dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--output_dir', type=str, required=True, 32 | help='Directory to save the metadata') 33 | parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, 34 | help='Filter objects with aesthetic score lower than this value') 35 | parser.add_argument('--instances', type=str, default=None, 36 | help='Instances to process') 37 | parser.add_argument('--num_views', type=int, default=150, 38 | help='Number of views to render') 39 | dataset_utils.add_args(parser) 40 | parser.add_argument('--rank', type=int, default=0) 41 | parser.add_argument('--world_size', type=int, default=1) 42 | parser.add_argument('--max_workers', type=int, default=None) 43 | opt = parser.parse_args(sys.argv[2:]) 44 | opt = edict(vars(opt)) 45 | 46 | os.makedirs(os.path.join(opt.output_dir, 'voxels'), exist_ok=True) 47 | 48 | # get file list 49 | if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): 50 | raise ValueError('metadata.csv not found') 51 | metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) 52 | if opt.instances is None: 53 | if opt.filter_low_aesthetic_score is not None: 54 | metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] 55 | if 'rendered' not in metadata.columns: 56 | raise ValueError('metadata.csv does not have "rendered" column, please run "build_metadata.py" first') 57 | metadata = metadata[metadata['rendered'] == True] 58 | if 'voxelized' in metadata.columns: 59 | metadata = metadata[metadata['voxelized'] == False] 60 | else: 61 | if os.path.exists(opt.instances): 62 | with open(opt.instances, 'r') as f: 63 | instances = f.read().splitlines() 64 | else: 65 | instances = opt.instances.split(',') 66 | metadata = metadata[metadata['sha256'].isin(instances)] 67 | 68 | start = len(metadata) * opt.rank // opt.world_size 69 | end = len(metadata) * (opt.rank + 1) // opt.world_size 70 | metadata = metadata[start:end] 71 | records = [] 72 | 73 | # filter out objects that are already processed 74 | for sha256 in copy.copy(metadata['sha256'].values): 75 | if os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')): 76 | pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] 77 | records.append({'sha256': sha256, 'voxelized': True, 'num_voxels': len(pts)}) 78 | metadata = metadata[metadata['sha256'] != sha256] 79 | 80 | print(f'Processing {len(metadata)} objects...') 81 | 82 | # process objects 83 | func = partial(_voxelize, output_dir=opt.output_dir) 84 | voxelized = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Voxelizing') 85 | voxelized = pd.concat([voxelized, pd.DataFrame.from_records(records)]) 86 | voxelized.to_csv(os.path.join(opt.output_dir, f'voxelized_{opt.rank}.csv'), index=False) 87 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' 3 | os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. 4 | # 'auto' is faster but will do benchmarking at the beginning. 5 | # Recommended to set to 'native' if run only once. 6 | 7 | import imageio 8 | from PIL import Image 9 | from trellis.pipelines import TrellisImageTo3DPipeline 10 | from trellis.utils import render_utils, postprocessing_utils 11 | 12 | # Load a pipeline from a model folder or a Hugging Face model hub. 13 | pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") 14 | pipeline.cuda() 15 | 16 | # Load an image 17 | image = Image.open("assets/example_image/T.png") 18 | 19 | # Run the pipeline 20 | outputs = pipeline.run( 21 | image, 22 | seed=1, 23 | # Optional parameters 24 | # sparse_structure_sampler_params={ 25 | # "steps": 12, 26 | # "cfg_strength": 7.5, 27 | # }, 28 | # slat_sampler_params={ 29 | # "steps": 12, 30 | # "cfg_strength": 3, 31 | # }, 32 | ) 33 | # outputs is a dictionary containing generated 3D assets in different formats: 34 | # - outputs['gaussian']: a list of 3D Gaussians 35 | # - outputs['radiance_field']: a list of radiance fields 36 | # - outputs['mesh']: a list of meshes 37 | 38 | # Render the outputs 39 | video = render_utils.render_video(outputs['gaussian'][0])['color'] 40 | imageio.mimsave("sample_gs.mp4", video, fps=30) 41 | video = render_utils.render_video(outputs['radiance_field'][0])['color'] 42 | imageio.mimsave("sample_rf.mp4", video, fps=30) 43 | video = render_utils.render_video(outputs['mesh'][0])['normal'] 44 | imageio.mimsave("sample_mesh.mp4", video, fps=30) 45 | 46 | # GLB files can be extracted from the outputs 47 | glb = postprocessing_utils.to_glb( 48 | outputs['gaussian'][0], 49 | outputs['mesh'][0], 50 | # Optional parameters 51 | simplify=0.95, # Ratio of triangles to remove in the simplification process 52 | texture_size=1024, # Size of the texture used for the GLB 53 | ) 54 | glb.export("sample.glb") 55 | 56 | # Save Gaussians as PLY files 57 | outputs['gaussian'][0].save_ply("sample.ply") 58 | -------------------------------------------------------------------------------- /example_multi_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' 3 | os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. 4 | # 'auto' is faster but will do benchmarking at the beginning. 5 | # Recommended to set to 'native' if run only once. 6 | 7 | import numpy as np 8 | import imageio 9 | from PIL import Image 10 | from trellis.pipelines import TrellisImageTo3DPipeline 11 | from trellis.utils import render_utils 12 | 13 | # Load a pipeline from a model folder or a Hugging Face model hub. 14 | pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") 15 | pipeline.cuda() 16 | 17 | # Load an image 18 | images = [ 19 | Image.open("assets/example_multi_image/character_1.png"), 20 | Image.open("assets/example_multi_image/character_2.png"), 21 | Image.open("assets/example_multi_image/character_3.png"), 22 | ] 23 | 24 | # Run the pipeline 25 | outputs = pipeline.run_multi_image( 26 | images, 27 | seed=1, 28 | # Optional parameters 29 | sparse_structure_sampler_params={ 30 | "steps": 12, 31 | "cfg_strength": 7.5, 32 | }, 33 | slat_sampler_params={ 34 | "steps": 12, 35 | "cfg_strength": 3, 36 | }, 37 | ) 38 | # outputs is a dictionary containing generated 3D assets in different formats: 39 | # - outputs['gaussian']: a list of 3D Gaussians 40 | # - outputs['radiance_field']: a list of radiance fields 41 | # - outputs['mesh']: a list of meshes 42 | 43 | video_gs = render_utils.render_video(outputs['gaussian'][0])['color'] 44 | video_mesh = render_utils.render_video(outputs['mesh'][0])['normal'] 45 | video = [np.concatenate([frame_gs, frame_mesh], axis=1) for frame_gs, frame_mesh in zip(video_gs, video_mesh)] 46 | imageio.mimsave("sample_multi.mp4", video, fps=30) 47 | -------------------------------------------------------------------------------- /extensions/vox2seq/benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import vox2seq 4 | 5 | 6 | if __name__ == "__main__": 7 | stats = { 8 | 'z_order_cuda': [], 9 | 'z_order_pytorch': [], 10 | 'hilbert_cuda': [], 11 | 'hilbert_pytorch': [], 12 | } 13 | RES = [16, 32, 64, 128, 256] 14 | for res in RES: 15 | coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res)) 16 | coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda() 17 | 18 | start = time.time() 19 | for _ in range(100): 20 | code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda() 21 | torch.cuda.synchronize() 22 | stats['z_order_cuda'].append((time.time() - start) / 100) 23 | 24 | start = time.time() 25 | for _ in range(100): 26 | code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda() 27 | torch.cuda.synchronize() 28 | stats['z_order_pytorch'].append((time.time() - start) / 100) 29 | 30 | start = time.time() 31 | for _ in range(100): 32 | code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda() 33 | torch.cuda.synchronize() 34 | stats['hilbert_cuda'].append((time.time() - start) / 100) 35 | 36 | start = time.time() 37 | for _ in range(100): 38 | code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda() 39 | torch.cuda.synchronize() 40 | stats['hilbert_pytorch'].append((time.time() - start) / 100) 41 | 42 | print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}") 43 | for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']): 44 | print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}") 45 | 46 | -------------------------------------------------------------------------------- /extensions/vox2seq/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | os.path.dirname(os.path.abspath(__file__)) 16 | 17 | setup( 18 | name="vox2seq", 19 | packages=['vox2seq', 'vox2seq.pytorch'], 20 | ext_modules=[ 21 | CUDAExtension( 22 | name="vox2seq._C", 23 | sources=[ 24 | "src/api.cu", 25 | "src/z_order.cu", 26 | "src/hilbert.cu", 27 | "src/ext.cpp", 28 | ], 29 | ) 30 | ], 31 | cmdclass={ 32 | 'build_ext': BuildExtension 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /extensions/vox2seq/src/api.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "api.h" 3 | #include "z_order.h" 4 | #include "hilbert.h" 5 | 6 | 7 | torch::Tensor 8 | z_order_encode( 9 | const torch::Tensor& x, 10 | const torch::Tensor& y, 11 | const torch::Tensor& z 12 | ) { 13 | // Allocate output tensor 14 | torch::Tensor codes = torch::empty_like(x); 15 | 16 | // Call CUDA kernel 17 | z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( 18 | x.size(0), 19 | reinterpret_cast(x.contiguous().data_ptr()), 20 | reinterpret_cast(y.contiguous().data_ptr()), 21 | reinterpret_cast(z.contiguous().data_ptr()), 22 | reinterpret_cast(codes.data_ptr()) 23 | ); 24 | 25 | return codes; 26 | } 27 | 28 | 29 | std::tuple 30 | z_order_decode( 31 | const torch::Tensor& codes 32 | ) { 33 | // Allocate output tensors 34 | torch::Tensor x = torch::empty_like(codes); 35 | torch::Tensor y = torch::empty_like(codes); 36 | torch::Tensor z = torch::empty_like(codes); 37 | 38 | // Call CUDA kernel 39 | z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( 40 | codes.size(0), 41 | reinterpret_cast(codes.contiguous().data_ptr()), 42 | reinterpret_cast(x.data_ptr()), 43 | reinterpret_cast(y.data_ptr()), 44 | reinterpret_cast(z.data_ptr()) 45 | ); 46 | 47 | return std::make_tuple(x, y, z); 48 | } 49 | 50 | 51 | torch::Tensor 52 | hilbert_encode( 53 | const torch::Tensor& x, 54 | const torch::Tensor& y, 55 | const torch::Tensor& z 56 | ) { 57 | // Allocate output tensor 58 | torch::Tensor codes = torch::empty_like(x); 59 | 60 | // Call CUDA kernel 61 | hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( 62 | x.size(0), 63 | reinterpret_cast(x.contiguous().data_ptr()), 64 | reinterpret_cast(y.contiguous().data_ptr()), 65 | reinterpret_cast(z.contiguous().data_ptr()), 66 | reinterpret_cast(codes.data_ptr()) 67 | ); 68 | 69 | return codes; 70 | } 71 | 72 | 73 | std::tuple 74 | hilbert_decode( 75 | const torch::Tensor& codes 76 | ) { 77 | // Allocate output tensors 78 | torch::Tensor x = torch::empty_like(codes); 79 | torch::Tensor y = torch::empty_like(codes); 80 | torch::Tensor z = torch::empty_like(codes); 81 | 82 | // Call CUDA kernel 83 | hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( 84 | codes.size(0), 85 | reinterpret_cast(codes.contiguous().data_ptr()), 86 | reinterpret_cast(x.data_ptr()), 87 | reinterpret_cast(y.data_ptr()), 88 | reinterpret_cast(z.data_ptr()) 89 | ); 90 | 91 | return std::make_tuple(x, y, z); 92 | } 93 | -------------------------------------------------------------------------------- /extensions/vox2seq/src/api.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Serialize a voxel grid 3 | * 4 | * Copyright (C) 2024, Jianfeng XIANG 5 | * All rights reserved. 6 | * 7 | * Licensed under The MIT License [see LICENSE for details] 8 | * 9 | * Written by Jianfeng XIANG 10 | */ 11 | 12 | #pragma once 13 | #include 14 | 15 | 16 | #define BLOCK_SIZE 256 17 | 18 | 19 | /** 20 | * Z-order encode 3D points 21 | * 22 | * @param x [N] tensor containing the x coordinates 23 | * @param y [N] tensor containing the y coordinates 24 | * @param z [N] tensor containing the z coordinates 25 | * 26 | * @return [N] tensor containing the z-order encoded values 27 | */ 28 | torch::Tensor 29 | z_order_encode( 30 | const torch::Tensor& x, 31 | const torch::Tensor& y, 32 | const torch::Tensor& z 33 | ); 34 | 35 | 36 | /** 37 | * Z-order decode 3D points 38 | * 39 | * @param codes [N] tensor containing the z-order encoded values 40 | * 41 | * @return 3 tensors [N] containing the x, y, z coordinates 42 | */ 43 | std::tuple 44 | z_order_decode( 45 | const torch::Tensor& codes 46 | ); 47 | 48 | 49 | /** 50 | * Hilbert encode 3D points 51 | * 52 | * @param x [N] tensor containing the x coordinates 53 | * @param y [N] tensor containing the y coordinates 54 | * @param z [N] tensor containing the z coordinates 55 | * 56 | * @return [N] tensor containing the Hilbert encoded values 57 | */ 58 | torch::Tensor 59 | hilbert_encode( 60 | const torch::Tensor& x, 61 | const torch::Tensor& y, 62 | const torch::Tensor& z 63 | ); 64 | 65 | 66 | /** 67 | * Hilbert decode 3D points 68 | * 69 | * @param codes [N] tensor containing the Hilbert encoded values 70 | * 71 | * @return 3 tensors [N] containing the x, y, z coordinates 72 | */ 73 | std::tuple 74 | hilbert_decode( 75 | const torch::Tensor& codes 76 | ); 77 | -------------------------------------------------------------------------------- /extensions/vox2seq/src/ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "api.h" 3 | 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("z_order_encode", &z_order_encode); 7 | m.def("z_order_decode", &z_order_decode); 8 | m.def("hilbert_encode", &hilbert_encode); 9 | m.def("hilbert_decode", &hilbert_decode); 10 | } -------------------------------------------------------------------------------- /extensions/vox2seq/src/hilbert.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | namespace cg = cooperative_groups; 8 | 9 | #include "hilbert.h" 10 | 11 | 12 | // Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. 13 | static __device__ uint32_t expandBits(uint32_t v) 14 | { 15 | v = (v * 0x00010001u) & 0xFF0000FFu; 16 | v = (v * 0x00000101u) & 0x0F00F00Fu; 17 | v = (v * 0x00000011u) & 0xC30C30C3u; 18 | v = (v * 0x00000005u) & 0x49249249u; 19 | return v; 20 | } 21 | 22 | 23 | // Removes 2 zeros after each bit in a 30-bit integer. 24 | static __device__ uint32_t extractBits(uint32_t v) 25 | { 26 | v = v & 0x49249249; 27 | v = (v ^ (v >> 2)) & 0x030C30C3u; 28 | v = (v ^ (v >> 4)) & 0x0300F00Fu; 29 | v = (v ^ (v >> 8)) & 0x030000FFu; 30 | v = (v ^ (v >> 16)) & 0x000003FFu; 31 | return v; 32 | } 33 | 34 | 35 | __global__ void hilbert_encode_cuda( 36 | size_t N, 37 | const uint32_t* x, 38 | const uint32_t* y, 39 | const uint32_t* z, 40 | uint32_t* codes 41 | ) { 42 | size_t thread_id = cg::this_grid().thread_rank(); 43 | if (thread_id >= N) return; 44 | 45 | uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]}; 46 | 47 | uint32_t m = 1 << 9, q, p, t; 48 | 49 | // Inverse undo excess work 50 | q = m; 51 | while (q > 1) { 52 | p = q - 1; 53 | for (int i = 0; i < 3; i++) { 54 | if (point[i] & q) { 55 | point[0] ^= p; // invert 56 | } else { 57 | t = (point[0] ^ point[i]) & p; 58 | point[0] ^= t; 59 | point[i] ^= t; 60 | } 61 | } 62 | q >>= 1; 63 | } 64 | 65 | // Gray encode 66 | for (int i = 1; i < 3; i++) { 67 | point[i] ^= point[i - 1]; 68 | } 69 | t = 0; 70 | q = m; 71 | while (q > 1) { 72 | if (point[2] & q) { 73 | t ^= q - 1; 74 | } 75 | q >>= 1; 76 | } 77 | for (int i = 0; i < 3; i++) { 78 | point[i] ^= t; 79 | } 80 | 81 | // Convert to 3D Hilbert code 82 | uint32_t xx = expandBits(point[0]); 83 | uint32_t yy = expandBits(point[1]); 84 | uint32_t zz = expandBits(point[2]); 85 | 86 | codes[thread_id] = xx * 4 + yy * 2 + zz; 87 | } 88 | 89 | 90 | __global__ void hilbert_decode_cuda( 91 | size_t N, 92 | const uint32_t* codes, 93 | uint32_t* x, 94 | uint32_t* y, 95 | uint32_t* z 96 | ) { 97 | size_t thread_id = cg::this_grid().thread_rank(); 98 | if (thread_id >= N) return; 99 | 100 | uint32_t point[3]; 101 | point[0] = extractBits(codes[thread_id] >> 2); 102 | point[1] = extractBits(codes[thread_id] >> 1); 103 | point[2] = extractBits(codes[thread_id]); 104 | 105 | uint32_t m = 2 << 9, q, p, t; 106 | 107 | // Gray decode by H ^ (H/2) 108 | t = point[2] >> 1; 109 | for (int i = 2; i > 0; i--) { 110 | point[i] ^= point[i - 1]; 111 | } 112 | point[0] ^= t; 113 | 114 | // Undo excess work 115 | q = 2; 116 | while (q != m) { 117 | p = q - 1; 118 | for (int i = 2; i >= 0; i--) { 119 | if (point[i] & q) { 120 | point[0] ^= p; 121 | } else { 122 | t = (point[0] ^ point[i]) & p; 123 | point[0] ^= t; 124 | point[i] ^= t; 125 | } 126 | } 127 | q <<= 1; 128 | } 129 | 130 | x[thread_id] = point[0]; 131 | y[thread_id] = point[1]; 132 | z[thread_id] = point[2]; 133 | } 134 | -------------------------------------------------------------------------------- /extensions/vox2seq/src/hilbert.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /** 4 | * Hilbert encode 3D points 5 | * 6 | * @param x [N] tensor containing the x coordinates 7 | * @param y [N] tensor containing the y coordinates 8 | * @param z [N] tensor containing the z coordinates 9 | * 10 | * @return [N] tensor containing the z-order encoded values 11 | */ 12 | __global__ void hilbert_encode_cuda( 13 | size_t N, 14 | const uint32_t* x, 15 | const uint32_t* y, 16 | const uint32_t* z, 17 | uint32_t* codes 18 | ); 19 | 20 | 21 | /** 22 | * Hilbert decode 3D points 23 | * 24 | * @param codes [N] tensor containing the z-order encoded values 25 | * @param x [N] tensor containing the x coordinates 26 | * @param y [N] tensor containing the y coordinates 27 | * @param z [N] tensor containing the z coordinates 28 | */ 29 | __global__ void hilbert_decode_cuda( 30 | size_t N, 31 | const uint32_t* codes, 32 | uint32_t* x, 33 | uint32_t* y, 34 | uint32_t* z 35 | ); 36 | -------------------------------------------------------------------------------- /extensions/vox2seq/src/z_order.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | namespace cg = cooperative_groups; 8 | 9 | #include "z_order.h" 10 | 11 | 12 | // Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. 13 | static __device__ uint32_t expandBits(uint32_t v) 14 | { 15 | v = (v * 0x00010001u) & 0xFF0000FFu; 16 | v = (v * 0x00000101u) & 0x0F00F00Fu; 17 | v = (v * 0x00000011u) & 0xC30C30C3u; 18 | v = (v * 0x00000005u) & 0x49249249u; 19 | return v; 20 | } 21 | 22 | 23 | // Removes 2 zeros after each bit in a 30-bit integer. 24 | static __device__ uint32_t extractBits(uint32_t v) 25 | { 26 | v = v & 0x49249249; 27 | v = (v ^ (v >> 2)) & 0x030C30C3u; 28 | v = (v ^ (v >> 4)) & 0x0300F00Fu; 29 | v = (v ^ (v >> 8)) & 0x030000FFu; 30 | v = (v ^ (v >> 16)) & 0x000003FFu; 31 | return v; 32 | } 33 | 34 | 35 | __global__ void z_order_encode_cuda( 36 | size_t N, 37 | const uint32_t* x, 38 | const uint32_t* y, 39 | const uint32_t* z, 40 | uint32_t* codes 41 | ) { 42 | size_t thread_id = cg::this_grid().thread_rank(); 43 | if (thread_id >= N) return; 44 | 45 | uint32_t xx = expandBits(x[thread_id]); 46 | uint32_t yy = expandBits(y[thread_id]); 47 | uint32_t zz = expandBits(z[thread_id]); 48 | 49 | codes[thread_id] = xx * 4 + yy * 2 + zz; 50 | } 51 | 52 | 53 | __global__ void z_order_decode_cuda( 54 | size_t N, 55 | const uint32_t* codes, 56 | uint32_t* x, 57 | uint32_t* y, 58 | uint32_t* z 59 | ) { 60 | size_t thread_id = cg::this_grid().thread_rank(); 61 | if (thread_id >= N) return; 62 | 63 | x[thread_id] = extractBits(codes[thread_id] >> 2); 64 | y[thread_id] = extractBits(codes[thread_id] >> 1); 65 | z[thread_id] = extractBits(codes[thread_id]); 66 | } 67 | -------------------------------------------------------------------------------- /extensions/vox2seq/src/z_order.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /** 4 | * Z-order encode 3D points 5 | * 6 | * @param x [N] tensor containing the x coordinates 7 | * @param y [N] tensor containing the y coordinates 8 | * @param z [N] tensor containing the z coordinates 9 | * 10 | * @return [N] tensor containing the z-order encoded values 11 | */ 12 | __global__ void z_order_encode_cuda( 13 | size_t N, 14 | const uint32_t* x, 15 | const uint32_t* y, 16 | const uint32_t* z, 17 | uint32_t* codes 18 | ); 19 | 20 | 21 | /** 22 | * Z-order decode 3D points 23 | * 24 | * @param codes [N] tensor containing the z-order encoded values 25 | * @param x [N] tensor containing the x coordinates 26 | * @param y [N] tensor containing the y coordinates 27 | * @param z [N] tensor containing the z coordinates 28 | */ 29 | __global__ void z_order_decode_cuda( 30 | size_t N, 31 | const uint32_t* codes, 32 | uint32_t* x, 33 | uint32_t* y, 34 | uint32_t* z 35 | ); 36 | -------------------------------------------------------------------------------- /extensions/vox2seq/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import vox2seq 3 | 4 | 5 | if __name__ == "__main__": 6 | RES = 256 7 | coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES)) 8 | coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda() 9 | code_z_cuda = vox2seq.encode(coords, mode='z_order') 10 | code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order') 11 | code_h_cuda = vox2seq.encode(coords, mode='hilbert') 12 | code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert') 13 | assert torch.equal(code_z_cuda, code_z_pytorch) 14 | assert torch.equal(code_h_cuda, code_h_pytorch) 15 | 16 | code = torch.arange(RES**3).int().cuda() 17 | coords_z_cuda = vox2seq.decode(code, mode='z_order') 18 | coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order') 19 | coords_h_cuda = vox2seq.decode(code, mode='hilbert') 20 | coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert') 21 | assert torch.equal(coords_z_cuda, coords_z_pytorch) 22 | assert torch.equal(coords_h_cuda, coords_h_pytorch) 23 | 24 | print("All tests passed.") 25 | 26 | -------------------------------------------------------------------------------- /extensions/vox2seq/vox2seq/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import * 3 | import torch 4 | from . import _C 5 | from . import pytorch 6 | 7 | 8 | @torch.no_grad() 9 | def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: 10 | """ 11 | Encodes 3D coordinates into a 30-bit code. 12 | 13 | Args: 14 | coords: a tensor of shape [N, 3] containing the 3D coordinates. 15 | permute: the permutation of the coordinates. 16 | mode: the encoding mode to use. 17 | """ 18 | assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]" 19 | x = coords[:, permute[0]].int() 20 | y = coords[:, permute[1]].int() 21 | z = coords[:, permute[2]].int() 22 | if mode == 'z_order': 23 | return _C.z_order_encode(x, y, z) 24 | elif mode == 'hilbert': 25 | return _C.hilbert_encode(x, y, z) 26 | else: 27 | raise ValueError(f"Unknown encoding mode: {mode}") 28 | 29 | 30 | @torch.no_grad() 31 | def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: 32 | """ 33 | Decodes a 30-bit code into 3D coordinates. 34 | 35 | Args: 36 | code: a tensor of shape [N] containing the 30-bit code. 37 | permute: the permutation of the coordinates. 38 | mode: the decoding mode to use. 39 | """ 40 | assert code.ndim == 1, "Input code must be of shape [N]" 41 | if mode == 'z_order': 42 | coords = _C.z_order_decode(code) 43 | elif mode == 'hilbert': 44 | coords = _C.hilbert_decode(code) 45 | else: 46 | raise ValueError(f"Unknown decoding mode: {mode}") 47 | x = coords[permute.index(0)] 48 | y = coords[permute.index(1)] 49 | z = coords[permute.index(2)] 50 | return torch.stack([x, y, z], dim=-1) 51 | -------------------------------------------------------------------------------- /extensions/vox2seq/vox2seq/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import * 3 | 4 | from .default import ( 5 | encode, 6 | decode, 7 | z_order_encode, 8 | z_order_decode, 9 | hilbert_encode, 10 | hilbert_decode, 11 | ) 12 | 13 | 14 | @torch.no_grad() 15 | def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: 16 | """ 17 | Encodes 3D coordinates into a 30-bit code. 18 | 19 | Args: 20 | coords: a tensor of shape [N, 3] containing the 3D coordinates. 21 | permute: the permutation of the coordinates. 22 | mode: the encoding mode to use. 23 | """ 24 | if mode == 'z_order': 25 | return z_order_encode(coords[:, permute], depth=10).int() 26 | elif mode == 'hilbert': 27 | return hilbert_encode(coords[:, permute], depth=10).int() 28 | else: 29 | raise ValueError(f"Unknown encoding mode: {mode}") 30 | 31 | 32 | @torch.no_grad() 33 | def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: 34 | """ 35 | Decodes a 30-bit code into 3D coordinates. 36 | 37 | Args: 38 | code: a tensor of shape [N] containing the 30-bit code. 39 | permute: the permutation of the coordinates. 40 | mode: the decoding mode to use. 41 | """ 42 | if mode == 'z_order': 43 | return z_order_decode(code, depth=10)[:, permute].float() 44 | elif mode == 'hilbert': 45 | return hilbert_decode(code, depth=10)[:, permute].float() 46 | else: 47 | raise ValueError(f"Unknown decoding mode: {mode}") 48 | -------------------------------------------------------------------------------- /extensions/vox2seq/vox2seq/pytorch/default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .z_order import xyz2key as z_order_encode_ 3 | from .z_order import key2xyz as z_order_decode_ 4 | from .hilbert import encode as hilbert_encode_ 5 | from .hilbert import decode as hilbert_decode_ 6 | 7 | 8 | @torch.inference_mode() 9 | def encode(grid_coord, batch=None, depth=16, order="z"): 10 | assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} 11 | if order == "z": 12 | code = z_order_encode(grid_coord, depth=depth) 13 | elif order == "z-trans": 14 | code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) 15 | elif order == "hilbert": 16 | code = hilbert_encode(grid_coord, depth=depth) 17 | elif order == "hilbert-trans": 18 | code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) 19 | else: 20 | raise NotImplementedError 21 | if batch is not None: 22 | batch = batch.long() 23 | code = batch << depth * 3 | code 24 | return code 25 | 26 | 27 | @torch.inference_mode() 28 | def decode(code, depth=16, order="z"): 29 | assert order in {"z", "hilbert"} 30 | batch = code >> depth * 3 31 | code = code & ((1 << depth * 3) - 1) 32 | if order == "z": 33 | grid_coord = z_order_decode(code, depth=depth) 34 | elif order == "hilbert": 35 | grid_coord = hilbert_decode(code, depth=depth) 36 | else: 37 | raise NotImplementedError 38 | return grid_coord, batch 39 | 40 | 41 | def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): 42 | x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() 43 | # we block the support to batch, maintain batched code in Point class 44 | code = z_order_encode_(x, y, z, b=None, depth=depth) 45 | return code 46 | 47 | 48 | def z_order_decode(code: torch.Tensor, depth): 49 | x, y, z, _ = z_order_decode_(code, depth=depth) 50 | grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) 51 | return grid_coord 52 | 53 | 54 | def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): 55 | return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) 56 | 57 | 58 | def hilbert_decode(code: torch.Tensor, depth: int = 16): 59 | return hilbert_decode_(code, num_dims=3, num_bits=depth) -------------------------------------------------------------------------------- /extensions/vox2seq/vox2seq/pytorch/z_order.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import Optional, Union 10 | 11 | 12 | class KeyLUT: 13 | def __init__(self): 14 | r256 = torch.arange(256, dtype=torch.int64) 15 | r512 = torch.arange(512, dtype=torch.int64) 16 | zero = torch.zeros(256, dtype=torch.int64) 17 | device = torch.device("cpu") 18 | 19 | self._encode = { 20 | device: ( 21 | self.xyz2key(r256, zero, zero, 8), 22 | self.xyz2key(zero, r256, zero, 8), 23 | self.xyz2key(zero, zero, r256, 8), 24 | ) 25 | } 26 | self._decode = {device: self.key2xyz(r512, 9)} 27 | 28 | def encode_lut(self, device=torch.device("cpu")): 29 | if device not in self._encode: 30 | cpu = torch.device("cpu") 31 | self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) 32 | return self._encode[device] 33 | 34 | def decode_lut(self, device=torch.device("cpu")): 35 | if device not in self._decode: 36 | cpu = torch.device("cpu") 37 | self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) 38 | return self._decode[device] 39 | 40 | def xyz2key(self, x, y, z, depth): 41 | key = torch.zeros_like(x) 42 | for i in range(depth): 43 | mask = 1 << i 44 | key = ( 45 | key 46 | | ((x & mask) << (2 * i + 2)) 47 | | ((y & mask) << (2 * i + 1)) 48 | | ((z & mask) << (2 * i + 0)) 49 | ) 50 | return key 51 | 52 | def key2xyz(self, key, depth): 53 | x = torch.zeros_like(key) 54 | y = torch.zeros_like(key) 55 | z = torch.zeros_like(key) 56 | for i in range(depth): 57 | x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) 58 | y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) 59 | z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) 60 | return x, y, z 61 | 62 | 63 | _key_lut = KeyLUT() 64 | 65 | 66 | def xyz2key( 67 | x: torch.Tensor, 68 | y: torch.Tensor, 69 | z: torch.Tensor, 70 | b: Optional[Union[torch.Tensor, int]] = None, 71 | depth: int = 16, 72 | ): 73 | r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys 74 | based on pre-computed look up tables. The speed of this function is much 75 | faster than the method based on for-loop. 76 | 77 | Args: 78 | x (torch.Tensor): The x coordinate. 79 | y (torch.Tensor): The y coordinate. 80 | z (torch.Tensor): The z coordinate. 81 | b (torch.Tensor or int): The batch index of the coordinates, and should be 82 | smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of 83 | :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. 84 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 85 | """ 86 | 87 | EX, EY, EZ = _key_lut.encode_lut(x.device) 88 | x, y, z = x.long(), y.long(), z.long() 89 | 90 | mask = 255 if depth > 8 else (1 << depth) - 1 91 | key = EX[x & mask] | EY[y & mask] | EZ[z & mask] 92 | if depth > 8: 93 | mask = (1 << (depth - 8)) - 1 94 | key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] 95 | key = key16 << 24 | key 96 | 97 | if b is not None: 98 | b = b.long() 99 | key = b << 48 | key 100 | 101 | return key 102 | 103 | 104 | def key2xyz(key: torch.Tensor, depth: int = 16): 105 | r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates 106 | and the batch index based on pre-computed look up tables. 107 | 108 | Args: 109 | key (torch.Tensor): The shuffled key. 110 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 111 | """ 112 | 113 | DX, DY, DZ = _key_lut.decode_lut(key.device) 114 | x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) 115 | 116 | b = key >> 48 117 | key = key & ((1 << 48) - 1) 118 | 119 | n = (depth + 2) // 3 120 | for i in range(n): 121 | k = key >> (i * 9) & 511 122 | x = x | (DX[k] << (i * 3)) 123 | y = y | (DY[k] << (i * 3)) 124 | z = z | (DZ[k] << (i * 3)) 125 | 126 | return x, y, z, b -------------------------------------------------------------------------------- /trellis/__init__.py: -------------------------------------------------------------------------------- 1 | from . import models 2 | from . import modules 3 | from . import pipelines 4 | from . import renderers 5 | from . import representations 6 | from . import utils 7 | -------------------------------------------------------------------------------- /trellis/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | __attributes = { 4 | 'SparseStructureEncoder': 'sparse_structure_vae', 5 | 'SparseStructureDecoder': 'sparse_structure_vae', 6 | 'SparseStructureFlowModel': 'sparse_structure_flow', 7 | 'SLatEncoder': 'structured_latent_vae', 8 | 'SLatGaussianDecoder': 'structured_latent_vae', 9 | 'SLatRadianceFieldDecoder': 'structured_latent_vae', 10 | 'SLatMeshDecoder': 'structured_latent_vae', 11 | 'SLatFlowModel': 'structured_latent_flow', 12 | } 13 | 14 | __submodules = [] 15 | 16 | __all__ = list(__attributes.keys()) + __submodules 17 | 18 | def __getattr__(name): 19 | if name not in globals(): 20 | if name in __attributes: 21 | module_name = __attributes[name] 22 | module = importlib.import_module(f".{module_name}", __name__) 23 | globals()[name] = getattr(module, name) 24 | elif name in __submodules: 25 | module = importlib.import_module(f".{name}", __name__) 26 | globals()[name] = module 27 | else: 28 | raise AttributeError(f"module {__name__} has no attribute {name}") 29 | return globals()[name] 30 | 31 | 32 | def from_pretrained(path: str, **kwargs): 33 | """ 34 | Load a model from a pretrained checkpoint. 35 | 36 | Args: 37 | path: The path to the checkpoint. Can be either local path or a Hugging Face model name. 38 | NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. 39 | **kwargs: Additional arguments for the model constructor. 40 | """ 41 | import os 42 | import json 43 | from safetensors.torch import load_file 44 | is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") 45 | 46 | if is_local: 47 | config_file = f"{path}.json" 48 | model_file = f"{path}.safetensors" 49 | else: 50 | from huggingface_hub import hf_hub_download 51 | path_parts = path.split('/') 52 | repo_id = f'{path_parts[0]}/{path_parts[1]}' 53 | model_name = '/'.join(path_parts[2:]) 54 | config_file = hf_hub_download(repo_id, f"{model_name}.json") 55 | model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") 56 | 57 | with open(config_file, 'r') as f: 58 | config = json.load(f) 59 | model = __getattr__(config['name'])(**config['args'], **kwargs) 60 | model.load_state_dict(load_file(model_file)) 61 | 62 | return model 63 | 64 | 65 | # For Pylance 66 | if __name__ == '__main__': 67 | from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder 68 | from .sparse_structure_flow import SparseStructureFlowModel 69 | from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder 70 | from .structured_latent_flow import SLatFlowModel 71 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import SLatEncoder 2 | from .decoder_gs import SLatGaussianDecoder 3 | from .decoder_rf import SLatRadianceFieldDecoder 4 | from .decoder_mesh import SLatMeshDecoder 5 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/base.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from ...modules.utils import convert_module_to_f16, convert_module_to_f32 5 | from ...modules import sparse as sp 6 | from ...modules.transformer import AbsolutePositionEmbedder 7 | from ...modules.sparse.transformer import SparseTransformerBlock 8 | 9 | 10 | def block_attn_config(self): 11 | """ 12 | Return the attention configuration of the model. 13 | """ 14 | for i in range(self.num_blocks): 15 | if self.attn_mode == "shift_window": 16 | yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER 17 | elif self.attn_mode == "shift_sequence": 18 | yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER 19 | elif self.attn_mode == "shift_order": 20 | yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] 21 | elif self.attn_mode == "full": 22 | yield "full", None, None, None, None 23 | elif self.attn_mode == "swin": 24 | yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None 25 | 26 | 27 | class SparseTransformerBase(nn.Module): 28 | """ 29 | Sparse Transformer without output layers. 30 | Serve as the base class for encoder and decoder. 31 | """ 32 | def __init__( 33 | self, 34 | in_channels: int, 35 | model_channels: int, 36 | num_blocks: int, 37 | num_heads: Optional[int] = None, 38 | num_head_channels: Optional[int] = 64, 39 | mlp_ratio: float = 4.0, 40 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 41 | window_size: Optional[int] = None, 42 | pe_mode: Literal["ape", "rope"] = "ape", 43 | use_fp16: bool = False, 44 | use_checkpoint: bool = False, 45 | qk_rms_norm: bool = False, 46 | ): 47 | super().__init__() 48 | self.in_channels = in_channels 49 | self.model_channels = model_channels 50 | self.num_blocks = num_blocks 51 | self.window_size = window_size 52 | self.num_heads = num_heads or model_channels // num_head_channels 53 | self.mlp_ratio = mlp_ratio 54 | self.attn_mode = attn_mode 55 | self.pe_mode = pe_mode 56 | self.use_fp16 = use_fp16 57 | self.use_checkpoint = use_checkpoint 58 | self.qk_rms_norm = qk_rms_norm 59 | self.dtype = torch.float16 if use_fp16 else torch.float32 60 | 61 | if pe_mode == "ape": 62 | self.pos_embedder = AbsolutePositionEmbedder(model_channels) 63 | 64 | self.input_layer = sp.SparseLinear(in_channels, model_channels) 65 | self.blocks = nn.ModuleList([ 66 | SparseTransformerBlock( 67 | model_channels, 68 | num_heads=self.num_heads, 69 | mlp_ratio=self.mlp_ratio, 70 | attn_mode=attn_mode, 71 | window_size=window_size, 72 | shift_sequence=shift_sequence, 73 | shift_window=shift_window, 74 | serialize_mode=serialize_mode, 75 | use_checkpoint=self.use_checkpoint, 76 | use_rope=(pe_mode == "rope"), 77 | qk_rms_norm=self.qk_rms_norm, 78 | ) 79 | for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) 80 | ]) 81 | 82 | @property 83 | def device(self) -> torch.device: 84 | """ 85 | Return the device of the model. 86 | """ 87 | return next(self.parameters()).device 88 | 89 | def convert_to_fp16(self) -> None: 90 | """ 91 | Convert the torso of the model to float16. 92 | """ 93 | self.blocks.apply(convert_module_to_f16) 94 | 95 | def convert_to_fp32(self) -> None: 96 | """ 97 | Convert the torso of the model to float32. 98 | """ 99 | self.blocks.apply(convert_module_to_f32) 100 | 101 | def initialize_weights(self) -> None: 102 | # Initialize transformer layers: 103 | def _basic_init(module): 104 | if isinstance(module, nn.Linear): 105 | torch.nn.init.xavier_uniform_(module.weight) 106 | if module.bias is not None: 107 | nn.init.constant_(module.bias, 0) 108 | self.apply(_basic_init) 109 | 110 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: 111 | h = self.input_layer(x) 112 | if self.pe_mode == "ape": 113 | h = h + self.pos_embedder(x.coords[:, 1:]) 114 | h = h.type(self.dtype) 115 | for block in self.blocks: 116 | h = block(h) 117 | return h 118 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/decoder_gs.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ...modules import sparse as sp 6 | from ...utils.random_utils import hammersley_sequence 7 | from .base import SparseTransformerBase 8 | from ...representations import Gaussian 9 | 10 | 11 | class SLatGaussianDecoder(SparseTransformerBase): 12 | def __init__( 13 | self, 14 | resolution: int, 15 | model_channels: int, 16 | latent_channels: int, 17 | num_blocks: int, 18 | num_heads: Optional[int] = None, 19 | num_head_channels: Optional[int] = 64, 20 | mlp_ratio: float = 4, 21 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 22 | window_size: int = 8, 23 | pe_mode: Literal["ape", "rope"] = "ape", 24 | use_fp16: bool = False, 25 | use_checkpoint: bool = False, 26 | qk_rms_norm: bool = False, 27 | representation_config: dict = None, 28 | ): 29 | super().__init__( 30 | in_channels=latent_channels, 31 | model_channels=model_channels, 32 | num_blocks=num_blocks, 33 | num_heads=num_heads, 34 | num_head_channels=num_head_channels, 35 | mlp_ratio=mlp_ratio, 36 | attn_mode=attn_mode, 37 | window_size=window_size, 38 | pe_mode=pe_mode, 39 | use_fp16=use_fp16, 40 | use_checkpoint=use_checkpoint, 41 | qk_rms_norm=qk_rms_norm, 42 | ) 43 | self.resolution = resolution 44 | self.rep_config = representation_config 45 | self._calc_layout() 46 | self.out_layer = sp.SparseLinear(model_channels, self.out_channels) 47 | self._build_perturbation() 48 | 49 | self.initialize_weights() 50 | if use_fp16: 51 | self.convert_to_fp16() 52 | 53 | def initialize_weights(self) -> None: 54 | super().initialize_weights() 55 | # Zero-out output layers: 56 | nn.init.constant_(self.out_layer.weight, 0) 57 | nn.init.constant_(self.out_layer.bias, 0) 58 | 59 | def _build_perturbation(self) -> None: 60 | perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])] 61 | perturbation = torch.tensor(perturbation).float() * 2 - 1 62 | perturbation = perturbation / self.rep_config['voxel_size'] 63 | perturbation = torch.atanh(perturbation).to(self.device) 64 | self.register_buffer('offset_perturbation', perturbation) 65 | 66 | def _calc_layout(self) -> None: 67 | self.layout = { 68 | '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, 69 | '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3}, 70 | '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, 71 | '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4}, 72 | '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']}, 73 | } 74 | start = 0 75 | for k, v in self.layout.items(): 76 | v['range'] = (start, start + v['size']) 77 | start += v['size'] 78 | self.out_channels = start 79 | 80 | def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: 81 | """ 82 | Convert a batch of network outputs to 3D representations. 83 | 84 | Args: 85 | x: The [N x * x C] sparse tensor output by the network. 86 | 87 | Returns: 88 | list of representations 89 | """ 90 | ret = [] 91 | for i in range(x.shape[0]): 92 | representation = Gaussian( 93 | sh_degree=0, 94 | aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], 95 | mininum_kernel_size = self.rep_config['3d_filter_kernel_size'], 96 | scaling_bias = self.rep_config['scaling_bias'], 97 | opacity_bias = self.rep_config['opacity_bias'], 98 | scaling_activation = self.rep_config['scaling_activation'] 99 | ) 100 | xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution 101 | for k, v in self.layout.items(): 102 | if k == '_xyz': 103 | offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']) 104 | offset = offset * self.rep_config['lr'][k] 105 | if self.rep_config['perturb_offset']: 106 | offset = offset + self.offset_perturbation 107 | offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size'] 108 | _xyz = xyz.unsqueeze(1) + offset 109 | setattr(representation, k, _xyz.flatten(0, 1)) 110 | else: 111 | feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) 112 | feats = feats * self.rep_config['lr'][k] 113 | setattr(representation, k, feats) 114 | ret.append(representation) 115 | return ret 116 | 117 | def forward(self, x: sp.SparseTensor) -> List[Gaussian]: 118 | h = super().forward(x) 119 | h = h.type(x.dtype) 120 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 121 | h = self.out_layer(h) 122 | return self.to_representation(h) 123 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/decoder_rf.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from ...modules import sparse as sp 7 | from .base import SparseTransformerBase 8 | from ...representations import Strivec 9 | 10 | 11 | class SLatRadianceFieldDecoder(SparseTransformerBase): 12 | def __init__( 13 | self, 14 | resolution: int, 15 | model_channels: int, 16 | latent_channels: int, 17 | num_blocks: int, 18 | num_heads: Optional[int] = None, 19 | num_head_channels: Optional[int] = 64, 20 | mlp_ratio: float = 4, 21 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 22 | window_size: int = 8, 23 | pe_mode: Literal["ape", "rope"] = "ape", 24 | use_fp16: bool = False, 25 | use_checkpoint: bool = False, 26 | qk_rms_norm: bool = False, 27 | representation_config: dict = None, 28 | ): 29 | super().__init__( 30 | in_channels=latent_channels, 31 | model_channels=model_channels, 32 | num_blocks=num_blocks, 33 | num_heads=num_heads, 34 | num_head_channels=num_head_channels, 35 | mlp_ratio=mlp_ratio, 36 | attn_mode=attn_mode, 37 | window_size=window_size, 38 | pe_mode=pe_mode, 39 | use_fp16=use_fp16, 40 | use_checkpoint=use_checkpoint, 41 | qk_rms_norm=qk_rms_norm, 42 | ) 43 | self.resolution = resolution 44 | self.rep_config = representation_config 45 | self._calc_layout() 46 | self.out_layer = sp.SparseLinear(model_channels, self.out_channels) 47 | 48 | self.initialize_weights() 49 | if use_fp16: 50 | self.convert_to_fp16() 51 | 52 | def initialize_weights(self) -> None: 53 | super().initialize_weights() 54 | # Zero-out output layers: 55 | nn.init.constant_(self.out_layer.weight, 0) 56 | nn.init.constant_(self.out_layer.bias, 0) 57 | 58 | def _calc_layout(self) -> None: 59 | self.layout = { 60 | 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']}, 61 | 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']}, 62 | 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3}, 63 | } 64 | start = 0 65 | for k, v in self.layout.items(): 66 | v['range'] = (start, start + v['size']) 67 | start += v['size'] 68 | self.out_channels = start 69 | 70 | def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: 71 | """ 72 | Convert a batch of network outputs to 3D representations. 73 | 74 | Args: 75 | x: The [N x * x C] sparse tensor output by the network. 76 | 77 | Returns: 78 | list of representations 79 | """ 80 | ret = [] 81 | for i in range(x.shape[0]): 82 | representation = Strivec( 83 | sh_degree=0, 84 | resolution=self.resolution, 85 | aabb=[-0.5, -0.5, -0.5, 1, 1, 1], 86 | rank=self.rep_config['rank'], 87 | dim=self.rep_config['dim'], 88 | device='cuda', 89 | ) 90 | representation.density_shift = 0.0 91 | representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution 92 | representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') 93 | for k, v in self.layout.items(): 94 | setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])) 95 | representation.trivec = representation.trivec + 1 96 | ret.append(representation) 97 | return ret 98 | 99 | def forward(self, x: sp.SparseTensor) -> List[Strivec]: 100 | h = super().forward(x) 101 | h = h.type(x.dtype) 102 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 103 | h = self.out_layer(h) 104 | return self.to_representation(h) 105 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ...modules import sparse as sp 6 | from .base import SparseTransformerBase 7 | 8 | 9 | class SLatEncoder(SparseTransformerBase): 10 | def __init__( 11 | self, 12 | resolution: int, 13 | in_channels: int, 14 | model_channels: int, 15 | latent_channels: int, 16 | num_blocks: int, 17 | num_heads: Optional[int] = None, 18 | num_head_channels: Optional[int] = 64, 19 | mlp_ratio: float = 4, 20 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 21 | window_size: int = 8, 22 | pe_mode: Literal["ape", "rope"] = "ape", 23 | use_fp16: bool = False, 24 | use_checkpoint: bool = False, 25 | qk_rms_norm: bool = False, 26 | ): 27 | super().__init__( 28 | in_channels=in_channels, 29 | model_channels=model_channels, 30 | num_blocks=num_blocks, 31 | num_heads=num_heads, 32 | num_head_channels=num_head_channels, 33 | mlp_ratio=mlp_ratio, 34 | attn_mode=attn_mode, 35 | window_size=window_size, 36 | pe_mode=pe_mode, 37 | use_fp16=use_fp16, 38 | use_checkpoint=use_checkpoint, 39 | qk_rms_norm=qk_rms_norm, 40 | ) 41 | self.resolution = resolution 42 | self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) 43 | 44 | self.initialize_weights() 45 | if use_fp16: 46 | self.convert_to_fp16() 47 | 48 | def initialize_weights(self) -> None: 49 | super().initialize_weights() 50 | # Zero-out output layers: 51 | nn.init.constant_(self.out_layer.weight, 0) 52 | nn.init.constant_(self.out_layer.bias, 0) 53 | 54 | def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): 55 | h = super().forward(x) 56 | h = h.type(x.dtype) 57 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 58 | h = self.out_layer(h) 59 | 60 | # Sample from the posterior distribution 61 | mean, logvar = h.feats.chunk(2, dim=-1) 62 | if sample_posterior: 63 | std = torch.exp(0.5 * logvar) 64 | z = mean + std * torch.randn_like(std) 65 | else: 66 | z = mean 67 | z = h.replace(z) 68 | 69 | if return_raw: 70 | return z, mean, logvar 71 | else: 72 | return z 73 | -------------------------------------------------------------------------------- /trellis/modules/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | BACKEND = 'flash_attn' 4 | DEBUG = False 5 | 6 | def __from_env(): 7 | import os 8 | 9 | global BACKEND 10 | global DEBUG 11 | 12 | env_attn_backend = os.environ.get('ATTN_BACKEND') 13 | env_sttn_debug = os.environ.get('ATTN_DEBUG') 14 | 15 | if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: 16 | BACKEND = env_attn_backend 17 | if env_sttn_debug is not None: 18 | DEBUG = env_sttn_debug == '1' 19 | 20 | print(f"[ATTENTION] Using backend: {BACKEND}") 21 | 22 | 23 | __from_env() 24 | 25 | 26 | def set_backend(backend: Literal['xformers', 'flash_attn']): 27 | global BACKEND 28 | BACKEND = backend 29 | 30 | def set_debug(debug: bool): 31 | global DEBUG 32 | DEBUG = debug 33 | 34 | 35 | from .full_attn import * 36 | from .modules import * 37 | -------------------------------------------------------------------------------- /trellis/modules/attention/full_attn.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import math 4 | from . import DEBUG, BACKEND 5 | 6 | if BACKEND == 'xformers': 7 | import xformers.ops as xops 8 | elif BACKEND == 'flash_attn': 9 | import flash_attn 10 | elif BACKEND == 'sdpa': 11 | from torch.nn.functional import scaled_dot_product_attention as sdpa 12 | elif BACKEND == 'naive': 13 | pass 14 | else: 15 | raise ValueError(f"Unknown attention backend: {BACKEND}") 16 | 17 | 18 | __all__ = [ 19 | 'scaled_dot_product_attention', 20 | ] 21 | 22 | 23 | def _naive_sdpa(q, k, v): 24 | """ 25 | Naive implementation of scaled dot product attention. 26 | """ 27 | q = q.permute(0, 2, 1, 3) # [N, H, L, C] 28 | k = k.permute(0, 2, 1, 3) # [N, H, L, C] 29 | v = v.permute(0, 2, 1, 3) # [N, H, L, C] 30 | scale_factor = 1 / math.sqrt(q.size(-1)) 31 | attn_weight = q @ k.transpose(-2, -1) * scale_factor 32 | attn_weight = torch.softmax(attn_weight, dim=-1) 33 | out = attn_weight @ v 34 | out = out.permute(0, 2, 1, 3) # [N, L, H, C] 35 | return out 36 | 37 | 38 | @overload 39 | def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: 40 | """ 41 | Apply scaled dot product attention. 42 | 43 | Args: 44 | qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. 45 | """ 46 | ... 47 | 48 | @overload 49 | def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: 50 | """ 51 | Apply scaled dot product attention. 52 | 53 | Args: 54 | q (torch.Tensor): A [N, L, H, C] tensor containing Qs. 55 | kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. 56 | """ 57 | ... 58 | 59 | @overload 60 | def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 61 | """ 62 | Apply scaled dot product attention. 63 | 64 | Args: 65 | q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. 66 | k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. 67 | v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. 68 | 69 | Note: 70 | k and v are assumed to have the same coordinate map. 71 | """ 72 | ... 73 | 74 | def scaled_dot_product_attention(*args, **kwargs): 75 | arg_names_dict = { 76 | 1: ['qkv'], 77 | 2: ['q', 'kv'], 78 | 3: ['q', 'k', 'v'] 79 | } 80 | num_all_args = len(args) + len(kwargs) 81 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" 82 | for key in arg_names_dict[num_all_args][len(args):]: 83 | assert key in kwargs, f"Missing argument {key}" 84 | 85 | if num_all_args == 1: 86 | qkv = args[0] if len(args) > 0 else kwargs['qkv'] 87 | assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" 88 | device = qkv.device 89 | 90 | elif num_all_args == 2: 91 | q = args[0] if len(args) > 0 else kwargs['q'] 92 | kv = args[1] if len(args) > 1 else kwargs['kv'] 93 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" 94 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" 95 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" 96 | device = q.device 97 | 98 | elif num_all_args == 3: 99 | q = args[0] if len(args) > 0 else kwargs['q'] 100 | k = args[1] if len(args) > 1 else kwargs['k'] 101 | v = args[2] if len(args) > 2 else kwargs['v'] 102 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" 103 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" 104 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" 105 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" 106 | device = q.device 107 | 108 | if BACKEND == 'xformers': 109 | if num_all_args == 1: 110 | q, k, v = qkv.unbind(dim=2) 111 | elif num_all_args == 2: 112 | k, v = kv.unbind(dim=2) 113 | out = xops.memory_efficient_attention(q, k, v) 114 | elif BACKEND == 'flash_attn': 115 | if num_all_args == 1: 116 | out = flash_attn.flash_attn_qkvpacked_func(qkv) 117 | elif num_all_args == 2: 118 | out = flash_attn.flash_attn_kvpacked_func(q, kv) 119 | elif num_all_args == 3: 120 | out = flash_attn.flash_attn_func(q, k, v) 121 | elif BACKEND == 'sdpa': 122 | if num_all_args == 1: 123 | q, k, v = qkv.unbind(dim=2) 124 | elif num_all_args == 2: 125 | k, v = kv.unbind(dim=2) 126 | q = q.permute(0, 2, 1, 3) # [N, H, L, C] 127 | k = k.permute(0, 2, 1, 3) # [N, H, L, C] 128 | v = v.permute(0, 2, 1, 3) # [N, H, L, C] 129 | out = sdpa(q, k, v) # [N, H, L, C] 130 | out = out.permute(0, 2, 1, 3) # [N, L, H, C] 131 | elif BACKEND == 'naive': 132 | if num_all_args == 1: 133 | q, k, v = qkv.unbind(dim=2) 134 | elif num_all_args == 2: 135 | k, v = kv.unbind(dim=2) 136 | out = _naive_sdpa(q, k, v) 137 | else: 138 | raise ValueError(f"Unknown attention module: {BACKEND}") 139 | 140 | return out 141 | -------------------------------------------------------------------------------- /trellis/modules/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LayerNorm32(nn.LayerNorm): 6 | def forward(self, x: torch.Tensor) -> torch.Tensor: 7 | return super().forward(x.float()).type(x.dtype) 8 | 9 | 10 | class GroupNorm32(nn.GroupNorm): 11 | """ 12 | A GroupNorm layer that converts to float32 before the forward pass. 13 | """ 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | return super().forward(x.float()).type(x.dtype) 16 | 17 | 18 | class ChannelLayerNorm32(LayerNorm32): 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | DIM = x.dim() 21 | x = x.permute(0, *range(2, DIM), 1).contiguous() 22 | x = super().forward(x) 23 | x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() 24 | return x 25 | -------------------------------------------------------------------------------- /trellis/modules/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | BACKEND = 'spconv' 4 | DEBUG = False 5 | ATTN = 'flash_attn' 6 | 7 | def __from_env(): 8 | import os 9 | 10 | global BACKEND 11 | global DEBUG 12 | global ATTN 13 | 14 | env_sparse_backend = os.environ.get('SPARSE_BACKEND') 15 | env_sparse_debug = os.environ.get('SPARSE_DEBUG') 16 | env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') 17 | if env_sparse_attn is None: 18 | env_sparse_attn = os.environ.get('ATTN_BACKEND') 19 | 20 | if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: 21 | BACKEND = env_sparse_backend 22 | if env_sparse_debug is not None: 23 | DEBUG = env_sparse_debug == '1' 24 | if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: 25 | ATTN = env_sparse_attn 26 | 27 | print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") 28 | 29 | 30 | __from_env() 31 | 32 | 33 | def set_backend(backend: Literal['spconv', 'torchsparse']): 34 | global BACKEND 35 | BACKEND = backend 36 | 37 | def set_debug(debug: bool): 38 | global DEBUG 39 | DEBUG = debug 40 | 41 | def set_attn(attn: Literal['xformers', 'flash_attn']): 42 | global ATTN 43 | ATTN = attn 44 | 45 | 46 | import importlib 47 | 48 | __attributes = { 49 | 'SparseTensor': 'basic', 50 | 'sparse_batch_broadcast': 'basic', 51 | 'sparse_batch_op': 'basic', 52 | 'sparse_cat': 'basic', 53 | 'sparse_unbind': 'basic', 54 | 'SparseGroupNorm': 'norm', 55 | 'SparseLayerNorm': 'norm', 56 | 'SparseGroupNorm32': 'norm', 57 | 'SparseLayerNorm32': 'norm', 58 | 'SparseReLU': 'nonlinearity', 59 | 'SparseSiLU': 'nonlinearity', 60 | 'SparseGELU': 'nonlinearity', 61 | 'SparseActivation': 'nonlinearity', 62 | 'SparseLinear': 'linear', 63 | 'sparse_scaled_dot_product_attention': 'attention', 64 | 'SerializeMode': 'attention', 65 | 'sparse_serialized_scaled_dot_product_self_attention': 'attention', 66 | 'sparse_windowed_scaled_dot_product_self_attention': 'attention', 67 | 'SparseMultiHeadAttention': 'attention', 68 | 'SparseConv3d': 'conv', 69 | 'SparseInverseConv3d': 'conv', 70 | 'SparseDownsample': 'spatial', 71 | 'SparseUpsample': 'spatial', 72 | 'SparseSubdivide' : 'spatial' 73 | } 74 | 75 | __submodules = ['transformer'] 76 | 77 | __all__ = list(__attributes.keys()) + __submodules 78 | 79 | def __getattr__(name): 80 | if name not in globals(): 81 | if name in __attributes: 82 | module_name = __attributes[name] 83 | module = importlib.import_module(f".{module_name}", __name__) 84 | globals()[name] = getattr(module, name) 85 | elif name in __submodules: 86 | module = importlib.import_module(f".{name}", __name__) 87 | globals()[name] = module 88 | else: 89 | raise AttributeError(f"module {__name__} has no attribute {name}") 90 | return globals()[name] 91 | 92 | 93 | # For Pylance 94 | if __name__ == '__main__': 95 | from .basic import * 96 | from .norm import * 97 | from .nonlinearity import * 98 | from .linear import * 99 | from .attention import * 100 | from .conv import * 101 | from .spatial import * 102 | import transformer 103 | -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .full_attn import * 2 | from .serialized_attn import * 3 | from .windowed_attn import * 4 | from .modules import * 5 | -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import BACKEND 2 | 3 | 4 | SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' 5 | 6 | def __from_env(): 7 | import os 8 | 9 | global SPCONV_ALGO 10 | env_spconv_algo = os.environ.get('SPCONV_ALGO') 11 | if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: 12 | SPCONV_ALGO = env_spconv_algo 13 | print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") 14 | 15 | 16 | __from_env() 17 | 18 | if BACKEND == 'torchsparse': 19 | from .conv_torchsparse import * 20 | elif BACKEND == 'spconv': 21 | from .conv_spconv import * 22 | -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/conv_spconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .. import SparseTensor 4 | from .. import DEBUG 5 | from . import SPCONV_ALGO 6 | 7 | class SparseConv3d(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): 9 | super(SparseConv3d, self).__init__() 10 | if 'spconv' not in globals(): 11 | import spconv.pytorch as spconv 12 | algo = None 13 | if SPCONV_ALGO == 'native': 14 | algo = spconv.ConvAlgo.Native 15 | elif SPCONV_ALGO == 'implicit_gemm': 16 | algo = spconv.ConvAlgo.MaskImplicitGemm 17 | if stride == 1 and (padding is None): 18 | self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) 19 | else: 20 | self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) 21 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) 22 | self.padding = padding 23 | 24 | def forward(self, x: SparseTensor) -> SparseTensor: 25 | spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) 26 | new_data = self.conv(x.data) 27 | new_shape = [x.shape[0], self.conv.out_channels] 28 | new_layout = None if spatial_changed else x.layout 29 | 30 | if spatial_changed and (x.shape[0] != 1): 31 | # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords 32 | fwd = new_data.indices[:, 0].argsort() 33 | bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) 34 | sorted_feats = new_data.features[fwd] 35 | sorted_coords = new_data.indices[fwd] 36 | unsorted_data = new_data 37 | new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore 38 | 39 | out = SparseTensor( 40 | new_data, shape=torch.Size(new_shape), layout=new_layout, 41 | scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), 42 | spatial_cache=x._spatial_cache, 43 | ) 44 | 45 | if spatial_changed and (x.shape[0] != 1): 46 | out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) 47 | out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) 48 | 49 | return out 50 | 51 | 52 | class SparseInverseConv3d(nn.Module): 53 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 54 | super(SparseInverseConv3d, self).__init__() 55 | if 'spconv' not in globals(): 56 | import spconv.pytorch as spconv 57 | self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) 58 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) 59 | 60 | def forward(self, x: SparseTensor) -> SparseTensor: 61 | spatial_changed = any(s != 1 for s in self.stride) 62 | if spatial_changed: 63 | # recover the original spconv order 64 | data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') 65 | bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') 66 | data = data.replace_feature(x.feats[bwd]) 67 | if DEBUG: 68 | assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' 69 | else: 70 | data = x.data 71 | 72 | new_data = self.conv(data) 73 | new_shape = [x.shape[0], self.conv.out_channels] 74 | new_layout = None if spatial_changed else x.layout 75 | out = SparseTensor( 76 | new_data, shape=torch.Size(new_shape), layout=new_layout, 77 | scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), 78 | spatial_cache=x._spatial_cache, 79 | ) 80 | return out 81 | -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/conv_torchsparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .. import SparseTensor 4 | 5 | 6 | class SparseConv3d(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 8 | super(SparseConv3d, self).__init__() 9 | if 'torchsparse' not in globals(): 10 | import torchsparse 11 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) 12 | 13 | def forward(self, x: SparseTensor) -> SparseTensor: 14 | out = self.conv(x.data) 15 | new_shape = [x.shape[0], self.conv.out_channels] 16 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) 17 | out._spatial_cache = x._spatial_cache 18 | out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) 19 | return out 20 | 21 | 22 | class SparseInverseConv3d(nn.Module): 23 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 24 | super(SparseInverseConv3d, self).__init__() 25 | if 'torchsparse' not in globals(): 26 | import torchsparse 27 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) 28 | 29 | def forward(self, x: SparseTensor) -> SparseTensor: 30 | out = self.conv(x.data) 31 | new_shape = [x.shape[0], self.conv.out_channels] 32 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) 33 | out._spatial_cache = x._spatial_cache 34 | out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) 35 | return out 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /trellis/modules/sparse/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | 5 | __all__ = [ 6 | 'SparseLinear' 7 | ] 8 | 9 | 10 | class SparseLinear(nn.Linear): 11 | def __init__(self, in_features, out_features, bias=True): 12 | super(SparseLinear, self).__init__(in_features, out_features, bias) 13 | 14 | def forward(self, input: SparseTensor) -> SparseTensor: 15 | return input.replace(super().forward(input.feats)) 16 | -------------------------------------------------------------------------------- /trellis/modules/sparse/nonlinearity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | 5 | __all__ = [ 6 | 'SparseReLU', 7 | 'SparseSiLU', 8 | 'SparseGELU', 9 | 'SparseActivation' 10 | ] 11 | 12 | 13 | class SparseReLU(nn.ReLU): 14 | def forward(self, input: SparseTensor) -> SparseTensor: 15 | return input.replace(super().forward(input.feats)) 16 | 17 | 18 | class SparseSiLU(nn.SiLU): 19 | def forward(self, input: SparseTensor) -> SparseTensor: 20 | return input.replace(super().forward(input.feats)) 21 | 22 | 23 | class SparseGELU(nn.GELU): 24 | def forward(self, input: SparseTensor) -> SparseTensor: 25 | return input.replace(super().forward(input.feats)) 26 | 27 | 28 | class SparseActivation(nn.Module): 29 | def __init__(self, activation: nn.Module): 30 | super().__init__() 31 | self.activation = activation 32 | 33 | def forward(self, input: SparseTensor) -> SparseTensor: 34 | return input.replace(self.activation(input.feats)) 35 | 36 | -------------------------------------------------------------------------------- /trellis/modules/sparse/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | from . import DEBUG 5 | 6 | __all__ = [ 7 | 'SparseGroupNorm', 8 | 'SparseLayerNorm', 9 | 'SparseGroupNorm32', 10 | 'SparseLayerNorm32', 11 | ] 12 | 13 | 14 | class SparseGroupNorm(nn.GroupNorm): 15 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): 16 | super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) 17 | 18 | def forward(self, input: SparseTensor) -> SparseTensor: 19 | nfeats = torch.zeros_like(input.feats) 20 | for k in range(input.shape[0]): 21 | if DEBUG: 22 | assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" 23 | bfeats = input.feats[input.layout[k]] 24 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) 25 | bfeats = super().forward(bfeats) 26 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) 27 | nfeats[input.layout[k]] = bfeats 28 | return input.replace(nfeats) 29 | 30 | 31 | class SparseLayerNorm(nn.LayerNorm): 32 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 33 | super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) 34 | 35 | def forward(self, input: SparseTensor) -> SparseTensor: 36 | nfeats = torch.zeros_like(input.feats) 37 | for k in range(input.shape[0]): 38 | bfeats = input.feats[input.layout[k]] 39 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) 40 | bfeats = super().forward(bfeats) 41 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) 42 | nfeats[input.layout[k]] = bfeats 43 | return input.replace(nfeats) 44 | 45 | 46 | class SparseGroupNorm32(SparseGroupNorm): 47 | """ 48 | A GroupNorm layer that converts to float32 before the forward pass. 49 | """ 50 | def forward(self, x: SparseTensor) -> SparseTensor: 51 | return super().forward(x.float()).type(x.dtype) 52 | 53 | class SparseLayerNorm32(SparseLayerNorm): 54 | """ 55 | A LayerNorm layer that converts to float32 before the forward pass. 56 | """ 57 | def forward(self, x: SparseTensor) -> SparseTensor: 58 | return super().forward(x.float()).type(x.dtype) 59 | -------------------------------------------------------------------------------- /trellis/modules/sparse/spatial.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from . import SparseTensor 5 | 6 | __all__ = [ 7 | 'SparseDownsample', 8 | 'SparseUpsample', 9 | 'SparseSubdivide' 10 | ] 11 | 12 | 13 | class SparseDownsample(nn.Module): 14 | """ 15 | Downsample a sparse tensor by a factor of `factor`. 16 | Implemented as average pooling. 17 | """ 18 | def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): 19 | super(SparseDownsample, self).__init__() 20 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor 21 | 22 | def forward(self, input: SparseTensor) -> SparseTensor: 23 | DIM = input.coords.shape[-1] - 1 24 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM 25 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' 26 | 27 | coord = list(input.coords.unbind(dim=-1)) 28 | for i, f in enumerate(factor): 29 | coord[i+1] = coord[i+1] // f 30 | 31 | MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] 32 | OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] 33 | code = sum([c * o for c, o in zip(coord, OFFSET)]) 34 | code, idx = code.unique(return_inverse=True) 35 | 36 | new_feats = torch.scatter_reduce( 37 | torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), 38 | dim=0, 39 | index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), 40 | src=input.feats, 41 | reduce='mean' 42 | ) 43 | new_coords = torch.stack( 44 | [code // OFFSET[0]] + 45 | [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], 46 | dim=-1 47 | ) 48 | out = SparseTensor(new_feats, new_coords, input.shape,) 49 | out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) 50 | out._spatial_cache = input._spatial_cache 51 | 52 | out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) 53 | out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) 54 | out.register_spatial_cache(f'upsample_{factor}_idx', idx) 55 | 56 | return out 57 | 58 | 59 | class SparseUpsample(nn.Module): 60 | """ 61 | Upsample a sparse tensor by a factor of `factor`. 62 | Implemented as nearest neighbor interpolation. 63 | """ 64 | def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): 65 | super(SparseUpsample, self).__init__() 66 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor 67 | 68 | def forward(self, input: SparseTensor) -> SparseTensor: 69 | DIM = input.coords.shape[-1] - 1 70 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM 71 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' 72 | 73 | new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') 74 | new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') 75 | idx = input.get_spatial_cache(f'upsample_{factor}_idx') 76 | if any([x is None for x in [new_coords, new_layout, idx]]): 77 | raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') 78 | new_feats = input.feats[idx] 79 | out = SparseTensor(new_feats, new_coords, input.shape, new_layout) 80 | out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) 81 | out._spatial_cache = input._spatial_cache 82 | return out 83 | 84 | class SparseSubdivide(nn.Module): 85 | """ 86 | Upsample a sparse tensor by a factor of `factor`. 87 | Implemented as nearest neighbor interpolation. 88 | """ 89 | def __init__(self): 90 | super(SparseSubdivide, self).__init__() 91 | 92 | def forward(self, input: SparseTensor) -> SparseTensor: 93 | DIM = input.coords.shape[-1] - 1 94 | # upsample scale=2^DIM 95 | n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) 96 | n_coords = torch.nonzero(n_cube) 97 | n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) 98 | factor = n_coords.shape[0] 99 | assert factor == 2 ** DIM 100 | # print(n_coords.shape) 101 | new_coords = input.coords.clone() 102 | new_coords[:, 1:] *= 2 103 | new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) 104 | 105 | new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) 106 | out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) 107 | out._scale = input._scale * 2 108 | out._spatial_cache = input._spatial_cache 109 | return out 110 | 111 | -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .modulated import * -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from ..basic import SparseTensor 5 | from ..linear import SparseLinear 6 | from ..nonlinearity import SparseGELU 7 | from ..attention import SparseMultiHeadAttention, SerializeMode 8 | from ...norm import LayerNorm32 9 | 10 | 11 | class SparseFeedForwardNet(nn.Module): 12 | def __init__(self, channels: int, mlp_ratio: float = 4.0): 13 | super().__init__() 14 | self.mlp = nn.Sequential( 15 | SparseLinear(channels, int(channels * mlp_ratio)), 16 | SparseGELU(approximate="tanh"), 17 | SparseLinear(int(channels * mlp_ratio), channels), 18 | ) 19 | 20 | def forward(self, x: SparseTensor) -> SparseTensor: 21 | return self.mlp(x) 22 | 23 | 24 | class SparseTransformerBlock(nn.Module): 25 | """ 26 | Sparse Transformer block (MSA + FFN). 27 | """ 28 | def __init__( 29 | self, 30 | channels: int, 31 | num_heads: int, 32 | mlp_ratio: float = 4.0, 33 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 34 | window_size: Optional[int] = None, 35 | shift_sequence: Optional[int] = None, 36 | shift_window: Optional[Tuple[int, int, int]] = None, 37 | serialize_mode: Optional[SerializeMode] = None, 38 | use_checkpoint: bool = False, 39 | use_rope: bool = False, 40 | qk_rms_norm: bool = False, 41 | qkv_bias: bool = True, 42 | ln_affine: bool = False, 43 | ): 44 | super().__init__() 45 | self.use_checkpoint = use_checkpoint 46 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 47 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 48 | self.attn = SparseMultiHeadAttention( 49 | channels, 50 | num_heads=num_heads, 51 | attn_mode=attn_mode, 52 | window_size=window_size, 53 | shift_sequence=shift_sequence, 54 | shift_window=shift_window, 55 | serialize_mode=serialize_mode, 56 | qkv_bias=qkv_bias, 57 | use_rope=use_rope, 58 | qk_rms_norm=qk_rms_norm, 59 | ) 60 | self.mlp = SparseFeedForwardNet( 61 | channels, 62 | mlp_ratio=mlp_ratio, 63 | ) 64 | 65 | def _forward(self, x: SparseTensor) -> SparseTensor: 66 | h = x.replace(self.norm1(x.feats)) 67 | h = self.attn(h) 68 | x = x + h 69 | h = x.replace(self.norm2(x.feats)) 70 | h = self.mlp(h) 71 | x = x + h 72 | return x 73 | 74 | def forward(self, x: SparseTensor) -> SparseTensor: 75 | if self.use_checkpoint: 76 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) 77 | else: 78 | return self._forward(x) 79 | 80 | 81 | class SparseTransformerCrossBlock(nn.Module): 82 | """ 83 | Sparse Transformer cross-attention block (MSA + MCA + FFN). 84 | """ 85 | def __init__( 86 | self, 87 | channels: int, 88 | ctx_channels: int, 89 | num_heads: int, 90 | mlp_ratio: float = 4.0, 91 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 92 | window_size: Optional[int] = None, 93 | shift_sequence: Optional[int] = None, 94 | shift_window: Optional[Tuple[int, int, int]] = None, 95 | serialize_mode: Optional[SerializeMode] = None, 96 | use_checkpoint: bool = False, 97 | use_rope: bool = False, 98 | qk_rms_norm: bool = False, 99 | qk_rms_norm_cross: bool = False, 100 | qkv_bias: bool = True, 101 | ln_affine: bool = False, 102 | ): 103 | super().__init__() 104 | self.use_checkpoint = use_checkpoint 105 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 106 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 107 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 108 | self.self_attn = SparseMultiHeadAttention( 109 | channels, 110 | num_heads=num_heads, 111 | type="self", 112 | attn_mode=attn_mode, 113 | window_size=window_size, 114 | shift_sequence=shift_sequence, 115 | shift_window=shift_window, 116 | serialize_mode=serialize_mode, 117 | qkv_bias=qkv_bias, 118 | use_rope=use_rope, 119 | qk_rms_norm=qk_rms_norm, 120 | ) 121 | self.cross_attn = SparseMultiHeadAttention( 122 | channels, 123 | ctx_channels=ctx_channels, 124 | num_heads=num_heads, 125 | type="cross", 126 | attn_mode="full", 127 | qkv_bias=qkv_bias, 128 | qk_rms_norm=qk_rms_norm_cross, 129 | ) 130 | self.mlp = SparseFeedForwardNet( 131 | channels, 132 | mlp_ratio=mlp_ratio, 133 | ) 134 | 135 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): 136 | h = x.replace(self.norm1(x.feats)) 137 | h = self.self_attn(h) 138 | x = x + h 139 | h = x.replace(self.norm2(x.feats)) 140 | h = self.cross_attn(h, context) 141 | x = x + h 142 | h = x.replace(self.norm3(x.feats)) 143 | h = self.mlp(h) 144 | x = x + h 145 | return x 146 | 147 | def forward(self, x: SparseTensor, context: torch.Tensor): 148 | if self.use_checkpoint: 149 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) 150 | else: 151 | return self._forward(x, context) 152 | -------------------------------------------------------------------------------- /trellis/modules/spatial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: 5 | """ 6 | 3D pixel shuffle. 7 | """ 8 | B, C, H, W, D = x.shape 9 | C_ = C // scale_factor**3 10 | x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) 11 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) 12 | x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) 13 | return x 14 | 15 | 16 | def patchify(x: torch.Tensor, patch_size: int): 17 | """ 18 | Patchify a tensor. 19 | 20 | Args: 21 | x (torch.Tensor): (N, C, *spatial) tensor 22 | patch_size (int): Patch size 23 | """ 24 | DIM = x.dim() - 2 25 | for d in range(2, DIM + 2): 26 | assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" 27 | 28 | x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) 29 | x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) 30 | x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) 31 | return x 32 | 33 | 34 | def unpatchify(x: torch.Tensor, patch_size: int): 35 | """ 36 | Unpatchify a tensor. 37 | 38 | Args: 39 | x (torch.Tensor): (N, C, *spatial) tensor 40 | patch_size (int): Patch size 41 | """ 42 | DIM = x.dim() - 2 43 | assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" 44 | 45 | x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) 46 | x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) 47 | x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) 48 | return x 49 | -------------------------------------------------------------------------------- /trellis/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .modulated import * -------------------------------------------------------------------------------- /trellis/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..modules import sparse as sp 3 | 4 | FP16_MODULES = ( 5 | nn.Conv1d, 6 | nn.Conv2d, 7 | nn.Conv3d, 8 | nn.ConvTranspose1d, 9 | nn.ConvTranspose2d, 10 | nn.ConvTranspose3d, 11 | nn.Linear, 12 | sp.SparseConv3d, 13 | sp.SparseInverseConv3d, 14 | sp.SparseLinear, 15 | ) 16 | 17 | def convert_module_to_f16(l): 18 | """ 19 | Convert primitive modules to float16. 20 | """ 21 | if isinstance(l, FP16_MODULES): 22 | for p in l.parameters(): 23 | p.data = p.data.half() 24 | 25 | 26 | def convert_module_to_f32(l): 27 | """ 28 | Convert primitive modules to float32, undoing convert_module_to_f16(). 29 | """ 30 | if isinstance(l, FP16_MODULES): 31 | for p in l.parameters(): 32 | p.data = p.data.float() 33 | 34 | 35 | def zero_module(module): 36 | """ 37 | Zero out the parameters of a module and return it. 38 | """ 39 | for p in module.parameters(): 40 | p.detach().zero_() 41 | return module 42 | 43 | 44 | def scale_module(module, scale): 45 | """ 46 | Scale the parameters of a module and return it. 47 | """ 48 | for p in module.parameters(): 49 | p.detach().mul_(scale) 50 | return module 51 | 52 | 53 | def modulate(x, shift, scale): 54 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 55 | -------------------------------------------------------------------------------- /trellis/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from . import samplers 2 | from .trellis_image_to_3d import TrellisImageTo3DPipeline 3 | 4 | 5 | def from_pretrained(path: str): 6 | """ 7 | Load a pipeline from a model folder or a Hugging Face model hub. 8 | 9 | Args: 10 | path: The path to the model. Can be either local path or a Hugging Face model name. 11 | """ 12 | import os 13 | import json 14 | is_local = os.path.exists(f"{path}/pipeline.json") 15 | 16 | if is_local: 17 | config_file = f"{path}/pipeline.json" 18 | else: 19 | from huggingface_hub import hf_hub_download 20 | config_file = hf_hub_download(path, "pipeline.json") 21 | 22 | with open(config_file, 'r') as f: 23 | config = json.load(f) 24 | return globals()[config['name']].from_pretrained(path) 25 | -------------------------------------------------------------------------------- /trellis/pipelines/base.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from .. import models 5 | 6 | 7 | class Pipeline: 8 | """ 9 | A base class for pipelines. 10 | """ 11 | def __init__( 12 | self, 13 | models: dict[str, nn.Module] = None, 14 | ): 15 | if models is None: 16 | return 17 | self.models = models 18 | for model in self.models.values(): 19 | model.eval() 20 | 21 | @staticmethod 22 | def from_pretrained(path: str) -> "Pipeline": 23 | """ 24 | Load a pretrained model. 25 | """ 26 | import os 27 | import json 28 | is_local = os.path.exists(f"{path}/pipeline.json") 29 | 30 | if is_local: 31 | config_file = f"{path}/pipeline.json" 32 | else: 33 | from huggingface_hub import hf_hub_download 34 | config_file = hf_hub_download(path, "pipeline.json") 35 | 36 | with open(config_file, 'r') as f: 37 | args = json.load(f)['args'] 38 | 39 | _models = { 40 | k: models.from_pretrained(f"{path}/{v}") 41 | for k, v in args['models'].items() 42 | } 43 | 44 | new_pipeline = Pipeline(_models) 45 | new_pipeline._pretrained_args = args 46 | return new_pipeline 47 | 48 | @property 49 | def device(self) -> torch.device: 50 | for model in self.models.values(): 51 | if hasattr(model, 'device'): 52 | return model.device 53 | for model in self.models.values(): 54 | if hasattr(model, 'parameters'): 55 | return next(model.parameters()).device 56 | raise RuntimeError("No device found.") 57 | 58 | def to(self, device: torch.device) -> None: 59 | for model in self.models.values(): 60 | model.to(device) 61 | 62 | def cuda(self) -> None: 63 | self.to(torch.device("cuda")) 64 | 65 | def cpu(self) -> None: 66 | self.to(torch.device("cpu")) 67 | -------------------------------------------------------------------------------- /trellis/pipelines/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Sampler 2 | from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler -------------------------------------------------------------------------------- /trellis/pipelines/samplers/base.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class Sampler(ABC): 6 | """ 7 | A base class for samplers. 8 | """ 9 | 10 | @abstractmethod 11 | def sample( 12 | self, 13 | model, 14 | **kwargs 15 | ): 16 | """ 17 | Sample from a model. 18 | """ 19 | pass 20 | -------------------------------------------------------------------------------- /trellis/pipelines/samplers/classifier_free_guidance_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | 4 | class ClassifierFreeGuidanceSamplerMixin: 5 | """ 6 | A mixin class for samplers that apply classifier-free guidance. 7 | """ 8 | 9 | def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs): 10 | pred = super()._inference_model(model, x_t, t, cond, **kwargs) 11 | neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) 12 | return (1 + cfg_strength) * pred - cfg_strength * neg_pred 13 | -------------------------------------------------------------------------------- /trellis/pipelines/samplers/guidance_interval_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | 4 | class GuidanceIntervalSamplerMixin: 5 | """ 6 | A mixin class for samplers that apply classifier-free guidance with interval. 7 | """ 8 | 9 | def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): 10 | if cfg_interval[0] <= t <= cfg_interval[1]: 11 | pred = super()._inference_model(model, x_t, t, cond, **kwargs) 12 | neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) 13 | return (1 + cfg_strength) * pred - cfg_strength * neg_pred 14 | else: 15 | return super()._inference_model(model, x_t, t, cond, **kwargs) 16 | -------------------------------------------------------------------------------- /trellis/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | __attributes = { 4 | 'OctreeRenderer': 'octree_renderer', 5 | 'GSplatRenderer': 'gsplat_renderer', 6 | 'MeshRenderer': 'mesh_renderer', 7 | } 8 | 9 | __submodules = [] 10 | 11 | __all__ = list(__attributes.keys()) + __submodules 12 | 13 | def __getattr__(name): 14 | if name not in globals(): 15 | if name in __attributes: 16 | module_name = __attributes[name] 17 | module = importlib.import_module(f".{module_name}", __name__) 18 | globals()[name] = getattr(module, name) 19 | elif name in __submodules: 20 | module = importlib.import_module(f".{name}", __name__) 21 | globals()[name] = module 22 | else: 23 | raise AttributeError(f"module {__name__} has no attribute {name}") 24 | return globals()[name] 25 | 26 | 27 | # For Pylance 28 | if __name__ == '__main__': 29 | from .octree_renderer import OctreeRenderer 30 | from .gsplat_renderer import GSplatRenderer 31 | from .mesh_renderer import MeshRenderer -------------------------------------------------------------------------------- /trellis/renderers/gsplat_renderer.py: -------------------------------------------------------------------------------- 1 | import gsplat as gs 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from easydict import EasyDict as edict 6 | 7 | 8 | class GSplatRenderer: 9 | def __init__(self, rendering_options={}) -> None: 10 | self.pipe = edict({ 11 | "kernel_size": 0.1, 12 | "convert_SHs_python": False, 13 | "compute_cov3D_python": False, 14 | "scale_modifier": 1.0, 15 | "debug": False, 16 | "use_mip_gaussian": True 17 | }) 18 | self.rendering_options = edict({ 19 | "resolution": None, 20 | "near": None, 21 | "far": None, 22 | "ssaa": 1, 23 | "bg_color": 'random', 24 | }) 25 | self.rendering_options.update(rendering_options) 26 | self.bg_color = None 27 | 28 | def render( 29 | self, 30 | gaussian, 31 | extrinsics: torch.Tensor, 32 | intrinsics: torch.Tensor, 33 | colors_overwrite: torch.Tensor = None 34 | ) -> edict: 35 | 36 | resolution = self.rendering_options["resolution"] 37 | ssaa = self.rendering_options["ssaa"] 38 | 39 | if self.rendering_options["bg_color"] == 'random': 40 | self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") 41 | if np.random.rand() < 0.5: 42 | self.bg_color += 1 43 | else: 44 | self.bg_color = torch.tensor( 45 | self.rendering_options["bg_color"], 46 | dtype=torch.float32, 47 | device="cuda" 48 | ) 49 | 50 | height = resolution * ssaa 51 | width = resolution * ssaa 52 | 53 | # Set up background color 54 | if self.rendering_options["bg_color"] == 'random': 55 | self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") 56 | if np.random.rand() < 0.5: 57 | self.bg_color += 1 58 | else: 59 | self.bg_color = torch.tensor( 60 | self.rendering_options["bg_color"], 61 | dtype=torch.float32, 62 | device="cuda" 63 | ) 64 | 65 | Ks_scaled = intrinsics.clone() 66 | Ks_scaled[0, 0] *= width 67 | Ks_scaled[1, 1] *= height 68 | Ks_scaled[0, 2] *= width 69 | Ks_scaled[1, 2] *= height 70 | Ks_scaled = Ks_scaled.unsqueeze(0) 71 | 72 | near_plane = 0.01 73 | far_plane = 1000.0 74 | 75 | # Rasterize with gsplat 76 | render_colors, render_alphas, meta = gs.rasterization( 77 | means=gaussian.get_xyz, 78 | quats=F.normalize(gaussian.get_rotation, dim=-1), 79 | scales=gaussian.get_scaling / intrinsics[0, 0], 80 | opacities=gaussian.get_opacity.squeeze(-1), 81 | colors=colors_overwrite.unsqueeze(0) if colors_overwrite is not None else torch.sigmoid( 82 | gaussian.get_features.squeeze(1)).unsqueeze(0), 83 | viewmats=extrinsics.unsqueeze(0), 84 | Ks=Ks_scaled, 85 | width=width, 86 | height=height, 87 | near_plane=near_plane, 88 | far_plane=far_plane, 89 | radius_clip=3.0, 90 | eps2d=0.3, 91 | render_mode="RGB", 92 | backgrounds=self.bg_color.unsqueeze(0), 93 | camera_model="pinhole" 94 | ) 95 | 96 | rendered_image = render_colors[0, ..., 0:3].permute(2, 0, 1) 97 | 98 | # Apply supersampling if needed 99 | if ssaa > 1: 100 | rendered_image = F.interpolate( 101 | rendered_image[None], 102 | size=(resolution, resolution), 103 | mode='bilinear', 104 | align_corners=False, 105 | antialias=True 106 | ).squeeze() 107 | 108 | return edict({'color': rendered_image}) -------------------------------------------------------------------------------- /trellis/renderers/mesh_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from pytorch3d.renderer import ( 4 | FoVPerspectiveCameras, 5 | RasterizationSettings, 6 | MeshRasterizer, 7 | ) 8 | from pytorch3d.structures import Meshes 9 | from pytorch3d.renderer.blending import BlendParams 10 | from pytorch3d.renderer.mesh.shader import HardFlatShader 11 | from easydict import EasyDict as edict 12 | from ..representations.mesh import MeshExtractResult 13 | 14 | 15 | def intrinsics_to_pytorch3d_projection( 16 | intrinsics: torch.Tensor, 17 | image_size: tuple, 18 | near: float, 19 | far: float, 20 | ) -> torch.Tensor: 21 | """ 22 | Convert OpenCV intrinsics to PyTorch3D's camera convention. 23 | PyTorch3D uses screen coordinates in [-1, 1] and a different projection matrix. 24 | """ 25 | fx, fy = intrinsics[0, 0], intrinsics[1, 1] 26 | cx, cy = intrinsics[0, 2], intrinsics[1, 2] 27 | H, W = image_size 28 | half_pix_center = 0.5 # PyTorch3D assumes pixel centers at 0.5 29 | 30 | # Adjust for PyTorch3D's NDC space (normalized device coordinates) 31 | proj = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) 32 | proj[0, 0] = 2 * fx / W 33 | proj[1, 1] = 2 * fy / H 34 | proj[0, 2] = 2 * (cx - half_pix_center) / W - 1 35 | proj[1, 2] = 2 * (cy - half_pix_center) / H - 1 36 | proj[2, 2] = -(far + near) / (far - near) 37 | proj[2, 3] = -2 * far * near / (far - near) 38 | proj[3, 2] = -1.0 39 | return proj 40 | 41 | class MeshRenderer: 42 | def __init__(self, rendering_options={}, device="cuda"): 43 | self.rendering_options = edict({ 44 | "resolution": 512, 45 | "near": 1.0, 46 | "far": 100.0, 47 | "ssaa": 1, 48 | "bg_color": (0, 0, 0) 49 | }) 50 | self.rendering_options.update(rendering_options) 51 | self.device = device 52 | 53 | def render( 54 | self, 55 | mesh: MeshExtractResult, 56 | extrinsics: torch.Tensor, 57 | intrinsics: torch.Tensor, 58 | return_types=["normal"], 59 | ) -> edict: 60 | resolution = self.rendering_options["resolution"] 61 | ssaa = self.rendering_options["ssaa"] 62 | 63 | if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: 64 | return {"normal": torch.zeros((3, resolution, resolution), device=self.device)} 65 | 66 | # Prepare mesh data 67 | vertices = mesh.vertices.unsqueeze(0).to(self.device) # (1, V, 3) 68 | faces = mesh.faces.unsqueeze(0).to(self.device) # (1, F, 3) 69 | meshes = Meshes(verts=vertices, faces=faces) 70 | 71 | # Camera setup 72 | R = extrinsics[:3, :3].unsqueeze(0).to(self.device) 73 | T = extrinsics[:3, 3].unsqueeze(0).to(self.device) 74 | cameras = FoVPerspectiveCameras( 75 | device=self.device, 76 | R=R, 77 | T=T, 78 | znear=self.rendering_options["near"], 79 | zfar=self.rendering_options["far"], 80 | ) 81 | 82 | # Rasterization settings 83 | raster_settings = RasterizationSettings( 84 | image_size=resolution * ssaa, 85 | blur_radius=0.0, 86 | faces_per_pixel=1, 87 | ) 88 | 89 | # Render mesh 90 | rasterizer = MeshRasterizer( 91 | cameras=cameras, 92 | raster_settings=raster_settings, 93 | ) 94 | fragments = rasterizer(meshes) 95 | 96 | # Process normals 97 | verts_normals = meshes.verts_normals_packed() # (V, 3) 98 | faces_normals = verts_normals[mesh.faces].mean(dim=1) # (F, 3) 99 | visible_faces = fragments.pix_to_face.clamp(min=0) # (1, H, W, 1) 100 | 101 | # Get normals for visible pixels 102 | normals = faces_normals[visible_faces.view(-1)] # (N, 3) 103 | normals = normals.view(1, resolution*ssaa, resolution*ssaa, 3) # (1, H, W, 3) 104 | 105 | # Normalize to [0,1] and reshape to (1, 3, H, W) 106 | normals = ((normals + 1) / 2).permute(0, 3, 1, 2) 107 | 108 | # Downsample if needed 109 | if ssaa > 1: 110 | normals = F.interpolate( 111 | normals, 112 | size=(resolution, resolution), 113 | mode="bilinear", 114 | align_corners=False, 115 | ) 116 | 117 | return {"normal": normals.squeeze(0)} # (3, H, W) -------------------------------------------------------------------------------- /trellis/renderers/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /trellis/representations/__init__.py: -------------------------------------------------------------------------------- 1 | from .radiance_field import Strivec 2 | from .octree import DfsOctree as Octree 3 | from .gaussian import Gaussian 4 | from .mesh import MeshExtractResult 5 | -------------------------------------------------------------------------------- /trellis/representations/gaussian/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_model import Gaussian -------------------------------------------------------------------------------- /trellis/representations/gaussian/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /trellis/representations/mesh/__init__.py: -------------------------------------------------------------------------------- 1 | from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult 2 | -------------------------------------------------------------------------------- /trellis/representations/mesh/cube2mesh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ...modules.sparse import SparseTensor 3 | from easydict import EasyDict as edict 4 | from .utils_cube import * 5 | from .flexicubes.flexicubes import FlexiCubes 6 | 7 | 8 | class MeshExtractResult: 9 | def __init__(self, 10 | vertices, 11 | faces, 12 | vertex_attrs=None, 13 | res=64 14 | ): 15 | self.vertices = vertices 16 | self.faces = faces.long() 17 | self.vertex_attrs = vertex_attrs 18 | self.face_normal = self.comput_face_normals(vertices, faces) 19 | self.res = res 20 | self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) 21 | 22 | # training only 23 | self.tsdf_v = None 24 | self.tsdf_s = None 25 | self.reg_loss = None 26 | 27 | def comput_face_normals(self, verts, faces): 28 | i0 = faces[..., 0].long() 29 | i1 = faces[..., 1].long() 30 | i2 = faces[..., 2].long() 31 | 32 | v0 = verts[i0, :] 33 | v1 = verts[i1, :] 34 | v2 = verts[i2, :] 35 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) 36 | face_normals = torch.nn.functional.normalize(face_normals, dim=1) 37 | # print(face_normals.min(), face_normals.max(), face_normals.shape) 38 | return face_normals[:, None, :].repeat(1, 3, 1) 39 | 40 | def comput_v_normals(self, verts, faces): 41 | i0 = faces[..., 0].long() 42 | i1 = faces[..., 1].long() 43 | i2 = faces[..., 2].long() 44 | 45 | v0 = verts[i0, :] 46 | v1 = verts[i1, :] 47 | v2 = verts[i2, :] 48 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) 49 | v_normals = torch.zeros_like(verts) 50 | v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) 51 | v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) 52 | v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) 53 | 54 | v_normals = torch.nn.functional.normalize(v_normals, dim=1) 55 | return v_normals 56 | 57 | 58 | class SparseFeatures2Mesh: 59 | def __init__(self, device="cuda", res=64, use_color=True): 60 | ''' 61 | a model to generate a mesh from sparse features structures using flexicube 62 | ''' 63 | super().__init__() 64 | self.device=device 65 | self.res = res 66 | self.mesh_extractor = FlexiCubes(device=device) 67 | self.sdf_bias = -1.0 / res 68 | verts, cube = construct_dense_grid(self.res, self.device) 69 | self.reg_c = cube.to(self.device) 70 | self.reg_v = verts.to(self.device) 71 | self.use_color = use_color 72 | self._calc_layout() 73 | 74 | def _calc_layout(self): 75 | LAYOUTS = { 76 | 'sdf': {'shape': (8, 1), 'size': 8}, 77 | 'deform': {'shape': (8, 3), 'size': 8 * 3}, 78 | 'weights': {'shape': (21,), 'size': 21} 79 | } 80 | if self.use_color: 81 | ''' 82 | 6 channel color including normal map 83 | ''' 84 | LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} 85 | self.layouts = edict(LAYOUTS) 86 | start = 0 87 | for k, v in self.layouts.items(): 88 | v['range'] = (start, start + v['size']) 89 | start += v['size'] 90 | self.feats_channels = start 91 | 92 | def get_layout(self, feats : torch.Tensor, name : str): 93 | if name not in self.layouts: 94 | return None 95 | return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) 96 | 97 | def __call__(self, cubefeats : SparseTensor, training=False): 98 | """ 99 | Generates a mesh based on the specified sparse voxel structures. 100 | Args: 101 | cube_attrs [Nx21] : Sparse Tensor attrs about cube weights 102 | verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal 103 | Returns: 104 | return the success tag and ni you loss, 105 | """ 106 | # add sdf bias to verts_attrs 107 | coords = cubefeats.coords[:, 1:] 108 | feats = cubefeats.feats 109 | 110 | sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] 111 | sdf += self.sdf_bias 112 | v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] 113 | v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) 114 | v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) 115 | weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) 116 | if self.use_color: 117 | sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] 118 | else: 119 | sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] 120 | colors_d = None 121 | 122 | x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) 123 | 124 | vertices, faces, L_dev, colors = self.mesh_extractor( 125 | voxelgrid_vertices=x_nx3, 126 | scalar_field=sdf_d, 127 | cube_idx=self.reg_c, 128 | resolution=self.res, 129 | beta=weights_d[:, :12], 130 | alpha=weights_d[:, 12:20], 131 | gamma_f=weights_d[:, 20], 132 | voxelgrid_colors=colors_d, 133 | training=training) 134 | 135 | mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) 136 | if training: 137 | if mesh.success: 138 | reg_loss += L_dev.mean() * 0.5 139 | reg_loss += (weights[:,:20]).abs().mean() * 0.2 140 | mesh.reg_loss = reg_loss 141 | mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) 142 | mesh.tsdf_s = v_attrs[:, 0] 143 | return mesh 144 | -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/DCO.txt: -------------------------------------------------------------------------------- 1 | Developer Certificate of Origin 2 | Version 1.1 3 | 4 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 5 | 6 | Everyone is permitted to copy and distribute verbatim copies of this 7 | license document, but changing it is not allowed. 8 | 9 | 10 | Developer's Certificate of Origin 1.1 11 | 12 | By making a contribution to this project, I certify that: 13 | 14 | (a) The contribution was created in whole or in part by me and I 15 | have the right to submit it under the open source license 16 | indicated in the file; or 17 | 18 | (b) The contribution is based upon previous work that, to the best 19 | of my knowledge, is covered under an appropriate open source 20 | license and I have the right under that license to submit that 21 | work with modifications, whether created in whole or in part 22 | by me, under the same open source license (unless I am 23 | permitted to submit under a different license), as indicated 24 | in the file; or 25 | 26 | (c) The contribution was provided directly to me by some other 27 | person who certified (a), (b) or (c) and I have not modified 28 | it. 29 | 30 | (d) I understand and agree that this project and the contribution 31 | are public and that a record of the contribution (including all 32 | personal information I submit with it, including my sign-off) is 33 | maintained indefinitely and may be redistributed consistent with 34 | this project or the open source license(s) involved. -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/examples/download_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import requests 16 | from zipfile import ZipFile 17 | from tqdm import tqdm 18 | import os 19 | 20 | def download_file(url, output_path): 21 | response = requests.get(url, stream=True) 22 | response.raise_for_status() 23 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 24 | block_size = 1024 #1 Kibibyte 25 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 26 | 27 | with open(output_path, 'wb') as file: 28 | for data in response.iter_content(block_size): 29 | progress_bar.update(len(data)) 30 | file.write(data) 31 | progress_bar.close() 32 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 33 | raise Exception("ERROR, something went wrong") 34 | 35 | 36 | url = "https://vcg.isti.cnr.it/Publications/2014/MPZ14/inputmodels.zip" 37 | zip_file_path = './data/inputmodels.zip' 38 | 39 | os.makedirs('./data', exist_ok=True) 40 | 41 | download_file(url, zip_file_path) 42 | 43 | with ZipFile(zip_file_path, 'r') as zip_ref: 44 | zip_ref.extractall('./data') 45 | 46 | os.remove(zip_file_path) 47 | 48 | print("Download and extraction complete.") 49 | -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/examples/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | import torch_scatter 17 | 18 | ############################################################################### 19 | # Pytorch implementation of the developability regularizer introduced in paper 20 | # "Developability of Triangle Meshes" by Stein et al. 21 | ############################################################################### 22 | def mesh_developable_reg(mesh): 23 | 24 | verts = mesh.vertices 25 | tris = mesh.faces 26 | 27 | device = verts.device 28 | V = verts.shape[0] 29 | F = tris.shape[0] 30 | 31 | POS_EPS = 1e-6 32 | REL_EPS = 1e-6 33 | 34 | def normalize(vecs): 35 | return vecs / (torch.linalg.norm(vecs, dim=-1, keepdim=True) + POS_EPS) 36 | 37 | tri_pos = verts[tris] 38 | 39 | vert_normal_covariance_sum = torch.zeros((V, 9), device=device) 40 | vert_area = torch.zeros(V, device=device) 41 | vert_degree = torch.zeros(V, dtype=torch.int32, device=device) 42 | 43 | for iC in range(3): # loop over three corners of each triangle 44 | 45 | # gather tri verts 46 | pRoot = tri_pos[:, iC, :] 47 | pA = tri_pos[:, (iC + 1) % 3, :] 48 | pB = tri_pos[:, (iC + 2) % 3, :] 49 | 50 | # compute the corner angle & normal 51 | vA = pA - pRoot 52 | vAn = normalize(vA) 53 | vB = pB - pRoot 54 | vBn = normalize(vB) 55 | area_normal = torch.linalg.cross(vA, vB, dim=-1) 56 | face_area = 0.5 * torch.linalg.norm(area_normal, dim=-1) 57 | normal = normalize(area_normal) 58 | corner_angle = torch.acos(torch.clamp(torch.sum(vAn * vBn, dim=-1), min=-1., max=1.)) 59 | 60 | # add up the contribution to the covariance matrix 61 | outer = normal[:, :, None] @ normal[:, None, :] 62 | contrib = corner_angle[:, None] * outer.reshape(-1, 9) 63 | 64 | # scatter the result to the appropriate matrices 65 | vert_normal_covariance_sum = torch_scatter.scatter_add(src=contrib, 66 | index=tris[:, iC], 67 | dim=-2, 68 | out=vert_normal_covariance_sum) 69 | 70 | vert_area = torch_scatter.scatter_add(src=face_area / 3., 71 | index=tris[:, iC], 72 | dim=-1, 73 | out=vert_area) 74 | 75 | vert_degree = torch_scatter.scatter_add(src=torch.ones(F, dtype=torch.int32, device=device), 76 | index=tris[:, iC], 77 | dim=-1, 78 | out=vert_degree) 79 | 80 | # The energy is the smallest eigenvalue of the outer-product matrix 81 | vert_normal_covariance_sum = vert_normal_covariance_sum.reshape( 82 | -1, 3, 3) # reshape to a batch of matrices 83 | vert_normal_covariance_sum = vert_normal_covariance_sum + torch.eye( 84 | 3, device=device)[None, :, :] * REL_EPS 85 | 86 | min_eigvals = torch.min(torch.linalg.eigvals(vert_normal_covariance_sum).abs(), dim=-1).values 87 | 88 | # Mask out degree-3 vertices 89 | vert_area = torch.where(vert_degree == 3, torch.tensor(0, dtype=vert_area.dtype,device=vert_area.device), vert_area) 90 | 91 | # Adjust the vertex area weighting so it is unit-less, and 1 on average 92 | vert_area = vert_area * (V / torch.sum(vert_area, dim=-1, keepdim=True)) 93 | 94 | return vert_area * min_eigvals 95 | 96 | def sdf_reg_loss(sdf, all_edges): 97 | sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) 98 | mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) 99 | sdf_f1x6x2 = sdf_f1x6x2[mask] 100 | sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ 101 | torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) 102 | return sdf_diff -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/examples/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import numpy as np 16 | import torch 17 | import trimesh 18 | import kaolin 19 | import nvdiffrast.torch as dr 20 | 21 | ############################################################################### 22 | # Functions adapted from https://github.com/NVlabs/nvdiffrec 23 | ############################################################################### 24 | 25 | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 26 | return torch.sum(x*y, -1, keepdim=True) 27 | 28 | def length(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor: 29 | return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN 30 | 31 | def safe_normalize(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor: 32 | return x / length(x, eps) 33 | 34 | def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): 35 | y = np.tan(fovy / 2) 36 | return torch.tensor([[1/(y*aspect), 0, 0, 0], 37 | [ 0, 1/-y, 0, 0], 38 | [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], 39 | [ 0, 0, -1, 0]], dtype=torch.float32, device=device) 40 | 41 | def translate(x, y, z, device=None): 42 | return torch.tensor([[1, 0, 0, x], 43 | [0, 1, 0, y], 44 | [0, 0, 1, z], 45 | [0, 0, 0, 1]], dtype=torch.float32, device=device) 46 | 47 | @torch.no_grad() 48 | def random_rotation_translation(t, device=None): 49 | m = np.random.normal(size=[3, 3]) 50 | m[1] = np.cross(m[0], m[2]) 51 | m[2] = np.cross(m[0], m[1]) 52 | m = m / np.linalg.norm(m, axis=1, keepdims=True) 53 | m = np.pad(m, [[0, 1], [0, 1]], mode='constant') 54 | m[3, 3] = 1.0 55 | m[:3, 3] = np.random.uniform(-t, t, size=[3]) 56 | return torch.tensor(m, dtype=torch.float32, device=device) 57 | 58 | def rotate_x(a, device=None): 59 | s, c = np.sin(a), np.cos(a) 60 | return torch.tensor([[1, 0, 0, 0], 61 | [0, c, s, 0], 62 | [0, -s, c, 0], 63 | [0, 0, 0, 1]], dtype=torch.float32, device=device) 64 | 65 | def rotate_y(a, device=None): 66 | s, c = np.sin(a), np.cos(a) 67 | return torch.tensor([[ c, 0, s, 0], 68 | [ 0, 1, 0, 0], 69 | [-s, 0, c, 0], 70 | [ 0, 0, 0, 1]], dtype=torch.float32, device=device) 71 | 72 | class Mesh: 73 | def __init__(self, vertices, faces): 74 | self.vertices = vertices 75 | self.faces = faces 76 | 77 | def auto_normals(self): 78 | v0 = self.vertices[self.faces[:, 0], :] 79 | v1 = self.vertices[self.faces[:, 1], :] 80 | v2 = self.vertices[self.faces[:, 2], :] 81 | nrm = safe_normalize(torch.cross(v1 - v0, v2 - v0)) 82 | self.nrm = nrm 83 | 84 | def load_mesh(path, device): 85 | mesh_np = trimesh.load(path) 86 | vertices = torch.tensor(mesh_np.vertices, device=device, dtype=torch.float) 87 | faces = torch.tensor(mesh_np.faces, device=device, dtype=torch.long) 88 | 89 | # Normalize 90 | vmin, vmax = vertices.min(dim=0)[0], vertices.max(dim=0)[0] 91 | scale = 1.8 / torch.max(vmax - vmin).item() 92 | vertices = vertices - (vmax + vmin) / 2 # Center mesh on origin 93 | vertices = vertices * scale # Rescale to [-0.9, 0.9] 94 | return Mesh(vertices, faces) 95 | 96 | def compute_sdf(points, vertices, faces): 97 | face_vertices = kaolin.ops.mesh.index_vertices_by_faces(vertices.clone().unsqueeze(0), faces) 98 | distance = kaolin.metrics.trianglemesh.point_to_mesh_distance(points.unsqueeze(0), face_vertices)[0] 99 | with torch.no_grad(): 100 | sign = (kaolin.ops.mesh.check_sign(vertices.unsqueeze(0), faces, points.unsqueeze(0))<1).float() * 2 - 1 101 | sdf = (sign*distance).squeeze(0) 102 | return sdf 103 | 104 | def sample_random_points(n, mesh): 105 | pts_random = (torch.rand((n//2,3),device='cuda') - 0.5) * 2 106 | pts_surface = kaolin.ops.mesh.sample_points(mesh.vertices.unsqueeze(0), mesh.faces, 500)[0].squeeze(0) 107 | pts_surface += torch.randn_like(pts_surface) * 0.05 108 | pts = torch.cat([pts_random, pts_surface]) 109 | return pts 110 | 111 | def xfm_points(points, matrix): 112 | '''Transform points. 113 | Args: 114 | points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] 115 | matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] 116 | use_python: Use PyTorch's torch.matmul (for validation) 117 | Returns: 118 | Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. 119 | ''' 120 | out = torch.matmul( 121 | torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) 122 | if torch.is_anomaly_enabled(): 123 | assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" 124 | return out 125 | 126 | def interpolate(attr, rast, attr_idx, rast_db=None): 127 | return dr.interpolate( 128 | attr, rast, attr_idx, rast_db=rast_db, 129 | diff_attrs=None if rast_db is None else 'all') -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/images/ablate_L_dev.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/trellis/representations/mesh/flexicubes/images/ablate_L_dev.jpg -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/images/block_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/trellis/representations/mesh/flexicubes/images/block_final.png -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/images/block_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/trellis/representations/mesh/flexicubes/images/block_init.png -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/images/teaser_top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/trellis/representations/mesh/flexicubes/images/teaser_top.png -------------------------------------------------------------------------------- /trellis/representations/mesh/utils_cube.py: -------------------------------------------------------------------------------- 1 | import torch 2 | cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ 3 | 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) 4 | cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) 5 | cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 6 | 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) 7 | 8 | def construct_dense_grid(res, device='cuda'): 9 | '''construct a dense grid based on resolution''' 10 | res_v = res + 1 11 | vertsid = torch.arange(res_v ** 3, device=device) 12 | coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() 13 | cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] 14 | cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) 15 | verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) 16 | return verts, cube_fx8 17 | 18 | 19 | def construct_voxel_grid(coords): 20 | verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) 21 | verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) 22 | cubes = inverse_indices.reshape(-1, 8) 23 | return verts_unique, cubes 24 | 25 | 26 | def cubes_to_verts(num_verts, cubes, value, reduce='mean'): 27 | """ 28 | Args: 29 | cubes [Vx8] verts index for each cube 30 | value [Vx8xM] value to be scattered 31 | Operation: 32 | reduced[cubes[i][j]][k] += value[i][k] 33 | """ 34 | M = value.shape[2] # number of channels 35 | reduced = torch.zeros(num_verts, M, device=cubes.device) 36 | return torch.scatter_reduce(reduced, 0, 37 | cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), 38 | value.flatten(0, 1), reduce=reduce, include_self=False) 39 | 40 | def sparse_cube2verts(coords, feats, training=True): 41 | new_coords, cubes = construct_voxel_grid(coords) 42 | new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) 43 | if training: 44 | con_loss = torch.mean((feats - new_feats[cubes]) ** 2) 45 | else: 46 | con_loss = 0.0 47 | return new_coords, new_feats, con_loss 48 | 49 | 50 | def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): 51 | F = feats.shape[-1] 52 | dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) 53 | if sdf_init: 54 | dense_attrs[..., 0] = 1 # initial outside sdf value 55 | dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats 56 | return dense_attrs.reshape(-1, F) 57 | 58 | 59 | def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): 60 | return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) 61 | -------------------------------------------------------------------------------- /trellis/representations/octree/__init__.py: -------------------------------------------------------------------------------- 1 | from .octree_dfs import DfsOctree -------------------------------------------------------------------------------- /trellis/representations/radiance_field/__init__.py: -------------------------------------------------------------------------------- 1 | from .strivec import Strivec -------------------------------------------------------------------------------- /trellis/representations/radiance_field/strivec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from ..octree import DfsOctree as Octree 6 | 7 | 8 | class Strivec(Octree): 9 | def __init__( 10 | self, 11 | resolution: int, 12 | aabb: list, 13 | sh_degree: int = 0, 14 | rank: int = 8, 15 | dim: int = 8, 16 | device: str = "cuda", 17 | ): 18 | assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" 19 | self.resolution = resolution 20 | depth = int(np.round(np.log2(resolution))) 21 | super().__init__( 22 | depth=depth, 23 | aabb=aabb, 24 | sh_degree=sh_degree, 25 | primitive="trivec", 26 | primitive_config={"rank": rank, "dim": dim}, 27 | device=device, 28 | ) 29 | -------------------------------------------------------------------------------- /trellis/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kg-git-dev/trellis-refactored/f7aa4fd508853a61e3567e61f58622ce725df3e4/trellis/utils/__init__.py -------------------------------------------------------------------------------- /trellis/utils/random_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] 4 | 5 | def radical_inverse(base, n): 6 | val = 0 7 | inv_base = 1.0 / base 8 | inv_base_n = inv_base 9 | while n > 0: 10 | digit = n % base 11 | val += digit * inv_base_n 12 | n //= base 13 | inv_base_n *= inv_base 14 | return val 15 | 16 | def halton_sequence(dim, n): 17 | return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] 18 | 19 | def hammersley_sequence(dim, n, num_samples): 20 | return [n / num_samples] + halton_sequence(dim - 1, n) 21 | 22 | def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): 23 | u, v = hammersley_sequence(2, n, num_samples) 24 | u += offset[0] / num_samples 25 | v += offset[1] 26 | if remap: 27 | u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 28 | theta = np.arccos(1 - 2 * u) - np.pi / 2 29 | phi = v * 2 * np.pi 30 | return [phi, theta] -------------------------------------------------------------------------------- /trellis/utils/render_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | import utils3d 5 | from PIL import Image 6 | 7 | from ..renderers import OctreeRenderer, MeshRenderer, GSplatRenderer 8 | from ..representations import Octree, Gaussian, MeshExtractResult 9 | from ..modules import sparse as sp 10 | from .random_utils import sphere_hammersley_sequence 11 | 12 | 13 | def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): 14 | is_list = isinstance(yaws, list) 15 | if not is_list: 16 | yaws = [yaws] 17 | pitchs = [pitchs] 18 | if not isinstance(rs, list): 19 | rs = [rs] * len(yaws) 20 | if not isinstance(fovs, list): 21 | fovs = [fovs] * len(yaws) 22 | extrinsics = [] 23 | intrinsics = [] 24 | for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): 25 | fov = torch.deg2rad(torch.tensor(float(fov))).cuda() 26 | yaw = torch.tensor(float(yaw)).cuda() 27 | pitch = torch.tensor(float(pitch)).cuda() 28 | orig = torch.tensor([ 29 | torch.sin(yaw) * torch.cos(pitch), 30 | torch.cos(yaw) * torch.cos(pitch), 31 | torch.sin(pitch), 32 | ]).cuda() * r 33 | extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) 34 | intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) 35 | extrinsics.append(extr) 36 | intrinsics.append(intr) 37 | if not is_list: 38 | extrinsics = extrinsics[0] 39 | intrinsics = intrinsics[0] 40 | return extrinsics, intrinsics 41 | 42 | 43 | def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs): 44 | if isinstance(sample, Octree): 45 | renderer = OctreeRenderer() 46 | renderer.rendering_options.resolution = options.get('resolution', 512) 47 | renderer.rendering_options.near = options.get('near', 0.8) 48 | renderer.rendering_options.far = options.get('far', 1.6) 49 | renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) 50 | renderer.rendering_options.ssaa = options.get('ssaa', 4) 51 | renderer.pipe.primitive = sample.primitive 52 | elif isinstance(sample, Gaussian): 53 | renderer = GSplatRenderer() 54 | renderer.rendering_options.resolution = options.get('resolution', 512) 55 | renderer.rendering_options.near = options.get('near', 0.8) 56 | renderer.rendering_options.far = options.get('far', 1.6) 57 | renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) 58 | renderer.rendering_options.ssaa = options.get('ssaa', 1) 59 | renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1) 60 | renderer.pipe.use_mip_gaussian = True 61 | elif isinstance(sample, MeshExtractResult): 62 | renderer = MeshRenderer() 63 | renderer.rendering_options.resolution = options.get('resolution', 512) 64 | renderer.rendering_options.near = options.get('near', 1) 65 | renderer.rendering_options.far = options.get('far', 100) 66 | renderer.rendering_options.ssaa = options.get('ssaa', 4) 67 | else: 68 | raise ValueError(f'Unsupported sample type: {type(sample)}') 69 | 70 | rets = {} 71 | for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose): 72 | if not isinstance(sample, MeshExtractResult): 73 | res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) 74 | if 'color' not in rets: rets['color'] = [] 75 | if 'depth' not in rets: rets['depth'] = [] 76 | rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) 77 | if 'percent_depth' in res: 78 | rets['depth'].append(res['percent_depth'].detach().cpu().numpy()) 79 | elif 'depth' in res: 80 | rets['depth'].append(res['depth'].detach().cpu().numpy()) 81 | else: 82 | rets['depth'].append(None) 83 | else: 84 | res = renderer.render(sample, extr, intr) 85 | if 'normal' not in rets: rets['normal'] = [] 86 | rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) 87 | return rets 88 | 89 | 90 | def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs): 91 | yaws = torch.linspace(0, 2 * 3.1415, num_frames) 92 | pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) 93 | yaws = yaws.tolist() 94 | pitch = pitch.tolist() 95 | extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) 96 | return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) 97 | 98 | 99 | def render_multiview(sample, resolution=512, nviews=30): 100 | r = 2 101 | fov = 40 102 | cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] 103 | yaws = [cam[0] for cam in cams] 104 | pitchs = [cam[1] for cam in cams] 105 | extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) 106 | res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) 107 | return res['color'], extrinsics, intrinsics 108 | 109 | 110 | def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs): 111 | yaw = [0, np.pi/2, np.pi, 3*np.pi/2] 112 | yaw_offset = offset[0] 113 | yaw = [y + yaw_offset for y in yaw] 114 | pitch = [offset[1] for _ in range(4)] 115 | extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) 116 | return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) --------------------------------------------------------------------------------