├── .gitignore ├── IF_Trellis.py ├── IF_TrellisCheckpointLoader.py ├── LICENSE ├── README.md ├── StableXWrapper.py ├── __init__.py ├── assets └── teaser.png ├── linux_requirements.txt ├── pyproject.toml ├── requirements.txt ├── stablex ├── __init__.py ├── controlnetvae.py └── pipeline_yoso.py ├── trellis ├── __init__.py ├── backend_config.py ├── models │ ├── __init__.py │ ├── sparse_structure_flow.py │ ├── sparse_structure_vae.py │ ├── structured_latent_flow.py │ └── structured_latent_vae │ │ ├── __init__.py │ │ ├── base.py │ │ ├── decoder_mesh.py │ │ └── encoder.py ├── modules │ ├── __init__.py │ ├── attention │ │ ├── __init__.py │ │ ├── full_attn.py │ │ └── modules.py │ ├── attention_utils.py │ ├── norm.py │ ├── sparse │ │ ├── __init__.py │ │ ├── attention │ │ │ ├── __init__.py │ │ │ ├── full_attn.py │ │ │ ├── modules.py │ │ │ ├── serialized_attn.py │ │ │ └── windowed_attn.py │ │ ├── basic.py │ │ ├── conv │ │ │ ├── __init__.py │ │ │ ├── conv_spconv.py │ │ │ └── conv_torchsparse.py │ │ ├── linear.py │ │ ├── nonlinearity.py │ │ ├── norm.py │ │ ├── spatial.py │ │ └── transformer │ │ │ ├── __init__.py │ │ │ ├── blocks.py │ │ │ └── modulated.py │ ├── spatial.py │ ├── transformer │ │ ├── __init__.py │ │ ├── blocks.py │ │ └── modulated.py │ └── utils.py ├── pipelines │ ├── __init__.py │ ├── base.py │ ├── samplers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classifier_free_guidance_mixin.py │ │ ├── flow_euler.py │ │ └── guidance_interval_mixin.py │ └── trellis_image_to_3d.py ├── representations │ ├── __init__.py │ └── mesh │ │ ├── __init__.py │ │ ├── cube2mesh.py │ │ ├── flexicube.py │ │ ├── tables.py │ │ └── utils_cube.py └── utils │ ├── __init__.py │ ├── _rasterization.py │ ├── general_utils.py │ └── random_utils.py ├── trellis_model_manager.py ├── win_requirements.txt └── workflow └── Hi3DGen_WF_single.json /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /IF_TrellisCheckpointLoader.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/if-ai/ComfyUI-IF_Trellis/blob/main/IF_TrellisCheckpointLoader.py 2 | import os 3 | import logging 4 | import torch 5 | import huggingface_hub 6 | import folder_paths 7 | from trellis_model_manager import TrellisModelManager 8 | from trellis.pipelines.trellis_image_to_3d import TrellisImageTo3DPipeline 9 | from trellis.backend_config import ( 10 | set_attention_backend, 11 | set_sparse_backend, 12 | get_available_backends, 13 | get_available_sparse_backends 14 | ) 15 | from typing import Literal 16 | from torchvision import transforms 17 | 18 | logger = logging.getLogger("IF_Trellis") 19 | 20 | class IF_TrellisCheckpointLoader: 21 | """ 22 | Node to manage the loading of the TRELLIS model with lazy backend selection. 23 | """ 24 | def __init__(self): 25 | self.logger = logger 26 | self.model_manager = None 27 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | # We might call these to figure out what's actually installed, 30 | # if we want to populate UI dropdowns: 31 | self.attn_backends = get_available_backends() # e.g. { 'xformers': True, 'flash_attn': False, ... } 32 | self.sparse_backends = get_available_sparse_backends()# e.g. { 'spconv': True, 'torchsparse': True } 33 | 34 | @classmethod 35 | def INPUT_TYPES(cls): 36 | """Define input types with device-specific options.""" 37 | # Filter only available backends 38 | attn_backends = get_available_backends() 39 | sparse_backends = get_available_sparse_backends() 40 | 41 | # e.g. create a list of names that are True: 42 | available_attn = [k for k, v in attn_backends.items() if v] 43 | if not available_attn: 44 | available_attn = ['flash_attn'] # fallback 45 | 46 | available_sparse = [k for k, v in sparse_backends.items() if v] 47 | if not available_sparse: 48 | available_sparse = ['spconv'] # fallback 49 | 50 | return { 51 | "required": { 52 | "model_name": (["trellis-normal-v0-1"],), 53 | "dinov2_model": (["dinov2_vitl14_reg"], 54 | {"default": "dinov2_vitl14_reg", 55 | "tooltip": "Select which Dinov2 model to use."}), 56 | "use_fp16": ("BOOLEAN", {"default": True}), 57 | # 58 | # The user picks from the actually installed backends 59 | # 60 | "attn_backend": (available_attn, 61 | {"default": "flash_attn" if "flash_attn" in available_attn else available_attn[0], 62 | "tooltip": "Select attention backend."}), 63 | "sparse_backend": (available_sparse, 64 | {"default": "spconv" if "spconv" in available_sparse else available_sparse[0], 65 | "tooltip": "Select sparse backend."}), 66 | "spconv_algo": (["implicit_gemm", "native", "auto"], 67 | {"default": "implicit_gemm", 68 | "tooltip": "Spconv algorithm. 'implicit_gemm' is slower but more robust."}), 69 | "smooth_k": ("BOOLEAN", 70 | {"default": True, 71 | "tooltip": "Smooth-k for SageAttention. Only relevant if attn_backend=sage."}), 72 | }, 73 | } 74 | 75 | RETURN_TYPES = ("TRELLIS_MODEL",) 76 | RETURN_NAMES = ("model",) 77 | FUNCTION = "load_model" 78 | CATEGORY = "ImpactFrames💥🎞️/Trellis" 79 | 80 | def _setup_environment(self, attn_backend: str, sparse_backend: str, spconv_algo: str, smooth_k: bool): 81 | """ 82 | Set up environment variables and backends lazily. 83 | This is the main difference: we call our new lazy set_*_backend funcs. 84 | """ 85 | # Try attention 86 | success = set_attention_backend(attn_backend) 87 | if not success: 88 | self.logger.warning(f"Failed to set {attn_backend} or not installed, fallback to sdpa.") 89 | 90 | # Try sparse 91 | success2 = set_sparse_backend(sparse_backend, spconv_algo) 92 | if not success2: 93 | self.logger.warning(f"Failed to set {sparse_backend} or not installed, fallback to default.") 94 | 95 | # If user wants SageAttn smooth_k, we set environment var (if they'd want that): 96 | os.environ['SAGEATTN_SMOOTH_K'] = '1' if smooth_k else '0' 97 | 98 | def _initialize_transforms(self): 99 | """Initialize image transforms if needed.""" 100 | return transforms.Compose([ 101 | transforms.Normalize( 102 | mean=[0.485, 0.456, 0.406], 103 | std=[0.229, 0.224, 0.225] 104 | ) 105 | ]) 106 | 107 | def _optimize_pipeline(self, pipeline, use_fp16: bool = True): 108 | """ 109 | Apply typical optimizations, half-precision, etc. 110 | """ 111 | if self.device.type == "cuda": 112 | try: 113 | if hasattr(pipeline, 'cuda'): 114 | pipeline.cuda() 115 | 116 | if use_fp16: 117 | if hasattr(pipeline, 'enable_attention_slicing'): 118 | pipeline.enable_attention_slicing(slice_size="auto") 119 | if hasattr(pipeline, 'half'): 120 | pipeline.half() 121 | except Exception as e: 122 | logger.warning(f"Some pipeline optimizations failed: {str(e)}") 123 | 124 | return pipeline 125 | 126 | def load_model( 127 | self, 128 | model_name: str, 129 | dinov2_model: str = "dinov2_vitl14_reg", 130 | attn_backend: str = "sdpa", 131 | sparse_backend: str = "spconv", 132 | spconv_algo: str = "implicit_gemm", 133 | use_fp16: bool = True, 134 | smooth_k: bool = True, 135 | ) -> tuple: 136 | """ 137 | Load and configure the TRELLIS pipeline. 138 | This is typically the main function invoked by ComfyUI at node execution time. 139 | """ 140 | try: 141 | # 1) Setup environment + backends 142 | self._setup_environment(attn_backend, sparse_backend, spconv_algo, smooth_k) 143 | 144 | # 2) Get model paths, download if needed 145 | model_path = os.path.join(folder_paths.models_dir, "checkpoints", model_name) 146 | if not os.path.exists(model_path) or not os.listdir(model_path): 147 | repo_id = "Stable-X" 148 | try: 149 | huggingface_hub.snapshot_download( 150 | f"{repo_id}/{model_name}", 151 | repo_type="model", 152 | local_dir=model_path 153 | ) 154 | except Exception as e: 155 | raise RuntimeError(f"Failed to download {repo_id}/{model_name} to: {model_path}, {e}") 156 | 157 | # 3) Create pipeline with the config 158 | pipeline = TrellisImageTo3DPipeline.from_pretrained( 159 | model_path, 160 | dinov2_model=dinov2_model 161 | ) 162 | pipeline._device = self.device # ensure pipeline uses our same device 163 | 164 | # 4) Apply optimizations 165 | pipeline = self._optimize_pipeline(pipeline, use_fp16) 166 | 167 | return (pipeline,) 168 | 169 | except Exception as e: 170 | logger.error(f"Error loading TRELLIS model: {str(e)}") 171 | raise 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-Hi3DGen 2 |
3 | 4 | [![Website](https://raw.githubusercontent.com/prs-eth/Marigold/main/doc/badges/badge-website.svg)](https://stable-x.github.io/Hi3DGen/) 5 | [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2503.22236) 6 | [![Online Demo](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Space-yellow)](https://huggingface.co/spaces/Stable-X/Hi3DGen) 7 | [![Hugging Face Model](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Model-green)](https://huggingface.co/Stable-X/trellis-normal-v0-1) 8 |
9 | 10 | This extension integrates [Hi3DGen](https://github.com/Stable-X/Hi3DGen) into ComfyUI, allowing user to generate high-fidelity 3D geometry generation from Images. 11 | ![comfyui_t2mv](assets/teaser.png) 12 | 13 | ## 🔥 Feature Updates 14 | * [2024-04-02] Support window installation 15 | * [2024-03-31] Support single-view image to 3D goemetry generation 16 | 17 | ## Installation 18 | 19 | ### From Source (Linux) 20 | * Clone or download this repository into your `ComfyUI/custom_nodes/` directory. 21 | * Install the required dependencies by running `pip install -r linux_requirements.txt`. 22 | 23 | ### From Source (Window) 24 | * Clone or download this repository into your `ComfyUI/custom_nodes/` directory. 25 | * Install the required dependencies by running `pip install -r win_requirements.txt`. 26 | 27 | ## Notes 28 | 29 | ### Workflows 30 | 31 | We provide the example workflows in `workflows` directory. 32 | 33 | Note that our code depends on diffusers, and will automatically download the model weights from huggingface to the hf cache path at the first time. The `model_name` in the node corresponds to the model name in huggingface, such as `yoso-normal-v1-8-1`. 34 | 35 | ## Usage 36 | ### Single-view Image to 3D 37 | * `workflow/Hi3DGen_WF_single.json` 38 | 39 | ## Acknowledgement 40 | This repository builds upon the excellent work in [Trellis](https://github.com/microsoft/TRELLIS), [ComfyUI-IF_Trellis](https://github.com/if-ai/ComfyUI-IF_Trellis) and [ComfyUI-StableXWrapper](https://github.com/kijai/ComfyUI-StableXWrapper). Special thanks to their developers for the foundational work that made this project possible. 41 | -------------------------------------------------------------------------------- /StableXWrapper.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/kijai/ComfyUI-StableXWrapper/blob/main/nodes.py 2 | import os 3 | import torch 4 | import json 5 | 6 | from accelerate import init_empty_weights 7 | from accelerate.utils import set_module_tensor_to_device 8 | 9 | from stablex.pipeline_yoso import YosoPipeline 10 | from stablex.controlnetvae import ControlNetVAEModel 11 | 12 | from diffusers.models import ( 13 | AutoencoderKL, 14 | UNet2DConditionModel, 15 | ) 16 | 17 | import folder_paths 18 | import comfy.model_management as mm 19 | from comfy.utils import load_torch_file, ProgressBar 20 | 21 | script_directory = os.path.dirname(os.path.abspath(__file__)) 22 | import logging 23 | log = logging.getLogger(__name__) 24 | 25 | #region Model loading 26 | 27 | class DownloadAndLoadStableXModel: 28 | @classmethod 29 | def INPUT_TYPES(s): 30 | return { 31 | "required": { 32 | "model": (["yoso-normal-v1-8-1"],), 33 | }, 34 | } 35 | 36 | RETURN_TYPES = ("YOSOPIPE",) 37 | RETURN_NAMES = ("pipeline", ) 38 | FUNCTION = "loadmodel" 39 | CATEGORY = "StableXWrapper" 40 | 41 | def loadmodel(self, model): 42 | device = mm.get_torch_device() 43 | offload_device = mm.unet_offload_device() 44 | 45 | download_path = os.path.join(folder_paths.models_dir,"diffusers") 46 | model_path = os.path.join(download_path, model) 47 | 48 | if not os.path.exists(model_path): 49 | log.info(f"Downloading model to: {model_path}") 50 | from huggingface_hub import snapshot_download 51 | snapshot_download( 52 | repo_id=f"Stable-X/{model}", 53 | #allow_patterns=[f"*{model}*"], 54 | ignore_patterns=["*text_encoder*", "tokenizer*", "*scheduler*"], 55 | local_dir=model_path, 56 | local_dir_use_symlinks=False, 57 | ) 58 | 59 | torch_dtype = torch.float16 60 | config_path = os.path.join(model_path, 'unet', 'config.json') 61 | unet_ckpt_path_safetensors = os.path.join(model_path, 'unet','diffusion_pytorch_model.fp16.safetensors') 62 | 63 | if not os.path.exists(config_path): 64 | raise FileNotFoundError(f"Config not found at {config_path}") 65 | 66 | with open(config_path, 'r', encoding='utf-8') as file: 67 | config = json.load(file) 68 | 69 | with init_empty_weights(): 70 | unet = UNet2DConditionModel(**config) 71 | 72 | if os.path.exists(unet_ckpt_path_safetensors): 73 | import safetensors.torch 74 | unet_sd = safetensors.torch.load_file(unet_ckpt_path_safetensors) 75 | else: 76 | raise FileNotFoundError(f"No checkpoint found at {unet_ckpt_path_safetensors}") 77 | 78 | for name, param in unet.named_parameters(): 79 | set_module_tensor_to_device(unet, name, device=offload_device, dtype=torch_dtype, value=unet_sd[name]) 80 | 81 | vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", variant="fp16", device=device, torch_dtype=torch_dtype) 82 | controlnet = ControlNetVAEModel.from_pretrained(model_path, subfolder="controlnet", variant="fp16", device=device, torch_dtype=torch_dtype) 83 | 84 | pipeline = YosoPipeline( 85 | unet=unet, 86 | vae = vae, 87 | controlnet = controlnet, 88 | ) 89 | 90 | #pipeline.enable_model_cpu_offload() 91 | return (pipeline,) 92 | 93 | class StableXProcessImage: 94 | @classmethod 95 | def INPUT_TYPES(s): 96 | return { 97 | "required": { 98 | "pipeline": ("YOSOPIPE",), 99 | "image": ("IMAGE", ), 100 | "processing_resolution": ("INT", {"default": 2048, "min": 64, "max": 4096, "step": 16}), 101 | "controlnet_strength": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 10.0, "step": 0.01, "tooltip": "controlnet condition scale"}), 102 | "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Seed only affects normal prediction mode"}), 103 | }, 104 | } 105 | 106 | RETURN_TYPES = ("IMAGE",) 107 | RETURN_NAMES = ("image",) 108 | FUNCTION = "process" 109 | CATEGORY = "StableXWrapper" 110 | 111 | def process(self, pipeline, image, processing_resolution,controlnet_strength, seed): 112 | 113 | device = mm.get_torch_device() 114 | offload_device = mm.unet_offload_device() 115 | 116 | image = image.permute(0, 3, 1, 2).to(device).to(torch.float16) 117 | 118 | pipeline.unet.to(device) 119 | pipeline.vae.to(device) 120 | pipeline.controlnet.to(device) 121 | 122 | pipe_out = pipeline( 123 | image, 124 | controlnet_conditioning_scale=controlnet_strength, 125 | processing_resolution=processing_resolution, 126 | generator = torch.Generator(device=device).manual_seed(seed), 127 | output_type="pt", 128 | ) 129 | pipeline.unet.to(offload_device) 130 | pipeline.vae.to(offload_device) 131 | pipeline.controlnet.to(offload_device) 132 | pipe_out = (pipe_out.prediction.clip(-1, 1) + 1) / 2 133 | 134 | out_tensor = pipe_out.permute(0, 2, 3, 1).cpu().float() 135 | 136 | return (out_tensor, ) 137 | 138 | class DifferenceExtractorNode: 139 | @classmethod 140 | def INPUT_TYPES(s): 141 | return { 142 | "required": { 143 | "original_image": ("IMAGE",), 144 | "processed_image": ("IMAGE",), 145 | "amplification": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}), 146 | } 147 | } 148 | 149 | RETURN_TYPES = ("IMAGE",) 150 | FUNCTION = "extract_luminosity_difference" 151 | CATEGORY = "iStableXWrapper" 152 | 153 | def extract_luminosity_difference(self, original_image, processed_image, amplification=1.0): 154 | import torch 155 | 156 | # RGB to luminosity conversion weights 157 | rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722]).to(original_image.device) 158 | 159 | # Convert images to luminosity (shape: B,H,W) 160 | original_lum = torch.sum(original_image * rgb_weights[None, None, None, :], dim=3) 161 | processed_lum = torch.sum(processed_image * rgb_weights[None, None, None, :], dim=3) 162 | 163 | # Calculate luminosity difference 164 | difference = (original_lum - processed_lum) * amplification 165 | 166 | # Normalize and clamp 167 | difference = torch.clamp(difference, 0, 1) 168 | 169 | # Convert back to RGB format (all channels identical) 170 | difference = difference.unsqueeze(3).repeat(1, 1, 1, 3) 171 | 172 | return (difference,) 173 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #__init__.py 2 | import os 3 | import sys 4 | import torch 5 | import logging 6 | import platform 7 | import folder_paths 8 | 9 | # Configure logging 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger('ComfyUI-Hi3DGen') 12 | 13 | # Add parent directory to Python path 14 | current_dir = os.path.dirname(os.path.abspath(__file__)) 15 | parent_dir = os.path.dirname(current_dir) 16 | 17 | # Add both current and parent dir to handle different installation scenarios 18 | if current_dir not in sys.path: 19 | sys.path.insert(0, current_dir) 20 | if parent_dir not in sys.path: 21 | sys.path.insert(0, parent_dir) 22 | 23 | # Add trellis package path 24 | trellis_path = os.path.join(current_dir, "trellis") 25 | if os.path.exists(trellis_path) and trellis_path not in sys.path: 26 | sys.path.insert(0, trellis_path) 27 | logger.info(f"Added trellis path to sys.path: {trellis_path}") 28 | 29 | # Add stablx package path 30 | stablex_path = os.path.join(current_dir, "stablex") 31 | if os.path.exists(trellis_path) and trellis_path not in sys.path: 32 | sys.path.insert(0, trellis_path) 33 | logger.info(f"Added stablex path to sys.path: {trellis_path}") 34 | 35 | # Verify trellis package is importable 36 | try: 37 | import trellis 38 | logger.info("Trellis package imported successfully") 39 | except ImportError as e: 40 | logger.error(f"Failed to import trellis package: {e}") 41 | logger.error(f"Current sys.path: {sys.path}") 42 | raise 43 | 44 | # Verify stablex package is importable 45 | try: 46 | import stablex 47 | logger.info("stablex package imported successfully") 48 | except ImportError as e: 49 | logger.error(f"Failed to import stablex package: {e}") 50 | logger.error(f"Current sys.path: {sys.path}") 51 | raise 52 | 53 | # Register model paths with ComfyUI 54 | try: 55 | folder_paths.add_model_folder_path("trellis", os.path.join(folder_paths.models_dir, "trellis")) 56 | folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.models_dir, "checkpoints")) 57 | except Exception as e: 58 | logger.error(f"Error registering model paths: {e}") 59 | 60 | # Register model paths with ComfyUI 61 | try: 62 | folder_paths.add_model_folder_path("stablex", os.path.join(folder_paths.models_dir, "stablex")) 63 | except Exception as e: 64 | logger.error(f"Error registering model paths: {e}") 65 | 66 | try: 67 | from IF_TrellisCheckpointLoader import IF_TrellisCheckpointLoader 68 | from IF_Trellis import IF_TrellisImageTo3D 69 | from StableXWrapper import DownloadAndLoadStableXModel, StableXProcessImage, DifferenceExtractorNode 70 | NODE_CLASS_MAPPINGS = { 71 | "IF_TrellisCheckpointLoader": IF_TrellisCheckpointLoader, 72 | "IF_TrellisImageTo3D": IF_TrellisImageTo3D, 73 | "DownloadAndLoadStableXModel": DownloadAndLoadStableXModel, 74 | "StableXProcessImage": StableXProcessImage, 75 | "DifferenceExtractorNode": DifferenceExtractorNode 76 | } 77 | 78 | NODE_DISPLAY_NAME_MAPPINGS = { 79 | "IF_TrellisCheckpointLoader": "Trellis Model Loader 💾", 80 | "IF_TrellisImageTo3D": "Trellis Image to 3D 🖼️➡️🎲", 81 | "DownloadAndLoadStableXModel": "(Down)load StableX Model", 82 | "StableXProcessImage": "StableX Process Image", 83 | "DifferenceExtractorNode": "Extract Difference" 84 | } 85 | 86 | except Exception as e: 87 | logger.error(f"Error importing node classes: {e}") 88 | raise 89 | 90 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-X/ComfyUI-Hi3DGen/99621aa48dd05203bee91cf3d5813669e0b5721d/assets/teaser.png -------------------------------------------------------------------------------- /linux_requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | torchvision==0.19.0 3 | pillow==10.4.0 4 | imageio==2.36.1 5 | imageio-ffmpeg==0.5.1 6 | tqdm==4.67.1 7 | easydict==1.13 8 | opencv-python-headless==4.10.0.84 9 | scipy==1.14.1 10 | rembg==2.0.60 11 | onnxruntime==1.20.1 12 | trimesh==4.5.3 13 | git+https://github.com/EasternJournalist/utils3d.git 14 | xformers==0.0.27.post2 15 | spconv-cu120==2.3.6 16 | transformers==4.46.3 17 | accelerate==1.5.2 18 | diffusers==0.32.2 19 | 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-hi3dgen" 3 | description = "ComfyUI Hi3DGen creates a high-fiedlity 3d mesh from a single view or multi angle pictures" 4 | version = "0.2.5" 5 | license = { file = "MIT License" } 6 | dependencies = [ 7 | "pillow", 8 | "imageio", 9 | "imageio-ffmpeg", 10 | "tqdm", 11 | "easydict", 12 | "opencv-python-headless", 13 | "scipy", 14 | "ninja", 15 | "plyfile", 16 | "rembg", 17 | "onnxruntime-gpu", 18 | "trimesh", 19 | "git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8", 20 | "xformers", 21 | "spconv-cu124", 22 | "diffrp-nvdiffrast", 23 | 24 | # triton for linux 25 | 'triton; sys_platform == "linux"', 26 | # triton for windows 27 | 'https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp312-cp312-win_amd64.whl; sys_platform == "win64" and (python_version >= "3.12" and python_version < "3.13")', 28 | 'https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp311-cp311-win_amd64.whl; sys_platform == "win64" and (python_version >= "3.11" and python_version < "3.12")', 29 | 'https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp310-cp310-win_amd64.whl; sys_platform == "win64" and (python_version >= "3.10" and python_version < "3.11")', 30 | 'https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp38-cp38-win_amd64.whl; sys_platform == "win64" and (python_version >= "3.8" and python_version < "3.9")' 31 | ] 32 | 33 | [project.urls] 34 | Repository = "https://github.com/Stable-X/ComfyUI-Hi3DGen" 35 | # Used by Comfy Registry https://comfyregistry.org 36 | 37 | [tool.comfy] 38 | PublisherId = "impactframes" 39 | DisplayName = "IF_Hi3DGen" 40 | Icon = "https://impactframes.ai/System/Icons/48x48/if.png" 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | torchvision==0.19.0 3 | diffusers==0.32.2 4 | accelerate==1.5.2 5 | pillow 6 | imageio 7 | imageio-ffmpeg 8 | tqdm 9 | easydict 10 | opencv-python-headless 11 | scipy 12 | rembg 13 | ninja 14 | plyfile 15 | trimesh 16 | git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8 17 | xformers 18 | triton 19 | spconv-cu124>=2.3.6 20 | -------------------------------------------------------------------------------- /stablex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-X/ComfyUI-Hi3DGen/99621aa48dd05203bee91cf3d5813669e0b5721d/stablex/__init__.py -------------------------------------------------------------------------------- /trellis/__init__.py: -------------------------------------------------------------------------------- 1 | from . import models 2 | from . import modules 3 | from . import pipelines 4 | from . import representations 5 | from . import utils 6 | -------------------------------------------------------------------------------- /trellis/backend_config.py: -------------------------------------------------------------------------------- 1 | # trellis/backend_config.py 2 | from typing import * 3 | import os 4 | import logging 5 | import importlib 6 | 7 | # Global variables 8 | BACKEND = 'spconv' # Default sparse backend 9 | DEBUG = False # Debug mode flag 10 | ATTN = 'xformers' # Default attention backend 11 | SPCONV_ALGO = 'implicit_gemm' # Default algorithm 12 | 13 | def get_spconv_algo() -> str: 14 | """Get current spconv algorithm.""" 15 | global SPCONV_ALGO 16 | return SPCONV_ALGO 17 | 18 | def set_spconv_algo(algo: Literal['implicit_gemm', 'native', 'auto']) -> bool: 19 | """Set spconv algorithm with validation.""" 20 | global SPCONV_ALGO 21 | 22 | if algo not in ['implicit_gemm', 'native', 'auto']: 23 | logger.warning(f"Invalid spconv algorithm: {algo}. Must be 'implicit_gemm', 'native', or 'auto'") 24 | return False 25 | 26 | SPCONV_ALGO = algo 27 | os.environ['SPCONV_ALGO'] = algo 28 | logger.info(f"Set spconv algorithm to: {algo}") 29 | return True 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | def _try_import_xformers() -> bool: 34 | try: 35 | import xformers.ops 36 | return True 37 | except ImportError: 38 | return False 39 | 40 | def _try_import_flash_attn() -> bool: 41 | try: 42 | import flash_attn 43 | return True 44 | except ImportError: 45 | return False 46 | 47 | def _try_import_sageattention() -> bool: 48 | try: 49 | import torch.nn.functional as F 50 | from sageattention import sageattn 51 | F.scaled_dot_product_attention = sageattn 52 | #import sageattention 53 | return True 54 | except ImportError: 55 | return False 56 | 57 | def _try_import_spconv() -> bool: 58 | try: 59 | import spconv 60 | return True 61 | except ImportError: 62 | return False 63 | 64 | def _try_import_torchsparse() -> bool: 65 | try: 66 | import torchsparse 67 | return True 68 | except ImportError: 69 | return False 70 | 71 | def get_available_backends() -> Dict[str, bool]: 72 | """Return dict of available attention backends and their status""" 73 | return { 74 | 'xformers': _try_import_xformers(), 75 | 'flash_attn': _try_import_flash_attn(), 76 | 'sage': _try_import_sageattention(), 77 | 'naive': True, 78 | 'sdpa': True # Always available with PyTorch >= 2.0 79 | } 80 | 81 | def get_available_sparse_backends() -> Dict[str, bool]: 82 | """Return dict of available sparse backends and their status""" 83 | return { 84 | 'spconv': _try_import_spconv(), 85 | 'torchsparse': _try_import_torchsparse() 86 | } 87 | 88 | def get_attention_backend() -> str: 89 | """Get current attention backend""" 90 | global ATTN 91 | return ATTN 92 | 93 | def get_sparse_backend() -> str: 94 | """Get current sparse backend""" 95 | global BACKEND 96 | return BACKEND 97 | 98 | def get_debug_mode() -> bool: 99 | """Get current debug mode status""" 100 | global DEBUG 101 | return DEBUG 102 | 103 | def __from_env(): 104 | """Initialize settings from environment variables""" 105 | global BACKEND 106 | global DEBUG 107 | global ATTN 108 | 109 | env_sparse_backend = os.environ.get('SPARSE_BACKEND') 110 | env_sparse_debug = os.environ.get('SPARSE_DEBUG') 111 | env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') 112 | 113 | if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: 114 | BACKEND = env_sparse_backend 115 | if env_sparse_debug is not None: 116 | DEBUG = env_sparse_debug == '1' 117 | if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sage', 'sdpa', 'naive']: 118 | ATTN = env_sparse_attn 119 | os.environ['SPARSE_ATTN_BACKEND'] = env_sparse_attn 120 | os.environ['ATTN_BACKEND'] = env_sparse_attn 121 | 122 | logger.info(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") 123 | 124 | def set_backend(backend: Literal['spconv', 'torchsparse']) -> bool: 125 | """Set sparse backend with validation""" 126 | global BACKEND 127 | 128 | backend = backend.lower().strip() 129 | logger.info(f"Setting sparse backend to: {backend}") 130 | 131 | if backend == 'spconv': 132 | try: 133 | import spconv 134 | BACKEND = 'spconv' 135 | os.environ['SPARSE_BACKEND'] = 'spconv' 136 | return True 137 | except ImportError: 138 | logger.warning("spconv not available") 139 | return False 140 | 141 | elif backend == 'torchsparse': 142 | try: 143 | import torchsparse 144 | BACKEND = 'torchsparse' 145 | os.environ['SPARSE_BACKEND'] = 'torchsparse' 146 | return True 147 | except ImportError: 148 | logger.warning("torchsparse not available") 149 | return False 150 | 151 | return False 152 | 153 | def set_sparse_backend(backend: Literal['spconv', 'torchsparse'], algo: str = None) -> bool: 154 | """Alias for set_backend for backwards compatibility 155 | 156 | Parameters: 157 | backend: The sparse backend to use 158 | algo: The algorithm to use (only relevant for spconv backend) 159 | """ 160 | # Call set_backend first 161 | result = set_backend(backend) 162 | 163 | # If algorithm is provided and backend was set successfully 164 | if algo is not None and result: 165 | set_spconv_algo(algo) 166 | 167 | return result 168 | 169 | def set_debug(debug: bool): 170 | """Set debug mode""" 171 | global DEBUG 172 | DEBUG = debug 173 | if debug: 174 | os.environ['SPARSE_DEBUG'] = '1' 175 | else: 176 | os.environ['SPARSE_DEBUG'] = '0' 177 | 178 | def set_attn(attn: Literal['xformers', 'flash_attn', 'sage', 'sdpa', 'naive']) -> bool: 179 | """Set attention backend with validation""" 180 | global ATTN 181 | 182 | attn = attn.lower().strip() 183 | logger.info(f"Setting attention backend to: {attn}") 184 | 185 | if attn == 'xformers' and _try_import_xformers(): 186 | ATTN = 'xformers' 187 | os.environ['SPARSE_ATTN_BACKEND'] = 'xformers' 188 | os.environ['ATTN_BACKEND'] = 'xformers' 189 | return True 190 | 191 | elif attn == 'flash_attn' and _try_import_flash_attn(): 192 | ATTN = 'flash_attn' 193 | os.environ['SPARSE_ATTN_BACKEND'] = 'flash_attn' 194 | os.environ['ATTN_BACKEND'] = 'flash_attn' 195 | return True 196 | 197 | elif attn == 'sage' and _try_import_sageattention(): 198 | ATTN = 'sage' 199 | os.environ['SPARSE_ATTN_BACKEND'] = 'sage' 200 | os.environ['ATTN_BACKEND'] = 'sage' 201 | return True 202 | 203 | elif attn == 'sdpa': 204 | ATTN = 'sdpa' 205 | os.environ['SPARSE_ATTN_BACKEND'] = 'sdpa' 206 | os.environ['ATTN_BACKEND'] = 'sdpa' 207 | return True 208 | 209 | elif attn == 'naive': 210 | ATTN = 'naive' 211 | os.environ['SPARSE_ATTN_BACKEND'] = 'naive' 212 | os.environ['ATTN_BACKEND'] = 'naive' 213 | return True 214 | 215 | 216 | logger.warning(f"Attention backend {attn} not available") 217 | return False 218 | 219 | # Add alias for backwards compatibility 220 | def set_attention_backend(backend: Literal['xformers', 'flash_attn', 'sage', 'sdpa']) -> bool: 221 | """Alias for set_attn for backwards compatibility""" 222 | return set_attn(backend) 223 | 224 | # Initialize from environment variables on module import 225 | __from_env() 226 | -------------------------------------------------------------------------------- /trellis/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import torch.nn as nn 4 | 5 | __attributes = { 6 | 'SparseStructureEncoder': 'sparse_structure_vae', 7 | 'SparseStructureDecoder': 'sparse_structure_vae', 8 | 'SparseStructureFlowModel': 'sparse_structure_flow', 9 | 'SLatEncoder': 'structured_latent_vae', 10 | 'SLatGaussianDecoder': 'structured_latent_vae', 11 | 'SLatRadianceFieldDecoder': 'structured_latent_vae', 12 | 'SLatMeshDecoder': 'structured_latent_vae', 13 | 'SLatFlowModel': 'structured_latent_flow', 14 | } 15 | 16 | __submodules = [] 17 | 18 | __all__ = list(__attributes.keys()) + __submodules 19 | 20 | def __getattr__(name): 21 | if name not in globals(): 22 | if name in __attributes: 23 | module_name = __attributes[name] 24 | module = importlib.import_module(f".{module_name}", __name__) 25 | globals()[name] = getattr(module, name) 26 | elif name in __submodules: 27 | module = importlib.import_module(f".{name}", __name__) 28 | globals()[name] = module 29 | else: 30 | raise AttributeError(f"module {__name__} has no attribute {name}") 31 | return globals()[name] 32 | 33 | 34 | def from_pretrained(path: str) -> nn.Module: 35 | """ 36 | Load a pretrained model. 37 | 38 | Args: 39 | path (str): Full path to model file or HuggingFace repo ID with model name 40 | """ 41 | import os 42 | import json 43 | from safetensors.torch import load_file 44 | 45 | # Split path into directory and model name 46 | path = os.path.normpath(path) 47 | model_dir = os.path.dirname(os.path.dirname(path)) # Go up two levels (past ckpts/) 48 | model_name = os.path.basename(path) 49 | 50 | is_local = os.path.exists(model_dir) 51 | 52 | if is_local: 53 | # For local paths 54 | print(f"Loading local model: {model_name}") 55 | model_name = model_name.replace('ckpts/', '').replace('ckpts\\', '') 56 | config_path = os.path.normpath(os.path.join(model_dir, "ckpts", f"{model_name}.json")) 57 | weights_path = os.path.normpath(os.path.join(model_dir, "ckpts", f"{model_name}.safetensors")) 58 | 59 | if not os.path.exists(config_path): 60 | raise FileNotFoundError(f"Config file not found: {config_path}") 61 | if not os.path.exists(weights_path): 62 | raise FileNotFoundError(f"Weights file not found: {weights_path}") 63 | 64 | # Load config 65 | with open(config_path, 'r') as f: 66 | config = json.load(f) 67 | 68 | # Create model 69 | model = create_model_from_config(config) 70 | 71 | # Load weights 72 | state_dict = load_file(weights_path) 73 | model.load_state_dict(state_dict) 74 | 75 | else: 76 | # For HuggingFace paths 77 | from huggingface_hub import hf_hub_download 78 | 79 | config_file = hf_hub_download(path, f"{model_name}.json") 80 | with open(config_file, 'r') as f: 81 | config = json.load(f) 82 | 83 | model = create_model_from_config(config) 84 | 85 | weights_file = hf_hub_download(path, f"{model_name}.safetensors") 86 | state_dict = load_file(weights_file) 87 | model.load_state_dict(state_dict) 88 | 89 | return model 90 | 91 | def create_model_from_config(config): 92 | """Helper function to create model from config""" 93 | #print(f"Creating model from config: {config}") 94 | model_type = config.get('type') or config.get('name') 95 | #print(f"Model type: {model_type}") 96 | #print(f"Available model types: {list(__attributes.keys())}") 97 | if not model_type in __attributes: 98 | raise ValueError(f"Unknown model type: {model_type}") 99 | 100 | model_class = __getattr__(model_type) 101 | #print(f"Model class: {model_class}") 102 | args = config.get('args', {}) 103 | #print(f"Model args: {args}") 104 | return model_class(**args) 105 | 106 | 107 | # For Pylance 108 | if __name__ == '__main__': 109 | from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder 110 | from .sparse_structure_flow import SparseStructureFlowModel 111 | from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder 112 | from .structured_latent_flow import SLatFlowModel 113 | -------------------------------------------------------------------------------- /trellis/models/sparse_structure_flow.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from ..modules.utils import convert_module_to_f16, convert_module_to_f32 7 | from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock 8 | from ..modules.spatial import patchify, unpatchify 9 | 10 | 11 | class TimestepEmbedder(nn.Module): 12 | """ 13 | Embeds scalar timesteps into vector representations. 14 | """ 15 | def __init__(self, hidden_size, frequency_embedding_size=256): 16 | super().__init__() 17 | self.mlp = nn.Sequential( 18 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 19 | nn.SiLU(), 20 | nn.Linear(hidden_size, hidden_size, bias=True), 21 | ) 22 | self.frequency_embedding_size = frequency_embedding_size 23 | 24 | @staticmethod 25 | def timestep_embedding(t, dim, max_period=10000): 26 | """ 27 | Create sinusoidal timestep embeddings. 28 | 29 | Args: 30 | t: a 1-D Tensor of N indices, one per batch element. 31 | These may be fractional. 32 | dim: the dimension of the output. 33 | max_period: controls the minimum frequency of the embeddings. 34 | 35 | Returns: 36 | an (N, D) Tensor of positional embeddings. 37 | """ 38 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 39 | half = dim // 2 40 | freqs = torch.exp( 41 | -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 42 | ).to(device=t.device) 43 | args = t[:, None].float() * freqs[None] 44 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 45 | if dim % 2: 46 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 47 | return embedding 48 | 49 | def forward(self, t): 50 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 51 | t_emb = self.mlp(t_freq) 52 | return t_emb 53 | 54 | 55 | class SparseStructureFlowModel(nn.Module): 56 | def __init__( 57 | self, 58 | resolution: int, 59 | in_channels: int, 60 | model_channels: int, 61 | cond_channels: int, 62 | out_channels: int, 63 | num_blocks: int, 64 | num_heads: Optional[int] = None, 65 | num_head_channels: Optional[int] = 64, 66 | mlp_ratio: float = 4, 67 | patch_size: int = 2, 68 | pe_mode: Literal["ape", "rope"] = "ape", 69 | use_fp16: bool = False, 70 | use_checkpoint: bool = False, 71 | share_mod: bool = False, 72 | qk_rms_norm: bool = False, 73 | qk_rms_norm_cross: bool = False, 74 | ): 75 | super().__init__() 76 | self.resolution = resolution 77 | self.in_channels = in_channels 78 | self.model_channels = model_channels 79 | self.cond_channels = cond_channels 80 | self.out_channels = out_channels 81 | self.num_blocks = num_blocks 82 | self.num_heads = num_heads or model_channels // num_head_channels 83 | self.mlp_ratio = mlp_ratio 84 | self.patch_size = patch_size 85 | self.pe_mode = pe_mode 86 | self.use_fp16 = use_fp16 87 | self.use_checkpoint = use_checkpoint 88 | self.share_mod = share_mod 89 | self.qk_rms_norm = qk_rms_norm 90 | self.qk_rms_norm_cross = qk_rms_norm_cross 91 | self.dtype = torch.float16 if use_fp16 else torch.float32 92 | 93 | self.t_embedder = TimestepEmbedder(model_channels) 94 | if share_mod: 95 | self.adaLN_modulation = nn.Sequential( 96 | nn.SiLU(), 97 | nn.Linear(model_channels, 6 * model_channels, bias=True) 98 | ) 99 | 100 | if pe_mode == "ape": 101 | pos_embedder = AbsolutePositionEmbedder(model_channels, 3) 102 | coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') 103 | coords = torch.stack(coords, dim=-1).reshape(-1, 3) 104 | pos_emb = pos_embedder(coords) 105 | self.register_buffer("pos_emb", pos_emb) 106 | 107 | self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) 108 | 109 | self.blocks = nn.ModuleList([ 110 | ModulatedTransformerCrossBlock( 111 | model_channels, 112 | cond_channels, 113 | num_heads=self.num_heads, 114 | mlp_ratio=self.mlp_ratio, 115 | attn_mode='full', 116 | use_checkpoint=self.use_checkpoint, 117 | use_rope=(pe_mode == "rope"), 118 | share_mod=share_mod, 119 | qk_rms_norm=self.qk_rms_norm, 120 | qk_rms_norm_cross=self.qk_rms_norm_cross, 121 | ) 122 | for _ in range(num_blocks) 123 | ]) 124 | 125 | self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) 126 | 127 | self.initialize_weights() 128 | if use_fp16: 129 | self.convert_to_fp16() 130 | 131 | @property 132 | def device(self) -> torch.device: 133 | """ 134 | Return the device of the model. 135 | """ 136 | return next(self.parameters()).device 137 | 138 | def convert_to_fp16(self) -> None: 139 | """ 140 | Convert the torso of the model to float16. 141 | """ 142 | self.blocks.apply(convert_module_to_f16) 143 | 144 | def convert_to_fp32(self) -> None: 145 | """ 146 | Convert the torso of the model to float32. 147 | """ 148 | self.blocks.apply(convert_module_to_f32) 149 | 150 | def initialize_weights(self) -> None: 151 | # Initialize transformer layers: 152 | def _basic_init(module): 153 | if isinstance(module, nn.Linear): 154 | torch.nn.init.xavier_uniform_(module.weight) 155 | if module.bias is not None: 156 | nn.init.constant_(module.bias, 0) 157 | self.apply(_basic_init) 158 | 159 | # Initialize timestep embedding MLP: 160 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 161 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 162 | 163 | # Zero-out adaLN modulation layers in DiT blocks: 164 | if self.share_mod: 165 | nn.init.constant_(self.adaLN_modulation[-1].weight, 0) 166 | nn.init.constant_(self.adaLN_modulation[-1].bias, 0) 167 | else: 168 | for block in self.blocks: 169 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 170 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 171 | 172 | # Zero-out output layers: 173 | nn.init.constant_(self.out_layer.weight, 0) 174 | nn.init.constant_(self.out_layer.bias, 0) 175 | 176 | def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: 177 | assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ 178 | f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" 179 | 180 | h = patchify(x, self.patch_size) 181 | h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() 182 | 183 | h = self.input_layer(h) 184 | h = h + self.pos_emb[None] 185 | t_emb = self.t_embedder(t) 186 | if self.share_mod: 187 | t_emb = self.adaLN_modulation(t_emb) 188 | t_emb = t_emb.type(self.dtype) 189 | h = h.type(self.dtype) 190 | cond = cond.type(self.dtype) 191 | for block in self.blocks: 192 | h = block(h, t_emb, cond) 193 | h = h.type(x.dtype) 194 | h = F.layer_norm(h, h.shape[-1:]) 195 | h = self.out_layer(h) 196 | 197 | h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) 198 | h = unpatchify(h, self.patch_size).contiguous() 199 | 200 | return h 201 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_flow.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 7 | from ..modules.transformer import AbsolutePositionEmbedder 8 | from ..modules.norm import LayerNorm32 9 | from ..modules import sparse as sp 10 | from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock 11 | from .sparse_structure_flow import TimestepEmbedder 12 | 13 | 14 | class SparseResBlock3d(nn.Module): 15 | def __init__( 16 | self, 17 | channels: int, 18 | emb_channels: int, 19 | out_channels: Optional[int] = None, 20 | downsample: bool = False, 21 | upsample: bool = False, 22 | ): 23 | super().__init__() 24 | self.channels = channels 25 | self.emb_channels = emb_channels 26 | self.out_channels = out_channels or channels 27 | self.downsample = downsample 28 | self.upsample = upsample 29 | 30 | assert not (downsample and upsample), "Cannot downsample and upsample at the same time" 31 | 32 | self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) 33 | self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) 34 | self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) 35 | self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) 36 | self.emb_layers = nn.Sequential( 37 | nn.SiLU(), 38 | nn.Linear(emb_channels, 2 * self.out_channels, bias=True), 39 | ) 40 | self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() 41 | self.updown = None 42 | if self.downsample: 43 | self.updown = sp.SparseDownsample(2) 44 | elif self.upsample: 45 | self.updown = sp.SparseUpsample(2) 46 | 47 | def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: 48 | if self.updown is not None: 49 | x = self.updown(x) 50 | return x 51 | 52 | def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor: 53 | emb_out = self.emb_layers(emb).type(x.dtype) 54 | scale, shift = torch.chunk(emb_out, 2, dim=1) 55 | 56 | x = self._updown(x) 57 | h = x.replace(self.norm1(x.feats)) 58 | h = h.replace(F.silu(h.feats)) 59 | h = self.conv1(h) 60 | h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift 61 | h = h.replace(F.silu(h.feats)) 62 | h = self.conv2(h) 63 | h = h + self.skip_connection(x) 64 | 65 | return h 66 | 67 | 68 | class SLatFlowModel(nn.Module): 69 | def __init__( 70 | self, 71 | resolution: int, 72 | in_channels: int, 73 | model_channels: int, 74 | cond_channels: int, 75 | out_channels: int, 76 | num_blocks: int, 77 | num_heads: Optional[int] = None, 78 | num_head_channels: Optional[int] = 64, 79 | mlp_ratio: float = 4, 80 | patch_size: int = 2, 81 | num_io_res_blocks: int = 2, 82 | io_block_channels: List[int] = None, 83 | pe_mode: Literal["ape", "rope"] = "ape", 84 | use_fp16: bool = False, 85 | use_checkpoint: bool = False, 86 | use_skip_connection: bool = True, 87 | share_mod: bool = False, 88 | qk_rms_norm: bool = False, 89 | qk_rms_norm_cross: bool = False, 90 | ): 91 | super().__init__() 92 | self.resolution = resolution 93 | self.in_channels = in_channels 94 | self.model_channels = model_channels 95 | self.cond_channels = cond_channels 96 | self.out_channels = out_channels 97 | self.num_blocks = num_blocks 98 | self.num_heads = num_heads or model_channels // num_head_channels 99 | self.mlp_ratio = mlp_ratio 100 | self.patch_size = patch_size 101 | self.num_io_res_blocks = num_io_res_blocks 102 | self.io_block_channels = io_block_channels 103 | self.pe_mode = pe_mode 104 | self.use_fp16 = use_fp16 105 | self.use_checkpoint = use_checkpoint 106 | self.use_skip_connection = use_skip_connection 107 | self.share_mod = share_mod 108 | self.qk_rms_norm = qk_rms_norm 109 | self.qk_rms_norm_cross = qk_rms_norm_cross 110 | self.dtype = torch.float16 if use_fp16 else torch.float32 111 | 112 | assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2" 113 | assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages" 114 | 115 | self.t_embedder = TimestepEmbedder(model_channels) 116 | if share_mod: 117 | self.adaLN_modulation = nn.Sequential( 118 | nn.SiLU(), 119 | nn.Linear(model_channels, 6 * model_channels, bias=True) 120 | ) 121 | 122 | if pe_mode == "ape": 123 | self.pos_embedder = AbsolutePositionEmbedder(model_channels) 124 | 125 | self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0]) 126 | self.input_blocks = nn.ModuleList([]) 127 | for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]): 128 | self.input_blocks.extend([ 129 | SparseResBlock3d( 130 | chs, 131 | model_channels, 132 | out_channels=chs, 133 | ) 134 | for _ in range(num_io_res_blocks-1) 135 | ]) 136 | self.input_blocks.append( 137 | SparseResBlock3d( 138 | chs, 139 | model_channels, 140 | out_channels=next_chs, 141 | downsample=True, 142 | ) 143 | ) 144 | 145 | self.blocks = nn.ModuleList([ 146 | ModulatedSparseTransformerCrossBlock( 147 | model_channels, 148 | cond_channels, 149 | num_heads=self.num_heads, 150 | mlp_ratio=self.mlp_ratio, 151 | attn_mode='full', 152 | use_checkpoint=self.use_checkpoint, 153 | use_rope=(pe_mode == "rope"), 154 | share_mod=self.share_mod, 155 | qk_rms_norm=self.qk_rms_norm, 156 | qk_rms_norm_cross=self.qk_rms_norm_cross, 157 | ) 158 | for _ in range(num_blocks) 159 | ]) 160 | 161 | self.out_blocks = nn.ModuleList([]) 162 | for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))): 163 | self.out_blocks.append( 164 | SparseResBlock3d( 165 | prev_chs * 2 if self.use_skip_connection else prev_chs, 166 | model_channels, 167 | out_channels=chs, 168 | upsample=True, 169 | ) 170 | ) 171 | self.out_blocks.extend([ 172 | SparseResBlock3d( 173 | chs * 2 if self.use_skip_connection else chs, 174 | model_channels, 175 | out_channels=chs, 176 | ) 177 | for _ in range(num_io_res_blocks-1) 178 | ]) 179 | self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels) 180 | 181 | self.initialize_weights() 182 | if use_fp16: 183 | self.convert_to_fp16() 184 | 185 | @property 186 | def device(self) -> torch.device: 187 | """ 188 | Return the device of the model. 189 | """ 190 | return next(self.parameters()).device 191 | 192 | def convert_to_fp16(self) -> None: 193 | """ 194 | Convert the torso of the model to float16. 195 | """ 196 | self.input_blocks.apply(convert_module_to_f16) 197 | self.blocks.apply(convert_module_to_f16) 198 | self.out_blocks.apply(convert_module_to_f16) 199 | 200 | def convert_to_fp32(self) -> None: 201 | """ 202 | Convert the torso of the model to float32. 203 | """ 204 | self.input_blocks.apply(convert_module_to_f32) 205 | self.blocks.apply(convert_module_to_f32) 206 | self.out_blocks.apply(convert_module_to_f32) 207 | 208 | def initialize_weights(self) -> None: 209 | # Initialize transformer layers: 210 | def _basic_init(module): 211 | if isinstance(module, nn.Linear): 212 | torch.nn.init.xavier_uniform_(module.weight) 213 | if module.bias is not None: 214 | nn.init.constant_(module.bias, 0) 215 | self.apply(_basic_init) 216 | 217 | # Initialize timestep embedding MLP: 218 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 219 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 220 | 221 | # Zero-out adaLN modulation layers in DiT blocks: 222 | if self.share_mod: 223 | nn.init.constant_(self.adaLN_modulation[-1].weight, 0) 224 | nn.init.constant_(self.adaLN_modulation[-1].bias, 0) 225 | else: 226 | for block in self.blocks: 227 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 228 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 229 | 230 | # Zero-out output layers: 231 | nn.init.constant_(self.out_layer.weight, 0) 232 | nn.init.constant_(self.out_layer.bias, 0) 233 | 234 | def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor: 235 | h = self.input_layer(x).type(self.dtype) 236 | t_emb = self.t_embedder(t) 237 | if self.share_mod: 238 | t_emb = self.adaLN_modulation(t_emb) 239 | t_emb = t_emb.type(self.dtype) 240 | cond = cond.type(self.dtype) 241 | 242 | skips = [] 243 | # pack with input blocks 244 | for block in self.input_blocks: 245 | h = block(h, t_emb) 246 | skips.append(h.feats) 247 | 248 | if self.pe_mode == "ape": 249 | h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) 250 | for block in self.blocks: 251 | h = block(h, t_emb, cond) 252 | 253 | # unpack with output blocks 254 | for block, skip in zip(self.out_blocks, reversed(skips)): 255 | if self.use_skip_connection: 256 | h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb) 257 | else: 258 | h = block(h, t_emb) 259 | 260 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 261 | h = self.out_layer(h.type(x.dtype)) 262 | return h 263 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import SLatEncoder 2 | from .decoder_mesh import SLatMeshDecoder 3 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/base.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from ...modules.utils import convert_module_to_f16, convert_module_to_f32 5 | from ...modules import sparse as sp 6 | from ...modules.transformer import AbsolutePositionEmbedder 7 | from ...modules.sparse.transformer import SparseTransformerBlock 8 | 9 | 10 | def block_attn_config(self): 11 | """ 12 | Return the attention configuration of the model. 13 | """ 14 | for i in range(self.num_blocks): 15 | if self.attn_mode == "shift_window": 16 | yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER 17 | elif self.attn_mode == "shift_sequence": 18 | yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER 19 | elif self.attn_mode == "shift_order": 20 | yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] 21 | elif self.attn_mode == "full": 22 | yield "full", None, None, None, None 23 | elif self.attn_mode == "swin": 24 | yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None 25 | 26 | 27 | class SparseTransformerBase(nn.Module): 28 | """ 29 | Sparse Transformer without output layers. 30 | Serve as the base class for encoder and decoder. 31 | """ 32 | def __init__( 33 | self, 34 | in_channels: int, 35 | model_channels: int, 36 | num_blocks: int, 37 | num_heads: Optional[int] = None, 38 | num_head_channels: Optional[int] = 64, 39 | mlp_ratio: float = 4.0, 40 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 41 | window_size: Optional[int] = None, 42 | pe_mode: Literal["ape", "rope"] = "ape", 43 | use_fp16: bool = False, 44 | use_checkpoint: bool = False, 45 | qk_rms_norm: bool = False, 46 | ): 47 | super().__init__() 48 | self.in_channels = in_channels 49 | self.model_channels = model_channels 50 | self.num_blocks = num_blocks 51 | self.window_size = window_size 52 | self.num_heads = num_heads or model_channels // num_head_channels 53 | self.mlp_ratio = mlp_ratio 54 | self.attn_mode = attn_mode 55 | self.pe_mode = pe_mode 56 | self.use_fp16 = use_fp16 57 | self.use_checkpoint = use_checkpoint 58 | self.qk_rms_norm = qk_rms_norm 59 | self.dtype = torch.float16 if use_fp16 else torch.float32 60 | 61 | if pe_mode == "ape": 62 | self.pos_embedder = AbsolutePositionEmbedder(model_channels) 63 | 64 | self.input_layer = sp.SparseLinear(in_channels, model_channels) 65 | self.blocks = nn.ModuleList([ 66 | SparseTransformerBlock( 67 | model_channels, 68 | num_heads=self.num_heads, 69 | mlp_ratio=self.mlp_ratio, 70 | attn_mode=attn_mode, 71 | window_size=window_size, 72 | shift_sequence=shift_sequence, 73 | shift_window=shift_window, 74 | serialize_mode=serialize_mode, 75 | use_checkpoint=self.use_checkpoint, 76 | use_rope=(pe_mode == "rope"), 77 | qk_rms_norm=self.qk_rms_norm, 78 | ) 79 | for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) 80 | ]) 81 | 82 | @property 83 | def device(self) -> torch.device: 84 | """ 85 | Return the device of the model. 86 | """ 87 | return next(self.parameters()).device 88 | 89 | def convert_to_fp16(self) -> None: 90 | """ 91 | Convert the torso of the model to float16. 92 | """ 93 | self.blocks.apply(convert_module_to_f16) 94 | 95 | def convert_to_fp32(self) -> None: 96 | """ 97 | Convert the torso of the model to float32. 98 | """ 99 | self.blocks.apply(convert_module_to_f32) 100 | 101 | def initialize_weights(self) -> None: 102 | # Initialize transformer layers: 103 | def _basic_init(module): 104 | if isinstance(module, nn.Linear): 105 | torch.nn.init.xavier_uniform_(module.weight) 106 | if module.bias is not None: 107 | nn.init.constant_(module.bias, 0) 108 | self.apply(_basic_init) 109 | 110 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: 111 | h = self.input_layer(x) 112 | if self.pe_mode == "ape": 113 | h = h + self.pos_embedder(x.coords[:, 1:]) 114 | h = h.type(self.dtype) 115 | for block in self.blocks: 116 | h = block(h) 117 | return h 118 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/decoder_mesh.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 7 | from ...modules import sparse as sp 8 | from .base import SparseTransformerBase 9 | from ...representations import MeshExtractResult 10 | from ...representations.mesh import SparseFeatures2Mesh 11 | 12 | 13 | class SparseSubdivideBlock3d(nn.Module): 14 | """ 15 | A 3D subdivide block that can subdivide the sparse tensor. 16 | 17 | Args: 18 | channels: channels in the inputs and outputs. 19 | out_channels: if specified, the number of output channels. 20 | num_groups: the number of groups for the group norm. 21 | """ 22 | def __init__( 23 | self, 24 | channels: int, 25 | resolution: int, 26 | out_channels: Optional[int] = None, 27 | num_groups: int = 32 28 | ): 29 | super().__init__() 30 | self.channels = channels 31 | self.resolution = resolution 32 | self.out_resolution = resolution * 2 33 | self.out_channels = out_channels or channels 34 | 35 | self.act_layers = nn.Sequential( 36 | sp.SparseGroupNorm32(num_groups, channels), 37 | sp.SparseSiLU() 38 | ) 39 | 40 | self.sub = sp.SparseSubdivide() 41 | 42 | self.out_layers = nn.Sequential( 43 | sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), 44 | sp.SparseGroupNorm32(num_groups, self.out_channels), 45 | sp.SparseSiLU(), 46 | zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), 47 | ) 48 | 49 | if self.out_channels == channels: 50 | self.skip_connection = nn.Identity() 51 | else: 52 | self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") 53 | 54 | def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: 55 | """ 56 | Apply the block to a Tensor, conditioned on a timestep embedding. 57 | 58 | Args: 59 | x: an [N x C x ...] Tensor of features. 60 | Returns: 61 | an [N x C x ...] Tensor of outputs. 62 | """ 63 | h = self.act_layers(x) 64 | h = self.sub(h) 65 | x = self.sub(x) 66 | h = self.out_layers(h) 67 | h = h + self.skip_connection(x) 68 | return h 69 | 70 | 71 | class SLatMeshDecoder(SparseTransformerBase): 72 | def __init__( 73 | self, 74 | resolution: int, 75 | model_channels: int, 76 | latent_channels: int, 77 | num_blocks: int, 78 | num_heads: Optional[int] = None, 79 | num_head_channels: Optional[int] = 64, 80 | mlp_ratio: float = 4, 81 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 82 | window_size: int = 8, 83 | pe_mode: Literal["ape", "rope"] = "ape", 84 | use_fp16: bool = False, 85 | use_checkpoint: bool = False, 86 | qk_rms_norm: bool = False, 87 | representation_config: dict = None, 88 | ): 89 | super().__init__( 90 | in_channels=latent_channels, 91 | model_channels=model_channels, 92 | num_blocks=num_blocks, 93 | num_heads=num_heads, 94 | num_head_channels=num_head_channels, 95 | mlp_ratio=mlp_ratio, 96 | attn_mode=attn_mode, 97 | window_size=window_size, 98 | pe_mode=pe_mode, 99 | use_fp16=use_fp16, 100 | use_checkpoint=use_checkpoint, 101 | qk_rms_norm=qk_rms_norm, 102 | ) 103 | self.resolution = resolution 104 | self.rep_config = representation_config 105 | self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) 106 | self.out_channels = self.mesh_extractor.feats_channels 107 | self.upsample = nn.ModuleList([ 108 | SparseSubdivideBlock3d( 109 | channels=model_channels, 110 | resolution=resolution, 111 | out_channels=model_channels // 4 112 | ), 113 | SparseSubdivideBlock3d( 114 | channels=model_channels // 4, 115 | resolution=resolution * 2, 116 | out_channels=model_channels // 8 117 | ) 118 | ]) 119 | self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) 120 | 121 | self.initialize_weights() 122 | if use_fp16: 123 | self.convert_to_fp16() 124 | 125 | def initialize_weights(self) -> None: 126 | super().initialize_weights() 127 | # Zero-out output layers: 128 | nn.init.constant_(self.out_layer.weight, 0) 129 | nn.init.constant_(self.out_layer.bias, 0) 130 | 131 | def convert_to_fp16(self) -> None: 132 | """ 133 | Convert the torso of the model to float16. 134 | """ 135 | super().convert_to_fp16() 136 | self.upsample.apply(convert_module_to_f16) 137 | 138 | def convert_to_fp32(self) -> None: 139 | """ 140 | Convert the torso of the model to float32. 141 | """ 142 | super().convert_to_fp32() 143 | self.upsample.apply(convert_module_to_f32) 144 | 145 | def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: 146 | """ 147 | Convert a batch of network outputs to 3D representations. 148 | 149 | Args: 150 | x: The [N x * x C] sparse tensor output by the network. 151 | 152 | Returns: 153 | list of representations 154 | """ 155 | ret = [] 156 | for i in range(x.shape[0]): 157 | mesh = self.mesh_extractor(x[i], training=self.training) 158 | ret.append(mesh) 159 | return ret 160 | 161 | def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]: 162 | h = super().forward(x) 163 | for block in self.upsample: 164 | h = block(h) 165 | h = h.type(x.dtype) 166 | h = self.out_layer(h) 167 | return self.to_representation(h) 168 | -------------------------------------------------------------------------------- /trellis/models/structured_latent_vae/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ...modules import sparse as sp 6 | from .base import SparseTransformerBase 7 | 8 | 9 | class SLatEncoder(SparseTransformerBase): 10 | def __init__( 11 | self, 12 | resolution: int, 13 | in_channels: int, 14 | model_channels: int, 15 | latent_channels: int, 16 | num_blocks: int, 17 | num_heads: Optional[int] = None, 18 | num_head_channels: Optional[int] = 64, 19 | mlp_ratio: float = 4, 20 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", 21 | window_size: int = 8, 22 | pe_mode: Literal["ape", "rope"] = "ape", 23 | use_fp16: bool = False, 24 | use_checkpoint: bool = False, 25 | qk_rms_norm: bool = False, 26 | ): 27 | super().__init__( 28 | in_channels=in_channels, 29 | model_channels=model_channels, 30 | num_blocks=num_blocks, 31 | num_heads=num_heads, 32 | num_head_channels=num_head_channels, 33 | mlp_ratio=mlp_ratio, 34 | attn_mode=attn_mode, 35 | window_size=window_size, 36 | pe_mode=pe_mode, 37 | use_fp16=use_fp16, 38 | use_checkpoint=use_checkpoint, 39 | qk_rms_norm=qk_rms_norm, 40 | ) 41 | self.resolution = resolution 42 | self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) 43 | 44 | self.initialize_weights() 45 | if use_fp16: 46 | self.convert_to_fp16() 47 | 48 | def initialize_weights(self) -> None: 49 | super().initialize_weights() 50 | # Zero-out output layers: 51 | nn.init.constant_(self.out_layer.weight, 0) 52 | nn.init.constant_(self.out_layer.bias, 0) 53 | 54 | def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): 55 | h = super().forward(x) 56 | h = h.type(x.dtype) 57 | h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) 58 | h = self.out_layer(h) 59 | 60 | # Sample from the posterior distribution 61 | mean, logvar = h.feats.chunk(2, dim=-1) 62 | if sample_posterior: 63 | std = torch.exp(0.5 * logvar) 64 | z = mean + std * torch.randn_like(std) 65 | else: 66 | z = mean 67 | z = h.replace(z) 68 | 69 | if return_raw: 70 | return z, mean, logvar 71 | else: 72 | return z 73 | -------------------------------------------------------------------------------- /trellis/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #from .attention_utils import enable_sage_attention, disable_sage_attention 2 | from .attention import ( 3 | scaled_dot_product_attention, 4 | BACKEND, 5 | DEBUG, 6 | MultiHeadAttention, 7 | RotaryPositionEmbedder 8 | ) 9 | 10 | __all__ = [ 11 | 'scaled_dot_product_attention', 12 | 'BACKEND', 13 | 'DEBUG', 14 | 'MultiHeadAttention', 15 | 'RotaryPositionEmbedder' 16 | ] -------------------------------------------------------------------------------- /trellis/modules/attention/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Literal 4 | from trellis.backend_config import ( 5 | get_attention_backend, 6 | get_debug_mode, 7 | ) 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | #ATTN = get_attention_backend() 13 | BACKEND = get_attention_backend() 14 | DEBUG = get_debug_mode() 15 | 16 | def __from_env(): 17 | """Read current backend configuration""" 18 | #global ATTN 19 | global BACKEND 20 | global DEBUG 21 | 22 | # Get current settings from central config 23 | #ATTN = 24 | BACKEND = get_attention_backend() 25 | DEBUG = get_debug_mode() 26 | 27 | print(f"[ATTENTION] Using backend: {BACKEND}") 28 | 29 | from .modules import MultiHeadAttention, RotaryPositionEmbedder 30 | from .full_attn import scaled_dot_product_attention 31 | 32 | __all__ = [ 33 | 'scaled_dot_product_attention', 34 | 'BACKEND', 35 | 'DEBUG', 36 | 'MultiHeadAttention', 37 | 'RotaryPositionEmbedder' 38 | ] 39 | -------------------------------------------------------------------------------- /trellis/modules/attention/full_attn.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import math 4 | from trellis.backend_config import ( 5 | get_attention_backend, 6 | get_debug_mode, 7 | get_available_backends 8 | ) 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | # Get configuration from central config 14 | BACKEND = get_attention_backend() 15 | DEBUG = get_debug_mode() 16 | 17 | # Get available backends and import if active 18 | available_backends = get_available_backends() 19 | 20 | if BACKEND == "xformers" and available_backends['xformers']: 21 | import xformers.ops as xops 22 | elif BACKEND == "flash_attn" and available_backends['flash_attn']: 23 | import flash_attn 24 | elif BACKEND == "sage" and available_backends['sage']: 25 | import torch.nn.functional as F 26 | from sageattention import sageattn 27 | F.scaled_dot_product_attention = sageattn 28 | elif BACKEND == "sdpa": 29 | from torch.nn.functional import scaled_dot_product_attention as sdpa 30 | elif BACKEND == "naive": 31 | from torch.nn.functional import scaled_dot_product_attention as naive 32 | else: 33 | raise ValueError(f"Unknown attention module: {BACKEND}") 34 | 35 | 36 | __all__ = [ 37 | 'scaled_dot_product_attention', 38 | ] 39 | 40 | 41 | def _naive_sdpa(q, k, v): 42 | """ 43 | Naive implementation of scaled dot product attention. 44 | """ 45 | q = q.permute(0, 2, 1, 3) # [N, H, L, C] 46 | k = k.permute(0, 2, 1, 3) # [N, H, L, C] 47 | v = v.permute(0, 2, 1, 3) # [N, H, L, C] 48 | scale_factor = 1 / math.sqrt(q.size(-1)) 49 | attn_weight = q @ k.transpose(-2, -1) * scale_factor 50 | attn_weight = torch.softmax(attn_weight, dim=-1) 51 | out = attn_weight @ v 52 | out = out.permute(0, 2, 1, 3) # [N, L, H, C] 53 | return out 54 | 55 | 56 | @overload 57 | def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: 58 | """ 59 | Apply scaled dot product attention. 60 | 61 | Args: 62 | qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. 63 | """ 64 | ... 65 | 66 | @overload 67 | def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: 68 | """ 69 | Apply scaled dot product attention. 70 | 71 | Args: 72 | q (torch.Tensor): A [N, L, H, C] tensor containing Qs. 73 | kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. 74 | """ 75 | ... 76 | 77 | @overload 78 | def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 79 | """ 80 | Apply scaled dot product attention. 81 | 82 | Args: 83 | q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. 84 | k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. 85 | v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. 86 | 87 | Note: 88 | k and v are assumed to have the same coordinate map. 89 | """ 90 | ... 91 | 92 | def scaled_dot_product_attention(*args, **kwargs): 93 | arg_names_dict = { 94 | 1: ['qkv'], 95 | 2: ['q', 'kv'], 96 | 3: ['q', 'k', 'v'] 97 | } 98 | num_all_args = len(args) + len(kwargs) 99 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" 100 | for key in arg_names_dict[num_all_args][len(args):]: 101 | assert key in kwargs, f"Missing argument {key}" 102 | 103 | if num_all_args == 1: 104 | qkv = args[0] if len(args) > 0 else kwargs['qkv'] 105 | assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" 106 | device = qkv.device 107 | 108 | elif num_all_args == 2: 109 | q = args[0] if len(args) > 0 else kwargs['q'] 110 | kv = args[1] if len(args) > 1 else kwargs['kv'] 111 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" 112 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" 113 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" 114 | device = q.device 115 | 116 | elif num_all_args == 3: 117 | q = args[0] if len(args) > 0 else kwargs['q'] 118 | k = args[1] if len(args) > 1 else kwargs['k'] 119 | v = args[2] if len(args) > 2 else kwargs['v'] 120 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" 121 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" 122 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" 123 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" 124 | device = q.device 125 | 126 | if BACKEND == 'xformers': 127 | if num_all_args == 1: 128 | q, k, v = qkv.unbind(dim=2) 129 | elif num_all_args == 2: 130 | k, v = kv.unbind(dim=2) 131 | out = xops.memory_efficient_attention(q, k, v) 132 | elif BACKEND == 'flash_attn': 133 | if num_all_args == 1: 134 | out = flash_attn.flash_attn_qkvpacked_func(qkv) 135 | elif num_all_args == 2: 136 | out = flash_attn.flash_attn_kvpacked_func(q, kv) 137 | elif num_all_args == 3: 138 | out = flash_attn.flash_attn_func(q, k, v) 139 | elif BACKEND == 'sdpa': 140 | if num_all_args == 1: 141 | q, k, v = qkv.unbind(dim=2) 142 | elif num_all_args == 2: 143 | k, v = kv.unbind(dim=2) 144 | q = q.permute(0, 2, 1, 3) # [N, H, L, C] 145 | k = k.permute(0, 2, 1, 3) # [N, H, L, C] 146 | v = v.permute(0, 2, 1, 3) # [N, H, L, C] 147 | out = sdpa(q, k, v) # [N, H, L, C] 148 | out = out.permute(0, 2, 1, 3) # [N, L, H, C] 149 | elif BACKEND == 'naive': 150 | if num_all_args == 1: 151 | q, k, v = qkv.unbind(dim=2) 152 | elif num_all_args == 2: 153 | k, v = kv.unbind(dim=2) 154 | out = _naive_sdpa(q, k, v) 155 | else: 156 | raise ValueError(f"Unknown attention module: {BACKEND}") 157 | 158 | return out -------------------------------------------------------------------------------- /trellis/modules/attention/modules.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .full_attn import scaled_dot_product_attention 6 | 7 | 8 | class MultiHeadRMSNorm(nn.Module): 9 | def __init__(self, dim: int, heads: int): 10 | super().__init__() 11 | self.scale = dim ** 0.5 12 | self.gamma = nn.Parameter(torch.ones(heads, dim)) 13 | 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) 16 | 17 | 18 | class RotaryPositionEmbedder(nn.Module): 19 | def __init__(self, hidden_size: int, in_channels: int = 3): 20 | super().__init__() 21 | assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" 22 | self.hidden_size = hidden_size 23 | self.in_channels = in_channels 24 | self.freq_dim = hidden_size // in_channels // 2 25 | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim 26 | self.freqs = 1.0 / (10000 ** self.freqs) 27 | 28 | def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: 29 | self.freqs = self.freqs.to(indices.device) 30 | phases = torch.outer(indices, self.freqs) 31 | phases = torch.polar(torch.ones_like(phases), phases) 32 | return phases 33 | 34 | def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: 35 | x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) 36 | x_rotated = x_complex * phases 37 | x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) 38 | return x_embed 39 | 40 | def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 41 | """ 42 | Args: 43 | q (sp.SparseTensor): [..., N, D] tensor of queries 44 | k (sp.SparseTensor): [..., N, D] tensor of keys 45 | indices (torch.Tensor): [..., N, C] tensor of spatial positions 46 | """ 47 | if indices is None: 48 | indices = torch.arange(q.shape[-2], device=q.device) 49 | if len(q.shape) > 2: 50 | indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) 51 | 52 | phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) 53 | if phases.shape[1] < self.hidden_size // 2: 54 | phases = torch.cat([phases, torch.polar( 55 | torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), 56 | torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) 57 | )], dim=-1) 58 | q_embed = self._rotary_embedding(q, phases) 59 | k_embed = self._rotary_embedding(k, phases) 60 | return q_embed, k_embed 61 | 62 | 63 | class MultiHeadAttention(nn.Module): 64 | def __init__( 65 | self, 66 | channels: int, 67 | num_heads: int, 68 | ctx_channels: Optional[int]=None, 69 | type: Literal["self", "cross"] = "self", 70 | attn_mode: Literal["full", "windowed"] = "full", 71 | window_size: Optional[int] = None, 72 | shift_window: Optional[Tuple[int, int, int]] = None, 73 | qkv_bias: bool = True, 74 | use_rope: bool = False, 75 | qk_rms_norm: bool = False, 76 | ): 77 | super().__init__() 78 | assert channels % num_heads == 0 79 | assert type in ["self", "cross"], f"Invalid attention type: {type}" 80 | assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" 81 | assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" 82 | 83 | if attn_mode == "windowed": 84 | raise NotImplementedError("Windowed attention is not yet implemented") 85 | 86 | self.channels = channels 87 | self.head_dim = channels // num_heads 88 | self.ctx_channels = ctx_channels if ctx_channels is not None else channels 89 | self.num_heads = num_heads 90 | self._type = type 91 | self.attn_mode = attn_mode 92 | self.window_size = window_size 93 | self.shift_window = shift_window 94 | self.use_rope = use_rope 95 | self.qk_rms_norm = qk_rms_norm 96 | 97 | if self._type == "self": 98 | self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) 99 | else: 100 | self.to_q = nn.Linear(channels, channels, bias=qkv_bias) 101 | self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) 102 | 103 | if self.qk_rms_norm: 104 | self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) 105 | self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) 106 | 107 | self.to_out = nn.Linear(channels, channels) 108 | 109 | if use_rope: 110 | self.rope = RotaryPositionEmbedder(channels) 111 | 112 | def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: 113 | B, L, C = x.shape 114 | if self._type == "self": 115 | qkv = self.to_qkv(x) 116 | qkv = qkv.reshape(B, L, 3, self.num_heads, -1) 117 | if self.use_rope: 118 | q, k, v = qkv.unbind(dim=2) 119 | q, k = self.rope(q, k, indices) 120 | qkv = torch.stack([q, k, v], dim=2) 121 | if self.attn_mode == "full": 122 | if self.qk_rms_norm: 123 | q, k, v = qkv.unbind(dim=2) 124 | q = self.q_rms_norm(q) 125 | k = self.k_rms_norm(k) 126 | h = scaled_dot_product_attention(q, k, v) 127 | else: 128 | h = scaled_dot_product_attention(qkv) 129 | elif self.attn_mode == "windowed": 130 | raise NotImplementedError("Windowed attention is not yet implemented") 131 | else: 132 | Lkv = context.shape[1] 133 | q = self.to_q(x) 134 | kv = self.to_kv(context) 135 | q = q.reshape(B, L, self.num_heads, -1) 136 | kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) 137 | if self.qk_rms_norm: 138 | q = self.q_rms_norm(q) 139 | k, v = kv.unbind(dim=2) 140 | k = self.k_rms_norm(k) 141 | h = scaled_dot_product_attention(q, k, v) 142 | else: 143 | h = scaled_dot_product_attention(q, kv) 144 | h = h.reshape(B, L, -1) 145 | h = self.to_out(h) 146 | return h 147 | -------------------------------------------------------------------------------- /trellis/modules/attention_utils.py: -------------------------------------------------------------------------------- 1 | #sage_attn.py 2 | import os 3 | from typing import Optional 4 | import torch 5 | import torch.nn.functional as F 6 | from sageattention import sageattn 7 | import math 8 | 9 | __all__ = ['SageAttention', 'sage_attention'] 10 | 11 | 12 | def enable_sage_attention(): 13 | """ 14 | Enable SageAttention by replacing PyTorch's scaled_dot_product_attention 15 | with sageattn from the SageAttention library. 16 | """ 17 | F.scaled_dot_product_attention = sageattn 18 | return True 19 | 20 | def disable_sage_attention(): 21 | """ 22 | Restore PyTorch's original scaled_dot_product_attention function. 23 | """ 24 | F.scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention 25 | return True -------------------------------------------------------------------------------- /trellis/modules/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LayerNorm32(nn.LayerNorm): 6 | def forward(self, x: torch.Tensor) -> torch.Tensor: 7 | return super().forward(x.float()).type(x.dtype) 8 | 9 | 10 | class GroupNorm32(nn.GroupNorm): 11 | """ 12 | A GroupNorm layer that converts to float32 before the forward pass. 13 | """ 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | return super().forward(x.float()).type(x.dtype) 16 | 17 | 18 | class ChannelLayerNorm32(LayerNorm32): 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | DIM = x.dim() 21 | x = x.permute(0, *range(2, DIM), 1).contiguous() 22 | x = super().forward(x) 23 | x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() 24 | return x 25 | -------------------------------------------------------------------------------- /trellis/modules/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import importlib 3 | from trellis.backend_config import get_sparse_backend, get_attention_backend, get_debug_mode 4 | 5 | BACKEND = get_sparse_backend() 6 | DEBUG = get_debug_mode() 7 | ATTN = get_attention_backend() 8 | 9 | if ATTN not in ['xformers', 'flash_attn']: 10 | logger.warning(f"Attention backend {ATTN} not supported for sparse attention. Only 'xformers' and 'flash_attn' are available. Defaulting to 'flash_attn'") 11 | ATTN = 'flash_attn' 12 | 13 | __attributes = { 14 | 'SparseTensor': 'basic', 15 | 'sparse_batch_broadcast': 'basic', 16 | 'sparse_batch_op': 'basic', 17 | 'sparse_cat': 'basic', 18 | 'sparse_unbind': 'basic', 19 | 'SparseGroupNorm': 'norm', 20 | 'SparseLayerNorm': 'norm', 21 | 'SparseGroupNorm32': 'norm', 22 | 'SparseLayerNorm32': 'norm', 23 | 'SparseReLU': 'nonlinearity', 24 | 'SparseSiLU': 'nonlinearity', 25 | 'SparseGELU': 'nonlinearity', 26 | 'SparseActivation': 'nonlinearity', 27 | 'SparseLinear': 'linear', 28 | 'sparse_scaled_dot_product_attention': 'attention', 29 | 'SerializeMode': 'attention', 30 | 'sparse_serialized_scaled_dot_product_self_attention': 'attention', 31 | 'sparse_windowed_scaled_dot_product_self_attention': 'attention', 32 | 'SparseMultiHeadAttention': 'attention', 33 | 'SparseConv3d': 'conv', 34 | 'SparseInverseConv3d': 'conv', 35 | 'SparseDownsample': 'spatial', 36 | 'SparseUpsample': 'spatial', 37 | 'SparseSubdivide' : 'spatial' 38 | } 39 | 40 | __submodules = ['transformer'] 41 | 42 | __all__ = list(__attributes.keys()) + __submodules 43 | 44 | def __getattr__(name): 45 | if name not in globals(): 46 | if name in __attributes: 47 | module_name = __attributes[name] 48 | module = importlib.import_module(f".{module_name}", __name__) 49 | globals()[name] = getattr(module, name) 50 | elif name in __submodules: 51 | module = importlib.import_module(f".{name}", __name__) 52 | globals()[name] = module 53 | else: 54 | raise AttributeError(f"module {__name__} has no attribute {name}") 55 | return globals()[name] 56 | 57 | 58 | # For Pylance 59 | if __name__ == '__main__': 60 | from .basic import * 61 | from .norm import * 62 | from .nonlinearity import * 63 | from .linear import * 64 | from .attention import * 65 | from .conv import * 66 | from .spatial import * 67 | import transformer 68 | -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .full_attn import * 2 | from .serialized_attn import * 3 | from .windowed_attn import * 4 | from .modules import * 5 | import os 6 | import logging 7 | from typing import Literal 8 | from trellis.backend_config import ( 9 | get_attention_backend, 10 | get_debug_mode, 11 | ) 12 | import logging 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | #ATTN = get_attention_backend() 17 | ATTN = get_attention_backend() 18 | DEBUG = get_debug_mode() 19 | 20 | def __from_env(): 21 | """Read current backend configuration""" 22 | #global ATTN 23 | global ATTN 24 | global DEBUG 25 | 26 | # Get current settings from central config 27 | #ATTN = 28 | ATTN = get_attention_backend() 29 | DEBUG = get_debug_mode() 30 | 31 | print(f"[ATTENTION] sparse backend: {ATTN}") -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/full_attn.py: -------------------------------------------------------------------------------- 1 | #trellis\modules\sparse\attention\full_attn.py 2 | from typing import * 3 | import torch 4 | import math 5 | from .. import SparseTensor 6 | from trellis.backend_config import ( 7 | get_attention_backend, 8 | get_debug_mode, 9 | get_available_backends 10 | ) 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | # Get configuration from central config 16 | ATTN = get_attention_backend() 17 | DEBUG = get_debug_mode() 18 | 19 | # Get available backends and import if active 20 | available_backends = get_available_backends() 21 | 22 | if ATTN not in ['xformers', 'flash_attn']: 23 | logger.warning(f"Attention backend {ATTN} not supported for sparse attention. Only 'xformers' and 'flash_attn' are available. Defaulting to 'flash_attn'") 24 | ATTN = 'flash_attn' 25 | 26 | if ATTN == 'xformers' and available_backends['xformers']: 27 | import xformers.ops as xops 28 | elif ATTN == 'flash_attn' and available_backends['flash_attn']: 29 | import flash_attn 30 | else: 31 | raise ImportError(f"Could not import {ATTN}. Please install either xformers or flash-attn for sparse attention support.") 32 | 33 | 34 | 35 | __all__ = [ 36 | 'sparse_scaled_dot_product_attention', 37 | ] 38 | 39 | 40 | @overload 41 | def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: 42 | """ 43 | Apply scaled dot product attention to a sparse tensor. 44 | 45 | Args: 46 | qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. 47 | """ 48 | ... 49 | 50 | @overload 51 | def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: 52 | """ 53 | Apply scaled dot product attention to a sparse tensor. 54 | 55 | Args: 56 | q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. 57 | kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. 58 | """ 59 | ... 60 | 61 | @overload 62 | def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: 63 | """ 64 | Apply scaled dot product attention to a sparse tensor. 65 | 66 | Args: 67 | q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. 68 | kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. 69 | """ 70 | ... 71 | 72 | @overload 73 | def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: 74 | """ 75 | Apply scaled dot product attention to a sparse tensor. 76 | 77 | Args: 78 | q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. 79 | k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. 80 | v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. 81 | 82 | Note: 83 | k and v are assumed to have the same coordinate map. 84 | """ 85 | ... 86 | 87 | @overload 88 | def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: 89 | """ 90 | Apply scaled dot product attention to a sparse tensor. 91 | 92 | Args: 93 | q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. 94 | k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. 95 | v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. 96 | """ 97 | ... 98 | 99 | @overload 100 | def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: 101 | """ 102 | Apply scaled dot product attention to a sparse tensor. 103 | 104 | Args: 105 | q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. 106 | k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. 107 | v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. 108 | """ 109 | ... 110 | 111 | def sparse_scaled_dot_product_attention(*args, **kwargs): 112 | arg_names_dict = { 113 | 1: ['qkv'], 114 | 2: ['q', 'kv'], 115 | 3: ['q', 'k', 'v'] 116 | } 117 | num_all_args = len(args) + len(kwargs) 118 | assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" 119 | for key in arg_names_dict[num_all_args][len(args):]: 120 | assert key in kwargs, f"Missing argument {key}" 121 | 122 | if num_all_args == 1: 123 | qkv = args[0] if len(args) > 0 else kwargs['qkv'] 124 | assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" 125 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" 126 | device = qkv.device 127 | 128 | s = qkv 129 | q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] 130 | kv_seqlen = q_seqlen 131 | qkv = qkv.feats # [T, 3, H, C] 132 | 133 | elif num_all_args == 2: 134 | q = args[0] if len(args) > 0 else kwargs['q'] 135 | kv = args[1] if len(args) > 1 else kwargs['kv'] 136 | assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ 137 | isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ 138 | f"Invalid types, got {type(q)} and {type(kv)}" 139 | assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" 140 | device = q.device 141 | 142 | if isinstance(q, SparseTensor): 143 | assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" 144 | s = q 145 | q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] 146 | q = q.feats # [T_Q, H, C] 147 | else: 148 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" 149 | s = None 150 | N, L, H, C = q.shape 151 | q_seqlen = [L] * N 152 | q = q.reshape(N * L, H, C) # [T_Q, H, C] 153 | 154 | if isinstance(kv, SparseTensor): 155 | assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" 156 | kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] 157 | kv = kv.feats # [T_KV, 2, H, C] 158 | else: 159 | assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" 160 | N, L, _, H, C = kv.shape 161 | kv_seqlen = [L] * N 162 | kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] 163 | 164 | elif num_all_args == 3: 165 | q = args[0] if len(args) > 0 else kwargs['q'] 166 | k = args[1] if len(args) > 1 else kwargs['k'] 167 | v = args[2] if len(args) > 2 else kwargs['v'] 168 | assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ 169 | isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ 170 | f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" 171 | assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" 172 | device = q.device 173 | 174 | if isinstance(q, SparseTensor): 175 | assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" 176 | s = q 177 | q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] 178 | q = q.feats # [T_Q, H, Ci] 179 | else: 180 | assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" 181 | s = None 182 | N, L, H, CI = q.shape 183 | q_seqlen = [L] * N 184 | q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] 185 | 186 | if isinstance(k, SparseTensor): 187 | assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" 188 | assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" 189 | kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] 190 | k = k.feats # [T_KV, H, Ci] 191 | v = v.feats # [T_KV, H, Co] 192 | else: 193 | assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" 194 | assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" 195 | N, L, H, CI, CO = *k.shape, v.shape[-1] 196 | kv_seqlen = [L] * N 197 | k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] 198 | v = v.reshape(N * L, H, CO) # [T_KV, H, Co] 199 | 200 | if DEBUG: 201 | if s is not None: 202 | for i in range(s.shape[0]): 203 | assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" 204 | if num_all_args in [2, 3]: 205 | assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" 206 | if num_all_args == 3: 207 | assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" 208 | assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" 209 | 210 | if ATTN == 'xformers': 211 | if num_all_args == 1: 212 | q, k, v = qkv.unbind(dim=1) 213 | elif num_all_args == 2: 214 | k, v = kv.unbind(dim=1) 215 | q = q.unsqueeze(0) 216 | k = k.unsqueeze(0) 217 | v = v.unsqueeze(0) 218 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) 219 | out = xops.memory_efficient_attention(q, k, v, mask)[0] 220 | elif ATTN == 'flash_attn': 221 | cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) 222 | if num_all_args in [2, 3]: 223 | cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) 224 | if num_all_args == 1: 225 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) 226 | elif num_all_args == 2: 227 | out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) 228 | elif num_all_args == 3: 229 | out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) 230 | else: 231 | raise ValueError(f"Unknown attention module: {ATTN}") 232 | 233 | if s is not None: 234 | return s.replace(out) 235 | else: 236 | return out.reshape(N, L, H, -1) -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/modules.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .. import SparseTensor 6 | from .full_attn import sparse_scaled_dot_product_attention 7 | from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention 8 | from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention 9 | from ...attention import RotaryPositionEmbedder 10 | 11 | 12 | class SparseMultiHeadRMSNorm(nn.Module): 13 | def __init__(self, dim: int, heads: int): 14 | super().__init__() 15 | self.scale = dim ** 0.5 16 | self.gamma = nn.Parameter(torch.ones(heads, dim)) 17 | 18 | def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: 19 | x_type = x.dtype 20 | x = x.float() 21 | if isinstance(x, SparseTensor): 22 | x = x.replace(F.normalize(x.feats, dim=-1)) 23 | else: 24 | x = F.normalize(x, dim=-1) 25 | return (x * self.gamma * self.scale).to(x_type) 26 | 27 | 28 | class SparseMultiHeadAttention(nn.Module): 29 | def __init__( 30 | self, 31 | channels: int, 32 | num_heads: int, 33 | ctx_channels: Optional[int] = None, 34 | type: Literal["self", "cross"] = "self", 35 | attn_mode: Literal["full", "serialized", "windowed"] = "full", 36 | window_size: Optional[int] = None, 37 | shift_sequence: Optional[int] = None, 38 | shift_window: Optional[Tuple[int, int, int]] = None, 39 | serialize_mode: Optional[SerializeMode] = None, 40 | qkv_bias: bool = True, 41 | use_rope: bool = False, 42 | qk_rms_norm: bool = False, 43 | ): 44 | super().__init__() 45 | assert channels % num_heads == 0 46 | assert type in ["self", "cross"], f"Invalid attention type: {type}" 47 | assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" 48 | assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" 49 | assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" 50 | self.channels = channels 51 | self.ctx_channels = ctx_channels if ctx_channels is not None else channels 52 | self.num_heads = num_heads 53 | self._type = type 54 | self.attn_mode = attn_mode 55 | self.window_size = window_size 56 | self.shift_sequence = shift_sequence 57 | self.shift_window = shift_window 58 | self.serialize_mode = serialize_mode 59 | self.use_rope = use_rope 60 | self.qk_rms_norm = qk_rms_norm 61 | 62 | if self._type == "self": 63 | self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) 64 | else: 65 | self.to_q = nn.Linear(channels, channels, bias=qkv_bias) 66 | self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) 67 | 68 | if self.qk_rms_norm: 69 | self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) 70 | self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) 71 | 72 | self.to_out = nn.Linear(channels, channels) 73 | 74 | if use_rope: 75 | self.rope = RotaryPositionEmbedder(channels) 76 | 77 | @staticmethod 78 | def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: 79 | if isinstance(x, SparseTensor): 80 | return x.replace(module(x.feats)) 81 | else: 82 | return module(x) 83 | 84 | @staticmethod 85 | def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: 86 | if isinstance(x, SparseTensor): 87 | return x.reshape(*shape) 88 | else: 89 | return x.reshape(*x.shape[:2], *shape) 90 | 91 | def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: 92 | if isinstance(x, SparseTensor): 93 | x_feats = x.feats.unsqueeze(0) 94 | else: 95 | x_feats = x 96 | x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) 97 | return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats 98 | 99 | def _rope(self, qkv: SparseTensor) -> SparseTensor: 100 | q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] 101 | q, k = self.rope(q, k, qkv.coords[:, 1:]) 102 | qkv = qkv.replace(torch.stack([q, k, v], dim=1)) 103 | return qkv 104 | 105 | def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: 106 | if self._type == "self": 107 | qkv = self._linear(self.to_qkv, x) 108 | qkv = self._fused_pre(qkv, num_fused=3) 109 | if self.use_rope: 110 | qkv = self._rope(qkv) 111 | if self.qk_rms_norm: 112 | q, k, v = qkv.unbind(dim=1) 113 | q = self.q_rms_norm(q) 114 | k = self.k_rms_norm(k) 115 | qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) 116 | if self.attn_mode == "full": 117 | h = sparse_scaled_dot_product_attention(qkv) 118 | elif self.attn_mode == "serialized": 119 | h = sparse_serialized_scaled_dot_product_self_attention( 120 | qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window 121 | ) 122 | elif self.attn_mode == "windowed": 123 | h = sparse_windowed_scaled_dot_product_self_attention( 124 | qkv, self.window_size, shift_window=self.shift_window 125 | ) 126 | else: 127 | q = self._linear(self.to_q, x) 128 | q = self._reshape_chs(q, (self.num_heads, -1)) 129 | kv = self._linear(self.to_kv, context) 130 | kv = self._fused_pre(kv, num_fused=2) 131 | if self.qk_rms_norm: 132 | q = self.q_rms_norm(q) 133 | k, v = kv.unbind(dim=1) 134 | k = self.k_rms_norm(k) 135 | kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) 136 | h = sparse_scaled_dot_product_attention(q, kv) 137 | h = self._reshape_chs(h, (-1,)) 138 | h = self._linear(self.to_out, h) 139 | return h 140 | -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/serialized_attn.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from enum import Enum 3 | import torch 4 | import math 5 | from .. import SparseTensor 6 | from trellis.backend_config import ( 7 | get_attention_backend, 8 | get_debug_mode, 9 | get_available_backends 10 | ) 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | # Get configuration from central config 16 | ATTN = get_attention_backend() 17 | DEBUG = get_debug_mode() 18 | 19 | # Get available backends and import if active 20 | available_backends = get_available_backends() 21 | 22 | if ATTN not in ['xformers', 'flash_attn']: 23 | logger.warning(f"Attention backend {ATTN} not supported for sparse attention. Only 'xformers' and 'flash_attn' are available. Defaulting to 'flash_attn'") 24 | ATTN = 'flash_attn' 25 | 26 | if ATTN == 'xformers' and available_backends['xformers']: 27 | import xformers.ops as xops 28 | elif ATTN == 'flash_attn' and available_backends['flash_attn']: 29 | import flash_attn 30 | else: 31 | raise ImportError(f"Could not import {ATTN}. Please install either xformers or flash-attn for sparse attention support.") 32 | 33 | 34 | __all__ = [ 35 | 'sparse_serialized_scaled_dot_product_self_attention', 36 | ] 37 | 38 | 39 | class SerializeMode(Enum): 40 | Z_ORDER = 0 41 | Z_ORDER_TRANSPOSED = 1 42 | HILBERT = 2 43 | HILBERT_TRANSPOSED = 3 44 | 45 | 46 | SerializeModes = [ 47 | SerializeMode.Z_ORDER, 48 | SerializeMode.Z_ORDER_TRANSPOSED, 49 | SerializeMode.HILBERT, 50 | SerializeMode.HILBERT_TRANSPOSED 51 | ] 52 | 53 | 54 | def calc_serialization( 55 | tensor: SparseTensor, 56 | window_size: int, 57 | serialize_mode: SerializeMode = SerializeMode.Z_ORDER, 58 | shift_sequence: int = 0, 59 | shift_window: Tuple[int, int, int] = (0, 0, 0) 60 | ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: 61 | """ 62 | Calculate serialization and partitioning for a set of coordinates. 63 | 64 | Args: 65 | tensor (SparseTensor): The input tensor. 66 | window_size (int): The window size to use. 67 | serialize_mode (SerializeMode): The serialization mode to use. 68 | shift_sequence (int): The shift of serialized sequence. 69 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates. 70 | 71 | Returns: 72 | (torch.Tensor, torch.Tensor): Forwards and backwards indices. 73 | """ 74 | fwd_indices = [] 75 | bwd_indices = [] 76 | seq_lens = [] 77 | seq_batch_indices = [] 78 | offsets = [0] 79 | 80 | if 'vox2seq' not in globals(): 81 | import vox2seq 82 | 83 | # Serialize the input 84 | serialize_coords = tensor.coords[:, 1:].clone() 85 | serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) 86 | if serialize_mode == SerializeMode.Z_ORDER: 87 | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) 88 | elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: 89 | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) 90 | elif serialize_mode == SerializeMode.HILBERT: 91 | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) 92 | elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: 93 | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) 94 | else: 95 | raise ValueError(f"Unknown serialize mode: {serialize_mode}") 96 | 97 | for bi, s in enumerate(tensor.layout): 98 | num_points = s.stop - s.start 99 | num_windows = (num_points + window_size - 1) // window_size 100 | valid_window_size = num_points / num_windows 101 | to_ordered = torch.argsort(code[s.start:s.stop]) 102 | if num_windows == 1: 103 | fwd_indices.append(to_ordered) 104 | bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) 105 | fwd_indices[-1] += s.start 106 | bwd_indices[-1] += offsets[-1] 107 | seq_lens.append(num_points) 108 | seq_batch_indices.append(bi) 109 | offsets.append(offsets[-1] + seq_lens[-1]) 110 | else: 111 | # Partition the input 112 | offset = 0 113 | mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] 114 | split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] 115 | bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) 116 | for i in range(num_windows): 117 | mid = mids[i] 118 | valid_start = split[i] 119 | valid_end = split[i + 1] 120 | padded_start = math.floor(mid - 0.5 * window_size) 121 | padded_end = padded_start + window_size 122 | fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) 123 | offset += valid_start - padded_start 124 | bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) 125 | offset += padded_end - valid_start 126 | fwd_indices[-1] += s.start 127 | seq_lens.extend([window_size] * num_windows) 128 | seq_batch_indices.extend([bi] * num_windows) 129 | bwd_indices.append(bwd_index + offsets[-1]) 130 | offsets.append(offsets[-1] + num_windows * window_size) 131 | 132 | fwd_indices = torch.cat(fwd_indices) 133 | bwd_indices = torch.cat(bwd_indices) 134 | 135 | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices 136 | 137 | 138 | def sparse_serialized_scaled_dot_product_self_attention( 139 | qkv: SparseTensor, 140 | window_size: int, 141 | serialize_mode: SerializeMode = SerializeMode.Z_ORDER, 142 | shift_sequence: int = 0, 143 | shift_window: Tuple[int, int, int] = (0, 0, 0) 144 | ) -> SparseTensor: 145 | """ 146 | Apply serialized scaled dot product self attention to a sparse tensor. 147 | 148 | Args: 149 | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. 150 | window_size (int): The window size to use. 151 | serialize_mode (SerializeMode): The serialization mode to use. 152 | shift_sequence (int): The shift of serialized sequence. 153 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates. 154 | shift (int): The shift to use. 155 | """ 156 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" 157 | 158 | serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' 159 | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) 160 | if serialization_spatial_cache is None: 161 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) 162 | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) 163 | else: 164 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache 165 | 166 | M = fwd_indices.shape[0] 167 | T = qkv.feats.shape[0] 168 | H = qkv.feats.shape[2] 169 | C = qkv.feats.shape[3] 170 | 171 | qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] 172 | 173 | if DEBUG: 174 | start = 0 175 | qkv_coords = qkv.coords[fwd_indices] 176 | for i in range(len(seq_lens)): 177 | assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" 178 | start += seq_lens[i] 179 | 180 | if all([seq_len == window_size for seq_len in seq_lens]): 181 | B = len(seq_lens) 182 | N = window_size 183 | qkv_feats = qkv_feats.reshape(B, N, 3, H, C) 184 | if ATTN == 'xformers': 185 | q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] 186 | out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] 187 | elif ATTN == 'flash_attn': 188 | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] 189 | else: 190 | raise ValueError(f"Unknown attention module: {ATTN}") 191 | out = out.reshape(B * N, H, C) # [M, H, C] 192 | else: 193 | if ATTN == 'xformers': 194 | q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] 195 | q = q.unsqueeze(0) # [1, M, H, C] 196 | k = k.unsqueeze(0) # [1, M, H, C] 197 | v = v.unsqueeze(0) # [1, M, H, C] 198 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) 199 | out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] 200 | elif ATTN == 'flash_attn': 201 | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ 202 | .to(qkv.device).int() 203 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] 204 | 205 | out = out[bwd_indices] # [T, H, C] 206 | 207 | if DEBUG: 208 | qkv_coords = qkv_coords[bwd_indices] 209 | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" 210 | 211 | return qkv.replace(out) 212 | -------------------------------------------------------------------------------- /trellis/modules/sparse/attention/windowed_attn.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from enum import Enum 3 | import math 4 | import os 5 | import logging 6 | import torch 7 | from .. import SparseTensor 8 | from trellis.backend_config import ( 9 | get_attention_backend, 10 | get_debug_mode, 11 | get_available_backends 12 | ) 13 | import logging 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # Get configuration from central config 18 | # Get configuration from central config 19 | ATTN = get_attention_backend() 20 | DEBUG = get_debug_mode() 21 | 22 | # Get available backends and import if active 23 | available_backends = get_available_backends() 24 | 25 | if ATTN not in ['xformers', 'flash_attn']: 26 | logger.warning(f"Attention backend {ATTN} not supported for sparse attention. Only 'xformers' and 'flash_attn' are available. Defaulting to 'flash_attn'") 27 | ATTN = 'flash_attn' 28 | 29 | if ATTN == 'xformers' and available_backends['xformers']: 30 | import xformers.ops as xops 31 | elif ATTN == 'flash_attn' and available_backends['flash_attn']: 32 | import flash_attn 33 | else: 34 | raise ImportError(f"Could not import {ATTN}. Please install either xformers or flash-attn for sparse attention support.") 35 | 36 | 37 | __all__ = [ 38 | 'sparse_windowed_scaled_dot_product_self_attention', 39 | ] 40 | 41 | 42 | def calc_window_partition( 43 | tensor: SparseTensor, 44 | window_size: Union[int, Tuple[int, ...]], 45 | shift_window: Union[int, Tuple[int, ...]] = 0 46 | ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: 47 | """ 48 | Calculate serialization and partitioning for a set of coordinates. 49 | 50 | Args: 51 | tensor (SparseTensor): The input tensor. 52 | window_size (int): The window size to use. 53 | shift_window (Tuple[int, ...]): The shift of serialized coordinates. 54 | 55 | Returns: 56 | (torch.Tensor): Forwards indices. 57 | (torch.Tensor): Backwards indices. 58 | (List[int]): Sequence lengths. 59 | (List[int]): Sequence batch indices. 60 | """ 61 | DIM = tensor.coords.shape[1] - 1 62 | shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window 63 | window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size 64 | shifted_coords = tensor.coords.clone().detach() 65 | shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) 66 | 67 | MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() 68 | NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] 69 | OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] 70 | 71 | shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) 72 | shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) 73 | fwd_indices = torch.argsort(shifted_indices) 74 | bwd_indices = torch.empty_like(fwd_indices) 75 | bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) 76 | seq_lens = torch.bincount(shifted_indices) 77 | seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] 78 | mask = seq_lens != 0 79 | seq_lens = seq_lens[mask].tolist() 80 | seq_batch_indices = seq_batch_indices[mask].tolist() 81 | 82 | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices 83 | 84 | 85 | def sparse_windowed_scaled_dot_product_self_attention( 86 | qkv: SparseTensor, 87 | window_size: int, 88 | shift_window: Tuple[int, int, int] = (0, 0, 0) 89 | ) -> SparseTensor: 90 | """ 91 | Apply windowed scaled dot product self attention to a sparse tensor. 92 | 93 | Args: 94 | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. 95 | window_size (int): The window size to use. 96 | shift_window (Tuple[int, int, int]): The shift of serialized coordinates. 97 | shift (int): The shift to use. 98 | """ 99 | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" 100 | 101 | serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' 102 | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) 103 | if serialization_spatial_cache is None: 104 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) 105 | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) 106 | else: 107 | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache 108 | 109 | M = fwd_indices.shape[0] 110 | T = qkv.feats.shape[0] 111 | H = qkv.feats.shape[2] 112 | C = qkv.feats.shape[3] 113 | 114 | qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] 115 | 116 | if DEBUG: 117 | start = 0 118 | qkv_coords = qkv.coords[fwd_indices] 119 | for i in range(len(seq_lens)): 120 | seq_coords = qkv_coords[start:start+seq_lens[i]] 121 | assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" 122 | assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ 123 | f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" 124 | start += seq_lens[i] 125 | 126 | if all([seq_len == window_size for seq_len in seq_lens]): 127 | B = len(seq_lens) 128 | N = window_size 129 | qkv_feats = qkv_feats.reshape(B, N, 3, H, C) 130 | if ATTN == 'xformers': 131 | q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] 132 | out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] 133 | elif ATTN == 'flash_attn': 134 | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] 135 | else: 136 | raise ValueError(f"Unknown attention module: {ATTN}") 137 | out = out.reshape(B * N, H, C) # [M, H, C] 138 | else: 139 | if ATTN == 'xformers': 140 | q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] 141 | q = q.unsqueeze(0) # [1, M, H, C] 142 | k = k.unsqueeze(0) # [1, M, H, C] 143 | v = v.unsqueeze(0) # [1, M, H, C] 144 | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) 145 | out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] 146 | elif ATTN == 'flash_attn': 147 | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ 148 | .to(qkv.device).int() 149 | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] 150 | 151 | out = out[bwd_indices] # [T, H, C] 152 | 153 | if DEBUG: 154 | qkv_coords = qkv_coords[bwd_indices] 155 | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" 156 | 157 | return qkv.replace(out) 158 | -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from trellis.backend_config import get_sparse_backend, get_spconv_algo 4 | 5 | BACKEND = get_sparse_backend() 6 | SPCONV_ALGO = get_spconv_algo() 7 | 8 | def __from_env(): 9 | import os 10 | 11 | global SPCONV_ALGO 12 | env_spconv_algo = os.environ.get('SPCONV_ALGO') 13 | if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: 14 | SPCONV_ALGO = env_spconv_algo 15 | print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") 16 | 17 | 18 | __from_env() 19 | 20 | if BACKEND == 'torchsparse': 21 | from .conv_torchsparse import * 22 | elif BACKEND == 'spconv': 23 | from .conv_spconv import * 24 | 25 | __all__ = [ 26 | "SparseConv3d", 27 | "SparseInverseConv3d", 28 | ] 29 | 30 | -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/conv_spconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .. import SparseTensor 4 | from .. import DEBUG 5 | from trellis.backend_config import get_debug_mode, get_spconv_algo, get_sparse_backend 6 | 7 | # Get configuration from central config 8 | DEBUG = get_debug_mode() 9 | SPCONV_ALGO = get_spconv_algo() 10 | BACKEND = get_sparse_backend() 11 | 12 | class SparseConv3d(nn.Module): 13 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): 14 | super(SparseConv3d, self).__init__() 15 | if BACKEND == 'spconv': 16 | import spconv.pytorch as spconv 17 | algo = None 18 | if SPCONV_ALGO == 'native': 19 | algo = spconv.ConvAlgo.Native 20 | elif SPCONV_ALGO == 'implicit_gemm': 21 | algo = spconv.ConvAlgo.MaskImplicitGemm 22 | if stride == 1 and (padding is None): 23 | self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) 24 | else: 25 | self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) 26 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) 27 | self.padding = padding 28 | 29 | def forward(self, x: SparseTensor) -> SparseTensor: 30 | spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) 31 | new_data = self.conv(x.data) 32 | new_shape = [x.shape[0], self.conv.out_channels] 33 | new_layout = None if spatial_changed else x.layout 34 | 35 | if spatial_changed and (x.shape[0] != 1): 36 | # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords 37 | fwd = new_data.indices[:, 0].argsort() 38 | bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) 39 | sorted_feats = new_data.features[fwd] 40 | sorted_coords = new_data.indices[fwd] 41 | unsorted_data = new_data 42 | new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore 43 | 44 | out = SparseTensor( 45 | new_data, shape=torch.Size(new_shape), layout=new_layout, 46 | scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), 47 | spatial_cache=x._spatial_cache, 48 | ) 49 | 50 | if spatial_changed and (x.shape[0] != 1): 51 | out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) 52 | out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) 53 | 54 | return out 55 | 56 | 57 | class SparseInverseConv3d(nn.Module): 58 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 59 | super(SparseInverseConv3d, self).__init__() 60 | if BACKEND == 'spconv': 61 | import spconv.pytorch as spconv 62 | self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) 63 | self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) 64 | 65 | def forward(self, x: SparseTensor) -> SparseTensor: 66 | spatial_changed = any(s != 1 for s in self.stride) 67 | if spatial_changed: 68 | # recover the original spconv order 69 | data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') 70 | bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') 71 | data = data.replace_feature(x.feats[bwd]) 72 | if DEBUG: 73 | assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' 74 | else: 75 | data = x.data 76 | 77 | new_data = self.conv(data) 78 | new_shape = [x.shape[0], self.conv.out_channels] 79 | new_layout = None if spatial_changed else x.layout 80 | out = SparseTensor( 81 | new_data, shape=torch.Size(new_shape), layout=new_layout, 82 | scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), 83 | spatial_cache=x._spatial_cache, 84 | ) 85 | return out 86 | -------------------------------------------------------------------------------- /trellis/modules/sparse/conv/conv_torchsparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .. import SparseTensor 4 | from trellis.backend_config import get_debug_mode, get_spconv_algo, get_sparse_backend 5 | 6 | # Get configuration from central config 7 | DEBUG = get_debug_mode() 8 | SPCONV_ALGO = get_spconv_algo() 9 | BACKEND = get_sparse_backend() 10 | 11 | class SparseConv3d(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 13 | super(SparseConv3d, self).__init__() 14 | if BACKEND == 'torchsparse': 15 | import torchsparse 16 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) 17 | 18 | def forward(self, x: SparseTensor) -> SparseTensor: 19 | out = self.conv(x.data) 20 | new_shape = [x.shape[0], self.conv.out_channels] 21 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) 22 | out._spatial_cache = x._spatial_cache 23 | out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) 24 | return out 25 | 26 | 27 | class SparseInverseConv3d(nn.Module): 28 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): 29 | super(SparseInverseConv3d, self).__init__() 30 | if BACKEND == 'torchsparse': 31 | import torchsparse 32 | self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) 33 | 34 | def forward(self, x: SparseTensor) -> SparseTensor: 35 | out = self.conv(x.data) 36 | new_shape = [x.shape[0], self.conv.out_channels] 37 | out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) 38 | out._spatial_cache = x._spatial_cache 39 | out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) 40 | return out 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /trellis/modules/sparse/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | 5 | __all__ = [ 6 | 'SparseLinear' 7 | ] 8 | 9 | 10 | class SparseLinear(nn.Linear): 11 | def __init__(self, in_features, out_features, bias=True): 12 | super(SparseLinear, self).__init__(in_features, out_features, bias) 13 | 14 | def forward(self, input: SparseTensor) -> SparseTensor: 15 | return input.replace(super().forward(input.feats)) 16 | -------------------------------------------------------------------------------- /trellis/modules/sparse/nonlinearity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | 5 | __all__ = [ 6 | 'SparseReLU', 7 | 'SparseSiLU', 8 | 'SparseGELU', 9 | 'SparseActivation' 10 | ] 11 | 12 | 13 | class SparseReLU(nn.ReLU): 14 | def forward(self, input: SparseTensor) -> SparseTensor: 15 | return input.replace(super().forward(input.feats)) 16 | 17 | 18 | class SparseSiLU(nn.SiLU): 19 | def forward(self, input: SparseTensor) -> SparseTensor: 20 | return input.replace(super().forward(input.feats)) 21 | 22 | 23 | class SparseGELU(nn.GELU): 24 | def forward(self, input: SparseTensor) -> SparseTensor: 25 | return input.replace(super().forward(input.feats)) 26 | 27 | 28 | class SparseActivation(nn.Module): 29 | def __init__(self, activation: nn.Module): 30 | super().__init__() 31 | self.activation = activation 32 | 33 | def forward(self, input: SparseTensor) -> SparseTensor: 34 | return input.replace(self.activation(input.feats)) 35 | 36 | -------------------------------------------------------------------------------- /trellis/modules/sparse/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import SparseTensor 4 | from . import DEBUG 5 | 6 | __all__ = [ 7 | 'SparseGroupNorm', 8 | 'SparseLayerNorm', 9 | 'SparseGroupNorm32', 10 | 'SparseLayerNorm32', 11 | ] 12 | 13 | 14 | class SparseGroupNorm(nn.GroupNorm): 15 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): 16 | super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) 17 | 18 | def forward(self, input: SparseTensor) -> SparseTensor: 19 | nfeats = torch.zeros_like(input.feats) 20 | for k in range(input.shape[0]): 21 | if DEBUG: 22 | assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" 23 | bfeats = input.feats[input.layout[k]] 24 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) 25 | bfeats = super().forward(bfeats) 26 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) 27 | nfeats[input.layout[k]] = bfeats 28 | return input.replace(nfeats) 29 | 30 | 31 | class SparseLayerNorm(nn.LayerNorm): 32 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 33 | super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) 34 | 35 | def forward(self, input: SparseTensor) -> SparseTensor: 36 | nfeats = torch.zeros_like(input.feats) 37 | for k in range(input.shape[0]): 38 | bfeats = input.feats[input.layout[k]] 39 | bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) 40 | bfeats = super().forward(bfeats) 41 | bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) 42 | nfeats[input.layout[k]] = bfeats 43 | return input.replace(nfeats) 44 | 45 | 46 | class SparseGroupNorm32(SparseGroupNorm): 47 | """ 48 | A GroupNorm layer that converts to float32 before the forward pass. 49 | """ 50 | def forward(self, x: SparseTensor) -> SparseTensor: 51 | return super().forward(x.float()).type(x.dtype) 52 | 53 | class SparseLayerNorm32(SparseLayerNorm): 54 | """ 55 | A LayerNorm layer that converts to float32 before the forward pass. 56 | """ 57 | def forward(self, x: SparseTensor) -> SparseTensor: 58 | return super().forward(x.float()).type(x.dtype) 59 | -------------------------------------------------------------------------------- /trellis/modules/sparse/spatial.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from . import SparseTensor 5 | 6 | __all__ = [ 7 | 'SparseDownsample', 8 | 'SparseUpsample', 9 | 'SparseSubdivide' 10 | ] 11 | 12 | 13 | class SparseDownsample(nn.Module): 14 | """ 15 | Downsample a sparse tensor by a factor of `factor`. 16 | Implemented as average pooling. 17 | """ 18 | def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): 19 | super(SparseDownsample, self).__init__() 20 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor 21 | 22 | def forward(self, input: SparseTensor) -> SparseTensor: 23 | DIM = input.coords.shape[-1] - 1 24 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM 25 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' 26 | 27 | coord = list(input.coords.unbind(dim=-1)) 28 | for i, f in enumerate(factor): 29 | coord[i+1] = coord[i+1] // f 30 | 31 | MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] 32 | OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] 33 | code = sum([c * o for c, o in zip(coord, OFFSET)]) 34 | code, idx = code.unique(return_inverse=True) 35 | 36 | new_feats = torch.scatter_reduce( 37 | torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), 38 | dim=0, 39 | index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), 40 | src=input.feats, 41 | reduce='mean' 42 | ) 43 | new_coords = torch.stack( 44 | [code // OFFSET[0]] + 45 | [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], 46 | dim=-1 47 | ) 48 | out = SparseTensor(new_feats, new_coords, input.shape,) 49 | out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) 50 | out._spatial_cache = input._spatial_cache 51 | 52 | out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) 53 | out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) 54 | out.register_spatial_cache(f'upsample_{factor}_idx', idx) 55 | 56 | return out 57 | 58 | 59 | class SparseUpsample(nn.Module): 60 | """ 61 | Upsample a sparse tensor by a factor of `factor`. 62 | Implemented as nearest neighbor interpolation. 63 | """ 64 | def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): 65 | super(SparseUpsample, self).__init__() 66 | self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor 67 | 68 | def forward(self, input: SparseTensor) -> SparseTensor: 69 | DIM = input.coords.shape[-1] - 1 70 | factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM 71 | assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' 72 | 73 | new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') 74 | new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') 75 | idx = input.get_spatial_cache(f'upsample_{factor}_idx') 76 | if any([x is None for x in [new_coords, new_layout, idx]]): 77 | raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') 78 | new_feats = input.feats[idx] 79 | out = SparseTensor(new_feats, new_coords, input.shape, new_layout) 80 | out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) 81 | out._spatial_cache = input._spatial_cache 82 | return out 83 | 84 | class SparseSubdivide(nn.Module): 85 | """ 86 | Upsample a sparse tensor by a factor of `factor`. 87 | Implemented as nearest neighbor interpolation. 88 | """ 89 | def __init__(self): 90 | super(SparseSubdivide, self).__init__() 91 | 92 | def forward(self, input: SparseTensor) -> SparseTensor: 93 | DIM = input.coords.shape[-1] - 1 94 | # upsample scale=2^DIM 95 | n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) 96 | n_coords = torch.nonzero(n_cube) 97 | n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) 98 | factor = n_coords.shape[0] 99 | assert factor == 2 ** DIM 100 | # print(n_coords.shape) 101 | new_coords = input.coords.clone() 102 | new_coords[:, 1:] *= 2 103 | new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) 104 | 105 | new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) 106 | out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) 107 | out._scale = input._scale * 2 108 | out._spatial_cache = input._spatial_cache 109 | return out 110 | 111 | -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .modulated import * -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from ..basic import SparseTensor 5 | from ..linear import SparseLinear 6 | from ..nonlinearity import SparseGELU 7 | from ..attention import SparseMultiHeadAttention, SerializeMode 8 | from ...norm import LayerNorm32 9 | 10 | 11 | class SparseFeedForwardNet(nn.Module): 12 | def __init__(self, channels: int, mlp_ratio: float = 4.0): 13 | super().__init__() 14 | self.mlp = nn.Sequential( 15 | SparseLinear(channels, int(channels * mlp_ratio)), 16 | SparseGELU(approximate="tanh"), 17 | SparseLinear(int(channels * mlp_ratio), channels), 18 | ) 19 | 20 | def forward(self, x: SparseTensor) -> SparseTensor: 21 | return self.mlp(x) 22 | 23 | 24 | class SparseTransformerBlock(nn.Module): 25 | """ 26 | Sparse Transformer block (MSA + FFN). 27 | """ 28 | def __init__( 29 | self, 30 | channels: int, 31 | num_heads: int, 32 | mlp_ratio: float = 4.0, 33 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 34 | window_size: Optional[int] = None, 35 | shift_sequence: Optional[int] = None, 36 | shift_window: Optional[Tuple[int, int, int]] = None, 37 | serialize_mode: Optional[SerializeMode] = None, 38 | use_checkpoint: bool = False, 39 | use_rope: bool = False, 40 | qk_rms_norm: bool = False, 41 | qkv_bias: bool = True, 42 | ln_affine: bool = False, 43 | ): 44 | super().__init__() 45 | self.use_checkpoint = use_checkpoint 46 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 47 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 48 | self.attn = SparseMultiHeadAttention( 49 | channels, 50 | num_heads=num_heads, 51 | attn_mode=attn_mode, 52 | window_size=window_size, 53 | shift_sequence=shift_sequence, 54 | shift_window=shift_window, 55 | serialize_mode=serialize_mode, 56 | qkv_bias=qkv_bias, 57 | use_rope=use_rope, 58 | qk_rms_norm=qk_rms_norm, 59 | ) 60 | self.mlp = SparseFeedForwardNet( 61 | channels, 62 | mlp_ratio=mlp_ratio, 63 | ) 64 | 65 | def _forward(self, x: SparseTensor) -> SparseTensor: 66 | h = x.replace(self.norm1(x.feats)) 67 | h = self.attn(h) 68 | x = x + h 69 | h = x.replace(self.norm2(x.feats)) 70 | h = self.mlp(h) 71 | x = x + h 72 | return x 73 | 74 | def forward(self, x: SparseTensor) -> SparseTensor: 75 | if self.use_checkpoint: 76 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) 77 | else: 78 | return self._forward(x) 79 | 80 | 81 | class SparseTransformerCrossBlock(nn.Module): 82 | """ 83 | Sparse Transformer cross-attention block (MSA + MCA + FFN). 84 | """ 85 | def __init__( 86 | self, 87 | channels: int, 88 | ctx_channels: int, 89 | num_heads: int, 90 | mlp_ratio: float = 4.0, 91 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 92 | window_size: Optional[int] = None, 93 | shift_sequence: Optional[int] = None, 94 | shift_window: Optional[Tuple[int, int, int]] = None, 95 | serialize_mode: Optional[SerializeMode] = None, 96 | use_checkpoint: bool = False, 97 | use_rope: bool = False, 98 | qk_rms_norm: bool = False, 99 | qk_rms_norm_cross: bool = False, 100 | qkv_bias: bool = True, 101 | ln_affine: bool = False, 102 | ): 103 | super().__init__() 104 | self.use_checkpoint = use_checkpoint 105 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 106 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 107 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 108 | self.self_attn = SparseMultiHeadAttention( 109 | channels, 110 | num_heads=num_heads, 111 | type="self", 112 | attn_mode=attn_mode, 113 | window_size=window_size, 114 | shift_sequence=shift_sequence, 115 | shift_window=shift_window, 116 | serialize_mode=serialize_mode, 117 | qkv_bias=qkv_bias, 118 | use_rope=use_rope, 119 | qk_rms_norm=qk_rms_norm, 120 | ) 121 | self.cross_attn = SparseMultiHeadAttention( 122 | channels, 123 | ctx_channels=ctx_channels, 124 | num_heads=num_heads, 125 | type="cross", 126 | attn_mode="full", 127 | qkv_bias=qkv_bias, 128 | qk_rms_norm=qk_rms_norm_cross, 129 | ) 130 | self.mlp = SparseFeedForwardNet( 131 | channels, 132 | mlp_ratio=mlp_ratio, 133 | ) 134 | 135 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): 136 | h = x.replace(self.norm1(x.feats)) 137 | h = self.self_attn(h) 138 | x = x + h 139 | h = x.replace(self.norm2(x.feats)) 140 | h = self.cross_attn(h, context) 141 | x = x + h 142 | h = x.replace(self.norm3(x.feats)) 143 | h = self.mlp(h) 144 | x = x + h 145 | return x 146 | 147 | def forward(self, x: SparseTensor, context: torch.Tensor): 148 | if self.use_checkpoint: 149 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) 150 | else: 151 | return self._forward(x, context) 152 | -------------------------------------------------------------------------------- /trellis/modules/sparse/transformer/modulated.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from ..basic import SparseTensor 5 | from ..attention import SparseMultiHeadAttention, SerializeMode 6 | from ...norm import LayerNorm32 7 | from .blocks import SparseFeedForwardNet 8 | 9 | 10 | class ModulatedSparseTransformerBlock(nn.Module): 11 | """ 12 | Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. 13 | """ 14 | def __init__( 15 | self, 16 | channels: int, 17 | num_heads: int, 18 | mlp_ratio: float = 4.0, 19 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 20 | window_size: Optional[int] = None, 21 | shift_sequence: Optional[int] = None, 22 | shift_window: Optional[Tuple[int, int, int]] = None, 23 | serialize_mode: Optional[SerializeMode] = None, 24 | use_checkpoint: bool = False, 25 | use_rope: bool = False, 26 | qk_rms_norm: bool = False, 27 | qkv_bias: bool = True, 28 | share_mod: bool = False, 29 | ): 30 | super().__init__() 31 | self.use_checkpoint = use_checkpoint 32 | self.share_mod = share_mod 33 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 34 | self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 35 | self.attn = SparseMultiHeadAttention( 36 | channels, 37 | num_heads=num_heads, 38 | attn_mode=attn_mode, 39 | window_size=window_size, 40 | shift_sequence=shift_sequence, 41 | shift_window=shift_window, 42 | serialize_mode=serialize_mode, 43 | qkv_bias=qkv_bias, 44 | use_rope=use_rope, 45 | qk_rms_norm=qk_rms_norm, 46 | ) 47 | self.mlp = SparseFeedForwardNet( 48 | channels, 49 | mlp_ratio=mlp_ratio, 50 | ) 51 | if not share_mod: 52 | self.adaLN_modulation = nn.Sequential( 53 | nn.SiLU(), 54 | nn.Linear(channels, 6 * channels, bias=True) 55 | ) 56 | 57 | def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: 58 | if self.share_mod: 59 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) 60 | else: 61 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) 62 | h = x.replace(self.norm1(x.feats)) 63 | h = h * (1 + scale_msa) + shift_msa 64 | h = self.attn(h) 65 | h = h * gate_msa 66 | x = x + h 67 | h = x.replace(self.norm2(x.feats)) 68 | h = h * (1 + scale_mlp) + shift_mlp 69 | h = self.mlp(h) 70 | h = h * gate_mlp 71 | x = x + h 72 | return x 73 | 74 | def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: 75 | if self.use_checkpoint: 76 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) 77 | else: 78 | return self._forward(x, mod) 79 | 80 | 81 | class ModulatedSparseTransformerCrossBlock(nn.Module): 82 | """ 83 | Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. 84 | """ 85 | def __init__( 86 | self, 87 | channels: int, 88 | ctx_channels: int, 89 | num_heads: int, 90 | mlp_ratio: float = 4.0, 91 | attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", 92 | window_size: Optional[int] = None, 93 | shift_sequence: Optional[int] = None, 94 | shift_window: Optional[Tuple[int, int, int]] = None, 95 | serialize_mode: Optional[SerializeMode] = None, 96 | use_checkpoint: bool = False, 97 | use_rope: bool = False, 98 | qk_rms_norm: bool = False, 99 | qk_rms_norm_cross: bool = False, 100 | qkv_bias: bool = True, 101 | share_mod: bool = False, 102 | 103 | ): 104 | super().__init__() 105 | self.use_checkpoint = use_checkpoint 106 | self.share_mod = share_mod 107 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 108 | self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) 109 | self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 110 | self.self_attn = SparseMultiHeadAttention( 111 | channels, 112 | num_heads=num_heads, 113 | type="self", 114 | attn_mode=attn_mode, 115 | window_size=window_size, 116 | shift_sequence=shift_sequence, 117 | shift_window=shift_window, 118 | serialize_mode=serialize_mode, 119 | qkv_bias=qkv_bias, 120 | use_rope=use_rope, 121 | qk_rms_norm=qk_rms_norm, 122 | ) 123 | self.cross_attn = SparseMultiHeadAttention( 124 | channels, 125 | ctx_channels=ctx_channels, 126 | num_heads=num_heads, 127 | type="cross", 128 | attn_mode="full", 129 | qkv_bias=qkv_bias, 130 | qk_rms_norm=qk_rms_norm_cross, 131 | ) 132 | self.mlp = SparseFeedForwardNet( 133 | channels, 134 | mlp_ratio=mlp_ratio, 135 | ) 136 | if not share_mod: 137 | self.adaLN_modulation = nn.Sequential( 138 | nn.SiLU(), 139 | nn.Linear(channels, 6 * channels, bias=True) 140 | ) 141 | 142 | def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: 143 | if self.share_mod: 144 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) 145 | else: 146 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) 147 | h = x.replace(self.norm1(x.feats)) 148 | h = h * (1 + scale_msa) + shift_msa 149 | h = self.self_attn(h) 150 | h = h * gate_msa 151 | x = x + h 152 | h = x.replace(self.norm2(x.feats)) 153 | h = self.cross_attn(h, context) 154 | x = x + h 155 | h = x.replace(self.norm3(x.feats)) 156 | h = h * (1 + scale_mlp) + shift_mlp 157 | h = self.mlp(h) 158 | h = h * gate_mlp 159 | x = x + h 160 | return x 161 | 162 | def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: 163 | if self.use_checkpoint: 164 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) 165 | else: 166 | return self._forward(x, mod, context) 167 | -------------------------------------------------------------------------------- /trellis/modules/spatial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: 5 | """ 6 | 3D pixel shuffle. 7 | """ 8 | B, C, H, W, D = x.shape 9 | C_ = C // scale_factor**3 10 | x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) 11 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) 12 | x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) 13 | return x 14 | 15 | 16 | def patchify(x: torch.Tensor, patch_size: int): 17 | """ 18 | Patchify a tensor. 19 | 20 | Args: 21 | x (torch.Tensor): (N, C, *spatial) tensor 22 | patch_size (int): Patch size 23 | """ 24 | DIM = x.dim() - 2 25 | for d in range(2, DIM + 2): 26 | assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" 27 | 28 | x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) 29 | x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) 30 | x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) 31 | return x 32 | 33 | 34 | def unpatchify(x: torch.Tensor, patch_size: int): 35 | """ 36 | Unpatchify a tensor. 37 | 38 | Args: 39 | x (torch.Tensor): (N, C, *spatial) tensor 40 | patch_size (int): Patch size 41 | """ 42 | DIM = x.dim() - 2 43 | assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" 44 | 45 | x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) 46 | x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) 47 | x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) 48 | return x 49 | -------------------------------------------------------------------------------- /trellis/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .modulated import * -------------------------------------------------------------------------------- /trellis/modules/transformer/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from ..attention import MultiHeadAttention 5 | from ..norm import LayerNorm32 6 | 7 | 8 | class AbsolutePositionEmbedder(nn.Module): 9 | """ 10 | Embeds spatial positions into vector representations. 11 | """ 12 | def __init__(self, channels: int, in_channels: int = 3): 13 | super().__init__() 14 | self.channels = channels 15 | self.in_channels = in_channels 16 | self.freq_dim = channels // in_channels // 2 17 | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim 18 | self.freqs = 1.0 / (10000 ** self.freqs) 19 | 20 | def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: 21 | """ 22 | Create sinusoidal position embeddings. 23 | 24 | Args: 25 | x: a 1-D Tensor of N indices 26 | 27 | Returns: 28 | an (N, D) Tensor of positional embeddings. 29 | """ 30 | self.freqs = self.freqs.to(x.device) 31 | out = torch.outer(x, self.freqs) 32 | out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) 33 | return out 34 | 35 | def forward(self, x: torch.Tensor) -> torch.Tensor: 36 | """ 37 | Args: 38 | x (torch.Tensor): (N, D) tensor of spatial positions 39 | """ 40 | N, D = x.shape 41 | assert D == self.in_channels, "Input dimension must match number of input channels" 42 | embed = self._sin_cos_embedding(x.reshape(-1)) 43 | embed = embed.reshape(N, -1) 44 | if embed.shape[1] < self.channels: 45 | embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) 46 | return embed 47 | 48 | 49 | class FeedForwardNet(nn.Module): 50 | def __init__(self, channels: int, mlp_ratio: float = 4.0): 51 | super().__init__() 52 | self.mlp = nn.Sequential( 53 | nn.Linear(channels, int(channels * mlp_ratio)), 54 | nn.GELU(approximate="tanh"), 55 | nn.Linear(int(channels * mlp_ratio), channels), 56 | ) 57 | 58 | def forward(self, x: torch.Tensor) -> torch.Tensor: 59 | return self.mlp(x) 60 | 61 | 62 | class TransformerBlock(nn.Module): 63 | """ 64 | Transformer block (MSA + FFN). 65 | """ 66 | def __init__( 67 | self, 68 | channels: int, 69 | num_heads: int, 70 | mlp_ratio: float = 4.0, 71 | attn_mode: Literal["full", "windowed"] = "full", 72 | window_size: Optional[int] = None, 73 | shift_window: Optional[int] = None, 74 | use_checkpoint: bool = False, 75 | use_rope: bool = False, 76 | qk_rms_norm: bool = False, 77 | qkv_bias: bool = True, 78 | ln_affine: bool = False, 79 | ): 80 | super().__init__() 81 | self.use_checkpoint = use_checkpoint 82 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 83 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 84 | self.attn = MultiHeadAttention( 85 | channels, 86 | num_heads=num_heads, 87 | attn_mode=attn_mode, 88 | window_size=window_size, 89 | shift_window=shift_window, 90 | qkv_bias=qkv_bias, 91 | use_rope=use_rope, 92 | qk_rms_norm=qk_rms_norm, 93 | ) 94 | self.mlp = FeedForwardNet( 95 | channels, 96 | mlp_ratio=mlp_ratio, 97 | ) 98 | 99 | def _forward(self, x: torch.Tensor) -> torch.Tensor: 100 | h = self.norm1(x) 101 | h = self.attn(h) 102 | x = x + h 103 | h = self.norm2(x) 104 | h = self.mlp(h) 105 | x = x + h 106 | return x 107 | 108 | def forward(self, x: torch.Tensor) -> torch.Tensor: 109 | if self.use_checkpoint: 110 | return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) 111 | else: 112 | return self._forward(x) 113 | 114 | 115 | class TransformerCrossBlock(nn.Module): 116 | """ 117 | Transformer cross-attention block (MSA + MCA + FFN). 118 | """ 119 | def __init__( 120 | self, 121 | channels: int, 122 | ctx_channels: int, 123 | num_heads: int, 124 | mlp_ratio: float = 4.0, 125 | attn_mode: Literal["full", "windowed"] = "full", 126 | window_size: Optional[int] = None, 127 | shift_window: Optional[Tuple[int, int, int]] = None, 128 | use_checkpoint: bool = False, 129 | use_rope: bool = False, 130 | qk_rms_norm: bool = False, 131 | qk_rms_norm_cross: bool = False, 132 | qkv_bias: bool = True, 133 | ln_affine: bool = False, 134 | ): 135 | super().__init__() 136 | self.use_checkpoint = use_checkpoint 137 | self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 138 | self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 139 | self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) 140 | self.self_attn = MultiHeadAttention( 141 | channels, 142 | num_heads=num_heads, 143 | type="self", 144 | attn_mode=attn_mode, 145 | window_size=window_size, 146 | shift_window=shift_window, 147 | qkv_bias=qkv_bias, 148 | use_rope=use_rope, 149 | qk_rms_norm=qk_rms_norm, 150 | ) 151 | self.cross_attn = MultiHeadAttention( 152 | channels, 153 | ctx_channels=ctx_channels, 154 | num_heads=num_heads, 155 | type="cross", 156 | attn_mode="full", 157 | qkv_bias=qkv_bias, 158 | qk_rms_norm=qk_rms_norm_cross, 159 | ) 160 | self.mlp = FeedForwardNet( 161 | channels, 162 | mlp_ratio=mlp_ratio, 163 | ) 164 | 165 | def _forward(self, x: torch.Tensor, context: torch.Tensor): 166 | h = self.norm1(x) 167 | h = self.self_attn(h) 168 | x = x + h 169 | h = self.norm2(x) 170 | h = self.cross_attn(h, context) 171 | x = x + h 172 | h = self.norm3(x) 173 | h = self.mlp(h) 174 | x = x + h 175 | return x 176 | 177 | def forward(self, x: torch.Tensor, context: torch.Tensor): 178 | if self.use_checkpoint: 179 | return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) 180 | else: 181 | return self._forward(x, context) 182 | -------------------------------------------------------------------------------- /trellis/modules/transformer/modulated.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from ..attention import MultiHeadAttention 5 | from ..norm import LayerNorm32 6 | from .blocks import FeedForwardNet 7 | 8 | 9 | class ModulatedTransformerBlock(nn.Module): 10 | """ 11 | Transformer block (MSA + FFN) with adaptive layer norm conditioning. 12 | """ 13 | def __init__( 14 | self, 15 | channels: int, 16 | num_heads: int, 17 | mlp_ratio: float = 4.0, 18 | attn_mode: Literal["full", "windowed"] = "full", 19 | window_size: Optional[int] = None, 20 | shift_window: Optional[Tuple[int, int, int]] = None, 21 | use_checkpoint: bool = False, 22 | use_rope: bool = False, 23 | qk_rms_norm: bool = False, 24 | qkv_bias: bool = True, 25 | share_mod: bool = False, 26 | ): 27 | super().__init__() 28 | self.use_checkpoint = use_checkpoint 29 | self.share_mod = share_mod 30 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 31 | self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 32 | self.attn = MultiHeadAttention( 33 | channels, 34 | num_heads=num_heads, 35 | attn_mode=attn_mode, 36 | window_size=window_size, 37 | shift_window=shift_window, 38 | qkv_bias=qkv_bias, 39 | use_rope=use_rope, 40 | qk_rms_norm=qk_rms_norm, 41 | ) 42 | self.mlp = FeedForwardNet( 43 | channels, 44 | mlp_ratio=mlp_ratio, 45 | ) 46 | if not share_mod: 47 | self.adaLN_modulation = nn.Sequential( 48 | nn.SiLU(), 49 | nn.Linear(channels, 6 * channels, bias=True) 50 | ) 51 | 52 | def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: 53 | if self.share_mod: 54 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) 55 | else: 56 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) 57 | h = self.norm1(x) 58 | h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) 59 | h = self.attn(h) 60 | h = h * gate_msa.unsqueeze(1) 61 | x = x + h 62 | h = self.norm2(x) 63 | h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) 64 | h = self.mlp(h) 65 | h = h * gate_mlp.unsqueeze(1) 66 | x = x + h 67 | return x 68 | 69 | def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: 70 | if self.use_checkpoint: 71 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) 72 | else: 73 | return self._forward(x, mod) 74 | 75 | 76 | class ModulatedTransformerCrossBlock(nn.Module): 77 | """ 78 | Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. 79 | """ 80 | def __init__( 81 | self, 82 | channels: int, 83 | ctx_channels: int, 84 | num_heads: int, 85 | mlp_ratio: float = 4.0, 86 | attn_mode: Literal["full", "windowed"] = "full", 87 | window_size: Optional[int] = None, 88 | shift_window: Optional[Tuple[int, int, int]] = None, 89 | use_checkpoint: bool = False, 90 | use_rope: bool = False, 91 | qk_rms_norm: bool = False, 92 | qk_rms_norm_cross: bool = False, 93 | qkv_bias: bool = True, 94 | share_mod: bool = False, 95 | ): 96 | super().__init__() 97 | self.use_checkpoint = use_checkpoint 98 | self.share_mod = share_mod 99 | self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 100 | self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) 101 | self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) 102 | self.self_attn = MultiHeadAttention( 103 | channels, 104 | num_heads=num_heads, 105 | type="self", 106 | attn_mode=attn_mode, 107 | window_size=window_size, 108 | shift_window=shift_window, 109 | qkv_bias=qkv_bias, 110 | use_rope=use_rope, 111 | qk_rms_norm=qk_rms_norm, 112 | ) 113 | self.cross_attn = MultiHeadAttention( 114 | channels, 115 | ctx_channels=ctx_channels, 116 | num_heads=num_heads, 117 | type="cross", 118 | attn_mode="full", 119 | qkv_bias=qkv_bias, 120 | qk_rms_norm=qk_rms_norm_cross, 121 | ) 122 | self.mlp = FeedForwardNet( 123 | channels, 124 | mlp_ratio=mlp_ratio, 125 | ) 126 | if not share_mod: 127 | self.adaLN_modulation = nn.Sequential( 128 | nn.SiLU(), 129 | nn.Linear(channels, 6 * channels, bias=True) 130 | ) 131 | 132 | def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): 133 | if self.share_mod: 134 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) 135 | else: 136 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) 137 | h = self.norm1(x) 138 | h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) 139 | h = self.self_attn(h) 140 | h = h * gate_msa.unsqueeze(1) 141 | x = x + h 142 | h = self.norm2(x) 143 | h = self.cross_attn(h, context) 144 | x = x + h 145 | h = self.norm3(x) 146 | h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) 147 | h = self.mlp(h) 148 | h = h * gate_mlp.unsqueeze(1) 149 | x = x + h 150 | return x 151 | 152 | def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): 153 | if self.use_checkpoint: 154 | return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) 155 | else: 156 | return self._forward(x, mod, context) 157 | -------------------------------------------------------------------------------- /trellis/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..modules import sparse as sp 3 | 4 | FP16_MODULES = ( 5 | nn.Conv1d, 6 | nn.Conv2d, 7 | nn.Conv3d, 8 | nn.ConvTranspose1d, 9 | nn.ConvTranspose2d, 10 | nn.ConvTranspose3d, 11 | nn.Linear, 12 | sp.SparseConv3d, 13 | sp.SparseInverseConv3d, 14 | sp.SparseLinear, 15 | ) 16 | 17 | def convert_module_to_f16(l): 18 | """ 19 | Convert primitive modules to float16. 20 | """ 21 | if isinstance(l, FP16_MODULES): 22 | for p in l.parameters(): 23 | p.data = p.data.half() 24 | 25 | 26 | def convert_module_to_f32(l): 27 | """ 28 | Convert primitive modules to float32, undoing convert_module_to_f16(). 29 | """ 30 | if isinstance(l, FP16_MODULES): 31 | for p in l.parameters(): 32 | p.data = p.data.float() 33 | 34 | 35 | def zero_module(module): 36 | """ 37 | Zero out the parameters of a module and return it. 38 | """ 39 | for p in module.parameters(): 40 | p.detach().zero_() 41 | return module 42 | 43 | 44 | def scale_module(module, scale): 45 | """ 46 | Scale the parameters of a module and return it. 47 | """ 48 | for p in module.parameters(): 49 | p.detach().mul_(scale) 50 | return module 51 | 52 | 53 | def modulate(x, shift, scale): 54 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 55 | -------------------------------------------------------------------------------- /trellis/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from . import samplers 2 | from .trellis_image_to_3d import TrellisImageTo3DPipeline 3 | 4 | 5 | def from_pretrained(path: str): 6 | """ 7 | Load a pipeline from a model folder or a Hugging Face model hub. 8 | 9 | Args: 10 | path: The path to the model. Can be either local path or a Hugging Face model name. 11 | """ 12 | import os 13 | import json 14 | is_local = os.path.exists(f"{path}/pipeline.json") 15 | 16 | if is_local: 17 | config_file = f"{path}/pipeline.json" 18 | else: 19 | from huggingface_hub import hf_hub_download 20 | config_file = hf_hub_download(path, "pipeline.json") 21 | 22 | with open(config_file, 'r') as f: 23 | config = json.load(f) 24 | return globals()[config['name']].from_pretrained(path) 25 | -------------------------------------------------------------------------------- /trellis/pipelines/base.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn as nn 4 | from .. import models 5 | 6 | 7 | class Pipeline: 8 | """ 9 | A base class for pipelines. 10 | """ 11 | def __init__( 12 | self, 13 | models: dict[str, nn.Module] = None, 14 | ): 15 | if models is None: 16 | return 17 | self.models = models 18 | for model in self.models.values(): 19 | model.eval() 20 | 21 | @staticmethod 22 | def from_pretrained(path: str) -> "Pipeline": 23 | """ 24 | Load a pretrained model. 25 | """ 26 | import os 27 | import json 28 | is_local = os.path.exists(f"{path}/pipeline.json") 29 | 30 | if is_local: 31 | config_file = f"{path}/pipeline.json" 32 | with open(config_file, 'r') as f: 33 | #print("Loading pipeline.json:") 34 | args = json.load(f)['args'] 35 | #print("Models in pipeline.json:") 36 | '''for k, v in args['models'].items(): 37 | print(f" {k}: {v}")''' 38 | 39 | # For local paths, construct the models dict differently 40 | _models = { 41 | k: models.from_pretrained(os.path.join(path, v)) 42 | for k, v in args['models'].items() 43 | } 44 | else: 45 | from huggingface_hub import hf_hub_download 46 | config_file = hf_hub_download(path, "pipeline.json") 47 | with open(config_file, 'r') as f: 48 | args = json.load(f)['args'] 49 | 50 | _models = { 51 | k: models.from_pretrained(os.path.join(path, v)) 52 | for k, v in args['models'].items() 53 | } 54 | 55 | new_pipeline = Pipeline(_models) 56 | new_pipeline._pretrained_args = args 57 | return new_pipeline 58 | 59 | @property 60 | def device(self) -> torch.device: 61 | for model in self.models.values(): 62 | if hasattr(model, 'device'): 63 | return model.device 64 | for model in self.models.values(): 65 | if hasattr(model, 'parameters'): 66 | return next(model.parameters()).device 67 | raise RuntimeError("No device found.") 68 | 69 | def to(self, device: torch.device) -> None: 70 | for model in self.models.values(): 71 | model.to(device) 72 | 73 | def cuda(self) -> None: 74 | self.to(torch.device("cuda")) 75 | 76 | def cpu(self) -> None: 77 | self.to(torch.device("cpu")) -------------------------------------------------------------------------------- /trellis/pipelines/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Sampler 2 | from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler -------------------------------------------------------------------------------- /trellis/pipelines/samplers/base.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class Sampler(ABC): 6 | """ 7 | A base class for samplers. 8 | """ 9 | 10 | @abstractmethod 11 | def sample( 12 | self, 13 | model, 14 | **kwargs 15 | ): 16 | """ 17 | Sample from a model. 18 | """ 19 | pass 20 | -------------------------------------------------------------------------------- /trellis/pipelines/samplers/classifier_free_guidance_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | 4 | class ClassifierFreeGuidanceSamplerMixin: 5 | """ 6 | A mixin class for samplers that apply classifier-free guidance. 7 | """ 8 | 9 | def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs): 10 | pred = super()._inference_model(model, x_t, t, cond, **kwargs) 11 | neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) 12 | return (1 + cfg_strength) * pred - cfg_strength * neg_pred 13 | -------------------------------------------------------------------------------- /trellis/pipelines/samplers/flow_euler.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from easydict import EasyDict as edict 6 | from .base import Sampler 7 | from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin 8 | from .guidance_interval_mixin import GuidanceIntervalSamplerMixin 9 | 10 | 11 | class FlowEulerSampler(Sampler): 12 | """ 13 | Generate samples from a flow-matching model using Euler sampling. 14 | 15 | Args: 16 | sigma_min: The minimum scale of noise in flow. 17 | """ 18 | def __init__( 19 | self, 20 | sigma_min: float, 21 | ): 22 | self.sigma_min = sigma_min 23 | 24 | def _eps_to_xstart(self, x_t, t, eps): 25 | assert x_t.shape == eps.shape 26 | return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) 27 | 28 | def _xstart_to_eps(self, x_t, t, x_0): 29 | assert x_t.shape == x_0.shape 30 | return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) 31 | 32 | def _v_to_xstart_eps(self, x_t, t, v): 33 | assert x_t.shape == v.shape 34 | eps = (1 - t) * v + x_t 35 | x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v 36 | return x_0, eps 37 | 38 | def _inference_model(self, model, x_t, t, cond=None, **kwargs): 39 | t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) 40 | return model(x_t, t, cond, **kwargs) 41 | 42 | def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): 43 | pred_v = self._inference_model(model, x_t, t, cond, **kwargs) 44 | pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) 45 | return pred_x_0, pred_eps, pred_v 46 | 47 | @torch.no_grad() 48 | def sample_once( 49 | self, 50 | model, 51 | x_t, 52 | t: float, 53 | t_prev: float, 54 | cond: Optional[Any] = None, 55 | **kwargs 56 | ): 57 | """ 58 | Sample x_{t-1} from the model using Euler method. 59 | 60 | Args: 61 | model: The model to sample from. 62 | x_t: The [N x C x ...] tensor of noisy inputs at time t. 63 | t: The current timestep. 64 | t_prev: The previous timestep. 65 | cond: conditional information. 66 | **kwargs: Additional arguments for model inference. 67 | 68 | Returns: 69 | a dict containing the following 70 | - 'pred_x_prev': x_{t-1}. 71 | - 'pred_x_0': a prediction of x_0. 72 | """ 73 | pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) 74 | pred_x_prev = x_t - (t - t_prev) * pred_v 75 | return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) 76 | 77 | @torch.no_grad() 78 | def sample( 79 | self, 80 | model, 81 | noise, 82 | cond: Optional[Any] = None, 83 | steps: int = 50, 84 | rescale_t: float = 1.0, 85 | verbose: bool = True, 86 | **kwargs 87 | ): 88 | """ 89 | Generate samples from the model using Euler method. 90 | 91 | Args: 92 | model: The model to sample from. 93 | noise: The initial noise tensor. 94 | cond: conditional information. 95 | steps: The number of steps to sample. 96 | rescale_t: The rescale factor for t. 97 | verbose: If True, show a progress bar. 98 | **kwargs: Additional arguments for model_inference. 99 | 100 | Returns: 101 | a dict containing the following 102 | - 'samples': the model samples. 103 | - 'pred_x_t': a list of prediction of x_t. 104 | - 'pred_x_0': a list of prediction of x_0. 105 | """ 106 | sample = noise 107 | t_seq = np.linspace(1, 0, steps + 1) 108 | t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) 109 | t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) 110 | ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) 111 | for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): 112 | out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) 113 | sample = out.pred_x_prev 114 | ret.pred_x_t.append(out.pred_x_prev) 115 | ret.pred_x_0.append(out.pred_x_0) 116 | ret.samples = sample 117 | return ret 118 | 119 | 120 | class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): 121 | """ 122 | Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. 123 | """ 124 | @torch.no_grad() 125 | def sample( 126 | self, 127 | model, 128 | noise, 129 | cond, 130 | neg_cond, 131 | steps: int = 50, 132 | rescale_t: float = 1.0, 133 | cfg_strength: float = 3.0, 134 | verbose: bool = True, 135 | **kwargs 136 | ): 137 | """ 138 | Generate samples from the model using Euler method. 139 | 140 | Args: 141 | model: The model to sample from. 142 | noise: The initial noise tensor. 143 | cond: conditional information. 144 | neg_cond: negative conditional information. 145 | steps: The number of steps to sample. 146 | rescale_t: The rescale factor for t. 147 | cfg_strength: The strength of classifier-free guidance. 148 | verbose: If True, show a progress bar. 149 | **kwargs: Additional arguments for model_inference. 150 | 151 | Returns: 152 | a dict containing the following 153 | - 'samples': the model samples. 154 | - 'pred_x_t': a list of prediction of x_t. 155 | - 'pred_x_0': a list of prediction of x_0. 156 | """ 157 | return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) 158 | 159 | 160 | class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): 161 | """ 162 | Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. 163 | """ 164 | @torch.no_grad() 165 | def sample( 166 | self, 167 | model, 168 | noise, 169 | cond, 170 | neg_cond, 171 | steps: int = 50, 172 | rescale_t: float = 1.0, 173 | cfg_strength: float = 3.0, 174 | cfg_interval: Tuple[float, float] = (0.0, 1.0), 175 | verbose: bool = True, 176 | **kwargs 177 | ): 178 | """ 179 | Generate samples from the model using Euler method. 180 | 181 | Args: 182 | model: The model to sample from. 183 | noise: The initial noise tensor. 184 | cond: conditional information. 185 | neg_cond: negative conditional information. 186 | steps: The number of steps to sample. 187 | rescale_t: The rescale factor for t. 188 | cfg_strength: The strength of classifier-free guidance. 189 | cfg_interval: The interval for classifier-free guidance. 190 | verbose: If True, show a progress bar. 191 | **kwargs: Additional arguments for model_inference. 192 | 193 | Returns: 194 | a dict containing the following 195 | - 'samples': the model samples. 196 | - 'pred_x_t': a list of prediction of x_t. 197 | - 'pred_x_0': a list of prediction of x_0. 198 | """ 199 | return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) 200 | -------------------------------------------------------------------------------- /trellis/pipelines/samplers/guidance_interval_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | 4 | class GuidanceIntervalSamplerMixin: 5 | """ 6 | A mixin class for samplers that apply classifier-free guidance with interval. 7 | """ 8 | 9 | def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): 10 | if cfg_interval[0] <= t <= cfg_interval[1]: 11 | pred = super()._inference_model(model, x_t, t, cond, **kwargs) 12 | neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) 13 | return (1 + cfg_strength) * pred - cfg_strength * neg_pred 14 | else: 15 | return super()._inference_model(model, x_t, t, cond, **kwargs) 16 | -------------------------------------------------------------------------------- /trellis/representations/__init__.py: -------------------------------------------------------------------------------- 1 | from .mesh import MeshExtractResult 2 | -------------------------------------------------------------------------------- /trellis/representations/mesh/__init__.py: -------------------------------------------------------------------------------- 1 | from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult 2 | -------------------------------------------------------------------------------- /trellis/representations/mesh/cube2mesh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ...modules.sparse import SparseTensor 3 | from easydict import EasyDict as edict 4 | from .utils_cube import * 5 | from .flexicube import FlexiCubes 6 | 7 | 8 | class MeshExtractResult: 9 | def __init__(self, 10 | vertices, 11 | faces, 12 | vertex_attrs=None, 13 | res=64 14 | ): 15 | self.vertices = vertices 16 | self.faces = faces.long() 17 | self.vertex_attrs = vertex_attrs 18 | self.face_normal = self.comput_face_normals(vertices, faces) 19 | self.res = res 20 | self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) 21 | 22 | # training only 23 | self.tsdf_v = None 24 | self.tsdf_s = None 25 | self.reg_loss = None 26 | 27 | def comput_face_normals(self, verts, faces): 28 | i0 = faces[..., 0].long() 29 | i1 = faces[..., 1].long() 30 | i2 = faces[..., 2].long() 31 | 32 | v0 = verts[i0, :] 33 | v1 = verts[i1, :] 34 | v2 = verts[i2, :] 35 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) 36 | face_normals = torch.nn.functional.normalize(face_normals, dim=1) 37 | # print(face_normals.min(), face_normals.max(), face_normals.shape) 38 | return face_normals[:, None, :].repeat(1, 3, 1) 39 | 40 | def comput_v_normals(self, verts, faces): 41 | i0 = faces[..., 0].long() 42 | i1 = faces[..., 1].long() 43 | i2 = faces[..., 2].long() 44 | 45 | v0 = verts[i0, :] 46 | v1 = verts[i1, :] 47 | v2 = verts[i2, :] 48 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) 49 | v_normals = torch.zeros_like(verts) 50 | v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) 51 | v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) 52 | v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) 53 | 54 | v_normals = torch.nn.functional.normalize(v_normals, dim=1) 55 | return v_normals 56 | 57 | 58 | class SparseFeatures2Mesh: 59 | def __init__(self, device="cuda", res=64, use_color=True): 60 | ''' 61 | a model to generate a mesh from sparse features structures using flexicube 62 | ''' 63 | super().__init__() 64 | self.device=device 65 | self.res = res 66 | self.mesh_extractor = FlexiCubes(device=device) 67 | self.sdf_bias = -1.0 / res 68 | verts, cube = construct_dense_grid(self.res, self.device) 69 | self.reg_c = cube.to(self.device) 70 | self.reg_v = verts.to(self.device) 71 | self.use_color = use_color 72 | self._calc_layout() 73 | 74 | def _calc_layout(self): 75 | LAYOUTS = { 76 | 'sdf': {'shape': (8, 1), 'size': 8}, 77 | 'deform': {'shape': (8, 3), 'size': 8 * 3}, 78 | 'weights': {'shape': (21,), 'size': 21} 79 | } 80 | if self.use_color: 81 | ''' 82 | 6 channel color including normal map 83 | ''' 84 | LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} 85 | self.layouts = edict(LAYOUTS) 86 | start = 0 87 | for k, v in self.layouts.items(): 88 | v['range'] = (start, start + v['size']) 89 | start += v['size'] 90 | self.feats_channels = start 91 | 92 | def get_layout(self, feats : torch.Tensor, name : str): 93 | if name not in self.layouts: 94 | return None 95 | return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) 96 | 97 | def __call__(self, cubefeats : SparseTensor, training=False): 98 | """ 99 | Generates a mesh based on the specified sparse voxel structures. 100 | Args: 101 | cube_attrs [Nx21] : Sparse Tensor attrs about cube weights 102 | verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal 103 | Returns: 104 | return the success tag and ni you loss, 105 | """ 106 | # add sdf bias to verts_attrs 107 | coords = cubefeats.coords[:, 1:] 108 | feats = cubefeats.feats 109 | 110 | sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] 111 | sdf += self.sdf_bias 112 | v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] 113 | v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) 114 | v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) 115 | weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) 116 | if self.use_color: 117 | sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] 118 | else: 119 | sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] 120 | colors_d = None 121 | 122 | x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) 123 | 124 | vertices, faces, L_dev, colors = self.mesh_extractor( 125 | voxelgrid_vertices=x_nx3, 126 | scalar_field=sdf_d, 127 | cube_idx=self.reg_c, 128 | resolution=self.res, 129 | beta=weights_d[:, :12], 130 | alpha=weights_d[:, 12:20], 131 | gamma_f=weights_d[:, 20], 132 | voxelgrid_colors=colors_d, 133 | training=training) 134 | 135 | mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) 136 | if training: 137 | if mesh.success: 138 | reg_loss += L_dev.mean() * 0.5 139 | reg_loss += (weights[:,:20]).abs().mean() * 0.2 140 | mesh.reg_loss = reg_loss 141 | mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) 142 | mesh.tsdf_s = v_attrs[:, 0] 143 | return mesh -------------------------------------------------------------------------------- /trellis/representations/mesh/utils_cube.py: -------------------------------------------------------------------------------- 1 | import torch 2 | cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ 3 | 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) 4 | cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) 5 | cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 6 | 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) 7 | 8 | def construct_dense_grid(res, device='cuda'): 9 | '''construct a dense grid based on resolution''' 10 | res_v = res + 1 11 | vertsid = torch.arange(res_v ** 3, device=device) 12 | coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() 13 | cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] 14 | cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) 15 | verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) 16 | return verts, cube_fx8 17 | 18 | 19 | def construct_voxel_grid(coords): 20 | verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) 21 | verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) 22 | cubes = inverse_indices.reshape(-1, 8) 23 | return verts_unique, cubes 24 | 25 | 26 | def cubes_to_verts(num_verts, cubes, value, reduce='mean'): 27 | """ 28 | Args: 29 | cubes [Vx8] verts index for each cube 30 | value [Vx8xM] value to be scattered 31 | Operation: 32 | reduced[cubes[i][j]][k] += value[i][k] 33 | """ 34 | M = value.shape[2] # number of channels 35 | reduced = torch.zeros(num_verts, M, device=cubes.device) 36 | return torch.scatter_reduce(reduced, 0, 37 | cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), 38 | value.flatten(0, 1), reduce=reduce, include_self=False) 39 | 40 | def sparse_cube2verts(coords, feats, training=True): 41 | new_coords, cubes = construct_voxel_grid(coords) 42 | new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) 43 | if training: 44 | con_loss = torch.mean((feats - new_feats[cubes]) ** 2) 45 | else: 46 | con_loss = 0.0 47 | return new_coords, new_feats, con_loss 48 | 49 | 50 | def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): 51 | F = feats.shape[-1] 52 | dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) 53 | if sdf_init: 54 | dense_attrs[..., 0] = 1 # initial outside sdf value 55 | dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats 56 | return dense_attrs.reshape(-1, F) 57 | 58 | 59 | def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): 60 | return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) 61 | -------------------------------------------------------------------------------- /trellis/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stable-X/ComfyUI-Hi3DGen/99621aa48dd05203bee91cf3d5813669e0b5721d/trellis/utils/__init__.py -------------------------------------------------------------------------------- /trellis/utils/_rasterization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, List, Optional, Tuple, Dict 3 | import nvdiffrast.torch as dr 4 | import utils3d 5 | class RastContext: 6 | def __init__(self, backend='cuda'): 7 | self.backend = backend 8 | 9 | def rasterize_triangle_faces( 10 | ctx: RastContext, 11 | vertices: torch.Tensor, 12 | faces: torch.Tensor, 13 | width: int, 14 | height: int, 15 | attr: torch.Tensor = None, 16 | uv: torch.Tensor = None, 17 | texture: torch.Tensor = None, 18 | model: torch.Tensor = None, 19 | view: torch.Tensor = None, 20 | projection: torch.Tensor = None, 21 | antialiasing: Union[bool, List[int]] = True, 22 | diff_attrs: Union[None, List[int]] = None, 23 | ) -> Dict[str, torch.Tensor]: 24 | """ 25 | Rasterize a mesh with vertex attributes. 26 | """ 27 | assert vertices.ndim == 3 28 | assert faces.ndim == 2 29 | 30 | # Handle vertices dimensions 31 | if vertices.shape[-1] == 2: 32 | vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1) 33 | elif vertices.shape[-1] == 3: 34 | vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) 35 | elif vertices.shape[-1] == 4: 36 | pass 37 | else: 38 | raise ValueError(f'Wrong shape of vertices: {vertices.shape}') 39 | 40 | # Calculate MVP matrix 41 | mvp = projection if projection is not None else torch.eye(4, device=vertices.device) 42 | if view is not None: 43 | mvp = mvp @ view 44 | if model is not None: 45 | mvp = mvp @ model 46 | 47 | # Transform vertices to clip space 48 | pos_clip = vertices @ mvp.transpose(-1, -2) 49 | faces = faces.contiguous() 50 | if attr is not None: 51 | attr = attr.contiguous() 52 | 53 | # Rasterize 54 | rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True) 55 | 56 | # Extract basic outputs 57 | face_id = rast_out[..., 3].flip(1) 58 | depth = rast_out[..., 2].flip(1) 59 | mask = (face_id > 0).float() 60 | depth = (depth * 0.5 + 0.5) * mask + (1.0 - mask) 61 | 62 | ret = { 63 | 'depth': depth, 64 | 'mask': mask, 65 | 'face_id': face_id, 66 | } 67 | 68 | # Handle attribute interpolation 69 | if attr is not None: 70 | image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs) 71 | 72 | if antialiasing == True: 73 | image = dr.antialias(image, rast_out, pos_clip, faces) 74 | elif isinstance(antialiasing, list): 75 | aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces) 76 | image[..., antialiasing] = aa_image 77 | 78 | image = image.flip(1).permute(0, 3, 1, 2) 79 | ret['image'] = image 80 | 81 | # Handle UV mapping 82 | if uv is not None: 83 | uv_map, uv_map_dr = dr.interpolate(uv, rast_out, faces, rast_db, diff_attrs='all') 84 | ret['uv'] = uv_map 85 | ret['uv_dr'] = uv_map_dr 86 | 87 | if texture is not None: 88 | texture_map = dr.texture(texture, uv_map, uv_map_dr) 89 | ret['texture'] = texture_map.flip(1).permute(0, 3, 1, 2) 90 | 91 | # Handle derivatives 92 | if diff_attrs is not None: 93 | image_dr = image_dr.flip(1).permute(0, 3, 1, 2) 94 | ret['image_dr'] = image_dr 95 | 96 | return ret -------------------------------------------------------------------------------- /trellis/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | 6 | # Dictionary utils 7 | def _dict_merge(dicta, dictb, prefix=''): 8 | """ 9 | Merge two dictionaries. 10 | """ 11 | assert isinstance(dicta, dict), 'input must be a dictionary' 12 | assert isinstance(dictb, dict), 'input must be a dictionary' 13 | dict_ = {} 14 | all_keys = set(dicta.keys()).union(set(dictb.keys())) 15 | for key in all_keys: 16 | if key in dicta.keys() and key in dictb.keys(): 17 | if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): 18 | dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') 19 | else: 20 | raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') 21 | elif key in dicta.keys(): 22 | dict_[key] = dicta[key] 23 | else: 24 | dict_[key] = dictb[key] 25 | return dict_ 26 | 27 | 28 | def dict_merge(dicta, dictb): 29 | """ 30 | Merge two dictionaries. 31 | """ 32 | return _dict_merge(dicta, dictb, prefix='') 33 | 34 | 35 | def dict_foreach(dic, func, special_func={}): 36 | """ 37 | Recursively apply a function to all non-dictionary leaf values in a dictionary. 38 | """ 39 | assert isinstance(dic, dict), 'input must be a dictionary' 40 | for key in dic.keys(): 41 | if isinstance(dic[key], dict): 42 | dic[key] = dict_foreach(dic[key], func) 43 | else: 44 | if key in special_func.keys(): 45 | dic[key] = special_func[key](dic[key]) 46 | else: 47 | dic[key] = func(dic[key]) 48 | return dic 49 | 50 | 51 | def dict_reduce(dicts, func, special_func={}): 52 | """ 53 | Reduce a list of dictionaries. Leaf values must be scalars. 54 | """ 55 | assert isinstance(dicts, list), 'input must be a list of dictionaries' 56 | assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' 57 | assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' 58 | all_keys = set([key for dict_ in dicts for key in dict_.keys()]) 59 | reduced_dict = {} 60 | for key in all_keys: 61 | vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] 62 | if isinstance(vlist[0], dict): 63 | reduced_dict[key] = dict_reduce(vlist, func, special_func) 64 | else: 65 | if key in special_func.keys(): 66 | reduced_dict[key] = special_func[key](vlist) 67 | else: 68 | reduced_dict[key] = func(vlist) 69 | return reduced_dict 70 | 71 | 72 | def dict_any(dic, func): 73 | """ 74 | Recursively apply a function to all non-dictionary leaf values in a dictionary. 75 | """ 76 | assert isinstance(dic, dict), 'input must be a dictionary' 77 | for key in dic.keys(): 78 | if isinstance(dic[key], dict): 79 | if dict_any(dic[key], func): 80 | return True 81 | else: 82 | if func(dic[key]): 83 | return True 84 | return False 85 | 86 | 87 | def dict_all(dic, func): 88 | """ 89 | Recursively apply a function to all non-dictionary leaf values in a dictionary. 90 | """ 91 | assert isinstance(dic, dict), 'input must be a dictionary' 92 | for key in dic.keys(): 93 | if isinstance(dic[key], dict): 94 | if not dict_all(dic[key], func): 95 | return False 96 | else: 97 | if not func(dic[key]): 98 | return False 99 | return True 100 | 101 | 102 | def dict_flatten(dic, sep='.'): 103 | """ 104 | Flatten a nested dictionary into a dictionary with no nested dictionaries. 105 | """ 106 | assert isinstance(dic, dict), 'input must be a dictionary' 107 | flat_dict = {} 108 | for key in dic.keys(): 109 | if isinstance(dic[key], dict): 110 | sub_dict = dict_flatten(dic[key], sep=sep) 111 | for sub_key in sub_dict.keys(): 112 | flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] 113 | else: 114 | flat_dict[key] = dic[key] 115 | return flat_dict 116 | 117 | 118 | def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): 119 | num_images = len(images) 120 | if nrow is None and ncol is None: 121 | if aspect_ratio is not None: 122 | nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) 123 | else: 124 | nrow = int(np.sqrt(num_images)) 125 | ncol = (num_images + nrow - 1) // nrow 126 | elif nrow is None and ncol is not None: 127 | nrow = (num_images + ncol - 1) // ncol 128 | elif nrow is not None and ncol is None: 129 | ncol = (num_images + nrow - 1) // nrow 130 | else: 131 | assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' 132 | 133 | grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) 134 | for i, img in enumerate(images): 135 | row = i // ncol 136 | col = i % ncol 137 | grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img 138 | return grid 139 | 140 | 141 | def notes_on_image(img, notes=None): 142 | img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) 143 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 144 | if notes is not None: 145 | img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) 146 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 147 | return img 148 | 149 | 150 | def save_image_with_notes(img, path, notes=None): 151 | """ 152 | Save an image with notes. 153 | """ 154 | if isinstance(img, torch.Tensor): 155 | img = img.cpu().numpy().transpose(1, 2, 0) 156 | if img.dtype == np.float32 or img.dtype == np.float64: 157 | img = np.clip(img * 255, 0, 255).astype(np.uint8) 158 | img = notes_on_image(img, notes) 159 | cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 160 | 161 | 162 | # debug utils 163 | 164 | def atol(x, y): 165 | """ 166 | Absolute tolerance. 167 | """ 168 | return torch.abs(x - y) 169 | 170 | 171 | def rtol(x, y): 172 | """ 173 | Relative tolerance. 174 | """ 175 | return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) 176 | 177 | 178 | # print utils 179 | def indent(s, n=4): 180 | """ 181 | Indent a string. 182 | """ 183 | lines = s.split('\n') 184 | for i in range(1, len(lines)): 185 | lines[i] = ' ' * n + lines[i] 186 | return '\n'.join(lines) 187 | 188 | -------------------------------------------------------------------------------- /trellis/utils/random_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] 4 | 5 | def radical_inverse(base, n): 6 | val = 0 7 | inv_base = 1.0 / base 8 | inv_base_n = inv_base 9 | while n > 0: 10 | digit = n % base 11 | val += digit * inv_base_n 12 | n //= base 13 | inv_base_n *= inv_base 14 | return val 15 | 16 | def halton_sequence(dim, n): 17 | return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] 18 | 19 | def hammersley_sequence(dim, n, num_samples): 20 | return [n / num_samples] + halton_sequence(dim - 1, n) 21 | 22 | def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): 23 | u, v = hammersley_sequence(2, n, num_samples) 24 | u += offset[0] / num_samples 25 | v += offset[1] 26 | if remap: 27 | u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 28 | theta = np.arccos(1 - 2 * u) - np.pi / 2 29 | phi = v * 2 * np.pi 30 | return [phi, theta] -------------------------------------------------------------------------------- /trellis_model_manager.py: -------------------------------------------------------------------------------- 1 | # trellis_model_manager.py 2 | import os 3 | import logging 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from safetensors.torch import load_file 8 | import folder_paths 9 | from huggingface_hub import hf_hub_download, snapshot_download 10 | from typing import Dict, Union 11 | import json 12 | import importlib # Import the importlib module 13 | from trellis.modules.utils import convert_module_to_f16, convert_module_to_f32 14 | 15 | logger = logging.getLogger('model_manager') 16 | 17 | __attributes = { 18 | 'SparseStructureDecoder': 'trellis.models.sparse_structure_vae', 19 | 'SparseStructureFlowModel': 'trellis.models.sparse_structure_flow', 20 | 'SLatFlowModel': 'trellis.models.structured_latent_flow', 21 | } 22 | 23 | __all__ = list(__attributes.keys()) 24 | 25 | def __getattr__(name): 26 | if name in __attributes: 27 | module_name = __attributes[name] 28 | module = importlib.import_module(module_name, package=None) # Import the module 29 | return getattr(module, name) 30 | raise AttributeError(f"module {__name__} has no attribute {name}") 31 | 32 | class TrellisModelManager: 33 | """ 34 | Basic manager for Trellis models, using ComfyUI's new model path. 35 | """ 36 | def __init__(self, model_dir: str, config=None, device: str = "cuda"): 37 | """ 38 | Initialize the model manager with a specific model directory. 39 | 40 | Args: 41 | model_dir (str): Path to model directory (e.g. "models/checkpoints/TRELLIS-image-large") 42 | config (dict or object): Global configuration for Trellis 43 | device (str): Device to load models on (e.g. "cuda") 44 | """ 45 | self.model_dir = model_dir 46 | # Handle config being either a dict or an object 47 | if config is None: 48 | self.device = device 49 | elif isinstance(config, dict): 50 | self.device = config.get('device', device) 51 | self.config = config 52 | else: 53 | self.device = getattr(config, 'device', device) 54 | self.config = config 55 | self.model = None 56 | self.dinov2_model = None 57 | 58 | def load(self) -> None: 59 | """Load model configuration and checkpoints""" 60 | try: 61 | # Ensure directory exists 62 | os.makedirs(self.model_dir, exist_ok=True) 63 | ckpts_folder = os.path.join(self.model_dir, "ckpts") 64 | os.makedirs(ckpts_folder, exist_ok=True) 65 | 66 | # Download model files if needed 67 | if not os.path.exists(os.path.join(self.model_dir, "pipeline.json")): 68 | logger.info("Downloading TRELLIS models...") 69 | try: 70 | # Download main pipeline files 71 | snapshot_download( 72 | repo_id="Stable-X/trellis-normal-v0-1", 73 | local_dir=self.model_dir, 74 | local_dir_use_symlinks=False, 75 | allow_patterns=["pipeline.json", "README.md"] 76 | ) 77 | # Download checkpoint files 78 | snapshot_download( 79 | repo_id="Stable-X/trellis-normal-v0-1", 80 | local_dir=ckpts_folder, 81 | local_dir_use_symlinks=False, 82 | allow_patterns=["*.safetensors", "*.json"], 83 | cache_dir=os.path.join(self.model_dir, ".cache") 84 | ) 85 | logger.info("Model files downloaded successfully") 86 | except Exception as e: 87 | logger.error(f"Error downloading model files: {str(e)}") 88 | raise 89 | 90 | # Load configuration 91 | self.config = self._load_config() 92 | 93 | except Exception as e: 94 | logger.error(f"Error in load(): {str(e)}") 95 | raise 96 | 97 | def get_checkpoint_path(self, filename: str) -> str: 98 | """ 99 | Returns the full path to a checkpoint file. 100 | """ 101 | ckpts_folder = os.path.join(self.model_dir, "ckpts") 102 | # Add .safetensors extension if not present 103 | if not filename.endswith('.safetensors'): 104 | filename = f"{filename}.safetensors" 105 | full_path = os.path.join(ckpts_folder, filename) 106 | if not os.path.exists(full_path): 107 | raise FileNotFoundError(f"Checkpoint file not found: {full_path}") 108 | return full_path 109 | 110 | def _load_config(self) -> Dict: 111 | """Load model configuration from pipeline.json""" 112 | try: 113 | config_path = os.path.join(self.model_dir, "pipeline.json") 114 | 115 | if os.path.exists(config_path): 116 | logger.info(f"Loading config from {config_path}") 117 | with open(config_path, 'r') as f: 118 | config = json.load(f) 119 | else: 120 | logger.info(f"Config not found locally, downloading from HuggingFace") 121 | config_path = hf_hub_download( 122 | repo_id=f"JeffreyXiang/{os.path.basename(self.model_dir)}", 123 | filename="pipeline.json", 124 | cache_dir=os.path.join(self.model_dir, ".cache") 125 | ) 126 | with open(config_path, 'r') as f: 127 | config = json.load(f) 128 | 129 | # Debug: Print raw config 130 | logger.info("Raw config contents:") 131 | logger.info(json.dumps(config, indent=2)) 132 | 133 | if not config: 134 | raise ValueError(f"Could not load valid configuration from {self.model_dir}") 135 | 136 | if 'name' not in config: 137 | config['name'] = 'TrellisImageTo3DPipeline' 138 | 139 | return config 140 | 141 | except Exception as e: 142 | logger.error(f"Error loading config from {self.model_dir}: {e}") 143 | return { 144 | 'name': 'TrellisImageTo3DPipeline', 145 | 'version': '1.0' 146 | } 147 | 148 | def load_models(self) -> Dict[str, nn.Module]: 149 | """Load all required models with current configuration""" 150 | return { 151 | 'sparse_structure_flow_model': self.get_checkpoint_path("ss_flow_img_dit_L_16l8_fp16"), 152 | 'slat_flow_model': self.get_checkpoint_path("slat_flow_img_dit_L_64l8p2_fp16") 153 | } 154 | 155 | def load_model_components(self) -> Dict[str, nn.Module]: 156 | """Loads individual model components.""" 157 | models = {} 158 | model_paths = self.load_models() 159 | for name, path in model_paths.items(): 160 | models[name] = models.from_pretrained(path, config=self.config) 161 | 162 | # Ensure each model is converted to the desired precision 163 | if self.config.get('use_fp16', True): 164 | convert_module_to_f16(models[name]) 165 | else: 166 | convert_module_to_f32(models[name]) 167 | 168 | # DINOv2 is handled separately 169 | # models['image_cond_model'] = self.load_dinov2(self.config.get("dinov2_model", "dinov2_vitl14")) 170 | 171 | return models 172 | 173 | def load_dinov2(self, model_name: str): 174 | """Load DINOv2 model with device, precision, and attention backend management""" 175 | try: 176 | # Get configuration values 177 | use_fp16 = (self.config.get('use_fp16', True) 178 | if isinstance(self.config, dict) 179 | else getattr(self.config, 'use_fp16', True)) 180 | 181 | # Get attention backend from config 182 | attention_backend = (self.config.get('attention_backend', 'default') 183 | if isinstance(self.config, dict) 184 | else getattr(self.config, 'attention_backend', 'default')) 185 | 186 | # Try to load from local path first 187 | model_path = folder_paths.get_full_path("classifiers", f"{model_name}.pth") 188 | 189 | if model_path is None: 190 | print(f"Downloading {model_name} from torch hub...") 191 | try: 192 | # Load model architecture with specified attention backend 193 | model = torch.hub.load('facebookresearch/dinov2', model_name, 194 | pretrained=True, 195 | force_reload=False, 196 | trust_repo=True) 197 | 198 | # Save model for future use 199 | save_dir = os.path.join(folder_paths.models_dir, "classifiers") 200 | os.makedirs(save_dir, exist_ok=True) 201 | save_path = os.path.join(save_dir, f"{model_name}.pth") 202 | 203 | # Save on CPU to avoid memory issues 204 | model = model.cpu() 205 | torch.save(model.state_dict(), save_path) 206 | print(f"Saved DINOv2 model to {save_path}") 207 | 208 | except Exception as e: 209 | raise RuntimeError(f"Failed to download DINOv2 model: {str(e)}") 210 | else: 211 | # Load from local path 212 | print(f"Loading DINOv2 model from {model_path}") 213 | model = torch.hub.load('facebookresearch/dinov2', model_name, 214 | pretrained=False, 215 | force_reload=False, 216 | trust_repo=True) 217 | model.load_state_dict(torch.load(model_path)) 218 | 219 | # Move model to specified device and apply precision settings 220 | model = model.to(self.device) 221 | if use_fp16: 222 | model = model.half() 223 | 224 | # Set attention backend if specified in config 225 | if hasattr(model, 'set_attention_backend') and attention_backend != 'default': 226 | model.set_attention_backend(attention_backend) 227 | 228 | model.eval() 229 | return model 230 | 231 | except Exception as e: 232 | raise RuntimeError(f"Error loading DINOv2 model: {str(e)}") -------------------------------------------------------------------------------- /win_requirements.txt: -------------------------------------------------------------------------------- 1 | pillow==10.4.0 2 | imageio==2.36.1 3 | imageio-ffmpeg==0.5.1 4 | tqdm==4.67.1 5 | easydict==1.13 6 | opencv-python-headless==4.10.0.84 7 | scipy 8 | rembg 9 | onnxruntime 10 | trimesh 11 | xatlas 12 | pyvista 13 | pymeshfix 14 | igraph 15 | spconv-cu124 16 | onnxruntime 17 | git+https://github.com/EasternJournalist/utils3d.git 18 | 19 | # triton ( Windows specific builds from : https://github.com/woct0rdho/triton-windows ) 20 | https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp312-cp312-win_amd64.whl; (python_version >= "3.12" and python_version < "3.13") 21 | https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp311-cp311-win_amd64.whl; (python_version >= "3.11" and python_version < "3.12") 22 | https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp310-cp310-win_amd64.whl; (python_version >= "3.10" and python_version < "3.11") 23 | https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post8/triton-3.1.0-cp38-cp38-win_amd64.whl; (python_version >= "3.8" and python_version < "3.9") 24 | -------------------------------------------------------------------------------- /workflow/Hi3DGen_WF_single.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "2ce13a2d-0e15-4260-b3b7-531ba96bf89e", 3 | "revision": 0, 4 | "last_node_id": 89, 5 | "last_link_id": 163, 6 | "nodes": [ 7 | { 8 | "id": 87, 9 | "type": "DownloadAndLoadStableXModel", 10 | "pos": [-367.123107910156, 146.043090820313], 11 | "size": [315, 58], 12 | "flags": { 13 | 14 | }, 15 | "order": 0, 16 | "mode": 0, 17 | "inputs": [], 18 | "outputs": [ 19 | { 20 | "name": "pipeline", 21 | "type": "YOSOPIPE", 22 | "links": [160] 23 | } 24 | ], 25 | "properties": { 26 | "Node name for S&R": "DownloadAndLoadStableXModel" 27 | }, 28 | "widgets_values": [ 29 | "yoso-normal-v1-8-1" 30 | ] 31 | }, 32 | { 33 | "id": 88, 34 | "type": "StableXProcessImage", 35 | "pos": [490, 180], 36 | "size": [315, 150], 37 | "flags": { 38 | 39 | }, 40 | "order": 3, 41 | "mode": 0, 42 | "inputs": [ 43 | { 44 | "name": "pipeline", 45 | "type": "YOSOPIPE", 46 | "link": 160 47 | }, 48 | { 49 | "name": "image", 50 | "type": "IMAGE", 51 | "link": 161 52 | } 53 | ], 54 | "outputs": [ 55 | { 56 | "name": "image", 57 | "type": "IMAGE", 58 | "links": [162, 163] 59 | } 60 | ], 61 | "properties": { 62 | "Node name for S&R": "StableXProcessImage" 63 | }, 64 | "widgets_values": [2048, 1, 148262827792070, "randomize" 65 | ] 66 | }, 67 | { 68 | "id": 89, 69 | "type": "PreviewImage", 70 | "pos": [490, 480], 71 | "size": [210, 246], 72 | "flags": { 73 | 74 | }, 75 | "order": 5, 76 | "mode": 0, 77 | "inputs": [ 78 | { 79 | "name": "images", 80 | "type": "IMAGE", 81 | "link": 163 82 | } 83 | ], 84 | "outputs": [], 85 | "properties": { 86 | "Node name for S&R": "PreviewImage" 87 | }, 88 | "widgets_values": [ 89 | "" 90 | ] 91 | }, 92 | { 93 | "id": 86, 94 | "type": "IF_TrellisCheckpointLoader", 95 | "pos": [-358.851715087891, -174.489013671875], 96 | "size": [315, 202], 97 | "flags": { 98 | 99 | }, 100 | "order": 1, 101 | "mode": 0, 102 | "inputs": [], 103 | "outputs": [ 104 | { 105 | "name": "model", 106 | "type": "TRELLIS_MODEL", 107 | "links": [159] 108 | } 109 | ], 110 | "properties": { 111 | "Node name for S&R": "IF_TrellisCheckpointLoader" 112 | }, 113 | "widgets_values": [ 114 | "trellis-normal-v0-1", 115 | "dinov2_vitl14_reg", 116 | true, "xformers", 117 | "spconv", 118 | "implicit_gemm", 119 | "cuda" 120 | ] 121 | }, 122 | { 123 | "id": 84, 124 | "type": "IF_TrellisImageTo3D", 125 | "pos": [860.648742675781, 208.74560546875], 126 | "size": [340.200012207031, 506], 127 | "flags": { 128 | 129 | }, 130 | "order": 4, 131 | "mode": 0, 132 | "inputs": [ 133 | { 134 | "name": "model", 135 | "type": "TRELLIS_MODEL", 136 | "link": 159 137 | }, 138 | { 139 | "name": "images", 140 | "type": "IMAGE", 141 | "link": 162 142 | }, 143 | { 144 | "name": "masks", 145 | "shape": 7, 146 | "type": "MASK", 147 | "link": null 148 | } 149 | ], 150 | "outputs": [ 151 | { 152 | "name": "model_file", 153 | "type": "STRING", 154 | "links": [157] 155 | }, 156 | { 157 | "name": "video_path", 158 | "type": "STRING", 159 | "links": [] 160 | }, 161 | { 162 | "name": "texture_image", 163 | "type": "IMAGE", 164 | "links": null 165 | } 166 | ], 167 | "properties": { 168 | "Node name for S&R": "IF_TrellisImageTo3D" 169 | }, 170 | "widgets_values": [ 171 | "single", 172 | 240919286, "randomize", 173 | 7.5, 12, 3, 12, 0.95, "stochastic", 174 | "test" 175 | ] 176 | }, 177 | { 178 | "id": 75, 179 | "type": "Preview3D", 180 | "pos": [1249.12377929688, 220.59049987793], 181 | "size": [315, 550], 182 | "flags": { 183 | 184 | }, 185 | "order": 6, 186 | "mode": 0, 187 | "inputs": [ 188 | { 189 | "name": "model_file", 190 | "type": "STRING", 191 | "widget": { 192 | "name": "model_file" 193 | }, 194 | "link": 157 195 | } 196 | ], 197 | "outputs": [], 198 | "properties": { 199 | "Node name for S&R": "Preview3D", 200 | "Camera Info": { 201 | "position": { 202 | "x": 9.19323855253915, 203 | "y": 2.91023369098772, 204 | "z": -1.400979112753 205 | }, 206 | "target": { 207 | "x": 0, 208 | "y": 1.61671303147776, 209 | "z": 0 210 | }, 211 | "zoom": 1, 212 | "cameraType": "perspective" 213 | } 214 | }, 215 | "widgets_values": [ 216 | "test/test.glb", 217 | "" 218 | ] 219 | }, 220 | { 221 | "id": 67, 222 | "type": "LoadImage", 223 | "pos": [-368.674621582031, 330.965423583984], 224 | "size": [315, 314], 225 | "flags": { 226 | 227 | }, 228 | "order": 2, 229 | "mode": 0, 230 | "inputs": [], 231 | "outputs": [ 232 | { 233 | "name": "IMAGE", 234 | "type": "IMAGE", 235 | "slot_index": 0, 236 | "links": [161] 237 | }, 238 | { 239 | "name": "MASK", 240 | "type": "MASK", 241 | "links": null 242 | } 243 | ], 244 | "properties": { 245 | "Node name for S&R": "LoadImage" 246 | }, 247 | "widgets_values": [ 248 | "Generated Image March 31, 2025 - 3_19PM.jpeg", 249 | "image", 250 | "" 251 | ] 252 | } 253 | ], 254 | "links": [ 255 | [157, 84, 0, 75, 0, "STRING" 256 | ], 257 | [159, 86, 0, 84, 0, "TRELLIS_MODEL" 258 | ], 259 | [160, 87, 0, 88, 0, "YOSOPIPE" 260 | ], 261 | [161, 67, 0, 88, 1, "IMAGE" 262 | ], 263 | [162, 88, 0, 84, 1, "IMAGE" 264 | ], 265 | [163, 88, 0, 89, 0, "IMAGE" 266 | ] 267 | ], 268 | "groups": [ 269 | { 270 | "id": 1, 271 | "title": "Hi3DGen", 272 | "bounding": [450, 90, 1017.04064941406, 874.525085449219], 273 | "color": "#3f789e", 274 | "font_size": 24, 275 | "flags": { 276 | 277 | } 278 | }, 279 | { 280 | "id": 2, 281 | "title": "Image", 282 | "bounding": [-396.156677246094, 258.446929931641, 355.320861816406, 402.488861083984], 283 | "color": "#3f789e", 284 | "font_size": 24, 285 | "flags": { 286 | 287 | } 288 | } 289 | ], 290 | "config": { 291 | 292 | }, 293 | "extra": { 294 | "ds": { 295 | "scale": 0.640376065342708, 296 | "offset": [1194.09386904219, 251.478399306866] 297 | }, 298 | "ue_links": [], 299 | "reroutes": [ 300 | { 301 | "id": 1, 302 | "pos": [248.001159667969, 386.019378662109], 303 | "linkIds": [159] 304 | } 305 | ], 306 | "VHS_latentpreview": false, 307 | "VHS_latentpreviewrate": 0, 308 | "VHS_MetadataImage": true, 309 | "VHS_KeepIntermediate": true, 310 | "linkExtensions": [ 311 | { 312 | "id": 159, 313 | "parentId": 1 314 | } 315 | ] 316 | }, 317 | "version": 0.4 318 | } 319 | --------------------------------------------------------------------------------