├── .gitignore ├── LICENSE ├── README.md ├── assets ├── bear.png ├── book.png ├── car.png ├── castle.png ├── demo │ ├── video1.gif │ └── video2.gif ├── devil.png ├── dragon.png ├── earth.png ├── figure │ ├── framework.png │ └── teaser.gif ├── fish.png ├── girl.png ├── hamburger.png ├── man.png ├── panda.png ├── parrot.png ├── phoenix.png ├── riding.png └── sunglasses.png ├── direct3d ├── models │ ├── __init__.py │ ├── condition.py │ ├── dit.py │ └── vae.py ├── pipeline.py └── utils │ ├── __init__.py │ ├── image.py │ ├── triplane.py │ └── util.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *__pycache__* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Direct3D: Scalable Image-to-3D Generation via 3D Latent Diffusion Transformer (NeurIPS 2024) 3 | 4 |
5 | 6 | 7 | 8 | 9 | 10 |
11 | 12 |

13 | 14 | 15 |
16 |

17 | 18 | --- 19 | 20 | ## ✨ News 21 | 22 | - Feb 11, 2025: 🔨 We are working on the Gradio demo and will release it soon! 23 | - Feb 11, 2025: 🎁 Enjoy our improved version of Direct3D with high quality geometry and texture at [https://www.neural4d.com](https://www.neural4d.com/). 24 | - Feb 11, 2025: 🚀 Release inference code of Direct3D and the pretrained models are available at 🤗 [Hugging Face](https://huggingface.co/DreamTechAI/Direct3D/tree/main). 25 | 26 | ## 📝 Abstract 27 | 28 | We introduce **Direct3D**, a native 3D generative model scalable to in-the-wild input images, without requiring a multiview diffusion model or SDS optimization. Our approach comprises two primary components: a Direct 3D Variational Auto-Encoder **(D3D-VAE)** and a Direct 3D Diffusion Transformer **(D3D-DiT)**. D3D-VAE efficiently encodes high-resolution 3D shapes into a compact and continuous latent triplane space. Notably, our method directly supervises the decoded geometry using a semi-continuous surface sampling strategy, diverging from previous methods relying on rendered images as supervision signals. D3D-DiT models the distribution of encoded 3D latents and is specifically designed to fuse positional information from the three feature maps of the triplane latent, enabling a native 3D generative model scalable to large-scale 3D datasets. Additionally, we introduce an innovative image-to-3D generation pipeline incorporating semantic and pixel-level image conditions, allowing the model to produce 3D shapes consistent with the provided conditional image input. Extensive experiments demonstrate the superiority of our large-scale pre-trained Direct3D over previous image-to-3D approaches, achieving significantly better generation quality and generalization ability, thus establishing a new state-of-the-art for 3D content creation. 29 | 30 |

31 | 32 |
33 |

34 | 35 | ## 🚀 Getting Started 36 | 37 | ### Installation 38 | 39 | ```sh 40 | git clone https://github.com/DreamTechAI/Direct3D.git 41 | 42 | cd Direct3D 43 | 44 | pip install -r requirements.txt 45 | 46 | pip install -e . 47 | ``` 48 | 49 | ### Usage 50 | 51 | ```python 52 | from direct3d.pipeline import Direct3dPipeline 53 | pipeline = Direct3dPipeline.from_pretrained("DreamTechAI/Direct3D") 54 | pipeline.to("cuda") 55 | mesh = pipeline( 56 | "assets/devil.png", 57 | remove_background=False, # set to True if the background of the image needs to be removed 58 | mc_threshold=-1.0, 59 | guidance_scale=4.0, 60 | num_inference_steps=50, 61 | )["meshes"][0] 62 | mesh.export("output.obj") 63 | ``` 64 | 65 | ## 🤗 Acknowledgements 66 | 67 | Thanks to the following repos for their great work, which helps us a lot in the development of Direct3D: 68 | 69 | - [3DShape2VecSet](https://github.com/1zb/3DShape2VecSet/tree/master) 70 | - [Michelangelo](https://github.com/NeuralCarver/Michelangelo) 71 | - [Objaverse](https://objaverse.allenai.org/) 72 | - [diffusers](https://github.com/huggingface/diffusers) 73 | 74 | ## 📖 Citation 75 | 76 | If you find our work useful, please consider citing our paper: 77 | 78 | ```bibtex 79 | @article{direct3d, 80 | title={Direct3D: Scalable Image-to-3D Generation via 3D Latent Diffusion Transformer}, 81 | author={Wu, Shuang and Lin, Youtian and Zhang, Feihu and Zeng, Yifei and Xu, Jingxi and Torr, Philip and Cao, Xun and Yao, Yao}, 82 | journal={arXiv preprint arXiv:2405.14832}, 83 | year={2024} 84 | } 85 | ``` 86 | 87 | --- 88 | -------------------------------------------------------------------------------- /assets/bear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/bear.png -------------------------------------------------------------------------------- /assets/book.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/book.png -------------------------------------------------------------------------------- /assets/car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/car.png -------------------------------------------------------------------------------- /assets/castle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/castle.png -------------------------------------------------------------------------------- /assets/demo/video1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/demo/video1.gif -------------------------------------------------------------------------------- /assets/demo/video2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/demo/video2.gif -------------------------------------------------------------------------------- /assets/devil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/devil.png -------------------------------------------------------------------------------- /assets/dragon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/dragon.png -------------------------------------------------------------------------------- /assets/earth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/earth.png -------------------------------------------------------------------------------- /assets/figure/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/figure/framework.png -------------------------------------------------------------------------------- /assets/figure/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/figure/teaser.gif -------------------------------------------------------------------------------- /assets/fish.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/fish.png -------------------------------------------------------------------------------- /assets/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/girl.png -------------------------------------------------------------------------------- /assets/hamburger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/hamburger.png -------------------------------------------------------------------------------- /assets/man.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/man.png -------------------------------------------------------------------------------- /assets/panda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/panda.png -------------------------------------------------------------------------------- /assets/parrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/parrot.png -------------------------------------------------------------------------------- /assets/phoenix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/phoenix.png -------------------------------------------------------------------------------- /assets/riding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/riding.png -------------------------------------------------------------------------------- /assets/sunglasses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/sunglasses.png -------------------------------------------------------------------------------- /direct3d/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/direct3d/models/__init__.py -------------------------------------------------------------------------------- /direct3d/models/condition.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import CLIPModel, AutoModel 3 | from torchvision import transforms as T 4 | 5 | class ClipImageEncoder(nn.Module): 6 | 7 | def __init__(self, version="openai/clip-vit-large-patch14", img_size=224): 8 | super().__init__() 9 | 10 | encoder = CLIPModel.from_pretrained(version) 11 | encoder = encoder.eval() 12 | self.encoder = encoder 13 | self.transform = T.Compose( 14 | [ 15 | T.Resize(img_size, antialias=True), 16 | T.Normalize( 17 | mean=[0.48145466, 0.4578275, 0.40821073], 18 | std=[0.26862954, 0.26130258, 0.27577711], 19 | ), 20 | ] 21 | ) 22 | 23 | def forward(self, image): 24 | image = self.transform(image) 25 | embbed = self.encoder.vision_model(image).last_hidden_state 26 | return embbed 27 | 28 | 29 | class DinoEncoder(nn.Module): 30 | 31 | def __init__(self, version="facebook/dinov2-large", img_size=224): 32 | super().__init__() 33 | 34 | encoder = AutoModel.from_pretrained(version) 35 | encoder = encoder.eval() 36 | self.encoder = encoder 37 | self.transform = T.Compose( 38 | [ 39 | T.Resize(img_size, antialias=True), 40 | T.Normalize( 41 | mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225], 43 | ), 44 | ] 45 | ) 46 | 47 | def forward(self, image): 48 | image = self.transform(image) 49 | embbed = self.encoder(image).last_hidden_state 50 | return embbed 51 | 52 | -------------------------------------------------------------------------------- /direct3d/models/dit.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py 2 | 3 | from typing import Optional 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection, get_2d_sincos_pos_embed_from_grid 12 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 13 | from diffusers.models.attention import FeedForward 14 | 15 | 16 | class ClassCombinedTimestepSizeEmbeddings(nn.Module): 17 | 18 | def __init__(self, embedding_dim, class_emb_dim): 19 | super().__init__() 20 | 21 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) 22 | self.timestep_embedder = TimestepEmbedding(in_channels=256, 23 | time_embed_dim=embedding_dim, 24 | cond_proj_dim=class_emb_dim) 25 | 26 | def forward(self, timestep, hidden_dtype, class_embedding=None): 27 | timesteps_proj = self.time_proj(timestep) 28 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype), 29 | condition=class_embedding) # (N, D) 30 | return timesteps_emb 31 | 32 | 33 | class AdaLayerNormClassEmb(nn.Module): 34 | 35 | def __init__(self, embedding_dim: int, class_emb_dim: int): 36 | super().__init__() 37 | 38 | self.emb = ClassCombinedTimestepSizeEmbeddings( 39 | embedding_dim, class_emb_dim 40 | ) 41 | self.silu = nn.SiLU() 42 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) 43 | 44 | def forward( 45 | self, 46 | timestep: torch.Tensor, 47 | class_embedding: torch.Tensor = None, 48 | hidden_dtype: Optional[torch.dtype] = None, 49 | ): 50 | embedded_timestep = self.emb(timestep, 51 | class_embedding=class_embedding, 52 | hidden_dtype=hidden_dtype) 53 | return self.linear(self.silu(embedded_timestep)), embedded_timestep 54 | 55 | 56 | def get_2d_sincos_pos_embed( 57 | embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 58 | ): 59 | if isinstance(grid_size, int): 60 | grid_size = (grid_size, grid_size) 61 | 62 | if isinstance(base_size, int): 63 | base_size = (base_size, base_size) 64 | 65 | grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale 66 | grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale 67 | grid = np.meshgrid(grid_w, grid_h) 68 | grid = np.stack(grid, axis=0) 69 | 70 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 71 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 72 | if cls_token and extra_tokens > 0: 73 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 74 | return pos_embed 75 | 76 | 77 | class PatchEmbed(nn.Module): 78 | 79 | def __init__( 80 | self, 81 | height=224, 82 | width=224, 83 | patch_size=16, 84 | in_channels=3, 85 | embed_dim=768, 86 | layer_norm=False, 87 | flatten=True, 88 | bias=True, 89 | interpolation_scale=1, 90 | ): 91 | super().__init__() 92 | 93 | self.flatten = flatten 94 | self.layer_norm = layer_norm 95 | 96 | self.proj = nn.Conv2d( 97 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias 98 | ) 99 | if layer_norm: 100 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 101 | else: 102 | self.norm = None 103 | 104 | self.patch_size = patch_size 105 | # See: 106 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 107 | self.height, self.width = height // patch_size, width // patch_size 108 | self.base_size = height // patch_size 109 | self.interpolation_scale = interpolation_scale 110 | pos_embed = get_2d_sincos_pos_embed( 111 | embed_dim, (self.height, self.width), base_size=(self.height, self.width), interpolation_scale=self.interpolation_scale 112 | ) 113 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) 114 | 115 | def forward(self, latent): 116 | height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size 117 | 118 | latent = self.proj(latent) 119 | if self.flatten: 120 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 121 | if self.layer_norm: 122 | latent = self.norm(latent) 123 | 124 | # Interpolate positional embeddings if needed. 125 | # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) 126 | if self.height != height or self.width != width: 127 | pos_embed = get_2d_sincos_pos_embed( 128 | embed_dim=self.pos_embed.shape[-1], 129 | grid_size=(height, width), 130 | base_size=(height, width), 131 | interpolation_scale=self.interpolation_scale, 132 | ) 133 | pos_embed = torch.from_numpy(pos_embed) 134 | pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) 135 | else: 136 | pos_embed = self.pos_embed 137 | 138 | return (latent + pos_embed).to(latent.dtype) 139 | 140 | 141 | class Attention(nn.Module): 142 | 143 | def __init__( 144 | self, 145 | heads: int = 8, 146 | dim_head: int = 64, 147 | dropout: float = 0.0, 148 | bias: bool = False, 149 | out_bias: bool = True, 150 | ): 151 | super().__init__() 152 | 153 | self.inner_dim = dim_head * heads 154 | self.use_bias = bias 155 | self.dropout = dropout 156 | self.heads = heads 157 | 158 | self.to_q = nn.Linear(self.inner_dim, self.inner_dim, bias=bias) 159 | self.to_k = nn.Linear(self.inner_dim, self.inner_dim, bias=bias) 160 | self.to_v = nn.Linear(self.inner_dim, self.inner_dim, bias=bias) 161 | 162 | self.to_out = nn.ModuleList([ 163 | nn.Linear(self.inner_dim, self.inner_dim, bias=out_bias), 164 | nn.Dropout(dropout) 165 | ]) 166 | 167 | def forward( 168 | self, 169 | hidden_states: torch.Tensor, 170 | encoder_hidden_states: Optional[torch.Tensor] = None, 171 | ): 172 | 173 | input_ndim = hidden_states.ndim 174 | 175 | if input_ndim == 4: 176 | batch_size, channel, height, width = hidden_states.shape 177 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 178 | 179 | batch_size = hidden_states.shape[0] if encoder_hidden_states is None else encoder_hidden_states.shape[0] 180 | 181 | query = self.to_q(hidden_states) 182 | 183 | if encoder_hidden_states is None: 184 | encoder_hidden_states = hidden_states 185 | 186 | key = self.to_k(encoder_hidden_states) 187 | value = self.to_v(encoder_hidden_states) 188 | 189 | inner_dim = key.shape[-1] 190 | head_dim = inner_dim // self.heads 191 | 192 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 193 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 194 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 195 | 196 | hidden_states = F.scaled_dot_product_attention( 197 | query, key, value, dropout_p=0.0, is_causal=False 198 | ) 199 | 200 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) 201 | hidden_states = hidden_states.to(query.dtype) 202 | 203 | hidden_states = self.to_out[0](hidden_states) 204 | hidden_states = self.to_out[1](hidden_states) 205 | 206 | if input_ndim == 4: 207 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 208 | 209 | return hidden_states 210 | 211 | 212 | class DiTBlock(nn.Module): 213 | 214 | def __init__( 215 | self, 216 | dim: int, 217 | num_attention_heads: int, 218 | attention_head_dim: int, 219 | dropout=0.0, 220 | activation_fn: str = "geglu", 221 | attention_bias: bool = False, 222 | norm_elementwise_affine: bool = True, 223 | norm_eps: float = 1e-5, 224 | final_dropout: bool = False, 225 | ff_inner_dim: Optional[int] = None, 226 | ff_bias: bool = True, 227 | attention_out_bias: bool = True, 228 | ): 229 | super().__init__() 230 | 231 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) 232 | 233 | self.attn1 = Attention( 234 | heads=num_attention_heads, 235 | dim_head=attention_head_dim, 236 | dropout=dropout, 237 | bias=attention_bias, 238 | out_bias=attention_out_bias, 239 | ) 240 | 241 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 242 | 243 | self.attn2 = Attention( 244 | heads=num_attention_heads, 245 | dim_head=attention_head_dim, 246 | dropout=dropout, 247 | bias=attention_bias, 248 | out_bias=attention_out_bias, 249 | ) 250 | 251 | self.ff = FeedForward( 252 | dim, 253 | dropout=dropout, 254 | activation_fn=activation_fn, 255 | final_dropout=final_dropout, 256 | inner_dim=ff_inner_dim, 257 | bias=ff_bias, 258 | ) 259 | 260 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 261 | 262 | def forward( 263 | self, 264 | hidden_states: torch.FloatTensor, 265 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 266 | timestep: Optional[torch.LongTensor] = None, 267 | pixel_hidden_states: Optional[torch.FloatTensor] = None, 268 | ): 269 | batch_size = hidden_states.shape[0] 270 | 271 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 272 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 273 | ).chunk(6, dim=1) 274 | norm_hidden_states = self.norm1(hidden_states) 275 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 276 | norm_hidden_states = norm_hidden_states.squeeze(1) 277 | hidden_states_len = norm_hidden_states.shape[1] 278 | attn_output = self.attn1( 279 | torch.cat([pixel_hidden_states, norm_hidden_states], dim=1), 280 | )[:, -hidden_states_len:] 281 | attn_output = gate_msa * attn_output 282 | 283 | hidden_states = attn_output + hidden_states 284 | if hidden_states.ndim == 4: 285 | hidden_states = hidden_states.squeeze(1) 286 | 287 | norm_hidden_states = hidden_states 288 | 289 | attn_output = self.attn2( 290 | norm_hidden_states, 291 | encoder_hidden_states=encoder_hidden_states, 292 | ) 293 | hidden_states = attn_output + hidden_states 294 | 295 | norm_hidden_states = self.norm2(hidden_states) 296 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 297 | 298 | ff_output = self.ff(norm_hidden_states) 299 | 300 | ff_output = gate_mlp * ff_output 301 | 302 | hidden_states = ff_output + hidden_states 303 | if hidden_states.ndim == 4: 304 | hidden_states = hidden_states.squeeze(1) 305 | 306 | return hidden_states 307 | 308 | 309 | class D3D_DiT(nn.Module): 310 | 311 | def __init__( 312 | self, 313 | num_attention_heads: int = 16, 314 | attention_head_dim: int = 72, 315 | in_channels: Optional[int] = None, 316 | out_channels: Optional[int] = None, 317 | num_layers: int = 1, 318 | dropout: float = 0.0, 319 | attention_bias: bool = False, 320 | sample_size: Optional[int] = None, 321 | patch_size: Optional[int] = None, 322 | activation_fn: str = "gelu-approximate", 323 | norm_elementwise_affine: bool = False, 324 | norm_eps: float = 1e-6, 325 | semantic_channels: int = None, 326 | pixel_channels: int = None, 327 | interpolation_scale: float = 1.0, 328 | gradient_checkpointing: bool = False, 329 | ): 330 | super().__init__() 331 | 332 | self.num_attention_heads = num_attention_heads 333 | self.attention_head_dim = attention_head_dim 334 | inner_dim = num_attention_heads * attention_head_dim 335 | 336 | if isinstance(sample_size, int): 337 | sample_size = (sample_size, sample_size) 338 | self.height = sample_size[0] 339 | self.width = sample_size[1] 340 | 341 | self.patch_size = patch_size 342 | interpolation_scale = ( 343 | interpolation_scale if interpolation_scale is not None else max(min(self.config.sample_size) // 32, 1) 344 | ) 345 | self.pos_embed = PatchEmbed( 346 | height=sample_size[0], 347 | width=sample_size[1], 348 | patch_size=patch_size, 349 | in_channels=in_channels, 350 | embed_dim=inner_dim, 351 | interpolation_scale=interpolation_scale, 352 | ) 353 | 354 | self.transformer_blocks = nn.ModuleList( 355 | [ 356 | DiTBlock( 357 | inner_dim, 358 | num_attention_heads, 359 | attention_head_dim, 360 | dropout=dropout, 361 | activation_fn=activation_fn, 362 | attention_bias=attention_bias, 363 | norm_elementwise_affine=norm_elementwise_affine, 364 | norm_eps=norm_eps, 365 | ) 366 | for d in range(num_layers) 367 | ] 368 | ) 369 | 370 | self.out_channels = in_channels if out_channels is None else out_channels 371 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 372 | self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) 373 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 374 | 375 | self.adaln_single = AdaLayerNormClassEmb(inner_dim, semantic_channels) 376 | 377 | self.semantic_projection = PixArtAlphaTextProjection(in_features=semantic_channels, hidden_size=inner_dim) 378 | self.pixel_projection = PixArtAlphaTextProjection(in_features=pixel_channels, hidden_size=inner_dim) 379 | 380 | self.gradient_checkpointing = gradient_checkpointing 381 | 382 | def forward( 383 | self, 384 | hidden_states: torch.Tensor, 385 | encoder_hidden_states: Optional[torch.Tensor] = None, 386 | timestep: Optional[torch.LongTensor] = None, 387 | pixel_hidden_states: Optional[torch.Tensor] = None, 388 | ): 389 | height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size 390 | hidden_states = self.pos_embed(hidden_states) 391 | 392 | timestep, embedded_timestep = self.adaln_single( 393 | timestep, class_embedding=encoder_hidden_states[:, 0], hidden_dtype=hidden_states.dtype 394 | ) 395 | 396 | batch_size = hidden_states.shape[0] 397 | encoder_hidden_states = self.semantic_projection(encoder_hidden_states) 398 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 399 | 400 | pixel_hidden_states = self.pixel_projection(pixel_hidden_states) 401 | pixel_hidden_states = pixel_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 402 | 403 | for block in self.transformer_blocks: 404 | if self.training and self.gradient_checkpointing: 405 | hidden_states = torch.utils.checkpoint.checkpoint( 406 | block, 407 | hidden_states, 408 | encoder_hidden_states, 409 | timestep, 410 | pixel_hidden_states, 411 | ) 412 | else: 413 | hidden_states = block( 414 | hidden_states, 415 | encoder_hidden_states=encoder_hidden_states, 416 | timestep=timestep, 417 | pixel_hidden_states=pixel_hidden_states, 418 | ) 419 | 420 | shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) 421 | hidden_states = self.norm_out(hidden_states) 422 | hidden_states = hidden_states * (1 + scale) + shift 423 | hidden_states = self.proj_out(hidden_states) 424 | hidden_states = hidden_states.squeeze(1) 425 | 426 | hidden_states = hidden_states.reshape( 427 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 428 | ) 429 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 430 | output = hidden_states.reshape( 431 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 432 | ) 433 | 434 | return output 435 | 436 | -------------------------------------------------------------------------------- /direct3d/models/vae.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/vae.py 2 | 3 | import trimesh 4 | import itertools 5 | import numpy as np 6 | from tqdm import tqdm 7 | from einops import rearrange, repeat 8 | from skimage import measure 9 | from typing import List, Tuple, Optional, Union 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from direct3d.utils.triplane import sample_from_planes, generate_planes 16 | from diffusers.models.autoencoders.vae import UNetMidBlock2D 17 | from diffusers.models.unets.unet_2d_blocks import UpDecoderBlock2D 18 | 19 | 20 | class DiagonalGaussianDistribution(object): 21 | def __init__(self, parameters: torch.Tensor, deterministic: bool=False): 22 | self.parameters = parameters 23 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 24 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 25 | self.deterministic = deterministic 26 | self.std = torch.exp(0.5 * self.logvar) 27 | self.var = torch.exp(self.logvar) 28 | if self.deterministic: 29 | self.var = self.std = torch.zeros_like( 30 | self.mean, device=self.parameters.device, dtype=self.parameters.dtype 31 | ) 32 | 33 | def sample(self): 34 | x = self.mean + self.std * torch.randn_like(self.mean) 35 | return x 36 | 37 | def kl(self, other=None): 38 | if self.deterministic: 39 | return torch.Tensor([0.]) 40 | else: 41 | if other is None: 42 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 43 | + self.var - 1.0 - self.logvar, 44 | dim=[1, 2, 3]) 45 | else: 46 | return 0.5 * torch.mean( 47 | torch.pow(self.mean - other.mean, 2) / other.var 48 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 49 | dim=[1, 2, 3]) 50 | 51 | def nll(self, sample, dims=(1, 2, 3)): 52 | if self.deterministic: 53 | return torch.Tensor([0.]) 54 | logtwopi = np.log(2.0 * np.pi) 55 | return 0.5 * torch.sum( 56 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 57 | dim=dims) 58 | 59 | def mode(self): 60 | return self.mean 61 | 62 | 63 | class FourierEmbedder(nn.Module): 64 | 65 | def __init__(self, 66 | num_freqs: int = 6, 67 | input_dim: int = 3): 68 | 69 | super().__init__() 70 | freq = 2.0 ** torch.arange(num_freqs) 71 | self.register_buffer("freq", freq, persistent=False) 72 | self.num_freqs = num_freqs 73 | self.out_dim = input_dim * (num_freqs * 2 + 1) 74 | 75 | def forward(self, x: torch.Tensor): 76 | embed = (x[..., None].contiguous() * self.freq).view(*x.shape[:-1], -1) 77 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) 78 | 79 | 80 | class OccDecoder(nn.Module): 81 | 82 | def __init__(self, 83 | n_features: int, 84 | hidden_dim: int = 64, 85 | num_layers: int = 4, 86 | activation: nn.Module = nn.ReLU, 87 | final_activation: str = None): 88 | super().__init__() 89 | 90 | self.net = nn.Sequential( 91 | nn.Linear(3 * n_features, hidden_dim), 92 | activation(), 93 | *itertools.chain(*[[ 94 | nn.Linear(hidden_dim, hidden_dim), 95 | activation(), 96 | ] for _ in range(num_layers - 2)]), 97 | nn.Linear(hidden_dim, 1), 98 | ) 99 | self.final_activation = final_activation 100 | 101 | def forward(self, sampled_features): 102 | 103 | x = rearrange(sampled_features, "N_b N_t N_s C -> N_b N_s (N_t C)") 104 | x = self.net(x) 105 | 106 | if self.final_activation is None: 107 | pass 108 | elif self.final_activation == 'tanh': 109 | x = torch.tanh(x) 110 | elif self.final_activation == 'sigmoid': 111 | x = torch.sigmoid(x) 112 | else: 113 | raise ValueError(f"Unknown final activation: {self.final_activation}") 114 | 115 | return x[..., 0] 116 | 117 | 118 | class Attention(nn.Module): 119 | 120 | def __init__(self, 121 | dim: int, 122 | heads: int = 8, 123 | dim_head: int = 64): 124 | super().__init__() 125 | inner_dim = dim_head * heads 126 | self.heads = heads 127 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 128 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 129 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 130 | self.to_out = nn.Linear(inner_dim, inner_dim) 131 | 132 | def forward(self, 133 | hidden_states: torch.Tensor, 134 | encoder_hidden_states: Optional[torch.Tensor] = None, 135 | ): 136 | batch_size = hidden_states.shape[0] 137 | if encoder_hidden_states is None: 138 | encoder_hidden_states = hidden_states 139 | key = self.to_k(encoder_hidden_states) 140 | value = self.to_v(encoder_hidden_states) 141 | query = self.to_q(hidden_states) 142 | 143 | inner_dim = key.shape[-1] 144 | head_dim = inner_dim // self.heads 145 | 146 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 147 | 148 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 149 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 150 | 151 | hidden_states = F.scaled_dot_product_attention( 152 | query, key, value 153 | ) 154 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) 155 | hidden_states = self.to_out(hidden_states) 156 | 157 | return hidden_states 158 | 159 | 160 | class TransformerBlock(nn.Module): 161 | 162 | def __init__(self, 163 | num_attention_heads: int, 164 | attention_head_dim: int, 165 | cross_attention: bool = False): 166 | super().__init__() 167 | inner_dim = attention_head_dim * num_attention_heads 168 | self.norm1 = nn.LayerNorm(inner_dim) 169 | if cross_attention: 170 | self.norm1_c = nn.LayerNorm(inner_dim) 171 | else: 172 | self.norm1_c = None 173 | self.attn = Attention(inner_dim, num_attention_heads, attention_head_dim) 174 | self.norm2 = nn.LayerNorm(inner_dim) 175 | self.mlp = nn.Sequential( 176 | nn.Linear(inner_dim, 4 * inner_dim), 177 | nn.GELU(), 178 | nn.Linear(4 * inner_dim, inner_dim), 179 | ) 180 | 181 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None): 182 | if self.norm1_c is not None: 183 | x = self.attn(self.norm1(x), self.norm1_c(y)) + x 184 | else: 185 | x = self.attn(self.norm1(x)) + x 186 | x = x + self.mlp(self.norm2(x)) 187 | return x 188 | 189 | 190 | class PointEncoder(nn.Module): 191 | 192 | def __init__(self, 193 | num_latents: int, 194 | in_channels: int, 195 | num_attention_heads: int, 196 | attention_head_dim: int, 197 | num_layers: int, 198 | gradient_checkpointing: bool = False): 199 | 200 | super().__init__() 201 | 202 | self.gradient_checkpointing = gradient_checkpointing 203 | self.num_latents = num_latents 204 | inner_dim = attention_head_dim * num_attention_heads 205 | 206 | self.learnable_token = nn.Parameter(torch.randn((num_latents, inner_dim)) * 0.01) 207 | 208 | self.proj_in = nn.Linear(in_channels, inner_dim) 209 | self.cross_attn = TransformerBlock(num_attention_heads, attention_head_dim, cross_attention=True) 210 | 211 | self.self_attn = nn.ModuleList([ 212 | TransformerBlock(num_attention_heads, attention_head_dim) for _ in range(num_layers) 213 | ]) 214 | 215 | self.norm_out = nn.LayerNorm(inner_dim) 216 | 217 | def forward(self, pc): 218 | 219 | bs = pc.shape[0] 220 | pc = self.proj_in(pc) 221 | 222 | learnable_token = repeat(self.learnable_token, "m c -> b m c", b=bs) 223 | 224 | if self.training and self.gradient_checkpointing: 225 | latents = torch.utils.checkpoint.checkpoint(self.cross_attn, learnable_token, pc) 226 | for block in self.self_attn: 227 | latents = torch.utils.checkpoint.checkpoint(block, latents) 228 | else: 229 | latents = self.cross_attn(learnable_token, pc) 230 | for block in self.self_attn: 231 | latents = block(latents) 232 | 233 | latents = self.norm_out(latents) 234 | 235 | return latents 236 | 237 | 238 | class TriplaneDecoder(nn.Module): 239 | 240 | def __init__( 241 | self, 242 | in_channels: int = 3, 243 | out_channels: int = 3, 244 | block_out_channels: Tuple[int, ...] = (64,), 245 | layers_per_block: int = 2, 246 | norm_num_groups: int = 32, 247 | act_fn: str = "silu", 248 | norm_type: str = "group", 249 | mid_block_add_attention=True, 250 | gradient_checkpointing: bool = False, 251 | ): 252 | super().__init__() 253 | self.layers_per_block = layers_per_block 254 | 255 | self.conv_in = nn.Conv2d( 256 | in_channels, 257 | block_out_channels[-1], 258 | kernel_size=3, 259 | stride=1, 260 | padding=1, 261 | ) 262 | 263 | self.up_blocks = nn.ModuleList([]) 264 | 265 | temb_channels = in_channels if norm_type == "spatial" else None 266 | 267 | # mid 268 | self.mid_block = UNetMidBlock2D( 269 | in_channels=block_out_channels[-1], 270 | resnet_eps=1e-6, 271 | resnet_act_fn=act_fn, 272 | output_scale_factor=1, 273 | resnet_time_scale_shift="default" if norm_type == "group" else norm_type, 274 | attention_head_dim=block_out_channels[-1], 275 | resnet_groups=norm_num_groups, 276 | temb_channels=temb_channels, 277 | add_attention=mid_block_add_attention, 278 | ) 279 | 280 | # up 281 | reversed_block_out_channels = list(reversed(block_out_channels)) 282 | output_channel = reversed_block_out_channels[0] 283 | for i in range(len(block_out_channels)): 284 | prev_output_channel = output_channel 285 | output_channel = reversed_block_out_channels[i] 286 | 287 | is_final_block = i == len(block_out_channels) - 1 288 | up_block = UpDecoderBlock2D( 289 | num_layers=self.layers_per_block + 1, 290 | in_channels=prev_output_channel, 291 | out_channels=output_channel, 292 | add_upsample=not is_final_block, 293 | resnet_eps=1e-6, 294 | resnet_act_fn=act_fn, 295 | resnet_groups=norm_num_groups, 296 | resnet_time_scale_shift=norm_type, 297 | temb_channels=temb_channels, 298 | ) 299 | self.up_blocks.append(up_block) 300 | prev_output_channel = output_channel 301 | 302 | # out 303 | if norm_type == "group": 304 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) 305 | else: 306 | raise ValueError(f"Unsupported norm type: {norm_type}") 307 | 308 | self.conv_act = nn.SiLU() 309 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) 310 | 311 | self.gradient_checkpointing = gradient_checkpointing 312 | 313 | def forward(self, sample: torch.Tensor): 314 | r"""The forward method of the `Decoder` class.""" 315 | 316 | sample = self.conv_in(sample) 317 | 318 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype 319 | if self.training and self.gradient_checkpointing: 320 | # middle 321 | sample = torch.utils.checkpoint.checkpoint( 322 | self.mid_block, sample 323 | ) 324 | sample = sample.to(upscale_dtype) 325 | 326 | # up 327 | for up_block in self.up_blocks: 328 | sample = torch.utils.checkpoint.checkpoint(up_block, sample) 329 | else: 330 | # middle 331 | sample = self.mid_block(sample) 332 | sample = sample.to(upscale_dtype) 333 | 334 | # up 335 | for up_block in self.up_blocks: 336 | sample = up_block(sample) 337 | 338 | # post-process 339 | sample = self.conv_norm_out(sample) 340 | sample = self.conv_act(sample) 341 | sample = self.conv_out(sample) 342 | 343 | return sample 344 | 345 | 346 | class D3D_VAE(nn.Module): 347 | def __init__(self, 348 | triplane_res: int, 349 | latent_dim: int = 0, 350 | triplane_dim: int = 32, 351 | num_freqs: int = 8, 352 | num_attention_heads: int = 12, 353 | attention_head_dim: int = 64, 354 | num_encoder_layers: int = 8, 355 | num_geodecoder_layers: int = 5, 356 | final_activation: str = None, 357 | block_out_channels=[128, 256, 512, 512], 358 | mid_block_add_attention=True, 359 | gradient_checkpointing: bool = False, 360 | latents_scale: float = 1.0, 361 | latents_shift: float = 0.0): 362 | 363 | super().__init__() 364 | 365 | self.gradient_checkpointing = gradient_checkpointing 366 | 367 | self.triplane_res = triplane_res 368 | self.num_latents = triplane_res ** 2 * 3 369 | self.latent_shape = (latent_dim, triplane_res, 3 * triplane_res) 370 | self.latents_scale = latents_scale 371 | self.latents_shift = latents_shift 372 | 373 | self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs) 374 | 375 | inner_dim = attention_head_dim * num_attention_heads 376 | self.encoder = PointEncoder( 377 | num_latents=self.num_latents, 378 | in_channels=self.fourier_embedder.out_dim + 3, 379 | num_attention_heads=num_attention_heads, 380 | attention_head_dim=attention_head_dim, 381 | num_layers=num_encoder_layers, 382 | gradient_checkpointing=gradient_checkpointing, 383 | ) 384 | 385 | self.latent_dim = latent_dim 386 | 387 | self.pre_latent = nn.Conv2d(inner_dim, 2 * latent_dim, 1) 388 | self.post_latent = nn.Conv2d(latent_dim, inner_dim, 1) 389 | 390 | self.decoder = TriplaneDecoder( 391 | in_channels=inner_dim, 392 | out_channels=triplane_dim, 393 | block_out_channels=block_out_channels, 394 | mid_block_add_attention=mid_block_add_attention, 395 | gradient_checkpointing=gradient_checkpointing, 396 | ) 397 | 398 | self.plane_axes = generate_planes() 399 | self.occ_decoder = OccDecoder( 400 | n_features=triplane_dim, 401 | num_layers=num_geodecoder_layers, 402 | final_activation=final_activation, 403 | ) 404 | 405 | def rollout(self, triplane): 406 | triplane = rearrange(triplane, "N_b (N_t C) H_t W_t -> N_b C H_t (N_t W_t)", N_t=3) 407 | return triplane 408 | 409 | def unrollout(self, triplane): 410 | triplane = rearrange(triplane, "N_b C H_t (N_t W_t) -> N_b N_t C H_t W_t", N_t=3) 411 | return triplane 412 | 413 | def encode(self, 414 | pc: torch.FloatTensor, 415 | feats: Optional[torch.FloatTensor] = None): 416 | 417 | x = self.fourier_embedder(pc) 418 | if feats is not None: 419 | x = torch.cat((x, feats), dim=-1) 420 | x = self.encoder(x) 421 | x = rearrange(x, "N_b (N_t H_t W_t) C -> N_b (N_t C) H_t W_t", 422 | N_t=3, H_t=self.triplane_res, W_t=self.triplane_res) 423 | x = self.rollout(x) 424 | moments = self.pre_latent(x) 425 | 426 | posterior = DiagonalGaussianDistribution(moments) 427 | latents = posterior.sample() 428 | 429 | return latents, posterior 430 | 431 | def decode(self, z, unrollout=False): 432 | z = self.post_latent(z) 433 | dec = self.decoder(z) 434 | if unrollout: 435 | dec = self.unrollout(dec) 436 | return dec 437 | 438 | def decode_mesh(self, 439 | latents, 440 | bounds: Union[Tuple[float], List[float], float] = 1.0, 441 | voxel_resolution: int = 512, 442 | mc_threshold: float = 0.0): 443 | triplane = self.decode(latents, unrollout=True) 444 | mesh = self.triplane2mesh(triplane, 445 | bounds=bounds, 446 | voxel_resolution=voxel_resolution, 447 | mc_threshold=mc_threshold) 448 | return mesh 449 | 450 | def triplane2mesh(self, 451 | latents: torch.FloatTensor, 452 | bounds: Union[Tuple[float], List[float], float] = 1.0, 453 | voxel_resolution: int = 512, 454 | mc_threshold: float = 0.0, 455 | chunk_size: int = 50000): 456 | 457 | batch_size = len(latents) 458 | 459 | if isinstance(bounds, float): 460 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] 461 | bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) 462 | bbox_length = bbox_max - bbox_min 463 | 464 | x = torch.linspace(bbox_min[0], bbox_max[0], steps=int(voxel_resolution) + 1) 465 | y = torch.linspace(bbox_min[1], bbox_max[1], steps=int(voxel_resolution) + 1) 466 | z = torch.linspace(bbox_min[2], bbox_max[2], steps=int(voxel_resolution) + 1) 467 | xs, ys, zs = torch.meshgrid(x, y, z, indexing='ij') 468 | xyz = torch.stack((xs, ys, zs), dim=-1) 469 | xyz = xyz.reshape(-1, 3) 470 | grid_size = [int(voxel_resolution) + 1, int(voxel_resolution) + 1, int(voxel_resolution) + 1] 471 | 472 | logits_total = [] 473 | for start in tqdm(range(0, xyz.shape[0], chunk_size), desc="Triplane Sampling:"): 474 | positions = xyz[start:start + chunk_size].to(latents.device) 475 | positions = repeat(positions, "p d -> b p d", b=batch_size) 476 | 477 | triplane_features = sample_from_planes(self.plane_axes.to(latents.device), 478 | latents, positions, 479 | box_warp=2.0) 480 | logits = self.occ_decoder(triplane_features) 481 | logits_total.append(logits) 482 | 483 | logits_total = torch.cat(logits_total, dim=1).view( 484 | (batch_size, grid_size[0], grid_size[1], grid_size[2])).cpu().numpy() 485 | 486 | meshes = [] 487 | for i in range(batch_size): 488 | vertices, faces, _, _ = measure.marching_cubes( 489 | logits_total[i], 490 | mc_threshold, 491 | method="lewiner" 492 | ) 493 | vertices = vertices / grid_size * bbox_length + bbox_min 494 | faces = faces[:, ::-1] 495 | meshes.append(trimesh.Trimesh(vertices, faces)) 496 | return meshes 497 | 498 | -------------------------------------------------------------------------------- /direct3d/pipeline.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py 2 | 3 | import os 4 | from tqdm import tqdm 5 | from PIL import Image 6 | from omegaconf import OmegaConf 7 | from huggingface_hub import hf_hub_download 8 | from typing import Union, List, Optional 9 | 10 | import torch 11 | from direct3d.utils import instantiate_from_config, preprocess 12 | from diffusers.utils.torch_utils import randn_tensor 13 | 14 | 15 | class Direct3dPipeline(object): 16 | 17 | def __init__(self, 18 | vae, 19 | dit, 20 | semantic_encoder, 21 | pixel_encoder, 22 | scheduler): 23 | self.vae = vae 24 | self.dit = dit 25 | self.semantic_encoder = semantic_encoder 26 | self.pixel_encoder = pixel_encoder 27 | self.scheduler = scheduler 28 | 29 | def to(self, device): 30 | self.device = torch.device(device) 31 | self.vae.to(device) 32 | self.dit.to(device) 33 | self.semantic_encoder.to(device) 34 | self.pixel_encoder.to(device) 35 | 36 | @classmethod 37 | def from_pretrained(cls, 38 | pipeline_path): 39 | 40 | if os.path.isdir(pipeline_path): 41 | config_path = os.path.join(pipeline_path, 'config.yaml') 42 | model_path = os.path.join(pipeline_path, 'model.ckpt') 43 | else: 44 | config_path = hf_hub_download(repo_id=pipeline_path, filename="config.yaml", repo_type="model") 45 | model_path = hf_hub_download(repo_id=pipeline_path, filename="model.ckpt", repo_type="model") 46 | 47 | cfg = OmegaConf.load(config_path) 48 | state_dict = torch.load(model_path, map_location='cpu') 49 | 50 | vae = instantiate_from_config(cfg.vae) 51 | vae.load_state_dict(state_dict["vae"], strict=True) 52 | dit = instantiate_from_config(cfg.dit) 53 | dit.load_state_dict(state_dict["dit"], strict=True) 54 | 55 | semantic_encoder = instantiate_from_config(cfg.semantic_encoder) 56 | pixel_encoder = instantiate_from_config(cfg.pixel_encoder) 57 | 58 | scheduler = instantiate_from_config(cfg.scheduler) 59 | 60 | return cls( 61 | vae=vae, 62 | dit=dit, 63 | semantic_encoder=semantic_encoder, 64 | pixel_encoder=pixel_encoder, 65 | scheduler=scheduler) 66 | 67 | def prepare_image(self, image: Union[str, List[str], Image.Image, List[Image.Image]], rmbg: bool = True): 68 | if not isinstance(image, list): 69 | image = [image] 70 | if isinstance(image[0], str): 71 | image = [Image.open(img) for img in image] 72 | image = [preprocess(img, rmbg=rmbg) for img in image] 73 | image = torch.stack([img for img in image]).to(self.device) 74 | return image 75 | 76 | def encode_image(self, image: torch.Tensor, do_classifier_free_guidance: bool = True): 77 | semantic_cond = self.semantic_encoder(image) 78 | pixel_cond = self.pixel_encoder(image) 79 | if do_classifier_free_guidance: 80 | semantic_uncond = torch.zeros_like(semantic_cond) 81 | pixel_uncond = torch.zeros_like(pixel_cond) 82 | semantic_cond = torch.cat([semantic_uncond, semantic_cond], dim=0) 83 | pixel_cond = torch.cat([pixel_uncond, pixel_cond], dim=0) 84 | 85 | return semantic_cond, pixel_cond 86 | 87 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator): 88 | shape = ( 89 | batch_size, 90 | num_channels_latents, 91 | height, 92 | width, 93 | ) 94 | if isinstance(generator, list) and len(generator) != batch_size: 95 | raise ValueError( 96 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 97 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 98 | ) 99 | 100 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 101 | return latents 102 | 103 | @torch.no_grad() 104 | def __call__( 105 | self, 106 | image: Union[str, List[str], Image.Image, List[Image.Image]] = None, 107 | num_inference_steps: int = 50, 108 | guidance_scale: float = 4.0, 109 | generator: Optional[torch.Generator] = None, 110 | mc_threshold: float = -2.0, 111 | remove_background: bool = True,): 112 | 113 | batch_size = len(image) if isinstance(image, list) else 1 114 | do_classifier_free_guidance = guidance_scale > 0 115 | 116 | self.scheduler.set_timesteps(num_inference_steps, device=self.device) 117 | timesteps = self.scheduler.timesteps 118 | 119 | image = self.prepare_image(image, remove_background) 120 | semantic_cond, pixel_cond = self.encode_image(image, do_classifier_free_guidance) 121 | latents = self.prepare_latents( 122 | batch_size=batch_size, 123 | num_channels_latents=self.vae.latent_shape[0], 124 | height=self.vae.latent_shape[1], 125 | width=self.vae.latent_shape[2], 126 | dtype=image.dtype, 127 | device=self.device, 128 | generator=generator, 129 | ) 130 | 131 | extra_step_kwargs = { 132 | "generator": generator 133 | } 134 | 135 | for i, t in enumerate(tqdm(timesteps, desc="Diffusion Sampling:")): 136 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 137 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 138 | 139 | t = t.expand(latent_model_input.shape[0]) 140 | 141 | noise_pred = self.dit( 142 | hidden_states=latent_model_input, 143 | timestep=t, 144 | encoder_hidden_states=semantic_cond, 145 | pixel_hidden_states=pixel_cond, 146 | ) 147 | 148 | if do_classifier_free_guidance: 149 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 150 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 151 | 152 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 153 | 154 | latents = 1. / self.vae.latents_scale * latents + self.vae.latents_shift 155 | meshes = self.vae.decode_mesh(latents, mc_threshold=mc_threshold) 156 | outputs = {"meshes": meshes, "latents": latents} 157 | 158 | return outputs 159 | -------------------------------------------------------------------------------- /direct3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import instantiate_from_config, get_obj_from_str 2 | from .image import preprocess -------------------------------------------------------------------------------- /direct3d/utils/image.py: -------------------------------------------------------------------------------- 1 | import rembg 2 | import numpy as np 3 | from PIL import Image 4 | from torchvision import transforms as T 5 | 6 | 7 | def crop_recenter(image_no_bg, thereshold=100): 8 | image_no_bg_np = np.array(image_no_bg) 9 | if image_no_bg_np.shape[2] == 3: 10 | return image_no_bg 11 | mask = (image_no_bg_np[..., -1]).astype(np.uint8) 12 | mask_bin = mask > thereshold 13 | 14 | H, W = image_no_bg_np.shape[:2] 15 | 16 | valid_pixels = mask_bin.astype(np.float32).nonzero() 17 | if np.sum(mask_bin) < (H*W) * 0.001: 18 | min_h = 0 19 | max_h = H - 1 20 | min_w = 0 21 | max_w = W -1 22 | else: 23 | min_h, max_h = valid_pixels[0].min(), valid_pixels[0].max() 24 | min_w, max_w = valid_pixels[1].min(), valid_pixels[1].max() 25 | 26 | if min_h < 0: 27 | min_h = 0 28 | if min_w < 0: 29 | min_w = 0 30 | if max_h > H: 31 | max_h = H - 1 32 | if max_w > W: 33 | max_w = W - 1 34 | 35 | image_no_bg_np = image_no_bg_np[min_h:max_h+1, min_w:max_w+1] 36 | image_no_bg = Image.fromarray(image_no_bg_np) 37 | return image_no_bg 38 | 39 | 40 | def pad_to_same_size(image, pad_value=1): 41 | image = np.array(image) 42 | h, w, _ = image.shape 43 | image_temp = image.copy() 44 | if h != w: 45 | # find the max one and pad the other side with white 46 | max_size = max(h, w) 47 | 48 | pad_h = max_size - h 49 | pad_w = max_size - w 50 | pad_h_top = max(pad_h // 2, 0) 51 | pad_h_bottom = max(pad_h - pad_h_top, 0) 52 | pad_w_left = max(pad_w // 2, 0) 53 | pad_w_right = max(pad_w - pad_w_left, 0) 54 | 55 | image_temp = np.pad( 56 | image[..., :3], 57 | ((pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)), 58 | constant_values=pad_value 59 | ) 60 | if image.shape[2] == 4: 61 | image_bg = np.pad( 62 | image[..., 3:], 63 | ((pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)), 64 | constant_values=0 65 | ) 66 | image = np.concatenate([image_temp, image_bg], axis=2) 67 | else: 68 | image = image_temp 69 | 70 | return Image.fromarray(image) 71 | 72 | 73 | def remove_bg(image): 74 | image = rembg.remove(image) 75 | return image 76 | 77 | 78 | def preprocess(image, rmbg=True): 79 | 80 | if rmbg: 81 | image = remove_bg(image) 82 | 83 | image = crop_recenter(image) 84 | image = pad_to_same_size(image, pad_value=255) 85 | image = np.array(image) 86 | image = image / 255. 87 | if image.shape[2] == 4: 88 | image = image[..., :3] * image[..., 3:] + (1 - image[..., 3:]) 89 | image = Image.fromarray((image * 255).astype('uint8'), "RGB") 90 | 91 | W, H = image.size[:2] 92 | pad_margin = int(W * 0.04) 93 | image_transforms = T.Compose([ 94 | T.Pad((pad_margin, pad_margin, pad_margin, pad_margin), fill=255), 95 | T.ToTensor(), 96 | ]) 97 | 98 | image = image_transforms(image) 99 | 100 | return image 101 | -------------------------------------------------------------------------------- /direct3d/utils/triplane.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | 6 | def generate_planes(): 7 | """ 8 | Defines planes by the three vectors that form the "axes" of the 9 | plane. Should work with arbitrary number of planes and planes of 10 | arbitrary orientation. 11 | 12 | Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 13 | """ 14 | return torch.tensor([[[1, 0, 0], 15 | [0, 1, 0], 16 | [0, 0, 1]], 17 | [[1, 0, 0], 18 | [0, 0, 1], 19 | [0, 1, 0]], 20 | [[0, 0, 1], 21 | [0, 1, 0], 22 | [1, 0, 0]]]) 23 | 24 | 25 | def project_onto_planes(planes, coordinates): 26 | """ 27 | Does a projection of a 3D point onto a batch of 2D planes, 28 | returning 2D plane coordinates. 29 | 30 | Takes plane axes of shape n_planes, 3, 3 31 | # Takes coordinates of shape N, M, 3 32 | # returns projections of shape N*n_planes, M, 2 33 | """ 34 | N, M, C = coordinates.shape 35 | n_planes, _, _ = planes.shape 36 | coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) 37 | inv_planes = torch.linalg.inv(planes.to(coordinates.dtype)).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) 38 | projections = torch.bmm(coordinates, inv_planes) 39 | return projections[..., :2] 40 | 41 | 42 | def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): 43 | assert padding_mode == 'zeros' 44 | N, n_planes, C, H, W = plane_features.shape 45 | _, M, _ = coordinates.shape 46 | plane_features = rearrange(plane_features, "N_b N_t C H_t W_t -> (N_b N_t) C H_t W_t") 47 | 48 | coordinates = (2/box_warp) * coordinates 49 | 50 | projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) 51 | output_features = F.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) 52 | return output_features -------------------------------------------------------------------------------- /direct3d/utils/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def instantiate_from_config(config): 5 | if not "target" in config: 6 | if config == '__is_first_stage__': 7 | return None 8 | elif config == "__is_unconditional__": 9 | return None 10 | raise KeyError("Expected key `target` to instantiate.") 11 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 12 | 13 | 14 | def get_obj_from_str(string, reload=False): 15 | module, cls = string.rsplit(".", 1) 16 | if reload: 17 | module_imp = importlib.import_module(module) 18 | importlib.reload(module_imp) 19 | return getattr(importlib.import_module(module, package=None), cls) 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.0 2 | scikit-image 3 | trimesh 4 | tqdm 5 | einops 6 | numpy 7 | transformers==4.40.2 8 | diffusers 9 | rembg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name="direct3d", 6 | version="1.0.0", 7 | description="Direct3D: Scalable Image-to-3D Generation via 3D Latent Diffusion Transformer", 8 | packages=find_packages(), 9 | python_requires=">=3.10", 10 | install_requires=[ 11 | "torch", 12 | "numpy", 13 | "cython", 14 | "trimesh", 15 | "diffusers", 16 | ], 17 | ) --------------------------------------------------------------------------------