├── .gitignore
├── IF_Trellis.py
├── IF_TrellisCheckpointLoader.py
├── LICENSE
├── README.md
├── StableXWrapper.py
├── __init__.py
├── assets
└── teaser.png
├── linux_requirements.txt
├── pyproject.toml
├── requirements.txt
├── stablex
├── __init__.py
├── controlnetvae.py
└── pipeline_yoso.py
├── trellis
├── __init__.py
├── backend_config.py
├── models
│ ├── __init__.py
│ ├── sparse_structure_flow.py
│ ├── sparse_structure_vae.py
│ ├── structured_latent_flow.py
│ └── structured_latent_vae
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── decoder_mesh.py
│ │ └── encoder.py
├── modules
│ ├── __init__.py
│ ├── attention
│ │ ├── __init__.py
│ │ ├── full_attn.py
│ │ └── modules.py
│ ├── attention_utils.py
│ ├── norm.py
│ ├── sparse
│ │ ├── __init__.py
│ │ ├── attention
│ │ │ ├── __init__.py
│ │ │ ├── full_attn.py
│ │ │ ├── modules.py
│ │ │ ├── serialized_attn.py
│ │ │ └── windowed_attn.py
│ │ ├── basic.py
│ │ ├── conv
│ │ │ ├── __init__.py
│ │ │ ├── conv_spconv.py
│ │ │ └── conv_torchsparse.py
│ │ ├── linear.py
│ │ ├── nonlinearity.py
│ │ ├── norm.py
│ │ ├── spatial.py
│ │ └── transformer
│ │ │ ├── __init__.py
│ │ │ ├── blocks.py
│ │ │ └── modulated.py
│ ├── spatial.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── blocks.py
│ │ └── modulated.py
│ └── utils.py
├── pipelines
│ ├── __init__.py
│ ├── base.py
│ ├── samplers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── classifier_free_guidance_mixin.py
│ │ ├── flow_euler.py
│ │ └── guidance_interval_mixin.py
│ └── trellis_image_to_3d.py
├── representations
│ ├── __init__.py
│ └── mesh
│ │ ├── __init__.py
│ │ ├── cube2mesh.py
│ │ ├── flexicube.py
│ │ ├── tables.py
│ │ └── utils_cube.py
└── utils
│ ├── __init__.py
│ ├── _rasterization.py
│ ├── general_utils.py
│ └── random_utils.py
├── trellis_model_manager.py
├── win_requirements.txt
└── workflow
└── Hi3DGen_WF_single.json
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/IF_TrellisCheckpointLoader.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/if-ai/ComfyUI-IF_Trellis/blob/main/IF_TrellisCheckpointLoader.py
2 | import os
3 | import logging
4 | import torch
5 | import huggingface_hub
6 | import folder_paths
7 | from trellis_model_manager import TrellisModelManager
8 | from trellis.pipelines.trellis_image_to_3d import TrellisImageTo3DPipeline
9 | from trellis.backend_config import (
10 | set_attention_backend,
11 | set_sparse_backend,
12 | get_available_backends,
13 | get_available_sparse_backends
14 | )
15 | from typing import Literal
16 | from torchvision import transforms
17 |
18 | logger = logging.getLogger("IF_Trellis")
19 |
20 | class IF_TrellisCheckpointLoader:
21 | """
22 | Node to manage the loading of the TRELLIS model with lazy backend selection.
23 | """
24 | def __init__(self):
25 | self.logger = logger
26 | self.model_manager = None
27 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 |
29 | # We might call these to figure out what's actually installed,
30 | # if we want to populate UI dropdowns:
31 | self.attn_backends = get_available_backends() # e.g. { 'xformers': True, 'flash_attn': False, ... }
32 | self.sparse_backends = get_available_sparse_backends()# e.g. { 'spconv': True, 'torchsparse': True }
33 |
34 | @classmethod
35 | def INPUT_TYPES(cls):
36 | """Define input types with device-specific options."""
37 | # Filter only available backends
38 | attn_backends = get_available_backends()
39 | sparse_backends = get_available_sparse_backends()
40 |
41 | # e.g. create a list of names that are True:
42 | available_attn = [k for k, v in attn_backends.items() if v]
43 | if not available_attn:
44 | available_attn = ['flash_attn'] # fallback
45 |
46 | available_sparse = [k for k, v in sparse_backends.items() if v]
47 | if not available_sparse:
48 | available_sparse = ['spconv'] # fallback
49 |
50 | return {
51 | "required": {
52 | "model_name": (["trellis-normal-v0-1"],),
53 | "dinov2_model": (["dinov2_vitl14_reg"],
54 | {"default": "dinov2_vitl14_reg",
55 | "tooltip": "Select which Dinov2 model to use."}),
56 | "use_fp16": ("BOOLEAN", {"default": True}),
57 | #
58 | # The user picks from the actually installed backends
59 | #
60 | "attn_backend": (available_attn,
61 | {"default": "flash_attn" if "flash_attn" in available_attn else available_attn[0],
62 | "tooltip": "Select attention backend."}),
63 | "sparse_backend": (available_sparse,
64 | {"default": "spconv" if "spconv" in available_sparse else available_sparse[0],
65 | "tooltip": "Select sparse backend."}),
66 | "spconv_algo": (["implicit_gemm", "native", "auto"],
67 | {"default": "implicit_gemm",
68 | "tooltip": "Spconv algorithm. 'implicit_gemm' is slower but more robust."}),
69 | "smooth_k": ("BOOLEAN",
70 | {"default": True,
71 | "tooltip": "Smooth-k for SageAttention. Only relevant if attn_backend=sage."}),
72 | },
73 | }
74 |
75 | RETURN_TYPES = ("TRELLIS_MODEL",)
76 | RETURN_NAMES = ("model",)
77 | FUNCTION = "load_model"
78 | CATEGORY = "ImpactFrames💥🎞️/Trellis"
79 |
80 | def _setup_environment(self, attn_backend: str, sparse_backend: str, spconv_algo: str, smooth_k: bool):
81 | """
82 | Set up environment variables and backends lazily.
83 | This is the main difference: we call our new lazy set_*_backend funcs.
84 | """
85 | # Try attention
86 | success = set_attention_backend(attn_backend)
87 | if not success:
88 | self.logger.warning(f"Failed to set {attn_backend} or not installed, fallback to sdpa.")
89 |
90 | # Try sparse
91 | success2 = set_sparse_backend(sparse_backend, spconv_algo)
92 | if not success2:
93 | self.logger.warning(f"Failed to set {sparse_backend} or not installed, fallback to default.")
94 |
95 | # If user wants SageAttn smooth_k, we set environment var (if they'd want that):
96 | os.environ['SAGEATTN_SMOOTH_K'] = '1' if smooth_k else '0'
97 |
98 | def _initialize_transforms(self):
99 | """Initialize image transforms if needed."""
100 | return transforms.Compose([
101 | transforms.Normalize(
102 | mean=[0.485, 0.456, 0.406],
103 | std=[0.229, 0.224, 0.225]
104 | )
105 | ])
106 |
107 | def _optimize_pipeline(self, pipeline, use_fp16: bool = True):
108 | """
109 | Apply typical optimizations, half-precision, etc.
110 | """
111 | if self.device.type == "cuda":
112 | try:
113 | if hasattr(pipeline, 'cuda'):
114 | pipeline.cuda()
115 |
116 | if use_fp16:
117 | if hasattr(pipeline, 'enable_attention_slicing'):
118 | pipeline.enable_attention_slicing(slice_size="auto")
119 | if hasattr(pipeline, 'half'):
120 | pipeline.half()
121 | except Exception as e:
122 | logger.warning(f"Some pipeline optimizations failed: {str(e)}")
123 |
124 | return pipeline
125 |
126 | def load_model(
127 | self,
128 | model_name: str,
129 | dinov2_model: str = "dinov2_vitl14_reg",
130 | attn_backend: str = "sdpa",
131 | sparse_backend: str = "spconv",
132 | spconv_algo: str = "implicit_gemm",
133 | use_fp16: bool = True,
134 | smooth_k: bool = True,
135 | ) -> tuple:
136 | """
137 | Load and configure the TRELLIS pipeline.
138 | This is typically the main function invoked by ComfyUI at node execution time.
139 | """
140 | try:
141 | # 1) Setup environment + backends
142 | self._setup_environment(attn_backend, sparse_backend, spconv_algo, smooth_k)
143 |
144 | # 2) Get model paths, download if needed
145 | model_path = os.path.join(folder_paths.models_dir, "checkpoints", model_name)
146 | if not os.path.exists(model_path) or not os.listdir(model_path):
147 | repo_id = "Stable-X"
148 | try:
149 | huggingface_hub.snapshot_download(
150 | f"{repo_id}/{model_name}",
151 | repo_type="model",
152 | local_dir=model_path
153 | )
154 | except Exception as e:
155 | raise RuntimeError(f"Failed to download {repo_id}/{model_name} to: {model_path}, {e}")
156 |
157 | # 3) Create pipeline with the config
158 | pipeline = TrellisImageTo3DPipeline.from_pretrained(
159 | model_path,
160 | dinov2_model=dinov2_model
161 | )
162 | pipeline._device = self.device # ensure pipeline uses our same device
163 |
164 | # 4) Apply optimizations
165 | pipeline = self._optimize_pipeline(pipeline, use_fp16)
166 |
167 | return (pipeline,)
168 |
169 | except Exception as e:
170 | logger.error(f"Error loading TRELLIS model: {str(e)}")
171 | raise
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-Hi3DGen
2 |
3 |
4 | [](https://stable-x.github.io/Hi3DGen/)
5 | [](https://arxiv.org/abs/2503.22236)
6 | [](https://huggingface.co/spaces/Stable-X/Hi3DGen)
7 | [](https://huggingface.co/Stable-X/trellis-normal-v0-1)
8 |
9 |
10 | This extension integrates [Hi3DGen](https://github.com/Stable-X/Hi3DGen) into ComfyUI, allowing user to generate high-fidelity 3D geometry generation from Images.
11 | 
12 |
13 | ## 🔥 Feature Updates
14 | * [2024-04-02] Support window installation
15 | * [2024-03-31] Support single-view image to 3D goemetry generation
16 |
17 | ## Installation
18 |
19 | ### From Source (Linux)
20 | * Clone or download this repository into your `ComfyUI/custom_nodes/` directory.
21 | * Install the required dependencies by running `pip install -r linux_requirements.txt`.
22 |
23 | ### From Source (Window)
24 | * Clone or download this repository into your `ComfyUI/custom_nodes/` directory.
25 | * Install the required dependencies by running `pip install -r win_requirements.txt`.
26 |
27 | ## Notes
28 |
29 | ### Workflows
30 |
31 | We provide the example workflows in `workflows` directory.
32 |
33 | Note that our code depends on diffusers, and will automatically download the model weights from huggingface to the hf cache path at the first time. The `model_name` in the node corresponds to the model name in huggingface, such as `yoso-normal-v1-8-1`.
34 |
35 | ## Usage
36 | ### Single-view Image to 3D
37 | * `workflow/Hi3DGen_WF_single.json`
38 |
39 | ## Acknowledgement
40 | This repository builds upon the excellent work in [Trellis](https://github.com/microsoft/TRELLIS), [ComfyUI-IF_Trellis](https://github.com/if-ai/ComfyUI-IF_Trellis) and [ComfyUI-StableXWrapper](https://github.com/kijai/ComfyUI-StableXWrapper). Special thanks to their developers for the foundational work that made this project possible.
41 |
--------------------------------------------------------------------------------
/StableXWrapper.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/kijai/ComfyUI-StableXWrapper/blob/main/nodes.py
2 | import os
3 | import torch
4 | import json
5 |
6 | from accelerate import init_empty_weights
7 | from accelerate.utils import set_module_tensor_to_device
8 |
9 | from stablex.pipeline_yoso import YosoPipeline
10 | from stablex.controlnetvae import ControlNetVAEModel
11 |
12 | from diffusers.models import (
13 | AutoencoderKL,
14 | UNet2DConditionModel,
15 | )
16 |
17 | import folder_paths
18 | import comfy.model_management as mm
19 | from comfy.utils import load_torch_file, ProgressBar
20 |
21 | script_directory = os.path.dirname(os.path.abspath(__file__))
22 | import logging
23 | log = logging.getLogger(__name__)
24 |
25 | #region Model loading
26 |
27 | class DownloadAndLoadStableXModel:
28 | @classmethod
29 | def INPUT_TYPES(s):
30 | return {
31 | "required": {
32 | "model": (["yoso-normal-v1-8-1"],),
33 | },
34 | }
35 |
36 | RETURN_TYPES = ("YOSOPIPE",)
37 | RETURN_NAMES = ("pipeline", )
38 | FUNCTION = "loadmodel"
39 | CATEGORY = "StableXWrapper"
40 |
41 | def loadmodel(self, model):
42 | device = mm.get_torch_device()
43 | offload_device = mm.unet_offload_device()
44 |
45 | download_path = os.path.join(folder_paths.models_dir,"diffusers")
46 | model_path = os.path.join(download_path, model)
47 |
48 | if not os.path.exists(model_path):
49 | log.info(f"Downloading model to: {model_path}")
50 | from huggingface_hub import snapshot_download
51 | snapshot_download(
52 | repo_id=f"Stable-X/{model}",
53 | #allow_patterns=[f"*{model}*"],
54 | ignore_patterns=["*text_encoder*", "tokenizer*", "*scheduler*"],
55 | local_dir=model_path,
56 | local_dir_use_symlinks=False,
57 | )
58 |
59 | torch_dtype = torch.float16
60 | config_path = os.path.join(model_path, 'unet', 'config.json')
61 | unet_ckpt_path_safetensors = os.path.join(model_path, 'unet','diffusion_pytorch_model.fp16.safetensors')
62 |
63 | if not os.path.exists(config_path):
64 | raise FileNotFoundError(f"Config not found at {config_path}")
65 |
66 | with open(config_path, 'r', encoding='utf-8') as file:
67 | config = json.load(file)
68 |
69 | with init_empty_weights():
70 | unet = UNet2DConditionModel(**config)
71 |
72 | if os.path.exists(unet_ckpt_path_safetensors):
73 | import safetensors.torch
74 | unet_sd = safetensors.torch.load_file(unet_ckpt_path_safetensors)
75 | else:
76 | raise FileNotFoundError(f"No checkpoint found at {unet_ckpt_path_safetensors}")
77 |
78 | for name, param in unet.named_parameters():
79 | set_module_tensor_to_device(unet, name, device=offload_device, dtype=torch_dtype, value=unet_sd[name])
80 |
81 | vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", variant="fp16", device=device, torch_dtype=torch_dtype)
82 | controlnet = ControlNetVAEModel.from_pretrained(model_path, subfolder="controlnet", variant="fp16", device=device, torch_dtype=torch_dtype)
83 |
84 | pipeline = YosoPipeline(
85 | unet=unet,
86 | vae = vae,
87 | controlnet = controlnet,
88 | )
89 |
90 | #pipeline.enable_model_cpu_offload()
91 | return (pipeline,)
92 |
93 | class StableXProcessImage:
94 | @classmethod
95 | def INPUT_TYPES(s):
96 | return {
97 | "required": {
98 | "pipeline": ("YOSOPIPE",),
99 | "image": ("IMAGE", ),
100 | "processing_resolution": ("INT", {"default": 2048, "min": 64, "max": 4096, "step": 16}),
101 | "controlnet_strength": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 10.0, "step": 0.01, "tooltip": "controlnet condition scale"}),
102 | "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Seed only affects normal prediction mode"}),
103 | },
104 | }
105 |
106 | RETURN_TYPES = ("IMAGE",)
107 | RETURN_NAMES = ("image",)
108 | FUNCTION = "process"
109 | CATEGORY = "StableXWrapper"
110 |
111 | def process(self, pipeline, image, processing_resolution,controlnet_strength, seed):
112 |
113 | device = mm.get_torch_device()
114 | offload_device = mm.unet_offload_device()
115 |
116 | image = image.permute(0, 3, 1, 2).to(device).to(torch.float16)
117 |
118 | pipeline.unet.to(device)
119 | pipeline.vae.to(device)
120 | pipeline.controlnet.to(device)
121 |
122 | pipe_out = pipeline(
123 | image,
124 | controlnet_conditioning_scale=controlnet_strength,
125 | processing_resolution=processing_resolution,
126 | generator = torch.Generator(device=device).manual_seed(seed),
127 | output_type="pt",
128 | )
129 | pipeline.unet.to(offload_device)
130 | pipeline.vae.to(offload_device)
131 | pipeline.controlnet.to(offload_device)
132 | pipe_out = (pipe_out.prediction.clip(-1, 1) + 1) / 2
133 |
134 | out_tensor = pipe_out.permute(0, 2, 3, 1).cpu().float()
135 |
136 | return (out_tensor, )
137 |
138 | class DifferenceExtractorNode:
139 | @classmethod
140 | def INPUT_TYPES(s):
141 | return {
142 | "required": {
143 | "original_image": ("IMAGE",),
144 | "processed_image": ("IMAGE",),
145 | "amplification": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
146 | }
147 | }
148 |
149 | RETURN_TYPES = ("IMAGE",)
150 | FUNCTION = "extract_luminosity_difference"
151 | CATEGORY = "iStableXWrapper"
152 |
153 | def extract_luminosity_difference(self, original_image, processed_image, amplification=1.0):
154 | import torch
155 |
156 | # RGB to luminosity conversion weights
157 | rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722]).to(original_image.device)
158 |
159 | # Convert images to luminosity (shape: B,H,W)
160 | original_lum = torch.sum(original_image * rgb_weights[None, None, None, :], dim=3)
161 | processed_lum = torch.sum(processed_image * rgb_weights[None, None, None, :], dim=3)
162 |
163 | # Calculate luminosity difference
164 | difference = (original_lum - processed_lum) * amplification
165 |
166 | # Normalize and clamp
167 | difference = torch.clamp(difference, 0, 1)
168 |
169 | # Convert back to RGB format (all channels identical)
170 | difference = difference.unsqueeze(3).repeat(1, 1, 1, 3)
171 |
172 | return (difference,)
173 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | #__init__.py
2 | import os
3 | import sys
4 | import torch
5 | import logging
6 | import platform
7 | import folder_paths
8 |
9 | # Configure logging
10 | logging.basicConfig(level=logging.INFO)
11 | logger = logging.getLogger('ComfyUI-Hi3DGen')
12 |
13 | # Add parent directory to Python path
14 | current_dir = os.path.dirname(os.path.abspath(__file__))
15 | parent_dir = os.path.dirname(current_dir)
16 |
17 | # Add both current and parent dir to handle different installation scenarios
18 | if current_dir not in sys.path:
19 | sys.path.insert(0, current_dir)
20 | if parent_dir not in sys.path:
21 | sys.path.insert(0, parent_dir)
22 |
23 | # Add trellis package path
24 | trellis_path = os.path.join(current_dir, "trellis")
25 | if os.path.exists(trellis_path) and trellis_path not in sys.path:
26 | sys.path.insert(0, trellis_path)
27 | logger.info(f"Added trellis path to sys.path: {trellis_path}")
28 |
29 | # Add stablx package path
30 | stablex_path = os.path.join(current_dir, "stablex")
31 | if os.path.exists(trellis_path) and trellis_path not in sys.path:
32 | sys.path.insert(0, trellis_path)
33 | logger.info(f"Added stablex path to sys.path: {trellis_path}")
34 |
35 | # Verify trellis package is importable
36 | try:
37 | import trellis
38 | logger.info("Trellis package imported successfully")
39 | except ImportError as e:
40 | logger.error(f"Failed to import trellis package: {e}")
41 | logger.error(f"Current sys.path: {sys.path}")
42 | raise
43 |
44 | # Verify stablex package is importable
45 | try:
46 | import stablex
47 | logger.info("stablex package imported successfully")
48 | except ImportError as e:
49 | logger.error(f"Failed to import stablex package: {e}")
50 | logger.error(f"Current sys.path: {sys.path}")
51 | raise
52 |
53 | # Register model paths with ComfyUI
54 | try:
55 | folder_paths.add_model_folder_path("trellis", os.path.join(folder_paths.models_dir, "trellis"))
56 | folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.models_dir, "checkpoints"))
57 | except Exception as e:
58 | logger.error(f"Error registering model paths: {e}")
59 |
60 | # Register model paths with ComfyUI
61 | try:
62 | folder_paths.add_model_folder_path("stablex", os.path.join(folder_paths.models_dir, "stablex"))
63 | except Exception as e:
64 | logger.error(f"Error registering model paths: {e}")
65 |
66 | try:
67 | from IF_TrellisCheckpointLoader import IF_TrellisCheckpointLoader
68 | from IF_Trellis import IF_TrellisImageTo3D
69 | from StableXWrapper import DownloadAndLoadStableXModel, StableXProcessImage, DifferenceExtractorNode
70 | NODE_CLASS_MAPPINGS = {
71 | "IF_TrellisCheckpointLoader": IF_TrellisCheckpointLoader,
72 | "IF_TrellisImageTo3D": IF_TrellisImageTo3D,
73 | "DownloadAndLoadStableXModel": DownloadAndLoadStableXModel,
74 | "StableXProcessImage": StableXProcessImage,
75 | "DifferenceExtractorNode": DifferenceExtractorNode
76 | }
77 |
78 | NODE_DISPLAY_NAME_MAPPINGS = {
79 | "IF_TrellisCheckpointLoader": "Trellis Model Loader 💾",
80 | "IF_TrellisImageTo3D": "Trellis Image to 3D 🖼️➡️🎲",
81 | "DownloadAndLoadStableXModel": "(Down)load StableX Model",
82 | "StableXProcessImage": "StableX Process Image",
83 | "DifferenceExtractorNode": "Extract Difference"
84 | }
85 |
86 | except Exception as e:
87 | logger.error(f"Error importing node classes: {e}")
88 | raise
89 |
90 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-X/ComfyUI-Hi3DGen/99621aa48dd05203bee91cf3d5813669e0b5721d/assets/teaser.png
--------------------------------------------------------------------------------
/linux_requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.4.0
2 | torchvision==0.19.0
3 | pillow==10.4.0
4 | imageio==2.36.1
5 | imageio-ffmpeg==0.5.1
6 | tqdm==4.67.1
7 | easydict==1.13
8 | opencv-python-headless==4.10.0.84
9 | scipy==1.14.1
10 | rembg==2.0.60
11 | onnxruntime==1.20.1
12 | trimesh==4.5.3
13 | git+https://github.com/EasternJournalist/utils3d.git
14 | xformers==0.0.27.post2
15 | spconv-cu120==2.3.6
16 | transformers==4.46.3
17 | accelerate==1.5.2
18 | diffusers==0.32.2
19 |
20 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui-hi3dgen"
3 | description = "ComfyUI Hi3DGen creates a high-fiedlity 3d mesh from a single view or multi angle pictures"
4 | version = "0.2.5"
5 | license = { file = "MIT License" }
6 | dependencies = [
7 | "pillow",
8 | "imageio",
9 | "imageio-ffmpeg",
10 | "tqdm",
11 | "easydict",
12 | "opencv-python-headless",
13 | "scipy",
14 | "ninja",
15 | "plyfile",
16 | "rembg",
17 | "onnxruntime-gpu",
18 | "trimesh",
19 | "git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8",
20 | "xformers",
21 | "spconv-cu124",
22 | "diffrp-nvdiffrast",
23 |
24 | # triton for linux
25 | 'triton; sys_platform == "linux"',
26 | # triton for windows
27 | 'https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp312-cp312-win_amd64.whl; sys_platform == "win64" and (python_version >= "3.12" and python_version < "3.13")',
28 | 'https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp311-cp311-win_amd64.whl; sys_platform == "win64" and (python_version >= "3.11" and python_version < "3.12")',
29 | 'https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp310-cp310-win_amd64.whl; sys_platform == "win64" and (python_version >= "3.10" and python_version < "3.11")',
30 | 'https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp38-cp38-win_amd64.whl; sys_platform == "win64" and (python_version >= "3.8" and python_version < "3.9")'
31 | ]
32 |
33 | [project.urls]
34 | Repository = "https://github.com/Stable-X/ComfyUI-Hi3DGen"
35 | # Used by Comfy Registry https://comfyregistry.org
36 |
37 | [tool.comfy]
38 | PublisherId = "impactframes"
39 | DisplayName = "IF_Hi3DGen"
40 | Icon = "https://impactframes.ai/System/Icons/48x48/if.png"
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.4.0
2 | torchvision==0.19.0
3 | diffusers==0.32.2
4 | accelerate==1.5.2
5 | pillow
6 | imageio
7 | imageio-ffmpeg
8 | tqdm
9 | easydict
10 | opencv-python-headless
11 | scipy
12 | rembg
13 | ninja
14 | plyfile
15 | trimesh
16 | git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
17 | xformers
18 | triton
19 | spconv-cu124>=2.3.6
20 |
--------------------------------------------------------------------------------
/stablex/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-X/ComfyUI-Hi3DGen/99621aa48dd05203bee91cf3d5813669e0b5721d/stablex/__init__.py
--------------------------------------------------------------------------------
/trellis/__init__.py:
--------------------------------------------------------------------------------
1 | from . import models
2 | from . import modules
3 | from . import pipelines
4 | from . import representations
5 | from . import utils
6 |
--------------------------------------------------------------------------------
/trellis/backend_config.py:
--------------------------------------------------------------------------------
1 | # trellis/backend_config.py
2 | from typing import *
3 | import os
4 | import logging
5 | import importlib
6 |
7 | # Global variables
8 | BACKEND = 'spconv' # Default sparse backend
9 | DEBUG = False # Debug mode flag
10 | ATTN = 'xformers' # Default attention backend
11 | SPCONV_ALGO = 'implicit_gemm' # Default algorithm
12 |
13 | def get_spconv_algo() -> str:
14 | """Get current spconv algorithm."""
15 | global SPCONV_ALGO
16 | return SPCONV_ALGO
17 |
18 | def set_spconv_algo(algo: Literal['implicit_gemm', 'native', 'auto']) -> bool:
19 | """Set spconv algorithm with validation."""
20 | global SPCONV_ALGO
21 |
22 | if algo not in ['implicit_gemm', 'native', 'auto']:
23 | logger.warning(f"Invalid spconv algorithm: {algo}. Must be 'implicit_gemm', 'native', or 'auto'")
24 | return False
25 |
26 | SPCONV_ALGO = algo
27 | os.environ['SPCONV_ALGO'] = algo
28 | logger.info(f"Set spconv algorithm to: {algo}")
29 | return True
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | def _try_import_xformers() -> bool:
34 | try:
35 | import xformers.ops
36 | return True
37 | except ImportError:
38 | return False
39 |
40 | def _try_import_flash_attn() -> bool:
41 | try:
42 | import flash_attn
43 | return True
44 | except ImportError:
45 | return False
46 |
47 | def _try_import_sageattention() -> bool:
48 | try:
49 | import torch.nn.functional as F
50 | from sageattention import sageattn
51 | F.scaled_dot_product_attention = sageattn
52 | #import sageattention
53 | return True
54 | except ImportError:
55 | return False
56 |
57 | def _try_import_spconv() -> bool:
58 | try:
59 | import spconv
60 | return True
61 | except ImportError:
62 | return False
63 |
64 | def _try_import_torchsparse() -> bool:
65 | try:
66 | import torchsparse
67 | return True
68 | except ImportError:
69 | return False
70 |
71 | def get_available_backends() -> Dict[str, bool]:
72 | """Return dict of available attention backends and their status"""
73 | return {
74 | 'xformers': _try_import_xformers(),
75 | 'flash_attn': _try_import_flash_attn(),
76 | 'sage': _try_import_sageattention(),
77 | 'naive': True,
78 | 'sdpa': True # Always available with PyTorch >= 2.0
79 | }
80 |
81 | def get_available_sparse_backends() -> Dict[str, bool]:
82 | """Return dict of available sparse backends and their status"""
83 | return {
84 | 'spconv': _try_import_spconv(),
85 | 'torchsparse': _try_import_torchsparse()
86 | }
87 |
88 | def get_attention_backend() -> str:
89 | """Get current attention backend"""
90 | global ATTN
91 | return ATTN
92 |
93 | def get_sparse_backend() -> str:
94 | """Get current sparse backend"""
95 | global BACKEND
96 | return BACKEND
97 |
98 | def get_debug_mode() -> bool:
99 | """Get current debug mode status"""
100 | global DEBUG
101 | return DEBUG
102 |
103 | def __from_env():
104 | """Initialize settings from environment variables"""
105 | global BACKEND
106 | global DEBUG
107 | global ATTN
108 |
109 | env_sparse_backend = os.environ.get('SPARSE_BACKEND')
110 | env_sparse_debug = os.environ.get('SPARSE_DEBUG')
111 | env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
112 |
113 | if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
114 | BACKEND = env_sparse_backend
115 | if env_sparse_debug is not None:
116 | DEBUG = env_sparse_debug == '1'
117 | if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sage', 'sdpa', 'naive']:
118 | ATTN = env_sparse_attn
119 | os.environ['SPARSE_ATTN_BACKEND'] = env_sparse_attn
120 | os.environ['ATTN_BACKEND'] = env_sparse_attn
121 |
122 | logger.info(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
123 |
124 | def set_backend(backend: Literal['spconv', 'torchsparse']) -> bool:
125 | """Set sparse backend with validation"""
126 | global BACKEND
127 |
128 | backend = backend.lower().strip()
129 | logger.info(f"Setting sparse backend to: {backend}")
130 |
131 | if backend == 'spconv':
132 | try:
133 | import spconv
134 | BACKEND = 'spconv'
135 | os.environ['SPARSE_BACKEND'] = 'spconv'
136 | return True
137 | except ImportError:
138 | logger.warning("spconv not available")
139 | return False
140 |
141 | elif backend == 'torchsparse':
142 | try:
143 | import torchsparse
144 | BACKEND = 'torchsparse'
145 | os.environ['SPARSE_BACKEND'] = 'torchsparse'
146 | return True
147 | except ImportError:
148 | logger.warning("torchsparse not available")
149 | return False
150 |
151 | return False
152 |
153 | def set_sparse_backend(backend: Literal['spconv', 'torchsparse'], algo: str = None) -> bool:
154 | """Alias for set_backend for backwards compatibility
155 |
156 | Parameters:
157 | backend: The sparse backend to use
158 | algo: The algorithm to use (only relevant for spconv backend)
159 | """
160 | # Call set_backend first
161 | result = set_backend(backend)
162 |
163 | # If algorithm is provided and backend was set successfully
164 | if algo is not None and result:
165 | set_spconv_algo(algo)
166 |
167 | return result
168 |
169 | def set_debug(debug: bool):
170 | """Set debug mode"""
171 | global DEBUG
172 | DEBUG = debug
173 | if debug:
174 | os.environ['SPARSE_DEBUG'] = '1'
175 | else:
176 | os.environ['SPARSE_DEBUG'] = '0'
177 |
178 | def set_attn(attn: Literal['xformers', 'flash_attn', 'sage', 'sdpa', 'naive']) -> bool:
179 | """Set attention backend with validation"""
180 | global ATTN
181 |
182 | attn = attn.lower().strip()
183 | logger.info(f"Setting attention backend to: {attn}")
184 |
185 | if attn == 'xformers' and _try_import_xformers():
186 | ATTN = 'xformers'
187 | os.environ['SPARSE_ATTN_BACKEND'] = 'xformers'
188 | os.environ['ATTN_BACKEND'] = 'xformers'
189 | return True
190 |
191 | elif attn == 'flash_attn' and _try_import_flash_attn():
192 | ATTN = 'flash_attn'
193 | os.environ['SPARSE_ATTN_BACKEND'] = 'flash_attn'
194 | os.environ['ATTN_BACKEND'] = 'flash_attn'
195 | return True
196 |
197 | elif attn == 'sage' and _try_import_sageattention():
198 | ATTN = 'sage'
199 | os.environ['SPARSE_ATTN_BACKEND'] = 'sage'
200 | os.environ['ATTN_BACKEND'] = 'sage'
201 | return True
202 |
203 | elif attn == 'sdpa':
204 | ATTN = 'sdpa'
205 | os.environ['SPARSE_ATTN_BACKEND'] = 'sdpa'
206 | os.environ['ATTN_BACKEND'] = 'sdpa'
207 | return True
208 |
209 | elif attn == 'naive':
210 | ATTN = 'naive'
211 | os.environ['SPARSE_ATTN_BACKEND'] = 'naive'
212 | os.environ['ATTN_BACKEND'] = 'naive'
213 | return True
214 |
215 |
216 | logger.warning(f"Attention backend {attn} not available")
217 | return False
218 |
219 | # Add alias for backwards compatibility
220 | def set_attention_backend(backend: Literal['xformers', 'flash_attn', 'sage', 'sdpa']) -> bool:
221 | """Alias for set_attn for backwards compatibility"""
222 | return set_attn(backend)
223 |
224 | # Initialize from environment variables on module import
225 | __from_env()
226 |
--------------------------------------------------------------------------------
/trellis/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch
3 | import torch.nn as nn
4 |
5 | __attributes = {
6 | 'SparseStructureEncoder': 'sparse_structure_vae',
7 | 'SparseStructureDecoder': 'sparse_structure_vae',
8 | 'SparseStructureFlowModel': 'sparse_structure_flow',
9 | 'SLatEncoder': 'structured_latent_vae',
10 | 'SLatGaussianDecoder': 'structured_latent_vae',
11 | 'SLatRadianceFieldDecoder': 'structured_latent_vae',
12 | 'SLatMeshDecoder': 'structured_latent_vae',
13 | 'SLatFlowModel': 'structured_latent_flow',
14 | }
15 |
16 | __submodules = []
17 |
18 | __all__ = list(__attributes.keys()) + __submodules
19 |
20 | def __getattr__(name):
21 | if name not in globals():
22 | if name in __attributes:
23 | module_name = __attributes[name]
24 | module = importlib.import_module(f".{module_name}", __name__)
25 | globals()[name] = getattr(module, name)
26 | elif name in __submodules:
27 | module = importlib.import_module(f".{name}", __name__)
28 | globals()[name] = module
29 | else:
30 | raise AttributeError(f"module {__name__} has no attribute {name}")
31 | return globals()[name]
32 |
33 |
34 | def from_pretrained(path: str) -> nn.Module:
35 | """
36 | Load a pretrained model.
37 |
38 | Args:
39 | path (str): Full path to model file or HuggingFace repo ID with model name
40 | """
41 | import os
42 | import json
43 | from safetensors.torch import load_file
44 |
45 | # Split path into directory and model name
46 | path = os.path.normpath(path)
47 | model_dir = os.path.dirname(os.path.dirname(path)) # Go up two levels (past ckpts/)
48 | model_name = os.path.basename(path)
49 |
50 | is_local = os.path.exists(model_dir)
51 |
52 | if is_local:
53 | # For local paths
54 | print(f"Loading local model: {model_name}")
55 | model_name = model_name.replace('ckpts/', '').replace('ckpts\\', '')
56 | config_path = os.path.normpath(os.path.join(model_dir, "ckpts", f"{model_name}.json"))
57 | weights_path = os.path.normpath(os.path.join(model_dir, "ckpts", f"{model_name}.safetensors"))
58 |
59 | if not os.path.exists(config_path):
60 | raise FileNotFoundError(f"Config file not found: {config_path}")
61 | if not os.path.exists(weights_path):
62 | raise FileNotFoundError(f"Weights file not found: {weights_path}")
63 |
64 | # Load config
65 | with open(config_path, 'r') as f:
66 | config = json.load(f)
67 |
68 | # Create model
69 | model = create_model_from_config(config)
70 |
71 | # Load weights
72 | state_dict = load_file(weights_path)
73 | model.load_state_dict(state_dict)
74 |
75 | else:
76 | # For HuggingFace paths
77 | from huggingface_hub import hf_hub_download
78 |
79 | config_file = hf_hub_download(path, f"{model_name}.json")
80 | with open(config_file, 'r') as f:
81 | config = json.load(f)
82 |
83 | model = create_model_from_config(config)
84 |
85 | weights_file = hf_hub_download(path, f"{model_name}.safetensors")
86 | state_dict = load_file(weights_file)
87 | model.load_state_dict(state_dict)
88 |
89 | return model
90 |
91 | def create_model_from_config(config):
92 | """Helper function to create model from config"""
93 | #print(f"Creating model from config: {config}")
94 | model_type = config.get('type') or config.get('name')
95 | #print(f"Model type: {model_type}")
96 | #print(f"Available model types: {list(__attributes.keys())}")
97 | if not model_type in __attributes:
98 | raise ValueError(f"Unknown model type: {model_type}")
99 |
100 | model_class = __getattr__(model_type)
101 | #print(f"Model class: {model_class}")
102 | args = config.get('args', {})
103 | #print(f"Model args: {args}")
104 | return model_class(**args)
105 |
106 |
107 | # For Pylance
108 | if __name__ == '__main__':
109 | from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
110 | from .sparse_structure_flow import SparseStructureFlowModel
111 | from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder
112 | from .structured_latent_flow import SLatFlowModel
113 |
--------------------------------------------------------------------------------
/trellis/models/sparse_structure_flow.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7 | from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
8 | from ..modules.spatial import patchify, unpatchify
9 |
10 |
11 | class TimestepEmbedder(nn.Module):
12 | """
13 | Embeds scalar timesteps into vector representations.
14 | """
15 | def __init__(self, hidden_size, frequency_embedding_size=256):
16 | super().__init__()
17 | self.mlp = nn.Sequential(
18 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
19 | nn.SiLU(),
20 | nn.Linear(hidden_size, hidden_size, bias=True),
21 | )
22 | self.frequency_embedding_size = frequency_embedding_size
23 |
24 | @staticmethod
25 | def timestep_embedding(t, dim, max_period=10000):
26 | """
27 | Create sinusoidal timestep embeddings.
28 |
29 | Args:
30 | t: a 1-D Tensor of N indices, one per batch element.
31 | These may be fractional.
32 | dim: the dimension of the output.
33 | max_period: controls the minimum frequency of the embeddings.
34 |
35 | Returns:
36 | an (N, D) Tensor of positional embeddings.
37 | """
38 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
39 | half = dim // 2
40 | freqs = torch.exp(
41 | -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
42 | ).to(device=t.device)
43 | args = t[:, None].float() * freqs[None]
44 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45 | if dim % 2:
46 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47 | return embedding
48 |
49 | def forward(self, t):
50 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
51 | t_emb = self.mlp(t_freq)
52 | return t_emb
53 |
54 |
55 | class SparseStructureFlowModel(nn.Module):
56 | def __init__(
57 | self,
58 | resolution: int,
59 | in_channels: int,
60 | model_channels: int,
61 | cond_channels: int,
62 | out_channels: int,
63 | num_blocks: int,
64 | num_heads: Optional[int] = None,
65 | num_head_channels: Optional[int] = 64,
66 | mlp_ratio: float = 4,
67 | patch_size: int = 2,
68 | pe_mode: Literal["ape", "rope"] = "ape",
69 | use_fp16: bool = False,
70 | use_checkpoint: bool = False,
71 | share_mod: bool = False,
72 | qk_rms_norm: bool = False,
73 | qk_rms_norm_cross: bool = False,
74 | ):
75 | super().__init__()
76 | self.resolution = resolution
77 | self.in_channels = in_channels
78 | self.model_channels = model_channels
79 | self.cond_channels = cond_channels
80 | self.out_channels = out_channels
81 | self.num_blocks = num_blocks
82 | self.num_heads = num_heads or model_channels // num_head_channels
83 | self.mlp_ratio = mlp_ratio
84 | self.patch_size = patch_size
85 | self.pe_mode = pe_mode
86 | self.use_fp16 = use_fp16
87 | self.use_checkpoint = use_checkpoint
88 | self.share_mod = share_mod
89 | self.qk_rms_norm = qk_rms_norm
90 | self.qk_rms_norm_cross = qk_rms_norm_cross
91 | self.dtype = torch.float16 if use_fp16 else torch.float32
92 |
93 | self.t_embedder = TimestepEmbedder(model_channels)
94 | if share_mod:
95 | self.adaLN_modulation = nn.Sequential(
96 | nn.SiLU(),
97 | nn.Linear(model_channels, 6 * model_channels, bias=True)
98 | )
99 |
100 | if pe_mode == "ape":
101 | pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
102 | coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
103 | coords = torch.stack(coords, dim=-1).reshape(-1, 3)
104 | pos_emb = pos_embedder(coords)
105 | self.register_buffer("pos_emb", pos_emb)
106 |
107 | self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
108 |
109 | self.blocks = nn.ModuleList([
110 | ModulatedTransformerCrossBlock(
111 | model_channels,
112 | cond_channels,
113 | num_heads=self.num_heads,
114 | mlp_ratio=self.mlp_ratio,
115 | attn_mode='full',
116 | use_checkpoint=self.use_checkpoint,
117 | use_rope=(pe_mode == "rope"),
118 | share_mod=share_mod,
119 | qk_rms_norm=self.qk_rms_norm,
120 | qk_rms_norm_cross=self.qk_rms_norm_cross,
121 | )
122 | for _ in range(num_blocks)
123 | ])
124 |
125 | self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
126 |
127 | self.initialize_weights()
128 | if use_fp16:
129 | self.convert_to_fp16()
130 |
131 | @property
132 | def device(self) -> torch.device:
133 | """
134 | Return the device of the model.
135 | """
136 | return next(self.parameters()).device
137 |
138 | def convert_to_fp16(self) -> None:
139 | """
140 | Convert the torso of the model to float16.
141 | """
142 | self.blocks.apply(convert_module_to_f16)
143 |
144 | def convert_to_fp32(self) -> None:
145 | """
146 | Convert the torso of the model to float32.
147 | """
148 | self.blocks.apply(convert_module_to_f32)
149 |
150 | def initialize_weights(self) -> None:
151 | # Initialize transformer layers:
152 | def _basic_init(module):
153 | if isinstance(module, nn.Linear):
154 | torch.nn.init.xavier_uniform_(module.weight)
155 | if module.bias is not None:
156 | nn.init.constant_(module.bias, 0)
157 | self.apply(_basic_init)
158 |
159 | # Initialize timestep embedding MLP:
160 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
161 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
162 |
163 | # Zero-out adaLN modulation layers in DiT blocks:
164 | if self.share_mod:
165 | nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
166 | nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
167 | else:
168 | for block in self.blocks:
169 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
170 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
171 |
172 | # Zero-out output layers:
173 | nn.init.constant_(self.out_layer.weight, 0)
174 | nn.init.constant_(self.out_layer.bias, 0)
175 |
176 | def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
177 | assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
178 | f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
179 |
180 | h = patchify(x, self.patch_size)
181 | h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
182 |
183 | h = self.input_layer(h)
184 | h = h + self.pos_emb[None]
185 | t_emb = self.t_embedder(t)
186 | if self.share_mod:
187 | t_emb = self.adaLN_modulation(t_emb)
188 | t_emb = t_emb.type(self.dtype)
189 | h = h.type(self.dtype)
190 | cond = cond.type(self.dtype)
191 | for block in self.blocks:
192 | h = block(h, t_emb, cond)
193 | h = h.type(x.dtype)
194 | h = F.layer_norm(h, h.shape[-1:])
195 | h = self.out_layer(h)
196 |
197 | h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
198 | h = unpatchify(h, self.patch_size).contiguous()
199 |
200 | return h
201 |
--------------------------------------------------------------------------------
/trellis/models/structured_latent_flow.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7 | from ..modules.transformer import AbsolutePositionEmbedder
8 | from ..modules.norm import LayerNorm32
9 | from ..modules import sparse as sp
10 | from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
11 | from .sparse_structure_flow import TimestepEmbedder
12 |
13 |
14 | class SparseResBlock3d(nn.Module):
15 | def __init__(
16 | self,
17 | channels: int,
18 | emb_channels: int,
19 | out_channels: Optional[int] = None,
20 | downsample: bool = False,
21 | upsample: bool = False,
22 | ):
23 | super().__init__()
24 | self.channels = channels
25 | self.emb_channels = emb_channels
26 | self.out_channels = out_channels or channels
27 | self.downsample = downsample
28 | self.upsample = upsample
29 |
30 | assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
31 |
32 | self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
33 | self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
34 | self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
35 | self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
36 | self.emb_layers = nn.Sequential(
37 | nn.SiLU(),
38 | nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
39 | )
40 | self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
41 | self.updown = None
42 | if self.downsample:
43 | self.updown = sp.SparseDownsample(2)
44 | elif self.upsample:
45 | self.updown = sp.SparseUpsample(2)
46 |
47 | def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
48 | if self.updown is not None:
49 | x = self.updown(x)
50 | return x
51 |
52 | def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
53 | emb_out = self.emb_layers(emb).type(x.dtype)
54 | scale, shift = torch.chunk(emb_out, 2, dim=1)
55 |
56 | x = self._updown(x)
57 | h = x.replace(self.norm1(x.feats))
58 | h = h.replace(F.silu(h.feats))
59 | h = self.conv1(h)
60 | h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
61 | h = h.replace(F.silu(h.feats))
62 | h = self.conv2(h)
63 | h = h + self.skip_connection(x)
64 |
65 | return h
66 |
67 |
68 | class SLatFlowModel(nn.Module):
69 | def __init__(
70 | self,
71 | resolution: int,
72 | in_channels: int,
73 | model_channels: int,
74 | cond_channels: int,
75 | out_channels: int,
76 | num_blocks: int,
77 | num_heads: Optional[int] = None,
78 | num_head_channels: Optional[int] = 64,
79 | mlp_ratio: float = 4,
80 | patch_size: int = 2,
81 | num_io_res_blocks: int = 2,
82 | io_block_channels: List[int] = None,
83 | pe_mode: Literal["ape", "rope"] = "ape",
84 | use_fp16: bool = False,
85 | use_checkpoint: bool = False,
86 | use_skip_connection: bool = True,
87 | share_mod: bool = False,
88 | qk_rms_norm: bool = False,
89 | qk_rms_norm_cross: bool = False,
90 | ):
91 | super().__init__()
92 | self.resolution = resolution
93 | self.in_channels = in_channels
94 | self.model_channels = model_channels
95 | self.cond_channels = cond_channels
96 | self.out_channels = out_channels
97 | self.num_blocks = num_blocks
98 | self.num_heads = num_heads or model_channels // num_head_channels
99 | self.mlp_ratio = mlp_ratio
100 | self.patch_size = patch_size
101 | self.num_io_res_blocks = num_io_res_blocks
102 | self.io_block_channels = io_block_channels
103 | self.pe_mode = pe_mode
104 | self.use_fp16 = use_fp16
105 | self.use_checkpoint = use_checkpoint
106 | self.use_skip_connection = use_skip_connection
107 | self.share_mod = share_mod
108 | self.qk_rms_norm = qk_rms_norm
109 | self.qk_rms_norm_cross = qk_rms_norm_cross
110 | self.dtype = torch.float16 if use_fp16 else torch.float32
111 |
112 | assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
113 | assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
114 |
115 | self.t_embedder = TimestepEmbedder(model_channels)
116 | if share_mod:
117 | self.adaLN_modulation = nn.Sequential(
118 | nn.SiLU(),
119 | nn.Linear(model_channels, 6 * model_channels, bias=True)
120 | )
121 |
122 | if pe_mode == "ape":
123 | self.pos_embedder = AbsolutePositionEmbedder(model_channels)
124 |
125 | self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
126 | self.input_blocks = nn.ModuleList([])
127 | for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
128 | self.input_blocks.extend([
129 | SparseResBlock3d(
130 | chs,
131 | model_channels,
132 | out_channels=chs,
133 | )
134 | for _ in range(num_io_res_blocks-1)
135 | ])
136 | self.input_blocks.append(
137 | SparseResBlock3d(
138 | chs,
139 | model_channels,
140 | out_channels=next_chs,
141 | downsample=True,
142 | )
143 | )
144 |
145 | self.blocks = nn.ModuleList([
146 | ModulatedSparseTransformerCrossBlock(
147 | model_channels,
148 | cond_channels,
149 | num_heads=self.num_heads,
150 | mlp_ratio=self.mlp_ratio,
151 | attn_mode='full',
152 | use_checkpoint=self.use_checkpoint,
153 | use_rope=(pe_mode == "rope"),
154 | share_mod=self.share_mod,
155 | qk_rms_norm=self.qk_rms_norm,
156 | qk_rms_norm_cross=self.qk_rms_norm_cross,
157 | )
158 | for _ in range(num_blocks)
159 | ])
160 |
161 | self.out_blocks = nn.ModuleList([])
162 | for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
163 | self.out_blocks.append(
164 | SparseResBlock3d(
165 | prev_chs * 2 if self.use_skip_connection else prev_chs,
166 | model_channels,
167 | out_channels=chs,
168 | upsample=True,
169 | )
170 | )
171 | self.out_blocks.extend([
172 | SparseResBlock3d(
173 | chs * 2 if self.use_skip_connection else chs,
174 | model_channels,
175 | out_channels=chs,
176 | )
177 | for _ in range(num_io_res_blocks-1)
178 | ])
179 | self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels)
180 |
181 | self.initialize_weights()
182 | if use_fp16:
183 | self.convert_to_fp16()
184 |
185 | @property
186 | def device(self) -> torch.device:
187 | """
188 | Return the device of the model.
189 | """
190 | return next(self.parameters()).device
191 |
192 | def convert_to_fp16(self) -> None:
193 | """
194 | Convert the torso of the model to float16.
195 | """
196 | self.input_blocks.apply(convert_module_to_f16)
197 | self.blocks.apply(convert_module_to_f16)
198 | self.out_blocks.apply(convert_module_to_f16)
199 |
200 | def convert_to_fp32(self) -> None:
201 | """
202 | Convert the torso of the model to float32.
203 | """
204 | self.input_blocks.apply(convert_module_to_f32)
205 | self.blocks.apply(convert_module_to_f32)
206 | self.out_blocks.apply(convert_module_to_f32)
207 |
208 | def initialize_weights(self) -> None:
209 | # Initialize transformer layers:
210 | def _basic_init(module):
211 | if isinstance(module, nn.Linear):
212 | torch.nn.init.xavier_uniform_(module.weight)
213 | if module.bias is not None:
214 | nn.init.constant_(module.bias, 0)
215 | self.apply(_basic_init)
216 |
217 | # Initialize timestep embedding MLP:
218 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
219 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
220 |
221 | # Zero-out adaLN modulation layers in DiT blocks:
222 | if self.share_mod:
223 | nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
224 | nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
225 | else:
226 | for block in self.blocks:
227 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
228 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
229 |
230 | # Zero-out output layers:
231 | nn.init.constant_(self.out_layer.weight, 0)
232 | nn.init.constant_(self.out_layer.bias, 0)
233 |
234 | def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
235 | h = self.input_layer(x).type(self.dtype)
236 | t_emb = self.t_embedder(t)
237 | if self.share_mod:
238 | t_emb = self.adaLN_modulation(t_emb)
239 | t_emb = t_emb.type(self.dtype)
240 | cond = cond.type(self.dtype)
241 |
242 | skips = []
243 | # pack with input blocks
244 | for block in self.input_blocks:
245 | h = block(h, t_emb)
246 | skips.append(h.feats)
247 |
248 | if self.pe_mode == "ape":
249 | h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
250 | for block in self.blocks:
251 | h = block(h, t_emb, cond)
252 |
253 | # unpack with output blocks
254 | for block, skip in zip(self.out_blocks, reversed(skips)):
255 | if self.use_skip_connection:
256 | h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
257 | else:
258 | h = block(h, t_emb)
259 |
260 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
261 | h = self.out_layer(h.type(x.dtype))
262 | return h
263 |
--------------------------------------------------------------------------------
/trellis/models/structured_latent_vae/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoder import SLatEncoder
2 | from .decoder_mesh import SLatMeshDecoder
3 |
--------------------------------------------------------------------------------
/trellis/models/structured_latent_vae/base.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | from ...modules.utils import convert_module_to_f16, convert_module_to_f32
5 | from ...modules import sparse as sp
6 | from ...modules.transformer import AbsolutePositionEmbedder
7 | from ...modules.sparse.transformer import SparseTransformerBlock
8 |
9 |
10 | def block_attn_config(self):
11 | """
12 | Return the attention configuration of the model.
13 | """
14 | for i in range(self.num_blocks):
15 | if self.attn_mode == "shift_window":
16 | yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
17 | elif self.attn_mode == "shift_sequence":
18 | yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
19 | elif self.attn_mode == "shift_order":
20 | yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
21 | elif self.attn_mode == "full":
22 | yield "full", None, None, None, None
23 | elif self.attn_mode == "swin":
24 | yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
25 |
26 |
27 | class SparseTransformerBase(nn.Module):
28 | """
29 | Sparse Transformer without output layers.
30 | Serve as the base class for encoder and decoder.
31 | """
32 | def __init__(
33 | self,
34 | in_channels: int,
35 | model_channels: int,
36 | num_blocks: int,
37 | num_heads: Optional[int] = None,
38 | num_head_channels: Optional[int] = 64,
39 | mlp_ratio: float = 4.0,
40 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
41 | window_size: Optional[int] = None,
42 | pe_mode: Literal["ape", "rope"] = "ape",
43 | use_fp16: bool = False,
44 | use_checkpoint: bool = False,
45 | qk_rms_norm: bool = False,
46 | ):
47 | super().__init__()
48 | self.in_channels = in_channels
49 | self.model_channels = model_channels
50 | self.num_blocks = num_blocks
51 | self.window_size = window_size
52 | self.num_heads = num_heads or model_channels // num_head_channels
53 | self.mlp_ratio = mlp_ratio
54 | self.attn_mode = attn_mode
55 | self.pe_mode = pe_mode
56 | self.use_fp16 = use_fp16
57 | self.use_checkpoint = use_checkpoint
58 | self.qk_rms_norm = qk_rms_norm
59 | self.dtype = torch.float16 if use_fp16 else torch.float32
60 |
61 | if pe_mode == "ape":
62 | self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63 |
64 | self.input_layer = sp.SparseLinear(in_channels, model_channels)
65 | self.blocks = nn.ModuleList([
66 | SparseTransformerBlock(
67 | model_channels,
68 | num_heads=self.num_heads,
69 | mlp_ratio=self.mlp_ratio,
70 | attn_mode=attn_mode,
71 | window_size=window_size,
72 | shift_sequence=shift_sequence,
73 | shift_window=shift_window,
74 | serialize_mode=serialize_mode,
75 | use_checkpoint=self.use_checkpoint,
76 | use_rope=(pe_mode == "rope"),
77 | qk_rms_norm=self.qk_rms_norm,
78 | )
79 | for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80 | ])
81 |
82 | @property
83 | def device(self) -> torch.device:
84 | """
85 | Return the device of the model.
86 | """
87 | return next(self.parameters()).device
88 |
89 | def convert_to_fp16(self) -> None:
90 | """
91 | Convert the torso of the model to float16.
92 | """
93 | self.blocks.apply(convert_module_to_f16)
94 |
95 | def convert_to_fp32(self) -> None:
96 | """
97 | Convert the torso of the model to float32.
98 | """
99 | self.blocks.apply(convert_module_to_f32)
100 |
101 | def initialize_weights(self) -> None:
102 | # Initialize transformer layers:
103 | def _basic_init(module):
104 | if isinstance(module, nn.Linear):
105 | torch.nn.init.xavier_uniform_(module.weight)
106 | if module.bias is not None:
107 | nn.init.constant_(module.bias, 0)
108 | self.apply(_basic_init)
109 |
110 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
111 | h = self.input_layer(x)
112 | if self.pe_mode == "ape":
113 | h = h + self.pos_embedder(x.coords[:, 1:])
114 | h = h.type(self.dtype)
115 | for block in self.blocks:
116 | h = block(h)
117 | return h
118 |
--------------------------------------------------------------------------------
/trellis/models/structured_latent_vae/decoder_mesh.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7 | from ...modules import sparse as sp
8 | from .base import SparseTransformerBase
9 | from ...representations import MeshExtractResult
10 | from ...representations.mesh import SparseFeatures2Mesh
11 |
12 |
13 | class SparseSubdivideBlock3d(nn.Module):
14 | """
15 | A 3D subdivide block that can subdivide the sparse tensor.
16 |
17 | Args:
18 | channels: channels in the inputs and outputs.
19 | out_channels: if specified, the number of output channels.
20 | num_groups: the number of groups for the group norm.
21 | """
22 | def __init__(
23 | self,
24 | channels: int,
25 | resolution: int,
26 | out_channels: Optional[int] = None,
27 | num_groups: int = 32
28 | ):
29 | super().__init__()
30 | self.channels = channels
31 | self.resolution = resolution
32 | self.out_resolution = resolution * 2
33 | self.out_channels = out_channels or channels
34 |
35 | self.act_layers = nn.Sequential(
36 | sp.SparseGroupNorm32(num_groups, channels),
37 | sp.SparseSiLU()
38 | )
39 |
40 | self.sub = sp.SparseSubdivide()
41 |
42 | self.out_layers = nn.Sequential(
43 | sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
44 | sp.SparseGroupNorm32(num_groups, self.out_channels),
45 | sp.SparseSiLU(),
46 | zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
47 | )
48 |
49 | if self.out_channels == channels:
50 | self.skip_connection = nn.Identity()
51 | else:
52 | self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
53 |
54 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
55 | """
56 | Apply the block to a Tensor, conditioned on a timestep embedding.
57 |
58 | Args:
59 | x: an [N x C x ...] Tensor of features.
60 | Returns:
61 | an [N x C x ...] Tensor of outputs.
62 | """
63 | h = self.act_layers(x)
64 | h = self.sub(h)
65 | x = self.sub(x)
66 | h = self.out_layers(h)
67 | h = h + self.skip_connection(x)
68 | return h
69 |
70 |
71 | class SLatMeshDecoder(SparseTransformerBase):
72 | def __init__(
73 | self,
74 | resolution: int,
75 | model_channels: int,
76 | latent_channels: int,
77 | num_blocks: int,
78 | num_heads: Optional[int] = None,
79 | num_head_channels: Optional[int] = 64,
80 | mlp_ratio: float = 4,
81 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
82 | window_size: int = 8,
83 | pe_mode: Literal["ape", "rope"] = "ape",
84 | use_fp16: bool = False,
85 | use_checkpoint: bool = False,
86 | qk_rms_norm: bool = False,
87 | representation_config: dict = None,
88 | ):
89 | super().__init__(
90 | in_channels=latent_channels,
91 | model_channels=model_channels,
92 | num_blocks=num_blocks,
93 | num_heads=num_heads,
94 | num_head_channels=num_head_channels,
95 | mlp_ratio=mlp_ratio,
96 | attn_mode=attn_mode,
97 | window_size=window_size,
98 | pe_mode=pe_mode,
99 | use_fp16=use_fp16,
100 | use_checkpoint=use_checkpoint,
101 | qk_rms_norm=qk_rms_norm,
102 | )
103 | self.resolution = resolution
104 | self.rep_config = representation_config
105 | self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
106 | self.out_channels = self.mesh_extractor.feats_channels
107 | self.upsample = nn.ModuleList([
108 | SparseSubdivideBlock3d(
109 | channels=model_channels,
110 | resolution=resolution,
111 | out_channels=model_channels // 4
112 | ),
113 | SparseSubdivideBlock3d(
114 | channels=model_channels // 4,
115 | resolution=resolution * 2,
116 | out_channels=model_channels // 8
117 | )
118 | ])
119 | self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
120 |
121 | self.initialize_weights()
122 | if use_fp16:
123 | self.convert_to_fp16()
124 |
125 | def initialize_weights(self) -> None:
126 | super().initialize_weights()
127 | # Zero-out output layers:
128 | nn.init.constant_(self.out_layer.weight, 0)
129 | nn.init.constant_(self.out_layer.bias, 0)
130 |
131 | def convert_to_fp16(self) -> None:
132 | """
133 | Convert the torso of the model to float16.
134 | """
135 | super().convert_to_fp16()
136 | self.upsample.apply(convert_module_to_f16)
137 |
138 | def convert_to_fp32(self) -> None:
139 | """
140 | Convert the torso of the model to float32.
141 | """
142 | super().convert_to_fp32()
143 | self.upsample.apply(convert_module_to_f32)
144 |
145 | def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
146 | """
147 | Convert a batch of network outputs to 3D representations.
148 |
149 | Args:
150 | x: The [N x * x C] sparse tensor output by the network.
151 |
152 | Returns:
153 | list of representations
154 | """
155 | ret = []
156 | for i in range(x.shape[0]):
157 | mesh = self.mesh_extractor(x[i], training=self.training)
158 | ret.append(mesh)
159 | return ret
160 |
161 | def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
162 | h = super().forward(x)
163 | for block in self.upsample:
164 | h = block(h)
165 | h = h.type(x.dtype)
166 | h = self.out_layer(h)
167 | return self.to_representation(h)
168 |
--------------------------------------------------------------------------------
/trellis/models/structured_latent_vae/encoder.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from ...modules import sparse as sp
6 | from .base import SparseTransformerBase
7 |
8 |
9 | class SLatEncoder(SparseTransformerBase):
10 | def __init__(
11 | self,
12 | resolution: int,
13 | in_channels: int,
14 | model_channels: int,
15 | latent_channels: int,
16 | num_blocks: int,
17 | num_heads: Optional[int] = None,
18 | num_head_channels: Optional[int] = 64,
19 | mlp_ratio: float = 4,
20 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
21 | window_size: int = 8,
22 | pe_mode: Literal["ape", "rope"] = "ape",
23 | use_fp16: bool = False,
24 | use_checkpoint: bool = False,
25 | qk_rms_norm: bool = False,
26 | ):
27 | super().__init__(
28 | in_channels=in_channels,
29 | model_channels=model_channels,
30 | num_blocks=num_blocks,
31 | num_heads=num_heads,
32 | num_head_channels=num_head_channels,
33 | mlp_ratio=mlp_ratio,
34 | attn_mode=attn_mode,
35 | window_size=window_size,
36 | pe_mode=pe_mode,
37 | use_fp16=use_fp16,
38 | use_checkpoint=use_checkpoint,
39 | qk_rms_norm=qk_rms_norm,
40 | )
41 | self.resolution = resolution
42 | self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
43 |
44 | self.initialize_weights()
45 | if use_fp16:
46 | self.convert_to_fp16()
47 |
48 | def initialize_weights(self) -> None:
49 | super().initialize_weights()
50 | # Zero-out output layers:
51 | nn.init.constant_(self.out_layer.weight, 0)
52 | nn.init.constant_(self.out_layer.bias, 0)
53 |
54 | def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
55 | h = super().forward(x)
56 | h = h.type(x.dtype)
57 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
58 | h = self.out_layer(h)
59 |
60 | # Sample from the posterior distribution
61 | mean, logvar = h.feats.chunk(2, dim=-1)
62 | if sample_posterior:
63 | std = torch.exp(0.5 * logvar)
64 | z = mean + std * torch.randn_like(std)
65 | else:
66 | z = mean
67 | z = h.replace(z)
68 |
69 | if return_raw:
70 | return z, mean, logvar
71 | else:
72 | return z
73 |
--------------------------------------------------------------------------------
/trellis/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #from .attention_utils import enable_sage_attention, disable_sage_attention
2 | from .attention import (
3 | scaled_dot_product_attention,
4 | BACKEND,
5 | DEBUG,
6 | MultiHeadAttention,
7 | RotaryPositionEmbedder
8 | )
9 |
10 | __all__ = [
11 | 'scaled_dot_product_attention',
12 | 'BACKEND',
13 | 'DEBUG',
14 | 'MultiHeadAttention',
15 | 'RotaryPositionEmbedder'
16 | ]
--------------------------------------------------------------------------------
/trellis/modules/attention/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from typing import Literal
4 | from trellis.backend_config import (
5 | get_attention_backend,
6 | get_debug_mode,
7 | )
8 | import logging
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | #ATTN = get_attention_backend()
13 | BACKEND = get_attention_backend()
14 | DEBUG = get_debug_mode()
15 |
16 | def __from_env():
17 | """Read current backend configuration"""
18 | #global ATTN
19 | global BACKEND
20 | global DEBUG
21 |
22 | # Get current settings from central config
23 | #ATTN =
24 | BACKEND = get_attention_backend()
25 | DEBUG = get_debug_mode()
26 |
27 | print(f"[ATTENTION] Using backend: {BACKEND}")
28 |
29 | from .modules import MultiHeadAttention, RotaryPositionEmbedder
30 | from .full_attn import scaled_dot_product_attention
31 |
32 | __all__ = [
33 | 'scaled_dot_product_attention',
34 | 'BACKEND',
35 | 'DEBUG',
36 | 'MultiHeadAttention',
37 | 'RotaryPositionEmbedder'
38 | ]
39 |
--------------------------------------------------------------------------------
/trellis/modules/attention/full_attn.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import math
4 | from trellis.backend_config import (
5 | get_attention_backend,
6 | get_debug_mode,
7 | get_available_backends
8 | )
9 | import logging
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | # Get configuration from central config
14 | BACKEND = get_attention_backend()
15 | DEBUG = get_debug_mode()
16 |
17 | # Get available backends and import if active
18 | available_backends = get_available_backends()
19 |
20 | if BACKEND == "xformers" and available_backends['xformers']:
21 | import xformers.ops as xops
22 | elif BACKEND == "flash_attn" and available_backends['flash_attn']:
23 | import flash_attn
24 | elif BACKEND == "sage" and available_backends['sage']:
25 | import torch.nn.functional as F
26 | from sageattention import sageattn
27 | F.scaled_dot_product_attention = sageattn
28 | elif BACKEND == "sdpa":
29 | from torch.nn.functional import scaled_dot_product_attention as sdpa
30 | elif BACKEND == "naive":
31 | from torch.nn.functional import scaled_dot_product_attention as naive
32 | else:
33 | raise ValueError(f"Unknown attention module: {BACKEND}")
34 |
35 |
36 | __all__ = [
37 | 'scaled_dot_product_attention',
38 | ]
39 |
40 |
41 | def _naive_sdpa(q, k, v):
42 | """
43 | Naive implementation of scaled dot product attention.
44 | """
45 | q = q.permute(0, 2, 1, 3) # [N, H, L, C]
46 | k = k.permute(0, 2, 1, 3) # [N, H, L, C]
47 | v = v.permute(0, 2, 1, 3) # [N, H, L, C]
48 | scale_factor = 1 / math.sqrt(q.size(-1))
49 | attn_weight = q @ k.transpose(-2, -1) * scale_factor
50 | attn_weight = torch.softmax(attn_weight, dim=-1)
51 | out = attn_weight @ v
52 | out = out.permute(0, 2, 1, 3) # [N, L, H, C]
53 | return out
54 |
55 |
56 | @overload
57 | def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
58 | """
59 | Apply scaled dot product attention.
60 |
61 | Args:
62 | qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
63 | """
64 | ...
65 |
66 | @overload
67 | def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
68 | """
69 | Apply scaled dot product attention.
70 |
71 | Args:
72 | q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
73 | kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
74 | """
75 | ...
76 |
77 | @overload
78 | def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
79 | """
80 | Apply scaled dot product attention.
81 |
82 | Args:
83 | q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
84 | k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
85 | v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
86 |
87 | Note:
88 | k and v are assumed to have the same coordinate map.
89 | """
90 | ...
91 |
92 | def scaled_dot_product_attention(*args, **kwargs):
93 | arg_names_dict = {
94 | 1: ['qkv'],
95 | 2: ['q', 'kv'],
96 | 3: ['q', 'k', 'v']
97 | }
98 | num_all_args = len(args) + len(kwargs)
99 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
100 | for key in arg_names_dict[num_all_args][len(args):]:
101 | assert key in kwargs, f"Missing argument {key}"
102 |
103 | if num_all_args == 1:
104 | qkv = args[0] if len(args) > 0 else kwargs['qkv']
105 | assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
106 | device = qkv.device
107 |
108 | elif num_all_args == 2:
109 | q = args[0] if len(args) > 0 else kwargs['q']
110 | kv = args[1] if len(args) > 1 else kwargs['kv']
111 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
112 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
113 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
114 | device = q.device
115 |
116 | elif num_all_args == 3:
117 | q = args[0] if len(args) > 0 else kwargs['q']
118 | k = args[1] if len(args) > 1 else kwargs['k']
119 | v = args[2] if len(args) > 2 else kwargs['v']
120 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
121 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
122 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
123 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
124 | device = q.device
125 |
126 | if BACKEND == 'xformers':
127 | if num_all_args == 1:
128 | q, k, v = qkv.unbind(dim=2)
129 | elif num_all_args == 2:
130 | k, v = kv.unbind(dim=2)
131 | out = xops.memory_efficient_attention(q, k, v)
132 | elif BACKEND == 'flash_attn':
133 | if num_all_args == 1:
134 | out = flash_attn.flash_attn_qkvpacked_func(qkv)
135 | elif num_all_args == 2:
136 | out = flash_attn.flash_attn_kvpacked_func(q, kv)
137 | elif num_all_args == 3:
138 | out = flash_attn.flash_attn_func(q, k, v)
139 | elif BACKEND == 'sdpa':
140 | if num_all_args == 1:
141 | q, k, v = qkv.unbind(dim=2)
142 | elif num_all_args == 2:
143 | k, v = kv.unbind(dim=2)
144 | q = q.permute(0, 2, 1, 3) # [N, H, L, C]
145 | k = k.permute(0, 2, 1, 3) # [N, H, L, C]
146 | v = v.permute(0, 2, 1, 3) # [N, H, L, C]
147 | out = sdpa(q, k, v) # [N, H, L, C]
148 | out = out.permute(0, 2, 1, 3) # [N, L, H, C]
149 | elif BACKEND == 'naive':
150 | if num_all_args == 1:
151 | q, k, v = qkv.unbind(dim=2)
152 | elif num_all_args == 2:
153 | k, v = kv.unbind(dim=2)
154 | out = _naive_sdpa(q, k, v)
155 | else:
156 | raise ValueError(f"Unknown attention module: {BACKEND}")
157 |
158 | return out
--------------------------------------------------------------------------------
/trellis/modules/attention/modules.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .full_attn import scaled_dot_product_attention
6 |
7 |
8 | class MultiHeadRMSNorm(nn.Module):
9 | def __init__(self, dim: int, heads: int):
10 | super().__init__()
11 | self.scale = dim ** 0.5
12 | self.gamma = nn.Parameter(torch.ones(heads, dim))
13 |
14 | def forward(self, x: torch.Tensor) -> torch.Tensor:
15 | return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
16 |
17 |
18 | class RotaryPositionEmbedder(nn.Module):
19 | def __init__(self, hidden_size: int, in_channels: int = 3):
20 | super().__init__()
21 | assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
22 | self.hidden_size = hidden_size
23 | self.in_channels = in_channels
24 | self.freq_dim = hidden_size // in_channels // 2
25 | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26 | self.freqs = 1.0 / (10000 ** self.freqs)
27 |
28 | def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29 | self.freqs = self.freqs.to(indices.device)
30 | phases = torch.outer(indices, self.freqs)
31 | phases = torch.polar(torch.ones_like(phases), phases)
32 | return phases
33 |
34 | def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35 | x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36 | x_rotated = x_complex * phases
37 | x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
38 | return x_embed
39 |
40 | def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
41 | """
42 | Args:
43 | q (sp.SparseTensor): [..., N, D] tensor of queries
44 | k (sp.SparseTensor): [..., N, D] tensor of keys
45 | indices (torch.Tensor): [..., N, C] tensor of spatial positions
46 | """
47 | if indices is None:
48 | indices = torch.arange(q.shape[-2], device=q.device)
49 | if len(q.shape) > 2:
50 | indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
51 |
52 | phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
53 | if phases.shape[1] < self.hidden_size // 2:
54 | phases = torch.cat([phases, torch.polar(
55 | torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
56 | torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
57 | )], dim=-1)
58 | q_embed = self._rotary_embedding(q, phases)
59 | k_embed = self._rotary_embedding(k, phases)
60 | return q_embed, k_embed
61 |
62 |
63 | class MultiHeadAttention(nn.Module):
64 | def __init__(
65 | self,
66 | channels: int,
67 | num_heads: int,
68 | ctx_channels: Optional[int]=None,
69 | type: Literal["self", "cross"] = "self",
70 | attn_mode: Literal["full", "windowed"] = "full",
71 | window_size: Optional[int] = None,
72 | shift_window: Optional[Tuple[int, int, int]] = None,
73 | qkv_bias: bool = True,
74 | use_rope: bool = False,
75 | qk_rms_norm: bool = False,
76 | ):
77 | super().__init__()
78 | assert channels % num_heads == 0
79 | assert type in ["self", "cross"], f"Invalid attention type: {type}"
80 | assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
81 | assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
82 |
83 | if attn_mode == "windowed":
84 | raise NotImplementedError("Windowed attention is not yet implemented")
85 |
86 | self.channels = channels
87 | self.head_dim = channels // num_heads
88 | self.ctx_channels = ctx_channels if ctx_channels is not None else channels
89 | self.num_heads = num_heads
90 | self._type = type
91 | self.attn_mode = attn_mode
92 | self.window_size = window_size
93 | self.shift_window = shift_window
94 | self.use_rope = use_rope
95 | self.qk_rms_norm = qk_rms_norm
96 |
97 | if self._type == "self":
98 | self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
99 | else:
100 | self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
101 | self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
102 |
103 | if self.qk_rms_norm:
104 | self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
105 | self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
106 |
107 | self.to_out = nn.Linear(channels, channels)
108 |
109 | if use_rope:
110 | self.rope = RotaryPositionEmbedder(channels)
111 |
112 | def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
113 | B, L, C = x.shape
114 | if self._type == "self":
115 | qkv = self.to_qkv(x)
116 | qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
117 | if self.use_rope:
118 | q, k, v = qkv.unbind(dim=2)
119 | q, k = self.rope(q, k, indices)
120 | qkv = torch.stack([q, k, v], dim=2)
121 | if self.attn_mode == "full":
122 | if self.qk_rms_norm:
123 | q, k, v = qkv.unbind(dim=2)
124 | q = self.q_rms_norm(q)
125 | k = self.k_rms_norm(k)
126 | h = scaled_dot_product_attention(q, k, v)
127 | else:
128 | h = scaled_dot_product_attention(qkv)
129 | elif self.attn_mode == "windowed":
130 | raise NotImplementedError("Windowed attention is not yet implemented")
131 | else:
132 | Lkv = context.shape[1]
133 | q = self.to_q(x)
134 | kv = self.to_kv(context)
135 | q = q.reshape(B, L, self.num_heads, -1)
136 | kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
137 | if self.qk_rms_norm:
138 | q = self.q_rms_norm(q)
139 | k, v = kv.unbind(dim=2)
140 | k = self.k_rms_norm(k)
141 | h = scaled_dot_product_attention(q, k, v)
142 | else:
143 | h = scaled_dot_product_attention(q, kv)
144 | h = h.reshape(B, L, -1)
145 | h = self.to_out(h)
146 | return h
147 |
--------------------------------------------------------------------------------
/trellis/modules/attention_utils.py:
--------------------------------------------------------------------------------
1 | #sage_attn.py
2 | import os
3 | from typing import Optional
4 | import torch
5 | import torch.nn.functional as F
6 | from sageattention import sageattn
7 | import math
8 |
9 | __all__ = ['SageAttention', 'sage_attention']
10 |
11 |
12 | def enable_sage_attention():
13 | """
14 | Enable SageAttention by replacing PyTorch's scaled_dot_product_attention
15 | with sageattn from the SageAttention library.
16 | """
17 | F.scaled_dot_product_attention = sageattn
18 | return True
19 |
20 | def disable_sage_attention():
21 | """
22 | Restore PyTorch's original scaled_dot_product_attention function.
23 | """
24 | F.scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
25 | return True
--------------------------------------------------------------------------------
/trellis/modules/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class LayerNorm32(nn.LayerNorm):
6 | def forward(self, x: torch.Tensor) -> torch.Tensor:
7 | return super().forward(x.float()).type(x.dtype)
8 |
9 |
10 | class GroupNorm32(nn.GroupNorm):
11 | """
12 | A GroupNorm layer that converts to float32 before the forward pass.
13 | """
14 | def forward(self, x: torch.Tensor) -> torch.Tensor:
15 | return super().forward(x.float()).type(x.dtype)
16 |
17 |
18 | class ChannelLayerNorm32(LayerNorm32):
19 | def forward(self, x: torch.Tensor) -> torch.Tensor:
20 | DIM = x.dim()
21 | x = x.permute(0, *range(2, DIM), 1).contiguous()
22 | x = super().forward(x)
23 | x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24 | return x
25 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import importlib
3 | from trellis.backend_config import get_sparse_backend, get_attention_backend, get_debug_mode
4 |
5 | BACKEND = get_sparse_backend()
6 | DEBUG = get_debug_mode()
7 | ATTN = get_attention_backend()
8 |
9 | if ATTN not in ['xformers', 'flash_attn']:
10 | logger.warning(f"Attention backend {ATTN} not supported for sparse attention. Only 'xformers' and 'flash_attn' are available. Defaulting to 'flash_attn'")
11 | ATTN = 'flash_attn'
12 |
13 | __attributes = {
14 | 'SparseTensor': 'basic',
15 | 'sparse_batch_broadcast': 'basic',
16 | 'sparse_batch_op': 'basic',
17 | 'sparse_cat': 'basic',
18 | 'sparse_unbind': 'basic',
19 | 'SparseGroupNorm': 'norm',
20 | 'SparseLayerNorm': 'norm',
21 | 'SparseGroupNorm32': 'norm',
22 | 'SparseLayerNorm32': 'norm',
23 | 'SparseReLU': 'nonlinearity',
24 | 'SparseSiLU': 'nonlinearity',
25 | 'SparseGELU': 'nonlinearity',
26 | 'SparseActivation': 'nonlinearity',
27 | 'SparseLinear': 'linear',
28 | 'sparse_scaled_dot_product_attention': 'attention',
29 | 'SerializeMode': 'attention',
30 | 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
31 | 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
32 | 'SparseMultiHeadAttention': 'attention',
33 | 'SparseConv3d': 'conv',
34 | 'SparseInverseConv3d': 'conv',
35 | 'SparseDownsample': 'spatial',
36 | 'SparseUpsample': 'spatial',
37 | 'SparseSubdivide' : 'spatial'
38 | }
39 |
40 | __submodules = ['transformer']
41 |
42 | __all__ = list(__attributes.keys()) + __submodules
43 |
44 | def __getattr__(name):
45 | if name not in globals():
46 | if name in __attributes:
47 | module_name = __attributes[name]
48 | module = importlib.import_module(f".{module_name}", __name__)
49 | globals()[name] = getattr(module, name)
50 | elif name in __submodules:
51 | module = importlib.import_module(f".{name}", __name__)
52 | globals()[name] = module
53 | else:
54 | raise AttributeError(f"module {__name__} has no attribute {name}")
55 | return globals()[name]
56 |
57 |
58 | # For Pylance
59 | if __name__ == '__main__':
60 | from .basic import *
61 | from .norm import *
62 | from .nonlinearity import *
63 | from .linear import *
64 | from .attention import *
65 | from .conv import *
66 | from .spatial import *
67 | import transformer
68 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .full_attn import *
2 | from .serialized_attn import *
3 | from .windowed_attn import *
4 | from .modules import *
5 | import os
6 | import logging
7 | from typing import Literal
8 | from trellis.backend_config import (
9 | get_attention_backend,
10 | get_debug_mode,
11 | )
12 | import logging
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | #ATTN = get_attention_backend()
17 | ATTN = get_attention_backend()
18 | DEBUG = get_debug_mode()
19 |
20 | def __from_env():
21 | """Read current backend configuration"""
22 | #global ATTN
23 | global ATTN
24 | global DEBUG
25 |
26 | # Get current settings from central config
27 | #ATTN =
28 | ATTN = get_attention_backend()
29 | DEBUG = get_debug_mode()
30 |
31 | print(f"[ATTENTION] sparse backend: {ATTN}")
--------------------------------------------------------------------------------
/trellis/modules/sparse/attention/full_attn.py:
--------------------------------------------------------------------------------
1 | #trellis\modules\sparse\attention\full_attn.py
2 | from typing import *
3 | import torch
4 | import math
5 | from .. import SparseTensor
6 | from trellis.backend_config import (
7 | get_attention_backend,
8 | get_debug_mode,
9 | get_available_backends
10 | )
11 | import logging
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | # Get configuration from central config
16 | ATTN = get_attention_backend()
17 | DEBUG = get_debug_mode()
18 |
19 | # Get available backends and import if active
20 | available_backends = get_available_backends()
21 |
22 | if ATTN not in ['xformers', 'flash_attn']:
23 | logger.warning(f"Attention backend {ATTN} not supported for sparse attention. Only 'xformers' and 'flash_attn' are available. Defaulting to 'flash_attn'")
24 | ATTN = 'flash_attn'
25 |
26 | if ATTN == 'xformers' and available_backends['xformers']:
27 | import xformers.ops as xops
28 | elif ATTN == 'flash_attn' and available_backends['flash_attn']:
29 | import flash_attn
30 | else:
31 | raise ImportError(f"Could not import {ATTN}. Please install either xformers or flash-attn for sparse attention support.")
32 |
33 |
34 |
35 | __all__ = [
36 | 'sparse_scaled_dot_product_attention',
37 | ]
38 |
39 |
40 | @overload
41 | def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
42 | """
43 | Apply scaled dot product attention to a sparse tensor.
44 |
45 | Args:
46 | qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
47 | """
48 | ...
49 |
50 | @overload
51 | def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
52 | """
53 | Apply scaled dot product attention to a sparse tensor.
54 |
55 | Args:
56 | q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
57 | kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
58 | """
59 | ...
60 |
61 | @overload
62 | def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
63 | """
64 | Apply scaled dot product attention to a sparse tensor.
65 |
66 | Args:
67 | q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
68 | kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
69 | """
70 | ...
71 |
72 | @overload
73 | def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
74 | """
75 | Apply scaled dot product attention to a sparse tensor.
76 |
77 | Args:
78 | q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
79 | k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
80 | v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
81 |
82 | Note:
83 | k and v are assumed to have the same coordinate map.
84 | """
85 | ...
86 |
87 | @overload
88 | def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
89 | """
90 | Apply scaled dot product attention to a sparse tensor.
91 |
92 | Args:
93 | q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
94 | k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
95 | v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
96 | """
97 | ...
98 |
99 | @overload
100 | def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
101 | """
102 | Apply scaled dot product attention to a sparse tensor.
103 |
104 | Args:
105 | q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
106 | k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
107 | v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
108 | """
109 | ...
110 |
111 | def sparse_scaled_dot_product_attention(*args, **kwargs):
112 | arg_names_dict = {
113 | 1: ['qkv'],
114 | 2: ['q', 'kv'],
115 | 3: ['q', 'k', 'v']
116 | }
117 | num_all_args = len(args) + len(kwargs)
118 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
119 | for key in arg_names_dict[num_all_args][len(args):]:
120 | assert key in kwargs, f"Missing argument {key}"
121 |
122 | if num_all_args == 1:
123 | qkv = args[0] if len(args) > 0 else kwargs['qkv']
124 | assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
125 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
126 | device = qkv.device
127 |
128 | s = qkv
129 | q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
130 | kv_seqlen = q_seqlen
131 | qkv = qkv.feats # [T, 3, H, C]
132 |
133 | elif num_all_args == 2:
134 | q = args[0] if len(args) > 0 else kwargs['q']
135 | kv = args[1] if len(args) > 1 else kwargs['kv']
136 | assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
137 | isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
138 | f"Invalid types, got {type(q)} and {type(kv)}"
139 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
140 | device = q.device
141 |
142 | if isinstance(q, SparseTensor):
143 | assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
144 | s = q
145 | q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
146 | q = q.feats # [T_Q, H, C]
147 | else:
148 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
149 | s = None
150 | N, L, H, C = q.shape
151 | q_seqlen = [L] * N
152 | q = q.reshape(N * L, H, C) # [T_Q, H, C]
153 |
154 | if isinstance(kv, SparseTensor):
155 | assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
156 | kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
157 | kv = kv.feats # [T_KV, 2, H, C]
158 | else:
159 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
160 | N, L, _, H, C = kv.shape
161 | kv_seqlen = [L] * N
162 | kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
163 |
164 | elif num_all_args == 3:
165 | q = args[0] if len(args) > 0 else kwargs['q']
166 | k = args[1] if len(args) > 1 else kwargs['k']
167 | v = args[2] if len(args) > 2 else kwargs['v']
168 | assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
169 | isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
170 | f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
171 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
172 | device = q.device
173 |
174 | if isinstance(q, SparseTensor):
175 | assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
176 | s = q
177 | q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
178 | q = q.feats # [T_Q, H, Ci]
179 | else:
180 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
181 | s = None
182 | N, L, H, CI = q.shape
183 | q_seqlen = [L] * N
184 | q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
185 |
186 | if isinstance(k, SparseTensor):
187 | assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
188 | assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
189 | kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
190 | k = k.feats # [T_KV, H, Ci]
191 | v = v.feats # [T_KV, H, Co]
192 | else:
193 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
194 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
195 | N, L, H, CI, CO = *k.shape, v.shape[-1]
196 | kv_seqlen = [L] * N
197 | k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
198 | v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
199 |
200 | if DEBUG:
201 | if s is not None:
202 | for i in range(s.shape[0]):
203 | assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
204 | if num_all_args in [2, 3]:
205 | assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
206 | if num_all_args == 3:
207 | assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
208 | assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
209 |
210 | if ATTN == 'xformers':
211 | if num_all_args == 1:
212 | q, k, v = qkv.unbind(dim=1)
213 | elif num_all_args == 2:
214 | k, v = kv.unbind(dim=1)
215 | q = q.unsqueeze(0)
216 | k = k.unsqueeze(0)
217 | v = v.unsqueeze(0)
218 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
219 | out = xops.memory_efficient_attention(q, k, v, mask)[0]
220 | elif ATTN == 'flash_attn':
221 | cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
222 | if num_all_args in [2, 3]:
223 | cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
224 | if num_all_args == 1:
225 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
226 | elif num_all_args == 2:
227 | out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
228 | elif num_all_args == 3:
229 | out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
230 | else:
231 | raise ValueError(f"Unknown attention module: {ATTN}")
232 |
233 | if s is not None:
234 | return s.replace(out)
235 | else:
236 | return out.reshape(N, L, H, -1)
--------------------------------------------------------------------------------
/trellis/modules/sparse/attention/modules.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .. import SparseTensor
6 | from .full_attn import sparse_scaled_dot_product_attention
7 | from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
8 | from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
9 | from ...attention import RotaryPositionEmbedder
10 |
11 |
12 | class SparseMultiHeadRMSNorm(nn.Module):
13 | def __init__(self, dim: int, heads: int):
14 | super().__init__()
15 | self.scale = dim ** 0.5
16 | self.gamma = nn.Parameter(torch.ones(heads, dim))
17 |
18 | def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
19 | x_type = x.dtype
20 | x = x.float()
21 | if isinstance(x, SparseTensor):
22 | x = x.replace(F.normalize(x.feats, dim=-1))
23 | else:
24 | x = F.normalize(x, dim=-1)
25 | return (x * self.gamma * self.scale).to(x_type)
26 |
27 |
28 | class SparseMultiHeadAttention(nn.Module):
29 | def __init__(
30 | self,
31 | channels: int,
32 | num_heads: int,
33 | ctx_channels: Optional[int] = None,
34 | type: Literal["self", "cross"] = "self",
35 | attn_mode: Literal["full", "serialized", "windowed"] = "full",
36 | window_size: Optional[int] = None,
37 | shift_sequence: Optional[int] = None,
38 | shift_window: Optional[Tuple[int, int, int]] = None,
39 | serialize_mode: Optional[SerializeMode] = None,
40 | qkv_bias: bool = True,
41 | use_rope: bool = False,
42 | qk_rms_norm: bool = False,
43 | ):
44 | super().__init__()
45 | assert channels % num_heads == 0
46 | assert type in ["self", "cross"], f"Invalid attention type: {type}"
47 | assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
48 | assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
49 | assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
50 | self.channels = channels
51 | self.ctx_channels = ctx_channels if ctx_channels is not None else channels
52 | self.num_heads = num_heads
53 | self._type = type
54 | self.attn_mode = attn_mode
55 | self.window_size = window_size
56 | self.shift_sequence = shift_sequence
57 | self.shift_window = shift_window
58 | self.serialize_mode = serialize_mode
59 | self.use_rope = use_rope
60 | self.qk_rms_norm = qk_rms_norm
61 |
62 | if self._type == "self":
63 | self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
64 | else:
65 | self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
66 | self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
67 |
68 | if self.qk_rms_norm:
69 | self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
70 | self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
71 |
72 | self.to_out = nn.Linear(channels, channels)
73 |
74 | if use_rope:
75 | self.rope = RotaryPositionEmbedder(channels)
76 |
77 | @staticmethod
78 | def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
79 | if isinstance(x, SparseTensor):
80 | return x.replace(module(x.feats))
81 | else:
82 | return module(x)
83 |
84 | @staticmethod
85 | def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
86 | if isinstance(x, SparseTensor):
87 | return x.reshape(*shape)
88 | else:
89 | return x.reshape(*x.shape[:2], *shape)
90 |
91 | def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
92 | if isinstance(x, SparseTensor):
93 | x_feats = x.feats.unsqueeze(0)
94 | else:
95 | x_feats = x
96 | x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
97 | return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
98 |
99 | def _rope(self, qkv: SparseTensor) -> SparseTensor:
100 | q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
101 | q, k = self.rope(q, k, qkv.coords[:, 1:])
102 | qkv = qkv.replace(torch.stack([q, k, v], dim=1))
103 | return qkv
104 |
105 | def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
106 | if self._type == "self":
107 | qkv = self._linear(self.to_qkv, x)
108 | qkv = self._fused_pre(qkv, num_fused=3)
109 | if self.use_rope:
110 | qkv = self._rope(qkv)
111 | if self.qk_rms_norm:
112 | q, k, v = qkv.unbind(dim=1)
113 | q = self.q_rms_norm(q)
114 | k = self.k_rms_norm(k)
115 | qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
116 | if self.attn_mode == "full":
117 | h = sparse_scaled_dot_product_attention(qkv)
118 | elif self.attn_mode == "serialized":
119 | h = sparse_serialized_scaled_dot_product_self_attention(
120 | qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
121 | )
122 | elif self.attn_mode == "windowed":
123 | h = sparse_windowed_scaled_dot_product_self_attention(
124 | qkv, self.window_size, shift_window=self.shift_window
125 | )
126 | else:
127 | q = self._linear(self.to_q, x)
128 | q = self._reshape_chs(q, (self.num_heads, -1))
129 | kv = self._linear(self.to_kv, context)
130 | kv = self._fused_pre(kv, num_fused=2)
131 | if self.qk_rms_norm:
132 | q = self.q_rms_norm(q)
133 | k, v = kv.unbind(dim=1)
134 | k = self.k_rms_norm(k)
135 | kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
136 | h = sparse_scaled_dot_product_attention(q, kv)
137 | h = self._reshape_chs(h, (-1,))
138 | h = self._linear(self.to_out, h)
139 | return h
140 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/attention/serialized_attn.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | from enum import Enum
3 | import torch
4 | import math
5 | from .. import SparseTensor
6 | from trellis.backend_config import (
7 | get_attention_backend,
8 | get_debug_mode,
9 | get_available_backends
10 | )
11 | import logging
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | # Get configuration from central config
16 | ATTN = get_attention_backend()
17 | DEBUG = get_debug_mode()
18 |
19 | # Get available backends and import if active
20 | available_backends = get_available_backends()
21 |
22 | if ATTN not in ['xformers', 'flash_attn']:
23 | logger.warning(f"Attention backend {ATTN} not supported for sparse attention. Only 'xformers' and 'flash_attn' are available. Defaulting to 'flash_attn'")
24 | ATTN = 'flash_attn'
25 |
26 | if ATTN == 'xformers' and available_backends['xformers']:
27 | import xformers.ops as xops
28 | elif ATTN == 'flash_attn' and available_backends['flash_attn']:
29 | import flash_attn
30 | else:
31 | raise ImportError(f"Could not import {ATTN}. Please install either xformers or flash-attn for sparse attention support.")
32 |
33 |
34 | __all__ = [
35 | 'sparse_serialized_scaled_dot_product_self_attention',
36 | ]
37 |
38 |
39 | class SerializeMode(Enum):
40 | Z_ORDER = 0
41 | Z_ORDER_TRANSPOSED = 1
42 | HILBERT = 2
43 | HILBERT_TRANSPOSED = 3
44 |
45 |
46 | SerializeModes = [
47 | SerializeMode.Z_ORDER,
48 | SerializeMode.Z_ORDER_TRANSPOSED,
49 | SerializeMode.HILBERT,
50 | SerializeMode.HILBERT_TRANSPOSED
51 | ]
52 |
53 |
54 | def calc_serialization(
55 | tensor: SparseTensor,
56 | window_size: int,
57 | serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
58 | shift_sequence: int = 0,
59 | shift_window: Tuple[int, int, int] = (0, 0, 0)
60 | ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
61 | """
62 | Calculate serialization and partitioning for a set of coordinates.
63 |
64 | Args:
65 | tensor (SparseTensor): The input tensor.
66 | window_size (int): The window size to use.
67 | serialize_mode (SerializeMode): The serialization mode to use.
68 | shift_sequence (int): The shift of serialized sequence.
69 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
70 |
71 | Returns:
72 | (torch.Tensor, torch.Tensor): Forwards and backwards indices.
73 | """
74 | fwd_indices = []
75 | bwd_indices = []
76 | seq_lens = []
77 | seq_batch_indices = []
78 | offsets = [0]
79 |
80 | if 'vox2seq' not in globals():
81 | import vox2seq
82 |
83 | # Serialize the input
84 | serialize_coords = tensor.coords[:, 1:].clone()
85 | serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
86 | if serialize_mode == SerializeMode.Z_ORDER:
87 | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
88 | elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
89 | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
90 | elif serialize_mode == SerializeMode.HILBERT:
91 | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
92 | elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
93 | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
94 | else:
95 | raise ValueError(f"Unknown serialize mode: {serialize_mode}")
96 |
97 | for bi, s in enumerate(tensor.layout):
98 | num_points = s.stop - s.start
99 | num_windows = (num_points + window_size - 1) // window_size
100 | valid_window_size = num_points / num_windows
101 | to_ordered = torch.argsort(code[s.start:s.stop])
102 | if num_windows == 1:
103 | fwd_indices.append(to_ordered)
104 | bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
105 | fwd_indices[-1] += s.start
106 | bwd_indices[-1] += offsets[-1]
107 | seq_lens.append(num_points)
108 | seq_batch_indices.append(bi)
109 | offsets.append(offsets[-1] + seq_lens[-1])
110 | else:
111 | # Partition the input
112 | offset = 0
113 | mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
114 | split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
115 | bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
116 | for i in range(num_windows):
117 | mid = mids[i]
118 | valid_start = split[i]
119 | valid_end = split[i + 1]
120 | padded_start = math.floor(mid - 0.5 * window_size)
121 | padded_end = padded_start + window_size
122 | fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
123 | offset += valid_start - padded_start
124 | bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
125 | offset += padded_end - valid_start
126 | fwd_indices[-1] += s.start
127 | seq_lens.extend([window_size] * num_windows)
128 | seq_batch_indices.extend([bi] * num_windows)
129 | bwd_indices.append(bwd_index + offsets[-1])
130 | offsets.append(offsets[-1] + num_windows * window_size)
131 |
132 | fwd_indices = torch.cat(fwd_indices)
133 | bwd_indices = torch.cat(bwd_indices)
134 |
135 | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
136 |
137 |
138 | def sparse_serialized_scaled_dot_product_self_attention(
139 | qkv: SparseTensor,
140 | window_size: int,
141 | serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
142 | shift_sequence: int = 0,
143 | shift_window: Tuple[int, int, int] = (0, 0, 0)
144 | ) -> SparseTensor:
145 | """
146 | Apply serialized scaled dot product self attention to a sparse tensor.
147 |
148 | Args:
149 | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
150 | window_size (int): The window size to use.
151 | serialize_mode (SerializeMode): The serialization mode to use.
152 | shift_sequence (int): The shift of serialized sequence.
153 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
154 | shift (int): The shift to use.
155 | """
156 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
157 |
158 | serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
159 | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
160 | if serialization_spatial_cache is None:
161 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
162 | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
163 | else:
164 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
165 |
166 | M = fwd_indices.shape[0]
167 | T = qkv.feats.shape[0]
168 | H = qkv.feats.shape[2]
169 | C = qkv.feats.shape[3]
170 |
171 | qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
172 |
173 | if DEBUG:
174 | start = 0
175 | qkv_coords = qkv.coords[fwd_indices]
176 | for i in range(len(seq_lens)):
177 | assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
178 | start += seq_lens[i]
179 |
180 | if all([seq_len == window_size for seq_len in seq_lens]):
181 | B = len(seq_lens)
182 | N = window_size
183 | qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
184 | if ATTN == 'xformers':
185 | q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
186 | out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
187 | elif ATTN == 'flash_attn':
188 | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
189 | else:
190 | raise ValueError(f"Unknown attention module: {ATTN}")
191 | out = out.reshape(B * N, H, C) # [M, H, C]
192 | else:
193 | if ATTN == 'xformers':
194 | q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
195 | q = q.unsqueeze(0) # [1, M, H, C]
196 | k = k.unsqueeze(0) # [1, M, H, C]
197 | v = v.unsqueeze(0) # [1, M, H, C]
198 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
199 | out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
200 | elif ATTN == 'flash_attn':
201 | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
202 | .to(qkv.device).int()
203 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
204 |
205 | out = out[bwd_indices] # [T, H, C]
206 |
207 | if DEBUG:
208 | qkv_coords = qkv_coords[bwd_indices]
209 | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
210 |
211 | return qkv.replace(out)
212 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/attention/windowed_attn.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | from enum import Enum
3 | import math
4 | import os
5 | import logging
6 | import torch
7 | from .. import SparseTensor
8 | from trellis.backend_config import (
9 | get_attention_backend,
10 | get_debug_mode,
11 | get_available_backends
12 | )
13 | import logging
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | # Get configuration from central config
18 | # Get configuration from central config
19 | ATTN = get_attention_backend()
20 | DEBUG = get_debug_mode()
21 |
22 | # Get available backends and import if active
23 | available_backends = get_available_backends()
24 |
25 | if ATTN not in ['xformers', 'flash_attn']:
26 | logger.warning(f"Attention backend {ATTN} not supported for sparse attention. Only 'xformers' and 'flash_attn' are available. Defaulting to 'flash_attn'")
27 | ATTN = 'flash_attn'
28 |
29 | if ATTN == 'xformers' and available_backends['xformers']:
30 | import xformers.ops as xops
31 | elif ATTN == 'flash_attn' and available_backends['flash_attn']:
32 | import flash_attn
33 | else:
34 | raise ImportError(f"Could not import {ATTN}. Please install either xformers or flash-attn for sparse attention support.")
35 |
36 |
37 | __all__ = [
38 | 'sparse_windowed_scaled_dot_product_self_attention',
39 | ]
40 |
41 |
42 | def calc_window_partition(
43 | tensor: SparseTensor,
44 | window_size: Union[int, Tuple[int, ...]],
45 | shift_window: Union[int, Tuple[int, ...]] = 0
46 | ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
47 | """
48 | Calculate serialization and partitioning for a set of coordinates.
49 |
50 | Args:
51 | tensor (SparseTensor): The input tensor.
52 | window_size (int): The window size to use.
53 | shift_window (Tuple[int, ...]): The shift of serialized coordinates.
54 |
55 | Returns:
56 | (torch.Tensor): Forwards indices.
57 | (torch.Tensor): Backwards indices.
58 | (List[int]): Sequence lengths.
59 | (List[int]): Sequence batch indices.
60 | """
61 | DIM = tensor.coords.shape[1] - 1
62 | shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
63 | window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
64 | shifted_coords = tensor.coords.clone().detach()
65 | shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
66 |
67 | MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
68 | NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
69 | OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
70 |
71 | shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
72 | shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
73 | fwd_indices = torch.argsort(shifted_indices)
74 | bwd_indices = torch.empty_like(fwd_indices)
75 | bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
76 | seq_lens = torch.bincount(shifted_indices)
77 | seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
78 | mask = seq_lens != 0
79 | seq_lens = seq_lens[mask].tolist()
80 | seq_batch_indices = seq_batch_indices[mask].tolist()
81 |
82 | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
83 |
84 |
85 | def sparse_windowed_scaled_dot_product_self_attention(
86 | qkv: SparseTensor,
87 | window_size: int,
88 | shift_window: Tuple[int, int, int] = (0, 0, 0)
89 | ) -> SparseTensor:
90 | """
91 | Apply windowed scaled dot product self attention to a sparse tensor.
92 |
93 | Args:
94 | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
95 | window_size (int): The window size to use.
96 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
97 | shift (int): The shift to use.
98 | """
99 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
100 |
101 | serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
102 | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
103 | if serialization_spatial_cache is None:
104 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
105 | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
106 | else:
107 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
108 |
109 | M = fwd_indices.shape[0]
110 | T = qkv.feats.shape[0]
111 | H = qkv.feats.shape[2]
112 | C = qkv.feats.shape[3]
113 |
114 | qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
115 |
116 | if DEBUG:
117 | start = 0
118 | qkv_coords = qkv.coords[fwd_indices]
119 | for i in range(len(seq_lens)):
120 | seq_coords = qkv_coords[start:start+seq_lens[i]]
121 | assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
122 | assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
123 | f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
124 | start += seq_lens[i]
125 |
126 | if all([seq_len == window_size for seq_len in seq_lens]):
127 | B = len(seq_lens)
128 | N = window_size
129 | qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
130 | if ATTN == 'xformers':
131 | q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
132 | out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
133 | elif ATTN == 'flash_attn':
134 | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
135 | else:
136 | raise ValueError(f"Unknown attention module: {ATTN}")
137 | out = out.reshape(B * N, H, C) # [M, H, C]
138 | else:
139 | if ATTN == 'xformers':
140 | q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
141 | q = q.unsqueeze(0) # [1, M, H, C]
142 | k = k.unsqueeze(0) # [1, M, H, C]
143 | v = v.unsqueeze(0) # [1, M, H, C]
144 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
145 | out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
146 | elif ATTN == 'flash_attn':
147 | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
148 | .to(qkv.device).int()
149 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
150 |
151 | out = out[bwd_indices] # [T, H, C]
152 |
153 | if DEBUG:
154 | qkv_coords = qkv_coords[bwd_indices]
155 | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
156 |
157 | return qkv.replace(out)
158 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/conv/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from trellis.backend_config import get_sparse_backend, get_spconv_algo
4 |
5 | BACKEND = get_sparse_backend()
6 | SPCONV_ALGO = get_spconv_algo()
7 |
8 | def __from_env():
9 | import os
10 |
11 | global SPCONV_ALGO
12 | env_spconv_algo = os.environ.get('SPCONV_ALGO')
13 | if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
14 | SPCONV_ALGO = env_spconv_algo
15 | print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
16 |
17 |
18 | __from_env()
19 |
20 | if BACKEND == 'torchsparse':
21 | from .conv_torchsparse import *
22 | elif BACKEND == 'spconv':
23 | from .conv_spconv import *
24 |
25 | __all__ = [
26 | "SparseConv3d",
27 | "SparseInverseConv3d",
28 | ]
29 |
30 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/conv/conv_spconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .. import SparseTensor
4 | from .. import DEBUG
5 | from trellis.backend_config import get_debug_mode, get_spconv_algo, get_sparse_backend
6 |
7 | # Get configuration from central config
8 | DEBUG = get_debug_mode()
9 | SPCONV_ALGO = get_spconv_algo()
10 | BACKEND = get_sparse_backend()
11 |
12 | class SparseConv3d(nn.Module):
13 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
14 | super(SparseConv3d, self).__init__()
15 | if BACKEND == 'spconv':
16 | import spconv.pytorch as spconv
17 | algo = None
18 | if SPCONV_ALGO == 'native':
19 | algo = spconv.ConvAlgo.Native
20 | elif SPCONV_ALGO == 'implicit_gemm':
21 | algo = spconv.ConvAlgo.MaskImplicitGemm
22 | if stride == 1 and (padding is None):
23 | self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
24 | else:
25 | self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
26 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
27 | self.padding = padding
28 |
29 | def forward(self, x: SparseTensor) -> SparseTensor:
30 | spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
31 | new_data = self.conv(x.data)
32 | new_shape = [x.shape[0], self.conv.out_channels]
33 | new_layout = None if spatial_changed else x.layout
34 |
35 | if spatial_changed and (x.shape[0] != 1):
36 | # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
37 | fwd = new_data.indices[:, 0].argsort()
38 | bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
39 | sorted_feats = new_data.features[fwd]
40 | sorted_coords = new_data.indices[fwd]
41 | unsorted_data = new_data
42 | new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
43 |
44 | out = SparseTensor(
45 | new_data, shape=torch.Size(new_shape), layout=new_layout,
46 | scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
47 | spatial_cache=x._spatial_cache,
48 | )
49 |
50 | if spatial_changed and (x.shape[0] != 1):
51 | out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
52 | out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
53 |
54 | return out
55 |
56 |
57 | class SparseInverseConv3d(nn.Module):
58 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
59 | super(SparseInverseConv3d, self).__init__()
60 | if BACKEND == 'spconv':
61 | import spconv.pytorch as spconv
62 | self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
63 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
64 |
65 | def forward(self, x: SparseTensor) -> SparseTensor:
66 | spatial_changed = any(s != 1 for s in self.stride)
67 | if spatial_changed:
68 | # recover the original spconv order
69 | data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
70 | bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
71 | data = data.replace_feature(x.feats[bwd])
72 | if DEBUG:
73 | assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
74 | else:
75 | data = x.data
76 |
77 | new_data = self.conv(data)
78 | new_shape = [x.shape[0], self.conv.out_channels]
79 | new_layout = None if spatial_changed else x.layout
80 | out = SparseTensor(
81 | new_data, shape=torch.Size(new_shape), layout=new_layout,
82 | scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
83 | spatial_cache=x._spatial_cache,
84 | )
85 | return out
86 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/conv/conv_torchsparse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .. import SparseTensor
4 | from trellis.backend_config import get_debug_mode, get_spconv_algo, get_sparse_backend
5 |
6 | # Get configuration from central config
7 | DEBUG = get_debug_mode()
8 | SPCONV_ALGO = get_spconv_algo()
9 | BACKEND = get_sparse_backend()
10 |
11 | class SparseConv3d(nn.Module):
12 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
13 | super(SparseConv3d, self).__init__()
14 | if BACKEND == 'torchsparse':
15 | import torchsparse
16 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
17 |
18 | def forward(self, x: SparseTensor) -> SparseTensor:
19 | out = self.conv(x.data)
20 | new_shape = [x.shape[0], self.conv.out_channels]
21 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
22 | out._spatial_cache = x._spatial_cache
23 | out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
24 | return out
25 |
26 |
27 | class SparseInverseConv3d(nn.Module):
28 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
29 | super(SparseInverseConv3d, self).__init__()
30 | if BACKEND == 'torchsparse':
31 | import torchsparse
32 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
33 |
34 | def forward(self, x: SparseTensor) -> SparseTensor:
35 | out = self.conv(x.data)
36 | new_shape = [x.shape[0], self.conv.out_channels]
37 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
38 | out._spatial_cache = x._spatial_cache
39 | out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
40 | return out
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/linear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from . import SparseTensor
4 |
5 | __all__ = [
6 | 'SparseLinear'
7 | ]
8 |
9 |
10 | class SparseLinear(nn.Linear):
11 | def __init__(self, in_features, out_features, bias=True):
12 | super(SparseLinear, self).__init__(in_features, out_features, bias)
13 |
14 | def forward(self, input: SparseTensor) -> SparseTensor:
15 | return input.replace(super().forward(input.feats))
16 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/nonlinearity.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from . import SparseTensor
4 |
5 | __all__ = [
6 | 'SparseReLU',
7 | 'SparseSiLU',
8 | 'SparseGELU',
9 | 'SparseActivation'
10 | ]
11 |
12 |
13 | class SparseReLU(nn.ReLU):
14 | def forward(self, input: SparseTensor) -> SparseTensor:
15 | return input.replace(super().forward(input.feats))
16 |
17 |
18 | class SparseSiLU(nn.SiLU):
19 | def forward(self, input: SparseTensor) -> SparseTensor:
20 | return input.replace(super().forward(input.feats))
21 |
22 |
23 | class SparseGELU(nn.GELU):
24 | def forward(self, input: SparseTensor) -> SparseTensor:
25 | return input.replace(super().forward(input.feats))
26 |
27 |
28 | class SparseActivation(nn.Module):
29 | def __init__(self, activation: nn.Module):
30 | super().__init__()
31 | self.activation = activation
32 |
33 | def forward(self, input: SparseTensor) -> SparseTensor:
34 | return input.replace(self.activation(input.feats))
35 |
36 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from . import SparseTensor
4 | from . import DEBUG
5 |
6 | __all__ = [
7 | 'SparseGroupNorm',
8 | 'SparseLayerNorm',
9 | 'SparseGroupNorm32',
10 | 'SparseLayerNorm32',
11 | ]
12 |
13 |
14 | class SparseGroupNorm(nn.GroupNorm):
15 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
16 | super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
17 |
18 | def forward(self, input: SparseTensor) -> SparseTensor:
19 | nfeats = torch.zeros_like(input.feats)
20 | for k in range(input.shape[0]):
21 | if DEBUG:
22 | assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
23 | bfeats = input.feats[input.layout[k]]
24 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
25 | bfeats = super().forward(bfeats)
26 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
27 | nfeats[input.layout[k]] = bfeats
28 | return input.replace(nfeats)
29 |
30 |
31 | class SparseLayerNorm(nn.LayerNorm):
32 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
33 | super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
34 |
35 | def forward(self, input: SparseTensor) -> SparseTensor:
36 | nfeats = torch.zeros_like(input.feats)
37 | for k in range(input.shape[0]):
38 | bfeats = input.feats[input.layout[k]]
39 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
40 | bfeats = super().forward(bfeats)
41 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
42 | nfeats[input.layout[k]] = bfeats
43 | return input.replace(nfeats)
44 |
45 |
46 | class SparseGroupNorm32(SparseGroupNorm):
47 | """
48 | A GroupNorm layer that converts to float32 before the forward pass.
49 | """
50 | def forward(self, x: SparseTensor) -> SparseTensor:
51 | return super().forward(x.float()).type(x.dtype)
52 |
53 | class SparseLayerNorm32(SparseLayerNorm):
54 | """
55 | A LayerNorm layer that converts to float32 before the forward pass.
56 | """
57 | def forward(self, x: SparseTensor) -> SparseTensor:
58 | return super().forward(x.float()).type(x.dtype)
59 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/spatial.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | from . import SparseTensor
5 |
6 | __all__ = [
7 | 'SparseDownsample',
8 | 'SparseUpsample',
9 | 'SparseSubdivide'
10 | ]
11 |
12 |
13 | class SparseDownsample(nn.Module):
14 | """
15 | Downsample a sparse tensor by a factor of `factor`.
16 | Implemented as average pooling.
17 | """
18 | def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
19 | super(SparseDownsample, self).__init__()
20 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
21 |
22 | def forward(self, input: SparseTensor) -> SparseTensor:
23 | DIM = input.coords.shape[-1] - 1
24 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
25 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
26 |
27 | coord = list(input.coords.unbind(dim=-1))
28 | for i, f in enumerate(factor):
29 | coord[i+1] = coord[i+1] // f
30 |
31 | MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
32 | OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
33 | code = sum([c * o for c, o in zip(coord, OFFSET)])
34 | code, idx = code.unique(return_inverse=True)
35 |
36 | new_feats = torch.scatter_reduce(
37 | torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
38 | dim=0,
39 | index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
40 | src=input.feats,
41 | reduce='mean'
42 | )
43 | new_coords = torch.stack(
44 | [code // OFFSET[0]] +
45 | [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
46 | dim=-1
47 | )
48 | out = SparseTensor(new_feats, new_coords, input.shape,)
49 | out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
50 | out._spatial_cache = input._spatial_cache
51 |
52 | out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
53 | out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
54 | out.register_spatial_cache(f'upsample_{factor}_idx', idx)
55 |
56 | return out
57 |
58 |
59 | class SparseUpsample(nn.Module):
60 | """
61 | Upsample a sparse tensor by a factor of `factor`.
62 | Implemented as nearest neighbor interpolation.
63 | """
64 | def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
65 | super(SparseUpsample, self).__init__()
66 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
67 |
68 | def forward(self, input: SparseTensor) -> SparseTensor:
69 | DIM = input.coords.shape[-1] - 1
70 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
71 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
72 |
73 | new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
74 | new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
75 | idx = input.get_spatial_cache(f'upsample_{factor}_idx')
76 | if any([x is None for x in [new_coords, new_layout, idx]]):
77 | raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
78 | new_feats = input.feats[idx]
79 | out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
80 | out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
81 | out._spatial_cache = input._spatial_cache
82 | return out
83 |
84 | class SparseSubdivide(nn.Module):
85 | """
86 | Upsample a sparse tensor by a factor of `factor`.
87 | Implemented as nearest neighbor interpolation.
88 | """
89 | def __init__(self):
90 | super(SparseSubdivide, self).__init__()
91 |
92 | def forward(self, input: SparseTensor) -> SparseTensor:
93 | DIM = input.coords.shape[-1] - 1
94 | # upsample scale=2^DIM
95 | n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
96 | n_coords = torch.nonzero(n_cube)
97 | n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
98 | factor = n_coords.shape[0]
99 | assert factor == 2 ** DIM
100 | # print(n_coords.shape)
101 | new_coords = input.coords.clone()
102 | new_coords[:, 1:] *= 2
103 | new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
104 |
105 | new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
106 | out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
107 | out._scale = input._scale * 2
108 | out._spatial_cache = input._spatial_cache
109 | return out
110 |
111 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .blocks import *
2 | from .modulated import *
--------------------------------------------------------------------------------
/trellis/modules/sparse/transformer/blocks.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | from ..basic import SparseTensor
5 | from ..linear import SparseLinear
6 | from ..nonlinearity import SparseGELU
7 | from ..attention import SparseMultiHeadAttention, SerializeMode
8 | from ...norm import LayerNorm32
9 |
10 |
11 | class SparseFeedForwardNet(nn.Module):
12 | def __init__(self, channels: int, mlp_ratio: float = 4.0):
13 | super().__init__()
14 | self.mlp = nn.Sequential(
15 | SparseLinear(channels, int(channels * mlp_ratio)),
16 | SparseGELU(approximate="tanh"),
17 | SparseLinear(int(channels * mlp_ratio), channels),
18 | )
19 |
20 | def forward(self, x: SparseTensor) -> SparseTensor:
21 | return self.mlp(x)
22 |
23 |
24 | class SparseTransformerBlock(nn.Module):
25 | """
26 | Sparse Transformer block (MSA + FFN).
27 | """
28 | def __init__(
29 | self,
30 | channels: int,
31 | num_heads: int,
32 | mlp_ratio: float = 4.0,
33 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
34 | window_size: Optional[int] = None,
35 | shift_sequence: Optional[int] = None,
36 | shift_window: Optional[Tuple[int, int, int]] = None,
37 | serialize_mode: Optional[SerializeMode] = None,
38 | use_checkpoint: bool = False,
39 | use_rope: bool = False,
40 | qk_rms_norm: bool = False,
41 | qkv_bias: bool = True,
42 | ln_affine: bool = False,
43 | ):
44 | super().__init__()
45 | self.use_checkpoint = use_checkpoint
46 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
47 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
48 | self.attn = SparseMultiHeadAttention(
49 | channels,
50 | num_heads=num_heads,
51 | attn_mode=attn_mode,
52 | window_size=window_size,
53 | shift_sequence=shift_sequence,
54 | shift_window=shift_window,
55 | serialize_mode=serialize_mode,
56 | qkv_bias=qkv_bias,
57 | use_rope=use_rope,
58 | qk_rms_norm=qk_rms_norm,
59 | )
60 | self.mlp = SparseFeedForwardNet(
61 | channels,
62 | mlp_ratio=mlp_ratio,
63 | )
64 |
65 | def _forward(self, x: SparseTensor) -> SparseTensor:
66 | h = x.replace(self.norm1(x.feats))
67 | h = self.attn(h)
68 | x = x + h
69 | h = x.replace(self.norm2(x.feats))
70 | h = self.mlp(h)
71 | x = x + h
72 | return x
73 |
74 | def forward(self, x: SparseTensor) -> SparseTensor:
75 | if self.use_checkpoint:
76 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
77 | else:
78 | return self._forward(x)
79 |
80 |
81 | class SparseTransformerCrossBlock(nn.Module):
82 | """
83 | Sparse Transformer cross-attention block (MSA + MCA + FFN).
84 | """
85 | def __init__(
86 | self,
87 | channels: int,
88 | ctx_channels: int,
89 | num_heads: int,
90 | mlp_ratio: float = 4.0,
91 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
92 | window_size: Optional[int] = None,
93 | shift_sequence: Optional[int] = None,
94 | shift_window: Optional[Tuple[int, int, int]] = None,
95 | serialize_mode: Optional[SerializeMode] = None,
96 | use_checkpoint: bool = False,
97 | use_rope: bool = False,
98 | qk_rms_norm: bool = False,
99 | qk_rms_norm_cross: bool = False,
100 | qkv_bias: bool = True,
101 | ln_affine: bool = False,
102 | ):
103 | super().__init__()
104 | self.use_checkpoint = use_checkpoint
105 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
106 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
107 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
108 | self.self_attn = SparseMultiHeadAttention(
109 | channels,
110 | num_heads=num_heads,
111 | type="self",
112 | attn_mode=attn_mode,
113 | window_size=window_size,
114 | shift_sequence=shift_sequence,
115 | shift_window=shift_window,
116 | serialize_mode=serialize_mode,
117 | qkv_bias=qkv_bias,
118 | use_rope=use_rope,
119 | qk_rms_norm=qk_rms_norm,
120 | )
121 | self.cross_attn = SparseMultiHeadAttention(
122 | channels,
123 | ctx_channels=ctx_channels,
124 | num_heads=num_heads,
125 | type="cross",
126 | attn_mode="full",
127 | qkv_bias=qkv_bias,
128 | qk_rms_norm=qk_rms_norm_cross,
129 | )
130 | self.mlp = SparseFeedForwardNet(
131 | channels,
132 | mlp_ratio=mlp_ratio,
133 | )
134 |
135 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
136 | h = x.replace(self.norm1(x.feats))
137 | h = self.self_attn(h)
138 | x = x + h
139 | h = x.replace(self.norm2(x.feats))
140 | h = self.cross_attn(h, context)
141 | x = x + h
142 | h = x.replace(self.norm3(x.feats))
143 | h = self.mlp(h)
144 | x = x + h
145 | return x
146 |
147 | def forward(self, x: SparseTensor, context: torch.Tensor):
148 | if self.use_checkpoint:
149 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
150 | else:
151 | return self._forward(x, context)
152 |
--------------------------------------------------------------------------------
/trellis/modules/sparse/transformer/modulated.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | from ..basic import SparseTensor
5 | from ..attention import SparseMultiHeadAttention, SerializeMode
6 | from ...norm import LayerNorm32
7 | from .blocks import SparseFeedForwardNet
8 |
9 |
10 | class ModulatedSparseTransformerBlock(nn.Module):
11 | """
12 | Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
13 | """
14 | def __init__(
15 | self,
16 | channels: int,
17 | num_heads: int,
18 | mlp_ratio: float = 4.0,
19 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
20 | window_size: Optional[int] = None,
21 | shift_sequence: Optional[int] = None,
22 | shift_window: Optional[Tuple[int, int, int]] = None,
23 | serialize_mode: Optional[SerializeMode] = None,
24 | use_checkpoint: bool = False,
25 | use_rope: bool = False,
26 | qk_rms_norm: bool = False,
27 | qkv_bias: bool = True,
28 | share_mod: bool = False,
29 | ):
30 | super().__init__()
31 | self.use_checkpoint = use_checkpoint
32 | self.share_mod = share_mod
33 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
34 | self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
35 | self.attn = SparseMultiHeadAttention(
36 | channels,
37 | num_heads=num_heads,
38 | attn_mode=attn_mode,
39 | window_size=window_size,
40 | shift_sequence=shift_sequence,
41 | shift_window=shift_window,
42 | serialize_mode=serialize_mode,
43 | qkv_bias=qkv_bias,
44 | use_rope=use_rope,
45 | qk_rms_norm=qk_rms_norm,
46 | )
47 | self.mlp = SparseFeedForwardNet(
48 | channels,
49 | mlp_ratio=mlp_ratio,
50 | )
51 | if not share_mod:
52 | self.adaLN_modulation = nn.Sequential(
53 | nn.SiLU(),
54 | nn.Linear(channels, 6 * channels, bias=True)
55 | )
56 |
57 | def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
58 | if self.share_mod:
59 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
60 | else:
61 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
62 | h = x.replace(self.norm1(x.feats))
63 | h = h * (1 + scale_msa) + shift_msa
64 | h = self.attn(h)
65 | h = h * gate_msa
66 | x = x + h
67 | h = x.replace(self.norm2(x.feats))
68 | h = h * (1 + scale_mlp) + shift_mlp
69 | h = self.mlp(h)
70 | h = h * gate_mlp
71 | x = x + h
72 | return x
73 |
74 | def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
75 | if self.use_checkpoint:
76 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
77 | else:
78 | return self._forward(x, mod)
79 |
80 |
81 | class ModulatedSparseTransformerCrossBlock(nn.Module):
82 | """
83 | Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
84 | """
85 | def __init__(
86 | self,
87 | channels: int,
88 | ctx_channels: int,
89 | num_heads: int,
90 | mlp_ratio: float = 4.0,
91 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
92 | window_size: Optional[int] = None,
93 | shift_sequence: Optional[int] = None,
94 | shift_window: Optional[Tuple[int, int, int]] = None,
95 | serialize_mode: Optional[SerializeMode] = None,
96 | use_checkpoint: bool = False,
97 | use_rope: bool = False,
98 | qk_rms_norm: bool = False,
99 | qk_rms_norm_cross: bool = False,
100 | qkv_bias: bool = True,
101 | share_mod: bool = False,
102 |
103 | ):
104 | super().__init__()
105 | self.use_checkpoint = use_checkpoint
106 | self.share_mod = share_mod
107 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
108 | self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
109 | self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
110 | self.self_attn = SparseMultiHeadAttention(
111 | channels,
112 | num_heads=num_heads,
113 | type="self",
114 | attn_mode=attn_mode,
115 | window_size=window_size,
116 | shift_sequence=shift_sequence,
117 | shift_window=shift_window,
118 | serialize_mode=serialize_mode,
119 | qkv_bias=qkv_bias,
120 | use_rope=use_rope,
121 | qk_rms_norm=qk_rms_norm,
122 | )
123 | self.cross_attn = SparseMultiHeadAttention(
124 | channels,
125 | ctx_channels=ctx_channels,
126 | num_heads=num_heads,
127 | type="cross",
128 | attn_mode="full",
129 | qkv_bias=qkv_bias,
130 | qk_rms_norm=qk_rms_norm_cross,
131 | )
132 | self.mlp = SparseFeedForwardNet(
133 | channels,
134 | mlp_ratio=mlp_ratio,
135 | )
136 | if not share_mod:
137 | self.adaLN_modulation = nn.Sequential(
138 | nn.SiLU(),
139 | nn.Linear(channels, 6 * channels, bias=True)
140 | )
141 |
142 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
143 | if self.share_mod:
144 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
145 | else:
146 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
147 | h = x.replace(self.norm1(x.feats))
148 | h = h * (1 + scale_msa) + shift_msa
149 | h = self.self_attn(h)
150 | h = h * gate_msa
151 | x = x + h
152 | h = x.replace(self.norm2(x.feats))
153 | h = self.cross_attn(h, context)
154 | x = x + h
155 | h = x.replace(self.norm3(x.feats))
156 | h = h * (1 + scale_mlp) + shift_mlp
157 | h = self.mlp(h)
158 | h = h * gate_mlp
159 | x = x + h
160 | return x
161 |
162 | def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
163 | if self.use_checkpoint:
164 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
165 | else:
166 | return self._forward(x, mod, context)
167 |
--------------------------------------------------------------------------------
/trellis/modules/spatial.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
5 | """
6 | 3D pixel shuffle.
7 | """
8 | B, C, H, W, D = x.shape
9 | C_ = C // scale_factor**3
10 | x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
11 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
12 | x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
13 | return x
14 |
15 |
16 | def patchify(x: torch.Tensor, patch_size: int):
17 | """
18 | Patchify a tensor.
19 |
20 | Args:
21 | x (torch.Tensor): (N, C, *spatial) tensor
22 | patch_size (int): Patch size
23 | """
24 | DIM = x.dim() - 2
25 | for d in range(2, DIM + 2):
26 | assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
27 |
28 | x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
29 | x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
30 | x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
31 | return x
32 |
33 |
34 | def unpatchify(x: torch.Tensor, patch_size: int):
35 | """
36 | Unpatchify a tensor.
37 |
38 | Args:
39 | x (torch.Tensor): (N, C, *spatial) tensor
40 | patch_size (int): Patch size
41 | """
42 | DIM = x.dim() - 2
43 | assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
44 |
45 | x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
46 | x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
47 | x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
48 | return x
49 |
--------------------------------------------------------------------------------
/trellis/modules/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .blocks import *
2 | from .modulated import *
--------------------------------------------------------------------------------
/trellis/modules/transformer/blocks.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | from ..attention import MultiHeadAttention
5 | from ..norm import LayerNorm32
6 |
7 |
8 | class AbsolutePositionEmbedder(nn.Module):
9 | """
10 | Embeds spatial positions into vector representations.
11 | """
12 | def __init__(self, channels: int, in_channels: int = 3):
13 | super().__init__()
14 | self.channels = channels
15 | self.in_channels = in_channels
16 | self.freq_dim = channels // in_channels // 2
17 | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
18 | self.freqs = 1.0 / (10000 ** self.freqs)
19 |
20 | def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
21 | """
22 | Create sinusoidal position embeddings.
23 |
24 | Args:
25 | x: a 1-D Tensor of N indices
26 |
27 | Returns:
28 | an (N, D) Tensor of positional embeddings.
29 | """
30 | self.freqs = self.freqs.to(x.device)
31 | out = torch.outer(x, self.freqs)
32 | out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
33 | return out
34 |
35 | def forward(self, x: torch.Tensor) -> torch.Tensor:
36 | """
37 | Args:
38 | x (torch.Tensor): (N, D) tensor of spatial positions
39 | """
40 | N, D = x.shape
41 | assert D == self.in_channels, "Input dimension must match number of input channels"
42 | embed = self._sin_cos_embedding(x.reshape(-1))
43 | embed = embed.reshape(N, -1)
44 | if embed.shape[1] < self.channels:
45 | embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
46 | return embed
47 |
48 |
49 | class FeedForwardNet(nn.Module):
50 | def __init__(self, channels: int, mlp_ratio: float = 4.0):
51 | super().__init__()
52 | self.mlp = nn.Sequential(
53 | nn.Linear(channels, int(channels * mlp_ratio)),
54 | nn.GELU(approximate="tanh"),
55 | nn.Linear(int(channels * mlp_ratio), channels),
56 | )
57 |
58 | def forward(self, x: torch.Tensor) -> torch.Tensor:
59 | return self.mlp(x)
60 |
61 |
62 | class TransformerBlock(nn.Module):
63 | """
64 | Transformer block (MSA + FFN).
65 | """
66 | def __init__(
67 | self,
68 | channels: int,
69 | num_heads: int,
70 | mlp_ratio: float = 4.0,
71 | attn_mode: Literal["full", "windowed"] = "full",
72 | window_size: Optional[int] = None,
73 | shift_window: Optional[int] = None,
74 | use_checkpoint: bool = False,
75 | use_rope: bool = False,
76 | qk_rms_norm: bool = False,
77 | qkv_bias: bool = True,
78 | ln_affine: bool = False,
79 | ):
80 | super().__init__()
81 | self.use_checkpoint = use_checkpoint
82 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
83 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
84 | self.attn = MultiHeadAttention(
85 | channels,
86 | num_heads=num_heads,
87 | attn_mode=attn_mode,
88 | window_size=window_size,
89 | shift_window=shift_window,
90 | qkv_bias=qkv_bias,
91 | use_rope=use_rope,
92 | qk_rms_norm=qk_rms_norm,
93 | )
94 | self.mlp = FeedForwardNet(
95 | channels,
96 | mlp_ratio=mlp_ratio,
97 | )
98 |
99 | def _forward(self, x: torch.Tensor) -> torch.Tensor:
100 | h = self.norm1(x)
101 | h = self.attn(h)
102 | x = x + h
103 | h = self.norm2(x)
104 | h = self.mlp(h)
105 | x = x + h
106 | return x
107 |
108 | def forward(self, x: torch.Tensor) -> torch.Tensor:
109 | if self.use_checkpoint:
110 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
111 | else:
112 | return self._forward(x)
113 |
114 |
115 | class TransformerCrossBlock(nn.Module):
116 | """
117 | Transformer cross-attention block (MSA + MCA + FFN).
118 | """
119 | def __init__(
120 | self,
121 | channels: int,
122 | ctx_channels: int,
123 | num_heads: int,
124 | mlp_ratio: float = 4.0,
125 | attn_mode: Literal["full", "windowed"] = "full",
126 | window_size: Optional[int] = None,
127 | shift_window: Optional[Tuple[int, int, int]] = None,
128 | use_checkpoint: bool = False,
129 | use_rope: bool = False,
130 | qk_rms_norm: bool = False,
131 | qk_rms_norm_cross: bool = False,
132 | qkv_bias: bool = True,
133 | ln_affine: bool = False,
134 | ):
135 | super().__init__()
136 | self.use_checkpoint = use_checkpoint
137 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
138 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
139 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
140 | self.self_attn = MultiHeadAttention(
141 | channels,
142 | num_heads=num_heads,
143 | type="self",
144 | attn_mode=attn_mode,
145 | window_size=window_size,
146 | shift_window=shift_window,
147 | qkv_bias=qkv_bias,
148 | use_rope=use_rope,
149 | qk_rms_norm=qk_rms_norm,
150 | )
151 | self.cross_attn = MultiHeadAttention(
152 | channels,
153 | ctx_channels=ctx_channels,
154 | num_heads=num_heads,
155 | type="cross",
156 | attn_mode="full",
157 | qkv_bias=qkv_bias,
158 | qk_rms_norm=qk_rms_norm_cross,
159 | )
160 | self.mlp = FeedForwardNet(
161 | channels,
162 | mlp_ratio=mlp_ratio,
163 | )
164 |
165 | def _forward(self, x: torch.Tensor, context: torch.Tensor):
166 | h = self.norm1(x)
167 | h = self.self_attn(h)
168 | x = x + h
169 | h = self.norm2(x)
170 | h = self.cross_attn(h, context)
171 | x = x + h
172 | h = self.norm3(x)
173 | h = self.mlp(h)
174 | x = x + h
175 | return x
176 |
177 | def forward(self, x: torch.Tensor, context: torch.Tensor):
178 | if self.use_checkpoint:
179 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
180 | else:
181 | return self._forward(x, context)
182 |
--------------------------------------------------------------------------------
/trellis/modules/transformer/modulated.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | from ..attention import MultiHeadAttention
5 | from ..norm import LayerNorm32
6 | from .blocks import FeedForwardNet
7 |
8 |
9 | class ModulatedTransformerBlock(nn.Module):
10 | """
11 | Transformer block (MSA + FFN) with adaptive layer norm conditioning.
12 | """
13 | def __init__(
14 | self,
15 | channels: int,
16 | num_heads: int,
17 | mlp_ratio: float = 4.0,
18 | attn_mode: Literal["full", "windowed"] = "full",
19 | window_size: Optional[int] = None,
20 | shift_window: Optional[Tuple[int, int, int]] = None,
21 | use_checkpoint: bool = False,
22 | use_rope: bool = False,
23 | qk_rms_norm: bool = False,
24 | qkv_bias: bool = True,
25 | share_mod: bool = False,
26 | ):
27 | super().__init__()
28 | self.use_checkpoint = use_checkpoint
29 | self.share_mod = share_mod
30 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
31 | self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
32 | self.attn = MultiHeadAttention(
33 | channels,
34 | num_heads=num_heads,
35 | attn_mode=attn_mode,
36 | window_size=window_size,
37 | shift_window=shift_window,
38 | qkv_bias=qkv_bias,
39 | use_rope=use_rope,
40 | qk_rms_norm=qk_rms_norm,
41 | )
42 | self.mlp = FeedForwardNet(
43 | channels,
44 | mlp_ratio=mlp_ratio,
45 | )
46 | if not share_mod:
47 | self.adaLN_modulation = nn.Sequential(
48 | nn.SiLU(),
49 | nn.Linear(channels, 6 * channels, bias=True)
50 | )
51 |
52 | def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
53 | if self.share_mod:
54 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
55 | else:
56 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
57 | h = self.norm1(x)
58 | h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
59 | h = self.attn(h)
60 | h = h * gate_msa.unsqueeze(1)
61 | x = x + h
62 | h = self.norm2(x)
63 | h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
64 | h = self.mlp(h)
65 | h = h * gate_mlp.unsqueeze(1)
66 | x = x + h
67 | return x
68 |
69 | def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
70 | if self.use_checkpoint:
71 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
72 | else:
73 | return self._forward(x, mod)
74 |
75 |
76 | class ModulatedTransformerCrossBlock(nn.Module):
77 | """
78 | Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
79 | """
80 | def __init__(
81 | self,
82 | channels: int,
83 | ctx_channels: int,
84 | num_heads: int,
85 | mlp_ratio: float = 4.0,
86 | attn_mode: Literal["full", "windowed"] = "full",
87 | window_size: Optional[int] = None,
88 | shift_window: Optional[Tuple[int, int, int]] = None,
89 | use_checkpoint: bool = False,
90 | use_rope: bool = False,
91 | qk_rms_norm: bool = False,
92 | qk_rms_norm_cross: bool = False,
93 | qkv_bias: bool = True,
94 | share_mod: bool = False,
95 | ):
96 | super().__init__()
97 | self.use_checkpoint = use_checkpoint
98 | self.share_mod = share_mod
99 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
100 | self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
101 | self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
102 | self.self_attn = MultiHeadAttention(
103 | channels,
104 | num_heads=num_heads,
105 | type="self",
106 | attn_mode=attn_mode,
107 | window_size=window_size,
108 | shift_window=shift_window,
109 | qkv_bias=qkv_bias,
110 | use_rope=use_rope,
111 | qk_rms_norm=qk_rms_norm,
112 | )
113 | self.cross_attn = MultiHeadAttention(
114 | channels,
115 | ctx_channels=ctx_channels,
116 | num_heads=num_heads,
117 | type="cross",
118 | attn_mode="full",
119 | qkv_bias=qkv_bias,
120 | qk_rms_norm=qk_rms_norm_cross,
121 | )
122 | self.mlp = FeedForwardNet(
123 | channels,
124 | mlp_ratio=mlp_ratio,
125 | )
126 | if not share_mod:
127 | self.adaLN_modulation = nn.Sequential(
128 | nn.SiLU(),
129 | nn.Linear(channels, 6 * channels, bias=True)
130 | )
131 |
132 | def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
133 | if self.share_mod:
134 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
135 | else:
136 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
137 | h = self.norm1(x)
138 | h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
139 | h = self.self_attn(h)
140 | h = h * gate_msa.unsqueeze(1)
141 | x = x + h
142 | h = self.norm2(x)
143 | h = self.cross_attn(h, context)
144 | x = x + h
145 | h = self.norm3(x)
146 | h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
147 | h = self.mlp(h)
148 | h = h * gate_mlp.unsqueeze(1)
149 | x = x + h
150 | return x
151 |
152 | def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
153 | if self.use_checkpoint:
154 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
155 | else:
156 | return self._forward(x, mod, context)
157 |
--------------------------------------------------------------------------------
/trellis/modules/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from ..modules import sparse as sp
3 |
4 | FP16_MODULES = (
5 | nn.Conv1d,
6 | nn.Conv2d,
7 | nn.Conv3d,
8 | nn.ConvTranspose1d,
9 | nn.ConvTranspose2d,
10 | nn.ConvTranspose3d,
11 | nn.Linear,
12 | sp.SparseConv3d,
13 | sp.SparseInverseConv3d,
14 | sp.SparseLinear,
15 | )
16 |
17 | def convert_module_to_f16(l):
18 | """
19 | Convert primitive modules to float16.
20 | """
21 | if isinstance(l, FP16_MODULES):
22 | for p in l.parameters():
23 | p.data = p.data.half()
24 |
25 |
26 | def convert_module_to_f32(l):
27 | """
28 | Convert primitive modules to float32, undoing convert_module_to_f16().
29 | """
30 | if isinstance(l, FP16_MODULES):
31 | for p in l.parameters():
32 | p.data = p.data.float()
33 |
34 |
35 | def zero_module(module):
36 | """
37 | Zero out the parameters of a module and return it.
38 | """
39 | for p in module.parameters():
40 | p.detach().zero_()
41 | return module
42 |
43 |
44 | def scale_module(module, scale):
45 | """
46 | Scale the parameters of a module and return it.
47 | """
48 | for p in module.parameters():
49 | p.detach().mul_(scale)
50 | return module
51 |
52 |
53 | def modulate(x, shift, scale):
54 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
55 |
--------------------------------------------------------------------------------
/trellis/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from . import samplers
2 | from .trellis_image_to_3d import TrellisImageTo3DPipeline
3 |
4 |
5 | def from_pretrained(path: str):
6 | """
7 | Load a pipeline from a model folder or a Hugging Face model hub.
8 |
9 | Args:
10 | path: The path to the model. Can be either local path or a Hugging Face model name.
11 | """
12 | import os
13 | import json
14 | is_local = os.path.exists(f"{path}/pipeline.json")
15 |
16 | if is_local:
17 | config_file = f"{path}/pipeline.json"
18 | else:
19 | from huggingface_hub import hf_hub_download
20 | config_file = hf_hub_download(path, "pipeline.json")
21 |
22 | with open(config_file, 'r') as f:
23 | config = json.load(f)
24 | return globals()[config['name']].from_pretrained(path)
25 |
--------------------------------------------------------------------------------
/trellis/pipelines/base.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import torch.nn as nn
4 | from .. import models
5 |
6 |
7 | class Pipeline:
8 | """
9 | A base class for pipelines.
10 | """
11 | def __init__(
12 | self,
13 | models: dict[str, nn.Module] = None,
14 | ):
15 | if models is None:
16 | return
17 | self.models = models
18 | for model in self.models.values():
19 | model.eval()
20 |
21 | @staticmethod
22 | def from_pretrained(path: str) -> "Pipeline":
23 | """
24 | Load a pretrained model.
25 | """
26 | import os
27 | import json
28 | is_local = os.path.exists(f"{path}/pipeline.json")
29 |
30 | if is_local:
31 | config_file = f"{path}/pipeline.json"
32 | with open(config_file, 'r') as f:
33 | #print("Loading pipeline.json:")
34 | args = json.load(f)['args']
35 | #print("Models in pipeline.json:")
36 | '''for k, v in args['models'].items():
37 | print(f" {k}: {v}")'''
38 |
39 | # For local paths, construct the models dict differently
40 | _models = {
41 | k: models.from_pretrained(os.path.join(path, v))
42 | for k, v in args['models'].items()
43 | }
44 | else:
45 | from huggingface_hub import hf_hub_download
46 | config_file = hf_hub_download(path, "pipeline.json")
47 | with open(config_file, 'r') as f:
48 | args = json.load(f)['args']
49 |
50 | _models = {
51 | k: models.from_pretrained(os.path.join(path, v))
52 | for k, v in args['models'].items()
53 | }
54 |
55 | new_pipeline = Pipeline(_models)
56 | new_pipeline._pretrained_args = args
57 | return new_pipeline
58 |
59 | @property
60 | def device(self) -> torch.device:
61 | for model in self.models.values():
62 | if hasattr(model, 'device'):
63 | return model.device
64 | for model in self.models.values():
65 | if hasattr(model, 'parameters'):
66 | return next(model.parameters()).device
67 | raise RuntimeError("No device found.")
68 |
69 | def to(self, device: torch.device) -> None:
70 | for model in self.models.values():
71 | model.to(device)
72 |
73 | def cuda(self) -> None:
74 | self.to(torch.device("cuda"))
75 |
76 | def cpu(self) -> None:
77 | self.to(torch.device("cpu"))
--------------------------------------------------------------------------------
/trellis/pipelines/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Sampler
2 | from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
--------------------------------------------------------------------------------
/trellis/pipelines/samplers/base.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | from abc import ABC, abstractmethod
3 |
4 |
5 | class Sampler(ABC):
6 | """
7 | A base class for samplers.
8 | """
9 |
10 | @abstractmethod
11 | def sample(
12 | self,
13 | model,
14 | **kwargs
15 | ):
16 | """
17 | Sample from a model.
18 | """
19 | pass
20 |
--------------------------------------------------------------------------------
/trellis/pipelines/samplers/classifier_free_guidance_mixin.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 |
4 | class ClassifierFreeGuidanceSamplerMixin:
5 | """
6 | A mixin class for samplers that apply classifier-free guidance.
7 | """
8 |
9 | def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
10 | pred = super()._inference_model(model, x_t, t, cond, **kwargs)
11 | neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
12 | return (1 + cfg_strength) * pred - cfg_strength * neg_pred
13 |
--------------------------------------------------------------------------------
/trellis/pipelines/samplers/flow_euler.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import torch
3 | import numpy as np
4 | from tqdm import tqdm
5 | from easydict import EasyDict as edict
6 | from .base import Sampler
7 | from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
8 | from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
9 |
10 |
11 | class FlowEulerSampler(Sampler):
12 | """
13 | Generate samples from a flow-matching model using Euler sampling.
14 |
15 | Args:
16 | sigma_min: The minimum scale of noise in flow.
17 | """
18 | def __init__(
19 | self,
20 | sigma_min: float,
21 | ):
22 | self.sigma_min = sigma_min
23 |
24 | def _eps_to_xstart(self, x_t, t, eps):
25 | assert x_t.shape == eps.shape
26 | return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
27 |
28 | def _xstart_to_eps(self, x_t, t, x_0):
29 | assert x_t.shape == x_0.shape
30 | return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
31 |
32 | def _v_to_xstart_eps(self, x_t, t, v):
33 | assert x_t.shape == v.shape
34 | eps = (1 - t) * v + x_t
35 | x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
36 | return x_0, eps
37 |
38 | def _inference_model(self, model, x_t, t, cond=None, **kwargs):
39 | t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
40 | return model(x_t, t, cond, **kwargs)
41 |
42 | def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
43 | pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
44 | pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
45 | return pred_x_0, pred_eps, pred_v
46 |
47 | @torch.no_grad()
48 | def sample_once(
49 | self,
50 | model,
51 | x_t,
52 | t: float,
53 | t_prev: float,
54 | cond: Optional[Any] = None,
55 | **kwargs
56 | ):
57 | """
58 | Sample x_{t-1} from the model using Euler method.
59 |
60 | Args:
61 | model: The model to sample from.
62 | x_t: The [N x C x ...] tensor of noisy inputs at time t.
63 | t: The current timestep.
64 | t_prev: The previous timestep.
65 | cond: conditional information.
66 | **kwargs: Additional arguments for model inference.
67 |
68 | Returns:
69 | a dict containing the following
70 | - 'pred_x_prev': x_{t-1}.
71 | - 'pred_x_0': a prediction of x_0.
72 | """
73 | pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
74 | pred_x_prev = x_t - (t - t_prev) * pred_v
75 | return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
76 |
77 | @torch.no_grad()
78 | def sample(
79 | self,
80 | model,
81 | noise,
82 | cond: Optional[Any] = None,
83 | steps: int = 50,
84 | rescale_t: float = 1.0,
85 | verbose: bool = True,
86 | **kwargs
87 | ):
88 | """
89 | Generate samples from the model using Euler method.
90 |
91 | Args:
92 | model: The model to sample from.
93 | noise: The initial noise tensor.
94 | cond: conditional information.
95 | steps: The number of steps to sample.
96 | rescale_t: The rescale factor for t.
97 | verbose: If True, show a progress bar.
98 | **kwargs: Additional arguments for model_inference.
99 |
100 | Returns:
101 | a dict containing the following
102 | - 'samples': the model samples.
103 | - 'pred_x_t': a list of prediction of x_t.
104 | - 'pred_x_0': a list of prediction of x_0.
105 | """
106 | sample = noise
107 | t_seq = np.linspace(1, 0, steps + 1)
108 | t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
109 | t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
110 | ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
111 | for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose):
112 | out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
113 | sample = out.pred_x_prev
114 | ret.pred_x_t.append(out.pred_x_prev)
115 | ret.pred_x_0.append(out.pred_x_0)
116 | ret.samples = sample
117 | return ret
118 |
119 |
120 | class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
121 | """
122 | Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
123 | """
124 | @torch.no_grad()
125 | def sample(
126 | self,
127 | model,
128 | noise,
129 | cond,
130 | neg_cond,
131 | steps: int = 50,
132 | rescale_t: float = 1.0,
133 | cfg_strength: float = 3.0,
134 | verbose: bool = True,
135 | **kwargs
136 | ):
137 | """
138 | Generate samples from the model using Euler method.
139 |
140 | Args:
141 | model: The model to sample from.
142 | noise: The initial noise tensor.
143 | cond: conditional information.
144 | neg_cond: negative conditional information.
145 | steps: The number of steps to sample.
146 | rescale_t: The rescale factor for t.
147 | cfg_strength: The strength of classifier-free guidance.
148 | verbose: If True, show a progress bar.
149 | **kwargs: Additional arguments for model_inference.
150 |
151 | Returns:
152 | a dict containing the following
153 | - 'samples': the model samples.
154 | - 'pred_x_t': a list of prediction of x_t.
155 | - 'pred_x_0': a list of prediction of x_0.
156 | """
157 | return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs)
158 |
159 |
160 | class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
161 | """
162 | Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
163 | """
164 | @torch.no_grad()
165 | def sample(
166 | self,
167 | model,
168 | noise,
169 | cond,
170 | neg_cond,
171 | steps: int = 50,
172 | rescale_t: float = 1.0,
173 | cfg_strength: float = 3.0,
174 | cfg_interval: Tuple[float, float] = (0.0, 1.0),
175 | verbose: bool = True,
176 | **kwargs
177 | ):
178 | """
179 | Generate samples from the model using Euler method.
180 |
181 | Args:
182 | model: The model to sample from.
183 | noise: The initial noise tensor.
184 | cond: conditional information.
185 | neg_cond: negative conditional information.
186 | steps: The number of steps to sample.
187 | rescale_t: The rescale factor for t.
188 | cfg_strength: The strength of classifier-free guidance.
189 | cfg_interval: The interval for classifier-free guidance.
190 | verbose: If True, show a progress bar.
191 | **kwargs: Additional arguments for model_inference.
192 |
193 | Returns:
194 | a dict containing the following
195 | - 'samples': the model samples.
196 | - 'pred_x_t': a list of prediction of x_t.
197 | - 'pred_x_0': a list of prediction of x_0.
198 | """
199 | return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
200 |
--------------------------------------------------------------------------------
/trellis/pipelines/samplers/guidance_interval_mixin.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 |
4 | class GuidanceIntervalSamplerMixin:
5 | """
6 | A mixin class for samplers that apply classifier-free guidance with interval.
7 | """
8 |
9 | def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
10 | if cfg_interval[0] <= t <= cfg_interval[1]:
11 | pred = super()._inference_model(model, x_t, t, cond, **kwargs)
12 | neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
13 | return (1 + cfg_strength) * pred - cfg_strength * neg_pred
14 | else:
15 | return super()._inference_model(model, x_t, t, cond, **kwargs)
16 |
--------------------------------------------------------------------------------
/trellis/representations/__init__.py:
--------------------------------------------------------------------------------
1 | from .mesh import MeshExtractResult
2 |
--------------------------------------------------------------------------------
/trellis/representations/mesh/__init__.py:
--------------------------------------------------------------------------------
1 | from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult
2 |
--------------------------------------------------------------------------------
/trellis/representations/mesh/cube2mesh.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ...modules.sparse import SparseTensor
3 | from easydict import EasyDict as edict
4 | from .utils_cube import *
5 | from .flexicube import FlexiCubes
6 |
7 |
8 | class MeshExtractResult:
9 | def __init__(self,
10 | vertices,
11 | faces,
12 | vertex_attrs=None,
13 | res=64
14 | ):
15 | self.vertices = vertices
16 | self.faces = faces.long()
17 | self.vertex_attrs = vertex_attrs
18 | self.face_normal = self.comput_face_normals(vertices, faces)
19 | self.res = res
20 | self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0)
21 |
22 | # training only
23 | self.tsdf_v = None
24 | self.tsdf_s = None
25 | self.reg_loss = None
26 |
27 | def comput_face_normals(self, verts, faces):
28 | i0 = faces[..., 0].long()
29 | i1 = faces[..., 1].long()
30 | i2 = faces[..., 2].long()
31 |
32 | v0 = verts[i0, :]
33 | v1 = verts[i1, :]
34 | v2 = verts[i2, :]
35 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
36 | face_normals = torch.nn.functional.normalize(face_normals, dim=1)
37 | # print(face_normals.min(), face_normals.max(), face_normals.shape)
38 | return face_normals[:, None, :].repeat(1, 3, 1)
39 |
40 | def comput_v_normals(self, verts, faces):
41 | i0 = faces[..., 0].long()
42 | i1 = faces[..., 1].long()
43 | i2 = faces[..., 2].long()
44 |
45 | v0 = verts[i0, :]
46 | v1 = verts[i1, :]
47 | v2 = verts[i2, :]
48 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
49 | v_normals = torch.zeros_like(verts)
50 | v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
51 | v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
52 | v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
53 |
54 | v_normals = torch.nn.functional.normalize(v_normals, dim=1)
55 | return v_normals
56 |
57 |
58 | class SparseFeatures2Mesh:
59 | def __init__(self, device="cuda", res=64, use_color=True):
60 | '''
61 | a model to generate a mesh from sparse features structures using flexicube
62 | '''
63 | super().__init__()
64 | self.device=device
65 | self.res = res
66 | self.mesh_extractor = FlexiCubes(device=device)
67 | self.sdf_bias = -1.0 / res
68 | verts, cube = construct_dense_grid(self.res, self.device)
69 | self.reg_c = cube.to(self.device)
70 | self.reg_v = verts.to(self.device)
71 | self.use_color = use_color
72 | self._calc_layout()
73 |
74 | def _calc_layout(self):
75 | LAYOUTS = {
76 | 'sdf': {'shape': (8, 1), 'size': 8},
77 | 'deform': {'shape': (8, 3), 'size': 8 * 3},
78 | 'weights': {'shape': (21,), 'size': 21}
79 | }
80 | if self.use_color:
81 | '''
82 | 6 channel color including normal map
83 | '''
84 | LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6}
85 | self.layouts = edict(LAYOUTS)
86 | start = 0
87 | for k, v in self.layouts.items():
88 | v['range'] = (start, start + v['size'])
89 | start += v['size']
90 | self.feats_channels = start
91 |
92 | def get_layout(self, feats : torch.Tensor, name : str):
93 | if name not in self.layouts:
94 | return None
95 | return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape'])
96 |
97 | def __call__(self, cubefeats : SparseTensor, training=False):
98 | """
99 | Generates a mesh based on the specified sparse voxel structures.
100 | Args:
101 | cube_attrs [Nx21] : Sparse Tensor attrs about cube weights
102 | verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal
103 | Returns:
104 | return the success tag and ni you loss,
105 | """
106 | # add sdf bias to verts_attrs
107 | coords = cubefeats.coords[:, 1:]
108 | feats = cubefeats.feats
109 |
110 | sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']]
111 | sdf += self.sdf_bias
112 | v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform]
113 | v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training)
114 | v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True)
115 | weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False)
116 | if self.use_color:
117 | sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:]
118 | else:
119 | sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4]
120 | colors_d = None
121 |
122 | x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res)
123 |
124 | vertices, faces, L_dev, colors = self.mesh_extractor(
125 | voxelgrid_vertices=x_nx3,
126 | scalar_field=sdf_d,
127 | cube_idx=self.reg_c,
128 | resolution=self.res,
129 | beta=weights_d[:, :12],
130 | alpha=weights_d[:, 12:20],
131 | gamma_f=weights_d[:, 20],
132 | voxelgrid_colors=colors_d,
133 | training=training)
134 |
135 | mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res)
136 | if training:
137 | if mesh.success:
138 | reg_loss += L_dev.mean() * 0.5
139 | reg_loss += (weights[:,:20]).abs().mean() * 0.2
140 | mesh.reg_loss = reg_loss
141 | mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res)
142 | mesh.tsdf_s = v_attrs[:, 0]
143 | return mesh
--------------------------------------------------------------------------------
/trellis/representations/mesh/utils_cube.py:
--------------------------------------------------------------------------------
1 | import torch
2 | cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
3 | 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int)
4 | cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]])
5 | cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
6 | 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False)
7 |
8 | def construct_dense_grid(res, device='cuda'):
9 | '''construct a dense grid based on resolution'''
10 | res_v = res + 1
11 | vertsid = torch.arange(res_v ** 3, device=device)
12 | coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten()
13 | cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2]
14 | cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device))
15 | verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1)
16 | return verts, cube_fx8
17 |
18 |
19 | def construct_voxel_grid(coords):
20 | verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3)
21 | verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True)
22 | cubes = inverse_indices.reshape(-1, 8)
23 | return verts_unique, cubes
24 |
25 |
26 | def cubes_to_verts(num_verts, cubes, value, reduce='mean'):
27 | """
28 | Args:
29 | cubes [Vx8] verts index for each cube
30 | value [Vx8xM] value to be scattered
31 | Operation:
32 | reduced[cubes[i][j]][k] += value[i][k]
33 | """
34 | M = value.shape[2] # number of channels
35 | reduced = torch.zeros(num_verts, M, device=cubes.device)
36 | return torch.scatter_reduce(reduced, 0,
37 | cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1),
38 | value.flatten(0, 1), reduce=reduce, include_self=False)
39 |
40 | def sparse_cube2verts(coords, feats, training=True):
41 | new_coords, cubes = construct_voxel_grid(coords)
42 | new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats)
43 | if training:
44 | con_loss = torch.mean((feats - new_feats[cubes]) ** 2)
45 | else:
46 | con_loss = 0.0
47 | return new_coords, new_feats, con_loss
48 |
49 |
50 | def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True):
51 | F = feats.shape[-1]
52 | dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device)
53 | if sdf_init:
54 | dense_attrs[..., 0] = 1 # initial outside sdf value
55 | dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats
56 | return dense_attrs.reshape(-1, F)
57 |
58 |
59 | def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res):
60 | return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform)
61 |
--------------------------------------------------------------------------------
/trellis/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-X/ComfyUI-Hi3DGen/99621aa48dd05203bee91cf3d5813669e0b5721d/trellis/utils/__init__.py
--------------------------------------------------------------------------------
/trellis/utils/_rasterization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, List, Optional, Tuple, Dict
3 | import nvdiffrast.torch as dr
4 | import utils3d
5 | class RastContext:
6 | def __init__(self, backend='cuda'):
7 | self.backend = backend
8 |
9 | def rasterize_triangle_faces(
10 | ctx: RastContext,
11 | vertices: torch.Tensor,
12 | faces: torch.Tensor,
13 | width: int,
14 | height: int,
15 | attr: torch.Tensor = None,
16 | uv: torch.Tensor = None,
17 | texture: torch.Tensor = None,
18 | model: torch.Tensor = None,
19 | view: torch.Tensor = None,
20 | projection: torch.Tensor = None,
21 | antialiasing: Union[bool, List[int]] = True,
22 | diff_attrs: Union[None, List[int]] = None,
23 | ) -> Dict[str, torch.Tensor]:
24 | """
25 | Rasterize a mesh with vertex attributes.
26 | """
27 | assert vertices.ndim == 3
28 | assert faces.ndim == 2
29 |
30 | # Handle vertices dimensions
31 | if vertices.shape[-1] == 2:
32 | vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1)
33 | elif vertices.shape[-1] == 3:
34 | vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
35 | elif vertices.shape[-1] == 4:
36 | pass
37 | else:
38 | raise ValueError(f'Wrong shape of vertices: {vertices.shape}')
39 |
40 | # Calculate MVP matrix
41 | mvp = projection if projection is not None else torch.eye(4, device=vertices.device)
42 | if view is not None:
43 | mvp = mvp @ view
44 | if model is not None:
45 | mvp = mvp @ model
46 |
47 | # Transform vertices to clip space
48 | pos_clip = vertices @ mvp.transpose(-1, -2)
49 | faces = faces.contiguous()
50 | if attr is not None:
51 | attr = attr.contiguous()
52 |
53 | # Rasterize
54 | rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True)
55 |
56 | # Extract basic outputs
57 | face_id = rast_out[..., 3].flip(1)
58 | depth = rast_out[..., 2].flip(1)
59 | mask = (face_id > 0).float()
60 | depth = (depth * 0.5 + 0.5) * mask + (1.0 - mask)
61 |
62 | ret = {
63 | 'depth': depth,
64 | 'mask': mask,
65 | 'face_id': face_id,
66 | }
67 |
68 | # Handle attribute interpolation
69 | if attr is not None:
70 | image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs)
71 |
72 | if antialiasing == True:
73 | image = dr.antialias(image, rast_out, pos_clip, faces)
74 | elif isinstance(antialiasing, list):
75 | aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces)
76 | image[..., antialiasing] = aa_image
77 |
78 | image = image.flip(1).permute(0, 3, 1, 2)
79 | ret['image'] = image
80 |
81 | # Handle UV mapping
82 | if uv is not None:
83 | uv_map, uv_map_dr = dr.interpolate(uv, rast_out, faces, rast_db, diff_attrs='all')
84 | ret['uv'] = uv_map
85 | ret['uv_dr'] = uv_map_dr
86 |
87 | if texture is not None:
88 | texture_map = dr.texture(texture, uv_map, uv_map_dr)
89 | ret['texture'] = texture_map.flip(1).permute(0, 3, 1, 2)
90 |
91 | # Handle derivatives
92 | if diff_attrs is not None:
93 | image_dr = image_dr.flip(1).permute(0, 3, 1, 2)
94 | ret['image_dr'] = image_dr
95 |
96 | return ret
--------------------------------------------------------------------------------
/trellis/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 |
5 |
6 | # Dictionary utils
7 | def _dict_merge(dicta, dictb, prefix=''):
8 | """
9 | Merge two dictionaries.
10 | """
11 | assert isinstance(dicta, dict), 'input must be a dictionary'
12 | assert isinstance(dictb, dict), 'input must be a dictionary'
13 | dict_ = {}
14 | all_keys = set(dicta.keys()).union(set(dictb.keys()))
15 | for key in all_keys:
16 | if key in dicta.keys() and key in dictb.keys():
17 | if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
18 | dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
19 | else:
20 | raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
21 | elif key in dicta.keys():
22 | dict_[key] = dicta[key]
23 | else:
24 | dict_[key] = dictb[key]
25 | return dict_
26 |
27 |
28 | def dict_merge(dicta, dictb):
29 | """
30 | Merge two dictionaries.
31 | """
32 | return _dict_merge(dicta, dictb, prefix='')
33 |
34 |
35 | def dict_foreach(dic, func, special_func={}):
36 | """
37 | Recursively apply a function to all non-dictionary leaf values in a dictionary.
38 | """
39 | assert isinstance(dic, dict), 'input must be a dictionary'
40 | for key in dic.keys():
41 | if isinstance(dic[key], dict):
42 | dic[key] = dict_foreach(dic[key], func)
43 | else:
44 | if key in special_func.keys():
45 | dic[key] = special_func[key](dic[key])
46 | else:
47 | dic[key] = func(dic[key])
48 | return dic
49 |
50 |
51 | def dict_reduce(dicts, func, special_func={}):
52 | """
53 | Reduce a list of dictionaries. Leaf values must be scalars.
54 | """
55 | assert isinstance(dicts, list), 'input must be a list of dictionaries'
56 | assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
57 | assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
58 | all_keys = set([key for dict_ in dicts for key in dict_.keys()])
59 | reduced_dict = {}
60 | for key in all_keys:
61 | vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
62 | if isinstance(vlist[0], dict):
63 | reduced_dict[key] = dict_reduce(vlist, func, special_func)
64 | else:
65 | if key in special_func.keys():
66 | reduced_dict[key] = special_func[key](vlist)
67 | else:
68 | reduced_dict[key] = func(vlist)
69 | return reduced_dict
70 |
71 |
72 | def dict_any(dic, func):
73 | """
74 | Recursively apply a function to all non-dictionary leaf values in a dictionary.
75 | """
76 | assert isinstance(dic, dict), 'input must be a dictionary'
77 | for key in dic.keys():
78 | if isinstance(dic[key], dict):
79 | if dict_any(dic[key], func):
80 | return True
81 | else:
82 | if func(dic[key]):
83 | return True
84 | return False
85 |
86 |
87 | def dict_all(dic, func):
88 | """
89 | Recursively apply a function to all non-dictionary leaf values in a dictionary.
90 | """
91 | assert isinstance(dic, dict), 'input must be a dictionary'
92 | for key in dic.keys():
93 | if isinstance(dic[key], dict):
94 | if not dict_all(dic[key], func):
95 | return False
96 | else:
97 | if not func(dic[key]):
98 | return False
99 | return True
100 |
101 |
102 | def dict_flatten(dic, sep='.'):
103 | """
104 | Flatten a nested dictionary into a dictionary with no nested dictionaries.
105 | """
106 | assert isinstance(dic, dict), 'input must be a dictionary'
107 | flat_dict = {}
108 | for key in dic.keys():
109 | if isinstance(dic[key], dict):
110 | sub_dict = dict_flatten(dic[key], sep=sep)
111 | for sub_key in sub_dict.keys():
112 | flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
113 | else:
114 | flat_dict[key] = dic[key]
115 | return flat_dict
116 |
117 |
118 | def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
119 | num_images = len(images)
120 | if nrow is None and ncol is None:
121 | if aspect_ratio is not None:
122 | nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
123 | else:
124 | nrow = int(np.sqrt(num_images))
125 | ncol = (num_images + nrow - 1) // nrow
126 | elif nrow is None and ncol is not None:
127 | nrow = (num_images + ncol - 1) // ncol
128 | elif nrow is not None and ncol is None:
129 | ncol = (num_images + nrow - 1) // nrow
130 | else:
131 | assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
132 |
133 | grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
134 | for i, img in enumerate(images):
135 | row = i // ncol
136 | col = i % ncol
137 | grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
138 | return grid
139 |
140 |
141 | def notes_on_image(img, notes=None):
142 | img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
143 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
144 | if notes is not None:
145 | img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
146 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
147 | return img
148 |
149 |
150 | def save_image_with_notes(img, path, notes=None):
151 | """
152 | Save an image with notes.
153 | """
154 | if isinstance(img, torch.Tensor):
155 | img = img.cpu().numpy().transpose(1, 2, 0)
156 | if img.dtype == np.float32 or img.dtype == np.float64:
157 | img = np.clip(img * 255, 0, 255).astype(np.uint8)
158 | img = notes_on_image(img, notes)
159 | cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
160 |
161 |
162 | # debug utils
163 |
164 | def atol(x, y):
165 | """
166 | Absolute tolerance.
167 | """
168 | return torch.abs(x - y)
169 |
170 |
171 | def rtol(x, y):
172 | """
173 | Relative tolerance.
174 | """
175 | return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
176 |
177 |
178 | # print utils
179 | def indent(s, n=4):
180 | """
181 | Indent a string.
182 | """
183 | lines = s.split('\n')
184 | for i in range(1, len(lines)):
185 | lines[i] = ' ' * n + lines[i]
186 | return '\n'.join(lines)
187 |
188 |
--------------------------------------------------------------------------------
/trellis/utils/random_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
4 |
5 | def radical_inverse(base, n):
6 | val = 0
7 | inv_base = 1.0 / base
8 | inv_base_n = inv_base
9 | while n > 0:
10 | digit = n % base
11 | val += digit * inv_base_n
12 | n //= base
13 | inv_base_n *= inv_base
14 | return val
15 |
16 | def halton_sequence(dim, n):
17 | return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
18 |
19 | def hammersley_sequence(dim, n, num_samples):
20 | return [n / num_samples] + halton_sequence(dim - 1, n)
21 |
22 | def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
23 | u, v = hammersley_sequence(2, n, num_samples)
24 | u += offset[0] / num_samples
25 | v += offset[1]
26 | if remap:
27 | u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
28 | theta = np.arccos(1 - 2 * u) - np.pi / 2
29 | phi = v * 2 * np.pi
30 | return [phi, theta]
--------------------------------------------------------------------------------
/trellis_model_manager.py:
--------------------------------------------------------------------------------
1 | # trellis_model_manager.py
2 | import os
3 | import logging
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from safetensors.torch import load_file
8 | import folder_paths
9 | from huggingface_hub import hf_hub_download, snapshot_download
10 | from typing import Dict, Union
11 | import json
12 | import importlib # Import the importlib module
13 | from trellis.modules.utils import convert_module_to_f16, convert_module_to_f32
14 |
15 | logger = logging.getLogger('model_manager')
16 |
17 | __attributes = {
18 | 'SparseStructureDecoder': 'trellis.models.sparse_structure_vae',
19 | 'SparseStructureFlowModel': 'trellis.models.sparse_structure_flow',
20 | 'SLatFlowModel': 'trellis.models.structured_latent_flow',
21 | }
22 |
23 | __all__ = list(__attributes.keys())
24 |
25 | def __getattr__(name):
26 | if name in __attributes:
27 | module_name = __attributes[name]
28 | module = importlib.import_module(module_name, package=None) # Import the module
29 | return getattr(module, name)
30 | raise AttributeError(f"module {__name__} has no attribute {name}")
31 |
32 | class TrellisModelManager:
33 | """
34 | Basic manager for Trellis models, using ComfyUI's new model path.
35 | """
36 | def __init__(self, model_dir: str, config=None, device: str = "cuda"):
37 | """
38 | Initialize the model manager with a specific model directory.
39 |
40 | Args:
41 | model_dir (str): Path to model directory (e.g. "models/checkpoints/TRELLIS-image-large")
42 | config (dict or object): Global configuration for Trellis
43 | device (str): Device to load models on (e.g. "cuda")
44 | """
45 | self.model_dir = model_dir
46 | # Handle config being either a dict or an object
47 | if config is None:
48 | self.device = device
49 | elif isinstance(config, dict):
50 | self.device = config.get('device', device)
51 | self.config = config
52 | else:
53 | self.device = getattr(config, 'device', device)
54 | self.config = config
55 | self.model = None
56 | self.dinov2_model = None
57 |
58 | def load(self) -> None:
59 | """Load model configuration and checkpoints"""
60 | try:
61 | # Ensure directory exists
62 | os.makedirs(self.model_dir, exist_ok=True)
63 | ckpts_folder = os.path.join(self.model_dir, "ckpts")
64 | os.makedirs(ckpts_folder, exist_ok=True)
65 |
66 | # Download model files if needed
67 | if not os.path.exists(os.path.join(self.model_dir, "pipeline.json")):
68 | logger.info("Downloading TRELLIS models...")
69 | try:
70 | # Download main pipeline files
71 | snapshot_download(
72 | repo_id="Stable-X/trellis-normal-v0-1",
73 | local_dir=self.model_dir,
74 | local_dir_use_symlinks=False,
75 | allow_patterns=["pipeline.json", "README.md"]
76 | )
77 | # Download checkpoint files
78 | snapshot_download(
79 | repo_id="Stable-X/trellis-normal-v0-1",
80 | local_dir=ckpts_folder,
81 | local_dir_use_symlinks=False,
82 | allow_patterns=["*.safetensors", "*.json"],
83 | cache_dir=os.path.join(self.model_dir, ".cache")
84 | )
85 | logger.info("Model files downloaded successfully")
86 | except Exception as e:
87 | logger.error(f"Error downloading model files: {str(e)}")
88 | raise
89 |
90 | # Load configuration
91 | self.config = self._load_config()
92 |
93 | except Exception as e:
94 | logger.error(f"Error in load(): {str(e)}")
95 | raise
96 |
97 | def get_checkpoint_path(self, filename: str) -> str:
98 | """
99 | Returns the full path to a checkpoint file.
100 | """
101 | ckpts_folder = os.path.join(self.model_dir, "ckpts")
102 | # Add .safetensors extension if not present
103 | if not filename.endswith('.safetensors'):
104 | filename = f"{filename}.safetensors"
105 | full_path = os.path.join(ckpts_folder, filename)
106 | if not os.path.exists(full_path):
107 | raise FileNotFoundError(f"Checkpoint file not found: {full_path}")
108 | return full_path
109 |
110 | def _load_config(self) -> Dict:
111 | """Load model configuration from pipeline.json"""
112 | try:
113 | config_path = os.path.join(self.model_dir, "pipeline.json")
114 |
115 | if os.path.exists(config_path):
116 | logger.info(f"Loading config from {config_path}")
117 | with open(config_path, 'r') as f:
118 | config = json.load(f)
119 | else:
120 | logger.info(f"Config not found locally, downloading from HuggingFace")
121 | config_path = hf_hub_download(
122 | repo_id=f"JeffreyXiang/{os.path.basename(self.model_dir)}",
123 | filename="pipeline.json",
124 | cache_dir=os.path.join(self.model_dir, ".cache")
125 | )
126 | with open(config_path, 'r') as f:
127 | config = json.load(f)
128 |
129 | # Debug: Print raw config
130 | logger.info("Raw config contents:")
131 | logger.info(json.dumps(config, indent=2))
132 |
133 | if not config:
134 | raise ValueError(f"Could not load valid configuration from {self.model_dir}")
135 |
136 | if 'name' not in config:
137 | config['name'] = 'TrellisImageTo3DPipeline'
138 |
139 | return config
140 |
141 | except Exception as e:
142 | logger.error(f"Error loading config from {self.model_dir}: {e}")
143 | return {
144 | 'name': 'TrellisImageTo3DPipeline',
145 | 'version': '1.0'
146 | }
147 |
148 | def load_models(self) -> Dict[str, nn.Module]:
149 | """Load all required models with current configuration"""
150 | return {
151 | 'sparse_structure_flow_model': self.get_checkpoint_path("ss_flow_img_dit_L_16l8_fp16"),
152 | 'slat_flow_model': self.get_checkpoint_path("slat_flow_img_dit_L_64l8p2_fp16")
153 | }
154 |
155 | def load_model_components(self) -> Dict[str, nn.Module]:
156 | """Loads individual model components."""
157 | models = {}
158 | model_paths = self.load_models()
159 | for name, path in model_paths.items():
160 | models[name] = models.from_pretrained(path, config=self.config)
161 |
162 | # Ensure each model is converted to the desired precision
163 | if self.config.get('use_fp16', True):
164 | convert_module_to_f16(models[name])
165 | else:
166 | convert_module_to_f32(models[name])
167 |
168 | # DINOv2 is handled separately
169 | # models['image_cond_model'] = self.load_dinov2(self.config.get("dinov2_model", "dinov2_vitl14"))
170 |
171 | return models
172 |
173 | def load_dinov2(self, model_name: str):
174 | """Load DINOv2 model with device, precision, and attention backend management"""
175 | try:
176 | # Get configuration values
177 | use_fp16 = (self.config.get('use_fp16', True)
178 | if isinstance(self.config, dict)
179 | else getattr(self.config, 'use_fp16', True))
180 |
181 | # Get attention backend from config
182 | attention_backend = (self.config.get('attention_backend', 'default')
183 | if isinstance(self.config, dict)
184 | else getattr(self.config, 'attention_backend', 'default'))
185 |
186 | # Try to load from local path first
187 | model_path = folder_paths.get_full_path("classifiers", f"{model_name}.pth")
188 |
189 | if model_path is None:
190 | print(f"Downloading {model_name} from torch hub...")
191 | try:
192 | # Load model architecture with specified attention backend
193 | model = torch.hub.load('facebookresearch/dinov2', model_name,
194 | pretrained=True,
195 | force_reload=False,
196 | trust_repo=True)
197 |
198 | # Save model for future use
199 | save_dir = os.path.join(folder_paths.models_dir, "classifiers")
200 | os.makedirs(save_dir, exist_ok=True)
201 | save_path = os.path.join(save_dir, f"{model_name}.pth")
202 |
203 | # Save on CPU to avoid memory issues
204 | model = model.cpu()
205 | torch.save(model.state_dict(), save_path)
206 | print(f"Saved DINOv2 model to {save_path}")
207 |
208 | except Exception as e:
209 | raise RuntimeError(f"Failed to download DINOv2 model: {str(e)}")
210 | else:
211 | # Load from local path
212 | print(f"Loading DINOv2 model from {model_path}")
213 | model = torch.hub.load('facebookresearch/dinov2', model_name,
214 | pretrained=False,
215 | force_reload=False,
216 | trust_repo=True)
217 | model.load_state_dict(torch.load(model_path))
218 |
219 | # Move model to specified device and apply precision settings
220 | model = model.to(self.device)
221 | if use_fp16:
222 | model = model.half()
223 |
224 | # Set attention backend if specified in config
225 | if hasattr(model, 'set_attention_backend') and attention_backend != 'default':
226 | model.set_attention_backend(attention_backend)
227 |
228 | model.eval()
229 | return model
230 |
231 | except Exception as e:
232 | raise RuntimeError(f"Error loading DINOv2 model: {str(e)}")
--------------------------------------------------------------------------------
/win_requirements.txt:
--------------------------------------------------------------------------------
1 | pillow==10.4.0
2 | imageio==2.36.1
3 | imageio-ffmpeg==0.5.1
4 | tqdm==4.67.1
5 | easydict==1.13
6 | opencv-python-headless==4.10.0.84
7 | scipy
8 | rembg
9 | onnxruntime
10 | trimesh
11 | xatlas
12 | pyvista
13 | pymeshfix
14 | igraph
15 | spconv-cu124
16 | onnxruntime
17 | git+https://github.com/EasternJournalist/utils3d.git
18 |
19 | # triton ( Windows specific builds from : https://github.com/woct0rdho/triton-windows )
20 | https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp312-cp312-win_amd64.whl; (python_version >= "3.12" and python_version < "3.13")
21 | https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp311-cp311-win_amd64.whl; (python_version >= "3.11" and python_version < "3.12")
22 | https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp310-cp310-win_amd64.whl; (python_version >= "3.10" and python_version < "3.11")
23 | https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp38-cp38-win_amd64.whl; (python_version >= "3.8" and python_version < "3.9")
24 |
--------------------------------------------------------------------------------
/workflow/Hi3DGen_WF_single.json:
--------------------------------------------------------------------------------
1 | {
2 | "id": "2ce13a2d-0e15-4260-b3b7-531ba96bf89e",
3 | "revision": 0,
4 | "last_node_id": 89,
5 | "last_link_id": 163,
6 | "nodes": [
7 | {
8 | "id": 87,
9 | "type": "DownloadAndLoadStableXModel",
10 | "pos": [-367.123107910156, 146.043090820313],
11 | "size": [315, 58],
12 | "flags": {
13 |
14 | },
15 | "order": 0,
16 | "mode": 0,
17 | "inputs": [],
18 | "outputs": [
19 | {
20 | "name": "pipeline",
21 | "type": "YOSOPIPE",
22 | "links": [160]
23 | }
24 | ],
25 | "properties": {
26 | "Node name for S&R": "DownloadAndLoadStableXModel"
27 | },
28 | "widgets_values": [
29 | "yoso-normal-v1-8-1"
30 | ]
31 | },
32 | {
33 | "id": 88,
34 | "type": "StableXProcessImage",
35 | "pos": [490, 180],
36 | "size": [315, 150],
37 | "flags": {
38 |
39 | },
40 | "order": 3,
41 | "mode": 0,
42 | "inputs": [
43 | {
44 | "name": "pipeline",
45 | "type": "YOSOPIPE",
46 | "link": 160
47 | },
48 | {
49 | "name": "image",
50 | "type": "IMAGE",
51 | "link": 161
52 | }
53 | ],
54 | "outputs": [
55 | {
56 | "name": "image",
57 | "type": "IMAGE",
58 | "links": [162, 163]
59 | }
60 | ],
61 | "properties": {
62 | "Node name for S&R": "StableXProcessImage"
63 | },
64 | "widgets_values": [2048, 1, 148262827792070, "randomize"
65 | ]
66 | },
67 | {
68 | "id": 89,
69 | "type": "PreviewImage",
70 | "pos": [490, 480],
71 | "size": [210, 246],
72 | "flags": {
73 |
74 | },
75 | "order": 5,
76 | "mode": 0,
77 | "inputs": [
78 | {
79 | "name": "images",
80 | "type": "IMAGE",
81 | "link": 163
82 | }
83 | ],
84 | "outputs": [],
85 | "properties": {
86 | "Node name for S&R": "PreviewImage"
87 | },
88 | "widgets_values": [
89 | ""
90 | ]
91 | },
92 | {
93 | "id": 86,
94 | "type": "IF_TrellisCheckpointLoader",
95 | "pos": [-358.851715087891, -174.489013671875],
96 | "size": [315, 202],
97 | "flags": {
98 |
99 | },
100 | "order": 1,
101 | "mode": 0,
102 | "inputs": [],
103 | "outputs": [
104 | {
105 | "name": "model",
106 | "type": "TRELLIS_MODEL",
107 | "links": [159]
108 | }
109 | ],
110 | "properties": {
111 | "Node name for S&R": "IF_TrellisCheckpointLoader"
112 | },
113 | "widgets_values": [
114 | "trellis-normal-v0-1",
115 | "dinov2_vitl14_reg",
116 | true, "xformers",
117 | "spconv",
118 | "implicit_gemm",
119 | "cuda"
120 | ]
121 | },
122 | {
123 | "id": 84,
124 | "type": "IF_TrellisImageTo3D",
125 | "pos": [860.648742675781, 208.74560546875],
126 | "size": [340.200012207031, 506],
127 | "flags": {
128 |
129 | },
130 | "order": 4,
131 | "mode": 0,
132 | "inputs": [
133 | {
134 | "name": "model",
135 | "type": "TRELLIS_MODEL",
136 | "link": 159
137 | },
138 | {
139 | "name": "images",
140 | "type": "IMAGE",
141 | "link": 162
142 | },
143 | {
144 | "name": "masks",
145 | "shape": 7,
146 | "type": "MASK",
147 | "link": null
148 | }
149 | ],
150 | "outputs": [
151 | {
152 | "name": "model_file",
153 | "type": "STRING",
154 | "links": [157]
155 | },
156 | {
157 | "name": "video_path",
158 | "type": "STRING",
159 | "links": []
160 | },
161 | {
162 | "name": "texture_image",
163 | "type": "IMAGE",
164 | "links": null
165 | }
166 | ],
167 | "properties": {
168 | "Node name for S&R": "IF_TrellisImageTo3D"
169 | },
170 | "widgets_values": [
171 | "single",
172 | 240919286, "randomize",
173 | 7.5, 12, 3, 12, 0.95, "stochastic",
174 | "test"
175 | ]
176 | },
177 | {
178 | "id": 75,
179 | "type": "Preview3D",
180 | "pos": [1249.12377929688, 220.59049987793],
181 | "size": [315, 550],
182 | "flags": {
183 |
184 | },
185 | "order": 6,
186 | "mode": 0,
187 | "inputs": [
188 | {
189 | "name": "model_file",
190 | "type": "STRING",
191 | "widget": {
192 | "name": "model_file"
193 | },
194 | "link": 157
195 | }
196 | ],
197 | "outputs": [],
198 | "properties": {
199 | "Node name for S&R": "Preview3D",
200 | "Camera Info": {
201 | "position": {
202 | "x": 9.19323855253915,
203 | "y": 2.91023369098772,
204 | "z": -1.400979112753
205 | },
206 | "target": {
207 | "x": 0,
208 | "y": 1.61671303147776,
209 | "z": 0
210 | },
211 | "zoom": 1,
212 | "cameraType": "perspective"
213 | }
214 | },
215 | "widgets_values": [
216 | "test/test.glb",
217 | ""
218 | ]
219 | },
220 | {
221 | "id": 67,
222 | "type": "LoadImage",
223 | "pos": [-368.674621582031, 330.965423583984],
224 | "size": [315, 314],
225 | "flags": {
226 |
227 | },
228 | "order": 2,
229 | "mode": 0,
230 | "inputs": [],
231 | "outputs": [
232 | {
233 | "name": "IMAGE",
234 | "type": "IMAGE",
235 | "slot_index": 0,
236 | "links": [161]
237 | },
238 | {
239 | "name": "MASK",
240 | "type": "MASK",
241 | "links": null
242 | }
243 | ],
244 | "properties": {
245 | "Node name for S&R": "LoadImage"
246 | },
247 | "widgets_values": [
248 | "Generated Image March 31, 2025 - 3_19PM.jpeg",
249 | "image",
250 | ""
251 | ]
252 | }
253 | ],
254 | "links": [
255 | [157, 84, 0, 75, 0, "STRING"
256 | ],
257 | [159, 86, 0, 84, 0, "TRELLIS_MODEL"
258 | ],
259 | [160, 87, 0, 88, 0, "YOSOPIPE"
260 | ],
261 | [161, 67, 0, 88, 1, "IMAGE"
262 | ],
263 | [162, 88, 0, 84, 1, "IMAGE"
264 | ],
265 | [163, 88, 0, 89, 0, "IMAGE"
266 | ]
267 | ],
268 | "groups": [
269 | {
270 | "id": 1,
271 | "title": "Hi3DGen",
272 | "bounding": [450, 90, 1017.04064941406, 874.525085449219],
273 | "color": "#3f789e",
274 | "font_size": 24,
275 | "flags": {
276 |
277 | }
278 | },
279 | {
280 | "id": 2,
281 | "title": "Image",
282 | "bounding": [-396.156677246094, 258.446929931641, 355.320861816406, 402.488861083984],
283 | "color": "#3f789e",
284 | "font_size": 24,
285 | "flags": {
286 |
287 | }
288 | }
289 | ],
290 | "config": {
291 |
292 | },
293 | "extra": {
294 | "ds": {
295 | "scale": 0.640376065342708,
296 | "offset": [1194.09386904219, 251.478399306866]
297 | },
298 | "ue_links": [],
299 | "reroutes": [
300 | {
301 | "id": 1,
302 | "pos": [248.001159667969, 386.019378662109],
303 | "linkIds": [159]
304 | }
305 | ],
306 | "VHS_latentpreview": false,
307 | "VHS_latentpreviewrate": 0,
308 | "VHS_MetadataImage": true,
309 | "VHS_KeepIntermediate": true,
310 | "linkExtensions": [
311 | {
312 | "id": 159,
313 | "parentId": 1
314 | }
315 | ]
316 | },
317 | "version": 0.4
318 | }
319 |
--------------------------------------------------------------------------------