├── trellis ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── data_utils.cpython-310.pyc │ │ ├── dist_utils.cpython-310.pyc │ │ ├── loss_utils.cpython-310.pyc │ │ ├── elastic_utils.cpython-310.pyc │ │ ├── general_utils.cpython-310.pyc │ │ ├── random_utils.cpython-310.pyc │ │ ├── render_utils.cpython-310.pyc │ │ ├── grad_clip_utils.cpython-310.pyc │ │ └── postprocessing_utils.cpython-310.pyc │ ├── random_utils.py │ ├── loss_utils.py │ ├── dist_utils.py │ └── grad_clip_utils.py ├── representations │ ├── octree │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── octree_dfs.cpython-310.pyc │ ├── radiance_field │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── strivec.cpython-310.pyc │ │ └── strivec.py │ ├── gaussian │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── gaussian_model.cpython-310.pyc │ │ │ └── general_utils.cpython-310.pyc │ │ └── general_utils.py │ ├── mesh │ │ ├── __init__.py │ │ ├── flexicubes │ │ │ ├── images │ │ │ │ ├── ablate_L_dev.jpg │ │ │ │ ├── block_final.png │ │ │ │ ├── block_init.png │ │ │ │ └── teaser_top.png │ │ │ ├── __pycache__ │ │ │ │ ├── tables.cpython-310.pyc │ │ │ │ └── flexicubes.cpython-310.pyc │ │ │ ├── examples │ │ │ │ ├── download_data.py │ │ │ │ └── loss.py │ │ │ └── LICENSE.txt │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── cube2mesh.cpython-310.pyc │ │ │ └── utils_cube.cpython-310.pyc │ │ └── utils_cube.py │ ├── __pycache__ │ │ └── __init__.cpython-310.pyc │ └── __init__.py ├── modules │ ├── transformer │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── blocks.cpython-310.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── modulated.cpython-310.pyc │ ├── sparse │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ └── __pycache__ │ │ │ │ ├── blocks.cpython-310.pyc │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── modulated.cpython-310.pyc │ │ ├── __pycache__ │ │ │ ├── norm.cpython-310.pyc │ │ │ ├── basic.cpython-310.pyc │ │ │ ├── linear.cpython-310.pyc │ │ │ ├── spatial.cpython-310.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── nonlinearity.cpython-310.pyc │ │ ├── attention │ │ │ ├── __init__.py │ │ │ └── __pycache__ │ │ │ │ ├── modules.cpython-310.pyc │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── full_attn.cpython-310.pyc │ │ │ │ ├── windowed_attn.cpython-310.pyc │ │ │ │ └── serialized_attn.cpython-310.pyc │ │ ├── conv │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── conv_spconv.cpython-310.pyc │ │ │ ├── __init__.py │ │ │ ├── conv_torchsparse.py │ │ │ └── conv_spconv.py │ │ ├── linear.py │ │ ├── nonlinearity.py │ │ ├── norm.py │ │ ├── __init__.py │ │ └── spatial.py │ ├── __pycache__ │ │ ├── norm.cpython-310.pyc │ │ ├── utils.cpython-310.pyc │ │ └── spatial.cpython-310.pyc │ ├── attention │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── full_attn.cpython-310.pyc │ │ │ └── modules.cpython-310.pyc │ │ ├── __init__.py │ │ └── full_attn.py │ ├── norm.py │ ├── utils.py │ └── spatial.py ├── __pycache__ │ └── __init__.cpython-310.pyc ├── pipelines │ ├── __pycache__ │ │ ├── base.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── trellis_image_to_3d.cpython-310.pyc │ │ └── trellis_text_to_3d.cpython-310.pyc │ ├── samplers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── base.cpython-310.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── flow_euler.cpython-310.pyc │ │ │ ├── guidance_interval_mixin.cpython-310.pyc │ │ │ └── classifier_free_guidance_mixin.cpython-310.pyc │ │ ├── base.py │ │ ├── classifier_free_guidance_mixin.py │ │ └── guidance_interval_mixin.py │ ├── __init__.py │ └── base.py ├── trainers │ ├── __pycache__ │ │ ├── base.cpython-310.pyc │ │ ├── basic.cpython-310.pyc │ │ ├── utils.cpython-310.pyc │ │ └── __init__.cpython-310.pyc │ ├── flow_matching │ │ ├── __pycache__ │ │ │ ├── flow_matching.cpython-310.pyc │ │ │ ├── sparse_flow_matching.cpython-310.pyc │ │ │ └── sparse_flow_matchingnew.cpython-310.pyc │ │ └── mixins │ │ │ ├── __pycache__ │ │ │ ├── image_conditioned.cpython-310.pyc │ │ │ ├── text_conditioned.cpython-310.pyc │ │ │ └── classifier_free_guidance.cpython-310.pyc │ │ │ ├── classifier_free_guidance.py │ │ │ ├── text_conditioned.py │ │ │ └── image_conditioned.py │ ├── vae │ │ ├── __pycache__ │ │ │ ├── structured_latent_vae_mesh.cpython-310.pyc │ │ │ ├── structured_latent_vae_gaussian.cpython-310.pyc │ │ │ └── structured_latent_vae_mesh_dec.cpython-310.pyc │ │ └── sparse_structure_vae.py │ ├── utils.py │ └── __init__.py ├── datasets │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── components.cpython-310.pyc │ │ ├── sparse_feat2render.cpython-310.pyc │ │ ├── structured_latent.cpython-310.pyc │ │ ├── structured_latent2render.cpython-310.pyc │ │ └── structured_latent2render_mesh.cpython-310.pyc │ ├── __init__.py │ ├── sparse_structure.py │ ├── sparse_feat2render.py │ └── components.py ├── models │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── sparse_elastic_mixin.cpython-310.pyc │ │ ├── sparse_structure_vae.cpython-310.pyc │ │ ├── sparse_structure_flow.cpython-310.pyc │ │ └── structured_latent_flow.cpython-310.pyc │ ├── structured_latent_vae │ │ ├── __pycache__ │ │ │ ├── base.cpython-310.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── encoder.cpython-310.pyc │ │ │ ├── decoder_gs.cpython-310.pyc │ │ │ ├── decoder_rf.cpython-310.pyc │ │ │ ├── decoder_mesh.cpython-310.pyc │ │ │ └── propertyencoder.cpython-310.pyc │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── propertyencoder.py │ │ └── decoder_rf.py │ ├── sparse_elastic_mixin.py │ └── __init__.py ├── renderers │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── sh_utils.cpython-310.pyc │ │ ├── mesh_renderer.cpython-310.pyc │ │ ├── gaussian_render.cpython-310.pyc │ │ └── octree_renderer.cpython-310.pyc │ ├── __init__.py │ └── sh_utils.py └── __init__.py ├── img └── teaser.png ├── example ├── chair.png └── table.png ├── val_test_list.npy ├── dataset_toolkits ├── __pycache__ │ ├── tools.cpython-310.pyc │ ├── utils.cpython-310.pyc │ └── render.cpython-310.pyc ├── blender_script │ └── io_scene_usdz.zip ├── datasets │ ├── __pycache__ │ │ ├── Partnet.cpython-310.pyc │ │ ├── PhysXNet.cpython-310.pyc │ │ ├── ObjaverseXL.cpython-310.pyc │ │ ├── ObjaverseXL.cpython-37.pyc │ │ └── Partnet_old.cpython-310.pyc │ └── PhysXNet.py ├── tools.py ├── precess.sh ├── gen_csv.py ├── utils.py ├── stat_latent.py └── render.py ├── download_pretrain.sh ├── vox2seq ├── src │ ├── ext.cpp │ ├── hilbert.h │ ├── z_order.h │ ├── api.h │ ├── z_order.cu │ ├── api.cu │ └── hilbert.cu ├── setup.py ├── test.py ├── vox2seq │ ├── pytorch │ │ ├── __init__.py │ │ ├── default.py │ │ └── z_order.py │ └── __init__.py └── benchmark.py ├── tools ├── README.md └── merge_mesh.py ├── LICENCE └── configs ├── vae └── slat_vae_enc_dec_mesh_phy.json └── generation └── slat_flow_img_dit_L_phy.json /trellis/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trellis/representations/octree/__init__.py: -------------------------------------------------------------------------------- 1 | from .octree_dfs import DfsOctree -------------------------------------------------------------------------------- /trellis/representations/radiance_field/__init__.py: -------------------------------------------------------------------------------- 1 | from .strivec import Strivec -------------------------------------------------------------------------------- /trellis/representations/gaussian/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_model import Gaussian -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/img/teaser.png -------------------------------------------------------------------------------- /example/chair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/example/chair.png -------------------------------------------------------------------------------- /example/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/example/table.png -------------------------------------------------------------------------------- /trellis/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .modulated import * -------------------------------------------------------------------------------- /val_test_list.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/val_test_list.npy -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .modulated import * -------------------------------------------------------------------------------- /trellis/representations/mesh/__init__.py: -------------------------------------------------------------------------------- 1 | from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult 2 | -------------------------------------------------------------------------------- /trellis/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/__pycache__/norm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/__pycache__/norm.cpython-310.pyc -------------------------------------------------------------------------------- /dataset_toolkits/__pycache__/tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/__pycache__/tools.cpython-310.pyc -------------------------------------------------------------------------------- /dataset_toolkits/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /dataset_toolkits/blender_script/io_scene_usdz.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/blender_script/io_scene_usdz.zip -------------------------------------------------------------------------------- /trellis/modules/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/__pycache__/basic.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/__pycache__/basic.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dataset_toolkits/__pycache__/render.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/__pycache__/render.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/__pycache__/spatial.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/__pycache__/spatial.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Sampler 2 | from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler -------------------------------------------------------------------------------- /trellis/trainers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/data_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/data_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/dist_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/dist_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/loss_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/loss_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/datasets/__pycache__/components.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/datasets/__pycache__/components.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/__pycache__/norm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/__pycache__/norm.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .full_attn import * 2 | from .serialized_attn import * 3 | from .windowed_attn import * 4 | from .modules import * 5 | -------------------------------------------------------------------------------- /trellis/pipelines/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/renderers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/renderers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/renderers/__pycache__/sh_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/renderers/__pycache__/sh_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/elastic_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/elastic_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/general_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/general_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/random_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/random_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/render_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/render_utils.cpython-310.pyc -------------------------------------------------------------------------------- /download_pretrain.sh: -------------------------------------------------------------------------------- 1 | huggingface-cli download microsoft/TRELLIS-image-large --local-dir pretrain/trellis 2 | huggingface-cli download Caoza/PhysXGen --local-dir pretrain 3 | -------------------------------------------------------------------------------- /trellis/modules/sparse/__pycache__/basic.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/__pycache__/linear.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/utils/__pycache__/grad_clip_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/grad_clip_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dataset_toolkits/datasets/__pycache__/Partnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/datasets/__pycache__/Partnet.cpython-310.pyc -------------------------------------------------------------------------------- /dataset_toolkits/datasets/__pycache__/PhysXNet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/datasets/__pycache__/PhysXNet.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/datasets/__pycache__/sparse_feat2render.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/datasets/__pycache__/sparse_feat2render.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/datasets/__pycache__/structured_latent.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/datasets/__pycache__/structured_latent.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/__pycache__/sparse_elastic_mixin.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/attention/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/attention/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/attention/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/renderers/__pycache__/octree_renderer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/renderers/__pycache__/octree_renderer.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/images/ablate_L_dev.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/flexicubes/images/ablate_L_dev.jpg -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/images/block_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/flexicubes/images/block_final.png -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/images/block_init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/flexicubes/images/block_init.png -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/images/teaser_top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/flexicubes/images/teaser_top.png -------------------------------------------------------------------------------- /trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc -------------------------------------------------------------------------------- /dataset_toolkits/datasets/__pycache__/ObjaverseXL.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/datasets/__pycache__/ObjaverseXL.cpython-310.pyc -------------------------------------------------------------------------------- /dataset_toolkits/datasets/__pycache__/ObjaverseXL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/datasets/__pycache__/ObjaverseXL.cpython-37.pyc -------------------------------------------------------------------------------- /dataset_toolkits/datasets/__pycache__/Partnet_old.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/dataset_toolkits/datasets/__pycache__/Partnet_old.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/__init__.py: -------------------------------------------------------------------------------- 1 | from . import models 2 | from . import modules 3 | from . import pipelines 4 | from . import renderers 5 | from . import representations 6 | from . import utils 7 | -------------------------------------------------------------------------------- /trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/__pycache__/trellis_text_to_3d.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/octree/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/datasets/__pycache__/structured_latent2render.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/datasets/__pycache__/structured_latent2render.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/__init__.py: -------------------------------------------------------------------------------- 1 | from .radiance_field import Strivec 2 | from .octree import DfsOctree as Octree 3 | from .gaussian import Gaussian 4 | from .mesh import MeshExtractResult 5 | -------------------------------------------------------------------------------- /trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/__pycache__/flow_matching.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/flow_matching/__pycache__/flow_matching.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/datasets/__pycache__/structured_latent2render_mesh.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/datasets/__pycache__/structured_latent2render_mesh.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__pycache__/decoder_rf.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/structured_latent_vae/__pycache__/decoder_rf.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/__pycache__/tables.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/flexicubes/__pycache__/tables.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/radiance_field/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/radiance_field/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/radiance_field/__pycache__/strivec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/radiance_field/__pycache__/strivec.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/vae/__pycache__/structured_latent_vae_mesh.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/vae/__pycache__/structured_latent_vae_mesh.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/__pycache__/flexicubes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/representations/mesh/flexicubes/__pycache__/flexicubes.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__pycache__/propertyencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/models/structured_latent_vae/__pycache__/propertyencoder.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/__pycache__/sparse_flow_matching.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/flow_matching/__pycache__/sparse_flow_matching.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/vae/__pycache__/structured_latent_vae_gaussian.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/vae/__pycache__/structured_latent_vae_gaussian.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/vae/__pycache__/structured_latent_vae_mesh_dec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/vae/__pycache__/structured_latent_vae_mesh_dec.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/__pycache__/sparse_flow_matchingnew.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/flow_matching/__pycache__/sparse_flow_matchingnew.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/mixins/__pycache__/image_conditioned.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/flow_matching/mixins/__pycache__/image_conditioned.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/mixins/__pycache__/text_conditioned.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/flow_matching/mixins/__pycache__/text_conditioned.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/mixins/__pycache__/classifier_free_guidance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziangcao0312/PhysX-3D/HEAD/trellis/trainers/flow_matching/mixins/__pycache__/classifier_free_guidance.cpython-310.pyc -------------------------------------------------------------------------------- /vox2seq/src/ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "api.h" 3 | 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("z_order_encode", &z_order_encode); 7 | m.def("z_order_decode", &z_order_decode); 8 | m.def("hilbert_encode", &hilbert_encode); 9 | m.def("hilbert_decode", &hilbert_decode); 10 | } -------------------------------------------------------------------------------- /trellis/pipelines/samplers/base.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class Sampler(ABC): 6 | """ 7 | A base class for samplers. 8 | """ 9 | 10 | @abstractmethod 11 | def sample( 12 | self, 13 | model, 14 | **kwargs 15 | ): 16 | """ 17 | Sample from a model. 18 | """ 19 | pass 20 | -------------------------------------------------------------------------------- /trellis/modules/sparse/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | 5 | __all__ = [ 6 | 'SparseLinear' 7 | ] 8 | 9 | 10 | class SparseLinear(nn.Linear): 11 | def __init__(self, in_features, out_features, bias=True): 12 | super(SparseLinear, self).__init__(in_features, out_features, bias) 13 | 14 | def forward(self, input: SparseTensor) -> SparseTensor: 15 | return input.replace(super().forward(input.feats)) 16 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import SLatEncoder, ElasticSLatEncoder 2 | from .propertyencoder import PropertyEncoder, ElasticPropertyEncoder 3 | from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder 4 | from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder 5 | from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder,PropertyDecoder,ElasticPropertyDecoder,PropertyOutput,SLatMeshDecodernew, ElasticSLatMeshDecodernew 6 | -------------------------------------------------------------------------------- /trellis/pipelines/samplers/classifier_free_guidance_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | 4 | class ClassifierFreeGuidanceSamplerMixin: 5 | """ 6 | A mixin class for samplers that apply classifier-free guidance. 7 | """ 8 | 9 | def _inference_model(self, model, x_t, t, cond_aes,cond,neg_cond, cfg_strength, **kwargs): 10 | pred = super()._inference_model(model, x_t, t, cond_aes,cond, **kwargs) 11 | neg_pred = super()._inference_model(model, x_t, t, cond_aes, neg_cond , **kwargs) 12 | return (1 + cfg_strength) * pred - cfg_strength * neg_pred 13 | -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import BACKEND 2 | 3 | 4 | SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' 5 | 6 | def __from_env(): 7 | import os 8 | 9 | global SPCONV_ALGO 10 | env_spconv_algo = os.environ.get('SPCONV_ALGO') 11 | if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: 12 | SPCONV_ALGO = env_spconv_algo 13 | print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") 14 | 15 | 16 | __from_env() 17 | 18 | if BACKEND == 'torchsparse': 19 | from .conv_torchsparse import * 20 | elif BACKEND == 'spconv': 21 | from .conv_spconv import * 22 | -------------------------------------------------------------------------------- /trellis/pipelines/samplers/guidance_interval_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | 4 | class GuidanceIntervalSamplerMixin: 5 | """ 6 | A mixin class for samplers that apply classifier-free guidance with interval. 7 | """ 8 | 9 | def _inference_model(self, model, x_t, t, cond_aes, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): 10 | if cfg_interval[0] <= t <= cfg_interval[1]: 11 | pred = super()._inference_model(model, x_t, t, cond_aes, cond, **kwargs) 12 | neg_pred = super()._inference_model(model, x_t, t, cond_aes, neg_cond, **kwargs) 13 | return (1 + cfg_strength) * pred - cfg_strength * neg_pred 14 | else: 15 | return super()._inference_model(model, x_t, t, cond_aes, cond, **kwargs) 16 | -------------------------------------------------------------------------------- /dataset_toolkits/tools.py: -------------------------------------------------------------------------------- 1 | from math import * 2 | import numpy as np 3 | def VaryPoint(data, axis, degree): 4 | xyzArray = { 5 | 'X': np.array([[1, 0, 0], 6 | [0, cos(radians(degree)), -sin(radians(degree))], 7 | [0, sin(radians(degree)), cos(radians(degree))]]), 8 | 'Y': np.array([[cos(radians(degree)), 0, sin(radians(degree))], 9 | [0, 1, 0], 10 | [-sin(radians(degree)), 0, cos(radians(degree))]]), 11 | 'Z': np.array([[cos(radians(degree)), -sin(radians(degree)), 0], 12 | [sin(radians(degree)), cos(radians(degree)), 0], 13 | [0, 0, 1]])} 14 | newData = np.dot(data, xyzArray[axis]) 15 | return newData -------------------------------------------------------------------------------- /trellis/modules/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LayerNorm32(nn.LayerNorm): 6 | def forward(self, x: torch.Tensor) -> torch.Tensor: 7 | return super().forward(x.float()).type(x.dtype) 8 | 9 | 10 | class GroupNorm32(nn.GroupNorm): 11 | """ 12 | A GroupNorm layer that converts to float32 before the forward pass. 13 | """ 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | return super().forward(x.float()).type(x.dtype) 16 | 17 | 18 | class ChannelLayerNorm32(LayerNorm32): 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | DIM = x.dim() 21 | x = x.permute(0, *range(2, DIM), 1).contiguous() 22 | x = super().forward(x) 23 | x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() 24 | return x 25 | -------------------------------------------------------------------------------- /trellis/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from . import samplers 2 | from .trellis_image_to_3d import TrellisImageTo3DPipeline 3 | from .trellis_text_to_3d import TrellisTextTo3DPipeline 4 | 5 | 6 | def from_pretrained(path: str): 7 | """ 8 | Load a pipeline from a model folder or a Hugging Face model hub. 9 | 10 | Args: 11 | path: The path to the model. Can be either local path or a Hugging Face model name. 12 | """ 13 | import os 14 | import json 15 | is_local = os.path.exists(f"{path}/pipeline.json") 16 | 17 | if is_local: 18 | config_file = f"{path}/pipeline.json" 19 | else: 20 | from huggingface_hub import hf_hub_download 21 | config_file = hf_hub_download(path, "pipeline.json") 22 | 23 | with open(config_file, 'r') as f: 24 | config = json.load(f) 25 | return globals()[config['name']].from_pretrained(path) 26 | -------------------------------------------------------------------------------- /trellis/representations/radiance_field/strivec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from ..octree import DfsOctree as Octree 6 | 7 | 8 | class Strivec(Octree): 9 | def __init__( 10 | self, 11 | resolution: int, 12 | aabb: list, 13 | sh_degree: int = 0, 14 | rank: int = 8, 15 | dim: int = 8, 16 | device: str = "cuda", 17 | ): 18 | assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" 19 | self.resolution = resolution 20 | depth = int(np.round(np.log2(resolution))) 21 | super().__init__( 22 | depth=depth, 23 | aabb=aabb, 24 | sh_degree=sh_degree, 25 | primitive="trivec", 26 | primitive_config={"rank": rank, "dim": dim}, 27 | device=device, 28 | ) 29 | -------------------------------------------------------------------------------- /trellis/modules/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | BACKEND = 'flash_attn' 4 | DEBUG = False 5 | 6 | def __from_env(): 7 | import os 8 | 9 | global BACKEND 10 | global DEBUG 11 | 12 | env_attn_backend = os.environ.get('ATTN_BACKEND') 13 | env_sttn_debug = os.environ.get('ATTN_DEBUG') 14 | 15 | if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: 16 | BACKEND = env_attn_backend 17 | if env_sttn_debug is not None: 18 | DEBUG = env_sttn_debug == '1' 19 | 20 | print(f"[ATTENTION] Using backend: {BACKEND}") 21 | 22 | 23 | __from_env() 24 | 25 | 26 | def set_backend(backend: Literal['xformers', 'flash_attn']): 27 | global BACKEND 28 | BACKEND = backend 29 | 30 | def set_debug(debug: bool): 31 | global DEBUG 32 | DEBUG = debug 33 | 34 | 35 | from .full_attn import * 36 | from .modules import * 37 | -------------------------------------------------------------------------------- /vox2seq/src/hilbert.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /** 4 | * Hilbert encode 3D points 5 | * 6 | * @param x [N] tensor containing the x coordinates 7 | * @param y [N] tensor containing the y coordinates 8 | * @param z [N] tensor containing the z coordinates 9 | * 10 | * @return [N] tensor containing the z-order encoded values 11 | */ 12 | __global__ void hilbert_encode_cuda( 13 | size_t N, 14 | const uint32_t* x, 15 | const uint32_t* y, 16 | const uint32_t* z, 17 | uint32_t* codes 18 | ); 19 | 20 | 21 | /** 22 | * Hilbert decode 3D points 23 | * 24 | * @param codes [N] tensor containing the z-order encoded values 25 | * @param x [N] tensor containing the x coordinates 26 | * @param y [N] tensor containing the y coordinates 27 | * @param z [N] tensor containing the z coordinates 28 | */ 29 | __global__ void hilbert_decode_cuda( 30 | size_t N, 31 | const uint32_t* codes, 32 | uint32_t* x, 33 | uint32_t* y, 34 | uint32_t* z 35 | ); 36 | -------------------------------------------------------------------------------- /vox2seq/src/z_order.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /** 4 | * Z-order encode 3D points 5 | * 6 | * @param x [N] tensor containing the x coordinates 7 | * @param y [N] tensor containing the y coordinates 8 | * @param z [N] tensor containing the z coordinates 9 | * 10 | * @return [N] tensor containing the z-order encoded values 11 | */ 12 | __global__ void z_order_encode_cuda( 13 | size_t N, 14 | const uint32_t* x, 15 | const uint32_t* y, 16 | const uint32_t* z, 17 | uint32_t* codes 18 | ); 19 | 20 | 21 | /** 22 | * Z-order decode 3D points 23 | * 24 | * @param codes [N] tensor containing the z-order encoded values 25 | * @param x [N] tensor containing the x coordinates 26 | * @param y [N] tensor containing the y coordinates 27 | * @param z [N] tensor containing the z coordinates 28 | */ 29 | __global__ void z_order_decode_cuda( 30 | size_t N, 31 | const uint32_t* codes, 32 | uint32_t* x, 33 | uint32_t* y, 34 | uint32_t* z 35 | ); 36 | -------------------------------------------------------------------------------- /vox2seq/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | os.path.dirname(os.path.abspath(__file__)) 16 | 17 | setup( 18 | name="vox2seq", 19 | packages=['vox2seq', 'vox2seq.pytorch'], 20 | ext_modules=[ 21 | CUDAExtension( 22 | name="vox2seq._C", 23 | sources=[ 24 | "src/api.cu", 25 | "src/z_order.cu", 26 | "src/hilbert.cu", 27 | "src/ext.cpp", 28 | ], 29 | ) 30 | ], 31 | cmdclass={ 32 | 'build_ext': BuildExtension 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /trellis/models/sparse_elastic_mixin.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import * 3 | import math 4 | from ..modules import sparse as sp 5 | from ..utils.elastic_utils import ElasticModuleMixin 6 | 7 | 8 | class SparseTransformerElasticMixin(ElasticModuleMixin): 9 | def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): 10 | return x.feats.shape[0] 11 | 12 | @contextmanager 13 | def with_mem_ratio(self, mem_ratio=1.0): 14 | if mem_ratio == 1.0: 15 | yield 1.0 16 | return 17 | num_blocks = len(self.blocks) 18 | num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks) 19 | exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks 20 | for i in range(num_blocks): 21 | self.blocks[i].use_checkpoint = i < num_checkpoint_blocks 22 | yield exact_mem_ratio 23 | for i in range(num_blocks): 24 | self.blocks[i].use_checkpoint = False 25 | -------------------------------------------------------------------------------- /trellis/utils/random_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] 4 | 5 | def radical_inverse(base, n): 6 | val = 0 7 | inv_base = 1.0 / base 8 | inv_base_n = inv_base 9 | while n > 0: 10 | digit = n % base 11 | val += digit * inv_base_n 12 | n //= base 13 | inv_base_n *= inv_base 14 | return val 15 | 16 | def halton_sequence(dim, n): 17 | return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] 18 | 19 | def hammersley_sequence(dim, n, num_samples): 20 | return [n / num_samples] + halton_sequence(dim - 1, n) 21 | 22 | def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): 23 | u, v = hammersley_sequence(2, n, num_samples) 24 | u += offset[0] / num_samples 25 | v += offset[1] 26 | if remap: 27 | u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 28 | theta = np.arccos(1 - 2 * u) - np.pi / 2 29 | phi = v * 2 * np.pi 30 | return [phi, theta] -------------------------------------------------------------------------------- /dataset_toolkits/precess.sh: -------------------------------------------------------------------------------- 1 | 2 | python merge_property.py --datapath ./physxnet # your physxnet path 3 | 4 | python gen_csv.py 5 | 6 | python retrieval_texture_example.py 7 | 8 | python build_metadata.py PhysXNet --output_dir ../datasets/PhysXNet 9 | 10 | python render.py PhysXNet --output_dir ../datasets/PhysXNet 11 | 12 | python build_metadata.py PhysXNet --output_dir ../datasets/PhysXNet 13 | python render_cond.py PhysXNet --output_dir ../datasets/PhysXNet 14 | 15 | python build_metadata.py PhysXNet --output_dir ../datasets/PhysXNet 16 | 17 | python voxelize.py PhysXNet --output_dir ../datasets/PhysXNet 18 | python build_metadata.py PhysXNet --output_dir ../datasets/PhysXNet 19 | 20 | python extract_feature.py --output_dir ../datasets/PhysXNet 21 | python build_metadata.py PhysXNet --output_dir ../datasets/PhysXNet 22 | 23 | 24 | python encode_latent.py --output_dir ../datasets/PhysXNet 25 | python encode_latent_phy.py --output_dir ../datasets/PhysXNet 26 | python build_metadata.py PhysXNet --output_dir ../datasets/PhysXNet 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /trellis/modules/sparse/nonlinearity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | 5 | __all__ = [ 6 | 'SparseReLU', 7 | 'SparseSiLU', 8 | 'SparseGELU', 9 | 'SparseActivation' 10 | ] 11 | 12 | 13 | class SparseReLU(nn.ReLU): 14 | def forward(self, input: SparseTensor) -> SparseTensor: 15 | return input.replace(super().forward(input.feats)) 16 | 17 | 18 | class SparseSiLU(nn.SiLU): 19 | def forward(self, input: SparseTensor) -> SparseTensor: 20 | return input.replace(super().forward(input.feats)) 21 | 22 | 23 | class SparseGELU(nn.GELU): 24 | def forward(self, input: SparseTensor) -> SparseTensor: 25 | return input.replace(super().forward(input.feats)) 26 | 27 | 28 | class SparseActivation(nn.Module): 29 | def __init__(self, activation: nn.Module): 30 | super().__init__() 31 | self.activation = activation 32 | 33 | def forward(self, input: SparseTensor) -> SparseTensor: 34 | return input.replace(self.activation(input.feats)) 35 | 36 | -------------------------------------------------------------------------------- /trellis/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | __attributes = { 4 | 'OctreeRenderer': 'octree_renderer', 5 | 'GaussianRenderer': 'gaussian_render', 6 | 'MeshRenderer': 'mesh_renderer', 7 | } 8 | 9 | __submodules = [] 10 | 11 | __all__ = list(__attributes.keys()) + __submodules 12 | 13 | def __getattr__(name): 14 | if name not in globals(): 15 | if name in __attributes: 16 | module_name = __attributes[name] 17 | module = importlib.import_module(f".{module_name}", __name__) 18 | globals()[name] = getattr(module, name) 19 | elif name in __submodules: 20 | module = importlib.import_module(f".{name}", __name__) 21 | globals()[name] = module 22 | else: 23 | raise AttributeError(f"module {__name__} has no attribute {name}") 24 | return globals()[name] 25 | 26 | 27 | # For Pylance 28 | if __name__ == '__main__': 29 | from .octree_renderer import OctreeRenderer 30 | from .gaussian_render import GaussianRenderer 31 | from .mesh_renderer import MeshRenderer -------------------------------------------------------------------------------- /vox2seq/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import vox2seq 3 | 4 | 5 | if __name__ == "__main__": 6 | RES = 256 7 | coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES)) 8 | coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda() 9 | code_z_cuda = vox2seq.encode(coords, mode='z_order') 10 | code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order') 11 | code_h_cuda = vox2seq.encode(coords, mode='hilbert') 12 | code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert') 13 | assert torch.equal(code_z_cuda, code_z_pytorch) 14 | assert torch.equal(code_h_cuda, code_h_pytorch) 15 | 16 | code = torch.arange(RES**3).int().cuda() 17 | coords_z_cuda = vox2seq.decode(code, mode='z_order') 18 | coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order') 19 | coords_h_cuda = vox2seq.decode(code, mode='hilbert') 20 | coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert') 21 | assert torch.equal(coords_z_cuda, coords_z_pytorch) 22 | assert torch.equal(coords_h_cuda, coords_h_pytorch) 23 | 24 | print("All tests passed.") 25 | 26 | -------------------------------------------------------------------------------- /dataset_toolkits/gen_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | import numpy as np 5 | import json 6 | savepath='../datasets/PhysXNet' 7 | finaldataset='./phy_dataset' 8 | 9 | os.makedirs(os.path.join(savepath), exist_ok=True) 10 | os.makedirs(os.path.join(savepath,'merged_records'), exist_ok=True) 11 | pathlist=[] 12 | namelist=os.listdir(finaldataset) 13 | namelist = sorted(namelist, key=lambda x: int(x)) 14 | for i in namelist: 15 | pathlist.append(os.path.join(finaldataset,i,'model_tex.obj')) 16 | 17 | namelist_=[] 18 | for i in namelist: 19 | namelist_.append(i+'_') 20 | zero=np.zeros((len(namelist_))).tolist() 21 | ten=(np.zeros((len(namelist_)))+10).tolist() 22 | false=(np.zeros((len(namelist_)))!=0).tolist() 23 | 24 | frame = pd.DataFrame({'sha256': namelist_, 'file_identifier': zero,'aesthetic_score': ten,'captions':zero,'rendered':false,'voxelized':false,'num_voxels': zero, 'cond_rendered': false,'local_path':pathlist},dtype='object') 25 | 26 | frame.to_csv(os.path.join(savepath,'metadata.csv'), index=False, sep=',') 27 | 28 | frame = pd.DataFrame({'sha256': namelist_, 'local_path':pathlist},dtype='object') 29 | 30 | frame.to_csv(os.path.join(savepath,'merged_records','1743666055_downloaded_0.csv'), index=False, sep=',') -------------------------------------------------------------------------------- /trellis/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..modules import sparse as sp 3 | 4 | FP16_MODULES = ( 5 | nn.Conv1d, 6 | nn.Conv2d, 7 | nn.Conv3d, 8 | nn.ConvTranspose1d, 9 | nn.ConvTranspose2d, 10 | nn.ConvTranspose3d, 11 | nn.Linear, 12 | sp.SparseConv3d, 13 | sp.SparseInverseConv3d, 14 | sp.SparseLinear, 15 | ) 16 | 17 | def convert_module_to_f16(l): 18 | """ 19 | Convert primitive modules to float16. 20 | """ 21 | if isinstance(l, FP16_MODULES): 22 | for p in l.parameters(): 23 | p.data = p.data.half() 24 | 25 | 26 | def convert_module_to_f32(l): 27 | """ 28 | Convert primitive modules to float32, undoing convert_module_to_f16(). 29 | """ 30 | if isinstance(l, FP16_MODULES): 31 | for p in l.parameters(): 32 | p.data = p.data.float() 33 | 34 | 35 | def zero_module(module): 36 | """ 37 | Zero out the parameters of a module and return it. 38 | """ 39 | for p in module.parameters(): 40 | p.detach().zero_() 41 | return module 42 | 43 | 44 | def scale_module(module, scale): 45 | """ 46 | Scale the parameters of a module and return it. 47 | """ 48 | for p in module.parameters(): 49 | p.detach().mul_(scale) 50 | return module 51 | 52 | 53 | def modulate(x, shift, scale): 54 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 55 | -------------------------------------------------------------------------------- /dataset_toolkits/utils.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import hashlib 3 | import numpy as np 4 | 5 | 6 | def get_file_hash(file: str) -> str: 7 | sha256 = hashlib.sha256() 8 | # Read the file from the path 9 | with open(file, "rb") as f: 10 | # Update the hash with the file content 11 | for byte_block in iter(lambda: f.read(4096), b""): 12 | sha256.update(byte_block) 13 | return sha256.hexdigest() 14 | 15 | # ===============LOW DISCREPANCY SEQUENCES================ 16 | 17 | PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] 18 | 19 | def radical_inverse(base, n): 20 | val = 0 21 | inv_base = 1.0 / base 22 | inv_base_n = inv_base 23 | while n > 0: 24 | digit = n % base 25 | val += digit * inv_base_n 26 | n //= base 27 | inv_base_n *= inv_base 28 | return val 29 | 30 | def halton_sequence(dim, n): 31 | return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] 32 | 33 | def hammersley_sequence(dim, n, num_samples): 34 | return [n / num_samples] + halton_sequence(dim - 1, n) 35 | 36 | def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)): 37 | u, v = hammersley_sequence(2, n, num_samples) 38 | u += offset[0] / num_samples 39 | v += offset[1] 40 | u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 41 | theta = np.arccos(1 - 2 * u) - np.pi / 2 42 | phi = v * 2 * np.pi 43 | return [phi, theta] 44 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # Annotation pipeline & Texture retrieval 2 | 3 | 4 | 5 | ## Annotation pipeline 6 | 7 | ### 1_gptanno.py 8 | 9 | ``` 10 | python gptanno.py 11 | ``` 12 | 13 | Get GPT-4o output results through the API and save them as JSON file 14 | 15 | ### 2_vis_kinematic.py 16 | 17 | ``` 18 | python 2_vis_kinematic.py 19 | ``` 20 | 21 | According to the preliminary results of the gpt annotation, the candidate axes and candidate points of the translation (B), rotation (C), and articulation (D) parts are saved as mesh 22 | 23 | **Note**: arrow_norm.obj is used for better visualization 24 | 25 | ### 3_human_in_the_loop annotation and check 26 | 27 | Manually select the candidate axes and points for B, C, and D 28 | 29 | 30 | 31 | ## Texture retrieval for PhysXNet 32 | 33 | Since [PartNet](https://huggingface.co/datasets/ShapeNet/PartNet-archive) has no texture information, you need to download the [ShapeNet](https://huggingface.co/datasets/ShapeNet/ShapeNetCore) dataset and save it to `./shapenet` to obtain texture information. 34 | 35 | ## merge_mesh.py 36 | 37 | ``` 38 | python merge_mesh.py 39 | ``` 40 | 41 | Merge the parts into one mesh in order to calculate the texture correspondence with the original mesh 42 | 43 | ## retrieval_texture_example.py 44 | 45 | ``` 46 | python retrieval_texture_example.py 47 | ``` 48 | 49 | Since the original meshes in ShapeNet are similar in shape to the meshes in PartNet, the nearest texture information can be obtained based on their coordinate correspondence. 50 | 51 | **Note**: finalindex.json is obtained from the metadata of PartNet. 52 | -------------------------------------------------------------------------------- /vox2seq/vox2seq/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import * 3 | 4 | from .default import ( 5 | encode, 6 | decode, 7 | z_order_encode, 8 | z_order_decode, 9 | hilbert_encode, 10 | hilbert_decode, 11 | ) 12 | 13 | 14 | @torch.no_grad() 15 | def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: 16 | """ 17 | Encodes 3D coordinates into a 30-bit code. 18 | 19 | Args: 20 | coords: a tensor of shape [N, 3] containing the 3D coordinates. 21 | permute: the permutation of the coordinates. 22 | mode: the encoding mode to use. 23 | """ 24 | if mode == 'z_order': 25 | return z_order_encode(coords[:, permute], depth=10).int() 26 | elif mode == 'hilbert': 27 | return hilbert_encode(coords[:, permute], depth=10).int() 28 | else: 29 | raise ValueError(f"Unknown encoding mode: {mode}") 30 | 31 | 32 | @torch.no_grad() 33 | def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: 34 | """ 35 | Decodes a 30-bit code into 3D coordinates. 36 | 37 | Args: 38 | code: a tensor of shape [N] containing the 30-bit code. 39 | permute: the permutation of the coordinates. 40 | mode: the decoding mode to use. 41 | """ 42 | if mode == 'z_order': 43 | return z_order_decode(code, depth=10)[:, permute].float() 44 | elif mode == 'hilbert': 45 | return hilbert_decode(code, depth=10)[:, permute].float() 46 | else: 47 | raise ValueError(f"Unknown decoding mode: {mode}") 48 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/examples/download_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | import requests 9 | from zipfile import ZipFile 10 | from tqdm import tqdm 11 | import os 12 | 13 | def download_file(url, output_path): 14 | response = requests.get(url, stream=True) 15 | response.raise_for_status() 16 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 17 | block_size = 1024 #1 Kibibyte 18 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 19 | 20 | with open(output_path, 'wb') as file: 21 | for data in response.iter_content(block_size): 22 | progress_bar.update(len(data)) 23 | file.write(data) 24 | progress_bar.close() 25 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 26 | raise Exception("ERROR, something went wrong") 27 | 28 | 29 | url = "https://vcg.isti.cnr.it/Publications/2014/MPZ14/inputmodels.zip" 30 | zip_file_path = './data/inputmodels.zip' 31 | 32 | os.makedirs('./data', exist_ok=True) 33 | 34 | download_file(url, zip_file_path) 35 | 36 | with ZipFile(zip_file_path, 'r') as zip_ref: 37 | zip_ref.extractall('./data') 38 | 39 | os.remove(zip_file_path) 40 | 41 | print("Download and extraction complete.") 42 | -------------------------------------------------------------------------------- /dataset_toolkits/datasets/PhysXNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import tarfile 5 | from concurrent.futures import ThreadPoolExecutor 6 | from tqdm import tqdm 7 | import pandas as pd 8 | from utils import get_file_hash 9 | 10 | 11 | def add_args(parser: argparse.ArgumentParser): 12 | pass 13 | 14 | 15 | def get_metadata(**kwargs): 16 | metadata = pd.read_csv("./datasets/PhysXNet/metadata.csv") 17 | return metadata 18 | 19 | 20 | 21 | def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: 22 | import os 23 | from concurrent.futures import ThreadPoolExecutor 24 | from tqdm import tqdm 25 | 26 | # load metadata 27 | metadata = metadata.to_dict('records') 28 | 29 | # processing objects 30 | records = [] 31 | max_workers = max_workers or os.cpu_count() 32 | try: 33 | with ThreadPoolExecutor(max_workers=max_workers) as executor, \ 34 | tqdm(total=len(metadata), desc=desc) as pbar: 35 | def worker(metadatum): 36 | try: 37 | local_path = metadatum['local_path'] 38 | sha256 = metadatum['sha256'] 39 | file = os.path.join(local_path) 40 | record = func(file, sha256) 41 | if record is not None: 42 | records.append(record) 43 | pbar.update() 44 | except Exception as e: 45 | print(f"Error processing object {sha256}: {e}") 46 | pbar.update() 47 | 48 | executor.map(worker, metadata) 49 | executor.shutdown(wait=True) 50 | except: 51 | print("Error happened during processing.") 52 | 53 | return pd.DataFrame.from_records(records) 54 | -------------------------------------------------------------------------------- /vox2seq/vox2seq/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import * 3 | import torch 4 | from . import _C 5 | from . import pytorch 6 | 7 | 8 | @torch.no_grad() 9 | def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: 10 | """ 11 | Encodes 3D coordinates into a 30-bit code. 12 | 13 | Args: 14 | coords: a tensor of shape [N, 3] containing the 3D coordinates. 15 | permute: the permutation of the coordinates. 16 | mode: the encoding mode to use. 17 | """ 18 | assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]" 19 | x = coords[:, permute[0]].int() 20 | y = coords[:, permute[1]].int() 21 | z = coords[:, permute[2]].int() 22 | if mode == 'z_order': 23 | return _C.z_order_encode(x, y, z) 24 | elif mode == 'hilbert': 25 | return _C.hilbert_encode(x, y, z) 26 | else: 27 | raise ValueError(f"Unknown encoding mode: {mode}") 28 | 29 | 30 | @torch.no_grad() 31 | def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: 32 | """ 33 | Decodes a 30-bit code into 3D coordinates. 34 | 35 | Args: 36 | code: a tensor of shape [N] containing the 30-bit code. 37 | permute: the permutation of the coordinates. 38 | mode: the decoding mode to use. 39 | """ 40 | assert code.ndim == 1, "Input code must be of shape [N]" 41 | if mode == 'z_order': 42 | coords = _C.z_order_decode(code) 43 | elif mode == 'hilbert': 44 | coords = _C.hilbert_decode(code) 45 | else: 46 | raise ValueError(f"Unknown decoding mode: {mode}") 47 | x = coords[permute.index(0)] 48 | y = coords[permute.index(1)] 49 | z = coords[permute.index(2)] 50 | return torch.stack([x, y, z], dim=-1) 51 | -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/conv_torchsparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .. import SparseTensor 4 | 5 | 6 | class SparseConv3d(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 8 | super(SparseConv3d, self).__init__() 9 | if 'torchsparse' not in globals(): 10 | import torchsparse 11 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) 12 | 13 | def forward(self, x: SparseTensor) -> SparseTensor: 14 | out = self.conv(x.data) 15 | new_shape = [x.shape[0], self.conv.out_channels] 16 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) 17 | out._spatial_cache = x._spatial_cache 18 | out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) 19 | return out 20 | 21 | 22 | class SparseInverseConv3d(nn.Module): 23 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 24 | super(SparseInverseConv3d, self).__init__() 25 | if 'torchsparse' not in globals(): 26 | import torchsparse 27 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) 28 | 29 | def forward(self, x: SparseTensor) -> SparseTensor: 30 | out = self.conv(x.data) 31 | new_shape = [x.shape[0], self.conv.out_channels] 32 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) 33 | out._spatial_cache = x._spatial_cache 34 | out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) 35 | return out 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /vox2seq/src/api.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Serialize a voxel grid 3 | * 4 | * Copyright (C) 2024, Jianfeng XIANG 5 | * All rights reserved. 6 | * 7 | * Licensed under The MIT License [see LICENSE for details] 8 | * 9 | * Written by Jianfeng XIANG 10 | */ 11 | 12 | #pragma once 13 | #include 14 | 15 | 16 | #define BLOCK_SIZE 256 17 | 18 | 19 | /** 20 | * Z-order encode 3D points 21 | * 22 | * @param x [N] tensor containing the x coordinates 23 | * @param y [N] tensor containing the y coordinates 24 | * @param z [N] tensor containing the z coordinates 25 | * 26 | * @return [N] tensor containing the z-order encoded values 27 | */ 28 | torch::Tensor 29 | z_order_encode( 30 | const torch::Tensor& x, 31 | const torch::Tensor& y, 32 | const torch::Tensor& z 33 | ); 34 | 35 | 36 | /** 37 | * Z-order decode 3D points 38 | * 39 | * @param codes [N] tensor containing the z-order encoded values 40 | * 41 | * @return 3 tensors [N] containing the x, y, z coordinates 42 | */ 43 | std::tuple 44 | z_order_decode( 45 | const torch::Tensor& codes 46 | ); 47 | 48 | 49 | /** 50 | * Hilbert encode 3D points 51 | * 52 | * @param x [N] tensor containing the x coordinates 53 | * @param y [N] tensor containing the y coordinates 54 | * @param z [N] tensor containing the z coordinates 55 | * 56 | * @return [N] tensor containing the Hilbert encoded values 57 | */ 58 | torch::Tensor 59 | hilbert_encode( 60 | const torch::Tensor& x, 61 | const torch::Tensor& y, 62 | const torch::Tensor& z 63 | ); 64 | 65 | 66 | /** 67 | * Hilbert decode 3D points 68 | * 69 | * @param codes [N] tensor containing the Hilbert encoded values 70 | * 71 | * @return 3 tensors [N] containing the x, y, z coordinates 72 | */ 73 | std::tuple 74 | hilbert_decode( 75 | const torch::Tensor& codes 76 | ); 77 | -------------------------------------------------------------------------------- /vox2seq/src/z_order.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | namespace cg = cooperative_groups; 8 | 9 | #include "z_order.h" 10 | 11 | 12 | // Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. 13 | static __device__ uint32_t expandBits(uint32_t v) 14 | { 15 | v = (v * 0x00010001u) & 0xFF0000FFu; 16 | v = (v * 0x00000101u) & 0x0F00F00Fu; 17 | v = (v * 0x00000011u) & 0xC30C30C3u; 18 | v = (v * 0x00000005u) & 0x49249249u; 19 | return v; 20 | } 21 | 22 | 23 | // Removes 2 zeros after each bit in a 30-bit integer. 24 | static __device__ uint32_t extractBits(uint32_t v) 25 | { 26 | v = v & 0x49249249; 27 | v = (v ^ (v >> 2)) & 0x030C30C3u; 28 | v = (v ^ (v >> 4)) & 0x0300F00Fu; 29 | v = (v ^ (v >> 8)) & 0x030000FFu; 30 | v = (v ^ (v >> 16)) & 0x000003FFu; 31 | return v; 32 | } 33 | 34 | 35 | __global__ void z_order_encode_cuda( 36 | size_t N, 37 | const uint32_t* x, 38 | const uint32_t* y, 39 | const uint32_t* z, 40 | uint32_t* codes 41 | ) { 42 | size_t thread_id = cg::this_grid().thread_rank(); 43 | if (thread_id >= N) return; 44 | 45 | uint32_t xx = expandBits(x[thread_id]); 46 | uint32_t yy = expandBits(y[thread_id]); 47 | uint32_t zz = expandBits(z[thread_id]); 48 | 49 | codes[thread_id] = xx * 4 + yy * 2 + zz; 50 | } 51 | 52 | 53 | __global__ void z_order_decode_cuda( 54 | size_t N, 55 | const uint32_t* codes, 56 | uint32_t* x, 57 | uint32_t* y, 58 | uint32_t* z 59 | ) { 60 | size_t thread_id = cg::this_grid().thread_rank(); 61 | if (thread_id >= N) return; 62 | 63 | x[thread_id] = extractBits(codes[thread_id] >> 2); 64 | y[thread_id] = extractBits(codes[thread_id] >> 1); 65 | z[thread_id] = extractBits(codes[thread_id]); 66 | } 67 | -------------------------------------------------------------------------------- /trellis/modules/spatial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: 5 | """ 6 | 3D pixel shuffle. 7 | """ 8 | B, C, H, W, D = x.shape 9 | C_ = C // scale_factor**3 10 | x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) 11 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) 12 | x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) 13 | return x 14 | 15 | 16 | def patchify(x: torch.Tensor, patch_size: int): 17 | """ 18 | Patchify a tensor. 19 | 20 | Args: 21 | x (torch.Tensor): (N, C, *spatial) tensor 22 | patch_size (int): Patch size 23 | """ 24 | DIM = x.dim() - 2 25 | for d in range(2, DIM + 2): 26 | assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" 27 | 28 | x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) 29 | x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) 30 | x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) 31 | return x 32 | 33 | 34 | def unpatchify(x: torch.Tensor, patch_size: int): 35 | """ 36 | Unpatchify a tensor. 37 | 38 | Args: 39 | x (torch.Tensor): (N, C, *spatial) tensor 40 | patch_size (int): Patch size 41 | """ 42 | DIM = x.dim() - 2 43 | assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" 44 | 45 | x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) 46 | x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) 47 | x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) 48 | return x 49 | -------------------------------------------------------------------------------- /vox2seq/benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import vox2seq 4 | 5 | 6 | if __name__ == "__main__": 7 | stats = { 8 | 'z_order_cuda': [], 9 | 'z_order_pytorch': [], 10 | 'hilbert_cuda': [], 11 | 'hilbert_pytorch': [], 12 | } 13 | RES = [16, 32, 64, 128, 256] 14 | for res in RES: 15 | coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res)) 16 | coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda() 17 | 18 | start = time.time() 19 | for _ in range(100): 20 | code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda() 21 | torch.cuda.synchronize() 22 | stats['z_order_cuda'].append((time.time() - start) / 100) 23 | 24 | start = time.time() 25 | for _ in range(100): 26 | code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda() 27 | torch.cuda.synchronize() 28 | stats['z_order_pytorch'].append((time.time() - start) / 100) 29 | 30 | start = time.time() 31 | for _ in range(100): 32 | code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda() 33 | torch.cuda.synchronize() 34 | stats['hilbert_cuda'].append((time.time() - start) / 100) 35 | 36 | start = time.time() 37 | for _ in range(100): 38 | code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda() 39 | torch.cuda.synchronize() 40 | stats['hilbert_pytorch'].append((time.time() - start) / 100) 41 | 42 | print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}") 43 | for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']): 44 | print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}") 45 | 46 | -------------------------------------------------------------------------------- /trellis/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | __attributes = { 4 | 'SparseStructure': 'sparse_structure', 5 | 6 | 'SparseFeat2Render': 'sparse_feat2render', 7 | 'SLat2Render':'structured_latent2render', 8 | 'Slat2RenderGeomesh': 'structured_latent2render_mesh', 9 | 'Slat2RenderGeo':'structured_latent2render', 10 | 11 | 'SparseStructureLatent': 'sparse_structure_latent', 12 | 'TextConditionedSparseStructureLatent': 'sparse_structure_latent', 13 | 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent', 14 | 15 | 'SLat': 'structured_latent', 16 | 'TextConditionedSLat': 'structured_latent', 17 | 'ImageConditionedSLat': 'structured_latent', 18 | } 19 | 20 | __submodules = [] 21 | 22 | __all__ = list(__attributes.keys()) + __submodules 23 | 24 | def __getattr__(name): 25 | if name not in globals(): 26 | if name in __attributes: 27 | module_name = __attributes[name] 28 | module = importlib.import_module(f".{module_name}", __name__) 29 | globals()[name] = getattr(module, name) 30 | elif name in __submodules: 31 | module = importlib.import_module(f".{name}", __name__) 32 | globals()[name] = module 33 | else: 34 | raise AttributeError(f"module {__name__} has no attribute {name}") 35 | return globals()[name] 36 | 37 | 38 | # For Pylance 39 | if __name__ == '__main__': 40 | from .sparse_structure import SparseStructure 41 | 42 | from .sparse_feat2render import SparseFeat2Render 43 | from .structured_latent2render import ( 44 | SLat2Render, 45 | Slat2RenderGeo, 46 | ) 47 | 48 | from .sparse_structure_latent import ( 49 | SparseStructureLatent, 50 | TextConditionedSparseStructureLatent, 51 | ImageConditionedSparseStructureLatent, 52 | ) 53 | 54 | from .structured_latent import ( 55 | SLat, 56 | TextConditionedSLat, 57 | ImageConditionedSLat, 58 | ) 59 | -------------------------------------------------------------------------------- /trellis/pipelines/base.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from .. import models 5 | 6 | class Pipeline: 7 | """ 8 | A base class for pipelines. 9 | """ 10 | def __init__( 11 | self, 12 | models: dict[str, nn.Module] = None, 13 | ): 14 | if models is None: 15 | return 16 | self.models = models 17 | for model in self.models.values(): 18 | model.eval() 19 | 20 | @staticmethod 21 | def from_pretrained(path: str) -> "Pipeline": 22 | """ 23 | Load a pretrained model. 24 | """ 25 | import os 26 | import json 27 | is_local = os.path.exists(f"{path}/pipeline.json") 28 | 29 | if is_local: 30 | config_file = f"{path}/pipeline.json" 31 | else: 32 | from huggingface_hub import hf_hub_download 33 | config_file = hf_hub_download(path, "pipeline.json") 34 | 35 | with open(config_file, 'r') as f: 36 | args = json.load(f)['args'] 37 | 38 | _models = {} 39 | 40 | for k, v in args['models'].items(): 41 | if 'ckpts_new' in v: 42 | _models[k] = models.from_pretrained_config(f"{path}/{v}") 43 | else: 44 | _models[k] = models.from_pretrained(f"{v}") 45 | 46 | new_pipeline = Pipeline(_models) 47 | new_pipeline._pretrained_args = args 48 | return new_pipeline 49 | 50 | @property 51 | def device(self) -> torch.device: 52 | for model in self.models.values(): 53 | if hasattr(model, 'device'): 54 | return model.device 55 | for model in self.models.values(): 56 | if hasattr(model, 'parameters'): 57 | return next(model.parameters()).device 58 | raise RuntimeError("No device found.") 59 | 60 | def to(self, device: torch.device) -> None: 61 | for model in self.models.values(): 62 | model.to(device) 63 | 64 | def cuda(self) -> None: 65 | self.to(torch.device("cuda")) 66 | 67 | def cpu(self) -> None: 68 | self.to(torch.device("cpu")) 69 | -------------------------------------------------------------------------------- /vox2seq/vox2seq/pytorch/default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .z_order import xyz2key as z_order_encode_ 3 | from .z_order import key2xyz as z_order_decode_ 4 | from .hilbert import encode as hilbert_encode_ 5 | from .hilbert import decode as hilbert_decode_ 6 | 7 | 8 | @torch.inference_mode() 9 | def encode(grid_coord, batch=None, depth=16, order="z"): 10 | assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} 11 | if order == "z": 12 | code = z_order_encode(grid_coord, depth=depth) 13 | elif order == "z-trans": 14 | code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) 15 | elif order == "hilbert": 16 | code = hilbert_encode(grid_coord, depth=depth) 17 | elif order == "hilbert-trans": 18 | code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) 19 | else: 20 | raise NotImplementedError 21 | if batch is not None: 22 | batch = batch.long() 23 | code = batch << depth * 3 | code 24 | return code 25 | 26 | 27 | @torch.inference_mode() 28 | def decode(code, depth=16, order="z"): 29 | assert order in {"z", "hilbert"} 30 | batch = code >> depth * 3 31 | code = code & ((1 << depth * 3) - 1) 32 | if order == "z": 33 | grid_coord = z_order_decode(code, depth=depth) 34 | elif order == "hilbert": 35 | grid_coord = hilbert_decode(code, depth=depth) 36 | else: 37 | raise NotImplementedError 38 | return grid_coord, batch 39 | 40 | 41 | def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): 42 | x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() 43 | # we block the support to batch, maintain batched code in Point class 44 | code = z_order_encode_(x, y, z, b=None, depth=depth) 45 | return code 46 | 47 | 48 | def z_order_decode(code: torch.Tensor, depth): 49 | x, y, z, _ = z_order_decode_(code, depth=depth) 50 | grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) 51 | return grid_coord 52 | 53 | 54 | def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): 55 | return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) 56 | 57 | 58 | def hilbert_decode(code: torch.Tensor, depth: int = 16): 59 | return hilbert_decode_(code, num_dims=3, num_bits=depth) -------------------------------------------------------------------------------- /trellis/modules/sparse/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | from . import DEBUG 5 | 6 | __all__ = [ 7 | 'SparseGroupNorm', 8 | 'SparseLayerNorm', 9 | 'SparseGroupNorm32', 10 | 'SparseLayerNorm32', 11 | ] 12 | 13 | 14 | class SparseGroupNorm(nn.GroupNorm): 15 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): 16 | super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) 17 | 18 | def forward(self, input: SparseTensor) -> SparseTensor: 19 | nfeats = torch.zeros_like(input.feats) 20 | for k in range(input.shape[0]): 21 | if DEBUG: 22 | assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" 23 | bfeats = input.feats[input.layout[k]] 24 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) 25 | bfeats = super().forward(bfeats) 26 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) 27 | nfeats[input.layout[k]] = bfeats 28 | return input.replace(nfeats) 29 | 30 | 31 | class SparseLayerNorm(nn.LayerNorm): 32 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 33 | super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) 34 | 35 | def forward(self, input: SparseTensor) -> SparseTensor: 36 | nfeats = torch.zeros_like(input.feats) 37 | for k in range(input.shape[0]): 38 | bfeats = input.feats[input.layout[k]] 39 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) 40 | bfeats = super().forward(bfeats) 41 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) 42 | nfeats[input.layout[k]] = bfeats 43 | return input.replace(nfeats) 44 | 45 | 46 | class SparseGroupNorm32(SparseGroupNorm): 47 | """ 48 | A GroupNorm layer that converts to float32 before the forward pass. 49 | """ 50 | def forward(self, x: SparseTensor) -> SparseTensor: 51 | return super().forward(x.float()).type(x.dtype) 52 | 53 | class SparseLayerNorm32(SparseLayerNorm): 54 | """ 55 | A LayerNorm layer that converts to float32 before the forward pass. 56 | """ 57 | def forward(self, x: SparseTensor) -> SparseTensor: 58 | return super().forward(x.float()).type(x.dtype) 59 | -------------------------------------------------------------------------------- /tools/merge_mesh.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import trimesh 4 | import argparse 5 | import logging 6 | 7 | def get_logger(filename, verbosity=1, name=None): 8 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 9 | formatter = logging.Formatter( 10 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 11 | ) 12 | logger = logging.getLogger(name) 13 | logger.setLevel(level_dict[verbosity]) 14 | 15 | fh = logging.FileHandler(filename, "w") 16 | fh.setFormatter(formatter) 17 | logger.addHandler(fh) 18 | 19 | sh = logging.StreamHandler() 20 | sh.setFormatter(formatter) 21 | logger.addHandler(sh) 22 | 23 | return logger 24 | 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--index", type=int, default=0) 30 | parser.add_argument("--range", type=int, default=1000) 31 | args = parser.parse_args() 32 | 33 | 34 | 35 | fianljson='./physxnet/finaljson' 36 | partsegpath='./physxnet/partseg' 37 | 38 | savepath='./phy_dataset/' 39 | 40 | namelist=os.listdir(fianljson) 41 | namelist=namelist[args.index*args.range:(args.index+1)*args.range] 42 | 43 | os.makedirs(savepath, exist_ok=True) 44 | logger = get_logger(os.path.join('./output_physxnet','exp_merge'+str(args.index)+'.log'),verbosity=1) 45 | 46 | logger.info('start') 47 | 48 | existfile=os.listdir(savepath) 49 | 50 | for name in namelist: 51 | name=name[:-5] 52 | logger.info('begin: '+name) 53 | 54 | if not os.path.exists(os.path.join(savepath,name,'model.obj')): 55 | 56 | 57 | namepath=os.path.join(partsegpath,name,'objs') 58 | objlist = sorted(os.listdir(namepath), key=lambda x: int(x.split('.')[0])) 59 | 60 | for objname in range(len(objlist)): 61 | 62 | newrotationarrow=trimesh.load(os.path.join(namepath,objlist[objname]),force='mesh') 63 | 64 | if objname==0: 65 | combinedarrow=newrotationarrow 66 | 67 | else: 68 | combinedarrow = trimesh.util.concatenate([combinedarrow,newrotationarrow]) 69 | combinedarrow.merge_vertices() 70 | 71 | 72 | combinedarrow.export(os.path.join(savepath,name,'model.obj')) 73 | logger.info('success: '+name) 74 | else: 75 | logger.info('skip: '+name) 76 | -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/mixins/classifier_free_guidance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from ....utils.general_utils import dict_foreach 4 | from ....pipelines import samplers 5 | 6 | 7 | class ClassifierFreeGuidanceMixin: 8 | def __init__(self, *args, p_uncond: float = 0.1, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.p_uncond = p_uncond 11 | 12 | def get_cond(self, cond, neg_cond=None, **kwargs): 13 | """ 14 | Get the conditioning data. 15 | """ 16 | assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" 17 | 18 | if self.p_uncond > 0: 19 | # randomly drop the class label 20 | def get_batch_size(cond): 21 | if isinstance(cond, torch.Tensor): 22 | return cond.shape[0] 23 | elif isinstance(cond, list): 24 | return len(cond) 25 | else: 26 | raise ValueError(f"Unsupported type of cond: {type(cond)}") 27 | 28 | ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]] 29 | B = get_batch_size(ref_cond) 30 | 31 | def select(cond, neg_cond, mask): 32 | if isinstance(cond, torch.Tensor): 33 | mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1)) 34 | return torch.where(mask, neg_cond, cond) 35 | elif isinstance(cond, list): 36 | return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)] 37 | else: 38 | raise ValueError(f"Unsupported type of cond: {type(cond)}") 39 | 40 | mask = list(np.random.rand(B) < self.p_uncond) 41 | if not isinstance(cond, dict): 42 | cond = select(cond, neg_cond, mask) 43 | else: 44 | cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask)) 45 | 46 | return cond 47 | 48 | def get_inference_cond(self, cond, neg_cond=None, **kwargs): 49 | """ 50 | Get the conditioning data for inference. 51 | """ 52 | assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" 53 | return {'cond': cond, 'neg_cond': neg_cond, **kwargs} 54 | 55 | def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler: 56 | """ 57 | Get the sampler for the diffusion process. 58 | """ 59 | return samplers.FlowEulerCfgSampler(self.sigma_min) 60 | -------------------------------------------------------------------------------- /trellis/trainers/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # FP16 utils 5 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 6 | 7 | def make_master_params(model_params): 8 | """ 9 | Copy model parameters into a inflated tensor of full-precision parameters. 10 | """ 11 | master_params = _flatten_dense_tensors( 12 | [param.detach().float() for param in model_params] 13 | ) 14 | master_params = nn.Parameter(master_params) 15 | master_params.requires_grad = True 16 | return [master_params] 17 | 18 | 19 | def unflatten_master_params(model_params, master_params): 20 | """ 21 | Unflatten the master parameters to look like model_params. 22 | """ 23 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 24 | 25 | 26 | def model_params_to_master_params(model_params, master_params): 27 | """ 28 | Copy the model parameter data into the master parameters. 29 | """ 30 | master_params[0].detach().copy_( 31 | _flatten_dense_tensors([param.detach().float() for param in model_params]) 32 | ) 33 | 34 | 35 | def master_params_to_model_params(model_params, master_params): 36 | """ 37 | Copy the master parameter data back into the model parameters. 38 | """ 39 | for param, master_param in zip( 40 | model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params) 41 | ): 42 | param.detach().copy_(master_param) 43 | 44 | 45 | def model_grads_to_master_grads(model_params, master_params): 46 | """ 47 | Copy the gradients from the model parameters into the master parameters 48 | from make_master_params(). 49 | """ 50 | master_params[0].grad = _flatten_dense_tensors( 51 | [param.grad.data.detach().float() for param in model_params] 52 | ) 53 | 54 | 55 | def zero_grad(model_params): 56 | for param in model_params: 57 | if param.grad is not None: 58 | if param.grad.grad_fn is not None: 59 | param.grad.detach_() 60 | else: 61 | param.grad.requires_grad_(False) 62 | param.grad.zero_() 63 | 64 | 65 | # LR Schedulers 66 | from torch.optim.lr_scheduler import LambdaLR 67 | 68 | class LinearWarmupLRScheduler(LambdaLR): 69 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 70 | self.warmup_steps = warmup_steps 71 | super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 72 | 73 | def lr_lambda(self, current_step): 74 | if current_step < self.warmup_steps: 75 | return float(current_step + 1) / self.warmup_steps 76 | return 1.0 77 | -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/mixins/text_conditioned.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import os 3 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' 4 | import torch 5 | from transformers import AutoTokenizer, CLIPTextModel 6 | 7 | from ....utils import dist_utils 8 | 9 | 10 | class TextConditionedMixin: 11 | """ 12 | Mixin for text-conditioned models. 13 | 14 | Args: 15 | text_cond_model: The text conditioning model. 16 | """ 17 | def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.text_cond_model_name = text_cond_model 20 | self.text_cond_model = None # the model is init lazily 21 | 22 | def _init_text_cond_model(self): 23 | """ 24 | Initialize the text conditioning model. 25 | """ 26 | # load model 27 | with dist_utils.local_master_first(): 28 | model = CLIPTextModel.from_pretrained(self.text_cond_model_name) 29 | tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name) 30 | model.eval() 31 | model = model.cuda() 32 | self.text_cond_model = { 33 | 'model': model, 34 | 'tokenizer': tokenizer, 35 | } 36 | self.text_cond_model['null_cond'] = self.encode_text(['']) 37 | 38 | @torch.no_grad() 39 | def encode_text(self, text: List[str]) -> torch.Tensor: 40 | """ 41 | Encode the text. 42 | """ 43 | assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond" 44 | if self.text_cond_model is None: 45 | self._init_text_cond_model() 46 | encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') 47 | tokens = encoding['input_ids'].cuda() 48 | embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state 49 | 50 | return embeddings 51 | 52 | def get_cond(self, cond, **kwargs): 53 | """ 54 | Get the conditioning data. 55 | """ 56 | cond = self.encode_text(cond) 57 | kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) 58 | cond = super().get_cond(cond, **kwargs) 59 | return cond 60 | 61 | def get_inference_cond(self, cond, **kwargs): 62 | """ 63 | Get the conditioning data for inference. 64 | """ 65 | cond = self.encode_text(cond) 66 | kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) 67 | cond = super().get_inference_cond(cond, **kwargs) 68 | return cond 69 | -------------------------------------------------------------------------------- /trellis/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | __attributes = { 4 | 'BasicTrainer': 'basic', 5 | 6 | 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae', 7 | 8 | 'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian', 9 | 'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec', 10 | 'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec', 11 | 'SLatVaeMeshTrainer': 'vae.structured_latent_vae_mesh', 12 | 13 | 'FlowMatchingTrainer': 'flow_matching.flow_matching', 14 | 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching', 15 | 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', 16 | 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', 17 | 18 | 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching', 19 | 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', 20 | 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', 21 | 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', 22 | 'ImageConditionedSparseFlowMatchingCFGTrainerphy': 'flow_matching.sparse_flow_matchingnew', 23 | } 24 | 25 | __submodules = [] 26 | 27 | __all__ = list(__attributes.keys()) + __submodules 28 | 29 | def __getattr__(name): 30 | if name not in globals(): 31 | if name in __attributes: 32 | module_name = __attributes[name] 33 | module = importlib.import_module(f".{module_name}", __name__) 34 | globals()[name] = getattr(module, name) 35 | elif name in __submodules: 36 | module = importlib.import_module(f".{name}", __name__) 37 | globals()[name] = module 38 | else: 39 | raise AttributeError(f"module {__name__} has no attribute {name}") 40 | return globals()[name] 41 | 42 | 43 | # For Pylance 44 | if __name__ == '__main__': 45 | from .basic import BasicTrainer 46 | 47 | from .vae.sparse_structure_vae import SparseStructureVaeTrainer 48 | 49 | from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer 50 | from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer 51 | from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer 52 | 53 | from .flow_matching.flow_matching import ( 54 | FlowMatchingTrainer, 55 | FlowMatchingCFGTrainer, 56 | TextConditionedFlowMatchingCFGTrainer, 57 | ImageConditionedFlowMatchingCFGTrainer, 58 | ) 59 | 60 | from .flow_matching.sparse_flow_matching import ( 61 | SparseFlowMatchingTrainer, 62 | SparseFlowMatchingCFGTrainer, 63 | TextConditionedSparseFlowMatchingCFGTrainer, 64 | ImageConditionedSparseFlowMatchingCFGTrainer, 65 | ) 66 | -------------------------------------------------------------------------------- /dataset_toolkits/stat_latent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from easydict import EasyDict as edict 8 | from concurrent.futures import ThreadPoolExecutor 9 | 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--output_dir', type=str, required=True, 14 | help='Directory to save the metadata') 15 | parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, 16 | help='Filter objects with aesthetic score lower than this value') 17 | parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_vae_s2new_update_100000', 18 | help='Latent model to use') 19 | parser.add_argument('--num_samples', type=int, default=50000, 20 | help='Number of samples to use for calculating stats') 21 | opt = parser.parse_args() 22 | opt = edict(vars(opt)) 23 | 24 | # get file list 25 | if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): 26 | metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) 27 | else: 28 | raise ValueError('metadata.csv not found') 29 | if opt.filter_low_aesthetic_score is not None: 30 | metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] 31 | #metadata = metadata[metadata[f'latent_{opt.model}'] == True] 32 | sha256s = metadata['sha256'].values 33 | sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False) 34 | 35 | # stats 36 | means = [] 37 | mean2s = [] 38 | with ThreadPoolExecutor(max_workers=16) as executor, \ 39 | tqdm(total=len(sha256s), desc="Extracting features") as pbar: 40 | def worker(sha256): 41 | try: 42 | feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz')) 43 | feats = feats['feats'] 44 | means.append(feats.mean(axis=0)) 45 | mean2s.append((feats ** 2).mean(axis=0)) 46 | pbar.update() 47 | except Exception as e: 48 | print(f"Error extracting features for {sha256}: {e}") 49 | pbar.update() 50 | 51 | executor.map(worker, sha256s) 52 | executor.shutdown(wait=True) 53 | 54 | mean = np.array(means).mean(axis=0) 55 | mean2 = np.array(mean2s).mean(axis=0) 56 | std = np.sqrt(mean2 - mean ** 2) 57 | 58 | print('mean:', mean) 59 | print('std:', std) 60 | 61 | with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f: 62 | json.dump({ 63 | 'mean': mean.tolist(), 64 | 'std': std.tolist(), 65 | }, f, indent=4) 66 | -------------------------------------------------------------------------------- /trellis/representations/mesh/utils_cube.py: -------------------------------------------------------------------------------- 1 | import torch 2 | cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ 3 | 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) 4 | cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) 5 | cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 6 | 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) 7 | 8 | def construct_dense_grid(res, device='cuda'): 9 | '''construct a dense grid based on resolution''' 10 | res_v = res + 1 11 | vertsid = torch.arange(res_v ** 3, device=device) 12 | coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() 13 | cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] 14 | cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) 15 | verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) 16 | return verts, cube_fx8 17 | 18 | 19 | def construct_voxel_grid(coords): 20 | verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) 21 | verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) 22 | cubes = inverse_indices.reshape(-1, 8) 23 | return verts_unique, cubes 24 | 25 | 26 | def cubes_to_verts(num_verts, cubes, value, reduce='mean'): 27 | """ 28 | Args: 29 | cubes [Vx8] verts index for each cube 30 | value [Vx8xM] value to be scattered 31 | Operation: 32 | reduced[cubes[i][j]][k] += value[i][k] 33 | """ 34 | M = value.shape[2] # number of channels 35 | reduced = torch.zeros(num_verts, M, device=cubes.device) 36 | return torch.scatter_reduce(reduced, 0, 37 | cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), 38 | value.flatten(0, 1), reduce=reduce, include_self=False) 39 | 40 | def sparse_cube2verts(coords, feats, training=True): 41 | new_coords, cubes = construct_voxel_grid(coords) 42 | new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) 43 | if training: 44 | con_loss = torch.mean((feats - new_feats[cubes]) ** 2) 45 | else: 46 | con_loss = 0.0 47 | return new_coords, new_feats, con_loss 48 | 49 | 50 | def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): 51 | F = feats.shape[-1] 52 | dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) 53 | if sdf_init: 54 | dense_attrs[..., 0] = 1 # initial outside sdf value 55 | dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats 56 | return dense_attrs.reshape(-1, F) 57 | 58 | 59 | def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): 60 | return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) 61 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ...modules import sparse as sp 6 | from .base import SparseTransformerBase 7 | from ..sparse_elastic_mixin import SparseTransformerElasticMixin 8 | 9 | 10 | class SLatEncoder(SparseTransformerBase): 11 | def __init__( 12 | self, 13 | resolution: int, 14 | in_channels: int, 15 | model_channels: int, 16 | latent_channels: int, 17 | num_blocks: int, 18 | num_heads: Optional[int] = None, 19 | num_head_channels: Optional[int] = 64, 20 | mlp_ratio: float = 4, 21 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 22 | window_size: int = 8, 23 | pe_mode: Literal["ape", "rope"] = "ape", 24 | use_fp16: bool = False, 25 | use_checkpoint: bool = False, 26 | qk_rms_norm: bool = False, 27 | ): 28 | super().__init__( 29 | in_channels=in_channels, 30 | model_channels=model_channels, 31 | num_blocks=num_blocks, 32 | num_heads=num_heads, 33 | num_head_channels=num_head_channels, 34 | mlp_ratio=mlp_ratio, 35 | attn_mode=attn_mode, 36 | window_size=window_size, 37 | pe_mode=pe_mode, 38 | use_fp16=use_fp16, 39 | use_checkpoint=use_checkpoint, 40 | qk_rms_norm=qk_rms_norm, 41 | ) 42 | self.resolution = resolution 43 | self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) 44 | 45 | self.initialize_weights() 46 | if use_fp16: 47 | self.convert_to_fp16() 48 | 49 | def initialize_weights(self) -> None: 50 | super().initialize_weights() 51 | # Zero-out output layers: 52 | nn.init.constant_(self.out_layer.weight, 0) 53 | nn.init.constant_(self.out_layer.bias, 0) 54 | 55 | def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): 56 | h = super().forward(x) 57 | h = h.type(x.dtype) 58 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 59 | h = self.out_layer(h) 60 | 61 | # Sample from the posterior distribution 62 | mean, logvar = h.feats.chunk(2, dim=-1) 63 | if sample_posterior: 64 | std = torch.exp(0.5 * logvar) 65 | z = mean + std * torch.randn_like(std) 66 | else: 67 | z = mean 68 | z = h.replace(z) 69 | 70 | if return_raw: 71 | return z, mean, logvar 72 | else: 73 | return z 74 | 75 | 76 | class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder): 77 | """ 78 | SLat VAE encoder with elastic memory management. 79 | Used for training with low VRAM. 80 | """ 81 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/propertyencoder.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ...modules import sparse as sp 6 | from .base import PropertyTransformerEncoderBase 7 | from ..sparse_elastic_mixin import SparseTransformerElasticMixin 8 | 9 | class PropertyEncoder(PropertyTransformerEncoderBase): 10 | def __init__( 11 | self, 12 | resolution: int, 13 | in_channels: int, 14 | in_channels_phy: int, 15 | model_channels: int, 16 | latent_channels: int, 17 | num_blocks: int, 18 | num_heads: Optional[int] = None, 19 | num_head_channels: Optional[int] = 64, 20 | mlp_ratio: float = 4, 21 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 22 | window_size: int = 8, 23 | pe_mode: Literal["ape", "rope"] = "ape", 24 | use_fp16: bool = False, 25 | use_checkpoint: bool = False, 26 | qk_rms_norm: bool = False, 27 | ): 28 | super().__init__( 29 | in_channels=in_channels, 30 | in_channels_phy=in_channels_phy, 31 | model_channels=model_channels, 32 | num_blocks=num_blocks, 33 | num_heads=num_heads, 34 | num_head_channels=num_head_channels, 35 | mlp_ratio=mlp_ratio, 36 | attn_mode=attn_mode, 37 | window_size=window_size, 38 | pe_mode=pe_mode, 39 | use_fp16=use_fp16, 40 | use_checkpoint=use_checkpoint, 41 | qk_rms_norm=qk_rms_norm, 42 | ) 43 | self.resolution = resolution 44 | self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) 45 | 46 | if use_fp16: 47 | self.convert_to_fp16() 48 | 49 | def initialize_weights(self) -> None: 50 | super().initialize_weights() 51 | # Zero-out output layers: 52 | nn.init.constant_(self.out_layer.weight, 0) 53 | nn.init.constant_(self.out_layer.bias, 0) 54 | 55 | def forward(self, lang: sp.SparseTensor,phy: sp.SparseTensor, sample_posterior=True, return_raw=False): 56 | h = super().forward(lang,phy) 57 | h = h.type(lang.dtype) 58 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 59 | h = self.out_layer(h) 60 | 61 | # Sample from the posterior distribution 62 | mean, logvar = h.feats.chunk(2, dim=-1) 63 | if sample_posterior: 64 | std = torch.exp(0.5 * logvar) 65 | z = mean + std * torch.randn_like(std) 66 | else: 67 | z = mean 68 | z = h.replace(z) 69 | 70 | if return_raw: 71 | return z, mean, logvar 72 | else: 73 | return z 74 | 75 | 76 | class ElasticPropertyEncoder(SparseTransformerElasticMixin, PropertyEncoder): 77 | """ 78 | SLat VAE encoder with elastic memory management. 79 | Used for training with low VRAM. 80 | """ 81 | -------------------------------------------------------------------------------- /vox2seq/src/api.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "api.h" 3 | #include "z_order.h" 4 | #include "hilbert.h" 5 | 6 | 7 | torch::Tensor 8 | z_order_encode( 9 | const torch::Tensor& x, 10 | const torch::Tensor& y, 11 | const torch::Tensor& z 12 | ) { 13 | // Allocate output tensor 14 | torch::Tensor codes = torch::empty_like(x); 15 | 16 | // Call CUDA kernel 17 | z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( 18 | x.size(0), 19 | reinterpret_cast(x.contiguous().data_ptr()), 20 | reinterpret_cast(y.contiguous().data_ptr()), 21 | reinterpret_cast(z.contiguous().data_ptr()), 22 | reinterpret_cast(codes.data_ptr()) 23 | ); 24 | 25 | return codes; 26 | } 27 | 28 | 29 | std::tuple 30 | z_order_decode( 31 | const torch::Tensor& codes 32 | ) { 33 | // Allocate output tensors 34 | torch::Tensor x = torch::empty_like(codes); 35 | torch::Tensor y = torch::empty_like(codes); 36 | torch::Tensor z = torch::empty_like(codes); 37 | 38 | // Call CUDA kernel 39 | z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( 40 | codes.size(0), 41 | reinterpret_cast(codes.contiguous().data_ptr()), 42 | reinterpret_cast(x.data_ptr()), 43 | reinterpret_cast(y.data_ptr()), 44 | reinterpret_cast(z.data_ptr()) 45 | ); 46 | 47 | return std::make_tuple(x, y, z); 48 | } 49 | 50 | 51 | torch::Tensor 52 | hilbert_encode( 53 | const torch::Tensor& x, 54 | const torch::Tensor& y, 55 | const torch::Tensor& z 56 | ) { 57 | // Allocate output tensor 58 | torch::Tensor codes = torch::empty_like(x); 59 | 60 | // Call CUDA kernel 61 | hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( 62 | x.size(0), 63 | reinterpret_cast(x.contiguous().data_ptr()), 64 | reinterpret_cast(y.contiguous().data_ptr()), 65 | reinterpret_cast(z.contiguous().data_ptr()), 66 | reinterpret_cast(codes.data_ptr()) 67 | ); 68 | 69 | return codes; 70 | } 71 | 72 | 73 | std::tuple 74 | hilbert_decode( 75 | const torch::Tensor& codes 76 | ) { 77 | // Allocate output tensors 78 | torch::Tensor x = torch::empty_like(codes); 79 | torch::Tensor y = torch::empty_like(codes); 80 | torch::Tensor z = torch::empty_like(codes); 81 | 82 | // Call CUDA kernel 83 | hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( 84 | codes.size(0), 85 | reinterpret_cast(codes.contiguous().data_ptr()), 86 | reinterpret_cast(x.data_ptr()), 87 | reinterpret_cast(y.data_ptr()), 88 | reinterpret_cast(z.data_ptr()) 89 | ); 90 | 91 | return std::make_tuple(x, y, z); 92 | } 93 | -------------------------------------------------------------------------------- /trellis/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | from math import exp 5 | from lpips import LPIPS 6 | 7 | 8 | def smooth_l1_loss(pred, target, beta=1.0): 9 | diff = torch.abs(pred - target) 10 | loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) 11 | return loss.mean() 12 | 13 | 14 | def l1_loss(network_output, gt): 15 | return torch.abs((network_output - gt)).mean() 16 | 17 | 18 | def l2_loss(network_output, gt): 19 | return ((network_output - gt) ** 2).mean() 20 | 21 | 22 | def gaussian(window_size, sigma): 23 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 24 | return gauss / gauss.sum() 25 | 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | 34 | def psnr(img1, img2, max_val=1.0): 35 | mse = F.mse_loss(img1, img2) 36 | return 20 * torch.log10(max_val / torch.sqrt(mse)) 37 | 38 | 39 | def ssim(img1, img2, window_size=11, size_average=True): 40 | channel = img1.size(-3) 41 | window = create_window(window_size, channel) 42 | 43 | if img1.is_cuda: 44 | window = window.cuda(img1.get_device()) 45 | window = window.type_as(img1) 46 | 47 | return _ssim(img1, img2, window, window_size, channel, size_average) 48 | 49 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 50 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 51 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 52 | 53 | mu1_sq = mu1.pow(2) 54 | mu2_sq = mu2.pow(2) 55 | mu1_mu2 = mu1 * mu2 56 | 57 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 58 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 59 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 60 | 61 | C1 = 0.01 ** 2 62 | C2 = 0.03 ** 2 63 | 64 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 65 | 66 | if size_average: 67 | return ssim_map.mean() 68 | else: 69 | return ssim_map.mean(1).mean(1).mean(1) 70 | 71 | 72 | loss_fn_vgg = None 73 | def lpips(img1, img2, value_range=(0, 1)): 74 | global loss_fn_vgg 75 | if loss_fn_vgg is None: 76 | loss_fn_vgg = LPIPS(net='vgg').cuda().eval() 77 | # normalize to [-1, 1] 78 | img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 79 | img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 80 | return loss_fn_vgg(img1, img2).mean() 81 | 82 | 83 | def normal_angle(pred, gt): 84 | pred = pred * 2.0 - 1.0 85 | gt = gt * 2.0 - 1.0 86 | norms = pred.norm(dim=-1) * gt.norm(dim=-1) 87 | cos_sim = (pred * gt).sum(-1) / (norms + 1e-9) 88 | cos_sim = torch.clamp(cos_sim, -1.0, 1.0) 89 | ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean() 90 | if ang.isnan(): 91 | return -1 92 | return ang 93 | -------------------------------------------------------------------------------- /trellis/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from contextlib import contextmanager 4 | import torch 5 | import torch.distributed as dist 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | import subprocess 8 | 9 | import socket 10 | from datetime import datetime 11 | 12 | def find_available_port(): 13 | port = 29500 14 | while True: 15 | try: 16 | with socket.socket() as s: 17 | s.bind(('', port)) 18 | return port 19 | except: 20 | port += 1 21 | 22 | 23 | def setup_dist(rank, local_rank, world_size, master_addr, master_port): 24 | os.environ['MASTER_ADDR'] = master_addr 25 | os.environ['MASTER_PORT'] = master_port 26 | os.environ['WORLD_SIZE'] = str(world_size) 27 | os.environ['RANK'] = str(rank) 28 | os.environ['LOCAL_RANK'] = str(local_rank) 29 | torch.cuda.set_device(local_rank) 30 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 31 | 32 | 33 | def read_file_dist(path): 34 | """ 35 | Read the binary file distributedly. 36 | File is only read once by the rank 0 process and broadcasted to other processes. 37 | 38 | Returns: 39 | data (io.BytesIO): The binary data read from the file. 40 | """ 41 | if dist.is_initialized() and dist.get_world_size() > 1: 42 | # read file 43 | size = torch.LongTensor(1).cuda() 44 | if dist.get_rank() == 0: 45 | with open(path, 'rb') as f: 46 | data = f.read() 47 | data = torch.ByteTensor( 48 | torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) 49 | ).cuda() 50 | size[0] = data.shape[0] 51 | # broadcast size 52 | dist.broadcast(size, src=0) 53 | if dist.get_rank() != 0: 54 | data = torch.ByteTensor(size[0].item()).cuda() 55 | # broadcast data 56 | dist.broadcast(data, src=0) 57 | # convert to io.BytesIO 58 | data = data.cpu().numpy().tobytes() 59 | data = io.BytesIO(data) 60 | return data 61 | else: 62 | with open(path, 'rb') as f: 63 | data = f.read() 64 | data = io.BytesIO(data) 65 | return data 66 | 67 | 68 | def unwrap_dist(model): 69 | """ 70 | Unwrap the model from distributed training. 71 | """ 72 | if isinstance(model, DDP): 73 | return model.module 74 | return model 75 | 76 | 77 | @contextmanager 78 | def master_first(): 79 | """ 80 | A context manager that ensures master process executes first. 81 | """ 82 | if not dist.is_initialized(): 83 | yield 84 | else: 85 | if dist.get_rank() == 0: 86 | yield 87 | dist.barrier() 88 | else: 89 | dist.barrier() 90 | yield 91 | 92 | 93 | @contextmanager 94 | def local_master_first(): 95 | """ 96 | A context manager that ensures local master process executes first. 97 | """ 98 | if not dist.is_initialized(): 99 | yield 100 | else: 101 | if dist.get_rank() % torch.cuda.device_count() == 0: 102 | yield 103 | dist.barrier() 104 | else: 105 | dist.barrier() 106 | yield 107 | -------------------------------------------------------------------------------- /trellis/modules/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | BACKEND = 'spconv' 4 | DEBUG = False 5 | ATTN = 'flash_attn' 6 | 7 | def __from_env(): 8 | import os 9 | 10 | global BACKEND 11 | global DEBUG 12 | global ATTN 13 | 14 | env_sparse_backend = os.environ.get('SPARSE_BACKEND') 15 | env_sparse_debug = os.environ.get('SPARSE_DEBUG') 16 | env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') 17 | if env_sparse_attn is None: 18 | env_sparse_attn = os.environ.get('ATTN_BACKEND') 19 | 20 | if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: 21 | BACKEND = env_sparse_backend 22 | if env_sparse_debug is not None: 23 | DEBUG = env_sparse_debug == '1' 24 | if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: 25 | ATTN = env_sparse_attn 26 | 27 | print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") 28 | 29 | 30 | __from_env() 31 | 32 | 33 | def set_backend(backend: Literal['spconv', 'torchsparse']): 34 | global BACKEND 35 | BACKEND = backend 36 | 37 | def set_debug(debug: bool): 38 | global DEBUG 39 | DEBUG = debug 40 | 41 | def set_attn(attn: Literal['xformers', 'flash_attn']): 42 | global ATTN 43 | ATTN = attn 44 | 45 | 46 | import importlib 47 | 48 | __attributes = { 49 | 'SparseTensor': 'basic', 50 | 'sparse_batch_broadcast': 'basic', 51 | 'sparse_batch_op': 'basic', 52 | 'sparse_cat': 'basic', 53 | 'sparse_unbind': 'basic', 54 | 'SparseGroupNorm': 'norm', 55 | 'SparseLayerNorm': 'norm', 56 | 'SparseGroupNorm32': 'norm', 57 | 'SparseLayerNorm32': 'norm', 58 | 'SparseReLU': 'nonlinearity', 59 | 'SparseSiLU': 'nonlinearity', 60 | 'SparseGELU': 'nonlinearity', 61 | 'SparseActivation': 'nonlinearity', 62 | 'SparseLinear': 'linear', 63 | 'sparse_scaled_dot_product_attention': 'attention', 64 | 'SerializeMode': 'attention', 65 | 'sparse_serialized_scaled_dot_product_self_attention': 'attention', 66 | 'sparse_windowed_scaled_dot_product_self_attention': 'attention', 67 | 'SparseMultiHeadAttention': 'attention', 68 | 'SparseConv3d': 'conv', 69 | 'SparseInverseConv3d': 'conv', 70 | 'SparseDownsample': 'spatial', 71 | 'SparseUpsample': 'spatial', 72 | 'SparseSubdivide' : 'spatial' 73 | } 74 | 75 | __submodules = ['transformer'] 76 | 77 | __all__ = list(__attributes.keys()) + __submodules 78 | 79 | def __getattr__(name): 80 | if name not in globals(): 81 | if name in __attributes: 82 | module_name = __attributes[name] 83 | module = importlib.import_module(f".{module_name}", __name__) 84 | globals()[name] = getattr(module, name) 85 | elif name in __submodules: 86 | module = importlib.import_module(f".{name}", __name__) 87 | globals()[name] = module 88 | else: 89 | raise AttributeError(f"module {__name__} has no attribute {name}") 90 | return globals()[name] 91 | 92 | 93 | # For Pylance 94 | if __name__ == '__main__': 95 | from .basic import * 96 | from .norm import * 97 | from .nonlinearity import * 98 | from .linear import * 99 | from .attention import * 100 | from .conv import * 101 | from .spatial import * 102 | import transformer 103 | -------------------------------------------------------------------------------- /trellis/utils/grad_clip_utils.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import numpy as np 4 | import torch.utils 5 | 6 | 7 | class AdaptiveGradClipper: 8 | """ 9 | Adaptive gradient clipping for training. 10 | """ 11 | def __init__( 12 | self, 13 | max_norm=None, 14 | clip_percentile=95.0, 15 | buffer_size=1000, 16 | ): 17 | self.max_norm = max_norm 18 | self.clip_percentile = clip_percentile 19 | self.buffer_size = buffer_size 20 | 21 | self._grad_norm = np.zeros(buffer_size, dtype=np.float32) 22 | self._max_norm = max_norm 23 | self._buffer_ptr = 0 24 | self._buffer_length = 0 25 | 26 | def __repr__(self): 27 | return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' 28 | 29 | def state_dict(self): 30 | return { 31 | 'grad_norm': self._grad_norm, 32 | 'max_norm': self._max_norm, 33 | 'buffer_ptr': self._buffer_ptr, 34 | 'buffer_length': self._buffer_length, 35 | } 36 | 37 | def load_state_dict(self, state_dict): 38 | self._grad_norm = state_dict['grad_norm'] 39 | self._max_norm = state_dict['max_norm'] 40 | self._buffer_ptr = state_dict['buffer_ptr'] 41 | self._buffer_length = state_dict['buffer_length'] 42 | 43 | def log(self): 44 | return { 45 | 'max_norm': self._max_norm, 46 | } 47 | 48 | def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None): 49 | """Clip the gradient norm of an iterable of parameters. 50 | 51 | The norm is computed over all gradients together, as if they were 52 | concatenated into a single vector. Gradients are modified in-place. 53 | 54 | Args: 55 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 56 | single Tensor that will have gradients normalized 57 | norm_type (float): type of the used p-norm. Can be ``'inf'`` for 58 | infinity norm. 59 | error_if_nonfinite (bool): if True, an error is thrown if the total 60 | norm of the gradients from :attr:`parameters` is ``nan``, 61 | ``inf``, or ``-inf``. Default: False (will switch to True in the future) 62 | foreach (bool): use the faster foreach-based implementation. 63 | If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently 64 | fall back to the slow implementation for other device types. 65 | Default: ``None`` 66 | 67 | Returns: 68 | Total norm of the parameter gradients (viewed as a single vector). 69 | """ 70 | max_norm = self._max_norm if self._max_norm is not None else float('inf') 71 | grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) 72 | 73 | if torch.isfinite(grad_norm): 74 | self._grad_norm[self._buffer_ptr] = grad_norm 75 | self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size 76 | self._buffer_length = min(self._buffer_length + 1, self.buffer_size) 77 | if self._buffer_length == self.buffer_size: 78 | self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) 79 | self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm 80 | 81 | return grad_norm -------------------------------------------------------------------------------- /vox2seq/src/hilbert.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | namespace cg = cooperative_groups; 8 | 9 | #include "hilbert.h" 10 | 11 | 12 | // Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. 13 | static __device__ uint32_t expandBits(uint32_t v) 14 | { 15 | v = (v * 0x00010001u) & 0xFF0000FFu; 16 | v = (v * 0x00000101u) & 0x0F00F00Fu; 17 | v = (v * 0x00000011u) & 0xC30C30C3u; 18 | v = (v * 0x00000005u) & 0x49249249u; 19 | return v; 20 | } 21 | 22 | 23 | // Removes 2 zeros after each bit in a 30-bit integer. 24 | static __device__ uint32_t extractBits(uint32_t v) 25 | { 26 | v = v & 0x49249249; 27 | v = (v ^ (v >> 2)) & 0x030C30C3u; 28 | v = (v ^ (v >> 4)) & 0x0300F00Fu; 29 | v = (v ^ (v >> 8)) & 0x030000FFu; 30 | v = (v ^ (v >> 16)) & 0x000003FFu; 31 | return v; 32 | } 33 | 34 | 35 | __global__ void hilbert_encode_cuda( 36 | size_t N, 37 | const uint32_t* x, 38 | const uint32_t* y, 39 | const uint32_t* z, 40 | uint32_t* codes 41 | ) { 42 | size_t thread_id = cg::this_grid().thread_rank(); 43 | if (thread_id >= N) return; 44 | 45 | uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]}; 46 | 47 | uint32_t m = 1 << 9, q, p, t; 48 | 49 | // Inverse undo excess work 50 | q = m; 51 | while (q > 1) { 52 | p = q - 1; 53 | for (int i = 0; i < 3; i++) { 54 | if (point[i] & q) { 55 | point[0] ^= p; // invert 56 | } else { 57 | t = (point[0] ^ point[i]) & p; 58 | point[0] ^= t; 59 | point[i] ^= t; 60 | } 61 | } 62 | q >>= 1; 63 | } 64 | 65 | // Gray encode 66 | for (int i = 1; i < 3; i++) { 67 | point[i] ^= point[i - 1]; 68 | } 69 | t = 0; 70 | q = m; 71 | while (q > 1) { 72 | if (point[2] & q) { 73 | t ^= q - 1; 74 | } 75 | q >>= 1; 76 | } 77 | for (int i = 0; i < 3; i++) { 78 | point[i] ^= t; 79 | } 80 | 81 | // Convert to 3D Hilbert code 82 | uint32_t xx = expandBits(point[0]); 83 | uint32_t yy = expandBits(point[1]); 84 | uint32_t zz = expandBits(point[2]); 85 | 86 | codes[thread_id] = xx * 4 + yy * 2 + zz; 87 | } 88 | 89 | 90 | __global__ void hilbert_decode_cuda( 91 | size_t N, 92 | const uint32_t* codes, 93 | uint32_t* x, 94 | uint32_t* y, 95 | uint32_t* z 96 | ) { 97 | size_t thread_id = cg::this_grid().thread_rank(); 98 | if (thread_id >= N) return; 99 | 100 | uint32_t point[3]; 101 | point[0] = extractBits(codes[thread_id] >> 2); 102 | point[1] = extractBits(codes[thread_id] >> 1); 103 | point[2] = extractBits(codes[thread_id]); 104 | 105 | uint32_t m = 2 << 9, q, p, t; 106 | 107 | // Gray decode by H ^ (H/2) 108 | t = point[2] >> 1; 109 | for (int i = 2; i > 0; i--) { 110 | point[i] ^= point[i - 1]; 111 | } 112 | point[0] ^= t; 113 | 114 | // Undo excess work 115 | q = 2; 116 | while (q != m) { 117 | p = q - 1; 118 | for (int i = 2; i >= 0; i--) { 119 | if (point[i] & q) { 120 | point[0] ^= p; 121 | } else { 122 | t = (point[0] ^ point[i]) & p; 123 | point[0] ^= t; 124 | point[i] ^= t; 125 | } 126 | } 127 | q <<= 1; 128 | } 129 | 130 | x[thread_id] = point[0]; 131 | y[thread_id] = point[1]; 132 | z[thread_id] = point[2]; 133 | } 134 | -------------------------------------------------------------------------------- /trellis/trainers/flow_matching/mixins/image_conditioned.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn.functional as F 4 | from torchvision import transforms 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from ....utils import dist_utils 9 | 10 | 11 | class ImageConditionedMixin: 12 | """ 13 | Mixin for image-conditioned models. 14 | 15 | Args: 16 | image_cond_model: The image conditioning model. 17 | """ 18 | def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.image_cond_model_name = image_cond_model 21 | self.image_cond_model = None # the model is init lazily 22 | 23 | @staticmethod 24 | def prepare_for_training(image_cond_model: str, **kwargs): 25 | """ 26 | Prepare for training. 27 | """ 28 | if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'): 29 | super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs) 30 | # download the model 31 | torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True) 32 | 33 | def _init_image_cond_model(self): 34 | """ 35 | Initialize the image conditioning model. 36 | """ 37 | with dist_utils.local_master_first(): 38 | dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True) 39 | dinov2_model.eval().cuda() 40 | transform = transforms.Compose([ 41 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 42 | ]) 43 | self.image_cond_model = { 44 | 'model': dinov2_model, 45 | 'transform': transform, 46 | } 47 | 48 | @torch.no_grad() 49 | def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: 50 | """ 51 | Encode the image. 52 | """ 53 | if isinstance(image, torch.Tensor): 54 | assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" 55 | elif isinstance(image, list): 56 | assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" 57 | image = [i.resize((518, 518), Image.LANCZOS) for i in image] 58 | image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] 59 | image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] 60 | image = torch.stack(image).cuda() 61 | else: 62 | raise ValueError(f"Unsupported type of image: {type(image)}") 63 | 64 | if self.image_cond_model is None: 65 | self._init_image_cond_model() 66 | image = self.image_cond_model['transform'](image).cuda() 67 | features = self.image_cond_model['model'](image, is_training=True)['x_prenorm'] 68 | patchtokens = F.layer_norm(features, features.shape[-1:]) 69 | return patchtokens 70 | 71 | def get_cond(self, cond, **kwargs): 72 | """ 73 | Get the conditioning data. 74 | """ 75 | cond = self.encode_image(cond) 76 | kwargs['neg_cond'] = torch.zeros_like(cond) 77 | cond = super().get_cond(cond, **kwargs) 78 | return cond 79 | 80 | def get_inference_cond(self, cond, **kwargs): 81 | """ 82 | Get the conditioning data for inference. 83 | """ 84 | cond = self.encode_image(cond) 85 | kwargs['neg_cond'] = torch.zeros_like(cond) 86 | cond = super().get_inference_cond(cond, **kwargs) 87 | return cond 88 | 89 | def vis_cond(self, cond, **kwargs): 90 | """ 91 | Visualize the conditioning data. 92 | """ 93 | return {'image': {'value': cond, 'type': 'image'}} 94 | -------------------------------------------------------------------------------- /configs/vae/slat_vae_enc_dec_mesh_phy.json: -------------------------------------------------------------------------------- 1 | { 2 | "models": { 3 | "property_encoder": { 4 | "name": "ElasticPropertyEncoder", 5 | "args": { 6 | "resolution": 64, 7 | "in_channels": 3072, 8 | "in_channels_phy": 14, 9 | "model_channels": 768, 10 | "latent_channels": 8, 11 | "num_blocks": 4, 12 | "num_heads": 12, 13 | "mlp_ratio": 4, 14 | "attn_mode": "swin", 15 | "window_size": 8, 16 | "use_fp16": true 17 | } 18 | }, 19 | "property_decoder": { 20 | "name": "ElasticPropertyDecoder", 21 | "args": { 22 | "resolution": 64, 23 | "model_channels": 2048, 24 | "latent_channels": 8, 25 | "num_blocks": 4, 26 | "num_heads": 16, 27 | "mlp_ratio": 4, 28 | "attn_mode": "swin", 29 | "window_size": 8, 30 | "use_fp16": true, 31 | "representation_config": { 32 | "use_color": true 33 | } 34 | } 35 | }, 36 | "property_output": { 37 | "name": "PropertyOutput", 38 | "args": { 39 | "model_channels": 32, 40 | "output_channels_lang": 3072, 41 | "output_channels_phy": 14, 42 | "use_fp16": true 43 | } 44 | }, 45 | 46 | "decoder": { 47 | "name": "ElasticSLatMeshDecodernew", 48 | "args": { 49 | "resolution": 64, 50 | "model_channels": 768, 51 | "phy_channels": 2048, 52 | "latent_channels": 8, 53 | "num_blocks": 12, 54 | "num_heads": 12, 55 | "mlp_ratio": 4, 56 | "attn_mode": "swin", 57 | "window_size": 8, 58 | "use_fp16": true, 59 | "representation_config": { 60 | "use_color": true 61 | } 62 | } 63 | } 64 | }, 65 | "dataset": { 66 | "name": "Slat2RenderGeomesh", 67 | "args": { 68 | "image_size": 384, 69 | "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16", 70 | "min_aesthetic_score": 4.5, 71 | "max_num_voxels": 28000 72 | } 73 | }, 74 | "trainer": { 75 | "name": "SLatVaeMeshTrainer", 76 | "args": { 77 | "onlyphy_property": true, 78 | "max_steps": 1000000, 79 | "batch_size_per_gpu": 4, 80 | "batch_split": 4, 81 | "optimizer": { 82 | "name": "AdamW", 83 | "args": { 84 | "lr": 1e-4, 85 | "weight_decay": 0.0 86 | } 87 | }, 88 | "ema_rate": [ 89 | 0.9999 90 | ], 91 | "fp16_mode": "inflat_all", 92 | "fp16_scale_growth": 0.001, 93 | "elastic": { 94 | "name": "LinearMemoryController", 95 | "args": { 96 | "target_ratio": 0.6, 97 | "max_mem_ratio_start": 0.5 98 | } 99 | }, 100 | "grad_clip": { 101 | "name": "AdaptiveGradClipper", 102 | "args": { 103 | "max_norm": 1.0, 104 | "clip_percentile": 95 105 | } 106 | }, 107 | "i_log": 10, 108 | "i_sample": 5000, 109 | "i_save": 10000, 110 | "lambda_ssim": 0.2, 111 | "lambda_lpips": 0.2, 112 | "lambda_tsdf": 0.01, 113 | "lambda_depth": 10.0, 114 | "lambda_color": 0.1, 115 | "lambda_kl": 1e-06, 116 | "depth_loss_type": "smooth_l1" 117 | } 118 | } 119 | } -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/conv_spconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .. import SparseTensor 4 | from .. import DEBUG 5 | from . import SPCONV_ALGO 6 | 7 | class SparseConv3d(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): 9 | super(SparseConv3d, self).__init__() 10 | if 'spconv' not in globals(): 11 | import spconv.pytorch as spconv 12 | algo = None 13 | if SPCONV_ALGO == 'native': 14 | algo = spconv.ConvAlgo.Native 15 | elif SPCONV_ALGO == 'implicit_gemm': 16 | algo = spconv.ConvAlgo.MaskImplicitGemm 17 | if stride == 1 and (padding is None): 18 | self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) 19 | else: 20 | self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) 21 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) 22 | self.padding = padding 23 | 24 | def forward(self, x: SparseTensor) -> SparseTensor: 25 | spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) 26 | new_data = self.conv(x.data) 27 | new_shape = [x.shape[0], self.conv.out_channels] 28 | new_layout = None if spatial_changed else x.layout 29 | 30 | if spatial_changed and (x.shape[0] != 1): 31 | # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords 32 | fwd = new_data.indices[:, 0].argsort() 33 | bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) 34 | sorted_feats = new_data.features[fwd] 35 | sorted_coords = new_data.indices[fwd] 36 | unsorted_data = new_data 37 | new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore 38 | 39 | out = SparseTensor( 40 | new_data, shape=torch.Size(new_shape), layout=new_layout, 41 | scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), 42 | spatial_cache=x._spatial_cache, 43 | ) 44 | 45 | if spatial_changed and (x.shape[0] != 1): 46 | out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) 47 | out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) 48 | 49 | return out 50 | 51 | 52 | class SparseInverseConv3d(nn.Module): 53 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 54 | super(SparseInverseConv3d, self).__init__() 55 | if 'spconv' not in globals(): 56 | import spconv.pytorch as spconv 57 | self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) 58 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) 59 | 60 | def forward(self, x: SparseTensor) -> SparseTensor: 61 | spatial_changed = any(s != 1 for s in self.stride) 62 | if spatial_changed: 63 | # recover the original spconv order 64 | data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') 65 | bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') 66 | data = data.replace_feature(x.feats[bwd]) 67 | if DEBUG: 68 | assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' 69 | else: 70 | data = x.data 71 | 72 | new_data = self.conv(data) 73 | new_shape = [x.shape[0], self.conv.out_channels] 74 | new_layout = None if spatial_changed else x.layout 75 | out = SparseTensor( 76 | new_data, shape=torch.Size(new_shape), layout=new_layout, 77 | scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), 78 | spatial_cache=x._spatial_cache, 79 | ) 80 | return out 81 | -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/examples/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | import torch 9 | import torch_scatter 10 | 11 | ############################################################################### 12 | # Pytorch implementation of the developability regularizer introduced in paper 13 | # "Developability of Triangle Meshes" by Stein et al. 14 | ############################################################################### 15 | def mesh_developable_reg(mesh): 16 | 17 | verts = mesh.vertices 18 | tris = mesh.faces 19 | 20 | device = verts.device 21 | V = verts.shape[0] 22 | F = tris.shape[0] 23 | 24 | POS_EPS = 1e-6 25 | REL_EPS = 1e-6 26 | 27 | def normalize(vecs): 28 | return vecs / (torch.linalg.norm(vecs, dim=-1, keepdim=True) + POS_EPS) 29 | 30 | tri_pos = verts[tris] 31 | 32 | vert_normal_covariance_sum = torch.zeros((V, 9), device=device) 33 | vert_area = torch.zeros(V, device=device) 34 | vert_degree = torch.zeros(V, dtype=torch.int32, device=device) 35 | 36 | for iC in range(3): # loop over three corners of each triangle 37 | 38 | # gather tri verts 39 | pRoot = tri_pos[:, iC, :] 40 | pA = tri_pos[:, (iC + 1) % 3, :] 41 | pB = tri_pos[:, (iC + 2) % 3, :] 42 | 43 | # compute the corner angle & normal 44 | vA = pA - pRoot 45 | vAn = normalize(vA) 46 | vB = pB - pRoot 47 | vBn = normalize(vB) 48 | area_normal = torch.linalg.cross(vA, vB, dim=-1) 49 | face_area = 0.5 * torch.linalg.norm(area_normal, dim=-1) 50 | normal = normalize(area_normal) 51 | corner_angle = torch.acos(torch.clamp(torch.sum(vAn * vBn, dim=-1), min=-1., max=1.)) 52 | 53 | # add up the contribution to the covariance matrix 54 | outer = normal[:, :, None] @ normal[:, None, :] 55 | contrib = corner_angle[:, None] * outer.reshape(-1, 9) 56 | 57 | # scatter the result to the appropriate matrices 58 | vert_normal_covariance_sum = torch_scatter.scatter_add(src=contrib, 59 | index=tris[:, iC], 60 | dim=-2, 61 | out=vert_normal_covariance_sum) 62 | 63 | vert_area = torch_scatter.scatter_add(src=face_area / 3., 64 | index=tris[:, iC], 65 | dim=-1, 66 | out=vert_area) 67 | 68 | vert_degree = torch_scatter.scatter_add(src=torch.ones(F, dtype=torch.int32, device=device), 69 | index=tris[:, iC], 70 | dim=-1, 71 | out=vert_degree) 72 | 73 | # The energy is the smallest eigenvalue of the outer-product matrix 74 | vert_normal_covariance_sum = vert_normal_covariance_sum.reshape( 75 | -1, 3, 3) # reshape to a batch of matrices 76 | vert_normal_covariance_sum = vert_normal_covariance_sum + torch.eye( 77 | 3, device=device)[None, :, :] * REL_EPS 78 | 79 | min_eigvals = torch.min(torch.linalg.eigvals(vert_normal_covariance_sum).abs(), dim=-1).values 80 | 81 | # Mask out degree-3 vertices 82 | vert_area = torch.where(vert_degree == 3, torch.tensor(0, dtype=vert_area.dtype,device=vert_area.device), vert_area) 83 | 84 | # Adjust the vertex area weighting so it is unit-less, and 1 on average 85 | vert_area = vert_area * (V / torch.sum(vert_area, dim=-1, keepdim=True)) 86 | 87 | return vert_area * min_eigvals 88 | 89 | def sdf_reg_loss(sdf, all_edges): 90 | sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) 91 | mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) 92 | sdf_f1x6x2 = sdf_f1x6x2[mask] 93 | sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ 94 | torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) 95 | return sdf_diff -------------------------------------------------------------------------------- /trellis/representations/gaussian/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /trellis/datasets/sparse_structure.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Union 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import Dataset 8 | import utils3d 9 | from .components import StandardDatasetBase 10 | from ..representations.octree import DfsOctree as Octree 11 | from ..renderers import OctreeRenderer 12 | 13 | 14 | class SparseStructure(StandardDatasetBase): 15 | """ 16 | Sparse structure dataset 17 | 18 | Args: 19 | roots (str): path to the dataset 20 | resolution (int): resolution of the voxel grid 21 | min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset 22 | """ 23 | 24 | def __init__(self, 25 | roots, 26 | resolution: int = 64, 27 | min_aesthetic_score: float = 5.0, 28 | ): 29 | self.resolution = resolution 30 | self.min_aesthetic_score = min_aesthetic_score 31 | self.value_range = (0, 1) 32 | 33 | super().__init__(roots) 34 | 35 | def filter_metadata(self, metadata): 36 | stats = {} 37 | metadata = metadata[metadata[f'voxelized']] 38 | stats['Voxelized'] = len(metadata) 39 | metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] 40 | stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) 41 | return metadata, stats 42 | 43 | def get_instance(self, root, instance): 44 | position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0] 45 | coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous() 46 | ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long) 47 | ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 48 | return {'ss': ss} 49 | 50 | @torch.no_grad() 51 | def visualize_sample(self, ss: Union[torch.Tensor, dict]): 52 | ss = ss if isinstance(ss, torch.Tensor) else ss['ss'] 53 | 54 | renderer = OctreeRenderer() 55 | renderer.rendering_options.resolution = 512 56 | renderer.rendering_options.near = 0.8 57 | renderer.rendering_options.far = 1.6 58 | renderer.rendering_options.bg_color = (0, 0, 0) 59 | renderer.rendering_options.ssaa = 4 60 | renderer.pipe.primitive = 'voxel' 61 | 62 | # Build camera 63 | yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] 64 | yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) 65 | yaws = [y + yaws_offset for y in yaws] 66 | pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] 67 | 68 | exts = [] 69 | ints = [] 70 | for yaw, pitch in zip(yaws, pitch): 71 | orig = torch.tensor([ 72 | np.sin(yaw) * np.cos(pitch), 73 | np.cos(yaw) * np.cos(pitch), 74 | np.sin(pitch), 75 | ]).float().cuda() * 2 76 | fov = torch.deg2rad(torch.tensor(30)).cuda() 77 | extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) 78 | intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) 79 | exts.append(extrinsics) 80 | ints.append(intrinsics) 81 | 82 | images = [] 83 | 84 | # Build each representation 85 | ss = ss.cuda() 86 | for i in range(ss.shape[0]): 87 | representation = Octree( 88 | depth=10, 89 | aabb=[-0.5, -0.5, -0.5, 1, 1, 1], 90 | device='cuda', 91 | primitive='voxel', 92 | sh_degree=0, 93 | primitive_config={'solid': True}, 94 | ) 95 | coords = torch.nonzero(ss[i, 0], as_tuple=False) 96 | representation.position = coords.float() / self.resolution 97 | representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') 98 | 99 | image = torch.zeros(3, 1024, 1024).cuda() 100 | tile = [2, 2] 101 | for j, (ext, intr) in enumerate(zip(exts, ints)): 102 | res = renderer.render(representation, ext, intr, colors_overwrite=representation.position) 103 | image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] 104 | images.append(image) 105 | 106 | return torch.stack(images) 107 | -------------------------------------------------------------------------------- /trellis/representations/mesh/flexicubes/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for FlexiCubes 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | “Licensor” means any person or entity that distributes its Work. 12 | 13 | “Work” means (a) the original work of authorship made available under 14 | this license, which may include software, documentation, or other files, 15 | and (b) any additions to or derivative works thereof that are made 16 | available under this license. 17 | 18 | The terms “reproduce,” “reproduction,” “derivative works,” and 19 | “distribution” have the meaning as provided under U.S. copyright law; 20 | provided, however, that for the purposes of this license, derivative works 21 | shall not include works that remain separable from, or merely link 22 | (or bind by name) to the interfaces of, the Work. 23 | 24 | Works are “made available” under this license by including in or with 25 | the Work either (a) a copyright notice referencing the applicability of 26 | this license to the Work, or (b) a copy of this license. 27 | 28 | 2. License Grant 29 | 30 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, 31 | each Licensor grants to you a perpetual, worldwide, non-exclusive, 32 | royalty-free, copyright license to use, reproduce, prepare derivative 33 | works of, publicly display, publicly perform, sublicense and distribute 34 | its Work and any resulting derivative works in any form. 35 | 36 | 3. Limitations 37 | 38 | 3.1 Redistribution. You may reproduce or distribute the Work only if 39 | (a) you do so under this license, (b) you include a complete copy of 40 | this license with your distribution, and (c) you retain without 41 | modification any copyright, patent, trademark, or attribution notices 42 | that are present in the Work. 43 | 44 | 3.2 Derivative Works. You may specify that additional or different terms 45 | apply to the use, reproduction, and distribution of your derivative 46 | works of the Work (“Your Terms”) only if (a) Your Terms provide that the 47 | use limitation in Section 3.3 applies to your derivative works, and (b) 48 | you identify the specific derivative works that are subject to Your Terms. 49 | Notwithstanding Your Terms, this license (including the redistribution 50 | requirements in Section 3.1) will continue to apply to the Work itself. 51 | 52 | 3.3 Use Limitation. The Work and any derivative works thereof only may be 53 | used or intended for use non-commercially. Notwithstanding the foregoing, 54 | NVIDIA Corporation and its affiliates may use the Work and any derivative 55 | works commercially. As used herein, “non-commercially” means for research 56 | or evaluation purposes only. 57 | 58 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against 59 | any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) 60 | to enforce any patents that you allege are infringed by any Work, then your 61 | rights under this license from such Licensor (including the grant in 62 | Section 2.1) will terminate immediately. 63 | 64 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s 65 | or its affiliates’ names, logos, or trademarks, except as necessary to 66 | reproduce the notices described in this license. 67 | 68 | 3.6 Termination. If you violate any term of this license, then your rights 69 | under this license (including the grant in Section 2.1) will terminate 70 | immediately. 71 | 72 | 4. Disclaimer of Warranty. 73 | 74 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 75 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 76 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. 77 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 78 | 79 | 5. Limitation of Liability. 80 | 81 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, 82 | WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY 83 | LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, 84 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, 85 | THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF 86 | GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR 87 | MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 88 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 89 | 90 | ======================================================================= -------------------------------------------------------------------------------- /vox2seq/vox2seq/pytorch/z_order.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import Optional, Union 10 | 11 | 12 | class KeyLUT: 13 | def __init__(self): 14 | r256 = torch.arange(256, dtype=torch.int64) 15 | r512 = torch.arange(512, dtype=torch.int64) 16 | zero = torch.zeros(256, dtype=torch.int64) 17 | device = torch.device("cpu") 18 | 19 | self._encode = { 20 | device: ( 21 | self.xyz2key(r256, zero, zero, 8), 22 | self.xyz2key(zero, r256, zero, 8), 23 | self.xyz2key(zero, zero, r256, 8), 24 | ) 25 | } 26 | self._decode = {device: self.key2xyz(r512, 9)} 27 | 28 | def encode_lut(self, device=torch.device("cpu")): 29 | if device not in self._encode: 30 | cpu = torch.device("cpu") 31 | self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) 32 | return self._encode[device] 33 | 34 | def decode_lut(self, device=torch.device("cpu")): 35 | if device not in self._decode: 36 | cpu = torch.device("cpu") 37 | self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) 38 | return self._decode[device] 39 | 40 | def xyz2key(self, x, y, z, depth): 41 | key = torch.zeros_like(x) 42 | for i in range(depth): 43 | mask = 1 << i 44 | key = ( 45 | key 46 | | ((x & mask) << (2 * i + 2)) 47 | | ((y & mask) << (2 * i + 1)) 48 | | ((z & mask) << (2 * i + 0)) 49 | ) 50 | return key 51 | 52 | def key2xyz(self, key, depth): 53 | x = torch.zeros_like(key) 54 | y = torch.zeros_like(key) 55 | z = torch.zeros_like(key) 56 | for i in range(depth): 57 | x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) 58 | y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) 59 | z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) 60 | return x, y, z 61 | 62 | 63 | _key_lut = KeyLUT() 64 | 65 | 66 | def xyz2key( 67 | x: torch.Tensor, 68 | y: torch.Tensor, 69 | z: torch.Tensor, 70 | b: Optional[Union[torch.Tensor, int]] = None, 71 | depth: int = 16, 72 | ): 73 | r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys 74 | based on pre-computed look up tables. The speed of this function is much 75 | faster than the method based on for-loop. 76 | 77 | Args: 78 | x (torch.Tensor): The x coordinate. 79 | y (torch.Tensor): The y coordinate. 80 | z (torch.Tensor): The z coordinate. 81 | b (torch.Tensor or int): The batch index of the coordinates, and should be 82 | smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of 83 | :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. 84 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 85 | """ 86 | 87 | EX, EY, EZ = _key_lut.encode_lut(x.device) 88 | x, y, z = x.long(), y.long(), z.long() 89 | 90 | mask = 255 if depth > 8 else (1 << depth) - 1 91 | key = EX[x & mask] | EY[y & mask] | EZ[z & mask] 92 | if depth > 8: 93 | mask = (1 << (depth - 8)) - 1 94 | key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] 95 | key = key16 << 24 | key 96 | 97 | if b is not None: 98 | b = b.long() 99 | key = b << 48 | key 100 | 101 | return key 102 | 103 | 104 | def key2xyz(key: torch.Tensor, depth: int = 16): 105 | r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates 106 | and the batch index based on pre-computed look up tables. 107 | 108 | Args: 109 | key (torch.Tensor): The shuffled key. 110 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 111 | """ 112 | 113 | DX, DY, DZ = _key_lut.decode_lut(key.device) 114 | x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) 115 | 116 | b = key >> 48 117 | key = key & ((1 << 48) - 1) 118 | 119 | n = (depth + 2) // 3 120 | for i in range(n): 121 | k = key >> (i * 9) & 511 122 | x = x | (DX[k] << (i * 3)) 123 | y = y | (DY[k] << (i * 3)) 124 | z = z | (DZ[k] << (i * 3)) 125 | 126 | return x, y, z, b -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/decoder_rf.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from ...modules import sparse as sp 7 | from .base import SparseTransformerBase 8 | from ...representations import Strivec 9 | from ..sparse_elastic_mixin import SparseTransformerElasticMixin 10 | 11 | 12 | class SLatRadianceFieldDecoder(SparseTransformerBase): 13 | def __init__( 14 | self, 15 | resolution: int, 16 | model_channels: int, 17 | latent_channels: int, 18 | num_blocks: int, 19 | num_heads: Optional[int] = None, 20 | num_head_channels: Optional[int] = 64, 21 | mlp_ratio: float = 4, 22 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 23 | window_size: int = 8, 24 | pe_mode: Literal["ape", "rope"] = "ape", 25 | use_fp16: bool = False, 26 | use_checkpoint: bool = False, 27 | qk_rms_norm: bool = False, 28 | representation_config: dict = None, 29 | ): 30 | super().__init__( 31 | in_channels=latent_channels, 32 | model_channels=model_channels, 33 | num_blocks=num_blocks, 34 | num_heads=num_heads, 35 | num_head_channels=num_head_channels, 36 | mlp_ratio=mlp_ratio, 37 | attn_mode=attn_mode, 38 | window_size=window_size, 39 | pe_mode=pe_mode, 40 | use_fp16=use_fp16, 41 | use_checkpoint=use_checkpoint, 42 | qk_rms_norm=qk_rms_norm, 43 | ) 44 | self.resolution = resolution 45 | self.rep_config = representation_config 46 | self._calc_layout() 47 | self.out_layer = sp.SparseLinear(model_channels, self.out_channels) 48 | 49 | self.initialize_weights() 50 | if use_fp16: 51 | self.convert_to_fp16() 52 | 53 | def initialize_weights(self) -> None: 54 | super().initialize_weights() 55 | # Zero-out output layers: 56 | nn.init.constant_(self.out_layer.weight, 0) 57 | nn.init.constant_(self.out_layer.bias, 0) 58 | 59 | def _calc_layout(self) -> None: 60 | self.layout = { 61 | 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']}, 62 | 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']}, 63 | 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3}, 64 | } 65 | start = 0 66 | for k, v in self.layout.items(): 67 | v['range'] = (start, start + v['size']) 68 | start += v['size'] 69 | self.out_channels = start 70 | 71 | def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: 72 | """ 73 | Convert a batch of network outputs to 3D representations. 74 | 75 | Args: 76 | x: The [N x * x C] sparse tensor output by the network. 77 | 78 | Returns: 79 | list of representations 80 | """ 81 | ret = [] 82 | for i in range(x.shape[0]): 83 | representation = Strivec( 84 | sh_degree=0, 85 | resolution=self.resolution, 86 | aabb=[-0.5, -0.5, -0.5, 1, 1, 1], 87 | rank=self.rep_config['rank'], 88 | dim=self.rep_config['dim'], 89 | device='cuda', 90 | ) 91 | representation.density_shift = 0.0 92 | representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution 93 | representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') 94 | for k, v in self.layout.items(): 95 | setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])) 96 | representation.trivec = representation.trivec + 1 97 | ret.append(representation) 98 | return ret 99 | 100 | def forward(self, x: sp.SparseTensor) -> List[Strivec]: 101 | h = super().forward(x) 102 | h = h.type(x.dtype) 103 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 104 | h = self.out_layer(h) 105 | return self.to_representation(h) 106 | 107 | 108 | class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder): 109 | """ 110 | Slat VAE Radiance Field Decoder with elastic memory management. 111 | Used for training with low VRAM. 112 | """ 113 | pass 114 | -------------------------------------------------------------------------------- /trellis/renderers/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /trellis/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | __attributes = { 3 | 'SparseStructureEncoder': 'sparse_structure_vae', 4 | 'SparseStructureDecoder': 'sparse_structure_vae', 5 | 6 | 'SparseStructureFlowModel': 'sparse_structure_flow', 7 | 8 | 'SLatEncoder': 'structured_latent_vae', 9 | 'SLatGaussianDecoder': 'structured_latent_vae', 10 | 'SLatRadianceFieldDecoder': 'structured_latent_vae', 11 | 'SLatMeshDecoder': 'structured_latent_vae', 12 | 'SLatMeshDecodernew': 'structured_latent_vae', 13 | 'ElasticSLatEncoder': 'structured_latent_vae', 14 | 'ElasticSLatGaussianDecoder': 'structured_latent_vae', 15 | 'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae', 16 | 'ElasticSLatMeshDecoder': 'structured_latent_vae', 17 | 'ElasticSLatMeshDecodernew': 'structured_latent_vae', 18 | 19 | 'PropertyEncoder': 'structured_latent_vae', 20 | 'ElasticPropertyEncoder': 'structured_latent_vae', 21 | 'PropertyOutput': 'structured_latent_vae', 22 | 'PropertyDecoder': 'structured_latent_vae', 23 | 'ElasticPropertyDecoder': 'structured_latent_vae', 24 | 25 | 'SLatFlowModel': 'structured_latent_flow', 26 | 'ElasticSLatFlowModel': 'structured_latent_flow', 27 | 'SLatFlowModelphy': 'structured_latent_flow', 28 | 'ElasticSLatFlowModelphy': 'structured_latent_flow', 29 | } 30 | 31 | __submodules = [] 32 | 33 | __all__ = list(__attributes.keys()) + __submodules 34 | 35 | def __getattr__(name): 36 | if name not in globals(): 37 | if name in __attributes: 38 | module_name = __attributes[name] 39 | module = importlib.import_module(f".{module_name}", __name__) 40 | globals()[name] = getattr(module, name) 41 | elif name in __submodules: 42 | module = importlib.import_module(f".{name}", __name__) 43 | globals()[name] = module 44 | else: 45 | raise AttributeError(f"module {__name__} has no attribute {name}") 46 | return globals()[name] 47 | 48 | 49 | def from_pretrained(path: str, **kwargs): 50 | """ 51 | Load a model from a pretrained checkpoint. 52 | 53 | Args: 54 | path: The path to the checkpoint. Can be either local path or a Hugging Face model name. 55 | NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. 56 | **kwargs: Additional arguments for the model constructor. 57 | """ 58 | import os 59 | import json 60 | from safetensors.torch import load_file 61 | is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") 62 | 63 | if is_local: 64 | config_file = f"{path}.json" 65 | model_file = f"{path}.safetensors" 66 | else: 67 | from huggingface_hub import hf_hub_download 68 | path_parts = path.split('/') 69 | repo_id = f'{path_parts[0]}/{path_parts[1]}' 70 | model_name = '/'.join(path_parts[2:]) 71 | config_file = hf_hub_download(repo_id, f"{model_name}.json") 72 | model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") 73 | 74 | with open(config_file, 'r') as f: 75 | config = json.load(f) 76 | model = __getattr__(config['name'])(**config['args'], **kwargs) 77 | model.load_state_dict(load_file(model_file)) 78 | 79 | return model 80 | 81 | 82 | def from_pretrained_config(path: str, **kwargs): 83 | """ 84 | Load a model from a pretrained checkpoint. 85 | 86 | Args: 87 | path: The path to the checkpoint. Can be either local path or a Hugging Face model name. 88 | NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. 89 | **kwargs: Additional arguments for the model constructor. 90 | """ 91 | import os 92 | import json 93 | import torch 94 | 95 | config_file = path[:-3]+'.json' 96 | model_file = path 97 | 98 | with open(config_file, 'r') as f: 99 | config = json.load(f) 100 | model = __getattr__(config['name'])(**config['args'], **kwargs) 101 | model.load_state_dict(torch.load(model_file, weights_only=True)) 102 | 103 | return model 104 | 105 | # For Pylance 106 | if __name__ == '__main__': 107 | from .sparse_structure_vae import ( 108 | SparseStructureEncoder, 109 | SparseStructureDecoder, 110 | ) 111 | 112 | from .sparse_structure_flow import SparseStructureFlowModel 113 | 114 | from .structured_latent_vae import ( 115 | SLatEncoder, 116 | SLatGaussianDecoder, 117 | SLatRadianceFieldDecoder, 118 | SLatMeshDecoder, 119 | ElasticSLatEncoder, 120 | ElasticSLatGaussianDecoder, 121 | ElasticSLatRadianceFieldDecoder, 122 | ElasticSLatMeshDecoder, 123 | ) 124 | 125 | from .structured_latent_flow import ( 126 | SLatFlowModel, 127 | ElasticSLatFlowModel, 128 | ) 129 | -------------------------------------------------------------------------------- /trellis/modules/sparse/spatial.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from . import SparseTensor 5 | 6 | __all__ = [ 7 | 'SparseDownsample', 8 | 'SparseUpsample', 9 | 'SparseSubdivide' 10 | ] 11 | 12 | 13 | class SparseDownsample(nn.Module): 14 | """ 15 | Downsample a sparse tensor by a factor of `factor`. 16 | Implemented as average pooling. 17 | """ 18 | def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): 19 | super(SparseDownsample, self).__init__() 20 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor 21 | 22 | def forward(self, input: SparseTensor) -> SparseTensor: 23 | DIM = input.coords.shape[-1] - 1 24 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM 25 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' 26 | 27 | coord = list(input.coords.unbind(dim=-1)) 28 | for i, f in enumerate(factor): 29 | coord[i+1] = coord[i+1] // f 30 | 31 | MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] 32 | OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] 33 | code = sum([c * o for c, o in zip(coord, OFFSET)]) 34 | code, idx = code.unique(return_inverse=True) 35 | 36 | new_feats = torch.scatter_reduce( 37 | torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), 38 | dim=0, 39 | index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), 40 | src=input.feats, 41 | reduce='mean' 42 | ) 43 | new_coords = torch.stack( 44 | [code // OFFSET[0]] + 45 | [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], 46 | dim=-1 47 | ) 48 | out = SparseTensor(new_feats, new_coords, input.shape,) 49 | out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) 50 | out._spatial_cache = input._spatial_cache 51 | 52 | out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) 53 | out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) 54 | out.register_spatial_cache(f'upsample_{factor}_idx', idx) 55 | 56 | return out 57 | 58 | 59 | class SparseUpsample(nn.Module): 60 | """ 61 | Upsample a sparse tensor by a factor of `factor`. 62 | Implemented as nearest neighbor interpolation. 63 | """ 64 | def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): 65 | super(SparseUpsample, self).__init__() 66 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor 67 | 68 | def forward(self, input: SparseTensor) -> SparseTensor: 69 | DIM = input.coords.shape[-1] - 1 70 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM 71 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' 72 | 73 | new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') 74 | new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') 75 | idx = input.get_spatial_cache(f'upsample_{factor}_idx') 76 | if any([x is None for x in [new_coords, new_layout, idx]]): 77 | raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') 78 | new_feats = input.feats[idx] 79 | out = SparseTensor(new_feats, new_coords, input.shape, new_layout) 80 | out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) 81 | out._spatial_cache = input._spatial_cache 82 | return out 83 | 84 | class SparseSubdivide(nn.Module): 85 | """ 86 | Upsample a sparse tensor by a factor of `factor`. 87 | Implemented as nearest neighbor interpolation. 88 | """ 89 | def __init__(self): 90 | super(SparseSubdivide, self).__init__() 91 | 92 | def forward(self, input: SparseTensor) -> SparseTensor: 93 | DIM = input.coords.shape[-1] - 1 94 | # upsample scale=2^DIM 95 | n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) 96 | n_coords = torch.nonzero(n_cube) 97 | n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) 98 | factor = n_coords.shape[0] 99 | assert factor == 2 ** DIM 100 | # print(n_coords.shape) 101 | new_coords = input.coords.clone() 102 | new_coords[:, 1:] *= 2 103 | new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) 104 | 105 | new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) 106 | out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) 107 | out._scale = input._scale * 2 108 | out._spatial_cache = input._spatial_cache 109 | return out 110 | 111 | -------------------------------------------------------------------------------- /configs/generation/slat_flow_img_dit_L_phy.json: -------------------------------------------------------------------------------- 1 | { 2 | "models": { 3 | "denoiser": { 4 | "name": "ElasticSLatFlowModel", 5 | "args": { 6 | "resolution": 64, 7 | "in_channels": 8, 8 | "out_channels": 8, 9 | "model_channels": 1024, 10 | "cond_channels": 1024, 11 | "num_blocks": 24, 12 | "num_heads": 16, 13 | "mlp_ratio": 4, 14 | "patch_size": 2, 15 | "num_io_res_blocks": 2, 16 | "io_block_channels": [128], 17 | "pe_mode": "ape", 18 | "qk_rms_norm": true, 19 | "use_fp16": true 20 | } 21 | }, 22 | "denoiser_phy": { 23 | "name": "ElasticSLatFlowModelphy", 24 | "args": { 25 | "resolution": 64, 26 | "in_channels": 8, 27 | "out_channels": 8, 28 | "model_channels": 1024, 29 | "cond_channels": 1024, 30 | "num_blocks": 14, 31 | "num_heads": 16, 32 | "mlp_ratio": 4, 33 | "patch_size": 2, 34 | "num_io_res_blocks": 2, 35 | "io_block_channels": [128], 36 | "pe_mode": "ape", 37 | "qk_rms_norm": true, 38 | "use_fp16": true 39 | } 40 | } 41 | }, 42 | "dataset": { 43 | "name": "ImageConditionedSLat", 44 | "args": { 45 | "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16", 46 | "min_aesthetic_score": 4.5, 47 | "max_num_voxels": 28000, 48 | "image_size": 518, 49 | "normalization": { 50 | "mean": [ 51 | -2.1687545776367188, 52 | -0.004347046371549368, 53 | -0.13352349400520325, 54 | -0.08418072760105133, 55 | -0.5271206498146057, 56 | 0.7238689064979553, 57 | -1.1414450407028198, 58 | 1.2039363384246826 59 | ], 60 | "std": [ 61 | 2.377650737762451, 62 | 2.386378288269043, 63 | 2.124418020248413, 64 | 2.1748552322387695, 65 | 2.663944721221924, 66 | 2.371192216873169, 67 | 2.6217446327209473, 68 | 2.684523105621338 69 | ], 70 | "mean_phy": [ 71 | -2.1507165, 72 | -0.9456348, 73 | -2.0234883, 74 | -0.5949867, 75 | -3.608296 , 76 | -1.062877 , 77 | -3.288852 , 78 | -1.0749111 79 | ], 80 | "std_phy": [ 81 | 0.6931998 , 82 | 0.9221464 , 83 | 0.6542199 , 84 | 0.6594776 , 85 | 0.8451334 , 86 | 0.594917 , 87 | 0.69759405, 88 | 1.1614994 89 | ] 90 | }, 91 | "slat_dec_path": "./pretrain/vae", 92 | "slat_dec_ckpt": "100000", 93 | "latent_model_phy":"dinov2_vitl14_reg_physxgen_100000" 94 | } 95 | }, 96 | "trainer": { 97 | "name": "ImageConditionedSparseFlowMatchingCFGTrainerphy", 98 | "args": { 99 | "onlyphy_property_gen": true, 100 | "max_steps": 1000000, 101 | "batch_size_per_gpu": 16, 102 | "batch_split": 16, 103 | "optimizer": { 104 | "name": "AdamW", 105 | "args": { 106 | "lr": 0.0001, 107 | "weight_decay": 0.0 108 | } 109 | }, 110 | "ema_rate": [ 111 | 0.9999 112 | ], 113 | "fp16_mode": "inflat_all", 114 | "fp16_scale_growth": 0.001, 115 | "elastic": { 116 | "name": "LinearMemoryController", 117 | "args": { 118 | "target_ratio": 0.6, 119 | "max_mem_ratio_start": 0.5 120 | } 121 | }, 122 | "grad_clip": { 123 | "name": "AdaptiveGradClipper", 124 | "args": { 125 | "max_norm": 1.0, 126 | "clip_percentile": 95 127 | } 128 | }, 129 | "i_log": 10, 130 | "i_sample": 10000, 131 | "i_save": 10000, 132 | "p_uncond": 0.1, 133 | "t_schedule": { 134 | "name": "logitNormal", 135 | "args": { 136 | "mean": 1.0, 137 | "std": 1.0 138 | } 139 | }, 140 | "sigma_min": 1e-5, 141 | "image_cond_model": "dinov2_vitl14_reg" 142 | } 143 | } 144 | } -------------------------------------------------------------------------------- /trellis/datasets/sparse_feat2render.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import json 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import utils3d.torch 8 | from ..modules.sparse.basic import SparseTensor 9 | from .components import StandardDatasetBase 10 | 11 | 12 | class SparseFeat2Render(StandardDatasetBase): 13 | """ 14 | SparseFeat2Render dataset. 15 | 16 | Args: 17 | roots (str): paths to the dataset 18 | image_size (int): size of the image 19 | model (str): model name 20 | resolution (int): resolution of the data 21 | min_aesthetic_score (float): minimum aesthetic score 22 | max_num_voxels (int): maximum number of voxels 23 | """ 24 | def __init__( 25 | self, 26 | roots: str, 27 | image_size: int, 28 | model: str = 'dinov2_vitl14_reg', 29 | resolution: int = 64, 30 | min_aesthetic_score: float = 5.0, 31 | max_num_voxels: int = 32768, 32 | ): 33 | self.image_size = image_size 34 | self.model = model 35 | self.resolution = resolution 36 | self.min_aesthetic_score = min_aesthetic_score 37 | self.max_num_voxels = max_num_voxels 38 | self.value_range = (0, 1) 39 | 40 | super().__init__(roots) 41 | 42 | def filter_metadata(self, metadata): 43 | stats = {} 44 | metadata = metadata[metadata[f'feature_{self.model}']] 45 | stats['With features'] = len(metadata) 46 | metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] 47 | stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) 48 | metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels] 49 | stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata) 50 | return metadata, stats 51 | 52 | def _get_image(self, root, instance): 53 | with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f: 54 | metadata = json.load(f) 55 | n_views = len(metadata['frames']) 56 | view = np.random.randint(n_views) 57 | metadata = metadata['frames'][view] 58 | fov = metadata['camera_angle_x'] 59 | intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov)) 60 | c2w = torch.tensor(metadata['transform_matrix']) 61 | c2w[:3, 1:3] *= -1 62 | extrinsics = torch.inverse(c2w) 63 | 64 | image_path = os.path.join(root, 'renders', instance, metadata['file_path']) 65 | image = Image.open(image_path) 66 | alpha = image.getchannel(3) 67 | image = image.convert('RGB') 68 | image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) 69 | alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) 70 | image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 71 | alpha = torch.tensor(np.array(alpha)).float() / 255.0 72 | 73 | return { 74 | 'image': image, 75 | 'alpha': alpha, 76 | 'extrinsics': extrinsics, 77 | 'intrinsics': intrinsics, 78 | } 79 | 80 | def _get_feat(self, root, instance): 81 | DATA_RESOLUTION = 64 82 | feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz') 83 | feats = np.load(feats_path, allow_pickle=True) 84 | coords = torch.tensor(feats['indices']).int() 85 | feats = torch.tensor(feats['patchtokens']).float() 86 | 87 | if self.resolution != DATA_RESOLUTION: 88 | factor = DATA_RESOLUTION // self.resolution 89 | coords = coords // factor 90 | coords, idx = coords.unique(return_inverse=True, dim=0) 91 | feats = torch.scatter_reduce( 92 | torch.zeros(coords.shape[0], feats.shape[1], device=feats.device), 93 | dim=0, 94 | index=idx.unsqueeze(-1).expand(-1, feats.shape[1]), 95 | src=feats, 96 | reduce='mean' 97 | ) 98 | 99 | return { 100 | 'coords': coords, 101 | 'feats': feats, 102 | } 103 | 104 | @torch.no_grad() 105 | def visualize_sample(self, sample: dict): 106 | return sample['image'] 107 | 108 | @staticmethod 109 | def collate_fn(batch): 110 | pack = {} 111 | coords = [] 112 | for i, b in enumerate(batch): 113 | coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1)) 114 | coords = torch.cat(coords) 115 | feats = torch.cat([b['feats'] for b in batch]) 116 | pack['feats'] = SparseTensor( 117 | coords=coords, 118 | feats=feats, 119 | ) 120 | 121 | pack['image'] = torch.stack([b['image'] for b in batch]) 122 | pack['alpha'] = torch.stack([b['alpha'] for b in batch]) 123 | pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch]) 124 | pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch]) 125 | 126 | return pack 127 | 128 | def get_instance(self, root, instance): 129 | image = self._get_image(root, instance) 130 | feat = self._get_feat(root, instance) 131 | return { 132 | **image, 133 | **feat, 134 | } 135 | -------------------------------------------------------------------------------- /trellis/datasets/components.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from abc import abstractmethod 3 | import os 4 | import json 5 | import torch 6 | import numpy as np 7 | import pandas as pd 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class StandardDatasetBase(Dataset): 13 | """ 14 | Base class for standard datasets. 15 | 16 | Args: 17 | roots (str): paths to the dataset 18 | """ 19 | 20 | def __init__(self, 21 | roots: str, 22 | ): 23 | super().__init__() 24 | self.roots = roots.split(',') 25 | self.instances = [] 26 | self.metadata = pd.DataFrame() 27 | 28 | self._stats = {} 29 | for root in self.roots: 30 | key = os.path.basename(root) 31 | self._stats[key] = {} 32 | metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))#[:10] 33 | self._stats[key]['Total'] = len(metadata) 34 | metadata, stats = self.filter_metadata(metadata) 35 | self._stats[key].update(stats) 36 | self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values]) 37 | metadata.set_index('sha256', inplace=True) 38 | 39 | self.metadata = pd.concat([self.metadata, metadata]) 40 | 41 | @abstractmethod 42 | def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]: 43 | pass 44 | 45 | @abstractmethod 46 | def get_instance(self, root: str, instance: str) -> Dict[str, Any]: 47 | pass 48 | 49 | def __len__(self): 50 | return len(self.instances) 51 | 52 | def __getitem__(self, index) -> Dict[str, Any]: 53 | try: 54 | root, instance = self.instances[index] 55 | return self.get_instance(root, instance) 56 | except Exception as e: 57 | print(e) 58 | return self.__getitem__(np.random.randint(0, len(self))) 59 | 60 | def __str__(self): 61 | lines = [] 62 | lines.append(self.__class__.__name__) 63 | lines.append(f' - Total instances: {len(self)}') 64 | lines.append(f' - Sources:') 65 | for key, stats in self._stats.items(): 66 | lines.append(f' - {key}:') 67 | for k, v in stats.items(): 68 | lines.append(f' - {k}: {v}') 69 | return '\n'.join(lines) 70 | 71 | 72 | class TextConditionedMixin: 73 | def __init__(self, roots, **kwargs): 74 | super().__init__(roots, **kwargs) 75 | self.captions = {} 76 | for instance in self.instances: 77 | sha256 = instance[1] 78 | self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions']) 79 | 80 | def filter_metadata(self, metadata): 81 | metadata, stats = super().filter_metadata(metadata) 82 | metadata = metadata[metadata['captions'].notna()] 83 | stats['With captions'] = len(metadata) 84 | return metadata, stats 85 | 86 | def get_instance(self, root, instance): 87 | pack = super().get_instance(root, instance) 88 | text = np.random.choice(self.captions[instance]) 89 | pack['cond'] = text 90 | return pack 91 | 92 | 93 | class ImageConditionedMixin: 94 | def __init__(self, roots, *, image_size=518, **kwargs): 95 | self.image_size = image_size 96 | super().__init__(roots, **kwargs) 97 | 98 | def filter_metadata(self, metadata): 99 | metadata, stats = super().filter_metadata(metadata) 100 | metadata = metadata[metadata[f'cond_rendered']] 101 | stats['Cond rendered'] = len(metadata) 102 | return metadata, stats 103 | 104 | def get_instance(self, root, instance): 105 | pack = super().get_instance(root, instance) 106 | 107 | image_root = os.path.join(root, 'renders_cond', instance) 108 | with open(os.path.join(image_root, 'transforms.json')) as f: 109 | metadata = json.load(f) 110 | n_views = len(metadata['frames']) 111 | view = np.random.randint(n_views) 112 | metadata = metadata['frames'][view] 113 | 114 | image_path = os.path.join(image_root, metadata['file_path']) 115 | image = Image.open(image_path) 116 | 117 | alpha = np.array(image.getchannel(3)) 118 | bbox = np.array(alpha).nonzero() 119 | bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] 120 | center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] 121 | hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 122 | aug_size_ratio = 1.2 123 | aug_hsize = hsize * aug_size_ratio 124 | aug_center_offset = [0, 0] 125 | aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] 126 | aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] 127 | image = image.crop(aug_bbox) 128 | 129 | image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) 130 | alpha = image.getchannel(3) 131 | image = image.convert('RGB') 132 | image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 133 | alpha = torch.tensor(np.array(alpha)).float() / 255.0 134 | image = image * alpha.unsqueeze(0) 135 | pack['cond'] = image 136 | 137 | return pack 138 | -------------------------------------------------------------------------------- /trellis/trainers/vae/sparse_structure_vae.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import copy 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | from easydict import EasyDict as edict 7 | 8 | from ..basic import BasicTrainer 9 | 10 | 11 | class SparseStructureVaeTrainer(BasicTrainer): 12 | """ 13 | Trainer for Sparse Structure VAE. 14 | 15 | Args: 16 | models (dict[str, nn.Module]): Models to train. 17 | dataset (torch.utils.data.Dataset): Dataset. 18 | output_dir (str): Output directory. 19 | load_dir (str): Load directory. 20 | step (int): Step to load. 21 | batch_size (int): Batch size. 22 | batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. 23 | batch_split (int): Split batch with gradient accumulation. 24 | max_steps (int): Max steps. 25 | optimizer (dict): Optimizer config. 26 | lr_scheduler (dict): Learning rate scheduler config. 27 | elastic (dict): Elastic memory management config. 28 | grad_clip (float or dict): Gradient clip config. 29 | ema_rate (float or list): Exponential moving average rates. 30 | fp16_mode (str): FP16 mode. 31 | - None: No FP16. 32 | - 'inflat_all': Hold a inflated fp32 master param for all params. 33 | - 'amp': Automatic mixed precision. 34 | fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. 35 | finetune_ckpt (dict): Finetune checkpoint. 36 | log_param_stats (bool): Log parameter stats. 37 | i_print (int): Print interval. 38 | i_log (int): Log interval. 39 | i_sample (int): Sample interval. 40 | i_save (int): Save interval. 41 | i_ddpcheck (int): DDP check interval. 42 | 43 | loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss. 44 | lambda_kl (float): KL divergence loss weight. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | *args, 50 | loss_type='bce', 51 | lambda_kl=1e-6, 52 | **kwargs 53 | ): 54 | super().__init__(*args, **kwargs) 55 | self.loss_type = loss_type 56 | self.lambda_kl = lambda_kl 57 | 58 | def training_losses( 59 | self, 60 | ss: torch.Tensor, 61 | **kwargs 62 | ) -> Tuple[Dict, Dict]: 63 | """ 64 | Compute training losses. 65 | 66 | Args: 67 | ss: The [N x 1 x H x W x D] tensor of binary sparse structure. 68 | 69 | Returns: 70 | a dict with the key "loss" containing a scalar tensor. 71 | may also contain other keys for different terms. 72 | """ 73 | z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True) 74 | logits = self.training_models['decoder'](z) 75 | 76 | terms = edict(loss = 0.0) 77 | if self.loss_type == 'bce': 78 | terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean') 79 | terms["loss"] = terms["loss"] + terms["bce"] 80 | elif self.loss_type == 'l1': 81 | terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean') 82 | terms["loss"] = terms["loss"] + terms["l1"] 83 | elif self.loss_type == 'dice': 84 | logits = F.sigmoid(logits) 85 | terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1) 86 | terms["loss"] = terms["loss"] + terms["dice"] 87 | else: 88 | raise ValueError(f'Invalid loss type {self.loss_type}') 89 | terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) 90 | terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] 91 | 92 | return terms, {} 93 | 94 | @torch.no_grad() 95 | def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False): 96 | super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose) 97 | 98 | @torch.no_grad() 99 | def run_snapshot( 100 | self, 101 | num_samples: int, 102 | batch_size: int, 103 | verbose: bool = False, 104 | ) -> Dict: 105 | dataloader = DataLoader( 106 | copy.deepcopy(self.dataset), 107 | batch_size=batch_size, 108 | shuffle=True, 109 | num_workers=0, 110 | collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, 111 | ) 112 | 113 | # inference 114 | gts = [] 115 | recons = [] 116 | for i in range(0, num_samples, batch_size): 117 | batch = min(batch_size, num_samples - i) 118 | data = next(iter(dataloader)) 119 | args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} 120 | z = self.models['encoder'](args['ss'].float(), sample_posterior=False) 121 | logits = self.models['decoder'](z) 122 | recon = (logits > 0).long() 123 | gts.append(args['ss']) 124 | recons.append(recon) 125 | 126 | sample_dict = { 127 | 'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'}, 128 | 'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'}, 129 | } 130 | return sample_dict 131 | -------------------------------------------------------------------------------- /dataset_toolkits/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import sys 5 | import importlib 6 | import argparse 7 | import pandas as pd 8 | from easydict import EasyDict as edict 9 | from functools import partial 10 | from subprocess import DEVNULL, call 11 | import numpy as np 12 | from utils import sphere_hammersley_sequence 13 | 14 | 15 | 16 | BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' 17 | BLENDER_INSTALLATION_PATH = '/tmp' 18 | BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' 19 | 20 | def _install_blender(): 21 | if not os.path.exists(BLENDER_PATH): 22 | os.system('sudo apt-get update') 23 | os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') 24 | os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') 25 | os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') 26 | 27 | 28 | def _render(file_path, sha256, output_dir, num_views): 29 | output_folder = os.path.join(output_dir, 'renders', sha256) 30 | 31 | # Build camera {yaw, pitch, radius, fov} 32 | yaws = [] 33 | pitchs = [] 34 | offset = (np.random.rand(), np.random.rand()) 35 | for i in range(num_views): 36 | y, p = sphere_hammersley_sequence(i, num_views, offset) 37 | yaws.append(y) 38 | pitchs.append(p) 39 | radius = [2] * num_views 40 | fov = [40 / 180 * np.pi] * num_views 41 | views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] 42 | 43 | args = [ 44 | BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), 45 | '--', 46 | '--views', json.dumps(views), 47 | '--object', os.path.expanduser(file_path), 48 | '--resolution', '512', 49 | '--output_folder', output_folder, 50 | '--engine', 'CYCLES', 51 | '--save_mesh', 52 | ] 53 | if file_path.endswith('.blend'): 54 | args.insert(1, file_path) 55 | 56 | call(args, stdout=DEVNULL, stderr=DEVNULL) 57 | 58 | if os.path.exists(os.path.join(output_folder, 'transforms.json')): 59 | return {'sha256': sha256, 'rendered': True} 60 | 61 | 62 | if __name__ == '__main__': 63 | dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--output_dir', type=str, required=True, 67 | help='Directory to save the metadata') 68 | parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, 69 | help='Filter objects with aesthetic score lower than this value') 70 | parser.add_argument('--instances', type=str, default=None, 71 | help='Instances to process') 72 | parser.add_argument('--num_views', type=int, default=150, 73 | help='Number of views to render') 74 | dataset_utils.add_args(parser) 75 | parser.add_argument('--rank', type=int, default=0) 76 | parser.add_argument('--world_size', type=int, default=1) 77 | parser.add_argument('--max_workers', type=int, default=0) 78 | opt = parser.parse_args(sys.argv[2:]) 79 | opt = edict(vars(opt)) 80 | 81 | os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True) 82 | 83 | # install blender 84 | print('Checking blender...', flush=True) 85 | _install_blender() 86 | 87 | # get file list 88 | if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): 89 | raise ValueError('metadata.csv not found') 90 | metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) 91 | 92 | if opt.instances is None: 93 | metadata = metadata[metadata['local_path'].notna()] 94 | if opt.filter_low_aesthetic_score is not None: 95 | metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] 96 | if 'rendered' in metadata.columns: 97 | metadata = metadata[metadata['rendered'] == False] 98 | else: 99 | if os.path.exists(opt.instances): 100 | with open(opt.instances, 'r') as f: 101 | instances = f.read().splitlines() 102 | else: 103 | instances = opt.instances.split(',') 104 | metadata = metadata[metadata['sha256'].isin(instances)] 105 | 106 | 107 | records = [] 108 | 109 | # filter out objects that are already processed 110 | for sha256 in copy.copy(metadata['sha256'].values): 111 | 112 | if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')): 113 | records.append({'sha256': sha256, 'rendered': True}) 114 | metadata = metadata[metadata['sha256'] != sha256] 115 | 116 | start = len(metadata) * opt.rank // opt.world_size 117 | end = len(metadata) * (opt.rank + 1) // opt.world_size 118 | metadata = metadata[start:end] 119 | print(f'Processing {len(metadata)} objects...') 120 | 121 | # process objects 122 | func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views) 123 | rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') 124 | rendered = pd.concat([rendered, pd.DataFrame.from_records(records)]) 125 | rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False) 126 | -------------------------------------------------------------------------------- /trellis/modules/attention/full_attn.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import math 4 | from . import DEBUG, BACKEND 5 | 6 | if BACKEND == 'xformers': 7 | import xformers.ops as xops 8 | elif BACKEND == 'flash_attn': 9 | import flash_attn 10 | elif BACKEND == 'sdpa': 11 | from torch.nn.functional import scaled_dot_product_attention as sdpa 12 | elif BACKEND == 'naive': 13 | pass 14 | else: 15 | raise ValueError(f"Unknown attention backend: {BACKEND}") 16 | 17 | 18 | __all__ = [ 19 | 'scaled_dot_product_attention', 20 | ] 21 | 22 | 23 | def _naive_sdpa(q, k, v): 24 | """ 25 | Naive implementation of scaled dot product attention. 26 | """ 27 | q = q.permute(0, 2, 1, 3) # [N, H, L, C] 28 | k = k.permute(0, 2, 1, 3) # [N, H, L, C] 29 | v = v.permute(0, 2, 1, 3) # [N, H, L, C] 30 | scale_factor = 1 / math.sqrt(q.size(-1)) 31 | attn_weight = q @ k.transpose(-2, -1) * scale_factor 32 | attn_weight = torch.softmax(attn_weight, dim=-1) 33 | out = attn_weight @ v 34 | out = out.permute(0, 2, 1, 3) # [N, L, H, C] 35 | return out 36 | 37 | 38 | @overload 39 | def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: 40 | """ 41 | Apply scaled dot product attention. 42 | 43 | Args: 44 | qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. 45 | """ 46 | ... 47 | 48 | @overload 49 | def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: 50 | """ 51 | Apply scaled dot product attention. 52 | 53 | Args: 54 | q (torch.Tensor): A [N, L, H, C] tensor containing Qs. 55 | kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. 56 | """ 57 | ... 58 | 59 | @overload 60 | def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 61 | """ 62 | Apply scaled dot product attention. 63 | 64 | Args: 65 | q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. 66 | k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. 67 | v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. 68 | 69 | Note: 70 | k and v are assumed to have the same coordinate map. 71 | """ 72 | ... 73 | 74 | def scaled_dot_product_attention(*args, **kwargs): 75 | arg_names_dict = { 76 | 1: ['qkv'], 77 | 2: ['q', 'kv'], 78 | 3: ['q', 'k', 'v'] 79 | } 80 | num_all_args = len(args) + len(kwargs) 81 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" 82 | for key in arg_names_dict[num_all_args][len(args):]: 83 | assert key in kwargs, f"Missing argument {key}" 84 | 85 | if num_all_args == 1: 86 | qkv = args[0] if len(args) > 0 else kwargs['qkv'] 87 | assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" 88 | device = qkv.device 89 | 90 | elif num_all_args == 2: 91 | q = args[0] if len(args) > 0 else kwargs['q'] 92 | kv = args[1] if len(args) > 1 else kwargs['kv'] 93 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" 94 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" 95 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" 96 | device = q.device 97 | 98 | elif num_all_args == 3: 99 | q = args[0] if len(args) > 0 else kwargs['q'] 100 | k = args[1] if len(args) > 1 else kwargs['k'] 101 | v = args[2] if len(args) > 2 else kwargs['v'] 102 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" 103 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" 104 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" 105 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" 106 | device = q.device 107 | 108 | if BACKEND == 'xformers': 109 | if num_all_args == 1: 110 | q, k, v = qkv.unbind(dim=2) 111 | elif num_all_args == 2: 112 | k, v = kv.unbind(dim=2) 113 | out = xops.memory_efficient_attention(q, k, v) 114 | elif BACKEND == 'flash_attn': 115 | if num_all_args == 1: 116 | out = flash_attn.flash_attn_qkvpacked_func(qkv) 117 | elif num_all_args == 2: 118 | out = flash_attn.flash_attn_kvpacked_func(q, kv) 119 | elif num_all_args == 3: 120 | out = flash_attn.flash_attn_func(q, k, v) 121 | elif BACKEND == 'sdpa': 122 | if num_all_args == 1: 123 | q, k, v = qkv.unbind(dim=2) 124 | elif num_all_args == 2: 125 | k, v = kv.unbind(dim=2) 126 | q = q.permute(0, 2, 1, 3) # [N, H, L, C] 127 | k = k.permute(0, 2, 1, 3) # [N, H, L, C] 128 | v = v.permute(0, 2, 1, 3) # [N, H, L, C] 129 | out = sdpa(q, k, v) # [N, H, L, C] 130 | out = out.permute(0, 2, 1, 3) # [N, L, H, C] 131 | elif BACKEND == 'naive': 132 | if num_all_args == 1: 133 | q, k, v = qkv.unbind(dim=2) 134 | elif num_all_args == 2: 135 | k, v = kv.unbind(dim=2) 136 | out = _naive_sdpa(q, k, v) 137 | else: 138 | raise ValueError(f"Unknown attention module: {BACKEND}") 139 | 140 | return out 141 | --------------------------------------------------------------------------------