├── assets └── docs │ ├── teaser.png │ └── local_gradio_example.png ├── requirements.txt ├── NOTICE ├── misc.py ├── triposf ├── modules │ ├── norm.py │ ├── transformer │ │ ├── __init__.py │ │ ├── modulated.py │ │ └── blocks.py │ ├── sparse │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ ├── blocks.py │ │ │ └── modulated.py │ │ ├── attention │ │ │ ├── __init__.py │ │ │ ├── windowed_attn.py │ │ │ ├── modules.py │ │ │ ├── serialized_attn.py │ │ │ └── full_attn.py │ │ ├── linear.py │ │ ├── conv │ │ │ ├── __init__.py │ │ │ ├── conv_torchsparse.py │ │ │ └── conv_spconv.py │ │ ├── nonlinearity.py │ │ ├── norm.py │ │ ├── __init__.py │ │ └── spatial.py │ ├── spatial.py │ ├── attention │ │ ├── __init__.py │ │ ├── full_attn.py │ │ └── modules.py │ ├── utils.py │ └── pointclouds │ │ └── pointnet.py ├── representations │ ├── __init__.py │ └── mesh │ │ ├── __init__.py │ │ ├── utils_cube.py │ │ ├── flexicubes │ │ └── LICENSE.txt │ │ └── cube2mesh.py ├── models │ ├── triposf_vae │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── base.py │ │ └── decoder.py │ └── __init__.py └── __init__.py ├── LICENSE ├── configs └── TripoSFVAE_1024.yaml ├── README.md ├── inference.py └── app.py /assets/docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TripoSF/HEAD/assets/docs/teaser.png -------------------------------------------------------------------------------- /assets/docs/local_gradio_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TripoSF/HEAD/assets/docs/local_gradio_example.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trimesh==4.5.3 2 | torch-scatter==2.1.2 3 | open3d==0.18.0 4 | numpy==1.22.2 5 | omegaconf==2.3.0 6 | flash-attn==2.5.9post1 7 | spconv 8 | safetensors 9 | easydict 10 | jaxtyping 11 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | TripoSF 2 | Copyright (c) 2025 VAST-AI-Research and contributors 3 | 4 | This project includes code from the following open source projects: 5 | 6 | TRELLIS 7 | Copyright (c) Microsoft Corporation 8 | License: MIT 9 | Source: https://github.com/Microsoft/TRELLIS 10 | 11 | FlexiCubes 12 | Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES 13 | License: Nvidia Source Code License 14 | Source: https://github.com/MaxtirError/FlexiCubes -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import importlib 4 | 5 | def find(cls_string): 6 | module_string = ".".join(cls_string.split(".")[:-1]) 7 | cls_name = cls_string.split(".")[-1] 8 | module = importlib.import_module(module_string, package=None) 9 | cls = getattr(module, cls_name) 10 | return cls 11 | 12 | def get_rank(): 13 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 14 | # therefore LOCAL_RANK needs to be checked first 15 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 16 | for key in rank_keys: 17 | rank = os.environ.get(key) 18 | if rank is not None: 19 | return int(rank) 20 | return 0 21 | 22 | def get_device(): 23 | return torch.device(f"cuda:{get_rank()}") -------------------------------------------------------------------------------- /triposf/modules/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LayerNorm32(nn.LayerNorm): 6 | def forward(self, x: torch.Tensor) -> torch.Tensor: 7 | return super().forward(x.float()).type(x.dtype) 8 | 9 | 10 | class GroupNorm32(nn.GroupNorm): 11 | """ 12 | A GroupNorm layer that converts to float32 before the forward pass. 13 | """ 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | return super().forward(x.float()).type(x.dtype) 16 | 17 | 18 | class ChannelLayerNorm32(LayerNorm32): 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | DIM = x.dim() 21 | x = x.permute(0, *range(2, DIM), 1).contiguous() 22 | x = super().forward(x) 23 | x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() 24 | return x 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2025 VAST-AI-Research and contributors. 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE 22 | -------------------------------------------------------------------------------- /configs/TripoSFVAE_1024.yaml: -------------------------------------------------------------------------------- 1 | weight: ckpts/pretrained_TripoSFVAE_256i1024o.safetensors 2 | 3 | resolution: 256 # 1024 // 4 4 | sample_points_num: 819_200 5 | pruning: false 6 | use_normals: true 7 | 8 | local_pc_encoder_cls: triposf.modules.pointclouds.pointnet.LocalPoolPointnet 9 | local_pc_encoder: 10 | in_channels: 6 # 3 + 3 11 | out_channels: 1024 12 | hidden_dim: 256 13 | scatter_type: mean 14 | n_blocks: 5 15 | 16 | encoder_cls: triposf.models.triposf_vae.encoder.TripoSFVAEEncoder 17 | encoder: 18 | resolution: ${resolution} 19 | in_channels: 1024 20 | model_channels: 768 21 | latent_channels: 8 22 | num_blocks: 12 23 | num_heads: 12 24 | num_head_channels: 64 25 | mlp_ratio: 4 26 | attn_mode: swin 27 | window_size: 8 28 | pe_mode: ape 29 | use_fp16: false 30 | use_checkpoint: true 31 | qk_rms_norm: false 32 | 33 | decoder_cls: triposf.models.triposf_vae.decoder.TripoSFVAEDecoder 34 | decoder: 35 | resolution: ${resolution} 36 | model_channels: 768 37 | latent_channels: 8 38 | num_blocks: 12 39 | num_heads: 12 40 | num_head_channels: 64 41 | mlp_ratio: 4 42 | attn_mode: swin 43 | window_size: 8 44 | pe_mode: ape 45 | use_fp16: false 46 | use_checkpoint: true 47 | qk_rms_norm: false 48 | representation_config: 49 | use_color: false 50 | -------------------------------------------------------------------------------- /triposf/representations/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from .mesh import MeshExtractResult 25 | -------------------------------------------------------------------------------- /triposf/models/triposf_vae/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2025 VAST-AI-Research and contributors. 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE 22 | 23 | from .encoder import TripoSFVAEEncoder 24 | from .decoder import TripoSFVAEDecoder 25 | -------------------------------------------------------------------------------- /triposf/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from .blocks import * 25 | from .modulated import * -------------------------------------------------------------------------------- /triposf/modules/sparse/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from .blocks import * 25 | from .modulated import * -------------------------------------------------------------------------------- /triposf/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from . import models 25 | from . import modules 26 | from . import representations -------------------------------------------------------------------------------- /triposf/representations/mesh/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult 25 | -------------------------------------------------------------------------------- /triposf/modules/sparse/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from .full_attn import * 25 | from .serialized_attn import * 26 | from .windowed_attn import * 27 | from .modules import * 28 | -------------------------------------------------------------------------------- /triposf/modules/sparse/linear.py: -------------------------------------------------------------------------------- 1 | 2 | # MIT License 3 | 4 | # Copyright (c) Microsoft Corporation. 5 | # Copyright (c) 2025 VAST-AI-Research and contributors. 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE 24 | 25 | import torch 26 | import torch.nn as nn 27 | from . import SparseTensor 28 | 29 | __all__ = [ 30 | 'SparseLinear' 31 | ] 32 | 33 | class SparseLinear(nn.Linear): 34 | def __init__(self, in_features, out_features, bias=True): 35 | super(SparseLinear, self).__init__(in_features, out_features, bias) 36 | 37 | def forward(self, input: SparseTensor) -> SparseTensor: 38 | return input.replace(super().forward(input.feats)) 39 | -------------------------------------------------------------------------------- /triposf/modules/sparse/conv/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from .. import BACKEND 25 | 26 | 27 | SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' 28 | 29 | def __from_env(): 30 | import os 31 | 32 | global SPCONV_ALGO 33 | env_spconv_algo = os.environ.get('SPCONV_ALGO') 34 | if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: 35 | SPCONV_ALGO = env_spconv_algo 36 | print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") 37 | 38 | 39 | __from_env() 40 | 41 | if BACKEND == 'torchsparse': 42 | from .conv_torchsparse import * 43 | elif BACKEND == 'spconv': 44 | from .conv_spconv import * 45 | -------------------------------------------------------------------------------- /triposf/modules/spatial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: 5 | """ 6 | 3D pixel shuffle. 7 | """ 8 | B, C, H, W, D = x.shape 9 | C_ = C // scale_factor**3 10 | x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) 11 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) 12 | x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) 13 | return x 14 | 15 | 16 | def patchify(x: torch.Tensor, patch_size: int): 17 | """ 18 | Patchify a tensor. 19 | 20 | Args: 21 | x (torch.Tensor): (N, C, *spatial) tensor 22 | patch_size (int): Patch size 23 | """ 24 | DIM = x.dim() - 2 25 | for d in range(2, DIM + 2): 26 | assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" 27 | 28 | x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) 29 | x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) 30 | x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) 31 | return x 32 | 33 | 34 | def unpatchify(x: torch.Tensor, patch_size: int): 35 | """ 36 | Unpatchify a tensor. 37 | 38 | Args: 39 | x (torch.Tensor): (N, C, *spatial) tensor 40 | patch_size (int): Patch size 41 | """ 42 | DIM = x.dim() - 2 43 | assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" 44 | 45 | x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) 46 | x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) 47 | x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) 48 | return x 49 | -------------------------------------------------------------------------------- /triposf/models/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | import importlib 25 | 26 | __attributes = { 27 | 'TripoSFVAEEncoder': 'triposf_vae', 28 | 'TripoSFVAEDecoder': 'triposf_vae', 29 | } 30 | 31 | __submodules = [] 32 | 33 | __all__ = list(__attributes.keys()) + __submodules 34 | 35 | def __getattr__(name): 36 | if name not in globals(): 37 | if name in __attributes: 38 | module_name = __attributes[name] 39 | module = importlib.import_module(f".{module_name}", __name__) 40 | globals()[name] = getattr(module, name) 41 | elif name in __submodules: 42 | module = importlib.import_module(f".{name}", __name__) 43 | globals()[name] = module 44 | else: 45 | raise AttributeError(f"module {__name__} has no attribute {name}") 46 | return globals()[name] 47 | 48 | # For Pylance 49 | if __name__ == '__main__': 50 | from .TripoSF_vae import TripoSFVAEEncoder, TripoSFVAEDecoder 51 | -------------------------------------------------------------------------------- /triposf/modules/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | 26 | BACKEND = 'flash_attn' 27 | DEBUG = False 28 | 29 | def __from_env(): 30 | import os 31 | 32 | global BACKEND 33 | global DEBUG 34 | 35 | env_attn_backend = os.environ.get('ATTN_BACKEND') 36 | env_sttn_debug = os.environ.get('ATTN_DEBUG') 37 | 38 | if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: 39 | BACKEND = env_attn_backend 40 | if env_sttn_debug is not None: 41 | DEBUG = env_sttn_debug == '1' 42 | 43 | print(f"[ATTENTION] Using backend: {BACKEND}") 44 | 45 | 46 | __from_env() 47 | 48 | 49 | def set_backend(backend: Literal['xformers', 'flash_attn']): 50 | global BACKEND 51 | BACKEND = backend 52 | 53 | def set_debug(debug: bool): 54 | global DEBUG 55 | DEBUG = debug 56 | 57 | 58 | from .full_attn import * 59 | from .modules import * 60 | -------------------------------------------------------------------------------- /triposf/modules/sparse/nonlinearity.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | import torch 25 | import torch.nn as nn 26 | from . import SparseTensor 27 | 28 | __all__ = [ 29 | 'SparseReLU', 30 | 'SparseSiLU', 31 | 'SparseGELU', 32 | 'SparseActivation' 33 | ] 34 | 35 | 36 | class SparseReLU(nn.ReLU): 37 | def forward(self, input: SparseTensor) -> SparseTensor: 38 | return input.replace(super().forward(input.feats)) 39 | 40 | 41 | class SparseSiLU(nn.SiLU): 42 | def forward(self, input: SparseTensor) -> SparseTensor: 43 | return input.replace(super().forward(input.feats)) 44 | 45 | 46 | class SparseGELU(nn.GELU): 47 | def forward(self, input: SparseTensor) -> SparseTensor: 48 | return input.replace(super().forward(input.feats)) 49 | 50 | 51 | class SparseActivation(nn.Module): 52 | def __init__(self, activation: nn.Module): 53 | super().__init__() 54 | self.activation = activation 55 | 56 | def forward(self, input: SparseTensor) -> SparseTensor: 57 | return input.replace(self.activation(input.feats)) 58 | 59 | -------------------------------------------------------------------------------- /triposf/modules/sparse/conv/conv_torchsparse.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | import torch 25 | import torch.nn as nn 26 | from .. import SparseTensor 27 | 28 | class SparseConv3d(nn.Module): 29 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 30 | super(SparseConv3d, self).__init__() 31 | if 'torchsparse' not in globals(): 32 | import torchsparse 33 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) 34 | 35 | def forward(self, x: SparseTensor) -> SparseTensor: 36 | out = self.conv(x.data) 37 | new_shape = [x.shape[0], self.conv.out_channels] 38 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) 39 | out._spatial_cache = x._spatial_cache 40 | out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) 41 | return out 42 | 43 | 44 | class SparseInverseConv3d(nn.Module): 45 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 46 | super(SparseInverseConv3d, self).__init__() 47 | if 'torchsparse' not in globals(): 48 | import torchsparse 49 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) 50 | 51 | def forward(self, x: SparseTensor) -> SparseTensor: 52 | out = self.conv(x.data) 53 | new_shape = [x.shape[0], self.conv.out_channels] 54 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) 55 | out._spatial_cache = x._spatial_cache 56 | out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) 57 | return out 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /triposf/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import * 4 | 5 | from ..modules import sparse as sp 6 | 7 | FP16_MODULES = ( 8 | nn.Conv1d, 9 | nn.Conv2d, 10 | nn.Conv3d, 11 | nn.ConvTranspose1d, 12 | nn.ConvTranspose2d, 13 | nn.ConvTranspose3d, 14 | nn.Linear, 15 | sp.SparseConv3d, 16 | sp.SparseInverseConv3d, 17 | sp.SparseLinear, 18 | ) 19 | 20 | def convert_module_to_f16(l): 21 | """ 22 | Convert primitive modules to float16. 23 | """ 24 | if isinstance(l, FP16_MODULES): 25 | for p in l.parameters(): 26 | p.data = p.data.half() 27 | 28 | 29 | def convert_module_to_f32(l): 30 | """ 31 | Convert primitive modules to float32, undoing convert_module_to_f16(). 32 | """ 33 | if isinstance(l, FP16_MODULES): 34 | for p in l.parameters(): 35 | p.data = p.data.float() 36 | 37 | 38 | def zero_module(module): 39 | """ 40 | Zero out the parameters of a module and return it. 41 | """ 42 | for p in module.parameters(): 43 | p.detach().zero_() 44 | return module 45 | 46 | 47 | def scale_module(module, scale): 48 | """ 49 | Scale the parameters of a module and return it. 50 | """ 51 | for p in module.parameters(): 52 | p.detach().mul_(scale) 53 | return module 54 | 55 | 56 | def modulate(x, shift, scale): 57 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 58 | 59 | class DiagonalGaussianDistribution(object): 60 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): 61 | self.feat_dim = feat_dim 62 | self.parameters = parameters 63 | 64 | if isinstance(parameters, list): 65 | self.mean = parameters[0] 66 | self.logvar = parameters[1] 67 | else: 68 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) 69 | 70 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 71 | self.deterministic = deterministic 72 | self.std = torch.exp(0.5 * self.logvar) 73 | self.var = torch.exp(self.logvar) 74 | if self.deterministic: 75 | self.var = self.std = torch.zeros_like(self.mean) 76 | 77 | def sample(self): 78 | x = self.mean + self.std * torch.randn_like(self.mean) 79 | return x 80 | 81 | def kl(self, other=None, dims=(1, 2, 3)): 82 | if self.deterministic: 83 | return torch.Tensor([0.]) 84 | else: 85 | if other is None: 86 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 87 | + self.var - 1.0 - self.logvar, 88 | dim=dims) 89 | else: 90 | return 0.5 * torch.mean( 91 | torch.pow(self.mean - other.mean, 2) / other.var 92 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 93 | dim=dims) 94 | 95 | def nll(self, sample, dims=(1, 2, 3)): 96 | if self.deterministic: 97 | return torch.Tensor([0.]) 98 | logtwopi = np.log(2.0 * np.pi) 99 | return 0.5 * torch.sum( 100 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 101 | dim=dims) 102 | 103 | def mode(self): 104 | return self.mean -------------------------------------------------------------------------------- /triposf/modules/sparse/norm.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | import torch 25 | import torch.nn as nn 26 | from . import SparseTensor 27 | from . import DEBUG 28 | 29 | __all__ = [ 30 | 'SparseGroupNorm', 31 | 'SparseLayerNorm', 32 | 'SparseGroupNorm32', 33 | 'SparseLayerNorm32', 34 | ] 35 | 36 | 37 | class SparseGroupNorm(nn.GroupNorm): 38 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): 39 | super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) 40 | 41 | def forward(self, input: SparseTensor) -> SparseTensor: 42 | nfeats = torch.zeros_like(input.feats) 43 | for k in range(input.shape[0]): 44 | if DEBUG: 45 | assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" 46 | bfeats = input.feats[input.layout[k]] 47 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) 48 | bfeats = super().forward(bfeats) 49 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) 50 | nfeats[input.layout[k]] = bfeats 51 | return input.replace(nfeats) 52 | 53 | 54 | class SparseLayerNorm(nn.LayerNorm): 55 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 56 | super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) 57 | 58 | def forward(self, input: SparseTensor) -> SparseTensor: 59 | nfeats = torch.zeros_like(input.feats) 60 | for k in range(input.shape[0]): 61 | bfeats = input.feats[input.layout[k]] 62 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) 63 | bfeats = super().forward(bfeats) 64 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) 65 | nfeats[input.layout[k]] = bfeats 66 | return input.replace(nfeats) 67 | 68 | 69 | class SparseGroupNorm32(SparseGroupNorm): 70 | """ 71 | A GroupNorm layer that converts to float32 before the forward pass. 72 | """ 73 | def forward(self, x: SparseTensor) -> SparseTensor: 74 | return super().forward(x.float()).type(x.dtype) 75 | 76 | class SparseLayerNorm32(SparseLayerNorm): 77 | """ 78 | A LayerNorm layer that converts to float32 before the forward pass. 79 | """ 80 | def forward(self, x: SparseTensor) -> SparseTensor: 81 | return super().forward(x.float()).type(x.dtype) 82 | -------------------------------------------------------------------------------- /triposf/models/triposf_vae/encoder.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | from ...modules import sparse as sp 28 | from .base import SparseTransformerBase 29 | from ...modules.utils import DiagonalGaussianDistribution 30 | 31 | class TripoSFVAEEncoder(SparseTransformerBase): 32 | def __init__( 33 | self, 34 | resolution: int, 35 | in_channels: int, 36 | model_channels: int, 37 | latent_channels: int, 38 | num_blocks: int, 39 | num_heads: Optional[int] = None, 40 | num_head_channels: Optional[int] = 64, 41 | mlp_ratio: float = 4, 42 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 43 | window_size: int = 8, 44 | pe_mode: Literal["ape", "rope"] = "ape", 45 | use_fp16: bool = False, 46 | use_checkpoint: bool = False, 47 | qk_rms_norm: bool = False 48 | ): 49 | super().__init__( 50 | in_channels=in_channels, 51 | model_channels=model_channels, 52 | num_blocks=num_blocks, 53 | num_heads=num_heads, 54 | num_head_channels=num_head_channels, 55 | mlp_ratio=mlp_ratio, 56 | attn_mode=attn_mode, 57 | window_size=window_size, 58 | pe_mode=pe_mode, 59 | use_fp16=use_fp16, 60 | use_checkpoint=use_checkpoint, 61 | qk_rms_norm=qk_rms_norm, 62 | ) 63 | self.resolution = resolution 64 | self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) 65 | self.initialize_weights() 66 | # if use_fp16: 67 | # self.convert_to_fp16() 68 | 69 | def initialize_weights(self) -> None: 70 | super().initialize_weights() 71 | # Zero-out output layers: 72 | nn.init.constant_(self.out_layer.weight, 0) 73 | nn.init.constant_(self.out_layer.bias, 0) 74 | 75 | def forward(self, x: sp.SparseTensor, sample_posterior=True): 76 | h = super().forward(x) 77 | h = h.type(x.dtype) 78 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 79 | h = self.out_layer(h) 80 | 81 | posterior = DiagonalGaussianDistribution(h.feats, feat_dim=-1) 82 | 83 | if sample_posterior: 84 | z = posterior.sample() 85 | else: 86 | z = posterior.mode() 87 | 88 | z = h.replace(z) 89 | return z, posterior 90 | -------------------------------------------------------------------------------- /triposf/representations/mesh/utils_cube.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | import torch 25 | cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ 26 | 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) 27 | cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) 28 | cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 29 | 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) 30 | 31 | def construct_dense_grid(res, device='cuda'): 32 | '''construct a dense grid based on resolution''' 33 | res_v = res + 1 34 | vertsid = torch.arange(res_v ** 3, device=device) 35 | coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() 36 | cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] 37 | cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) 38 | verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) 39 | return verts, cube_fx8 40 | 41 | 42 | def construct_voxel_grid(coords): 43 | verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) 44 | verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) 45 | cubes = inverse_indices.reshape(-1, 8) 46 | return verts_unique, cubes 47 | 48 | 49 | def cubes_to_verts(num_verts, cubes, value, reduce='mean'): 50 | """ 51 | Args: 52 | cubes [Vx8] verts index for each cube 53 | value [Vx8xM] value to be scattered 54 | Operation: 55 | reduced[cubes[i][j]][k] += value[i][k] 56 | """ 57 | M = value.shape[2] # number of channels 58 | reduced = torch.zeros(num_verts, M, device=cubes.device,dtype=value.dtype) 59 | return torch.scatter_reduce(reduced, 0, 60 | cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), 61 | value.flatten(0, 1), reduce=reduce, include_self=False) 62 | 63 | def sparse_cube2verts(coords, feats, training=True): 64 | new_coords, cubes = construct_voxel_grid(coords) 65 | new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) 66 | if training: 67 | con_loss = torch.mean((feats - new_feats[cubes]) ** 2) 68 | else: 69 | con_loss = 0.0 70 | return new_coords, new_feats, con_loss 71 | 72 | def get_sparse_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): 73 | verts = coords 74 | verts, masks = torch.unique(verts, dim=0, return_inverse=True) 75 | feats_sparse = torch.zeros((len(verts), feats.shape[-1]), device=feats.device, dtype=feats.dtype) 76 | feats_sparse[masks] = feats 77 | return feats_sparse, verts 78 | 79 | def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): 80 | return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TripoSF: High-Resolution and Arbitrary-Topology 3D Shape Modeling with SparseFlex 2 | 3 |
4 | 5 | [![Project Page](https://img.shields.io/badge/🏠-Project%20Page-blue.svg)](https://XianglongHe.github.io/TripoSF/index.html) 6 | [![Paper](https://img.shields.io/badge/📑-Paper-green.svg)](https://arxiv.org/abs/2503.21732) 7 | [![Model](https://img.shields.io/badge/🤗-Model-yellow.svg)](https://huggingface.co/VAST-AI/TripoSF) 8 | 9 | **By [Tripo](https://www.tripo3d.ai)** 10 | 11 |
12 | 13 | ![teaser](assets/docs/teaser.png) 14 | 15 | ## 🌟 Overview 16 | 17 | TripoSF represents a significant leap forward in 3D shape modeling, combining high-resolution capabilities with arbitrary topology support. Our approach enables: 18 | 19 | - 📈 Ultra-high resolution mesh modeling (up to $1024^3$) 20 | - 🎯 Direct optimization from rendering losses 21 | - 🌐 Efficient handling of open surfaces and complex topologies 22 | - 💾 Dramatic memory reduction through sparse computation 23 | - 🔄 Differentiable mesh extraction with sharp features 24 | 25 | ### SparseFlex 26 | 27 | SparseFlex, the core design powering TripoSF, introduces a sparse voxel structure that: 28 | - Focuses computational resources only on surface-adjacent regions 29 | - Enables natural handling of open surfaces (like cloth or leaves) 30 | - Supports complex internal structures without compromises 31 | - Achieves massive memory reduction compared to dense representations 32 | 33 | ## 🔥 Updates 34 | 35 | * [2025-03] Initial Release: 36 | - Pretrained VAE model weights ($1024^3$ reconstruction) 37 | - Inference scripts and examples 38 | - SparseFlex implementation 39 | 40 | ## 🚀 Getting Started 41 | 42 | ### System Requirements 43 | - CUDA-capable GPU (≥12GB VRAM for $1024^3$ resolution) 44 | - PyTorch 2.0+ 45 | 46 | ### Installation 47 | 48 | 1. Clone the repository: 49 | ```bash 50 | git clone https://github.com/VAST-AI-Research/TripoSF.git 51 | cd TripoSF 52 | ``` 53 | 54 | 2. Install dependencies: 55 | ```bash 56 | # Install PyTorch (select the correct CUDA version) 57 | pip install torch torchvision --index-url https://download.pytorch.org/whl/{your-cuda-version} 58 | 59 | # Install other dependencies 60 | pip install -r requirements.txt 61 | ``` 62 | 63 | ## 💫 Usage 64 | 65 | ### Pretrained Model Setup 66 | 1. Download our pretrained models from [Hugging Face](https://huggingface.co/VAST-AI/TripoSF) 67 | 2. Place the models in the `ckpts/` directory 68 | 69 | ### Running Inference 70 | Basic reconstruction using TripoSFVAE: 71 | ```bash 72 | python inference.py --mesh-path "assets/examples/jacket.obj" \ 73 | --output-dir "outputs/" \ 74 | --config "configs/TripoSFVAE_1024.yaml" 75 | ``` 76 | 77 | ### Local Gradio Example 78 | ```bash 79 | python app.py 80 | ``` 81 | ![gradio_example](assets/docs/local_gradio_example.png) 82 | 83 | ### Optimization Tips 💡 84 | 85 | #### For Open Surfaces 86 | - Enable `pruning` in the configuration: 87 | ```yaml 88 | pruning: true 89 | ``` 90 | - Benefits: 91 | - Higher-fidelity reconstruction 92 | - Faster processing 93 | - Better memory efficiency 94 | 95 | #### For Complex Shapes 96 | - Increase sampling density: 97 | ```yaml 98 | sample_points_num: 1638400 # Default: 819200 99 | ``` 100 | - Adjust resolution based on detail requirements: 101 | ```yaml 102 | resolution: 1024 # Options: 256, 512, 1024 103 | ``` 104 | 105 | 106 | ## 📊 Technical Details 107 | 108 | TripoSF VAE Architecture: 109 | - **Input**: Point clouds (preserving source geometry details) 110 | - **Encoder**: Sparse transformer for efficient geometry encoding 111 | - **Decoder**: Self-pruning upsampling modules maintaining sparsity 112 | - **Output**: High-resolution SparseFlex parameters for mesh extraction 113 | 114 | ## 📝 Citation 115 | 116 | ```bibtex 117 | @article{he2025triposf, 118 | title={SparseFlex: High-Resolution and Arbitrary-Topology 3D Shape Modeling}, 119 | author={He, Xianglong and Zou, Zi-Xin and Chen, Chia-Hao and Guo, Yuan-Chen and Liang, Ding and Yuan, Chun and Ouyang, Wanli and Cao, Yan-Pei and Li, Yangguang}, 120 | journal={arXiv preprint arXiv:2503.21732}, 121 | year={2025} 122 | } 123 | ``` 124 | 125 | 126 | ## 📚 Acknowledgements 127 | 128 | Our work builds upon these excellent repositories: 129 | - [Trellis](https://github.com/Microsoft/TRELLIS) 130 | - [Flexicubes](https://github.com/MaxtirError/FlexiCubes) 131 | 132 | ## 📄 License 133 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 134 | -------------------------------------------------------------------------------- /triposf/modules/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | 26 | BACKEND = 'spconv' 27 | DEBUG = False 28 | ATTN = 'flash_attn' 29 | 30 | def __from_env(): 31 | import os 32 | 33 | global BACKEND 34 | global DEBUG 35 | global ATTN 36 | 37 | env_sparse_backend = os.environ.get('SPARSE_BACKEND') 38 | env_sparse_debug = os.environ.get('SPARSE_DEBUG') 39 | env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') 40 | if env_sparse_attn is None: 41 | env_sparse_attn = os.environ.get('ATTN_BACKEND') 42 | 43 | if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: 44 | BACKEND = env_sparse_backend 45 | if env_sparse_debug is not None: 46 | DEBUG = env_sparse_debug == '1' 47 | if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: 48 | ATTN = env_sparse_attn 49 | 50 | print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") 51 | 52 | 53 | __from_env() 54 | 55 | 56 | def set_backend(backend: Literal['spconv', 'torchsparse']): 57 | global BACKEND 58 | BACKEND = backend 59 | 60 | def set_debug(debug: bool): 61 | global DEBUG 62 | DEBUG = debug 63 | 64 | def set_attn(attn: Literal['xformers', 'flash_attn']): 65 | global ATTN 66 | ATTN = attn 67 | 68 | 69 | import importlib 70 | 71 | __attributes = { 72 | 'SparseTensor': 'basic', 73 | 'sparse_batch_broadcast': 'basic', 74 | 'sparse_batch_op': 'basic', 75 | 'sparse_cat': 'basic', 76 | 'sparse_unbind': 'basic', 77 | 'SparseGroupNorm': 'norm', 78 | 'SparseLayerNorm': 'norm', 79 | 'SparseGroupNorm32': 'norm', 80 | 'SparseLayerNorm32': 'norm', 81 | 'SparseReLU': 'nonlinearity', 82 | 'SparseSiLU': 'nonlinearity', 83 | 'SparseGELU': 'nonlinearity', 84 | 'SparseActivation': 'nonlinearity', 85 | 'SparseLinear': 'linear', 86 | 'sparse_scaled_dot_product_attention': 'attention', 87 | 'SerializeMode': 'attention', 88 | 'sparse_serialized_scaled_dot_product_self_attention': 'attention', 89 | 'sparse_windowed_scaled_dot_product_self_attention': 'attention', 90 | 'SparseMultiHeadAttention': 'attention', 91 | 'SparseConv3d': 'conv', 92 | 'SparseInverseConv3d': 'conv', 93 | 'SparseDownsample': 'spatial', 94 | 'SparseUpsample': 'spatial', 95 | 'SparseSubdivide' : 'spatial' 96 | } 97 | 98 | __submodules = ['transformer'] 99 | 100 | __all__ = list(__attributes.keys()) + __submodules 101 | 102 | def __getattr__(name): 103 | if name not in globals(): 104 | if name in __attributes: 105 | module_name = __attributes[name] 106 | module = importlib.import_module(f".{module_name}", __name__) 107 | globals()[name] = getattr(module, name) 108 | elif name in __submodules: 109 | module = importlib.import_module(f".{name}", __name__) 110 | globals()[name] = module 111 | else: 112 | raise AttributeError(f"module {__name__} has no attribute {name}") 113 | return globals()[name] 114 | 115 | 116 | # For Pylance 117 | if __name__ == '__main__': 118 | from .basic import * 119 | from .norm import * 120 | from .nonlinearity import * 121 | from .linear import * 122 | from .attention import * 123 | from .conv import * 124 | from .spatial import * 125 | import transformer 126 | -------------------------------------------------------------------------------- /triposf/representations/mesh/flexicubes/LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | 5 | NVIDIA Source Code License for FlexiCubes 6 | 7 | 8 | ======================================================================= 9 | 10 | 1. Definitions 11 | 12 | “Licensor” means any person or entity that distributes its Work. 13 | 14 | “Work” means (a) the original work of authorship made available under 15 | this license, which may include software, documentation, or other files, 16 | and (b) any additions to or derivative works thereof that are made 17 | available under this license. 18 | 19 | The terms “reproduce,” “reproduction,” “derivative works,” and 20 | “distribution” have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this license, derivative works 22 | shall not include works that remain separable from, or merely link 23 | (or bind by name) to the interfaces of, the Work. 24 | 25 | Works are “made available” under this license by including in or with 26 | the Work either (a) a copyright notice referencing the applicability of 27 | this license to the Work, or (b) a copy of this license. 28 | 29 | 2. License Grant 30 | 31 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, 32 | each Licensor grants to you a perpetual, worldwide, non-exclusive, 33 | royalty-free, copyright license to use, reproduce, prepare derivative 34 | works of, publicly display, publicly perform, sublicense and distribute 35 | its Work and any resulting derivative works in any form. 36 | 37 | 3. Limitations 38 | 39 | 3.1 Redistribution. You may reproduce or distribute the Work only if 40 | (a) you do so under this license, (b) you include a complete copy of 41 | this license with your distribution, and (c) you retain without 42 | modification any copyright, patent, trademark, or attribution notices 43 | that are present in the Work. 44 | 45 | 3.2 Derivative Works. You may specify that additional or different terms 46 | apply to the use, reproduction, and distribution of your derivative 47 | works of the Work (“Your Terms”) only if (a) Your Terms provide that the 48 | use limitation in Section 3.3 applies to your derivative works, and (b) 49 | you identify the specific derivative works that are subject to Your Terms. 50 | Notwithstanding Your Terms, this license (including the redistribution 51 | requirements in Section 3.1) will continue to apply to the Work itself. 52 | 53 | 3.3 Use Limitation. The Work and any derivative works thereof only may be 54 | used or intended for use non-commercially. Notwithstanding the foregoing, 55 | NVIDIA Corporation and its affiliates may use the Work and any derivative 56 | works commercially. As used herein, “non-commercially” means for research 57 | or evaluation purposes only. 58 | 59 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against 60 | any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) 61 | to enforce any patents that you allege are infringed by any Work, then your 62 | rights under this license from such Licensor (including the grant in 63 | Section 2.1) will terminate immediately. 64 | 65 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s 66 | or its affiliates’ names, logos, or trademarks, except as necessary to 67 | reproduce the notices described in this license. 68 | 69 | 3.6 Termination. If you violate any term of this license, then your rights 70 | under this license (including the grant in Section 2.1) will terminate 71 | immediately. 72 | 73 | 4. Disclaimer of Warranty. 74 | 75 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 76 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 77 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. 78 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 79 | 80 | 5. Limitation of Liability. 81 | 82 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, 83 | WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY 84 | LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, 85 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, 86 | THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF 87 | GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR 88 | MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 89 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 90 | 91 | ======================================================================= 92 | -------------------------------------------------------------------------------- /triposf/modules/sparse/conv/conv_spconv.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | import torch 25 | import torch.nn as nn 26 | from .. import SparseTensor 27 | from .. import DEBUG 28 | from . import SPCONV_ALGO 29 | 30 | class SparseConv3d(nn.Module): 31 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): 32 | super(SparseConv3d, self).__init__() 33 | if 'spconv' not in globals(): 34 | import spconv.pytorch as spconv 35 | algo = None 36 | if SPCONV_ALGO == 'native': 37 | algo = spconv.ConvAlgo.Native 38 | elif SPCONV_ALGO == 'implicit_gemm': 39 | algo = spconv.ConvAlgo.MaskImplicitGemm 40 | if stride == 1 and (padding is None): 41 | self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) 42 | else: 43 | self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) 44 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) 45 | self.padding = padding 46 | 47 | def forward(self, x: SparseTensor) -> SparseTensor: 48 | spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) 49 | 50 | dtype_ = x.feats.dtype 51 | x = x.replace(x.feats.type(torch.float32)) 52 | new_data = self.conv(x.data) 53 | new_shape = [x.shape[0], self.conv.out_channels] 54 | new_layout = None if spatial_changed else x.layout 55 | 56 | if spatial_changed and (x.shape[0] != 1): 57 | # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords 58 | fwd = new_data.indices[:, 0].argsort() 59 | bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) 60 | sorted_feats = new_data.features[fwd] 61 | sorted_coords = new_data.indices[fwd] 62 | unsorted_data = new_data 63 | new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore 64 | 65 | out = SparseTensor( 66 | new_data, shape=torch.Size(new_shape), layout=new_layout, 67 | scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), 68 | spatial_cache=x._spatial_cache, 69 | ) 70 | out = out.replace(out.feats.type(dtype_)) 71 | 72 | if spatial_changed and (x.shape[0] != 1): 73 | out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) 74 | out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) 75 | 76 | return out 77 | 78 | class SparseInverseConv3d(nn.Module): 79 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 80 | super(SparseInverseConv3d, self).__init__() 81 | if 'spconv' not in globals(): 82 | import spconv.pytorch as spconv 83 | self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) 84 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) 85 | 86 | def forward(self, x: SparseTensor) -> SparseTensor: 87 | spatial_changed = any(s != 1 for s in self.stride) 88 | if spatial_changed: 89 | # recover the original spconv order 90 | data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') 91 | bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') 92 | data = data.replace_feature(x.feats[bwd]) 93 | if DEBUG: 94 | assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' 95 | else: 96 | data = x.data 97 | 98 | new_data = self.conv(data) 99 | new_shape = [x.shape[0], self.conv.out_channels] 100 | new_layout = None if spatial_changed else x.layout 101 | out = SparseTensor( 102 | new_data, shape=torch.Size(new_shape), layout=new_layout, 103 | scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), 104 | spatial_cache=x._spatial_cache, 105 | ) 106 | return out 107 | -------------------------------------------------------------------------------- /triposf/models/triposf_vae/base.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | from ...modules.utils import convert_module_to_f16, convert_module_to_f32 28 | from ...modules import sparse as sp 29 | from ...modules.transformer import AbsolutePositionEmbedder 30 | from ...modules.sparse.transformer import SparseTransformerBlock 31 | 32 | 33 | def block_attn_config(self): 34 | """ 35 | Return the attention configuration of the model. 36 | """ 37 | for i in range(self.num_blocks): 38 | if self.attn_mode == "shift_window": 39 | yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER 40 | elif self.attn_mode == "shift_sequence": 41 | yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER 42 | elif self.attn_mode == "shift_order": 43 | yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] 44 | elif self.attn_mode == "full": 45 | yield "full", None, None, None, None 46 | elif self.attn_mode == "swin": 47 | yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None 48 | 49 | 50 | class SparseTransformerBase(nn.Module): 51 | """ 52 | Sparse Transformer without output layers. 53 | Serve as the base class for encoder and decoder. 54 | """ 55 | def __init__( 56 | self, 57 | in_channels: int, 58 | model_channels: int, 59 | num_blocks: int, 60 | num_heads: Optional[int] = None, 61 | num_head_channels: Optional[int] = 64, 62 | mlp_ratio: float = 4.0, 63 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 64 | window_size: Optional[int] = None, 65 | pe_mode: Literal["ape", "rope"] = "ape", 66 | use_fp16: bool = False, 67 | use_checkpoint: bool = False, 68 | qk_rms_norm: bool = False, 69 | ): 70 | super().__init__() 71 | self.in_channels = in_channels 72 | self.model_channels = model_channels 73 | self.num_blocks = num_blocks 74 | self.window_size = window_size 75 | self.num_heads = num_heads or model_channels // num_head_channels 76 | self.mlp_ratio = mlp_ratio 77 | self.attn_mode = attn_mode 78 | self.pe_mode = pe_mode 79 | self.use_fp16 = use_fp16 80 | self.use_checkpoint = use_checkpoint 81 | self.qk_rms_norm = qk_rms_norm 82 | self.dtype = torch.float16 if use_fp16 else torch.float32 83 | 84 | if pe_mode == "ape": 85 | self.pos_embedder = AbsolutePositionEmbedder(model_channels) 86 | 87 | self.input_layer = sp.SparseLinear(in_channels, model_channels) 88 | self.blocks = nn.ModuleList([ 89 | SparseTransformerBlock( 90 | model_channels, 91 | num_heads=self.num_heads, 92 | mlp_ratio=self.mlp_ratio, 93 | attn_mode=attn_mode, 94 | window_size=window_size, 95 | shift_sequence=shift_sequence, 96 | shift_window=shift_window, 97 | serialize_mode=serialize_mode, 98 | use_checkpoint=self.use_checkpoint, 99 | use_rope=(pe_mode == "rope"), 100 | qk_rms_norm=self.qk_rms_norm, 101 | ) 102 | for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) 103 | ]) 104 | 105 | @property 106 | def device(self) -> torch.device: 107 | """ 108 | Return the device of the model. 109 | """ 110 | return next(self.parameters()).device 111 | 112 | def convert_to_fp16(self) -> None: 113 | """ 114 | Convert the torso of the model to float16. 115 | """ 116 | self.blocks.apply(convert_module_to_f16) 117 | 118 | def convert_to_fp32(self) -> None: 119 | """ 120 | Convert the torso of the model to float32. 121 | """ 122 | self.blocks.apply(convert_module_to_f32) 123 | 124 | def initialize_weights(self) -> None: 125 | # Initialize transformer layers: 126 | def _basic_init(module): 127 | if isinstance(module, nn.Linear): 128 | torch.nn.init.xavier_uniform_(module.weight) 129 | if module.bias is not None: 130 | nn.init.constant_(module.bias, 0) 131 | self.apply(_basic_init) 132 | 133 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: 134 | h = self.input_layer(x) 135 | if self.pe_mode == "ape": 136 | h = h + self.pos_embedder(x.coords[:, 1:]) 137 | # h = h.type(self.dtype) 138 | for block in self.blocks: 139 | h = block(h) 140 | return h 141 | -------------------------------------------------------------------------------- /triposf/modules/attention/full_attn.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import math 27 | from . import DEBUG, BACKEND 28 | 29 | if BACKEND == 'xformers': 30 | import xformers.ops as xops 31 | elif BACKEND == 'flash_attn': 32 | import flash_attn 33 | elif BACKEND == 'sdpa': 34 | from torch.nn.functional import scaled_dot_product_attention as sdpa 35 | elif BACKEND == 'naive': 36 | pass 37 | else: 38 | raise ValueError(f"Unknown attention backend: {BACKEND}") 39 | 40 | 41 | __all__ = [ 42 | 'scaled_dot_product_attention', 43 | ] 44 | 45 | 46 | def _naive_sdpa(q, k, v): 47 | """ 48 | Naive implementation of scaled dot product attention. 49 | """ 50 | q = q.permute(0, 2, 1, 3) # [N, H, L, C] 51 | k = k.permute(0, 2, 1, 3) # [N, H, L, C] 52 | v = v.permute(0, 2, 1, 3) # [N, H, L, C] 53 | scale_factor = 1 / math.sqrt(q.size(-1)) 54 | attn_weight = q @ k.transpose(-2, -1) * scale_factor 55 | attn_weight = torch.softmax(attn_weight, dim=-1) 56 | out = attn_weight @ v 57 | out = out.permute(0, 2, 1, 3) # [N, L, H, C] 58 | return out 59 | 60 | 61 | @overload 62 | def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Apply scaled dot product attention. 65 | 66 | Args: 67 | qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. 68 | """ 69 | ... 70 | 71 | @overload 72 | def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: 73 | """ 74 | Apply scaled dot product attention. 75 | 76 | Args: 77 | q (torch.Tensor): A [N, L, H, C] tensor containing Qs. 78 | kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. 79 | """ 80 | ... 81 | 82 | @overload 83 | def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 84 | """ 85 | Apply scaled dot product attention. 86 | 87 | Args: 88 | q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. 89 | k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. 90 | v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. 91 | 92 | Note: 93 | k and v are assumed to have the same coordinate map. 94 | """ 95 | ... 96 | 97 | def scaled_dot_product_attention(*args, **kwargs): 98 | arg_names_dict = { 99 | 1: ['qkv'], 100 | 2: ['q', 'kv'], 101 | 3: ['q', 'k', 'v'] 102 | } 103 | num_all_args = len(args) + len(kwargs) 104 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" 105 | for key in arg_names_dict[num_all_args][len(args):]: 106 | assert key in kwargs, f"Missing argument {key}" 107 | 108 | if num_all_args == 1: 109 | qkv = args[0] if len(args) > 0 else kwargs['qkv'] 110 | assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" 111 | device = qkv.device 112 | 113 | elif num_all_args == 2: 114 | q = args[0] if len(args) > 0 else kwargs['q'] 115 | kv = args[1] if len(args) > 1 else kwargs['kv'] 116 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" 117 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" 118 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" 119 | device = q.device 120 | 121 | elif num_all_args == 3: 122 | q = args[0] if len(args) > 0 else kwargs['q'] 123 | k = args[1] if len(args) > 1 else kwargs['k'] 124 | v = args[2] if len(args) > 2 else kwargs['v'] 125 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" 126 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" 127 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" 128 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" 129 | device = q.device 130 | 131 | if BACKEND == 'xformers': 132 | if num_all_args == 1: 133 | q, k, v = qkv.unbind(dim=2) 134 | elif num_all_args == 2: 135 | k, v = kv.unbind(dim=2) 136 | out = xops.memory_efficient_attention(q, k, v) 137 | elif BACKEND == 'flash_attn': 138 | if num_all_args == 1: 139 | out = flash_attn.flash_attn_qkvpacked_func(qkv) 140 | elif num_all_args == 2: 141 | out = flash_attn.flash_attn_kvpacked_func(q, kv) 142 | elif num_all_args == 3: 143 | out = flash_attn.flash_attn_func(q, k, v) 144 | elif BACKEND == 'sdpa': 145 | if num_all_args == 1: 146 | q, k, v = qkv.unbind(dim=2) 147 | elif num_all_args == 2: 148 | k, v = kv.unbind(dim=2) 149 | q = q.permute(0, 2, 1, 3) # [N, H, L, C] 150 | k = k.permute(0, 2, 1, 3) # [N, H, L, C] 151 | v = v.permute(0, 2, 1, 3) # [N, H, L, C] 152 | out = sdpa(q, k, v) # [N, H, L, C] 153 | out = out.permute(0, 2, 1, 3) # [N, L, H, C] 154 | elif BACKEND == 'naive': 155 | if num_all_args == 1: 156 | q, k, v = qkv.unbind(dim=2) 157 | elif num_all_args == 2: 158 | k, v = kv.unbind(dim=2) 159 | out = _naive_sdpa(q, k, v) 160 | else: 161 | raise ValueError(f"Unknown attention module: {BACKEND}") 162 | 163 | return out 164 | -------------------------------------------------------------------------------- /triposf/modules/sparse/spatial.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | from . import SparseTensor 28 | 29 | __all__ = [ 30 | 'SparseDownsample', 31 | 'SparseUpsample', 32 | 'SparseSubdivide' 33 | ] 34 | 35 | 36 | class SparseDownsample(nn.Module): 37 | """ 38 | Downsample a sparse tensor by a factor of `factor`. 39 | Implemented as average pooling. 40 | """ 41 | def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): 42 | super(SparseDownsample, self).__init__() 43 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor 44 | 45 | def forward(self, input: SparseTensor) -> SparseTensor: 46 | DIM = input.coords.shape[-1] - 1 47 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM 48 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' 49 | 50 | coord = list(input.coords.unbind(dim=-1)) 51 | for i, f in enumerate(factor): 52 | coord[i+1] = coord[i+1] // f 53 | 54 | MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] 55 | OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] 56 | code = sum([c * o for c, o in zip(coord, OFFSET)]) 57 | code, idx = code.unique(return_inverse=True) 58 | 59 | new_feats = torch.scatter_reduce( 60 | torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), 61 | dim=0, 62 | index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), 63 | src=input.feats, 64 | reduce='mean' 65 | ) 66 | new_coords = torch.stack( 67 | [code // OFFSET[0]] + 68 | [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], 69 | dim=-1 70 | ) 71 | out = SparseTensor(new_feats, new_coords, input.shape,) 72 | out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) 73 | out._spatial_cache = input._spatial_cache 74 | 75 | if out.get_spatial_cache(f'upsample_{factor}_coords') is not None: 76 | out.register_spatial_cache(f'upsample_{factor}_coords', [*out.get_spatial_cache(f'upsample_{factor}_coords'), input.coords]) 77 | out.register_spatial_cache(f'upsample_{factor}_layout', [*out.get_spatial_cache(f'upsample_{factor}_layout'), input.layout]) 78 | out.register_spatial_cache(f'upsample_{factor}_idx', [*out.get_spatial_cache(f'upsample_{factor}_idx'), idx]) 79 | else: 80 | out.register_spatial_cache(f'upsample_{factor}_coords', [input.coords]) 81 | out.register_spatial_cache(f'upsample_{factor}_layout', [input.layout]) 82 | out.register_spatial_cache(f'upsample_{factor}_idx', [idx]) 83 | 84 | return out 85 | 86 | 87 | class SparseUpsample(nn.Module): 88 | """ 89 | Upsample a sparse tensor by a factor of `factor`. 90 | Implemented as nearest neighbor interpolation. 91 | """ 92 | def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): 93 | super(SparseUpsample, self).__init__() 94 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor 95 | 96 | def forward(self, input: SparseTensor) -> SparseTensor: 97 | DIM = input.coords.shape[-1] - 1 98 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM 99 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' 100 | 101 | new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') 102 | new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') 103 | idx = input.get_spatial_cache(f'upsample_{factor}_idx') 104 | # print(len(new_coords)) 105 | new_coords = new_coords.pop(-1) 106 | new_layout = new_layout.pop(-1) 107 | idx = idx.pop(-1) 108 | if any([x is None for x in [new_coords, new_layout, idx]]): 109 | raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') 110 | new_feats = input.feats[idx] 111 | out = SparseTensor(new_feats, new_coords, input.shape, new_layout) 112 | out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) 113 | out._spatial_cache = input._spatial_cache 114 | return out 115 | 116 | class SparseSubdivide(nn.Module): 117 | """ 118 | Upsample a sparse tensor by a factor of `factor`. 119 | Implemented as nearest neighbor interpolation. 120 | """ 121 | def __init__(self): 122 | super(SparseSubdivide, self).__init__() 123 | 124 | def forward(self, input: SparseTensor) -> SparseTensor: 125 | DIM = input.coords.shape[-1] - 1 126 | # upsample scale=2^DIM 127 | n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) 128 | n_coords = torch.nonzero(n_cube) 129 | n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) 130 | factor = n_coords.shape[0] 131 | assert factor == 2 ** DIM 132 | # print(n_coords.shape) 133 | new_coords = input.coords.clone() 134 | new_coords[:, 1:] *= 2 135 | new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) 136 | 137 | new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) 138 | out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) 139 | out._scale = input._scale * 2 140 | out._spatial_cache = input._spatial_cache 141 | return out 142 | 143 | -------------------------------------------------------------------------------- /triposf/modules/sparse/transformer/blocks.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | import torch.utils.checkpoint 28 | from ..basic import SparseTensor 29 | from ..linear import SparseLinear 30 | from ..nonlinearity import SparseGELU 31 | from ..attention import SparseMultiHeadAttention, SerializeMode 32 | from ...norm import LayerNorm32 33 | 34 | 35 | class SparseFeedForwardNet(nn.Module): 36 | def __init__(self, channels: int, mlp_ratio: float = 4.0): 37 | super().__init__() 38 | self.mlp = nn.Sequential( 39 | SparseLinear(channels, int(channels * mlp_ratio)), 40 | SparseGELU(approximate="tanh"), 41 | SparseLinear(int(channels * mlp_ratio), channels), 42 | ) 43 | 44 | def forward(self, x: SparseTensor) -> SparseTensor: 45 | return self.mlp(x) 46 | 47 | 48 | class SparseTransformerBlock(nn.Module): 49 | """ 50 | Sparse Transformer block (MSA + FFN). 51 | """ 52 | def __init__( 53 | self, 54 | channels: int, 55 | num_heads: int, 56 | mlp_ratio: float = 4.0, 57 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 58 | window_size: Optional[int] = None, 59 | shift_sequence: Optional[int] = None, 60 | shift_window: Optional[Tuple[int, int, int]] = None, 61 | serialize_mode: Optional[SerializeMode] = None, 62 | use_checkpoint: bool = False, 63 | use_rope: bool = False, 64 | qk_rms_norm: bool = False, 65 | qkv_bias: bool = True, 66 | ln_affine: bool = False, 67 | ): 68 | super().__init__() 69 | self.use_checkpoint = use_checkpoint 70 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 71 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 72 | self.attn = SparseMultiHeadAttention( 73 | channels, 74 | num_heads=num_heads, 75 | attn_mode=attn_mode, 76 | window_size=window_size, 77 | shift_sequence=shift_sequence, 78 | shift_window=shift_window, 79 | serialize_mode=serialize_mode, 80 | qkv_bias=qkv_bias, 81 | use_rope=use_rope, 82 | qk_rms_norm=qk_rms_norm, 83 | ) 84 | self.mlp = SparseFeedForwardNet( 85 | channels, 86 | mlp_ratio=mlp_ratio, 87 | ) 88 | 89 | def _forward(self, x: SparseTensor) -> SparseTensor: 90 | h = x.replace(self.norm1(x.feats)) 91 | h = self.attn(h) 92 | x = x + h 93 | h = x.replace(self.norm2(x.feats)) 94 | h = self.mlp(h) 95 | x = x + h 96 | return x 97 | 98 | def forward(self, x: SparseTensor) -> SparseTensor: 99 | if self.use_checkpoint: 100 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) 101 | else: 102 | return self._forward(x) 103 | 104 | 105 | class SparseTransformerCrossBlock(nn.Module): 106 | """ 107 | Sparse Transformer cross-attention block (MSA + MCA + FFN). 108 | """ 109 | def __init__( 110 | self, 111 | channels: int, 112 | ctx_channels: int, 113 | num_heads: int, 114 | mlp_ratio: float = 4.0, 115 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 116 | window_size: Optional[int] = None, 117 | shift_sequence: Optional[int] = None, 118 | shift_window: Optional[Tuple[int, int, int]] = None, 119 | serialize_mode: Optional[SerializeMode] = None, 120 | use_checkpoint: bool = False, 121 | use_rope: bool = False, 122 | qk_rms_norm: bool = False, 123 | qk_rms_norm_cross: bool = False, 124 | qkv_bias: bool = True, 125 | ln_affine: bool = False, 126 | ): 127 | super().__init__() 128 | self.use_checkpoint = use_checkpoint 129 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 130 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 131 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 132 | self.self_attn = SparseMultiHeadAttention( 133 | channels, 134 | num_heads=num_heads, 135 | type="self", 136 | attn_mode=attn_mode, 137 | window_size=window_size, 138 | shift_sequence=shift_sequence, 139 | shift_window=shift_window, 140 | serialize_mode=serialize_mode, 141 | qkv_bias=qkv_bias, 142 | use_rope=use_rope, 143 | qk_rms_norm=qk_rms_norm, 144 | ) 145 | self.cross_attn = SparseMultiHeadAttention( 146 | channels, 147 | ctx_channels=ctx_channels, 148 | num_heads=num_heads, 149 | type="cross", 150 | attn_mode="full", 151 | qkv_bias=qkv_bias, 152 | qk_rms_norm=qk_rms_norm_cross, 153 | ) 154 | self.mlp = SparseFeedForwardNet( 155 | channels, 156 | mlp_ratio=mlp_ratio, 157 | ) 158 | 159 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): 160 | h = x.replace(self.norm1(x.feats)) 161 | h = self.self_attn(h) 162 | x = x + h 163 | h = x.replace(self.norm2(x.feats)) 164 | h = self.cross_attn(h, context) 165 | x = x + h 166 | h = x.replace(self.norm3(x.feats)) 167 | h = self.mlp(h) 168 | x = x + h 169 | return x 170 | 171 | def forward(self, x: SparseTensor, context: torch.Tensor): 172 | if self.use_checkpoint: 173 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) 174 | else: 175 | return self._forward(x, context) 176 | -------------------------------------------------------------------------------- /triposf/modules/transformer/modulated.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | import torch.utils.checkpoint 28 | from ..attention import MultiHeadAttention 29 | from ..norm import LayerNorm32 30 | from .blocks import FeedForwardNet 31 | 32 | 33 | class ModulatedTransformerBlock(nn.Module): 34 | """ 35 | Transformer block (MSA + FFN) with adaptive layer norm conditioning. 36 | """ 37 | def __init__( 38 | self, 39 | channels: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | attn_mode: Literal["full", "windowed"] = "full", 43 | window_size: Optional[int] = None, 44 | shift_window: Optional[Tuple[int, int, int]] = None, 45 | use_checkpoint: bool = False, 46 | use_rope: bool = False, 47 | qk_rms_norm: bool = False, 48 | qkv_bias: bool = True, 49 | share_mod: bool = False, 50 | ): 51 | super().__init__() 52 | self.use_checkpoint = use_checkpoint 53 | self.share_mod = share_mod 54 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 55 | self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 56 | self.attn = MultiHeadAttention( 57 | channels, 58 | num_heads=num_heads, 59 | attn_mode=attn_mode, 60 | window_size=window_size, 61 | shift_window=shift_window, 62 | qkv_bias=qkv_bias, 63 | use_rope=use_rope, 64 | qk_rms_norm=qk_rms_norm, 65 | ) 66 | self.mlp = FeedForwardNet( 67 | channels, 68 | mlp_ratio=mlp_ratio, 69 | ) 70 | if not share_mod: 71 | self.adaLN_modulation = nn.Sequential( 72 | nn.SiLU(), 73 | nn.Linear(channels, 6 * channels, bias=True) 74 | ) 75 | 76 | def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: 77 | if self.share_mod: 78 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) 79 | else: 80 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) 81 | h = self.norm1(x) 82 | h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) 83 | h = self.attn(h) 84 | h = h * gate_msa.unsqueeze(1) 85 | x = x + h 86 | h = self.norm2(x) 87 | h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) 88 | h = self.mlp(h) 89 | h = h * gate_mlp.unsqueeze(1) 90 | x = x + h 91 | return x 92 | 93 | def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: 94 | if self.use_checkpoint: 95 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) 96 | else: 97 | return self._forward(x, mod) 98 | 99 | 100 | class ModulatedTransformerCrossBlock(nn.Module): 101 | """ 102 | Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. 103 | """ 104 | def __init__( 105 | self, 106 | channels: int, 107 | ctx_channels: int, 108 | num_heads: int, 109 | mlp_ratio: float = 4.0, 110 | attn_mode: Literal["full", "windowed"] = "full", 111 | window_size: Optional[int] = None, 112 | shift_window: Optional[Tuple[int, int, int]] = None, 113 | use_checkpoint: bool = False, 114 | use_rope: bool = False, 115 | qk_rms_norm: bool = False, 116 | qk_rms_norm_cross: bool = False, 117 | qkv_bias: bool = True, 118 | share_mod: bool = False, 119 | ): 120 | super().__init__() 121 | self.use_checkpoint = use_checkpoint 122 | self.share_mod = share_mod 123 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 124 | self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) 125 | self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 126 | self.self_attn = MultiHeadAttention( 127 | channels, 128 | num_heads=num_heads, 129 | type="self", 130 | attn_mode=attn_mode, 131 | window_size=window_size, 132 | shift_window=shift_window, 133 | qkv_bias=qkv_bias, 134 | use_rope=use_rope, 135 | qk_rms_norm=qk_rms_norm, 136 | ) 137 | self.cross_attn = MultiHeadAttention( 138 | channels, 139 | ctx_channels=ctx_channels, 140 | num_heads=num_heads, 141 | type="cross", 142 | attn_mode="full", 143 | qkv_bias=qkv_bias, 144 | qk_rms_norm=qk_rms_norm_cross, 145 | ) 146 | self.mlp = FeedForwardNet( 147 | channels, 148 | mlp_ratio=mlp_ratio, 149 | ) 150 | if not share_mod: 151 | self.adaLN_modulation = nn.Sequential( 152 | nn.SiLU(), 153 | nn.Linear(channels, 6 * channels, bias=True) 154 | ) 155 | 156 | def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): 157 | if self.share_mod: 158 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) 159 | else: 160 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) 161 | h = self.norm1(x) 162 | h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) 163 | h = self.self_attn(h) 164 | h = h * gate_msa.unsqueeze(1) 165 | x = x + h 166 | h = self.norm2(x) 167 | h = self.cross_attn(h, context) 168 | x = x + h 169 | h = self.norm3(x) 170 | h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) 171 | h = self.mlp(h) 172 | h = h * gate_mlp.unsqueeze(1) 173 | x = x + h 174 | return x 175 | 176 | def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): 177 | if self.use_checkpoint: 178 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) 179 | else: 180 | return self._forward(x, mod, context) 181 | -------------------------------------------------------------------------------- /triposf/representations/mesh/cube2mesh.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | import torch 25 | from ...modules.sparse import SparseTensor 26 | from easydict import EasyDict as edict 27 | from .utils_cube import * 28 | from .flexicubes.flexicubes import FlexiCubes 29 | 30 | class MeshExtractResult: 31 | def __init__(self, 32 | vertices, 33 | faces, 34 | vertex_attrs=None, 35 | res=64 36 | ): 37 | self.vertices = vertices 38 | self.faces = faces.long() 39 | self.vertex_attrs = vertex_attrs 40 | self.face_normal = self.compute_face_normals(vertices, faces) 41 | self.res = res 42 | self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) 43 | 44 | # training only 45 | self.tsdf_v = None 46 | self.tsdf_s = None 47 | 48 | def compute_face_normals(self, verts, faces): 49 | i0 = faces[..., 0].long() 50 | i1 = faces[..., 1].long() 51 | i2 = faces[..., 2].long() 52 | 53 | v0 = verts[i0, :] 54 | v1 = verts[i1, :] 55 | v2 = verts[i2, :] 56 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) 57 | face_normals = torch.nn.functional.normalize(face_normals, dim=1) 58 | 59 | return face_normals[:, None, :].repeat(1, 3, 1) 60 | 61 | def comput_v_normals(self, verts, faces): 62 | i0 = faces[..., 0].long() 63 | i1 = faces[..., 1].long() 64 | i2 = faces[..., 2].long() 65 | 66 | v0 = verts[i0, :] 67 | v1 = verts[i1, :] 68 | v2 = verts[i2, :] 69 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) 70 | v_normals = torch.zeros_like(verts) 71 | v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) 72 | v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) 73 | v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) 74 | 75 | v_normals = torch.nn.functional.normalize(v_normals, dim=1) 76 | return v_normals 77 | 78 | 79 | class SparseFeatures2Mesh: 80 | def __init__(self, device="cuda", res=64, use_color=False, use_sparse_flexicube=True, use_sparse_sparse_flexicube=False): 81 | ''' 82 | a model to generate a mesh from sparse features structures using flexicube 83 | ''' 84 | super().__init__() 85 | self.device=device 86 | self.res = res 87 | self.mesh_extractor = FlexiCubes(device=device) 88 | self.sdf_bias = -1.0 / res 89 | 90 | self.use_sparse_flexicube = use_sparse_flexicube 91 | self.use_sparse_sparse_flexicube = use_sparse_sparse_flexicube 92 | 93 | self.use_color = use_color 94 | self._calc_layout() 95 | 96 | def _calc_layout(self): 97 | LAYOUTS = { 98 | 'sdf': {'shape': (8, 1), 'size': 8}, 99 | 'deform': {'shape': (8, 3), 'size': 8 * 3}, 100 | 'weights': {'shape': (21,), 'size': 21} 101 | } 102 | if self.use_color: 103 | ''' 104 | 6 channel color including normal map 105 | ''' 106 | LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} 107 | self.layouts = edict(LAYOUTS) 108 | start = 0 109 | for k, v in self.layouts.items(): 110 | v['range'] = (start, start + v['size']) 111 | start += v['size'] 112 | self.feats_channels = start 113 | 114 | def get_layout(self, feats : torch.Tensor, name : str): 115 | if name not in self.layouts: 116 | return None 117 | return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) 118 | 119 | def __call__(self, cubefeats : SparseTensor, training=False): 120 | """ 121 | Generates a mesh based on the specified sparse voxel structures. 122 | Args: 123 | cube_attrs [Nx21] : Sparse Tensor attrs about cube weights 124 | verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal 125 | Returns: 126 | return the success tag and ni you loss, 127 | """ 128 | assert self.use_color == False 129 | 130 | coords = cubefeats.coords[:, 1:] 131 | feats = cubefeats.feats 132 | 133 | sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] 134 | sdf = sdf * (4. / self.res) 135 | sdf += self.sdf_bias 136 | 137 | v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] 138 | v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) 139 | 140 | res_v = self.res + 1 141 | v_attrs_d_sparse, v_pos_dilate = get_sparse_attrs(v_pos, v_attrs, res=res_v, sdf_init=True) 142 | weights_d_sparse, coords_dilate = get_sparse_attrs(coords, weights, res=self.res, sdf_init=False) 143 | 144 | sdf_d, deform_d = v_attrs_d_sparse[..., 0], v_attrs_d_sparse[..., 1:4] 145 | 146 | x_nx3 = get_defomed_verts(v_pos_dilate, deform_d, self.res) 147 | x_nx3 = torch.cat((x_nx3, torch.ones((1, 3), dtype=x_nx3.dtype, device=x_nx3.device) * 0.5)) 148 | sdf_d = torch.cat((sdf_d, torch.ones((1), dtype=sdf_d.dtype, device=sdf_d.device))) 149 | 150 | mask_reg_c_sparse = (v_pos_dilate[..., 0] * res_v + v_pos_dilate[..., 1]) * res_v + v_pos_dilate[..., 2] 151 | reg_c_sparse = (coords_dilate[..., 0] * res_v + coords_dilate[..., 1]) * res_v + coords_dilate[..., 2] 152 | cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] 153 | reg_c_value = (reg_c_sparse.unsqueeze(1) + cube_corners_bias.unsqueeze(0).cuda()).reshape(-1) 154 | reg_c = torch.searchsorted(mask_reg_c_sparse, reg_c_value) 155 | exact_match_mask = mask_reg_c_sparse[reg_c] == reg_c_value 156 | reg_c[exact_match_mask == 0] = len(mask_reg_c_sparse) 157 | reg_c = reg_c.reshape(-1, 8) 158 | 159 | vertices, faces, L_dev, colors = self.mesh_extractor( 160 | voxelgrid_vertices=x_nx3, 161 | scalar_field=sdf_d, 162 | cube_idx=reg_c, 163 | resolution=self.res, 164 | beta=weights_d_sparse[:, :12], 165 | alpha=weights_d_sparse[:, 12:20], 166 | gamma_f=weights_d_sparse[:, 20], 167 | cube_index_map=coords_dilate, 168 | training=training) 169 | 170 | mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) 171 | 172 | return mesh -------------------------------------------------------------------------------- /triposf/modules/sparse/attention/windowed_attn.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import math 27 | from .. import SparseTensor 28 | from .. import DEBUG, ATTN 29 | 30 | if ATTN == 'xformers': 31 | import xformers.ops as xops 32 | elif ATTN == 'flash_attn': 33 | import flash_attn 34 | else: 35 | raise ValueError(f"Unknown attention module: {ATTN}") 36 | 37 | 38 | __all__ = [ 39 | 'sparse_windowed_scaled_dot_product_self_attention', 40 | ] 41 | 42 | 43 | def calc_window_partition( 44 | tensor: SparseTensor, 45 | window_size: Union[int, Tuple[int, ...]], 46 | shift_window: Union[int, Tuple[int, ...]] = 0 47 | ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: 48 | """ 49 | Calculate serialization and partitioning for a set of coordinates. 50 | 51 | Args: 52 | tensor (SparseTensor): The input tensor. 53 | window_size (int): The window size to use. 54 | shift_window (Tuple[int, ...]): The shift of serialized coordinates. 55 | 56 | Returns: 57 | (torch.Tensor): Forwards indices. 58 | (torch.Tensor): Backwards indices. 59 | (List[int]): Sequence lengths. 60 | (List[int]): Sequence batch indices. 61 | """ 62 | DIM = tensor.coords.shape[1] - 1 63 | shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window 64 | window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size 65 | shifted_coords = tensor.coords.clone().detach() 66 | shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) 67 | 68 | MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() 69 | NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] 70 | OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] 71 | 72 | shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) 73 | shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) 74 | fwd_indices = torch.argsort(shifted_indices) 75 | bwd_indices = torch.empty_like(fwd_indices) 76 | bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) 77 | seq_lens = torch.bincount(shifted_indices) 78 | seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] 79 | mask = seq_lens != 0 80 | seq_lens = seq_lens[mask].tolist() 81 | seq_batch_indices = seq_batch_indices[mask].tolist() 82 | 83 | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices 84 | 85 | 86 | def sparse_windowed_scaled_dot_product_self_attention( 87 | qkv: SparseTensor, 88 | window_size: int, 89 | shift_window: Tuple[int, int, int] = (0, 0, 0) 90 | ) -> SparseTensor: 91 | """ 92 | Apply windowed scaled dot product self attention to a sparse tensor. 93 | 94 | Args: 95 | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. 96 | window_size (int): The window size to use. 97 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates. 98 | shift (int): The shift to use. 99 | """ 100 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" 101 | 102 | serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' 103 | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) 104 | if serialization_spatial_cache is None: 105 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) 106 | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) 107 | else: 108 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache 109 | 110 | M = fwd_indices.shape[0] 111 | T = qkv.feats.shape[0] 112 | H = qkv.feats.shape[2] 113 | C = qkv.feats.shape[3] 114 | 115 | qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] 116 | 117 | if DEBUG: 118 | start = 0 119 | qkv_coords = qkv.coords[fwd_indices] 120 | for i in range(len(seq_lens)): 121 | seq_coords = qkv_coords[start:start+seq_lens[i]] 122 | assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" 123 | assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ 124 | f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" 125 | start += seq_lens[i] 126 | 127 | if all([seq_len == window_size for seq_len in seq_lens]): 128 | B = len(seq_lens) 129 | N = window_size 130 | qkv_feats = qkv_feats.reshape(B, N, 3, H, C) 131 | if ATTN == 'xformers': 132 | q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] 133 | out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] 134 | elif ATTN == 'flash_attn': 135 | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] 136 | else: 137 | raise ValueError(f"Unknown attention module: {ATTN}") 138 | out = out.reshape(B * N, H, C) # [M, H, C] 139 | else: 140 | if ATTN == 'xformers': 141 | q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] 142 | q = q.unsqueeze(0) # [1, M, H, C] 143 | k = k.unsqueeze(0) # [1, M, H, C] 144 | v = v.unsqueeze(0) # [1, M, H, C] 145 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) 146 | out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] 147 | elif ATTN == 'flash_attn': 148 | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ 149 | .to(qkv.device).int() 150 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] 151 | 152 | out = out[bwd_indices] # [T, H, C] 153 | 154 | if DEBUG: 155 | qkv_coords = qkv_coords[bwd_indices] 156 | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" 157 | 158 | return qkv.replace(out) 159 | -------------------------------------------------------------------------------- /triposf/modules/sparse/attention/modules.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | from .. import SparseTensor 29 | from .full_attn import sparse_scaled_dot_product_attention 30 | from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention 31 | from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention 32 | from ...attention import RotaryPositionEmbedder 33 | 34 | 35 | class SparseMultiHeadRMSNorm(nn.Module): 36 | def __init__(self, dim: int, heads: int): 37 | super().__init__() 38 | self.scale = dim ** 0.5 39 | self.gamma = nn.Parameter(torch.ones(heads, dim)) 40 | 41 | def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: 42 | x_type = x.dtype 43 | x = x.float() 44 | if isinstance(x, SparseTensor): 45 | x = x.replace(F.normalize(x.feats, dim=-1)) 46 | else: 47 | x = F.normalize(x, dim=-1) 48 | return (x * self.gamma * self.scale).to(x_type) 49 | 50 | 51 | class SparseMultiHeadAttention(nn.Module): 52 | def __init__( 53 | self, 54 | channels: int, 55 | num_heads: int, 56 | ctx_channels: Optional[int] = None, 57 | type: Literal["self", "cross"] = "self", 58 | attn_mode: Literal["full", "serialized", "windowed"] = "full", 59 | window_size: Optional[int] = None, 60 | shift_sequence: Optional[int] = None, 61 | shift_window: Optional[Tuple[int, int, int]] = None, 62 | serialize_mode: Optional[SerializeMode] = None, 63 | qkv_bias: bool = True, 64 | use_rope: bool = False, 65 | qk_rms_norm: bool = False, 66 | ): 67 | super().__init__() 68 | assert channels % num_heads == 0 69 | assert type in ["self", "cross"], f"Invalid attention type: {type}" 70 | assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" 71 | assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" 72 | assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" 73 | self.channels = channels 74 | self.ctx_channels = ctx_channels if ctx_channels is not None else channels 75 | self.num_heads = num_heads 76 | self._type = type 77 | self.attn_mode = attn_mode 78 | self.window_size = window_size 79 | self.shift_sequence = shift_sequence 80 | self.shift_window = shift_window 81 | self.serialize_mode = serialize_mode 82 | self.use_rope = use_rope 83 | self.qk_rms_norm = qk_rms_norm 84 | 85 | if self._type == "self": 86 | self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) 87 | else: 88 | self.to_q = nn.Linear(channels, channels, bias=qkv_bias) 89 | self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) 90 | 91 | if self.qk_rms_norm: 92 | self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) 93 | self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) 94 | 95 | self.to_out = nn.Linear(channels, channels) 96 | 97 | if use_rope: 98 | self.rope = RotaryPositionEmbedder(channels) 99 | 100 | @staticmethod 101 | def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: 102 | if isinstance(x, SparseTensor): 103 | return x.replace(module(x.feats)) 104 | else: 105 | return module(x) 106 | 107 | @staticmethod 108 | def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: 109 | if isinstance(x, SparseTensor): 110 | return x.reshape(*shape) 111 | else: 112 | return x.reshape(*x.shape[:2], *shape) 113 | 114 | def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: 115 | if isinstance(x, SparseTensor): 116 | x_feats = x.feats.unsqueeze(0) 117 | else: 118 | x_feats = x 119 | x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) 120 | return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats 121 | 122 | def _rope(self, qkv: SparseTensor) -> SparseTensor: 123 | q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] 124 | q, k = self.rope(q, k, qkv.coords[:, 1:]) 125 | qkv = qkv.replace(torch.stack([q, k, v], dim=1)) 126 | return qkv 127 | 128 | def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: 129 | if self._type == "self": 130 | qkv = self._linear(self.to_qkv, x) 131 | qkv = self._fused_pre(qkv, num_fused=3) 132 | if self.use_rope: 133 | qkv = self._rope(qkv) 134 | if self.qk_rms_norm: 135 | q, k, v = qkv.unbind(dim=1) 136 | q = self.q_rms_norm(q) 137 | k = self.k_rms_norm(k) 138 | qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) 139 | if self.attn_mode == "full": 140 | h = sparse_scaled_dot_product_attention(qkv) 141 | elif self.attn_mode == "serialized": 142 | h = sparse_serialized_scaled_dot_product_self_attention( 143 | qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window 144 | ) 145 | elif self.attn_mode == "windowed": 146 | h = sparse_windowed_scaled_dot_product_self_attention( 147 | qkv, self.window_size, shift_window=self.shift_window 148 | ) 149 | else: 150 | q = self._linear(self.to_q, x) 151 | q = self._reshape_chs(q, (self.num_heads, -1)) 152 | kv = self._linear(self.to_kv, context) 153 | kv = self._fused_pre(kv, num_fused=2) 154 | if self.qk_rms_norm: 155 | q = self.q_rms_norm(q) 156 | k, v = kv.unbind(dim=1) 157 | k = self.k_rms_norm(k) 158 | kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) 159 | h = sparse_scaled_dot_product_attention(q, kv) 160 | h = self._reshape_chs(h, (-1,)) 161 | h = self._linear(self.to_out, h) 162 | return h 163 | -------------------------------------------------------------------------------- /triposf/modules/transformer/blocks.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | import torch.utils.checkpoint 28 | from ..attention import MultiHeadAttention 29 | from ..norm import LayerNorm32 30 | 31 | 32 | class AbsolutePositionEmbedder(nn.Module): 33 | """ 34 | Embeds spatial positions into vector representations. 35 | """ 36 | def __init__(self, channels: int, in_channels: int = 3): 37 | super().__init__() 38 | self.channels = channels 39 | self.in_channels = in_channels 40 | self.freq_dim = channels // in_channels // 2 41 | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim 42 | self.freqs = 1.0 / (10000 ** self.freqs) 43 | 44 | def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Create sinusoidal position embeddings. 47 | 48 | Args: 49 | x: a 1-D Tensor of N indices 50 | 51 | Returns: 52 | an (N, D) Tensor of positional embeddings. 53 | """ 54 | self.freqs = self.freqs.to(x.device) 55 | out = torch.outer(x, self.freqs) 56 | out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) 57 | return out 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | """ 61 | Args: 62 | x (torch.Tensor): (N, D) tensor of spatial positions 63 | """ 64 | N, D = x.shape 65 | assert D == self.in_channels, "Input dimension must match number of input channels" 66 | embed = self._sin_cos_embedding(x.reshape(-1)) 67 | embed = embed.reshape(N, -1) 68 | if embed.shape[1] < self.channels: 69 | embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) 70 | return embed 71 | 72 | 73 | class FeedForwardNet(nn.Module): 74 | def __init__(self, channels: int, mlp_ratio: float = 4.0): 75 | super().__init__() 76 | self.mlp = nn.Sequential( 77 | nn.Linear(channels, int(channels * mlp_ratio)), 78 | nn.GELU(approximate="tanh"), 79 | nn.Linear(int(channels * mlp_ratio), channels), 80 | ) 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | return self.mlp(x) 84 | 85 | 86 | class TransformerBlock(nn.Module): 87 | """ 88 | Transformer block (MSA + FFN). 89 | """ 90 | def __init__( 91 | self, 92 | channels: int, 93 | num_heads: int, 94 | mlp_ratio: float = 4.0, 95 | attn_mode: Literal["full", "windowed"] = "full", 96 | window_size: Optional[int] = None, 97 | shift_window: Optional[int] = None, 98 | use_checkpoint: bool = False, 99 | use_rope: bool = False, 100 | qk_rms_norm: bool = False, 101 | qkv_bias: bool = True, 102 | ln_affine: bool = False, 103 | ): 104 | super().__init__() 105 | self.use_checkpoint = use_checkpoint 106 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 107 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 108 | self.attn = MultiHeadAttention( 109 | channels, 110 | num_heads=num_heads, 111 | attn_mode=attn_mode, 112 | window_size=window_size, 113 | shift_window=shift_window, 114 | qkv_bias=qkv_bias, 115 | use_rope=use_rope, 116 | qk_rms_norm=qk_rms_norm, 117 | ) 118 | self.mlp = FeedForwardNet( 119 | channels, 120 | mlp_ratio=mlp_ratio, 121 | ) 122 | 123 | def _forward(self, x: torch.Tensor) -> torch.Tensor: 124 | h = self.norm1(x) 125 | h = self.attn(h) 126 | x = x + h 127 | h = self.norm2(x) 128 | h = self.mlp(h) 129 | x = x + h 130 | return x 131 | 132 | def forward(self, x: torch.Tensor) -> torch.Tensor: 133 | if self.use_checkpoint: 134 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) 135 | else: 136 | return self._forward(x) 137 | 138 | 139 | class TransformerCrossBlock(nn.Module): 140 | """ 141 | Transformer cross-attention block (MSA + MCA + FFN). 142 | """ 143 | def __init__( 144 | self, 145 | channels: int, 146 | ctx_channels: int, 147 | num_heads: int, 148 | mlp_ratio: float = 4.0, 149 | attn_mode: Literal["full", "windowed"] = "full", 150 | window_size: Optional[int] = None, 151 | shift_window: Optional[Tuple[int, int, int]] = None, 152 | use_checkpoint: bool = False, 153 | use_rope: bool = False, 154 | qk_rms_norm: bool = False, 155 | qk_rms_norm_cross: bool = False, 156 | qkv_bias: bool = True, 157 | ln_affine: bool = False, 158 | ): 159 | super().__init__() 160 | self.use_checkpoint = use_checkpoint 161 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 162 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 163 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 164 | self.self_attn = MultiHeadAttention( 165 | channels, 166 | num_heads=num_heads, 167 | type="self", 168 | attn_mode=attn_mode, 169 | window_size=window_size, 170 | shift_window=shift_window, 171 | qkv_bias=qkv_bias, 172 | use_rope=use_rope, 173 | qk_rms_norm=qk_rms_norm, 174 | ) 175 | self.cross_attn = MultiHeadAttention( 176 | channels, 177 | ctx_channels=ctx_channels, 178 | num_heads=num_heads, 179 | type="cross", 180 | attn_mode="full", 181 | qkv_bias=qkv_bias, 182 | qk_rms_norm=qk_rms_norm_cross, 183 | ) 184 | self.mlp = FeedForwardNet( 185 | channels, 186 | mlp_ratio=mlp_ratio, 187 | ) 188 | 189 | def _forward(self, x: torch.Tensor, context: torch.Tensor): 190 | h = self.norm1(x) 191 | h = self.self_attn(h) 192 | x = x + h 193 | h = self.norm2(x) 194 | h = self.cross_attn(h, context) 195 | x = x + h 196 | h = self.norm3(x) 197 | h = self.mlp(h) 198 | x = x + h 199 | return x 200 | 201 | def forward(self, x: torch.Tensor, context: torch.Tensor): 202 | if self.use_checkpoint: 203 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) 204 | else: 205 | return self._forward(x, context) 206 | -------------------------------------------------------------------------------- /triposf/modules/attention/modules.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | from .full_attn import scaled_dot_product_attention 29 | 30 | 31 | class MultiHeadRMSNorm(nn.Module): 32 | def __init__(self, dim: int, heads: int): 33 | super().__init__() 34 | self.scale = dim ** 0.5 35 | self.gamma = nn.Parameter(torch.ones(heads, dim)) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) 39 | 40 | 41 | class RotaryPositionEmbedder(nn.Module): 42 | def __init__(self, hidden_size: int, in_channels: int = 3): 43 | super().__init__() 44 | assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" 45 | self.hidden_size = hidden_size 46 | self.in_channels = in_channels 47 | self.freq_dim = hidden_size // in_channels // 2 48 | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim 49 | self.freqs = 1.0 / (10000 ** self.freqs) 50 | 51 | def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: 52 | self.freqs = self.freqs.to(indices.device) 53 | phases = torch.outer(indices, self.freqs) 54 | phases = torch.polar(torch.ones_like(phases), phases) 55 | return phases 56 | 57 | def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: 58 | x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) 59 | x_rotated = x_complex * phases 60 | x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) 61 | return x_embed 62 | 63 | def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 64 | """ 65 | Args: 66 | q (sp.SparseTensor): [..., N, D] tensor of queries 67 | k (sp.SparseTensor): [..., N, D] tensor of keys 68 | indices (torch.Tensor): [..., N, C] tensor of spatial positions 69 | """ 70 | if indices is None: 71 | indices = torch.arange(q.shape[-2], device=q.device) 72 | if len(q.shape) > 2: 73 | indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) 74 | 75 | phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) 76 | if phases.shape[1] < self.hidden_size // 2: 77 | phases = torch.cat([phases, torch.polar( 78 | torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), 79 | torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) 80 | )], dim=-1) 81 | q_embed = self._rotary_embedding(q, phases) 82 | k_embed = self._rotary_embedding(k, phases) 83 | return q_embed, k_embed 84 | 85 | 86 | class MultiHeadAttention(nn.Module): 87 | def __init__( 88 | self, 89 | channels: int, 90 | num_heads: int, 91 | ctx_channels: Optional[int]=None, 92 | type: Literal["self", "cross"] = "self", 93 | attn_mode: Literal["full", "windowed"] = "full", 94 | window_size: Optional[int] = None, 95 | shift_window: Optional[Tuple[int, int, int]] = None, 96 | qkv_bias: bool = True, 97 | use_rope: bool = False, 98 | qk_rms_norm: bool = False, 99 | ): 100 | super().__init__() 101 | assert channels % num_heads == 0 102 | assert type in ["self", "cross"], f"Invalid attention type: {type}" 103 | assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" 104 | assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" 105 | 106 | if attn_mode == "windowed": 107 | raise NotImplementedError("Windowed attention is not yet implemented") 108 | 109 | self.channels = channels 110 | self.head_dim = channels // num_heads 111 | self.ctx_channels = ctx_channels if ctx_channels is not None else channels 112 | self.num_heads = num_heads 113 | self._type = type 114 | self.attn_mode = attn_mode 115 | self.window_size = window_size 116 | self.shift_window = shift_window 117 | self.use_rope = use_rope 118 | self.qk_rms_norm = qk_rms_norm 119 | 120 | if self._type == "self": 121 | self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) 122 | else: 123 | self.to_q = nn.Linear(channels, channels, bias=qkv_bias) 124 | self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) 125 | 126 | if self.qk_rms_norm: 127 | self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) 128 | self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) 129 | 130 | self.to_out = nn.Linear(channels, channels) 131 | 132 | if use_rope: 133 | self.rope = RotaryPositionEmbedder(channels) 134 | 135 | def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: 136 | B, L, C = x.shape 137 | if self._type == "self": 138 | qkv = self.to_qkv(x) 139 | qkv = qkv.reshape(B, L, 3, self.num_heads, -1) 140 | if self.use_rope: 141 | q, k, v = qkv.unbind(dim=2) 142 | q, k = self.rope(q, k, indices) 143 | qkv = torch.stack([q, k, v], dim=2) 144 | if self.attn_mode == "full": 145 | if self.qk_rms_norm: 146 | q, k, v = qkv.unbind(dim=2) 147 | q = self.q_rms_norm(q) 148 | k = self.k_rms_norm(k) 149 | h = scaled_dot_product_attention(q, k, v) 150 | else: 151 | h = scaled_dot_product_attention(qkv) 152 | elif self.attn_mode == "windowed": 153 | raise NotImplementedError("Windowed attention is not yet implemented") 154 | else: 155 | Lkv = context.shape[1] 156 | q = self.to_q(x) 157 | kv = self.to_kv(context) 158 | q = q.reshape(B, L, self.num_heads, -1) 159 | kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) 160 | if self.qk_rms_norm: 161 | q = self.q_rms_norm(q) 162 | k, v = kv.unbind(dim=2) 163 | k = self.k_rms_norm(k) 164 | h = scaled_dot_product_attention(q, k, v) 165 | else: 166 | h = scaled_dot_product_attention(q, kv) 167 | h = h.reshape(B, L, -1) 168 | h = self.to_out(h) 169 | return h 170 | -------------------------------------------------------------------------------- /triposf/modules/sparse/transformer/modulated.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | import torch.utils.checkpoint 28 | from ..basic import SparseTensor 29 | from ..attention import SparseMultiHeadAttention, SerializeMode 30 | from ...norm import LayerNorm32 31 | from .blocks import SparseFeedForwardNet 32 | 33 | 34 | class ModulatedSparseTransformerBlock(nn.Module): 35 | """ 36 | Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. 37 | """ 38 | def __init__( 39 | self, 40 | channels: int, 41 | num_heads: int, 42 | mlp_ratio: float = 4.0, 43 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 44 | window_size: Optional[int] = None, 45 | shift_sequence: Optional[int] = None, 46 | shift_window: Optional[Tuple[int, int, int]] = None, 47 | serialize_mode: Optional[SerializeMode] = None, 48 | use_checkpoint: bool = False, 49 | use_rope: bool = False, 50 | qk_rms_norm: bool = False, 51 | qkv_bias: bool = True, 52 | share_mod: bool = False, 53 | ): 54 | super().__init__() 55 | self.use_checkpoint = use_checkpoint 56 | self.share_mod = share_mod 57 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 58 | self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 59 | self.attn = SparseMultiHeadAttention( 60 | channels, 61 | num_heads=num_heads, 62 | attn_mode=attn_mode, 63 | window_size=window_size, 64 | shift_sequence=shift_sequence, 65 | shift_window=shift_window, 66 | serialize_mode=serialize_mode, 67 | qkv_bias=qkv_bias, 68 | use_rope=use_rope, 69 | qk_rms_norm=qk_rms_norm, 70 | ) 71 | self.mlp = SparseFeedForwardNet( 72 | channels, 73 | mlp_ratio=mlp_ratio, 74 | ) 75 | if not share_mod: 76 | self.adaLN_modulation = nn.Sequential( 77 | nn.SiLU(), 78 | nn.Linear(channels, 6 * channels, bias=True) 79 | ) 80 | 81 | def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: 82 | if self.share_mod: 83 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) 84 | else: 85 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) 86 | h = x.replace(self.norm1(x.feats)) 87 | h = h * (1 + scale_msa) + shift_msa 88 | h = self.attn(h) 89 | h = h * gate_msa 90 | x = x + h 91 | h = x.replace(self.norm2(x.feats)) 92 | h = h * (1 + scale_mlp) + shift_mlp 93 | h = self.mlp(h) 94 | h = h * gate_mlp 95 | x = x + h 96 | return x 97 | 98 | def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: 99 | if self.use_checkpoint: 100 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) 101 | else: 102 | return self._forward(x, mod) 103 | 104 | 105 | class ModulatedSparseTransformerCrossBlock(nn.Module): 106 | """ 107 | Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. 108 | """ 109 | def __init__( 110 | self, 111 | channels: int, 112 | ctx_channels: int, 113 | num_heads: int, 114 | mlp_ratio: float = 4.0, 115 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 116 | window_size: Optional[int] = None, 117 | shift_sequence: Optional[int] = None, 118 | shift_window: Optional[Tuple[int, int, int]] = None, 119 | serialize_mode: Optional[SerializeMode] = None, 120 | use_checkpoint: bool = False, 121 | use_rope: bool = False, 122 | qk_rms_norm: bool = False, 123 | qk_rms_norm_cross: bool = False, 124 | qkv_bias: bool = True, 125 | share_mod: bool = False, 126 | 127 | ): 128 | super().__init__() 129 | self.use_checkpoint = use_checkpoint 130 | self.share_mod = share_mod 131 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 132 | self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) 133 | self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 134 | self.self_attn = SparseMultiHeadAttention( 135 | channels, 136 | num_heads=num_heads, 137 | type="self", 138 | attn_mode=attn_mode, 139 | window_size=window_size, 140 | shift_sequence=shift_sequence, 141 | shift_window=shift_window, 142 | serialize_mode=serialize_mode, 143 | qkv_bias=qkv_bias, 144 | use_rope=use_rope, 145 | qk_rms_norm=qk_rms_norm, 146 | ) 147 | self.cross_attn = SparseMultiHeadAttention( 148 | channels, 149 | ctx_channels=ctx_channels, 150 | num_heads=num_heads, 151 | type="cross", 152 | attn_mode="full", 153 | qkv_bias=qkv_bias, 154 | qk_rms_norm=qk_rms_norm_cross, 155 | ) 156 | self.mlp = SparseFeedForwardNet( 157 | channels, 158 | mlp_ratio=mlp_ratio, 159 | ) 160 | if not share_mod: 161 | self.adaLN_modulation = nn.Sequential( 162 | nn.SiLU(), 163 | nn.Linear(channels, 6 * channels, bias=True) 164 | ) 165 | 166 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: 167 | if self.share_mod: 168 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) 169 | else: 170 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) 171 | h = x.replace(self.norm1(x.feats)) 172 | h = h * (1 + scale_msa) + shift_msa 173 | h = self.self_attn(h) 174 | h = h * gate_msa 175 | x = x + h 176 | h = x.replace(self.norm2(x.feats)) 177 | h = self.cross_attn(h, context) 178 | x = x + h 179 | h = x.replace(self.norm3(x.feats)) 180 | h = h * (1 + scale_mlp) + shift_mlp 181 | h = self.mlp(h) 182 | h = h * gate_mlp 183 | x = x + h 184 | return x 185 | 186 | def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: 187 | if self.use_checkpoint: 188 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) 189 | else: 190 | return self._forward(x, mod, context) 191 | -------------------------------------------------------------------------------- /triposf/modules/pointclouds/pointnet.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | # modified from https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/src/encoder/pointnet.py 25 | 26 | import torch 27 | import torch.nn as nn 28 | import copy 29 | from torch import Tensor 30 | from torch_scatter import scatter_mean 31 | 32 | def scale_tensor( 33 | dat, inp_scale=None, tgt_scale=None 34 | ): 35 | if inp_scale is None: 36 | inp_scale = (-0.5, 0.5) 37 | if tgt_scale is None: 38 | tgt_scale = (0, 1) 39 | assert tgt_scale[1] > tgt_scale[0] and inp_scale[1] > inp_scale[0] 40 | if isinstance(tgt_scale, Tensor): 41 | assert dat.shape[-1] == tgt_scale.shape[-1] 42 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 43 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 44 | return dat.clamp(tgt_scale[0] + 1e-6, tgt_scale[1] - 1e-6) 45 | 46 | # Resnet Blocks for pointnet 47 | class ResnetBlockFC(nn.Module): 48 | ''' Fully connected ResNet Block class. 49 | 50 | Args: 51 | size_in (int): input dimension 52 | size_out (int): output dimension 53 | size_h (int): hidden dimension 54 | ''' 55 | 56 | def __init__(self, size_in, size_out=None, size_h=None): 57 | super().__init__() 58 | # Attributes 59 | if size_out is None: 60 | size_out = size_in 61 | 62 | if size_h is None: 63 | size_h = min(size_in, size_out) 64 | 65 | self.size_in = size_in 66 | self.size_h = size_h 67 | self.size_out = size_out 68 | # Submodules 69 | self.fc_0 = nn.Linear(size_in, size_h) 70 | self.fc_1 = nn.Linear(size_h, size_out) 71 | self.actvn = nn.GELU(approximate="tanh") 72 | 73 | if size_in == size_out: 74 | self.shortcut = None 75 | else: 76 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 77 | # Initialization 78 | nn.init.xavier_uniform_(self.fc_0.weight) 79 | if self.fc_0.bias is not None: 80 | nn.init.constant_(self.fc_0.bias, 0) 81 | if self.shortcut is not None: 82 | nn.init.xavier_uniform_(self.shortcut.weight) 83 | if self.shortcut.bias is not None: 84 | nn.init.constant_(self.shortcut.bias, 0) 85 | 86 | nn.init.xavier_uniform_(self.fc_1.weight) 87 | if self.fc_1.bias is not None: 88 | nn.init.constant_(self.fc_1.bias, 0) 89 | 90 | 91 | def forward(self, x): 92 | net = self.fc_0(self.actvn(x)) 93 | dx = self.fc_1(self.actvn(net)) 94 | 95 | if self.shortcut is not None: 96 | x_s = self.shortcut(x) 97 | else: 98 | x_s = x 99 | 100 | return x_s + dx 101 | 102 | class LocalPoolPointnet(nn.Module): 103 | def __init__(self, in_channels=3, out_channels=128, hidden_dim=128, scatter_type='mean', n_blocks=5): 104 | super().__init__() 105 | self.scatter_type = scatter_type 106 | self.in_channels = in_channels 107 | self.hidden_dim = hidden_dim 108 | self.out_channels = out_channels 109 | self.fc_pos = nn.Linear(in_channels, 2*hidden_dim) 110 | self.blocks = nn.ModuleList([ 111 | ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) 112 | ]) 113 | self.fc_c = nn.Linear(hidden_dim, out_channels) 114 | 115 | if self.scatter_type == 'mean': 116 | self.scatter = scatter_mean 117 | else: 118 | raise ValueError('Incorrect scatter type') 119 | self.initialize_weights() 120 | 121 | def initialize_weights(self): 122 | 123 | nn.init.xavier_uniform_(self.fc_pos.weight) 124 | if self.fc_pos.bias is not None: 125 | nn.init.constant_(self.fc_pos.bias, 0) 126 | 127 | nn.init.xavier_uniform_(self.fc_c.weight) 128 | if self.fc_c.bias is not None: 129 | nn.init.constant_(self.fc_c.bias, 0) 130 | 131 | def convert_to_sparse_feats(self, c, sparse_coords): 132 | ''' 133 | Input: 134 | sparse_coords: Tensor [Nx, 4], point to sparse indices 135 | c: Tensor [B, res, C], input feats of each grid 136 | Output: 137 | c_out: Tensor [B, Np, C], aggregated grid feats of each point 138 | ''' 139 | feats_new = torch.zeros((sparse_coords.shape[0], c.shape[-1]), device=c.device, dtype=c.dtype) 140 | offsets = 0 141 | 142 | batch_nums = copy.deepcopy(sparse_coords[..., 0]) 143 | for i in range(len(c)): 144 | coords_num_i = (batch_nums == i).sum() 145 | feats_new[offsets: offsets + coords_num_i] = c[i, : coords_num_i] 146 | offsets += coords_num_i 147 | return feats_new 148 | 149 | def generate_sparse_grid_features(self, index, c, max_coord_num): 150 | # scatter grid features from points 151 | bs, fea_dim = c.size(0), c.size(2) 152 | res = max_coord_num 153 | c_out = c.new_zeros(bs, self.out_channels, res) 154 | c_out = scatter_mean(c.permute(0, 2, 1), index, out=c_out).permute(0, 2, 1) # B x res X C 155 | return c_out 156 | 157 | def pool_sparse_local(self, index, c, max_coord_num): 158 | ''' 159 | Input: 160 | index: Tensor [B, 1, Np], sparse indices of each point 161 | c: Tensor [B, Np, C], input feats of each point 162 | Output: 163 | c_out: Tensor [B, Np, C], aggregated grid feats of each point 164 | ''' 165 | 166 | bs, fea_dim = c.size(0), c.size(2) 167 | res = max_coord_num 168 | c_out = c.new_zeros(bs, fea_dim, res) 169 | c_out = self.scatter(c.permute(0, 2, 1), index, out=c_out) 170 | 171 | # gather feature back to points 172 | c_out = c_out.gather(dim=2, index=index.expand(-1, fea_dim, -1)) 173 | return c_out.permute(0, 2, 1) 174 | 175 | @torch.no_grad() 176 | def coordinate2sparseindex(self, x, sparse_coords, res): 177 | ''' 178 | Input: 179 | x: Tensor [B, Np, 3], points scaled at ([0, 1] * res) 180 | sparse_coords: Tensor [Nx, 4] ([batch_number, x, y, z]) 181 | res: Int, resolution of the grid index 182 | Output: 183 | sparse_index: Tensor [B, 1, Np], sparse indices of each point 184 | ''' 185 | B = x.shape[0] 186 | sparse_index = torch.zeros((B, x.shape[1]), device=x.device, dtype=torch.int64) 187 | 188 | index = (x[..., 0] * res + x[..., 1]) * res + x[..., 2] 189 | sparse_indices = copy.deepcopy(sparse_coords) 190 | sparse_indices[..., 1] = (sparse_indices[..., 1] * res + sparse_indices[..., 2]) * res + sparse_indices[..., 3] 191 | sparse_indices = sparse_indices[..., :2] 192 | 193 | for i in range(B): 194 | mask_i = sparse_indices[..., 0] == i 195 | coords_i = sparse_indices[mask_i, 1] 196 | coords_num_i = len(coords_i) 197 | sparse_index[i] = torch.searchsorted(coords_i, index[i]) 198 | 199 | return sparse_index[:, None, :] 200 | 201 | def forward(self, p, sparse_coords, res=64, bbox_size=(-0.5, 0.5)): 202 | ''' 203 | Input: 204 | p : Tensor [B, Np(819_200), 3] 205 | sparse_coords: Tensor [Nx, 4] ([batch_number, x, y, z]) 206 | 207 | Output: 208 | sparse_pc_feats: [Nx, self.out_channels] 209 | ''' 210 | batch_size, T, D = p.size() 211 | max_coord_num = 0 212 | for i in range(batch_size): 213 | max_coord_num = max(max_coord_num, (sparse_coords[..., 0] == i).sum().item() + 5) 214 | 215 | if D == 6: 216 | p, normals = p[..., :3], p[..., 3:] 217 | 218 | coord = (scale_tensor(p, inp_scale=bbox_size) * res) 219 | p = 2 * (coord - (coord.floor() + 0.5)) # dist to the centrios, [-1., 1.] 220 | index = self.coordinate2sparseindex(coord.long(), sparse_coords, res) 221 | 222 | if D == 6: 223 | p = torch.cat((p, normals), dim=-1) 224 | net = self.fc_pos(p) 225 | net = self.blocks[0](net) 226 | for block in self.blocks[1:]: 227 | pooled = self.pool_sparse_local(index, net, max_coord_num=max_coord_num) 228 | 229 | net = torch.cat([net, pooled], dim=2) 230 | net = block(net) 231 | c = self.fc_c(net) 232 | feats = self.generate_sparse_grid_features(index, c, max_coord_num=max_coord_num) 233 | feats = self.convert_to_sparse_feats(feats, sparse_coords) 234 | 235 | torch.cuda.empty_cache() 236 | return feats 237 | -------------------------------------------------------------------------------- /triposf/modules/sparse/attention/serialized_attn.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | from enum import Enum 26 | import torch 27 | import math 28 | from .. import SparseTensor 29 | from .. import DEBUG, ATTN 30 | 31 | if ATTN == 'xformers': 32 | import xformers.ops as xops 33 | elif ATTN == 'flash_attn': 34 | import flash_attn 35 | else: 36 | raise ValueError(f"Unknown attention module: {ATTN}") 37 | 38 | 39 | __all__ = [ 40 | 'sparse_serialized_scaled_dot_product_self_attention', 41 | ] 42 | 43 | 44 | class SerializeMode(Enum): 45 | Z_ORDER = 0 46 | Z_ORDER_TRANSPOSED = 1 47 | HILBERT = 2 48 | HILBERT_TRANSPOSED = 3 49 | 50 | 51 | SerializeModes = [ 52 | SerializeMode.Z_ORDER, 53 | SerializeMode.Z_ORDER_TRANSPOSED, 54 | SerializeMode.HILBERT, 55 | SerializeMode.HILBERT_TRANSPOSED 56 | ] 57 | 58 | 59 | def calc_serialization( 60 | tensor: SparseTensor, 61 | window_size: int, 62 | serialize_mode: SerializeMode = SerializeMode.Z_ORDER, 63 | shift_sequence: int = 0, 64 | shift_window: Tuple[int, int, int] = (0, 0, 0) 65 | ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: 66 | """ 67 | Calculate serialization and partitioning for a set of coordinates. 68 | 69 | Args: 70 | tensor (SparseTensor): The input tensor. 71 | window_size (int): The window size to use. 72 | serialize_mode (SerializeMode): The serialization mode to use. 73 | shift_sequence (int): The shift of serialized sequence. 74 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates. 75 | 76 | Returns: 77 | (torch.Tensor, torch.Tensor): Forwards and backwards indices. 78 | """ 79 | fwd_indices = [] 80 | bwd_indices = [] 81 | seq_lens = [] 82 | seq_batch_indices = [] 83 | offsets = [0] 84 | 85 | if 'vox2seq' not in globals(): 86 | import vox2seq 87 | 88 | # Serialize the input 89 | serialize_coords = tensor.coords[:, 1:].clone() 90 | serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) 91 | if serialize_mode == SerializeMode.Z_ORDER: 92 | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) 93 | elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: 94 | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) 95 | elif serialize_mode == SerializeMode.HILBERT: 96 | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) 97 | elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: 98 | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) 99 | else: 100 | raise ValueError(f"Unknown serialize mode: {serialize_mode}") 101 | 102 | for bi, s in enumerate(tensor.layout): 103 | num_points = s.stop - s.start 104 | num_windows = (num_points + window_size - 1) // window_size 105 | valid_window_size = num_points / num_windows 106 | to_ordered = torch.argsort(code[s.start:s.stop]) 107 | if num_windows == 1: 108 | fwd_indices.append(to_ordered) 109 | bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) 110 | fwd_indices[-1] += s.start 111 | bwd_indices[-1] += offsets[-1] 112 | seq_lens.append(num_points) 113 | seq_batch_indices.append(bi) 114 | offsets.append(offsets[-1] + seq_lens[-1]) 115 | else: 116 | # Partition the input 117 | offset = 0 118 | mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] 119 | split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] 120 | bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) 121 | for i in range(num_windows): 122 | mid = mids[i] 123 | valid_start = split[i] 124 | valid_end = split[i + 1] 125 | padded_start = math.floor(mid - 0.5 * window_size) 126 | padded_end = padded_start + window_size 127 | fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) 128 | offset += valid_start - padded_start 129 | bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) 130 | offset += padded_end - valid_start 131 | fwd_indices[-1] += s.start 132 | seq_lens.extend([window_size] * num_windows) 133 | seq_batch_indices.extend([bi] * num_windows) 134 | bwd_indices.append(bwd_index + offsets[-1]) 135 | offsets.append(offsets[-1] + num_windows * window_size) 136 | 137 | fwd_indices = torch.cat(fwd_indices) 138 | bwd_indices = torch.cat(bwd_indices) 139 | 140 | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices 141 | 142 | 143 | def sparse_serialized_scaled_dot_product_self_attention( 144 | qkv: SparseTensor, 145 | window_size: int, 146 | serialize_mode: SerializeMode = SerializeMode.Z_ORDER, 147 | shift_sequence: int = 0, 148 | shift_window: Tuple[int, int, int] = (0, 0, 0) 149 | ) -> SparseTensor: 150 | """ 151 | Apply serialized scaled dot product self attention to a sparse tensor. 152 | 153 | Args: 154 | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. 155 | window_size (int): The window size to use. 156 | serialize_mode (SerializeMode): The serialization mode to use. 157 | shift_sequence (int): The shift of serialized sequence. 158 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates. 159 | shift (int): The shift to use. 160 | """ 161 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" 162 | 163 | serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' 164 | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) 165 | if serialization_spatial_cache is None: 166 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) 167 | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) 168 | else: 169 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache 170 | 171 | M = fwd_indices.shape[0] 172 | T = qkv.feats.shape[0] 173 | H = qkv.feats.shape[2] 174 | C = qkv.feats.shape[3] 175 | 176 | qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] 177 | 178 | if DEBUG: 179 | start = 0 180 | qkv_coords = qkv.coords[fwd_indices] 181 | for i in range(len(seq_lens)): 182 | assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" 183 | start += seq_lens[i] 184 | 185 | if all([seq_len == window_size for seq_len in seq_lens]): 186 | B = len(seq_lens) 187 | N = window_size 188 | qkv_feats = qkv_feats.reshape(B, N, 3, H, C) 189 | if ATTN == 'xformers': 190 | q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] 191 | out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] 192 | elif ATTN == 'flash_attn': 193 | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] 194 | else: 195 | raise ValueError(f"Unknown attention module: {ATTN}") 196 | out = out.reshape(B * N, H, C) # [M, H, C] 197 | else: 198 | if ATTN == 'xformers': 199 | q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] 200 | q = q.unsqueeze(0) # [1, M, H, C] 201 | k = k.unsqueeze(0) # [1, M, H, C] 202 | v = v.unsqueeze(0) # [1, M, H, C] 203 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) 204 | out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] 205 | elif ATTN == 'flash_attn': 206 | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ 207 | .to(qkv.device).int() 208 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] 209 | 210 | out = out[bwd_indices] # [T, H, C] 211 | 212 | if DEBUG: 213 | qkv_coords = qkv_coords[bwd_indices] 214 | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" 215 | 216 | return qkv.replace(out) 217 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2025 VAST-AI-Research and contributors. 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE 22 | 23 | from safetensors.torch import load_file 24 | import torch 25 | from dataclasses import dataclass, field 26 | import numpy as np 27 | import open3d as o3d 28 | import trimesh 29 | import os 30 | import time 31 | import argparse 32 | from omegaconf import OmegaConf 33 | from typing import * 34 | 35 | from triposf.modules import sparse as sp 36 | from misc import get_device, find 37 | 38 | def normalize_mesh(mesh_path): 39 | scene = trimesh.load(mesh_path, process=False, force='scene') 40 | meshes = [] 41 | for node_name in scene.graph.nodes_geometry: 42 | geom_name = scene.graph[node_name][1] 43 | geometry = scene.geometry[geom_name] 44 | transform = scene.graph[node_name][0] 45 | if isinstance(geometry, trimesh.Trimesh): 46 | geometry.apply_transform(transform) 47 | meshes.append(geometry) 48 | 49 | mesh = trimesh.util.concatenate(meshes) 50 | 51 | center = mesh.bounding_box.centroid 52 | mesh.apply_translation(-center) 53 | scale = max(mesh.bounding_box.extents) 54 | mesh.apply_scale(2.0 / scale * 0.5) 55 | 56 | return mesh 57 | 58 | def load_quantized_mesh_original( 59 | mesh_path, 60 | volume_resolution=256, 61 | use_normals=True, 62 | pc_sample_number=4096000, 63 | ): 64 | cube_dilate = np.array( 65 | [ 66 | [0, 0, 0], 67 | [0, 0, 1], 68 | [0, 1, 0], 69 | [0, 0, -1], 70 | [0, -1, 0], 71 | [0, 1, 1], 72 | [0, -1, 1], 73 | [0, 1, -1], 74 | [0, -1, -1], 75 | 76 | [1, 0, 0], 77 | [1, 0, 1], 78 | [1, 1, 0], 79 | [1, 0, -1], 80 | [1, -1, 0], 81 | [1, 1, 1], 82 | [1, -1, 1], 83 | [1, 1, -1], 84 | [1, -1, -1], 85 | 86 | [-1, 0, 0], 87 | [-1, 0, 1], 88 | [-1, 1, 0], 89 | [-1, 0, -1], 90 | [-1, -1, 0], 91 | [-1, 1, 1], 92 | [-1, -1, 1], 93 | [-1, 1, -1], 94 | [-1, -1, -1], 95 | ] 96 | ) / (volume_resolution * 4 - 1) 97 | 98 | 99 | mesh = o3d.io.read_triangle_mesh(mesh_path) 100 | vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6) 101 | faces = np.asarray(mesh.triangles) 102 | mesh.vertices = o3d.utility.Vector3dVector(vertices) 103 | 104 | voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds( 105 | mesh, 106 | voxel_size=1. / volume_resolution, 107 | min_bound=[-0.5, -0.5, -0.5], 108 | max_bound=[0.5, 0.5, 0.5] 109 | ) 110 | voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()]) 111 | 112 | points_normals_sample = trimesh.Trimesh(vertices=vertices, faces=faces).sample(count=pc_sample_number, return_index=True) 113 | points_sample = points_normals_sample[0].astype(np.float32) 114 | voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds( 115 | o3d.geometry.PointCloud( 116 | o3d.utility.Vector3dVector( 117 | np.clip( 118 | (points_sample[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3), 119 | -0.5 + 1e-6, 0.5 - 1e-6) 120 | ) 121 | ), 122 | voxel_size=1. / volume_resolution, 123 | min_bound=[-0.5, -0.5, -0.5], 124 | max_bound=[0.5, 0.5, 0.5] 125 | ) 126 | voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()]) 127 | voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0)) 128 | 129 | if use_normals: 130 | mesh.compute_triangle_normals() 131 | normals_sample = np.asarray( 132 | mesh.triangle_normals 133 | )[points_normals_sample[1]].astype(np.float32) 134 | points_sample = torch.cat((torch.Tensor(points_sample), torch.Tensor(normals_sample)), axis=-1) 135 | 136 | return voxels, points_sample 137 | 138 | class TripoSFVAEInference(torch.nn.Module): 139 | @dataclass 140 | class Config: 141 | local_pc_encoder_cls: str = "" 142 | local_pc_encoder: dict = field(default_factory=dict) 143 | 144 | encoder_cls: str = "" 145 | encoder: dict = field(default_factory=dict) 146 | 147 | decoder_cls: str = "" 148 | decoder: dict = field(default_factory=dict) 149 | 150 | resolution: int = 256 151 | sample_points_num: int = 819_200 152 | use_normals: bool = True 153 | pruning: bool = False 154 | 155 | weight: Optional[str] = None 156 | 157 | cfg: Config 158 | 159 | def __init__(self, cfg): 160 | super().__init__() 161 | self.cfg = cfg 162 | self.configure() 163 | 164 | def load_weights(self): 165 | if self.cfg.weight is not None: 166 | print("Pretrained VAE Loading...") 167 | state_dict = load_file(self.cfg.weight) 168 | self.load_state_dict(state_dict) 169 | 170 | def configure(self) -> None: 171 | self.local_pc_encoder = find(self.cfg.local_pc_encoder_cls)(**self.cfg.local_pc_encoder).eval() 172 | for p in self.local_pc_encoder.parameters(): 173 | p.requires_grad = False 174 | 175 | self.encoder = find(self.cfg.encoder_cls)(**self.cfg.encoder).eval() 176 | for p in self.encoder.parameters(): 177 | p.requires_grad = False 178 | 179 | self.decoder = find(self.cfg.decoder_cls)(**self.cfg.decoder).eval() 180 | for p in self.decoder.parameters(): 181 | p.requires_grad = False 182 | 183 | self.load_weights() 184 | 185 | @torch.no_grad() 186 | def forward(self, points_sample, sparse_voxel_coords): 187 | with torch.autocast("cuda", dtype=torch.float32): 188 | sparse_pc_features = self.local_pc_encoder(points_sample, sparse_voxel_coords, res=self.cfg.resolution, bbox_size=(-0.5, 0.5)) 189 | sparse_tensor = sp.SparseTensor(sparse_pc_features, sparse_voxel_coords) 190 | latent, posterior = self.encoder(sparse_tensor) 191 | mesh = self.decoder(latent, pruning=self.cfg.pruning) 192 | return mesh 193 | 194 | @classmethod 195 | def from_config(cls, config_path): 196 | config = OmegaConf.load(config_path) 197 | cfg = OmegaConf.merge(OmegaConf.structured(TripoSFVAEInference.Config), config) 198 | return cls(cfg) 199 | 200 | if __name__ == "__main__": 201 | # Usage: `python inference.py --mesh-path "assets/examples/loong.obj" --output-dir "outputs/" --config "configs/triposfVAE_1024.yaml"` 202 | parser = argparse.ArgumentParser("TripoSF Reconstruction") 203 | parser.add_argument("--config", required=True, help="path to config file") 204 | parser.add_argument("--output-dir", default="outputs/", help="path to output folder") 205 | parser.add_argument("--mesh-path", type=str, help="the input mesh to be reconstructed") 206 | 207 | args, extras = parser.parse_known_args() 208 | device = get_device() 209 | save_name = os.path.split(args.mesh_path)[-1].split(".")[0] 210 | 211 | model = TripoSFVAEInference.from_config(args.config).to(device) 212 | 213 | os.makedirs(args.output_dir, exist_ok=True) 214 | 215 | print(f"Mesh Normalizing...") 216 | preprocess_start = time.time() 217 | mesh_gt = normalize_mesh(args.mesh_path) 218 | save_path_gt = f"{args.output_dir}/{save_name}_gt.obj" 219 | trimesh.Trimesh(vertices=mesh_gt.vertices.tolist(), faces=mesh_gt.faces.tolist()).export(save_path_gt) 220 | preprocess_end = time.time() 221 | print(f"Mesh Normalizing Time: {(preprocess_end - preprocess_start):.2f}") 222 | 223 | print(f"Mesh Loading...") 224 | input_loading_start = time.time() 225 | sparse_voxels, points_sample = load_quantized_mesh_original( 226 | save_path_gt, 227 | volume_resolution=model.cfg.resolution, 228 | use_normals=model.cfg.use_normals, 229 | pc_sample_number=model.cfg.sample_points_num, 230 | ) 231 | input_loading_end = time.time() 232 | print(f"Mesh Loading Time: {(input_loading_end - input_loading_start):.2f}") 233 | 234 | print(f"Mesh Reconstructing...") 235 | sparse_voxels, points_sample = sparse_voxels.to(device), points_sample.to(device) 236 | sparse_voxels_sp = torch.cat([torch.zeros_like(sparse_voxels[..., :1]), sparse_voxels], dim=-1).int() 237 | 238 | inference_start = time.time() 239 | with torch.cuda.amp.autocast(dtype=torch.float16): 240 | mesh_recon = model(points_sample[None], sparse_voxels_sp)[0] 241 | inference_end = time.time() 242 | print(f"Mesh Reconstructing Time: {(inference_end - inference_start):.2f}") 243 | 244 | save_path_recon = f"{args.output_dir}/{save_name}_reconstruction.obj" 245 | trimesh.Trimesh(vertices=mesh_recon.vertices.tolist(), faces=mesh_recon.faces.tolist()).export(save_path_recon) -------------------------------------------------------------------------------- /triposf/modules/sparse/attention/full_attn.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | from .. import SparseTensor 27 | from .. import DEBUG, ATTN 28 | 29 | if ATTN == 'xformers': 30 | import xformers.ops as xops 31 | elif ATTN == 'flash_attn': 32 | import flash_attn 33 | else: 34 | raise ValueError(f"Unknown attention module: {ATTN}") 35 | 36 | 37 | __all__ = [ 38 | 'sparse_scaled_dot_product_attention', 39 | ] 40 | 41 | 42 | @overload 43 | def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: 44 | """ 45 | Apply scaled dot product attention to a sparse tensor. 46 | 47 | Args: 48 | qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. 49 | """ 50 | ... 51 | 52 | @overload 53 | def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: 54 | """ 55 | Apply scaled dot product attention to a sparse tensor. 56 | 57 | Args: 58 | q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. 59 | kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. 60 | """ 61 | ... 62 | 63 | @overload 64 | def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: 65 | """ 66 | Apply scaled dot product attention to a sparse tensor. 67 | 68 | Args: 69 | q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. 70 | kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. 71 | """ 72 | ... 73 | 74 | @overload 75 | def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: 76 | """ 77 | Apply scaled dot product attention to a sparse tensor. 78 | 79 | Args: 80 | q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. 81 | k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. 82 | v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. 83 | 84 | Note: 85 | k and v are assumed to have the same coordinate map. 86 | """ 87 | ... 88 | 89 | @overload 90 | def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: 91 | """ 92 | Apply scaled dot product attention to a sparse tensor. 93 | 94 | Args: 95 | q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. 96 | k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. 97 | v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. 98 | """ 99 | ... 100 | 101 | @overload 102 | def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: 103 | """ 104 | Apply scaled dot product attention to a sparse tensor. 105 | 106 | Args: 107 | q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. 108 | k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. 109 | v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. 110 | """ 111 | ... 112 | 113 | def sparse_scaled_dot_product_attention(*args, **kwargs): 114 | arg_names_dict = { 115 | 1: ['qkv'], 116 | 2: ['q', 'kv'], 117 | 3: ['q', 'k', 'v'] 118 | } 119 | num_all_args = len(args) + len(kwargs) 120 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" 121 | for key in arg_names_dict[num_all_args][len(args):]: 122 | assert key in kwargs, f"Missing argument {key}" 123 | 124 | if num_all_args == 1: 125 | qkv = args[0] if len(args) > 0 else kwargs['qkv'] 126 | assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" 127 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" 128 | device = qkv.device 129 | 130 | s = qkv 131 | q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] 132 | kv_seqlen = q_seqlen 133 | qkv = qkv.feats # [T, 3, H, C] 134 | 135 | elif num_all_args == 2: 136 | q = args[0] if len(args) > 0 else kwargs['q'] 137 | kv = args[1] if len(args) > 1 else kwargs['kv'] 138 | assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ 139 | isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ 140 | f"Invalid types, got {type(q)} and {type(kv)}" 141 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" 142 | device = q.device 143 | 144 | if isinstance(q, SparseTensor): 145 | assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" 146 | s = q 147 | q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] 148 | q = q.feats # [T_Q, H, C] 149 | else: 150 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" 151 | s = None 152 | N, L, H, C = q.shape 153 | q_seqlen = [L] * N 154 | q = q.reshape(N * L, H, C) # [T_Q, H, C] 155 | 156 | if isinstance(kv, SparseTensor): 157 | assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" 158 | kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] 159 | kv = kv.feats # [T_KV, 2, H, C] 160 | else: 161 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" 162 | N, L, _, H, C = kv.shape 163 | kv_seqlen = [L] * N 164 | kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] 165 | 166 | elif num_all_args == 3: 167 | q = args[0] if len(args) > 0 else kwargs['q'] 168 | k = args[1] if len(args) > 1 else kwargs['k'] 169 | v = args[2] if len(args) > 2 else kwargs['v'] 170 | assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ 171 | isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ 172 | f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" 173 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" 174 | device = q.device 175 | 176 | if isinstance(q, SparseTensor): 177 | assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" 178 | s = q 179 | q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] 180 | q = q.feats # [T_Q, H, Ci] 181 | else: 182 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" 183 | s = None 184 | N, L, H, CI = q.shape 185 | q_seqlen = [L] * N 186 | q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] 187 | 188 | if isinstance(k, SparseTensor): 189 | assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" 190 | assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" 191 | kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] 192 | k = k.feats # [T_KV, H, Ci] 193 | v = v.feats # [T_KV, H, Co] 194 | else: 195 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" 196 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" 197 | N, L, H, CI, CO = *k.shape, v.shape[-1] 198 | kv_seqlen = [L] * N 199 | k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] 200 | v = v.reshape(N * L, H, CO) # [T_KV, H, Co] 201 | 202 | if DEBUG: 203 | if s is not None: 204 | for i in range(s.shape[0]): 205 | assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" 206 | if num_all_args in [2, 3]: 207 | assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" 208 | if num_all_args == 3: 209 | assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" 210 | assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" 211 | 212 | if ATTN == 'xformers': 213 | if num_all_args == 1: 214 | q, k, v = qkv.unbind(dim=1) 215 | elif num_all_args == 2: 216 | k, v = kv.unbind(dim=1) 217 | q = q.unsqueeze(0) 218 | k = k.unsqueeze(0) 219 | v = v.unsqueeze(0) 220 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) 221 | out = xops.memory_efficient_attention(q, k, v, mask)[0] 222 | elif ATTN == 'flash_attn': 223 | cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) 224 | if num_all_args in [2, 3]: 225 | cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) 226 | if num_all_args == 1: 227 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) 228 | elif num_all_args == 2: 229 | out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) 230 | elif num_all_args == 3: 231 | out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) 232 | else: 233 | raise ValueError(f"Unknown attention module: {ATTN}") 234 | 235 | if s is not None: 236 | return s.replace(out) 237 | else: 238 | return out.reshape(N, L, H, -1) 239 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2025 VAST-AI-Research and contributors. 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE 22 | 23 | import random 24 | import gradio as gr 25 | from safetensors.torch import load_file 26 | import torch 27 | from dataclasses import dataclass, field 28 | import numpy as np 29 | import open3d as o3d 30 | import trimesh 31 | import os 32 | import time 33 | import argparse 34 | from omegaconf import OmegaConf 35 | from typing import * 36 | 37 | from triposf.modules import sparse as sp 38 | from misc import get_device, find 39 | 40 | def get_random_hex(): 41 | random_bytes = os.urandom(8) 42 | random_hex = random_bytes.hex() 43 | return random_hex 44 | 45 | def get_random_seed(randomize_seed, seed): 46 | if randomize_seed: 47 | seed = random.randint(0, MAX_SEED) 48 | return seed 49 | 50 | def normalize_mesh(mesh_path): 51 | scene = trimesh.load(mesh_path, process=False, force='scene') 52 | meshes = [] 53 | for node_name in scene.graph.nodes_geometry: 54 | geom_name = scene.graph[node_name][1] 55 | geometry = scene.geometry[geom_name] 56 | transform = scene.graph[node_name][0] 57 | if isinstance(geometry, trimesh.Trimesh): 58 | geometry.apply_transform(transform) 59 | meshes.append(geometry) 60 | 61 | mesh = trimesh.util.concatenate(meshes) 62 | 63 | center = mesh.bounding_box.centroid 64 | mesh.apply_translation(-center) 65 | scale = max(mesh.bounding_box.extents) 66 | mesh.apply_scale(2.0 / scale * 0.5) 67 | 68 | angle = np.radians(90) 69 | rotation_matrix = trimesh.transformations.rotation_matrix(angle, [-1, 0, 0]) 70 | mesh.apply_transform(rotation_matrix) 71 | return mesh 72 | 73 | def load_quantized_mesh_original( 74 | mesh_path, 75 | volume_resolution=256, 76 | use_normals=True, 77 | pc_sample_number=4096000, 78 | ): 79 | cube_dilate = np.array( 80 | [ 81 | [0, 0, 0], 82 | [0, 0, 1], 83 | [0, 1, 0], 84 | [0, 0, -1], 85 | [0, -1, 0], 86 | [0, 1, 1], 87 | [0, -1, 1], 88 | [0, 1, -1], 89 | [0, -1, -1], 90 | 91 | [1, 0, 0], 92 | [1, 0, 1], 93 | [1, 1, 0], 94 | [1, 0, -1], 95 | [1, -1, 0], 96 | [1, 1, 1], 97 | [1, -1, 1], 98 | [1, 1, -1], 99 | [1, -1, -1], 100 | 101 | [-1, 0, 0], 102 | [-1, 0, 1], 103 | [-1, 1, 0], 104 | [-1, 0, -1], 105 | [-1, -1, 0], 106 | [-1, 1, 1], 107 | [-1, -1, 1], 108 | [-1, 1, -1], 109 | [-1, -1, -1], 110 | ] 111 | ) / (volume_resolution * 4 - 1) 112 | 113 | 114 | mesh = o3d.io.read_triangle_mesh(mesh_path) 115 | vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6) 116 | faces = np.asarray(mesh.triangles) 117 | mesh.vertices = o3d.utility.Vector3dVector(vertices) 118 | 119 | voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds( 120 | mesh, 121 | voxel_size=1. / volume_resolution, 122 | min_bound=[-0.5, -0.5, -0.5], 123 | max_bound=[0.5, 0.5, 0.5] 124 | ) 125 | voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()]) 126 | 127 | points_normals_sample = trimesh.Trimesh(vertices=vertices, faces=faces).sample(count=pc_sample_number, return_index=True) 128 | points_sample = points_normals_sample[0].astype(np.float32) 129 | voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds( 130 | o3d.geometry.PointCloud( 131 | o3d.utility.Vector3dVector( 132 | np.clip( 133 | (points_sample[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3), 134 | -0.5 + 1e-6, 0.5 - 1e-6) 135 | ) 136 | ), 137 | voxel_size=1. / volume_resolution, 138 | min_bound=[-0.5, -0.5, -0.5], 139 | max_bound=[0.5, 0.5, 0.5] 140 | ) 141 | voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()]) 142 | voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0)) 143 | 144 | if use_normals: 145 | mesh.compute_triangle_normals() 146 | normals_sample = np.asarray( 147 | mesh.triangle_normals 148 | )[points_normals_sample[1]].astype(np.float32) 149 | points_sample = torch.cat((torch.Tensor(points_sample), torch.Tensor(normals_sample)), axis=-1) 150 | 151 | return voxels, points_sample 152 | 153 | class TripoSFVAEInference(torch.nn.Module): 154 | @dataclass 155 | class Config: 156 | local_pc_encoder_cls: str = "" 157 | local_pc_encoder: dict = field(default_factory=dict) 158 | 159 | encoder_cls: str = "" 160 | encoder: dict = field(default_factory=dict) 161 | 162 | decoder_cls: str = "" 163 | decoder: dict = field(default_factory=dict) 164 | 165 | resolution: int = 256 166 | sample_points_num: int = 819_200 167 | use_normals: bool = True 168 | pruning: bool = False 169 | 170 | weight: Optional[str] = None 171 | 172 | cfg: Config 173 | 174 | def __init__(self, cfg): 175 | super().__init__() 176 | self.cfg = cfg 177 | self.configure() 178 | 179 | def load_weights(self): 180 | if self.cfg.weight is not None: 181 | print("Pretrained VAE Loading...") 182 | state_dict = load_file(self.cfg.weight) 183 | self.load_state_dict(state_dict) 184 | 185 | def configure(self) -> None: 186 | self.local_pc_encoder = find(self.cfg.local_pc_encoder_cls)(**self.cfg.local_pc_encoder).eval() 187 | for p in self.local_pc_encoder.parameters(): 188 | p.requires_grad = False 189 | 190 | self.encoder = find(self.cfg.encoder_cls)(**self.cfg.encoder).eval() 191 | for p in self.encoder.parameters(): 192 | p.requires_grad = False 193 | 194 | self.decoder = find(self.cfg.decoder_cls)(**self.cfg.decoder).eval() 195 | for p in self.decoder.parameters(): 196 | p.requires_grad = False 197 | 198 | self.load_weights() 199 | 200 | @torch.no_grad() 201 | def forward(self, points_sample, sparse_voxel_coords): 202 | with torch.autocast("cuda", dtype=torch.float32): 203 | sparse_pc_features = self.local_pc_encoder(points_sample, sparse_voxel_coords, res=self.cfg.resolution, bbox_size=(-0.5, 0.5)) 204 | sparse_tensor = sp.SparseTensor(sparse_pc_features, sparse_voxel_coords) 205 | latent, posterior = self.encoder(sparse_tensor) 206 | mesh = self.decoder(latent, pruning=self.cfg.pruning) 207 | return mesh 208 | 209 | @classmethod 210 | def from_config(cls, config_path): 211 | config = OmegaConf.load(config_path) 212 | cfg = OmegaConf.merge(OmegaConf.structured(TripoSFVAEInference.Config), config) 213 | return cls(cfg) 214 | 215 | HEADER = """ 216 | # TripoSF VAE Reconstruction [TripoSF](https://github.com/VAST-AI-Research/TripoSF) 217 | ## TripoSF represents a significant leap forward in 3D shape modeling, combining high-resolution capabilities with arbitrary topology support. 218 | ## 📋 Some Tips: 219 | 220 | 1. It is recommanded to enable `pruning` for open-surface objects 221 | 222 | 2. Increasing sampling points is helpful for reconstructing complex shapes 223 | 224 |

By Tripo

225 | 226 | """ 227 | MAX_SEED = np.iinfo(np.int32).max 228 | device = "cuda" if torch.cuda.is_available() else "cpu" 229 | config = "configs/TripoSFVAE_1024.yaml" 230 | 231 | random_hex = get_random_hex() 232 | output_dir = "outputs" 233 | os.makedirs(output_dir, exist_ok=True) 234 | model = TripoSFVAEInference.from_config(config).to(device) 235 | 236 | def run_normalize_mesh(input_mesh): 237 | mesh_gt = normalize_mesh(input_mesh) 238 | mesh_name = os.path.basename(input_mesh).split('.')[0] 239 | mesh_path_gt = f"{output_dir}/{mesh_name}_{random_hex}_normalized.obj" 240 | mesh_normalized = trimesh.Trimesh(vertices=mesh_gt.vertices.tolist(), faces=mesh_gt.faces.tolist()) 241 | mesh_normalized.export(mesh_path_gt) 242 | return mesh_path_gt 243 | 244 | def run_reconstruction(input_mesh, sample_points_num, pruning, seed): 245 | model.cfg.pruning = pruning 246 | model.cfg.sample_points_num = sample_points_num 247 | mesh_name = os.path.basename(input_mesh).split('.')[0] 248 | mesh_path_gt = f"{output_dir}/{mesh_name}_{random_hex}_normalized.obj" 249 | sparse_voxels, points_sample = load_quantized_mesh_original( 250 | mesh_path_gt, 251 | volume_resolution=model.cfg.resolution, 252 | use_normals=model.cfg.use_normals, 253 | pc_sample_number=model.cfg.sample_points_num, 254 | ) 255 | 256 | sparse_voxels, points_sample = sparse_voxels.to(device), points_sample.to(device) 257 | sparse_voxels_sp = torch.cat([torch.zeros_like(sparse_voxels[..., :1]), sparse_voxels], dim=-1).int() 258 | 259 | with torch.cuda.amp.autocast(dtype=torch.float16): 260 | mesh_recon = model(points_sample[None], sparse_voxels_sp)[0] 261 | 262 | mesh_path_recon = f"{output_dir}/{mesh_name}_{random_hex}_reconstructed.obj" 263 | mesh_reconstructed = trimesh.Trimesh(vertices=mesh_recon.vertices.tolist(), faces=mesh_recon.faces.tolist()) 264 | mesh_reconstructed.export(mesh_path_recon) 265 | return mesh_path_recon 266 | 267 | with gr.Blocks(title="TripoSFRecon") as demo: 268 | gr.Markdown(HEADER) 269 | 270 | with gr.Row(): 271 | with gr.Column(): 272 | with gr.Row(): 273 | input_mesh_path = gr.Textbox(label="Please input local path to the mesh to be reconstructed.", placeholder="For example: assets/examples/jacket.obj") 274 | 275 | with gr.Accordion("Reconstruction Settings", open=True): 276 | use_pruning = gr.Checkbox(label="Pruning", value=False) 277 | recon_button = gr.Button("Reconstruct Mesh", variant="primary") 278 | seed = gr.Slider( 279 | label="Seed", 280 | minimum=0, 281 | maximum=MAX_SEED, 282 | step=0, 283 | value=0 284 | ) 285 | sample_points_num = gr.Slider( 286 | label="Sample point number", 287 | minimum=819200, 288 | maximum=8192000, 289 | step=100, 290 | value=819200 291 | ) 292 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 293 | with gr.Column(): 294 | normalized_model_output = gr.Model3D(label="Normalized Mesh", interactive=True) 295 | reconstructed_model_output = gr.Model3D(label="Reconstructed Mesh", interactive=True) 296 | 297 | 298 | recon_button.click( 299 | run_normalize_mesh, 300 | inputs=[input_mesh_path], 301 | outputs=[normalized_model_output] 302 | ).then( 303 | get_random_seed, 304 | inputs=[randomize_seed, seed], 305 | outputs=[seed], 306 | ).then( 307 | run_reconstruction, 308 | inputs=[input_mesh_path, sample_points_num, use_pruning, seed], 309 | outputs=[reconstructed_model_output], 310 | ) 311 | 312 | demo.launch(server_name="0.0.0.0", server_port=12345) 313 | -------------------------------------------------------------------------------- /triposf/models/triposf_vae/decoder.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Copyright (c) 2025 VAST-AI-Research and contributors. 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE 23 | 24 | from typing import * 25 | import torch 26 | import torch.nn as nn 27 | from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 28 | from ...modules import sparse as sp 29 | from .base import SparseTransformerBase 30 | from ...representations import MeshExtractResult 31 | from ...representations.mesh import SparseFeatures2Mesh 32 | from ...modules.sparse.linear import SparseLinear 33 | from ...modules.sparse.nonlinearity import SparseGELU 34 | 35 | class SparseOccHead(nn.Module): 36 | def __init__(self, channels: int, out_channels: int, mlp_ratio: float = 4.0): 37 | super().__init__() 38 | self.mlp = nn.Sequential( 39 | SparseLinear(channels, int(channels * mlp_ratio)), 40 | SparseGELU(approximate="tanh"), 41 | SparseLinear(int(channels * mlp_ratio), out_channels), 42 | ) 43 | 44 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: 45 | return self.mlp(x) 46 | 47 | class SparseSubdivideBlock3d(nn.Module): 48 | """ 49 | A 3D subdivide block that can subdivide the sparse tensor. 50 | 51 | Args: 52 | channels: channels in the inputs and outputs. 53 | out_channels: if specified, the number of output channels. 54 | num_groups: the number of groups for the group norm. 55 | """ 56 | def __init__( 57 | self, 58 | channels: int, 59 | resolution: int, 60 | out_channels: Optional[int] = None, 61 | num_groups: int = 32, 62 | ): 63 | super().__init__() 64 | self.channels = channels 65 | self.resolution = resolution 66 | self.out_resolution = resolution * 2 67 | self.out_channels = out_channels or channels 68 | 69 | self.act_layers = nn.Sequential( 70 | sp.SparseGroupNorm32(num_groups, channels), 71 | sp.SparseSiLU() 72 | ) 73 | 74 | self.sub = sp.SparseSubdivide() 75 | 76 | self.out_layers = nn.Sequential( 77 | sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), 78 | sp.SparseGroupNorm32(num_groups, self.out_channels), 79 | sp.SparseSiLU(), 80 | zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), 81 | ) 82 | 83 | if self.out_channels == channels: 84 | self.skip_connection = nn.Identity() 85 | else: 86 | self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") 87 | 88 | self.pruning_head = SparseOccHead(self.out_channels, out_channels=1) 89 | 90 | def forward(self, x: sp.SparseTensor, pruning=False, training=True) -> sp.SparseTensor: 91 | """ 92 | Apply the block to a Tensor, conditioned on a timestep embedding. 93 | 94 | Args: 95 | x: an [N x C x ...] Tensor of features. 96 | Returns: 97 | an [N x C x ...] Tensor of outputs. 98 | """ 99 | h = self.act_layers(x) 100 | h = self.sub(h) 101 | x = self.sub(x) 102 | h = self.out_layers(h) 103 | h = h + self.skip_connection(x) 104 | if pruning: 105 | occ_prob = self.pruning_head(h) 106 | occ_mask = (occ_prob.feats >= 0.).squeeze(-1) 107 | if training == False: 108 | h = sp.SparseTensor(feats=h.feats[occ_mask], coords=h.coords[occ_mask]) 109 | return h, occ_prob 110 | else: 111 | return h, None 112 | 113 | 114 | class TripoSFVAEDecoder(SparseTransformerBase): 115 | def __init__( 116 | self, 117 | resolution: int, 118 | model_channels: int, 119 | latent_channels: int, 120 | num_blocks: int, 121 | num_heads: Optional[int] = None, 122 | num_head_channels: Optional[int] = 64, 123 | mlp_ratio: float = 4, 124 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 125 | window_size: int = 8, 126 | pe_mode: Literal["ape", "rope"] = "ape", 127 | use_fp16: bool = False, 128 | use_checkpoint: bool = False, 129 | qk_rms_norm: bool = False, 130 | use_sparse_flexicube: bool = True, 131 | use_sparse_sparse_flexicube: bool = False, 132 | representation_config: dict = None, 133 | ): 134 | super().__init__( 135 | in_channels=latent_channels, 136 | model_channels=model_channels, 137 | num_blocks=num_blocks, 138 | num_heads=num_heads, 139 | num_head_channels=num_head_channels, 140 | mlp_ratio=mlp_ratio, 141 | attn_mode=attn_mode, 142 | window_size=window_size, 143 | pe_mode=pe_mode, 144 | use_fp16=use_fp16, 145 | use_checkpoint=use_checkpoint, 146 | qk_rms_norm=qk_rms_norm, 147 | ) 148 | assert not (use_sparse_flexicube == False and use_sparse_sparse_flexicube == True) 149 | self.resolution = resolution 150 | self.rep_config = representation_config 151 | self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False), use_sparse_flexicube=use_sparse_flexicube, use_sparse_sparse_flexicube=use_sparse_sparse_flexicube) 152 | self.out_channels = self.mesh_extractor.feats_channels 153 | self.upsample = nn.ModuleList([ 154 | SparseSubdivideBlock3d( 155 | channels=model_channels, 156 | resolution=resolution, 157 | out_channels=model_channels // 4, 158 | ), 159 | SparseSubdivideBlock3d( 160 | channels=model_channels // 4, 161 | resolution=resolution * 2, 162 | out_channels=model_channels // 8, 163 | ) 164 | ]) 165 | self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) 166 | 167 | self.initialize_weights() 168 | if use_fp16: 169 | self.convert_to_fp16() 170 | 171 | def initialize_weights(self) -> None: 172 | super().initialize_weights() 173 | # Zero-out output layers: 174 | nn.init.constant_(self.out_layer.weight, 0) 175 | nn.init.constant_(self.out_layer.bias, 0) 176 | for module in self.upsample: 177 | nn.init.constant_(module.pruning_head.mlp[-1].weight, 0) 178 | nn.init.constant_(module.pruning_head.mlp[-1].bias, 0) 179 | def convert_to_fp16(self) -> None: 180 | """ 181 | Convert the torso of the model to float16. 182 | """ 183 | super().convert_to_fp16() 184 | self.upsample.apply(convert_module_to_f16) 185 | 186 | def convert_to_fp32(self) -> None: 187 | """ 188 | Convert the torso of the model to float32. 189 | """ 190 | super().convert_to_fp32() 191 | self.upsample.apply(convert_module_to_f32) 192 | 193 | def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: 194 | """ 195 | Convert a batch of network outputs to 3D representations. 196 | 197 | Args: 198 | x: The [N x * x C] sparse tensor output by the network. 199 | 200 | Returns: 201 | list of representations 202 | """ 203 | ret = [] 204 | for i in range(x.shape[0]): 205 | mesh = self.mesh_extractor(x[i].float(), training=self.training) 206 | ret.append(mesh) 207 | 208 | return ret 209 | 210 | @torch.no_grad() 211 | def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4, verbose=False): 212 | 213 | sub_resolution = self.resolution // chunk_size 214 | upsample_ratio = 4 # hard-coded here 215 | assert sub_resolution % padding == 0 216 | out = [] 217 | if verbose: 218 | print(f"Input coords range: x[{x.coords[:, 1].min()}, {x.coords[:, 1].max()}], " 219 | f"y[{x.coords[:, 2].min()}, {x.coords[:, 2].max()}], " 220 | f"z[{x.coords[:, 3].min()}, {x.coords[:, 3].max()}]") 221 | print(f"Resolution: {self.resolution}, sub_resolution: {sub_resolution}") 222 | 223 | for i in range(chunk_size): 224 | for j in range(chunk_size): 225 | for k in range(chunk_size): 226 | # Calculate padded boundaries 227 | start_x = max(0, i * sub_resolution - padding) 228 | end_x = min((i + 1) * sub_resolution + padding, self.resolution) 229 | start_y = max(0, j * sub_resolution - padding) 230 | end_y = min((j + 1) * sub_resolution + padding, self.resolution) 231 | start_z = max(0, k * sub_resolution - padding) 232 | end_z = min((k + 1) * sub_resolution + padding, self.resolution) 233 | 234 | # Store original (unpadded) boundaries for later cropping 235 | orig_start_x = i * sub_resolution 236 | orig_end_x = (i + 1) * sub_resolution 237 | orig_start_y = j * sub_resolution 238 | orig_end_y = (j + 1) * sub_resolution 239 | orig_start_z = k * sub_resolution 240 | orig_end_z = (k + 1) * sub_resolution 241 | 242 | if verbose: 243 | print(f"\nChunk ({i},{j},{k}):") 244 | print(f"Padded bounds: x[{start_x}, {end_x}], y[{start_y}, {end_y}], z[{start_z}, {end_z}]") 245 | print(f"Original bounds: x[{orig_start_x}, {orig_end_x}], y[{orig_start_y}, {orig_end_y}], z[{orig_start_z}, {orig_end_z}]") 246 | 247 | mask = torch.logical_and( 248 | torch.logical_and( 249 | torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x), 250 | torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y) 251 | ), 252 | torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z) 253 | ) 254 | 255 | if mask.sum() > 0: 256 | # Get the coordinates and shift them to local space 257 | coords = x.coords[mask].clone() 258 | if verbose: 259 | print(f"Before local shift - coords range: x[{coords[:, 1].min()}, {coords[:, 1].max()}], " 260 | f"y[{coords[:, 2].min()}, {coords[:, 2].max()}], " 261 | f"z[{coords[:, 3].min()}, {coords[:, 3].max()}]") 262 | 263 | # Shift to local coordinates 264 | coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z], 265 | device=coords.device).view(1, 3) 266 | if verbose: 267 | print(f"After local shift - coords range: x[{coords[:, 1].min()}, {coords[:, 1].max()}], " 268 | f"y[{coords[:, 2].min()}, {coords[:, 2].max()}], " 269 | f"z[{coords[:, 3].min()}, {coords[:, 3].max()}]") 270 | 271 | chunk_tensor = sp.SparseTensor(x.feats[mask], coords) 272 | # Store the boundaries and offsets as metadata for later reconstruction 273 | chunk_tensor.bounds = { 274 | 'padded': (start_x * upsample_ratio, end_x * upsample_ratio + (upsample_ratio - 1), start_y * upsample_ratio, end_y * upsample_ratio + (upsample_ratio - 1), start_z * upsample_ratio, end_z * upsample_ratio + (upsample_ratio - 1)), 275 | 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)), 276 | 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction 277 | } 278 | out.append(chunk_tensor) 279 | 280 | del mask 281 | torch.cuda.empty_cache() 282 | return out 283 | 284 | @torch.no_grad() 285 | def upsamples(self, chunk: sp.SparseTensor, pruning=False): # Only for inferencing 286 | dtype = chunk.dtype 287 | for block in self.upsample: 288 | chunk, _ = block(chunk, pruning=pruning, training=False) 289 | chunk = chunk.type(dtype) 290 | chunk = self.out_layer(chunk) 291 | return chunk 292 | 293 | def forward(self, x: sp.SparseTensor, pruning=False): 294 | batch_size = x.shape[0] 295 | chunk_size = 8 # hard-coded to balance memory usage and reconstruction speed duing inference 296 | if x.coords.shape[0] < 150000: 297 | chunk_size = 4 298 | 299 | h = super().forward(x) 300 | chunks = self.split_for_meshing(h, chunk_size=chunk_size, verbose=False) 301 | all_coords, all_feats = [], [] 302 | 303 | for chunk_idx, chunk in enumerate(chunks): 304 | try: 305 | chunk_result = self.upsamples(chunk, pruning=pruning) 306 | except: 307 | print(f"Failed to process chunk {chunk_idx}: {e}") 308 | continue 309 | 310 | for b in range(batch_size): 311 | mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1) 312 | if mask.numel() > 0: 313 | coords = chunk_result.coords[mask].clone() 314 | 315 | # Restore global coordinates 316 | offsets = torch.tensor(chunk.bounds['offsets'], 317 | device=coords.device).view(1, 3) 318 | coords[:, 1:] = coords[:, 1:] + offsets 319 | 320 | # Filter points within original bounds 321 | bounds = chunk.bounds['original'] 322 | within_bounds = torch.logical_and( 323 | torch.logical_and( 324 | torch.logical_and( 325 | coords[:, 1] >= bounds[0], 326 | coords[:, 1] < bounds[1] 327 | ), 328 | torch.logical_and( 329 | coords[:, 2] >= bounds[2], 330 | coords[:, 2] < bounds[3] 331 | ) 332 | ), 333 | torch.logical_and( 334 | coords[:, 3] >= bounds[4], 335 | coords[:, 3] < bounds[5] 336 | ) 337 | ) 338 | 339 | if within_bounds.any(): 340 | all_coords.append(coords[within_bounds]) 341 | all_feats.append(chunk_result.feats[mask][within_bounds]) 342 | 343 | torch.cuda.empty_cache() 344 | 345 | if len(all_coords) > 0: 346 | final_coords = torch.cat(all_coords) 347 | final_feats = torch.cat(all_feats) 348 | 349 | return self.to_representation(sp.SparseTensor(final_feats, final_coords)) 350 | else: 351 | return self.to_representation(sp.SparseTensor(x.feats[:0], x.coords[:0])) --------------------------------------------------------------------------------