├── assets
└── docs
│ ├── teaser.png
│ └── local_gradio_example.png
├── requirements.txt
├── NOTICE
├── misc.py
├── triposf
├── modules
│ ├── norm.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── modulated.py
│ │ └── blocks.py
│ ├── sparse
│ │ ├── transformer
│ │ │ ├── __init__.py
│ │ │ ├── blocks.py
│ │ │ └── modulated.py
│ │ ├── attention
│ │ │ ├── __init__.py
│ │ │ ├── windowed_attn.py
│ │ │ ├── modules.py
│ │ │ ├── serialized_attn.py
│ │ │ └── full_attn.py
│ │ ├── linear.py
│ │ ├── conv
│ │ │ ├── __init__.py
│ │ │ ├── conv_torchsparse.py
│ │ │ └── conv_spconv.py
│ │ ├── nonlinearity.py
│ │ ├── norm.py
│ │ ├── __init__.py
│ │ └── spatial.py
│ ├── spatial.py
│ ├── attention
│ │ ├── __init__.py
│ │ ├── full_attn.py
│ │ └── modules.py
│ ├── utils.py
│ └── pointclouds
│ │ └── pointnet.py
├── representations
│ ├── __init__.py
│ └── mesh
│ │ ├── __init__.py
│ │ ├── utils_cube.py
│ │ ├── flexicubes
│ │ └── LICENSE.txt
│ │ └── cube2mesh.py
├── models
│ ├── triposf_vae
│ │ ├── __init__.py
│ │ ├── encoder.py
│ │ ├── base.py
│ │ └── decoder.py
│ └── __init__.py
└── __init__.py
├── LICENSE
├── configs
└── TripoSFVAE_1024.yaml
├── README.md
├── inference.py
└── app.py
/assets/docs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VAST-AI-Research/TripoSF/HEAD/assets/docs/teaser.png
--------------------------------------------------------------------------------
/assets/docs/local_gradio_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VAST-AI-Research/TripoSF/HEAD/assets/docs/local_gradio_example.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | trimesh==4.5.3
2 | torch-scatter==2.1.2
3 | open3d==0.18.0
4 | numpy==1.22.2
5 | omegaconf==2.3.0
6 | flash-attn==2.5.9post1
7 | spconv
8 | safetensors
9 | easydict
10 | jaxtyping
11 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | TripoSF
2 | Copyright (c) 2025 VAST-AI-Research and contributors
3 |
4 | This project includes code from the following open source projects:
5 |
6 | TRELLIS
7 | Copyright (c) Microsoft Corporation
8 | License: MIT
9 | Source: https://github.com/Microsoft/TRELLIS
10 |
11 | FlexiCubes
12 | Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES
13 | License: Nvidia Source Code License
14 | Source: https://github.com/MaxtirError/FlexiCubes
--------------------------------------------------------------------------------
/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import importlib
4 |
5 | def find(cls_string):
6 | module_string = ".".join(cls_string.split(".")[:-1])
7 | cls_name = cls_string.split(".")[-1]
8 | module = importlib.import_module(module_string, package=None)
9 | cls = getattr(module, cls_name)
10 | return cls
11 |
12 | def get_rank():
13 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
14 | # therefore LOCAL_RANK needs to be checked first
15 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
16 | for key in rank_keys:
17 | rank = os.environ.get(key)
18 | if rank is not None:
19 | return int(rank)
20 | return 0
21 |
22 | def get_device():
23 | return torch.device(f"cuda:{get_rank()}")
--------------------------------------------------------------------------------
/triposf/modules/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class LayerNorm32(nn.LayerNorm):
6 | def forward(self, x: torch.Tensor) -> torch.Tensor:
7 | return super().forward(x.float()).type(x.dtype)
8 |
9 |
10 | class GroupNorm32(nn.GroupNorm):
11 | """
12 | A GroupNorm layer that converts to float32 before the forward pass.
13 | """
14 | def forward(self, x: torch.Tensor) -> torch.Tensor:
15 | return super().forward(x.float()).type(x.dtype)
16 |
17 |
18 | class ChannelLayerNorm32(LayerNorm32):
19 | def forward(self, x: torch.Tensor) -> torch.Tensor:
20 | DIM = x.dim()
21 | x = x.permute(0, *range(2, DIM), 1).contiguous()
22 | x = super().forward(x)
23 | x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24 | return x
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2025 VAST-AI-Research and contributors.
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE
22 |
--------------------------------------------------------------------------------
/configs/TripoSFVAE_1024.yaml:
--------------------------------------------------------------------------------
1 | weight: ckpts/pretrained_TripoSFVAE_256i1024o.safetensors
2 |
3 | resolution: 256 # 1024 // 4
4 | sample_points_num: 819_200
5 | pruning: false
6 | use_normals: true
7 |
8 | local_pc_encoder_cls: triposf.modules.pointclouds.pointnet.LocalPoolPointnet
9 | local_pc_encoder:
10 | in_channels: 6 # 3 + 3
11 | out_channels: 1024
12 | hidden_dim: 256
13 | scatter_type: mean
14 | n_blocks: 5
15 |
16 | encoder_cls: triposf.models.triposf_vae.encoder.TripoSFVAEEncoder
17 | encoder:
18 | resolution: ${resolution}
19 | in_channels: 1024
20 | model_channels: 768
21 | latent_channels: 8
22 | num_blocks: 12
23 | num_heads: 12
24 | num_head_channels: 64
25 | mlp_ratio: 4
26 | attn_mode: swin
27 | window_size: 8
28 | pe_mode: ape
29 | use_fp16: false
30 | use_checkpoint: true
31 | qk_rms_norm: false
32 |
33 | decoder_cls: triposf.models.triposf_vae.decoder.TripoSFVAEDecoder
34 | decoder:
35 | resolution: ${resolution}
36 | model_channels: 768
37 | latent_channels: 8
38 | num_blocks: 12
39 | num_heads: 12
40 | num_head_channels: 64
41 | mlp_ratio: 4
42 | attn_mode: swin
43 | window_size: 8
44 | pe_mode: ape
45 | use_fp16: false
46 | use_checkpoint: true
47 | qk_rms_norm: false
48 | representation_config:
49 | use_color: false
50 |
--------------------------------------------------------------------------------
/triposf/representations/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from .mesh import MeshExtractResult
25 |
--------------------------------------------------------------------------------
/triposf/models/triposf_vae/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2025 VAST-AI-Research and contributors.
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE
22 |
23 | from .encoder import TripoSFVAEEncoder
24 | from .decoder import TripoSFVAEDecoder
25 |
--------------------------------------------------------------------------------
/triposf/modules/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from .blocks import *
25 | from .modulated import *
--------------------------------------------------------------------------------
/triposf/modules/sparse/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from .blocks import *
25 | from .modulated import *
--------------------------------------------------------------------------------
/triposf/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from . import models
25 | from . import modules
26 | from . import representations
--------------------------------------------------------------------------------
/triposf/representations/mesh/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult
25 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/attention/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from .full_attn import *
25 | from .serialized_attn import *
26 | from .windowed_attn import *
27 | from .modules import *
28 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/linear.py:
--------------------------------------------------------------------------------
1 |
2 | # MIT License
3 |
4 | # Copyright (c) Microsoft Corporation.
5 | # Copyright (c) 2025 VAST-AI-Research and contributors.
6 |
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 |
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 |
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE
24 |
25 | import torch
26 | import torch.nn as nn
27 | from . import SparseTensor
28 |
29 | __all__ = [
30 | 'SparseLinear'
31 | ]
32 |
33 | class SparseLinear(nn.Linear):
34 | def __init__(self, in_features, out_features, bias=True):
35 | super(SparseLinear, self).__init__(in_features, out_features, bias)
36 |
37 | def forward(self, input: SparseTensor) -> SparseTensor:
38 | return input.replace(super().forward(input.feats))
39 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/conv/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from .. import BACKEND
25 |
26 |
27 | SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
28 |
29 | def __from_env():
30 | import os
31 |
32 | global SPCONV_ALGO
33 | env_spconv_algo = os.environ.get('SPCONV_ALGO')
34 | if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
35 | SPCONV_ALGO = env_spconv_algo
36 | print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
37 |
38 |
39 | __from_env()
40 |
41 | if BACKEND == 'torchsparse':
42 | from .conv_torchsparse import *
43 | elif BACKEND == 'spconv':
44 | from .conv_spconv import *
45 |
--------------------------------------------------------------------------------
/triposf/modules/spatial.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
5 | """
6 | 3D pixel shuffle.
7 | """
8 | B, C, H, W, D = x.shape
9 | C_ = C // scale_factor**3
10 | x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
11 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
12 | x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
13 | return x
14 |
15 |
16 | def patchify(x: torch.Tensor, patch_size: int):
17 | """
18 | Patchify a tensor.
19 |
20 | Args:
21 | x (torch.Tensor): (N, C, *spatial) tensor
22 | patch_size (int): Patch size
23 | """
24 | DIM = x.dim() - 2
25 | for d in range(2, DIM + 2):
26 | assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
27 |
28 | x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
29 | x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
30 | x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
31 | return x
32 |
33 |
34 | def unpatchify(x: torch.Tensor, patch_size: int):
35 | """
36 | Unpatchify a tensor.
37 |
38 | Args:
39 | x (torch.Tensor): (N, C, *spatial) tensor
40 | patch_size (int): Patch size
41 | """
42 | DIM = x.dim() - 2
43 | assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
44 |
45 | x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
46 | x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
47 | x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
48 | return x
49 |
--------------------------------------------------------------------------------
/triposf/models/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | import importlib
25 |
26 | __attributes = {
27 | 'TripoSFVAEEncoder': 'triposf_vae',
28 | 'TripoSFVAEDecoder': 'triposf_vae',
29 | }
30 |
31 | __submodules = []
32 |
33 | __all__ = list(__attributes.keys()) + __submodules
34 |
35 | def __getattr__(name):
36 | if name not in globals():
37 | if name in __attributes:
38 | module_name = __attributes[name]
39 | module = importlib.import_module(f".{module_name}", __name__)
40 | globals()[name] = getattr(module, name)
41 | elif name in __submodules:
42 | module = importlib.import_module(f".{name}", __name__)
43 | globals()[name] = module
44 | else:
45 | raise AttributeError(f"module {__name__} has no attribute {name}")
46 | return globals()[name]
47 |
48 | # For Pylance
49 | if __name__ == '__main__':
50 | from .TripoSF_vae import TripoSFVAEEncoder, TripoSFVAEDecoder
51 |
--------------------------------------------------------------------------------
/triposf/modules/attention/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 |
26 | BACKEND = 'flash_attn'
27 | DEBUG = False
28 |
29 | def __from_env():
30 | import os
31 |
32 | global BACKEND
33 | global DEBUG
34 |
35 | env_attn_backend = os.environ.get('ATTN_BACKEND')
36 | env_sttn_debug = os.environ.get('ATTN_DEBUG')
37 |
38 | if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
39 | BACKEND = env_attn_backend
40 | if env_sttn_debug is not None:
41 | DEBUG = env_sttn_debug == '1'
42 |
43 | print(f"[ATTENTION] Using backend: {BACKEND}")
44 |
45 |
46 | __from_env()
47 |
48 |
49 | def set_backend(backend: Literal['xformers', 'flash_attn']):
50 | global BACKEND
51 | BACKEND = backend
52 |
53 | def set_debug(debug: bool):
54 | global DEBUG
55 | DEBUG = debug
56 |
57 |
58 | from .full_attn import *
59 | from .modules import *
60 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/nonlinearity.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | import torch
25 | import torch.nn as nn
26 | from . import SparseTensor
27 |
28 | __all__ = [
29 | 'SparseReLU',
30 | 'SparseSiLU',
31 | 'SparseGELU',
32 | 'SparseActivation'
33 | ]
34 |
35 |
36 | class SparseReLU(nn.ReLU):
37 | def forward(self, input: SparseTensor) -> SparseTensor:
38 | return input.replace(super().forward(input.feats))
39 |
40 |
41 | class SparseSiLU(nn.SiLU):
42 | def forward(self, input: SparseTensor) -> SparseTensor:
43 | return input.replace(super().forward(input.feats))
44 |
45 |
46 | class SparseGELU(nn.GELU):
47 | def forward(self, input: SparseTensor) -> SparseTensor:
48 | return input.replace(super().forward(input.feats))
49 |
50 |
51 | class SparseActivation(nn.Module):
52 | def __init__(self, activation: nn.Module):
53 | super().__init__()
54 | self.activation = activation
55 |
56 | def forward(self, input: SparseTensor) -> SparseTensor:
57 | return input.replace(self.activation(input.feats))
58 |
59 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/conv/conv_torchsparse.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | import torch
25 | import torch.nn as nn
26 | from .. import SparseTensor
27 |
28 | class SparseConv3d(nn.Module):
29 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
30 | super(SparseConv3d, self).__init__()
31 | if 'torchsparse' not in globals():
32 | import torchsparse
33 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
34 |
35 | def forward(self, x: SparseTensor) -> SparseTensor:
36 | out = self.conv(x.data)
37 | new_shape = [x.shape[0], self.conv.out_channels]
38 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
39 | out._spatial_cache = x._spatial_cache
40 | out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
41 | return out
42 |
43 |
44 | class SparseInverseConv3d(nn.Module):
45 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
46 | super(SparseInverseConv3d, self).__init__()
47 | if 'torchsparse' not in globals():
48 | import torchsparse
49 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
50 |
51 | def forward(self, x: SparseTensor) -> SparseTensor:
52 | out = self.conv(x.data)
53 | new_shape = [x.shape[0], self.conv.out_channels]
54 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
55 | out._spatial_cache = x._spatial_cache
56 | out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
57 | return out
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/triposf/modules/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import *
4 |
5 | from ..modules import sparse as sp
6 |
7 | FP16_MODULES = (
8 | nn.Conv1d,
9 | nn.Conv2d,
10 | nn.Conv3d,
11 | nn.ConvTranspose1d,
12 | nn.ConvTranspose2d,
13 | nn.ConvTranspose3d,
14 | nn.Linear,
15 | sp.SparseConv3d,
16 | sp.SparseInverseConv3d,
17 | sp.SparseLinear,
18 | )
19 |
20 | def convert_module_to_f16(l):
21 | """
22 | Convert primitive modules to float16.
23 | """
24 | if isinstance(l, FP16_MODULES):
25 | for p in l.parameters():
26 | p.data = p.data.half()
27 |
28 |
29 | def convert_module_to_f32(l):
30 | """
31 | Convert primitive modules to float32, undoing convert_module_to_f16().
32 | """
33 | if isinstance(l, FP16_MODULES):
34 | for p in l.parameters():
35 | p.data = p.data.float()
36 |
37 |
38 | def zero_module(module):
39 | """
40 | Zero out the parameters of a module and return it.
41 | """
42 | for p in module.parameters():
43 | p.detach().zero_()
44 | return module
45 |
46 |
47 | def scale_module(module, scale):
48 | """
49 | Scale the parameters of a module and return it.
50 | """
51 | for p in module.parameters():
52 | p.detach().mul_(scale)
53 | return module
54 |
55 |
56 | def modulate(x, shift, scale):
57 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
58 |
59 | class DiagonalGaussianDistribution(object):
60 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
61 | self.feat_dim = feat_dim
62 | self.parameters = parameters
63 |
64 | if isinstance(parameters, list):
65 | self.mean = parameters[0]
66 | self.logvar = parameters[1]
67 | else:
68 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
69 |
70 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
71 | self.deterministic = deterministic
72 | self.std = torch.exp(0.5 * self.logvar)
73 | self.var = torch.exp(self.logvar)
74 | if self.deterministic:
75 | self.var = self.std = torch.zeros_like(self.mean)
76 |
77 | def sample(self):
78 | x = self.mean + self.std * torch.randn_like(self.mean)
79 | return x
80 |
81 | def kl(self, other=None, dims=(1, 2, 3)):
82 | if self.deterministic:
83 | return torch.Tensor([0.])
84 | else:
85 | if other is None:
86 | return 0.5 * torch.mean(torch.pow(self.mean, 2)
87 | + self.var - 1.0 - self.logvar,
88 | dim=dims)
89 | else:
90 | return 0.5 * torch.mean(
91 | torch.pow(self.mean - other.mean, 2) / other.var
92 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
93 | dim=dims)
94 |
95 | def nll(self, sample, dims=(1, 2, 3)):
96 | if self.deterministic:
97 | return torch.Tensor([0.])
98 | logtwopi = np.log(2.0 * np.pi)
99 | return 0.5 * torch.sum(
100 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
101 | dim=dims)
102 |
103 | def mode(self):
104 | return self.mean
--------------------------------------------------------------------------------
/triposf/modules/sparse/norm.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | import torch
25 | import torch.nn as nn
26 | from . import SparseTensor
27 | from . import DEBUG
28 |
29 | __all__ = [
30 | 'SparseGroupNorm',
31 | 'SparseLayerNorm',
32 | 'SparseGroupNorm32',
33 | 'SparseLayerNorm32',
34 | ]
35 |
36 |
37 | class SparseGroupNorm(nn.GroupNorm):
38 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
39 | super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
40 |
41 | def forward(self, input: SparseTensor) -> SparseTensor:
42 | nfeats = torch.zeros_like(input.feats)
43 | for k in range(input.shape[0]):
44 | if DEBUG:
45 | assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
46 | bfeats = input.feats[input.layout[k]]
47 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
48 | bfeats = super().forward(bfeats)
49 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
50 | nfeats[input.layout[k]] = bfeats
51 | return input.replace(nfeats)
52 |
53 |
54 | class SparseLayerNorm(nn.LayerNorm):
55 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
56 | super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
57 |
58 | def forward(self, input: SparseTensor) -> SparseTensor:
59 | nfeats = torch.zeros_like(input.feats)
60 | for k in range(input.shape[0]):
61 | bfeats = input.feats[input.layout[k]]
62 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
63 | bfeats = super().forward(bfeats)
64 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
65 | nfeats[input.layout[k]] = bfeats
66 | return input.replace(nfeats)
67 |
68 |
69 | class SparseGroupNorm32(SparseGroupNorm):
70 | """
71 | A GroupNorm layer that converts to float32 before the forward pass.
72 | """
73 | def forward(self, x: SparseTensor) -> SparseTensor:
74 | return super().forward(x.float()).type(x.dtype)
75 |
76 | class SparseLayerNorm32(SparseLayerNorm):
77 | """
78 | A LayerNorm layer that converts to float32 before the forward pass.
79 | """
80 | def forward(self, x: SparseTensor) -> SparseTensor:
81 | return super().forward(x.float()).type(x.dtype)
82 |
--------------------------------------------------------------------------------
/triposf/models/triposf_vae/encoder.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 | from ...modules import sparse as sp
28 | from .base import SparseTransformerBase
29 | from ...modules.utils import DiagonalGaussianDistribution
30 |
31 | class TripoSFVAEEncoder(SparseTransformerBase):
32 | def __init__(
33 | self,
34 | resolution: int,
35 | in_channels: int,
36 | model_channels: int,
37 | latent_channels: int,
38 | num_blocks: int,
39 | num_heads: Optional[int] = None,
40 | num_head_channels: Optional[int] = 64,
41 | mlp_ratio: float = 4,
42 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
43 | window_size: int = 8,
44 | pe_mode: Literal["ape", "rope"] = "ape",
45 | use_fp16: bool = False,
46 | use_checkpoint: bool = False,
47 | qk_rms_norm: bool = False
48 | ):
49 | super().__init__(
50 | in_channels=in_channels,
51 | model_channels=model_channels,
52 | num_blocks=num_blocks,
53 | num_heads=num_heads,
54 | num_head_channels=num_head_channels,
55 | mlp_ratio=mlp_ratio,
56 | attn_mode=attn_mode,
57 | window_size=window_size,
58 | pe_mode=pe_mode,
59 | use_fp16=use_fp16,
60 | use_checkpoint=use_checkpoint,
61 | qk_rms_norm=qk_rms_norm,
62 | )
63 | self.resolution = resolution
64 | self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
65 | self.initialize_weights()
66 | # if use_fp16:
67 | # self.convert_to_fp16()
68 |
69 | def initialize_weights(self) -> None:
70 | super().initialize_weights()
71 | # Zero-out output layers:
72 | nn.init.constant_(self.out_layer.weight, 0)
73 | nn.init.constant_(self.out_layer.bias, 0)
74 |
75 | def forward(self, x: sp.SparseTensor, sample_posterior=True):
76 | h = super().forward(x)
77 | h = h.type(x.dtype)
78 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
79 | h = self.out_layer(h)
80 |
81 | posterior = DiagonalGaussianDistribution(h.feats, feat_dim=-1)
82 |
83 | if sample_posterior:
84 | z = posterior.sample()
85 | else:
86 | z = posterior.mode()
87 |
88 | z = h.replace(z)
89 | return z, posterior
90 |
--------------------------------------------------------------------------------
/triposf/representations/mesh/utils_cube.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | import torch
25 | cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
26 | 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int)
27 | cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]])
28 | cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
29 | 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False)
30 |
31 | def construct_dense_grid(res, device='cuda'):
32 | '''construct a dense grid based on resolution'''
33 | res_v = res + 1
34 | vertsid = torch.arange(res_v ** 3, device=device)
35 | coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten()
36 | cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2]
37 | cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device))
38 | verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1)
39 | return verts, cube_fx8
40 |
41 |
42 | def construct_voxel_grid(coords):
43 | verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3)
44 | verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True)
45 | cubes = inverse_indices.reshape(-1, 8)
46 | return verts_unique, cubes
47 |
48 |
49 | def cubes_to_verts(num_verts, cubes, value, reduce='mean'):
50 | """
51 | Args:
52 | cubes [Vx8] verts index for each cube
53 | value [Vx8xM] value to be scattered
54 | Operation:
55 | reduced[cubes[i][j]][k] += value[i][k]
56 | """
57 | M = value.shape[2] # number of channels
58 | reduced = torch.zeros(num_verts, M, device=cubes.device,dtype=value.dtype)
59 | return torch.scatter_reduce(reduced, 0,
60 | cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1),
61 | value.flatten(0, 1), reduce=reduce, include_self=False)
62 |
63 | def sparse_cube2verts(coords, feats, training=True):
64 | new_coords, cubes = construct_voxel_grid(coords)
65 | new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats)
66 | if training:
67 | con_loss = torch.mean((feats - new_feats[cubes]) ** 2)
68 | else:
69 | con_loss = 0.0
70 | return new_coords, new_feats, con_loss
71 |
72 | def get_sparse_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True):
73 | verts = coords
74 | verts, masks = torch.unique(verts, dim=0, return_inverse=True)
75 | feats_sparse = torch.zeros((len(verts), feats.shape[-1]), device=feats.device, dtype=feats.dtype)
76 | feats_sparse[masks] = feats
77 | return feats_sparse, verts
78 |
79 | def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res):
80 | return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform)
81 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TripoSF: High-Resolution and Arbitrary-Topology 3D Shape Modeling with SparseFlex
2 |
3 |
4 |
5 | [](https://XianglongHe.github.io/TripoSF/index.html)
6 | [](https://arxiv.org/abs/2503.21732)
7 | [](https://huggingface.co/VAST-AI/TripoSF)
8 |
9 | **By [Tripo](https://www.tripo3d.ai)**
10 |
11 |
12 |
13 | 
14 |
15 | ## 🌟 Overview
16 |
17 | TripoSF represents a significant leap forward in 3D shape modeling, combining high-resolution capabilities with arbitrary topology support. Our approach enables:
18 |
19 | - 📈 Ultra-high resolution mesh modeling (up to $1024^3$)
20 | - 🎯 Direct optimization from rendering losses
21 | - 🌐 Efficient handling of open surfaces and complex topologies
22 | - 💾 Dramatic memory reduction through sparse computation
23 | - 🔄 Differentiable mesh extraction with sharp features
24 |
25 | ### SparseFlex
26 |
27 | SparseFlex, the core design powering TripoSF, introduces a sparse voxel structure that:
28 | - Focuses computational resources only on surface-adjacent regions
29 | - Enables natural handling of open surfaces (like cloth or leaves)
30 | - Supports complex internal structures without compromises
31 | - Achieves massive memory reduction compared to dense representations
32 |
33 | ## 🔥 Updates
34 |
35 | * [2025-03] Initial Release:
36 | - Pretrained VAE model weights ($1024^3$ reconstruction)
37 | - Inference scripts and examples
38 | - SparseFlex implementation
39 |
40 | ## 🚀 Getting Started
41 |
42 | ### System Requirements
43 | - CUDA-capable GPU (≥12GB VRAM for $1024^3$ resolution)
44 | - PyTorch 2.0+
45 |
46 | ### Installation
47 |
48 | 1. Clone the repository:
49 | ```bash
50 | git clone https://github.com/VAST-AI-Research/TripoSF.git
51 | cd TripoSF
52 | ```
53 |
54 | 2. Install dependencies:
55 | ```bash
56 | # Install PyTorch (select the correct CUDA version)
57 | pip install torch torchvision --index-url https://download.pytorch.org/whl/{your-cuda-version}
58 |
59 | # Install other dependencies
60 | pip install -r requirements.txt
61 | ```
62 |
63 | ## 💫 Usage
64 |
65 | ### Pretrained Model Setup
66 | 1. Download our pretrained models from [Hugging Face](https://huggingface.co/VAST-AI/TripoSF)
67 | 2. Place the models in the `ckpts/` directory
68 |
69 | ### Running Inference
70 | Basic reconstruction using TripoSFVAE:
71 | ```bash
72 | python inference.py --mesh-path "assets/examples/jacket.obj" \
73 | --output-dir "outputs/" \
74 | --config "configs/TripoSFVAE_1024.yaml"
75 | ```
76 |
77 | ### Local Gradio Example
78 | ```bash
79 | python app.py
80 | ```
81 | 
82 |
83 | ### Optimization Tips 💡
84 |
85 | #### For Open Surfaces
86 | - Enable `pruning` in the configuration:
87 | ```yaml
88 | pruning: true
89 | ```
90 | - Benefits:
91 | - Higher-fidelity reconstruction
92 | - Faster processing
93 | - Better memory efficiency
94 |
95 | #### For Complex Shapes
96 | - Increase sampling density:
97 | ```yaml
98 | sample_points_num: 1638400 # Default: 819200
99 | ```
100 | - Adjust resolution based on detail requirements:
101 | ```yaml
102 | resolution: 1024 # Options: 256, 512, 1024
103 | ```
104 |
105 |
106 | ## 📊 Technical Details
107 |
108 | TripoSF VAE Architecture:
109 | - **Input**: Point clouds (preserving source geometry details)
110 | - **Encoder**: Sparse transformer for efficient geometry encoding
111 | - **Decoder**: Self-pruning upsampling modules maintaining sparsity
112 | - **Output**: High-resolution SparseFlex parameters for mesh extraction
113 |
114 | ## 📝 Citation
115 |
116 | ```bibtex
117 | @article{he2025triposf,
118 | title={SparseFlex: High-Resolution and Arbitrary-Topology 3D Shape Modeling},
119 | author={He, Xianglong and Zou, Zi-Xin and Chen, Chia-Hao and Guo, Yuan-Chen and Liang, Ding and Yuan, Chun and Ouyang, Wanli and Cao, Yan-Pei and Li, Yangguang},
120 | journal={arXiv preprint arXiv:2503.21732},
121 | year={2025}
122 | }
123 | ```
124 |
125 |
126 | ## 📚 Acknowledgements
127 |
128 | Our work builds upon these excellent repositories:
129 | - [Trellis](https://github.com/Microsoft/TRELLIS)
130 | - [Flexicubes](https://github.com/MaxtirError/FlexiCubes)
131 |
132 | ## 📄 License
133 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
134 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 |
26 | BACKEND = 'spconv'
27 | DEBUG = False
28 | ATTN = 'flash_attn'
29 |
30 | def __from_env():
31 | import os
32 |
33 | global BACKEND
34 | global DEBUG
35 | global ATTN
36 |
37 | env_sparse_backend = os.environ.get('SPARSE_BACKEND')
38 | env_sparse_debug = os.environ.get('SPARSE_DEBUG')
39 | env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
40 | if env_sparse_attn is None:
41 | env_sparse_attn = os.environ.get('ATTN_BACKEND')
42 |
43 | if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
44 | BACKEND = env_sparse_backend
45 | if env_sparse_debug is not None:
46 | DEBUG = env_sparse_debug == '1'
47 | if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
48 | ATTN = env_sparse_attn
49 |
50 | print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
51 |
52 |
53 | __from_env()
54 |
55 |
56 | def set_backend(backend: Literal['spconv', 'torchsparse']):
57 | global BACKEND
58 | BACKEND = backend
59 |
60 | def set_debug(debug: bool):
61 | global DEBUG
62 | DEBUG = debug
63 |
64 | def set_attn(attn: Literal['xformers', 'flash_attn']):
65 | global ATTN
66 | ATTN = attn
67 |
68 |
69 | import importlib
70 |
71 | __attributes = {
72 | 'SparseTensor': 'basic',
73 | 'sparse_batch_broadcast': 'basic',
74 | 'sparse_batch_op': 'basic',
75 | 'sparse_cat': 'basic',
76 | 'sparse_unbind': 'basic',
77 | 'SparseGroupNorm': 'norm',
78 | 'SparseLayerNorm': 'norm',
79 | 'SparseGroupNorm32': 'norm',
80 | 'SparseLayerNorm32': 'norm',
81 | 'SparseReLU': 'nonlinearity',
82 | 'SparseSiLU': 'nonlinearity',
83 | 'SparseGELU': 'nonlinearity',
84 | 'SparseActivation': 'nonlinearity',
85 | 'SparseLinear': 'linear',
86 | 'sparse_scaled_dot_product_attention': 'attention',
87 | 'SerializeMode': 'attention',
88 | 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
89 | 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
90 | 'SparseMultiHeadAttention': 'attention',
91 | 'SparseConv3d': 'conv',
92 | 'SparseInverseConv3d': 'conv',
93 | 'SparseDownsample': 'spatial',
94 | 'SparseUpsample': 'spatial',
95 | 'SparseSubdivide' : 'spatial'
96 | }
97 |
98 | __submodules = ['transformer']
99 |
100 | __all__ = list(__attributes.keys()) + __submodules
101 |
102 | def __getattr__(name):
103 | if name not in globals():
104 | if name in __attributes:
105 | module_name = __attributes[name]
106 | module = importlib.import_module(f".{module_name}", __name__)
107 | globals()[name] = getattr(module, name)
108 | elif name in __submodules:
109 | module = importlib.import_module(f".{name}", __name__)
110 | globals()[name] = module
111 | else:
112 | raise AttributeError(f"module {__name__} has no attribute {name}")
113 | return globals()[name]
114 |
115 |
116 | # For Pylance
117 | if __name__ == '__main__':
118 | from .basic import *
119 | from .norm import *
120 | from .nonlinearity import *
121 | from .linear import *
122 | from .attention import *
123 | from .conv import *
124 | from .spatial import *
125 | import transformer
126 |
--------------------------------------------------------------------------------
/triposf/representations/mesh/flexicubes/LICENSE.txt:
--------------------------------------------------------------------------------
1 |
2 | Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 |
4 |
5 | NVIDIA Source Code License for FlexiCubes
6 |
7 |
8 | =======================================================================
9 |
10 | 1. Definitions
11 |
12 | “Licensor” means any person or entity that distributes its Work.
13 |
14 | “Work” means (a) the original work of authorship made available under
15 | this license, which may include software, documentation, or other files,
16 | and (b) any additions to or derivative works thereof that are made
17 | available under this license.
18 |
19 | The terms “reproduce,” “reproduction,” “derivative works,” and
20 | “distribution” have the meaning as provided under U.S. copyright law;
21 | provided, however, that for the purposes of this license, derivative works
22 | shall not include works that remain separable from, or merely link
23 | (or bind by name) to the interfaces of, the Work.
24 |
25 | Works are “made available” under this license by including in or with
26 | the Work either (a) a copyright notice referencing the applicability of
27 | this license to the Work, or (b) a copy of this license.
28 |
29 | 2. License Grant
30 |
31 | 2.1 Copyright Grant. Subject to the terms and conditions of this license,
32 | each Licensor grants to you a perpetual, worldwide, non-exclusive,
33 | royalty-free, copyright license to use, reproduce, prepare derivative
34 | works of, publicly display, publicly perform, sublicense and distribute
35 | its Work and any resulting derivative works in any form.
36 |
37 | 3. Limitations
38 |
39 | 3.1 Redistribution. You may reproduce or distribute the Work only if
40 | (a) you do so under this license, (b) you include a complete copy of
41 | this license with your distribution, and (c) you retain without
42 | modification any copyright, patent, trademark, or attribution notices
43 | that are present in the Work.
44 |
45 | 3.2 Derivative Works. You may specify that additional or different terms
46 | apply to the use, reproduction, and distribution of your derivative
47 | works of the Work (“Your Terms”) only if (a) Your Terms provide that the
48 | use limitation in Section 3.3 applies to your derivative works, and (b)
49 | you identify the specific derivative works that are subject to Your Terms.
50 | Notwithstanding Your Terms, this license (including the redistribution
51 | requirements in Section 3.1) will continue to apply to the Work itself.
52 |
53 | 3.3 Use Limitation. The Work and any derivative works thereof only may be
54 | used or intended for use non-commercially. Notwithstanding the foregoing,
55 | NVIDIA Corporation and its affiliates may use the Work and any derivative
56 | works commercially. As used herein, “non-commercially” means for research
57 | or evaluation purposes only.
58 |
59 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against
60 | any Licensor (including any claim, cross-claim or counterclaim in a lawsuit)
61 | to enforce any patents that you allege are infringed by any Work, then your
62 | rights under this license from such Licensor (including the grant in
63 | Section 2.1) will terminate immediately.
64 |
65 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s
66 | or its affiliates’ names, logos, or trademarks, except as necessary to
67 | reproduce the notices described in this license.
68 |
69 | 3.6 Termination. If you violate any term of this license, then your rights
70 | under this license (including the grant in Section 2.1) will terminate
71 | immediately.
72 |
73 | 4. Disclaimer of Warranty.
74 |
75 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
76 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
77 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
78 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
79 |
80 | 5. Limitation of Liability.
81 |
82 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY,
83 | WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY
84 | LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL,
85 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE,
86 | THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF
87 | GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR
88 | MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN
89 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
90 |
91 | =======================================================================
92 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/conv/conv_spconv.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | import torch
25 | import torch.nn as nn
26 | from .. import SparseTensor
27 | from .. import DEBUG
28 | from . import SPCONV_ALGO
29 |
30 | class SparseConv3d(nn.Module):
31 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
32 | super(SparseConv3d, self).__init__()
33 | if 'spconv' not in globals():
34 | import spconv.pytorch as spconv
35 | algo = None
36 | if SPCONV_ALGO == 'native':
37 | algo = spconv.ConvAlgo.Native
38 | elif SPCONV_ALGO == 'implicit_gemm':
39 | algo = spconv.ConvAlgo.MaskImplicitGemm
40 | if stride == 1 and (padding is None):
41 | self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
42 | else:
43 | self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
44 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
45 | self.padding = padding
46 |
47 | def forward(self, x: SparseTensor) -> SparseTensor:
48 | spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
49 |
50 | dtype_ = x.feats.dtype
51 | x = x.replace(x.feats.type(torch.float32))
52 | new_data = self.conv(x.data)
53 | new_shape = [x.shape[0], self.conv.out_channels]
54 | new_layout = None if spatial_changed else x.layout
55 |
56 | if spatial_changed and (x.shape[0] != 1):
57 | # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
58 | fwd = new_data.indices[:, 0].argsort()
59 | bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
60 | sorted_feats = new_data.features[fwd]
61 | sorted_coords = new_data.indices[fwd]
62 | unsorted_data = new_data
63 | new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
64 |
65 | out = SparseTensor(
66 | new_data, shape=torch.Size(new_shape), layout=new_layout,
67 | scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
68 | spatial_cache=x._spatial_cache,
69 | )
70 | out = out.replace(out.feats.type(dtype_))
71 |
72 | if spatial_changed and (x.shape[0] != 1):
73 | out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
74 | out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
75 |
76 | return out
77 |
78 | class SparseInverseConv3d(nn.Module):
79 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
80 | super(SparseInverseConv3d, self).__init__()
81 | if 'spconv' not in globals():
82 | import spconv.pytorch as spconv
83 | self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
84 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
85 |
86 | def forward(self, x: SparseTensor) -> SparseTensor:
87 | spatial_changed = any(s != 1 for s in self.stride)
88 | if spatial_changed:
89 | # recover the original spconv order
90 | data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
91 | bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
92 | data = data.replace_feature(x.feats[bwd])
93 | if DEBUG:
94 | assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
95 | else:
96 | data = x.data
97 |
98 | new_data = self.conv(data)
99 | new_shape = [x.shape[0], self.conv.out_channels]
100 | new_layout = None if spatial_changed else x.layout
101 | out = SparseTensor(
102 | new_data, shape=torch.Size(new_shape), layout=new_layout,
103 | scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
104 | spatial_cache=x._spatial_cache,
105 | )
106 | return out
107 |
--------------------------------------------------------------------------------
/triposf/models/triposf_vae/base.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | from ...modules.utils import convert_module_to_f16, convert_module_to_f32
28 | from ...modules import sparse as sp
29 | from ...modules.transformer import AbsolutePositionEmbedder
30 | from ...modules.sparse.transformer import SparseTransformerBlock
31 |
32 |
33 | def block_attn_config(self):
34 | """
35 | Return the attention configuration of the model.
36 | """
37 | for i in range(self.num_blocks):
38 | if self.attn_mode == "shift_window":
39 | yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
40 | elif self.attn_mode == "shift_sequence":
41 | yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
42 | elif self.attn_mode == "shift_order":
43 | yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
44 | elif self.attn_mode == "full":
45 | yield "full", None, None, None, None
46 | elif self.attn_mode == "swin":
47 | yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
48 |
49 |
50 | class SparseTransformerBase(nn.Module):
51 | """
52 | Sparse Transformer without output layers.
53 | Serve as the base class for encoder and decoder.
54 | """
55 | def __init__(
56 | self,
57 | in_channels: int,
58 | model_channels: int,
59 | num_blocks: int,
60 | num_heads: Optional[int] = None,
61 | num_head_channels: Optional[int] = 64,
62 | mlp_ratio: float = 4.0,
63 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
64 | window_size: Optional[int] = None,
65 | pe_mode: Literal["ape", "rope"] = "ape",
66 | use_fp16: bool = False,
67 | use_checkpoint: bool = False,
68 | qk_rms_norm: bool = False,
69 | ):
70 | super().__init__()
71 | self.in_channels = in_channels
72 | self.model_channels = model_channels
73 | self.num_blocks = num_blocks
74 | self.window_size = window_size
75 | self.num_heads = num_heads or model_channels // num_head_channels
76 | self.mlp_ratio = mlp_ratio
77 | self.attn_mode = attn_mode
78 | self.pe_mode = pe_mode
79 | self.use_fp16 = use_fp16
80 | self.use_checkpoint = use_checkpoint
81 | self.qk_rms_norm = qk_rms_norm
82 | self.dtype = torch.float16 if use_fp16 else torch.float32
83 |
84 | if pe_mode == "ape":
85 | self.pos_embedder = AbsolutePositionEmbedder(model_channels)
86 |
87 | self.input_layer = sp.SparseLinear(in_channels, model_channels)
88 | self.blocks = nn.ModuleList([
89 | SparseTransformerBlock(
90 | model_channels,
91 | num_heads=self.num_heads,
92 | mlp_ratio=self.mlp_ratio,
93 | attn_mode=attn_mode,
94 | window_size=window_size,
95 | shift_sequence=shift_sequence,
96 | shift_window=shift_window,
97 | serialize_mode=serialize_mode,
98 | use_checkpoint=self.use_checkpoint,
99 | use_rope=(pe_mode == "rope"),
100 | qk_rms_norm=self.qk_rms_norm,
101 | )
102 | for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
103 | ])
104 |
105 | @property
106 | def device(self) -> torch.device:
107 | """
108 | Return the device of the model.
109 | """
110 | return next(self.parameters()).device
111 |
112 | def convert_to_fp16(self) -> None:
113 | """
114 | Convert the torso of the model to float16.
115 | """
116 | self.blocks.apply(convert_module_to_f16)
117 |
118 | def convert_to_fp32(self) -> None:
119 | """
120 | Convert the torso of the model to float32.
121 | """
122 | self.blocks.apply(convert_module_to_f32)
123 |
124 | def initialize_weights(self) -> None:
125 | # Initialize transformer layers:
126 | def _basic_init(module):
127 | if isinstance(module, nn.Linear):
128 | torch.nn.init.xavier_uniform_(module.weight)
129 | if module.bias is not None:
130 | nn.init.constant_(module.bias, 0)
131 | self.apply(_basic_init)
132 |
133 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
134 | h = self.input_layer(x)
135 | if self.pe_mode == "ape":
136 | h = h + self.pos_embedder(x.coords[:, 1:])
137 | # h = h.type(self.dtype)
138 | for block in self.blocks:
139 | h = block(h)
140 | return h
141 |
--------------------------------------------------------------------------------
/triposf/modules/attention/full_attn.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import math
27 | from . import DEBUG, BACKEND
28 |
29 | if BACKEND == 'xformers':
30 | import xformers.ops as xops
31 | elif BACKEND == 'flash_attn':
32 | import flash_attn
33 | elif BACKEND == 'sdpa':
34 | from torch.nn.functional import scaled_dot_product_attention as sdpa
35 | elif BACKEND == 'naive':
36 | pass
37 | else:
38 | raise ValueError(f"Unknown attention backend: {BACKEND}")
39 |
40 |
41 | __all__ = [
42 | 'scaled_dot_product_attention',
43 | ]
44 |
45 |
46 | def _naive_sdpa(q, k, v):
47 | """
48 | Naive implementation of scaled dot product attention.
49 | """
50 | q = q.permute(0, 2, 1, 3) # [N, H, L, C]
51 | k = k.permute(0, 2, 1, 3) # [N, H, L, C]
52 | v = v.permute(0, 2, 1, 3) # [N, H, L, C]
53 | scale_factor = 1 / math.sqrt(q.size(-1))
54 | attn_weight = q @ k.transpose(-2, -1) * scale_factor
55 | attn_weight = torch.softmax(attn_weight, dim=-1)
56 | out = attn_weight @ v
57 | out = out.permute(0, 2, 1, 3) # [N, L, H, C]
58 | return out
59 |
60 |
61 | @overload
62 | def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
63 | """
64 | Apply scaled dot product attention.
65 |
66 | Args:
67 | qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
68 | """
69 | ...
70 |
71 | @overload
72 | def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
73 | """
74 | Apply scaled dot product attention.
75 |
76 | Args:
77 | q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
78 | kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
79 | """
80 | ...
81 |
82 | @overload
83 | def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
84 | """
85 | Apply scaled dot product attention.
86 |
87 | Args:
88 | q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
89 | k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
90 | v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
91 |
92 | Note:
93 | k and v are assumed to have the same coordinate map.
94 | """
95 | ...
96 |
97 | def scaled_dot_product_attention(*args, **kwargs):
98 | arg_names_dict = {
99 | 1: ['qkv'],
100 | 2: ['q', 'kv'],
101 | 3: ['q', 'k', 'v']
102 | }
103 | num_all_args = len(args) + len(kwargs)
104 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
105 | for key in arg_names_dict[num_all_args][len(args):]:
106 | assert key in kwargs, f"Missing argument {key}"
107 |
108 | if num_all_args == 1:
109 | qkv = args[0] if len(args) > 0 else kwargs['qkv']
110 | assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
111 | device = qkv.device
112 |
113 | elif num_all_args == 2:
114 | q = args[0] if len(args) > 0 else kwargs['q']
115 | kv = args[1] if len(args) > 1 else kwargs['kv']
116 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
117 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
118 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
119 | device = q.device
120 |
121 | elif num_all_args == 3:
122 | q = args[0] if len(args) > 0 else kwargs['q']
123 | k = args[1] if len(args) > 1 else kwargs['k']
124 | v = args[2] if len(args) > 2 else kwargs['v']
125 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
126 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
127 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
128 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
129 | device = q.device
130 |
131 | if BACKEND == 'xformers':
132 | if num_all_args == 1:
133 | q, k, v = qkv.unbind(dim=2)
134 | elif num_all_args == 2:
135 | k, v = kv.unbind(dim=2)
136 | out = xops.memory_efficient_attention(q, k, v)
137 | elif BACKEND == 'flash_attn':
138 | if num_all_args == 1:
139 | out = flash_attn.flash_attn_qkvpacked_func(qkv)
140 | elif num_all_args == 2:
141 | out = flash_attn.flash_attn_kvpacked_func(q, kv)
142 | elif num_all_args == 3:
143 | out = flash_attn.flash_attn_func(q, k, v)
144 | elif BACKEND == 'sdpa':
145 | if num_all_args == 1:
146 | q, k, v = qkv.unbind(dim=2)
147 | elif num_all_args == 2:
148 | k, v = kv.unbind(dim=2)
149 | q = q.permute(0, 2, 1, 3) # [N, H, L, C]
150 | k = k.permute(0, 2, 1, 3) # [N, H, L, C]
151 | v = v.permute(0, 2, 1, 3) # [N, H, L, C]
152 | out = sdpa(q, k, v) # [N, H, L, C]
153 | out = out.permute(0, 2, 1, 3) # [N, L, H, C]
154 | elif BACKEND == 'naive':
155 | if num_all_args == 1:
156 | q, k, v = qkv.unbind(dim=2)
157 | elif num_all_args == 2:
158 | k, v = kv.unbind(dim=2)
159 | out = _naive_sdpa(q, k, v)
160 | else:
161 | raise ValueError(f"Unknown attention module: {BACKEND}")
162 |
163 | return out
164 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/spatial.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | from . import SparseTensor
28 |
29 | __all__ = [
30 | 'SparseDownsample',
31 | 'SparseUpsample',
32 | 'SparseSubdivide'
33 | ]
34 |
35 |
36 | class SparseDownsample(nn.Module):
37 | """
38 | Downsample a sparse tensor by a factor of `factor`.
39 | Implemented as average pooling.
40 | """
41 | def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
42 | super(SparseDownsample, self).__init__()
43 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
44 |
45 | def forward(self, input: SparseTensor) -> SparseTensor:
46 | DIM = input.coords.shape[-1] - 1
47 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
48 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
49 |
50 | coord = list(input.coords.unbind(dim=-1))
51 | for i, f in enumerate(factor):
52 | coord[i+1] = coord[i+1] // f
53 |
54 | MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
55 | OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
56 | code = sum([c * o for c, o in zip(coord, OFFSET)])
57 | code, idx = code.unique(return_inverse=True)
58 |
59 | new_feats = torch.scatter_reduce(
60 | torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
61 | dim=0,
62 | index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
63 | src=input.feats,
64 | reduce='mean'
65 | )
66 | new_coords = torch.stack(
67 | [code // OFFSET[0]] +
68 | [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
69 | dim=-1
70 | )
71 | out = SparseTensor(new_feats, new_coords, input.shape,)
72 | out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
73 | out._spatial_cache = input._spatial_cache
74 |
75 | if out.get_spatial_cache(f'upsample_{factor}_coords') is not None:
76 | out.register_spatial_cache(f'upsample_{factor}_coords', [*out.get_spatial_cache(f'upsample_{factor}_coords'), input.coords])
77 | out.register_spatial_cache(f'upsample_{factor}_layout', [*out.get_spatial_cache(f'upsample_{factor}_layout'), input.layout])
78 | out.register_spatial_cache(f'upsample_{factor}_idx', [*out.get_spatial_cache(f'upsample_{factor}_idx'), idx])
79 | else:
80 | out.register_spatial_cache(f'upsample_{factor}_coords', [input.coords])
81 | out.register_spatial_cache(f'upsample_{factor}_layout', [input.layout])
82 | out.register_spatial_cache(f'upsample_{factor}_idx', [idx])
83 |
84 | return out
85 |
86 |
87 | class SparseUpsample(nn.Module):
88 | """
89 | Upsample a sparse tensor by a factor of `factor`.
90 | Implemented as nearest neighbor interpolation.
91 | """
92 | def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
93 | super(SparseUpsample, self).__init__()
94 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
95 |
96 | def forward(self, input: SparseTensor) -> SparseTensor:
97 | DIM = input.coords.shape[-1] - 1
98 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
99 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
100 |
101 | new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
102 | new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
103 | idx = input.get_spatial_cache(f'upsample_{factor}_idx')
104 | # print(len(new_coords))
105 | new_coords = new_coords.pop(-1)
106 | new_layout = new_layout.pop(-1)
107 | idx = idx.pop(-1)
108 | if any([x is None for x in [new_coords, new_layout, idx]]):
109 | raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
110 | new_feats = input.feats[idx]
111 | out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
112 | out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
113 | out._spatial_cache = input._spatial_cache
114 | return out
115 |
116 | class SparseSubdivide(nn.Module):
117 | """
118 | Upsample a sparse tensor by a factor of `factor`.
119 | Implemented as nearest neighbor interpolation.
120 | """
121 | def __init__(self):
122 | super(SparseSubdivide, self).__init__()
123 |
124 | def forward(self, input: SparseTensor) -> SparseTensor:
125 | DIM = input.coords.shape[-1] - 1
126 | # upsample scale=2^DIM
127 | n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
128 | n_coords = torch.nonzero(n_cube)
129 | n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
130 | factor = n_coords.shape[0]
131 | assert factor == 2 ** DIM
132 | # print(n_coords.shape)
133 | new_coords = input.coords.clone()
134 | new_coords[:, 1:] *= 2
135 | new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
136 |
137 | new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
138 | out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
139 | out._scale = input._scale * 2
140 | out._spatial_cache = input._spatial_cache
141 | return out
142 |
143 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/transformer/blocks.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | import torch.utils.checkpoint
28 | from ..basic import SparseTensor
29 | from ..linear import SparseLinear
30 | from ..nonlinearity import SparseGELU
31 | from ..attention import SparseMultiHeadAttention, SerializeMode
32 | from ...norm import LayerNorm32
33 |
34 |
35 | class SparseFeedForwardNet(nn.Module):
36 | def __init__(self, channels: int, mlp_ratio: float = 4.0):
37 | super().__init__()
38 | self.mlp = nn.Sequential(
39 | SparseLinear(channels, int(channels * mlp_ratio)),
40 | SparseGELU(approximate="tanh"),
41 | SparseLinear(int(channels * mlp_ratio), channels),
42 | )
43 |
44 | def forward(self, x: SparseTensor) -> SparseTensor:
45 | return self.mlp(x)
46 |
47 |
48 | class SparseTransformerBlock(nn.Module):
49 | """
50 | Sparse Transformer block (MSA + FFN).
51 | """
52 | def __init__(
53 | self,
54 | channels: int,
55 | num_heads: int,
56 | mlp_ratio: float = 4.0,
57 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
58 | window_size: Optional[int] = None,
59 | shift_sequence: Optional[int] = None,
60 | shift_window: Optional[Tuple[int, int, int]] = None,
61 | serialize_mode: Optional[SerializeMode] = None,
62 | use_checkpoint: bool = False,
63 | use_rope: bool = False,
64 | qk_rms_norm: bool = False,
65 | qkv_bias: bool = True,
66 | ln_affine: bool = False,
67 | ):
68 | super().__init__()
69 | self.use_checkpoint = use_checkpoint
70 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
71 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
72 | self.attn = SparseMultiHeadAttention(
73 | channels,
74 | num_heads=num_heads,
75 | attn_mode=attn_mode,
76 | window_size=window_size,
77 | shift_sequence=shift_sequence,
78 | shift_window=shift_window,
79 | serialize_mode=serialize_mode,
80 | qkv_bias=qkv_bias,
81 | use_rope=use_rope,
82 | qk_rms_norm=qk_rms_norm,
83 | )
84 | self.mlp = SparseFeedForwardNet(
85 | channels,
86 | mlp_ratio=mlp_ratio,
87 | )
88 |
89 | def _forward(self, x: SparseTensor) -> SparseTensor:
90 | h = x.replace(self.norm1(x.feats))
91 | h = self.attn(h)
92 | x = x + h
93 | h = x.replace(self.norm2(x.feats))
94 | h = self.mlp(h)
95 | x = x + h
96 | return x
97 |
98 | def forward(self, x: SparseTensor) -> SparseTensor:
99 | if self.use_checkpoint:
100 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
101 | else:
102 | return self._forward(x)
103 |
104 |
105 | class SparseTransformerCrossBlock(nn.Module):
106 | """
107 | Sparse Transformer cross-attention block (MSA + MCA + FFN).
108 | """
109 | def __init__(
110 | self,
111 | channels: int,
112 | ctx_channels: int,
113 | num_heads: int,
114 | mlp_ratio: float = 4.0,
115 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
116 | window_size: Optional[int] = None,
117 | shift_sequence: Optional[int] = None,
118 | shift_window: Optional[Tuple[int, int, int]] = None,
119 | serialize_mode: Optional[SerializeMode] = None,
120 | use_checkpoint: bool = False,
121 | use_rope: bool = False,
122 | qk_rms_norm: bool = False,
123 | qk_rms_norm_cross: bool = False,
124 | qkv_bias: bool = True,
125 | ln_affine: bool = False,
126 | ):
127 | super().__init__()
128 | self.use_checkpoint = use_checkpoint
129 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
130 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
131 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
132 | self.self_attn = SparseMultiHeadAttention(
133 | channels,
134 | num_heads=num_heads,
135 | type="self",
136 | attn_mode=attn_mode,
137 | window_size=window_size,
138 | shift_sequence=shift_sequence,
139 | shift_window=shift_window,
140 | serialize_mode=serialize_mode,
141 | qkv_bias=qkv_bias,
142 | use_rope=use_rope,
143 | qk_rms_norm=qk_rms_norm,
144 | )
145 | self.cross_attn = SparseMultiHeadAttention(
146 | channels,
147 | ctx_channels=ctx_channels,
148 | num_heads=num_heads,
149 | type="cross",
150 | attn_mode="full",
151 | qkv_bias=qkv_bias,
152 | qk_rms_norm=qk_rms_norm_cross,
153 | )
154 | self.mlp = SparseFeedForwardNet(
155 | channels,
156 | mlp_ratio=mlp_ratio,
157 | )
158 |
159 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
160 | h = x.replace(self.norm1(x.feats))
161 | h = self.self_attn(h)
162 | x = x + h
163 | h = x.replace(self.norm2(x.feats))
164 | h = self.cross_attn(h, context)
165 | x = x + h
166 | h = x.replace(self.norm3(x.feats))
167 | h = self.mlp(h)
168 | x = x + h
169 | return x
170 |
171 | def forward(self, x: SparseTensor, context: torch.Tensor):
172 | if self.use_checkpoint:
173 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
174 | else:
175 | return self._forward(x, context)
176 |
--------------------------------------------------------------------------------
/triposf/modules/transformer/modulated.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | import torch.utils.checkpoint
28 | from ..attention import MultiHeadAttention
29 | from ..norm import LayerNorm32
30 | from .blocks import FeedForwardNet
31 |
32 |
33 | class ModulatedTransformerBlock(nn.Module):
34 | """
35 | Transformer block (MSA + FFN) with adaptive layer norm conditioning.
36 | """
37 | def __init__(
38 | self,
39 | channels: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | attn_mode: Literal["full", "windowed"] = "full",
43 | window_size: Optional[int] = None,
44 | shift_window: Optional[Tuple[int, int, int]] = None,
45 | use_checkpoint: bool = False,
46 | use_rope: bool = False,
47 | qk_rms_norm: bool = False,
48 | qkv_bias: bool = True,
49 | share_mod: bool = False,
50 | ):
51 | super().__init__()
52 | self.use_checkpoint = use_checkpoint
53 | self.share_mod = share_mod
54 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
55 | self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
56 | self.attn = MultiHeadAttention(
57 | channels,
58 | num_heads=num_heads,
59 | attn_mode=attn_mode,
60 | window_size=window_size,
61 | shift_window=shift_window,
62 | qkv_bias=qkv_bias,
63 | use_rope=use_rope,
64 | qk_rms_norm=qk_rms_norm,
65 | )
66 | self.mlp = FeedForwardNet(
67 | channels,
68 | mlp_ratio=mlp_ratio,
69 | )
70 | if not share_mod:
71 | self.adaLN_modulation = nn.Sequential(
72 | nn.SiLU(),
73 | nn.Linear(channels, 6 * channels, bias=True)
74 | )
75 |
76 | def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
77 | if self.share_mod:
78 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
79 | else:
80 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
81 | h = self.norm1(x)
82 | h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
83 | h = self.attn(h)
84 | h = h * gate_msa.unsqueeze(1)
85 | x = x + h
86 | h = self.norm2(x)
87 | h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
88 | h = self.mlp(h)
89 | h = h * gate_mlp.unsqueeze(1)
90 | x = x + h
91 | return x
92 |
93 | def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
94 | if self.use_checkpoint:
95 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
96 | else:
97 | return self._forward(x, mod)
98 |
99 |
100 | class ModulatedTransformerCrossBlock(nn.Module):
101 | """
102 | Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
103 | """
104 | def __init__(
105 | self,
106 | channels: int,
107 | ctx_channels: int,
108 | num_heads: int,
109 | mlp_ratio: float = 4.0,
110 | attn_mode: Literal["full", "windowed"] = "full",
111 | window_size: Optional[int] = None,
112 | shift_window: Optional[Tuple[int, int, int]] = None,
113 | use_checkpoint: bool = False,
114 | use_rope: bool = False,
115 | qk_rms_norm: bool = False,
116 | qk_rms_norm_cross: bool = False,
117 | qkv_bias: bool = True,
118 | share_mod: bool = False,
119 | ):
120 | super().__init__()
121 | self.use_checkpoint = use_checkpoint
122 | self.share_mod = share_mod
123 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
124 | self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
125 | self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
126 | self.self_attn = MultiHeadAttention(
127 | channels,
128 | num_heads=num_heads,
129 | type="self",
130 | attn_mode=attn_mode,
131 | window_size=window_size,
132 | shift_window=shift_window,
133 | qkv_bias=qkv_bias,
134 | use_rope=use_rope,
135 | qk_rms_norm=qk_rms_norm,
136 | )
137 | self.cross_attn = MultiHeadAttention(
138 | channels,
139 | ctx_channels=ctx_channels,
140 | num_heads=num_heads,
141 | type="cross",
142 | attn_mode="full",
143 | qkv_bias=qkv_bias,
144 | qk_rms_norm=qk_rms_norm_cross,
145 | )
146 | self.mlp = FeedForwardNet(
147 | channels,
148 | mlp_ratio=mlp_ratio,
149 | )
150 | if not share_mod:
151 | self.adaLN_modulation = nn.Sequential(
152 | nn.SiLU(),
153 | nn.Linear(channels, 6 * channels, bias=True)
154 | )
155 |
156 | def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
157 | if self.share_mod:
158 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
159 | else:
160 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
161 | h = self.norm1(x)
162 | h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
163 | h = self.self_attn(h)
164 | h = h * gate_msa.unsqueeze(1)
165 | x = x + h
166 | h = self.norm2(x)
167 | h = self.cross_attn(h, context)
168 | x = x + h
169 | h = self.norm3(x)
170 | h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
171 | h = self.mlp(h)
172 | h = h * gate_mlp.unsqueeze(1)
173 | x = x + h
174 | return x
175 |
176 | def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
177 | if self.use_checkpoint:
178 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
179 | else:
180 | return self._forward(x, mod, context)
181 |
--------------------------------------------------------------------------------
/triposf/representations/mesh/cube2mesh.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | import torch
25 | from ...modules.sparse import SparseTensor
26 | from easydict import EasyDict as edict
27 | from .utils_cube import *
28 | from .flexicubes.flexicubes import FlexiCubes
29 |
30 | class MeshExtractResult:
31 | def __init__(self,
32 | vertices,
33 | faces,
34 | vertex_attrs=None,
35 | res=64
36 | ):
37 | self.vertices = vertices
38 | self.faces = faces.long()
39 | self.vertex_attrs = vertex_attrs
40 | self.face_normal = self.compute_face_normals(vertices, faces)
41 | self.res = res
42 | self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0)
43 |
44 | # training only
45 | self.tsdf_v = None
46 | self.tsdf_s = None
47 |
48 | def compute_face_normals(self, verts, faces):
49 | i0 = faces[..., 0].long()
50 | i1 = faces[..., 1].long()
51 | i2 = faces[..., 2].long()
52 |
53 | v0 = verts[i0, :]
54 | v1 = verts[i1, :]
55 | v2 = verts[i2, :]
56 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
57 | face_normals = torch.nn.functional.normalize(face_normals, dim=1)
58 |
59 | return face_normals[:, None, :].repeat(1, 3, 1)
60 |
61 | def comput_v_normals(self, verts, faces):
62 | i0 = faces[..., 0].long()
63 | i1 = faces[..., 1].long()
64 | i2 = faces[..., 2].long()
65 |
66 | v0 = verts[i0, :]
67 | v1 = verts[i1, :]
68 | v2 = verts[i2, :]
69 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
70 | v_normals = torch.zeros_like(verts)
71 | v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
72 | v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
73 | v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
74 |
75 | v_normals = torch.nn.functional.normalize(v_normals, dim=1)
76 | return v_normals
77 |
78 |
79 | class SparseFeatures2Mesh:
80 | def __init__(self, device="cuda", res=64, use_color=False, use_sparse_flexicube=True, use_sparse_sparse_flexicube=False):
81 | '''
82 | a model to generate a mesh from sparse features structures using flexicube
83 | '''
84 | super().__init__()
85 | self.device=device
86 | self.res = res
87 | self.mesh_extractor = FlexiCubes(device=device)
88 | self.sdf_bias = -1.0 / res
89 |
90 | self.use_sparse_flexicube = use_sparse_flexicube
91 | self.use_sparse_sparse_flexicube = use_sparse_sparse_flexicube
92 |
93 | self.use_color = use_color
94 | self._calc_layout()
95 |
96 | def _calc_layout(self):
97 | LAYOUTS = {
98 | 'sdf': {'shape': (8, 1), 'size': 8},
99 | 'deform': {'shape': (8, 3), 'size': 8 * 3},
100 | 'weights': {'shape': (21,), 'size': 21}
101 | }
102 | if self.use_color:
103 | '''
104 | 6 channel color including normal map
105 | '''
106 | LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6}
107 | self.layouts = edict(LAYOUTS)
108 | start = 0
109 | for k, v in self.layouts.items():
110 | v['range'] = (start, start + v['size'])
111 | start += v['size']
112 | self.feats_channels = start
113 |
114 | def get_layout(self, feats : torch.Tensor, name : str):
115 | if name not in self.layouts:
116 | return None
117 | return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape'])
118 |
119 | def __call__(self, cubefeats : SparseTensor, training=False):
120 | """
121 | Generates a mesh based on the specified sparse voxel structures.
122 | Args:
123 | cube_attrs [Nx21] : Sparse Tensor attrs about cube weights
124 | verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal
125 | Returns:
126 | return the success tag and ni you loss,
127 | """
128 | assert self.use_color == False
129 |
130 | coords = cubefeats.coords[:, 1:]
131 | feats = cubefeats.feats
132 |
133 | sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']]
134 | sdf = sdf * (4. / self.res)
135 | sdf += self.sdf_bias
136 |
137 | v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform]
138 | v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training)
139 |
140 | res_v = self.res + 1
141 | v_attrs_d_sparse, v_pos_dilate = get_sparse_attrs(v_pos, v_attrs, res=res_v, sdf_init=True)
142 | weights_d_sparse, coords_dilate = get_sparse_attrs(coords, weights, res=self.res, sdf_init=False)
143 |
144 | sdf_d, deform_d = v_attrs_d_sparse[..., 0], v_attrs_d_sparse[..., 1:4]
145 |
146 | x_nx3 = get_defomed_verts(v_pos_dilate, deform_d, self.res)
147 | x_nx3 = torch.cat((x_nx3, torch.ones((1, 3), dtype=x_nx3.dtype, device=x_nx3.device) * 0.5))
148 | sdf_d = torch.cat((sdf_d, torch.ones((1), dtype=sdf_d.dtype, device=sdf_d.device)))
149 |
150 | mask_reg_c_sparse = (v_pos_dilate[..., 0] * res_v + v_pos_dilate[..., 1]) * res_v + v_pos_dilate[..., 2]
151 | reg_c_sparse = (coords_dilate[..., 0] * res_v + coords_dilate[..., 1]) * res_v + coords_dilate[..., 2]
152 | cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2]
153 | reg_c_value = (reg_c_sparse.unsqueeze(1) + cube_corners_bias.unsqueeze(0).cuda()).reshape(-1)
154 | reg_c = torch.searchsorted(mask_reg_c_sparse, reg_c_value)
155 | exact_match_mask = mask_reg_c_sparse[reg_c] == reg_c_value
156 | reg_c[exact_match_mask == 0] = len(mask_reg_c_sparse)
157 | reg_c = reg_c.reshape(-1, 8)
158 |
159 | vertices, faces, L_dev, colors = self.mesh_extractor(
160 | voxelgrid_vertices=x_nx3,
161 | scalar_field=sdf_d,
162 | cube_idx=reg_c,
163 | resolution=self.res,
164 | beta=weights_d_sparse[:, :12],
165 | alpha=weights_d_sparse[:, 12:20],
166 | gamma_f=weights_d_sparse[:, 20],
167 | cube_index_map=coords_dilate,
168 | training=training)
169 |
170 | mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res)
171 |
172 | return mesh
--------------------------------------------------------------------------------
/triposf/modules/sparse/attention/windowed_attn.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import math
27 | from .. import SparseTensor
28 | from .. import DEBUG, ATTN
29 |
30 | if ATTN == 'xformers':
31 | import xformers.ops as xops
32 | elif ATTN == 'flash_attn':
33 | import flash_attn
34 | else:
35 | raise ValueError(f"Unknown attention module: {ATTN}")
36 |
37 |
38 | __all__ = [
39 | 'sparse_windowed_scaled_dot_product_self_attention',
40 | ]
41 |
42 |
43 | def calc_window_partition(
44 | tensor: SparseTensor,
45 | window_size: Union[int, Tuple[int, ...]],
46 | shift_window: Union[int, Tuple[int, ...]] = 0
47 | ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
48 | """
49 | Calculate serialization and partitioning for a set of coordinates.
50 |
51 | Args:
52 | tensor (SparseTensor): The input tensor.
53 | window_size (int): The window size to use.
54 | shift_window (Tuple[int, ...]): The shift of serialized coordinates.
55 |
56 | Returns:
57 | (torch.Tensor): Forwards indices.
58 | (torch.Tensor): Backwards indices.
59 | (List[int]): Sequence lengths.
60 | (List[int]): Sequence batch indices.
61 | """
62 | DIM = tensor.coords.shape[1] - 1
63 | shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
64 | window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
65 | shifted_coords = tensor.coords.clone().detach()
66 | shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
67 |
68 | MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
69 | NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
70 | OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
71 |
72 | shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
73 | shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
74 | fwd_indices = torch.argsort(shifted_indices)
75 | bwd_indices = torch.empty_like(fwd_indices)
76 | bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
77 | seq_lens = torch.bincount(shifted_indices)
78 | seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
79 | mask = seq_lens != 0
80 | seq_lens = seq_lens[mask].tolist()
81 | seq_batch_indices = seq_batch_indices[mask].tolist()
82 |
83 | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
84 |
85 |
86 | def sparse_windowed_scaled_dot_product_self_attention(
87 | qkv: SparseTensor,
88 | window_size: int,
89 | shift_window: Tuple[int, int, int] = (0, 0, 0)
90 | ) -> SparseTensor:
91 | """
92 | Apply windowed scaled dot product self attention to a sparse tensor.
93 |
94 | Args:
95 | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
96 | window_size (int): The window size to use.
97 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
98 | shift (int): The shift to use.
99 | """
100 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
101 |
102 | serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
103 | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
104 | if serialization_spatial_cache is None:
105 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
106 | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
107 | else:
108 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
109 |
110 | M = fwd_indices.shape[0]
111 | T = qkv.feats.shape[0]
112 | H = qkv.feats.shape[2]
113 | C = qkv.feats.shape[3]
114 |
115 | qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
116 |
117 | if DEBUG:
118 | start = 0
119 | qkv_coords = qkv.coords[fwd_indices]
120 | for i in range(len(seq_lens)):
121 | seq_coords = qkv_coords[start:start+seq_lens[i]]
122 | assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
123 | assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
124 | f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
125 | start += seq_lens[i]
126 |
127 | if all([seq_len == window_size for seq_len in seq_lens]):
128 | B = len(seq_lens)
129 | N = window_size
130 | qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
131 | if ATTN == 'xformers':
132 | q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
133 | out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
134 | elif ATTN == 'flash_attn':
135 | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
136 | else:
137 | raise ValueError(f"Unknown attention module: {ATTN}")
138 | out = out.reshape(B * N, H, C) # [M, H, C]
139 | else:
140 | if ATTN == 'xformers':
141 | q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
142 | q = q.unsqueeze(0) # [1, M, H, C]
143 | k = k.unsqueeze(0) # [1, M, H, C]
144 | v = v.unsqueeze(0) # [1, M, H, C]
145 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
146 | out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
147 | elif ATTN == 'flash_attn':
148 | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
149 | .to(qkv.device).int()
150 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
151 |
152 | out = out[bwd_indices] # [T, H, C]
153 |
154 | if DEBUG:
155 | qkv_coords = qkv_coords[bwd_indices]
156 | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
157 |
158 | return qkv.replace(out)
159 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/attention/modules.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.functional as F
28 | from .. import SparseTensor
29 | from .full_attn import sparse_scaled_dot_product_attention
30 | from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
31 | from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
32 | from ...attention import RotaryPositionEmbedder
33 |
34 |
35 | class SparseMultiHeadRMSNorm(nn.Module):
36 | def __init__(self, dim: int, heads: int):
37 | super().__init__()
38 | self.scale = dim ** 0.5
39 | self.gamma = nn.Parameter(torch.ones(heads, dim))
40 |
41 | def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
42 | x_type = x.dtype
43 | x = x.float()
44 | if isinstance(x, SparseTensor):
45 | x = x.replace(F.normalize(x.feats, dim=-1))
46 | else:
47 | x = F.normalize(x, dim=-1)
48 | return (x * self.gamma * self.scale).to(x_type)
49 |
50 |
51 | class SparseMultiHeadAttention(nn.Module):
52 | def __init__(
53 | self,
54 | channels: int,
55 | num_heads: int,
56 | ctx_channels: Optional[int] = None,
57 | type: Literal["self", "cross"] = "self",
58 | attn_mode: Literal["full", "serialized", "windowed"] = "full",
59 | window_size: Optional[int] = None,
60 | shift_sequence: Optional[int] = None,
61 | shift_window: Optional[Tuple[int, int, int]] = None,
62 | serialize_mode: Optional[SerializeMode] = None,
63 | qkv_bias: bool = True,
64 | use_rope: bool = False,
65 | qk_rms_norm: bool = False,
66 | ):
67 | super().__init__()
68 | assert channels % num_heads == 0
69 | assert type in ["self", "cross"], f"Invalid attention type: {type}"
70 | assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
71 | assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
72 | assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
73 | self.channels = channels
74 | self.ctx_channels = ctx_channels if ctx_channels is not None else channels
75 | self.num_heads = num_heads
76 | self._type = type
77 | self.attn_mode = attn_mode
78 | self.window_size = window_size
79 | self.shift_sequence = shift_sequence
80 | self.shift_window = shift_window
81 | self.serialize_mode = serialize_mode
82 | self.use_rope = use_rope
83 | self.qk_rms_norm = qk_rms_norm
84 |
85 | if self._type == "self":
86 | self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
87 | else:
88 | self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
89 | self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
90 |
91 | if self.qk_rms_norm:
92 | self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
93 | self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
94 |
95 | self.to_out = nn.Linear(channels, channels)
96 |
97 | if use_rope:
98 | self.rope = RotaryPositionEmbedder(channels)
99 |
100 | @staticmethod
101 | def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
102 | if isinstance(x, SparseTensor):
103 | return x.replace(module(x.feats))
104 | else:
105 | return module(x)
106 |
107 | @staticmethod
108 | def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
109 | if isinstance(x, SparseTensor):
110 | return x.reshape(*shape)
111 | else:
112 | return x.reshape(*x.shape[:2], *shape)
113 |
114 | def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
115 | if isinstance(x, SparseTensor):
116 | x_feats = x.feats.unsqueeze(0)
117 | else:
118 | x_feats = x
119 | x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
120 | return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
121 |
122 | def _rope(self, qkv: SparseTensor) -> SparseTensor:
123 | q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
124 | q, k = self.rope(q, k, qkv.coords[:, 1:])
125 | qkv = qkv.replace(torch.stack([q, k, v], dim=1))
126 | return qkv
127 |
128 | def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
129 | if self._type == "self":
130 | qkv = self._linear(self.to_qkv, x)
131 | qkv = self._fused_pre(qkv, num_fused=3)
132 | if self.use_rope:
133 | qkv = self._rope(qkv)
134 | if self.qk_rms_norm:
135 | q, k, v = qkv.unbind(dim=1)
136 | q = self.q_rms_norm(q)
137 | k = self.k_rms_norm(k)
138 | qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
139 | if self.attn_mode == "full":
140 | h = sparse_scaled_dot_product_attention(qkv)
141 | elif self.attn_mode == "serialized":
142 | h = sparse_serialized_scaled_dot_product_self_attention(
143 | qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
144 | )
145 | elif self.attn_mode == "windowed":
146 | h = sparse_windowed_scaled_dot_product_self_attention(
147 | qkv, self.window_size, shift_window=self.shift_window
148 | )
149 | else:
150 | q = self._linear(self.to_q, x)
151 | q = self._reshape_chs(q, (self.num_heads, -1))
152 | kv = self._linear(self.to_kv, context)
153 | kv = self._fused_pre(kv, num_fused=2)
154 | if self.qk_rms_norm:
155 | q = self.q_rms_norm(q)
156 | k, v = kv.unbind(dim=1)
157 | k = self.k_rms_norm(k)
158 | kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
159 | h = sparse_scaled_dot_product_attention(q, kv)
160 | h = self._reshape_chs(h, (-1,))
161 | h = self._linear(self.to_out, h)
162 | return h
163 |
--------------------------------------------------------------------------------
/triposf/modules/transformer/blocks.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | import torch.utils.checkpoint
28 | from ..attention import MultiHeadAttention
29 | from ..norm import LayerNorm32
30 |
31 |
32 | class AbsolutePositionEmbedder(nn.Module):
33 | """
34 | Embeds spatial positions into vector representations.
35 | """
36 | def __init__(self, channels: int, in_channels: int = 3):
37 | super().__init__()
38 | self.channels = channels
39 | self.in_channels = in_channels
40 | self.freq_dim = channels // in_channels // 2
41 | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
42 | self.freqs = 1.0 / (10000 ** self.freqs)
43 |
44 | def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
45 | """
46 | Create sinusoidal position embeddings.
47 |
48 | Args:
49 | x: a 1-D Tensor of N indices
50 |
51 | Returns:
52 | an (N, D) Tensor of positional embeddings.
53 | """
54 | self.freqs = self.freqs.to(x.device)
55 | out = torch.outer(x, self.freqs)
56 | out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
57 | return out
58 |
59 | def forward(self, x: torch.Tensor) -> torch.Tensor:
60 | """
61 | Args:
62 | x (torch.Tensor): (N, D) tensor of spatial positions
63 | """
64 | N, D = x.shape
65 | assert D == self.in_channels, "Input dimension must match number of input channels"
66 | embed = self._sin_cos_embedding(x.reshape(-1))
67 | embed = embed.reshape(N, -1)
68 | if embed.shape[1] < self.channels:
69 | embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
70 | return embed
71 |
72 |
73 | class FeedForwardNet(nn.Module):
74 | def __init__(self, channels: int, mlp_ratio: float = 4.0):
75 | super().__init__()
76 | self.mlp = nn.Sequential(
77 | nn.Linear(channels, int(channels * mlp_ratio)),
78 | nn.GELU(approximate="tanh"),
79 | nn.Linear(int(channels * mlp_ratio), channels),
80 | )
81 |
82 | def forward(self, x: torch.Tensor) -> torch.Tensor:
83 | return self.mlp(x)
84 |
85 |
86 | class TransformerBlock(nn.Module):
87 | """
88 | Transformer block (MSA + FFN).
89 | """
90 | def __init__(
91 | self,
92 | channels: int,
93 | num_heads: int,
94 | mlp_ratio: float = 4.0,
95 | attn_mode: Literal["full", "windowed"] = "full",
96 | window_size: Optional[int] = None,
97 | shift_window: Optional[int] = None,
98 | use_checkpoint: bool = False,
99 | use_rope: bool = False,
100 | qk_rms_norm: bool = False,
101 | qkv_bias: bool = True,
102 | ln_affine: bool = False,
103 | ):
104 | super().__init__()
105 | self.use_checkpoint = use_checkpoint
106 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
107 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
108 | self.attn = MultiHeadAttention(
109 | channels,
110 | num_heads=num_heads,
111 | attn_mode=attn_mode,
112 | window_size=window_size,
113 | shift_window=shift_window,
114 | qkv_bias=qkv_bias,
115 | use_rope=use_rope,
116 | qk_rms_norm=qk_rms_norm,
117 | )
118 | self.mlp = FeedForwardNet(
119 | channels,
120 | mlp_ratio=mlp_ratio,
121 | )
122 |
123 | def _forward(self, x: torch.Tensor) -> torch.Tensor:
124 | h = self.norm1(x)
125 | h = self.attn(h)
126 | x = x + h
127 | h = self.norm2(x)
128 | h = self.mlp(h)
129 | x = x + h
130 | return x
131 |
132 | def forward(self, x: torch.Tensor) -> torch.Tensor:
133 | if self.use_checkpoint:
134 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
135 | else:
136 | return self._forward(x)
137 |
138 |
139 | class TransformerCrossBlock(nn.Module):
140 | """
141 | Transformer cross-attention block (MSA + MCA + FFN).
142 | """
143 | def __init__(
144 | self,
145 | channels: int,
146 | ctx_channels: int,
147 | num_heads: int,
148 | mlp_ratio: float = 4.0,
149 | attn_mode: Literal["full", "windowed"] = "full",
150 | window_size: Optional[int] = None,
151 | shift_window: Optional[Tuple[int, int, int]] = None,
152 | use_checkpoint: bool = False,
153 | use_rope: bool = False,
154 | qk_rms_norm: bool = False,
155 | qk_rms_norm_cross: bool = False,
156 | qkv_bias: bool = True,
157 | ln_affine: bool = False,
158 | ):
159 | super().__init__()
160 | self.use_checkpoint = use_checkpoint
161 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
162 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
163 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
164 | self.self_attn = MultiHeadAttention(
165 | channels,
166 | num_heads=num_heads,
167 | type="self",
168 | attn_mode=attn_mode,
169 | window_size=window_size,
170 | shift_window=shift_window,
171 | qkv_bias=qkv_bias,
172 | use_rope=use_rope,
173 | qk_rms_norm=qk_rms_norm,
174 | )
175 | self.cross_attn = MultiHeadAttention(
176 | channels,
177 | ctx_channels=ctx_channels,
178 | num_heads=num_heads,
179 | type="cross",
180 | attn_mode="full",
181 | qkv_bias=qkv_bias,
182 | qk_rms_norm=qk_rms_norm_cross,
183 | )
184 | self.mlp = FeedForwardNet(
185 | channels,
186 | mlp_ratio=mlp_ratio,
187 | )
188 |
189 | def _forward(self, x: torch.Tensor, context: torch.Tensor):
190 | h = self.norm1(x)
191 | h = self.self_attn(h)
192 | x = x + h
193 | h = self.norm2(x)
194 | h = self.cross_attn(h, context)
195 | x = x + h
196 | h = self.norm3(x)
197 | h = self.mlp(h)
198 | x = x + h
199 | return x
200 |
201 | def forward(self, x: torch.Tensor, context: torch.Tensor):
202 | if self.use_checkpoint:
203 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
204 | else:
205 | return self._forward(x, context)
206 |
--------------------------------------------------------------------------------
/triposf/modules/attention/modules.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.functional as F
28 | from .full_attn import scaled_dot_product_attention
29 |
30 |
31 | class MultiHeadRMSNorm(nn.Module):
32 | def __init__(self, dim: int, heads: int):
33 | super().__init__()
34 | self.scale = dim ** 0.5
35 | self.gamma = nn.Parameter(torch.ones(heads, dim))
36 |
37 | def forward(self, x: torch.Tensor) -> torch.Tensor:
38 | return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
39 |
40 |
41 | class RotaryPositionEmbedder(nn.Module):
42 | def __init__(self, hidden_size: int, in_channels: int = 3):
43 | super().__init__()
44 | assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
45 | self.hidden_size = hidden_size
46 | self.in_channels = in_channels
47 | self.freq_dim = hidden_size // in_channels // 2
48 | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
49 | self.freqs = 1.0 / (10000 ** self.freqs)
50 |
51 | def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
52 | self.freqs = self.freqs.to(indices.device)
53 | phases = torch.outer(indices, self.freqs)
54 | phases = torch.polar(torch.ones_like(phases), phases)
55 | return phases
56 |
57 | def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
58 | x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
59 | x_rotated = x_complex * phases
60 | x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
61 | return x_embed
62 |
63 | def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
64 | """
65 | Args:
66 | q (sp.SparseTensor): [..., N, D] tensor of queries
67 | k (sp.SparseTensor): [..., N, D] tensor of keys
68 | indices (torch.Tensor): [..., N, C] tensor of spatial positions
69 | """
70 | if indices is None:
71 | indices = torch.arange(q.shape[-2], device=q.device)
72 | if len(q.shape) > 2:
73 | indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
74 |
75 | phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
76 | if phases.shape[1] < self.hidden_size // 2:
77 | phases = torch.cat([phases, torch.polar(
78 | torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
79 | torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
80 | )], dim=-1)
81 | q_embed = self._rotary_embedding(q, phases)
82 | k_embed = self._rotary_embedding(k, phases)
83 | return q_embed, k_embed
84 |
85 |
86 | class MultiHeadAttention(nn.Module):
87 | def __init__(
88 | self,
89 | channels: int,
90 | num_heads: int,
91 | ctx_channels: Optional[int]=None,
92 | type: Literal["self", "cross"] = "self",
93 | attn_mode: Literal["full", "windowed"] = "full",
94 | window_size: Optional[int] = None,
95 | shift_window: Optional[Tuple[int, int, int]] = None,
96 | qkv_bias: bool = True,
97 | use_rope: bool = False,
98 | qk_rms_norm: bool = False,
99 | ):
100 | super().__init__()
101 | assert channels % num_heads == 0
102 | assert type in ["self", "cross"], f"Invalid attention type: {type}"
103 | assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
104 | assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
105 |
106 | if attn_mode == "windowed":
107 | raise NotImplementedError("Windowed attention is not yet implemented")
108 |
109 | self.channels = channels
110 | self.head_dim = channels // num_heads
111 | self.ctx_channels = ctx_channels if ctx_channels is not None else channels
112 | self.num_heads = num_heads
113 | self._type = type
114 | self.attn_mode = attn_mode
115 | self.window_size = window_size
116 | self.shift_window = shift_window
117 | self.use_rope = use_rope
118 | self.qk_rms_norm = qk_rms_norm
119 |
120 | if self._type == "self":
121 | self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
122 | else:
123 | self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
124 | self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
125 |
126 | if self.qk_rms_norm:
127 | self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
128 | self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
129 |
130 | self.to_out = nn.Linear(channels, channels)
131 |
132 | if use_rope:
133 | self.rope = RotaryPositionEmbedder(channels)
134 |
135 | def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
136 | B, L, C = x.shape
137 | if self._type == "self":
138 | qkv = self.to_qkv(x)
139 | qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
140 | if self.use_rope:
141 | q, k, v = qkv.unbind(dim=2)
142 | q, k = self.rope(q, k, indices)
143 | qkv = torch.stack([q, k, v], dim=2)
144 | if self.attn_mode == "full":
145 | if self.qk_rms_norm:
146 | q, k, v = qkv.unbind(dim=2)
147 | q = self.q_rms_norm(q)
148 | k = self.k_rms_norm(k)
149 | h = scaled_dot_product_attention(q, k, v)
150 | else:
151 | h = scaled_dot_product_attention(qkv)
152 | elif self.attn_mode == "windowed":
153 | raise NotImplementedError("Windowed attention is not yet implemented")
154 | else:
155 | Lkv = context.shape[1]
156 | q = self.to_q(x)
157 | kv = self.to_kv(context)
158 | q = q.reshape(B, L, self.num_heads, -1)
159 | kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
160 | if self.qk_rms_norm:
161 | q = self.q_rms_norm(q)
162 | k, v = kv.unbind(dim=2)
163 | k = self.k_rms_norm(k)
164 | h = scaled_dot_product_attention(q, k, v)
165 | else:
166 | h = scaled_dot_product_attention(q, kv)
167 | h = h.reshape(B, L, -1)
168 | h = self.to_out(h)
169 | return h
170 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/transformer/modulated.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | import torch.utils.checkpoint
28 | from ..basic import SparseTensor
29 | from ..attention import SparseMultiHeadAttention, SerializeMode
30 | from ...norm import LayerNorm32
31 | from .blocks import SparseFeedForwardNet
32 |
33 |
34 | class ModulatedSparseTransformerBlock(nn.Module):
35 | """
36 | Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
37 | """
38 | def __init__(
39 | self,
40 | channels: int,
41 | num_heads: int,
42 | mlp_ratio: float = 4.0,
43 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
44 | window_size: Optional[int] = None,
45 | shift_sequence: Optional[int] = None,
46 | shift_window: Optional[Tuple[int, int, int]] = None,
47 | serialize_mode: Optional[SerializeMode] = None,
48 | use_checkpoint: bool = False,
49 | use_rope: bool = False,
50 | qk_rms_norm: bool = False,
51 | qkv_bias: bool = True,
52 | share_mod: bool = False,
53 | ):
54 | super().__init__()
55 | self.use_checkpoint = use_checkpoint
56 | self.share_mod = share_mod
57 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
58 | self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
59 | self.attn = SparseMultiHeadAttention(
60 | channels,
61 | num_heads=num_heads,
62 | attn_mode=attn_mode,
63 | window_size=window_size,
64 | shift_sequence=shift_sequence,
65 | shift_window=shift_window,
66 | serialize_mode=serialize_mode,
67 | qkv_bias=qkv_bias,
68 | use_rope=use_rope,
69 | qk_rms_norm=qk_rms_norm,
70 | )
71 | self.mlp = SparseFeedForwardNet(
72 | channels,
73 | mlp_ratio=mlp_ratio,
74 | )
75 | if not share_mod:
76 | self.adaLN_modulation = nn.Sequential(
77 | nn.SiLU(),
78 | nn.Linear(channels, 6 * channels, bias=True)
79 | )
80 |
81 | def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
82 | if self.share_mod:
83 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
84 | else:
85 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
86 | h = x.replace(self.norm1(x.feats))
87 | h = h * (1 + scale_msa) + shift_msa
88 | h = self.attn(h)
89 | h = h * gate_msa
90 | x = x + h
91 | h = x.replace(self.norm2(x.feats))
92 | h = h * (1 + scale_mlp) + shift_mlp
93 | h = self.mlp(h)
94 | h = h * gate_mlp
95 | x = x + h
96 | return x
97 |
98 | def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
99 | if self.use_checkpoint:
100 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
101 | else:
102 | return self._forward(x, mod)
103 |
104 |
105 | class ModulatedSparseTransformerCrossBlock(nn.Module):
106 | """
107 | Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
108 | """
109 | def __init__(
110 | self,
111 | channels: int,
112 | ctx_channels: int,
113 | num_heads: int,
114 | mlp_ratio: float = 4.0,
115 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
116 | window_size: Optional[int] = None,
117 | shift_sequence: Optional[int] = None,
118 | shift_window: Optional[Tuple[int, int, int]] = None,
119 | serialize_mode: Optional[SerializeMode] = None,
120 | use_checkpoint: bool = False,
121 | use_rope: bool = False,
122 | qk_rms_norm: bool = False,
123 | qk_rms_norm_cross: bool = False,
124 | qkv_bias: bool = True,
125 | share_mod: bool = False,
126 |
127 | ):
128 | super().__init__()
129 | self.use_checkpoint = use_checkpoint
130 | self.share_mod = share_mod
131 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
132 | self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
133 | self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
134 | self.self_attn = SparseMultiHeadAttention(
135 | channels,
136 | num_heads=num_heads,
137 | type="self",
138 | attn_mode=attn_mode,
139 | window_size=window_size,
140 | shift_sequence=shift_sequence,
141 | shift_window=shift_window,
142 | serialize_mode=serialize_mode,
143 | qkv_bias=qkv_bias,
144 | use_rope=use_rope,
145 | qk_rms_norm=qk_rms_norm,
146 | )
147 | self.cross_attn = SparseMultiHeadAttention(
148 | channels,
149 | ctx_channels=ctx_channels,
150 | num_heads=num_heads,
151 | type="cross",
152 | attn_mode="full",
153 | qkv_bias=qkv_bias,
154 | qk_rms_norm=qk_rms_norm_cross,
155 | )
156 | self.mlp = SparseFeedForwardNet(
157 | channels,
158 | mlp_ratio=mlp_ratio,
159 | )
160 | if not share_mod:
161 | self.adaLN_modulation = nn.Sequential(
162 | nn.SiLU(),
163 | nn.Linear(channels, 6 * channels, bias=True)
164 | )
165 |
166 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
167 | if self.share_mod:
168 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
169 | else:
170 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
171 | h = x.replace(self.norm1(x.feats))
172 | h = h * (1 + scale_msa) + shift_msa
173 | h = self.self_attn(h)
174 | h = h * gate_msa
175 | x = x + h
176 | h = x.replace(self.norm2(x.feats))
177 | h = self.cross_attn(h, context)
178 | x = x + h
179 | h = x.replace(self.norm3(x.feats))
180 | h = h * (1 + scale_mlp) + shift_mlp
181 | h = self.mlp(h)
182 | h = h * gate_mlp
183 | x = x + h
184 | return x
185 |
186 | def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
187 | if self.use_checkpoint:
188 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
189 | else:
190 | return self._forward(x, mod, context)
191 |
--------------------------------------------------------------------------------
/triposf/modules/pointclouds/pointnet.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | # modified from https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/src/encoder/pointnet.py
25 |
26 | import torch
27 | import torch.nn as nn
28 | import copy
29 | from torch import Tensor
30 | from torch_scatter import scatter_mean
31 |
32 | def scale_tensor(
33 | dat, inp_scale=None, tgt_scale=None
34 | ):
35 | if inp_scale is None:
36 | inp_scale = (-0.5, 0.5)
37 | if tgt_scale is None:
38 | tgt_scale = (0, 1)
39 | assert tgt_scale[1] > tgt_scale[0] and inp_scale[1] > inp_scale[0]
40 | if isinstance(tgt_scale, Tensor):
41 | assert dat.shape[-1] == tgt_scale.shape[-1]
42 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
43 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
44 | return dat.clamp(tgt_scale[0] + 1e-6, tgt_scale[1] - 1e-6)
45 |
46 | # Resnet Blocks for pointnet
47 | class ResnetBlockFC(nn.Module):
48 | ''' Fully connected ResNet Block class.
49 |
50 | Args:
51 | size_in (int): input dimension
52 | size_out (int): output dimension
53 | size_h (int): hidden dimension
54 | '''
55 |
56 | def __init__(self, size_in, size_out=None, size_h=None):
57 | super().__init__()
58 | # Attributes
59 | if size_out is None:
60 | size_out = size_in
61 |
62 | if size_h is None:
63 | size_h = min(size_in, size_out)
64 |
65 | self.size_in = size_in
66 | self.size_h = size_h
67 | self.size_out = size_out
68 | # Submodules
69 | self.fc_0 = nn.Linear(size_in, size_h)
70 | self.fc_1 = nn.Linear(size_h, size_out)
71 | self.actvn = nn.GELU(approximate="tanh")
72 |
73 | if size_in == size_out:
74 | self.shortcut = None
75 | else:
76 | self.shortcut = nn.Linear(size_in, size_out, bias=False)
77 | # Initialization
78 | nn.init.xavier_uniform_(self.fc_0.weight)
79 | if self.fc_0.bias is not None:
80 | nn.init.constant_(self.fc_0.bias, 0)
81 | if self.shortcut is not None:
82 | nn.init.xavier_uniform_(self.shortcut.weight)
83 | if self.shortcut.bias is not None:
84 | nn.init.constant_(self.shortcut.bias, 0)
85 |
86 | nn.init.xavier_uniform_(self.fc_1.weight)
87 | if self.fc_1.bias is not None:
88 | nn.init.constant_(self.fc_1.bias, 0)
89 |
90 |
91 | def forward(self, x):
92 | net = self.fc_0(self.actvn(x))
93 | dx = self.fc_1(self.actvn(net))
94 |
95 | if self.shortcut is not None:
96 | x_s = self.shortcut(x)
97 | else:
98 | x_s = x
99 |
100 | return x_s + dx
101 |
102 | class LocalPoolPointnet(nn.Module):
103 | def __init__(self, in_channels=3, out_channels=128, hidden_dim=128, scatter_type='mean', n_blocks=5):
104 | super().__init__()
105 | self.scatter_type = scatter_type
106 | self.in_channels = in_channels
107 | self.hidden_dim = hidden_dim
108 | self.out_channels = out_channels
109 | self.fc_pos = nn.Linear(in_channels, 2*hidden_dim)
110 | self.blocks = nn.ModuleList([
111 | ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
112 | ])
113 | self.fc_c = nn.Linear(hidden_dim, out_channels)
114 |
115 | if self.scatter_type == 'mean':
116 | self.scatter = scatter_mean
117 | else:
118 | raise ValueError('Incorrect scatter type')
119 | self.initialize_weights()
120 |
121 | def initialize_weights(self):
122 |
123 | nn.init.xavier_uniform_(self.fc_pos.weight)
124 | if self.fc_pos.bias is not None:
125 | nn.init.constant_(self.fc_pos.bias, 0)
126 |
127 | nn.init.xavier_uniform_(self.fc_c.weight)
128 | if self.fc_c.bias is not None:
129 | nn.init.constant_(self.fc_c.bias, 0)
130 |
131 | def convert_to_sparse_feats(self, c, sparse_coords):
132 | '''
133 | Input:
134 | sparse_coords: Tensor [Nx, 4], point to sparse indices
135 | c: Tensor [B, res, C], input feats of each grid
136 | Output:
137 | c_out: Tensor [B, Np, C], aggregated grid feats of each point
138 | '''
139 | feats_new = torch.zeros((sparse_coords.shape[0], c.shape[-1]), device=c.device, dtype=c.dtype)
140 | offsets = 0
141 |
142 | batch_nums = copy.deepcopy(sparse_coords[..., 0])
143 | for i in range(len(c)):
144 | coords_num_i = (batch_nums == i).sum()
145 | feats_new[offsets: offsets + coords_num_i] = c[i, : coords_num_i]
146 | offsets += coords_num_i
147 | return feats_new
148 |
149 | def generate_sparse_grid_features(self, index, c, max_coord_num):
150 | # scatter grid features from points
151 | bs, fea_dim = c.size(0), c.size(2)
152 | res = max_coord_num
153 | c_out = c.new_zeros(bs, self.out_channels, res)
154 | c_out = scatter_mean(c.permute(0, 2, 1), index, out=c_out).permute(0, 2, 1) # B x res X C
155 | return c_out
156 |
157 | def pool_sparse_local(self, index, c, max_coord_num):
158 | '''
159 | Input:
160 | index: Tensor [B, 1, Np], sparse indices of each point
161 | c: Tensor [B, Np, C], input feats of each point
162 | Output:
163 | c_out: Tensor [B, Np, C], aggregated grid feats of each point
164 | '''
165 |
166 | bs, fea_dim = c.size(0), c.size(2)
167 | res = max_coord_num
168 | c_out = c.new_zeros(bs, fea_dim, res)
169 | c_out = self.scatter(c.permute(0, 2, 1), index, out=c_out)
170 |
171 | # gather feature back to points
172 | c_out = c_out.gather(dim=2, index=index.expand(-1, fea_dim, -1))
173 | return c_out.permute(0, 2, 1)
174 |
175 | @torch.no_grad()
176 | def coordinate2sparseindex(self, x, sparse_coords, res):
177 | '''
178 | Input:
179 | x: Tensor [B, Np, 3], points scaled at ([0, 1] * res)
180 | sparse_coords: Tensor [Nx, 4] ([batch_number, x, y, z])
181 | res: Int, resolution of the grid index
182 | Output:
183 | sparse_index: Tensor [B, 1, Np], sparse indices of each point
184 | '''
185 | B = x.shape[0]
186 | sparse_index = torch.zeros((B, x.shape[1]), device=x.device, dtype=torch.int64)
187 |
188 | index = (x[..., 0] * res + x[..., 1]) * res + x[..., 2]
189 | sparse_indices = copy.deepcopy(sparse_coords)
190 | sparse_indices[..., 1] = (sparse_indices[..., 1] * res + sparse_indices[..., 2]) * res + sparse_indices[..., 3]
191 | sparse_indices = sparse_indices[..., :2]
192 |
193 | for i in range(B):
194 | mask_i = sparse_indices[..., 0] == i
195 | coords_i = sparse_indices[mask_i, 1]
196 | coords_num_i = len(coords_i)
197 | sparse_index[i] = torch.searchsorted(coords_i, index[i])
198 |
199 | return sparse_index[:, None, :]
200 |
201 | def forward(self, p, sparse_coords, res=64, bbox_size=(-0.5, 0.5)):
202 | '''
203 | Input:
204 | p : Tensor [B, Np(819_200), 3]
205 | sparse_coords: Tensor [Nx, 4] ([batch_number, x, y, z])
206 |
207 | Output:
208 | sparse_pc_feats: [Nx, self.out_channels]
209 | '''
210 | batch_size, T, D = p.size()
211 | max_coord_num = 0
212 | for i in range(batch_size):
213 | max_coord_num = max(max_coord_num, (sparse_coords[..., 0] == i).sum().item() + 5)
214 |
215 | if D == 6:
216 | p, normals = p[..., :3], p[..., 3:]
217 |
218 | coord = (scale_tensor(p, inp_scale=bbox_size) * res)
219 | p = 2 * (coord - (coord.floor() + 0.5)) # dist to the centrios, [-1., 1.]
220 | index = self.coordinate2sparseindex(coord.long(), sparse_coords, res)
221 |
222 | if D == 6:
223 | p = torch.cat((p, normals), dim=-1)
224 | net = self.fc_pos(p)
225 | net = self.blocks[0](net)
226 | for block in self.blocks[1:]:
227 | pooled = self.pool_sparse_local(index, net, max_coord_num=max_coord_num)
228 |
229 | net = torch.cat([net, pooled], dim=2)
230 | net = block(net)
231 | c = self.fc_c(net)
232 | feats = self.generate_sparse_grid_features(index, c, max_coord_num=max_coord_num)
233 | feats = self.convert_to_sparse_feats(feats, sparse_coords)
234 |
235 | torch.cuda.empty_cache()
236 | return feats
237 |
--------------------------------------------------------------------------------
/triposf/modules/sparse/attention/serialized_attn.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | from enum import Enum
26 | import torch
27 | import math
28 | from .. import SparseTensor
29 | from .. import DEBUG, ATTN
30 |
31 | if ATTN == 'xformers':
32 | import xformers.ops as xops
33 | elif ATTN == 'flash_attn':
34 | import flash_attn
35 | else:
36 | raise ValueError(f"Unknown attention module: {ATTN}")
37 |
38 |
39 | __all__ = [
40 | 'sparse_serialized_scaled_dot_product_self_attention',
41 | ]
42 |
43 |
44 | class SerializeMode(Enum):
45 | Z_ORDER = 0
46 | Z_ORDER_TRANSPOSED = 1
47 | HILBERT = 2
48 | HILBERT_TRANSPOSED = 3
49 |
50 |
51 | SerializeModes = [
52 | SerializeMode.Z_ORDER,
53 | SerializeMode.Z_ORDER_TRANSPOSED,
54 | SerializeMode.HILBERT,
55 | SerializeMode.HILBERT_TRANSPOSED
56 | ]
57 |
58 |
59 | def calc_serialization(
60 | tensor: SparseTensor,
61 | window_size: int,
62 | serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
63 | shift_sequence: int = 0,
64 | shift_window: Tuple[int, int, int] = (0, 0, 0)
65 | ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
66 | """
67 | Calculate serialization and partitioning for a set of coordinates.
68 |
69 | Args:
70 | tensor (SparseTensor): The input tensor.
71 | window_size (int): The window size to use.
72 | serialize_mode (SerializeMode): The serialization mode to use.
73 | shift_sequence (int): The shift of serialized sequence.
74 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
75 |
76 | Returns:
77 | (torch.Tensor, torch.Tensor): Forwards and backwards indices.
78 | """
79 | fwd_indices = []
80 | bwd_indices = []
81 | seq_lens = []
82 | seq_batch_indices = []
83 | offsets = [0]
84 |
85 | if 'vox2seq' not in globals():
86 | import vox2seq
87 |
88 | # Serialize the input
89 | serialize_coords = tensor.coords[:, 1:].clone()
90 | serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
91 | if serialize_mode == SerializeMode.Z_ORDER:
92 | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
93 | elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
94 | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
95 | elif serialize_mode == SerializeMode.HILBERT:
96 | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
97 | elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
98 | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
99 | else:
100 | raise ValueError(f"Unknown serialize mode: {serialize_mode}")
101 |
102 | for bi, s in enumerate(tensor.layout):
103 | num_points = s.stop - s.start
104 | num_windows = (num_points + window_size - 1) // window_size
105 | valid_window_size = num_points / num_windows
106 | to_ordered = torch.argsort(code[s.start:s.stop])
107 | if num_windows == 1:
108 | fwd_indices.append(to_ordered)
109 | bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
110 | fwd_indices[-1] += s.start
111 | bwd_indices[-1] += offsets[-1]
112 | seq_lens.append(num_points)
113 | seq_batch_indices.append(bi)
114 | offsets.append(offsets[-1] + seq_lens[-1])
115 | else:
116 | # Partition the input
117 | offset = 0
118 | mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
119 | split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
120 | bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
121 | for i in range(num_windows):
122 | mid = mids[i]
123 | valid_start = split[i]
124 | valid_end = split[i + 1]
125 | padded_start = math.floor(mid - 0.5 * window_size)
126 | padded_end = padded_start + window_size
127 | fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
128 | offset += valid_start - padded_start
129 | bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
130 | offset += padded_end - valid_start
131 | fwd_indices[-1] += s.start
132 | seq_lens.extend([window_size] * num_windows)
133 | seq_batch_indices.extend([bi] * num_windows)
134 | bwd_indices.append(bwd_index + offsets[-1])
135 | offsets.append(offsets[-1] + num_windows * window_size)
136 |
137 | fwd_indices = torch.cat(fwd_indices)
138 | bwd_indices = torch.cat(bwd_indices)
139 |
140 | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
141 |
142 |
143 | def sparse_serialized_scaled_dot_product_self_attention(
144 | qkv: SparseTensor,
145 | window_size: int,
146 | serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
147 | shift_sequence: int = 0,
148 | shift_window: Tuple[int, int, int] = (0, 0, 0)
149 | ) -> SparseTensor:
150 | """
151 | Apply serialized scaled dot product self attention to a sparse tensor.
152 |
153 | Args:
154 | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
155 | window_size (int): The window size to use.
156 | serialize_mode (SerializeMode): The serialization mode to use.
157 | shift_sequence (int): The shift of serialized sequence.
158 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
159 | shift (int): The shift to use.
160 | """
161 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
162 |
163 | serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
164 | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
165 | if serialization_spatial_cache is None:
166 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
167 | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
168 | else:
169 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
170 |
171 | M = fwd_indices.shape[0]
172 | T = qkv.feats.shape[0]
173 | H = qkv.feats.shape[2]
174 | C = qkv.feats.shape[3]
175 |
176 | qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
177 |
178 | if DEBUG:
179 | start = 0
180 | qkv_coords = qkv.coords[fwd_indices]
181 | for i in range(len(seq_lens)):
182 | assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
183 | start += seq_lens[i]
184 |
185 | if all([seq_len == window_size for seq_len in seq_lens]):
186 | B = len(seq_lens)
187 | N = window_size
188 | qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
189 | if ATTN == 'xformers':
190 | q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
191 | out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
192 | elif ATTN == 'flash_attn':
193 | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
194 | else:
195 | raise ValueError(f"Unknown attention module: {ATTN}")
196 | out = out.reshape(B * N, H, C) # [M, H, C]
197 | else:
198 | if ATTN == 'xformers':
199 | q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
200 | q = q.unsqueeze(0) # [1, M, H, C]
201 | k = k.unsqueeze(0) # [1, M, H, C]
202 | v = v.unsqueeze(0) # [1, M, H, C]
203 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
204 | out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
205 | elif ATTN == 'flash_attn':
206 | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
207 | .to(qkv.device).int()
208 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
209 |
210 | out = out[bwd_indices] # [T, H, C]
211 |
212 | if DEBUG:
213 | qkv_coords = qkv_coords[bwd_indices]
214 | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
215 |
216 | return qkv.replace(out)
217 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2025 VAST-AI-Research and contributors.
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE
22 |
23 | from safetensors.torch import load_file
24 | import torch
25 | from dataclasses import dataclass, field
26 | import numpy as np
27 | import open3d as o3d
28 | import trimesh
29 | import os
30 | import time
31 | import argparse
32 | from omegaconf import OmegaConf
33 | from typing import *
34 |
35 | from triposf.modules import sparse as sp
36 | from misc import get_device, find
37 |
38 | def normalize_mesh(mesh_path):
39 | scene = trimesh.load(mesh_path, process=False, force='scene')
40 | meshes = []
41 | for node_name in scene.graph.nodes_geometry:
42 | geom_name = scene.graph[node_name][1]
43 | geometry = scene.geometry[geom_name]
44 | transform = scene.graph[node_name][0]
45 | if isinstance(geometry, trimesh.Trimesh):
46 | geometry.apply_transform(transform)
47 | meshes.append(geometry)
48 |
49 | mesh = trimesh.util.concatenate(meshes)
50 |
51 | center = mesh.bounding_box.centroid
52 | mesh.apply_translation(-center)
53 | scale = max(mesh.bounding_box.extents)
54 | mesh.apply_scale(2.0 / scale * 0.5)
55 |
56 | return mesh
57 |
58 | def load_quantized_mesh_original(
59 | mesh_path,
60 | volume_resolution=256,
61 | use_normals=True,
62 | pc_sample_number=4096000,
63 | ):
64 | cube_dilate = np.array(
65 | [
66 | [0, 0, 0],
67 | [0, 0, 1],
68 | [0, 1, 0],
69 | [0, 0, -1],
70 | [0, -1, 0],
71 | [0, 1, 1],
72 | [0, -1, 1],
73 | [0, 1, -1],
74 | [0, -1, -1],
75 |
76 | [1, 0, 0],
77 | [1, 0, 1],
78 | [1, 1, 0],
79 | [1, 0, -1],
80 | [1, -1, 0],
81 | [1, 1, 1],
82 | [1, -1, 1],
83 | [1, 1, -1],
84 | [1, -1, -1],
85 |
86 | [-1, 0, 0],
87 | [-1, 0, 1],
88 | [-1, 1, 0],
89 | [-1, 0, -1],
90 | [-1, -1, 0],
91 | [-1, 1, 1],
92 | [-1, -1, 1],
93 | [-1, 1, -1],
94 | [-1, -1, -1],
95 | ]
96 | ) / (volume_resolution * 4 - 1)
97 |
98 |
99 | mesh = o3d.io.read_triangle_mesh(mesh_path)
100 | vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
101 | faces = np.asarray(mesh.triangles)
102 | mesh.vertices = o3d.utility.Vector3dVector(vertices)
103 |
104 | voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
105 | mesh,
106 | voxel_size=1. / volume_resolution,
107 | min_bound=[-0.5, -0.5, -0.5],
108 | max_bound=[0.5, 0.5, 0.5]
109 | )
110 | voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
111 |
112 | points_normals_sample = trimesh.Trimesh(vertices=vertices, faces=faces).sample(count=pc_sample_number, return_index=True)
113 | points_sample = points_normals_sample[0].astype(np.float32)
114 | voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
115 | o3d.geometry.PointCloud(
116 | o3d.utility.Vector3dVector(
117 | np.clip(
118 | (points_sample[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
119 | -0.5 + 1e-6, 0.5 - 1e-6)
120 | )
121 | ),
122 | voxel_size=1. / volume_resolution,
123 | min_bound=[-0.5, -0.5, -0.5],
124 | max_bound=[0.5, 0.5, 0.5]
125 | )
126 | voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
127 | voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
128 |
129 | if use_normals:
130 | mesh.compute_triangle_normals()
131 | normals_sample = np.asarray(
132 | mesh.triangle_normals
133 | )[points_normals_sample[1]].astype(np.float32)
134 | points_sample = torch.cat((torch.Tensor(points_sample), torch.Tensor(normals_sample)), axis=-1)
135 |
136 | return voxels, points_sample
137 |
138 | class TripoSFVAEInference(torch.nn.Module):
139 | @dataclass
140 | class Config:
141 | local_pc_encoder_cls: str = ""
142 | local_pc_encoder: dict = field(default_factory=dict)
143 |
144 | encoder_cls: str = ""
145 | encoder: dict = field(default_factory=dict)
146 |
147 | decoder_cls: str = ""
148 | decoder: dict = field(default_factory=dict)
149 |
150 | resolution: int = 256
151 | sample_points_num: int = 819_200
152 | use_normals: bool = True
153 | pruning: bool = False
154 |
155 | weight: Optional[str] = None
156 |
157 | cfg: Config
158 |
159 | def __init__(self, cfg):
160 | super().__init__()
161 | self.cfg = cfg
162 | self.configure()
163 |
164 | def load_weights(self):
165 | if self.cfg.weight is not None:
166 | print("Pretrained VAE Loading...")
167 | state_dict = load_file(self.cfg.weight)
168 | self.load_state_dict(state_dict)
169 |
170 | def configure(self) -> None:
171 | self.local_pc_encoder = find(self.cfg.local_pc_encoder_cls)(**self.cfg.local_pc_encoder).eval()
172 | for p in self.local_pc_encoder.parameters():
173 | p.requires_grad = False
174 |
175 | self.encoder = find(self.cfg.encoder_cls)(**self.cfg.encoder).eval()
176 | for p in self.encoder.parameters():
177 | p.requires_grad = False
178 |
179 | self.decoder = find(self.cfg.decoder_cls)(**self.cfg.decoder).eval()
180 | for p in self.decoder.parameters():
181 | p.requires_grad = False
182 |
183 | self.load_weights()
184 |
185 | @torch.no_grad()
186 | def forward(self, points_sample, sparse_voxel_coords):
187 | with torch.autocast("cuda", dtype=torch.float32):
188 | sparse_pc_features = self.local_pc_encoder(points_sample, sparse_voxel_coords, res=self.cfg.resolution, bbox_size=(-0.5, 0.5))
189 | sparse_tensor = sp.SparseTensor(sparse_pc_features, sparse_voxel_coords)
190 | latent, posterior = self.encoder(sparse_tensor)
191 | mesh = self.decoder(latent, pruning=self.cfg.pruning)
192 | return mesh
193 |
194 | @classmethod
195 | def from_config(cls, config_path):
196 | config = OmegaConf.load(config_path)
197 | cfg = OmegaConf.merge(OmegaConf.structured(TripoSFVAEInference.Config), config)
198 | return cls(cfg)
199 |
200 | if __name__ == "__main__":
201 | # Usage: `python inference.py --mesh-path "assets/examples/loong.obj" --output-dir "outputs/" --config "configs/triposfVAE_1024.yaml"`
202 | parser = argparse.ArgumentParser("TripoSF Reconstruction")
203 | parser.add_argument("--config", required=True, help="path to config file")
204 | parser.add_argument("--output-dir", default="outputs/", help="path to output folder")
205 | parser.add_argument("--mesh-path", type=str, help="the input mesh to be reconstructed")
206 |
207 | args, extras = parser.parse_known_args()
208 | device = get_device()
209 | save_name = os.path.split(args.mesh_path)[-1].split(".")[0]
210 |
211 | model = TripoSFVAEInference.from_config(args.config).to(device)
212 |
213 | os.makedirs(args.output_dir, exist_ok=True)
214 |
215 | print(f"Mesh Normalizing...")
216 | preprocess_start = time.time()
217 | mesh_gt = normalize_mesh(args.mesh_path)
218 | save_path_gt = f"{args.output_dir}/{save_name}_gt.obj"
219 | trimesh.Trimesh(vertices=mesh_gt.vertices.tolist(), faces=mesh_gt.faces.tolist()).export(save_path_gt)
220 | preprocess_end = time.time()
221 | print(f"Mesh Normalizing Time: {(preprocess_end - preprocess_start):.2f}")
222 |
223 | print(f"Mesh Loading...")
224 | input_loading_start = time.time()
225 | sparse_voxels, points_sample = load_quantized_mesh_original(
226 | save_path_gt,
227 | volume_resolution=model.cfg.resolution,
228 | use_normals=model.cfg.use_normals,
229 | pc_sample_number=model.cfg.sample_points_num,
230 | )
231 | input_loading_end = time.time()
232 | print(f"Mesh Loading Time: {(input_loading_end - input_loading_start):.2f}")
233 |
234 | print(f"Mesh Reconstructing...")
235 | sparse_voxels, points_sample = sparse_voxels.to(device), points_sample.to(device)
236 | sparse_voxels_sp = torch.cat([torch.zeros_like(sparse_voxels[..., :1]), sparse_voxels], dim=-1).int()
237 |
238 | inference_start = time.time()
239 | with torch.cuda.amp.autocast(dtype=torch.float16):
240 | mesh_recon = model(points_sample[None], sparse_voxels_sp)[0]
241 | inference_end = time.time()
242 | print(f"Mesh Reconstructing Time: {(inference_end - inference_start):.2f}")
243 |
244 | save_path_recon = f"{args.output_dir}/{save_name}_reconstruction.obj"
245 | trimesh.Trimesh(vertices=mesh_recon.vertices.tolist(), faces=mesh_recon.faces.tolist()).export(save_path_recon)
--------------------------------------------------------------------------------
/triposf/modules/sparse/attention/full_attn.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | from .. import SparseTensor
27 | from .. import DEBUG, ATTN
28 |
29 | if ATTN == 'xformers':
30 | import xformers.ops as xops
31 | elif ATTN == 'flash_attn':
32 | import flash_attn
33 | else:
34 | raise ValueError(f"Unknown attention module: {ATTN}")
35 |
36 |
37 | __all__ = [
38 | 'sparse_scaled_dot_product_attention',
39 | ]
40 |
41 |
42 | @overload
43 | def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
44 | """
45 | Apply scaled dot product attention to a sparse tensor.
46 |
47 | Args:
48 | qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
49 | """
50 | ...
51 |
52 | @overload
53 | def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
54 | """
55 | Apply scaled dot product attention to a sparse tensor.
56 |
57 | Args:
58 | q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
59 | kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
60 | """
61 | ...
62 |
63 | @overload
64 | def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
65 | """
66 | Apply scaled dot product attention to a sparse tensor.
67 |
68 | Args:
69 | q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
70 | kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
71 | """
72 | ...
73 |
74 | @overload
75 | def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
76 | """
77 | Apply scaled dot product attention to a sparse tensor.
78 |
79 | Args:
80 | q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
81 | k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
82 | v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
83 |
84 | Note:
85 | k and v are assumed to have the same coordinate map.
86 | """
87 | ...
88 |
89 | @overload
90 | def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
91 | """
92 | Apply scaled dot product attention to a sparse tensor.
93 |
94 | Args:
95 | q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
96 | k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
97 | v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
98 | """
99 | ...
100 |
101 | @overload
102 | def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
103 | """
104 | Apply scaled dot product attention to a sparse tensor.
105 |
106 | Args:
107 | q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
108 | k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
109 | v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
110 | """
111 | ...
112 |
113 | def sparse_scaled_dot_product_attention(*args, **kwargs):
114 | arg_names_dict = {
115 | 1: ['qkv'],
116 | 2: ['q', 'kv'],
117 | 3: ['q', 'k', 'v']
118 | }
119 | num_all_args = len(args) + len(kwargs)
120 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
121 | for key in arg_names_dict[num_all_args][len(args):]:
122 | assert key in kwargs, f"Missing argument {key}"
123 |
124 | if num_all_args == 1:
125 | qkv = args[0] if len(args) > 0 else kwargs['qkv']
126 | assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
127 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
128 | device = qkv.device
129 |
130 | s = qkv
131 | q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
132 | kv_seqlen = q_seqlen
133 | qkv = qkv.feats # [T, 3, H, C]
134 |
135 | elif num_all_args == 2:
136 | q = args[0] if len(args) > 0 else kwargs['q']
137 | kv = args[1] if len(args) > 1 else kwargs['kv']
138 | assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
139 | isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
140 | f"Invalid types, got {type(q)} and {type(kv)}"
141 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
142 | device = q.device
143 |
144 | if isinstance(q, SparseTensor):
145 | assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
146 | s = q
147 | q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
148 | q = q.feats # [T_Q, H, C]
149 | else:
150 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
151 | s = None
152 | N, L, H, C = q.shape
153 | q_seqlen = [L] * N
154 | q = q.reshape(N * L, H, C) # [T_Q, H, C]
155 |
156 | if isinstance(kv, SparseTensor):
157 | assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
158 | kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
159 | kv = kv.feats # [T_KV, 2, H, C]
160 | else:
161 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
162 | N, L, _, H, C = kv.shape
163 | kv_seqlen = [L] * N
164 | kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
165 |
166 | elif num_all_args == 3:
167 | q = args[0] if len(args) > 0 else kwargs['q']
168 | k = args[1] if len(args) > 1 else kwargs['k']
169 | v = args[2] if len(args) > 2 else kwargs['v']
170 | assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
171 | isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
172 | f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
173 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
174 | device = q.device
175 |
176 | if isinstance(q, SparseTensor):
177 | assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
178 | s = q
179 | q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
180 | q = q.feats # [T_Q, H, Ci]
181 | else:
182 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
183 | s = None
184 | N, L, H, CI = q.shape
185 | q_seqlen = [L] * N
186 | q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
187 |
188 | if isinstance(k, SparseTensor):
189 | assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
190 | assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
191 | kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
192 | k = k.feats # [T_KV, H, Ci]
193 | v = v.feats # [T_KV, H, Co]
194 | else:
195 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
196 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
197 | N, L, H, CI, CO = *k.shape, v.shape[-1]
198 | kv_seqlen = [L] * N
199 | k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
200 | v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
201 |
202 | if DEBUG:
203 | if s is not None:
204 | for i in range(s.shape[0]):
205 | assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
206 | if num_all_args in [2, 3]:
207 | assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
208 | if num_all_args == 3:
209 | assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
210 | assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
211 |
212 | if ATTN == 'xformers':
213 | if num_all_args == 1:
214 | q, k, v = qkv.unbind(dim=1)
215 | elif num_all_args == 2:
216 | k, v = kv.unbind(dim=1)
217 | q = q.unsqueeze(0)
218 | k = k.unsqueeze(0)
219 | v = v.unsqueeze(0)
220 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
221 | out = xops.memory_efficient_attention(q, k, v, mask)[0]
222 | elif ATTN == 'flash_attn':
223 | cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
224 | if num_all_args in [2, 3]:
225 | cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
226 | if num_all_args == 1:
227 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
228 | elif num_all_args == 2:
229 | out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
230 | elif num_all_args == 3:
231 | out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
232 | else:
233 | raise ValueError(f"Unknown attention module: {ATTN}")
234 |
235 | if s is not None:
236 | return s.replace(out)
237 | else:
238 | return out.reshape(N, L, H, -1)
239 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2025 VAST-AI-Research and contributors.
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE
22 |
23 | import random
24 | import gradio as gr
25 | from safetensors.torch import load_file
26 | import torch
27 | from dataclasses import dataclass, field
28 | import numpy as np
29 | import open3d as o3d
30 | import trimesh
31 | import os
32 | import time
33 | import argparse
34 | from omegaconf import OmegaConf
35 | from typing import *
36 |
37 | from triposf.modules import sparse as sp
38 | from misc import get_device, find
39 |
40 | def get_random_hex():
41 | random_bytes = os.urandom(8)
42 | random_hex = random_bytes.hex()
43 | return random_hex
44 |
45 | def get_random_seed(randomize_seed, seed):
46 | if randomize_seed:
47 | seed = random.randint(0, MAX_SEED)
48 | return seed
49 |
50 | def normalize_mesh(mesh_path):
51 | scene = trimesh.load(mesh_path, process=False, force='scene')
52 | meshes = []
53 | for node_name in scene.graph.nodes_geometry:
54 | geom_name = scene.graph[node_name][1]
55 | geometry = scene.geometry[geom_name]
56 | transform = scene.graph[node_name][0]
57 | if isinstance(geometry, trimesh.Trimesh):
58 | geometry.apply_transform(transform)
59 | meshes.append(geometry)
60 |
61 | mesh = trimesh.util.concatenate(meshes)
62 |
63 | center = mesh.bounding_box.centroid
64 | mesh.apply_translation(-center)
65 | scale = max(mesh.bounding_box.extents)
66 | mesh.apply_scale(2.0 / scale * 0.5)
67 |
68 | angle = np.radians(90)
69 | rotation_matrix = trimesh.transformations.rotation_matrix(angle, [-1, 0, 0])
70 | mesh.apply_transform(rotation_matrix)
71 | return mesh
72 |
73 | def load_quantized_mesh_original(
74 | mesh_path,
75 | volume_resolution=256,
76 | use_normals=True,
77 | pc_sample_number=4096000,
78 | ):
79 | cube_dilate = np.array(
80 | [
81 | [0, 0, 0],
82 | [0, 0, 1],
83 | [0, 1, 0],
84 | [0, 0, -1],
85 | [0, -1, 0],
86 | [0, 1, 1],
87 | [0, -1, 1],
88 | [0, 1, -1],
89 | [0, -1, -1],
90 |
91 | [1, 0, 0],
92 | [1, 0, 1],
93 | [1, 1, 0],
94 | [1, 0, -1],
95 | [1, -1, 0],
96 | [1, 1, 1],
97 | [1, -1, 1],
98 | [1, 1, -1],
99 | [1, -1, -1],
100 |
101 | [-1, 0, 0],
102 | [-1, 0, 1],
103 | [-1, 1, 0],
104 | [-1, 0, -1],
105 | [-1, -1, 0],
106 | [-1, 1, 1],
107 | [-1, -1, 1],
108 | [-1, 1, -1],
109 | [-1, -1, -1],
110 | ]
111 | ) / (volume_resolution * 4 - 1)
112 |
113 |
114 | mesh = o3d.io.read_triangle_mesh(mesh_path)
115 | vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
116 | faces = np.asarray(mesh.triangles)
117 | mesh.vertices = o3d.utility.Vector3dVector(vertices)
118 |
119 | voxelization_mesh = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(
120 | mesh,
121 | voxel_size=1. / volume_resolution,
122 | min_bound=[-0.5, -0.5, -0.5],
123 | max_bound=[0.5, 0.5, 0.5]
124 | )
125 | voxel_mesh = np.asarray([voxel.grid_index for voxel in voxelization_mesh.get_voxels()])
126 |
127 | points_normals_sample = trimesh.Trimesh(vertices=vertices, faces=faces).sample(count=pc_sample_number, return_index=True)
128 | points_sample = points_normals_sample[0].astype(np.float32)
129 | voxelization_points = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(
130 | o3d.geometry.PointCloud(
131 | o3d.utility.Vector3dVector(
132 | np.clip(
133 | (points_sample[np.newaxis] + cube_dilate[..., np.newaxis, :]).reshape(-1, 3),
134 | -0.5 + 1e-6, 0.5 - 1e-6)
135 | )
136 | ),
137 | voxel_size=1. / volume_resolution,
138 | min_bound=[-0.5, -0.5, -0.5],
139 | max_bound=[0.5, 0.5, 0.5]
140 | )
141 | voxel_points = np.asarray([voxel.grid_index for voxel in voxelization_points.get_voxels()])
142 | voxels = torch.Tensor(np.unique(np.concatenate([voxel_mesh, voxel_points]), axis=0))
143 |
144 | if use_normals:
145 | mesh.compute_triangle_normals()
146 | normals_sample = np.asarray(
147 | mesh.triangle_normals
148 | )[points_normals_sample[1]].astype(np.float32)
149 | points_sample = torch.cat((torch.Tensor(points_sample), torch.Tensor(normals_sample)), axis=-1)
150 |
151 | return voxels, points_sample
152 |
153 | class TripoSFVAEInference(torch.nn.Module):
154 | @dataclass
155 | class Config:
156 | local_pc_encoder_cls: str = ""
157 | local_pc_encoder: dict = field(default_factory=dict)
158 |
159 | encoder_cls: str = ""
160 | encoder: dict = field(default_factory=dict)
161 |
162 | decoder_cls: str = ""
163 | decoder: dict = field(default_factory=dict)
164 |
165 | resolution: int = 256
166 | sample_points_num: int = 819_200
167 | use_normals: bool = True
168 | pruning: bool = False
169 |
170 | weight: Optional[str] = None
171 |
172 | cfg: Config
173 |
174 | def __init__(self, cfg):
175 | super().__init__()
176 | self.cfg = cfg
177 | self.configure()
178 |
179 | def load_weights(self):
180 | if self.cfg.weight is not None:
181 | print("Pretrained VAE Loading...")
182 | state_dict = load_file(self.cfg.weight)
183 | self.load_state_dict(state_dict)
184 |
185 | def configure(self) -> None:
186 | self.local_pc_encoder = find(self.cfg.local_pc_encoder_cls)(**self.cfg.local_pc_encoder).eval()
187 | for p in self.local_pc_encoder.parameters():
188 | p.requires_grad = False
189 |
190 | self.encoder = find(self.cfg.encoder_cls)(**self.cfg.encoder).eval()
191 | for p in self.encoder.parameters():
192 | p.requires_grad = False
193 |
194 | self.decoder = find(self.cfg.decoder_cls)(**self.cfg.decoder).eval()
195 | for p in self.decoder.parameters():
196 | p.requires_grad = False
197 |
198 | self.load_weights()
199 |
200 | @torch.no_grad()
201 | def forward(self, points_sample, sparse_voxel_coords):
202 | with torch.autocast("cuda", dtype=torch.float32):
203 | sparse_pc_features = self.local_pc_encoder(points_sample, sparse_voxel_coords, res=self.cfg.resolution, bbox_size=(-0.5, 0.5))
204 | sparse_tensor = sp.SparseTensor(sparse_pc_features, sparse_voxel_coords)
205 | latent, posterior = self.encoder(sparse_tensor)
206 | mesh = self.decoder(latent, pruning=self.cfg.pruning)
207 | return mesh
208 |
209 | @classmethod
210 | def from_config(cls, config_path):
211 | config = OmegaConf.load(config_path)
212 | cfg = OmegaConf.merge(OmegaConf.structured(TripoSFVAEInference.Config), config)
213 | return cls(cfg)
214 |
215 | HEADER = """
216 | # TripoSF VAE Reconstruction [TripoSF](https://github.com/VAST-AI-Research/TripoSF)
217 | ## TripoSF represents a significant leap forward in 3D shape modeling, combining high-resolution capabilities with arbitrary topology support.
218 | ## 📋 Some Tips:
219 |
220 | 1. It is recommanded to enable `pruning` for open-surface objects
221 |
222 | 2. Increasing sampling points is helpful for reconstructing complex shapes
223 |
224 | By Tripo
225 |
226 | """
227 | MAX_SEED = np.iinfo(np.int32).max
228 | device = "cuda" if torch.cuda.is_available() else "cpu"
229 | config = "configs/TripoSFVAE_1024.yaml"
230 |
231 | random_hex = get_random_hex()
232 | output_dir = "outputs"
233 | os.makedirs(output_dir, exist_ok=True)
234 | model = TripoSFVAEInference.from_config(config).to(device)
235 |
236 | def run_normalize_mesh(input_mesh):
237 | mesh_gt = normalize_mesh(input_mesh)
238 | mesh_name = os.path.basename(input_mesh).split('.')[0]
239 | mesh_path_gt = f"{output_dir}/{mesh_name}_{random_hex}_normalized.obj"
240 | mesh_normalized = trimesh.Trimesh(vertices=mesh_gt.vertices.tolist(), faces=mesh_gt.faces.tolist())
241 | mesh_normalized.export(mesh_path_gt)
242 | return mesh_path_gt
243 |
244 | def run_reconstruction(input_mesh, sample_points_num, pruning, seed):
245 | model.cfg.pruning = pruning
246 | model.cfg.sample_points_num = sample_points_num
247 | mesh_name = os.path.basename(input_mesh).split('.')[0]
248 | mesh_path_gt = f"{output_dir}/{mesh_name}_{random_hex}_normalized.obj"
249 | sparse_voxels, points_sample = load_quantized_mesh_original(
250 | mesh_path_gt,
251 | volume_resolution=model.cfg.resolution,
252 | use_normals=model.cfg.use_normals,
253 | pc_sample_number=model.cfg.sample_points_num,
254 | )
255 |
256 | sparse_voxels, points_sample = sparse_voxels.to(device), points_sample.to(device)
257 | sparse_voxels_sp = torch.cat([torch.zeros_like(sparse_voxels[..., :1]), sparse_voxels], dim=-1).int()
258 |
259 | with torch.cuda.amp.autocast(dtype=torch.float16):
260 | mesh_recon = model(points_sample[None], sparse_voxels_sp)[0]
261 |
262 | mesh_path_recon = f"{output_dir}/{mesh_name}_{random_hex}_reconstructed.obj"
263 | mesh_reconstructed = trimesh.Trimesh(vertices=mesh_recon.vertices.tolist(), faces=mesh_recon.faces.tolist())
264 | mesh_reconstructed.export(mesh_path_recon)
265 | return mesh_path_recon
266 |
267 | with gr.Blocks(title="TripoSFRecon") as demo:
268 | gr.Markdown(HEADER)
269 |
270 | with gr.Row():
271 | with gr.Column():
272 | with gr.Row():
273 | input_mesh_path = gr.Textbox(label="Please input local path to the mesh to be reconstructed.", placeholder="For example: assets/examples/jacket.obj")
274 |
275 | with gr.Accordion("Reconstruction Settings", open=True):
276 | use_pruning = gr.Checkbox(label="Pruning", value=False)
277 | recon_button = gr.Button("Reconstruct Mesh", variant="primary")
278 | seed = gr.Slider(
279 | label="Seed",
280 | minimum=0,
281 | maximum=MAX_SEED,
282 | step=0,
283 | value=0
284 | )
285 | sample_points_num = gr.Slider(
286 | label="Sample point number",
287 | minimum=819200,
288 | maximum=8192000,
289 | step=100,
290 | value=819200
291 | )
292 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
293 | with gr.Column():
294 | normalized_model_output = gr.Model3D(label="Normalized Mesh", interactive=True)
295 | reconstructed_model_output = gr.Model3D(label="Reconstructed Mesh", interactive=True)
296 |
297 |
298 | recon_button.click(
299 | run_normalize_mesh,
300 | inputs=[input_mesh_path],
301 | outputs=[normalized_model_output]
302 | ).then(
303 | get_random_seed,
304 | inputs=[randomize_seed, seed],
305 | outputs=[seed],
306 | ).then(
307 | run_reconstruction,
308 | inputs=[input_mesh_path, sample_points_num, use_pruning, seed],
309 | outputs=[reconstructed_model_output],
310 | )
311 |
312 | demo.launch(server_name="0.0.0.0", server_port=12345)
313 |
--------------------------------------------------------------------------------
/triposf/models/triposf_vae/decoder.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) Microsoft Corporation.
4 | # Copyright (c) 2025 VAST-AI-Research and contributors.
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE
23 |
24 | from typing import *
25 | import torch
26 | import torch.nn as nn
27 | from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
28 | from ...modules import sparse as sp
29 | from .base import SparseTransformerBase
30 | from ...representations import MeshExtractResult
31 | from ...representations.mesh import SparseFeatures2Mesh
32 | from ...modules.sparse.linear import SparseLinear
33 | from ...modules.sparse.nonlinearity import SparseGELU
34 |
35 | class SparseOccHead(nn.Module):
36 | def __init__(self, channels: int, out_channels: int, mlp_ratio: float = 4.0):
37 | super().__init__()
38 | self.mlp = nn.Sequential(
39 | SparseLinear(channels, int(channels * mlp_ratio)),
40 | SparseGELU(approximate="tanh"),
41 | SparseLinear(int(channels * mlp_ratio), out_channels),
42 | )
43 |
44 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
45 | return self.mlp(x)
46 |
47 | class SparseSubdivideBlock3d(nn.Module):
48 | """
49 | A 3D subdivide block that can subdivide the sparse tensor.
50 |
51 | Args:
52 | channels: channels in the inputs and outputs.
53 | out_channels: if specified, the number of output channels.
54 | num_groups: the number of groups for the group norm.
55 | """
56 | def __init__(
57 | self,
58 | channels: int,
59 | resolution: int,
60 | out_channels: Optional[int] = None,
61 | num_groups: int = 32,
62 | ):
63 | super().__init__()
64 | self.channels = channels
65 | self.resolution = resolution
66 | self.out_resolution = resolution * 2
67 | self.out_channels = out_channels or channels
68 |
69 | self.act_layers = nn.Sequential(
70 | sp.SparseGroupNorm32(num_groups, channels),
71 | sp.SparseSiLU()
72 | )
73 |
74 | self.sub = sp.SparseSubdivide()
75 |
76 | self.out_layers = nn.Sequential(
77 | sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
78 | sp.SparseGroupNorm32(num_groups, self.out_channels),
79 | sp.SparseSiLU(),
80 | zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
81 | )
82 |
83 | if self.out_channels == channels:
84 | self.skip_connection = nn.Identity()
85 | else:
86 | self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
87 |
88 | self.pruning_head = SparseOccHead(self.out_channels, out_channels=1)
89 |
90 | def forward(self, x: sp.SparseTensor, pruning=False, training=True) -> sp.SparseTensor:
91 | """
92 | Apply the block to a Tensor, conditioned on a timestep embedding.
93 |
94 | Args:
95 | x: an [N x C x ...] Tensor of features.
96 | Returns:
97 | an [N x C x ...] Tensor of outputs.
98 | """
99 | h = self.act_layers(x)
100 | h = self.sub(h)
101 | x = self.sub(x)
102 | h = self.out_layers(h)
103 | h = h + self.skip_connection(x)
104 | if pruning:
105 | occ_prob = self.pruning_head(h)
106 | occ_mask = (occ_prob.feats >= 0.).squeeze(-1)
107 | if training == False:
108 | h = sp.SparseTensor(feats=h.feats[occ_mask], coords=h.coords[occ_mask])
109 | return h, occ_prob
110 | else:
111 | return h, None
112 |
113 |
114 | class TripoSFVAEDecoder(SparseTransformerBase):
115 | def __init__(
116 | self,
117 | resolution: int,
118 | model_channels: int,
119 | latent_channels: int,
120 | num_blocks: int,
121 | num_heads: Optional[int] = None,
122 | num_head_channels: Optional[int] = 64,
123 | mlp_ratio: float = 4,
124 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
125 | window_size: int = 8,
126 | pe_mode: Literal["ape", "rope"] = "ape",
127 | use_fp16: bool = False,
128 | use_checkpoint: bool = False,
129 | qk_rms_norm: bool = False,
130 | use_sparse_flexicube: bool = True,
131 | use_sparse_sparse_flexicube: bool = False,
132 | representation_config: dict = None,
133 | ):
134 | super().__init__(
135 | in_channels=latent_channels,
136 | model_channels=model_channels,
137 | num_blocks=num_blocks,
138 | num_heads=num_heads,
139 | num_head_channels=num_head_channels,
140 | mlp_ratio=mlp_ratio,
141 | attn_mode=attn_mode,
142 | window_size=window_size,
143 | pe_mode=pe_mode,
144 | use_fp16=use_fp16,
145 | use_checkpoint=use_checkpoint,
146 | qk_rms_norm=qk_rms_norm,
147 | )
148 | assert not (use_sparse_flexicube == False and use_sparse_sparse_flexicube == True)
149 | self.resolution = resolution
150 | self.rep_config = representation_config
151 | self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False), use_sparse_flexicube=use_sparse_flexicube, use_sparse_sparse_flexicube=use_sparse_sparse_flexicube)
152 | self.out_channels = self.mesh_extractor.feats_channels
153 | self.upsample = nn.ModuleList([
154 | SparseSubdivideBlock3d(
155 | channels=model_channels,
156 | resolution=resolution,
157 | out_channels=model_channels // 4,
158 | ),
159 | SparseSubdivideBlock3d(
160 | channels=model_channels // 4,
161 | resolution=resolution * 2,
162 | out_channels=model_channels // 8,
163 | )
164 | ])
165 | self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
166 |
167 | self.initialize_weights()
168 | if use_fp16:
169 | self.convert_to_fp16()
170 |
171 | def initialize_weights(self) -> None:
172 | super().initialize_weights()
173 | # Zero-out output layers:
174 | nn.init.constant_(self.out_layer.weight, 0)
175 | nn.init.constant_(self.out_layer.bias, 0)
176 | for module in self.upsample:
177 | nn.init.constant_(module.pruning_head.mlp[-1].weight, 0)
178 | nn.init.constant_(module.pruning_head.mlp[-1].bias, 0)
179 | def convert_to_fp16(self) -> None:
180 | """
181 | Convert the torso of the model to float16.
182 | """
183 | super().convert_to_fp16()
184 | self.upsample.apply(convert_module_to_f16)
185 |
186 | def convert_to_fp32(self) -> None:
187 | """
188 | Convert the torso of the model to float32.
189 | """
190 | super().convert_to_fp32()
191 | self.upsample.apply(convert_module_to_f32)
192 |
193 | def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
194 | """
195 | Convert a batch of network outputs to 3D representations.
196 |
197 | Args:
198 | x: The [N x * x C] sparse tensor output by the network.
199 |
200 | Returns:
201 | list of representations
202 | """
203 | ret = []
204 | for i in range(x.shape[0]):
205 | mesh = self.mesh_extractor(x[i].float(), training=self.training)
206 | ret.append(mesh)
207 |
208 | return ret
209 |
210 | @torch.no_grad()
211 | def split_for_meshing(self, x: sp.SparseTensor, chunk_size=4, padding=4, verbose=False):
212 |
213 | sub_resolution = self.resolution // chunk_size
214 | upsample_ratio = 4 # hard-coded here
215 | assert sub_resolution % padding == 0
216 | out = []
217 | if verbose:
218 | print(f"Input coords range: x[{x.coords[:, 1].min()}, {x.coords[:, 1].max()}], "
219 | f"y[{x.coords[:, 2].min()}, {x.coords[:, 2].max()}], "
220 | f"z[{x.coords[:, 3].min()}, {x.coords[:, 3].max()}]")
221 | print(f"Resolution: {self.resolution}, sub_resolution: {sub_resolution}")
222 |
223 | for i in range(chunk_size):
224 | for j in range(chunk_size):
225 | for k in range(chunk_size):
226 | # Calculate padded boundaries
227 | start_x = max(0, i * sub_resolution - padding)
228 | end_x = min((i + 1) * sub_resolution + padding, self.resolution)
229 | start_y = max(0, j * sub_resolution - padding)
230 | end_y = min((j + 1) * sub_resolution + padding, self.resolution)
231 | start_z = max(0, k * sub_resolution - padding)
232 | end_z = min((k + 1) * sub_resolution + padding, self.resolution)
233 |
234 | # Store original (unpadded) boundaries for later cropping
235 | orig_start_x = i * sub_resolution
236 | orig_end_x = (i + 1) * sub_resolution
237 | orig_start_y = j * sub_resolution
238 | orig_end_y = (j + 1) * sub_resolution
239 | orig_start_z = k * sub_resolution
240 | orig_end_z = (k + 1) * sub_resolution
241 |
242 | if verbose:
243 | print(f"\nChunk ({i},{j},{k}):")
244 | print(f"Padded bounds: x[{start_x}, {end_x}], y[{start_y}, {end_y}], z[{start_z}, {end_z}]")
245 | print(f"Original bounds: x[{orig_start_x}, {orig_end_x}], y[{orig_start_y}, {orig_end_y}], z[{orig_start_z}, {orig_end_z}]")
246 |
247 | mask = torch.logical_and(
248 | torch.logical_and(
249 | torch.logical_and(x.coords[:, 1] >= start_x, x.coords[:, 1] < end_x),
250 | torch.logical_and(x.coords[:, 2] >= start_y, x.coords[:, 2] < end_y)
251 | ),
252 | torch.logical_and(x.coords[:, 3] >= start_z, x.coords[:, 3] < end_z)
253 | )
254 |
255 | if mask.sum() > 0:
256 | # Get the coordinates and shift them to local space
257 | coords = x.coords[mask].clone()
258 | if verbose:
259 | print(f"Before local shift - coords range: x[{coords[:, 1].min()}, {coords[:, 1].max()}], "
260 | f"y[{coords[:, 2].min()}, {coords[:, 2].max()}], "
261 | f"z[{coords[:, 3].min()}, {coords[:, 3].max()}]")
262 |
263 | # Shift to local coordinates
264 | coords[:, 1:] = coords[:, 1:] - torch.tensor([start_x, start_y, start_z],
265 | device=coords.device).view(1, 3)
266 | if verbose:
267 | print(f"After local shift - coords range: x[{coords[:, 1].min()}, {coords[:, 1].max()}], "
268 | f"y[{coords[:, 2].min()}, {coords[:, 2].max()}], "
269 | f"z[{coords[:, 3].min()}, {coords[:, 3].max()}]")
270 |
271 | chunk_tensor = sp.SparseTensor(x.feats[mask], coords)
272 | # Store the boundaries and offsets as metadata for later reconstruction
273 | chunk_tensor.bounds = {
274 | 'padded': (start_x * upsample_ratio, end_x * upsample_ratio + (upsample_ratio - 1), start_y * upsample_ratio, end_y * upsample_ratio + (upsample_ratio - 1), start_z * upsample_ratio, end_z * upsample_ratio + (upsample_ratio - 1)),
275 | 'original': (orig_start_x * upsample_ratio, orig_end_x * upsample_ratio + (upsample_ratio - 1), orig_start_y * upsample_ratio, orig_end_y * upsample_ratio + (upsample_ratio - 1), orig_start_z * upsample_ratio, orig_end_z * upsample_ratio + (upsample_ratio - 1)),
276 | 'offsets': (start_x * upsample_ratio, start_y * upsample_ratio, start_z * upsample_ratio) # Store offsets for reconstruction
277 | }
278 | out.append(chunk_tensor)
279 |
280 | del mask
281 | torch.cuda.empty_cache()
282 | return out
283 |
284 | @torch.no_grad()
285 | def upsamples(self, chunk: sp.SparseTensor, pruning=False): # Only for inferencing
286 | dtype = chunk.dtype
287 | for block in self.upsample:
288 | chunk, _ = block(chunk, pruning=pruning, training=False)
289 | chunk = chunk.type(dtype)
290 | chunk = self.out_layer(chunk)
291 | return chunk
292 |
293 | def forward(self, x: sp.SparseTensor, pruning=False):
294 | batch_size = x.shape[0]
295 | chunk_size = 8 # hard-coded to balance memory usage and reconstruction speed duing inference
296 | if x.coords.shape[0] < 150000:
297 | chunk_size = 4
298 |
299 | h = super().forward(x)
300 | chunks = self.split_for_meshing(h, chunk_size=chunk_size, verbose=False)
301 | all_coords, all_feats = [], []
302 |
303 | for chunk_idx, chunk in enumerate(chunks):
304 | try:
305 | chunk_result = self.upsamples(chunk, pruning=pruning)
306 | except:
307 | print(f"Failed to process chunk {chunk_idx}: {e}")
308 | continue
309 |
310 | for b in range(batch_size):
311 | mask = torch.nonzero(chunk_result.coords[:, 0] == b).squeeze(-1)
312 | if mask.numel() > 0:
313 | coords = chunk_result.coords[mask].clone()
314 |
315 | # Restore global coordinates
316 | offsets = torch.tensor(chunk.bounds['offsets'],
317 | device=coords.device).view(1, 3)
318 | coords[:, 1:] = coords[:, 1:] + offsets
319 |
320 | # Filter points within original bounds
321 | bounds = chunk.bounds['original']
322 | within_bounds = torch.logical_and(
323 | torch.logical_and(
324 | torch.logical_and(
325 | coords[:, 1] >= bounds[0],
326 | coords[:, 1] < bounds[1]
327 | ),
328 | torch.logical_and(
329 | coords[:, 2] >= bounds[2],
330 | coords[:, 2] < bounds[3]
331 | )
332 | ),
333 | torch.logical_and(
334 | coords[:, 3] >= bounds[4],
335 | coords[:, 3] < bounds[5]
336 | )
337 | )
338 |
339 | if within_bounds.any():
340 | all_coords.append(coords[within_bounds])
341 | all_feats.append(chunk_result.feats[mask][within_bounds])
342 |
343 | torch.cuda.empty_cache()
344 |
345 | if len(all_coords) > 0:
346 | final_coords = torch.cat(all_coords)
347 | final_feats = torch.cat(all_feats)
348 |
349 | return self.to_representation(sp.SparseTensor(final_feats, final_coords))
350 | else:
351 | return self.to_representation(sp.SparseTensor(x.feats[:0], x.coords[:0]))
--------------------------------------------------------------------------------