├── .gitignore
├── LICENSE
├── README.md
├── assets
├── bear.png
├── book.png
├── car.png
├── castle.png
├── demo
│ ├── video1.gif
│ └── video2.gif
├── devil.png
├── dragon.png
├── earth.png
├── figure
│ ├── framework.png
│ └── teaser.gif
├── fish.png
├── girl.png
├── hamburger.png
├── man.png
├── panda.png
├── parrot.png
├── phoenix.png
├── riding.png
└── sunglasses.png
├── direct3d
├── models
│ ├── __init__.py
│ ├── condition.py
│ ├── dit.py
│ └── vae.py
├── pipeline.py
└── utils
│ ├── __init__.py
│ ├── image.py
│ ├── triplane.py
│ └── util.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *__pycache__*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Direct3D: Scalable Image-to-3D Generation via 3D Latent Diffusion Transformer (NeurIPS 2024)
3 |
4 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | ---
19 |
20 | ## ✨ News
21 |
22 | - Feb 11, 2025: 🔨 We are working on the Gradio demo and will release it soon!
23 | - Feb 11, 2025: 🎁 Enjoy our improved version of Direct3D with high quality geometry and texture at [https://www.neural4d.com](https://www.neural4d.com/).
24 | - Feb 11, 2025: 🚀 Release inference code of Direct3D and the pretrained models are available at 🤗 [Hugging Face](https://huggingface.co/DreamTechAI/Direct3D/tree/main).
25 |
26 | ## 📝 Abstract
27 |
28 | We introduce **Direct3D**, a native 3D generative model scalable to in-the-wild input images, without requiring a multiview diffusion model or SDS optimization. Our approach comprises two primary components: a Direct 3D Variational Auto-Encoder **(D3D-VAE)** and a Direct 3D Diffusion Transformer **(D3D-DiT)**. D3D-VAE efficiently encodes high-resolution 3D shapes into a compact and continuous latent triplane space. Notably, our method directly supervises the decoded geometry using a semi-continuous surface sampling strategy, diverging from previous methods relying on rendered images as supervision signals. D3D-DiT models the distribution of encoded 3D latents and is specifically designed to fuse positional information from the three feature maps of the triplane latent, enabling a native 3D generative model scalable to large-scale 3D datasets. Additionally, we introduce an innovative image-to-3D generation pipeline incorporating semantic and pixel-level image conditions, allowing the model to produce 3D shapes consistent with the provided conditional image input. Extensive experiments demonstrate the superiority of our large-scale pre-trained Direct3D over previous image-to-3D approaches, achieving significantly better generation quality and generalization ability, thus establishing a new state-of-the-art for 3D content creation.
29 |
30 |
31 |
32 |
33 |
34 |
35 | ## 🚀 Getting Started
36 |
37 | ### Installation
38 |
39 | ```sh
40 | git clone https://github.com/DreamTechAI/Direct3D.git
41 |
42 | cd Direct3D
43 |
44 | pip install -r requirements.txt
45 |
46 | pip install -e .
47 | ```
48 |
49 | ### Usage
50 |
51 | ```python
52 | from direct3d.pipeline import Direct3dPipeline
53 | pipeline = Direct3dPipeline.from_pretrained("DreamTechAI/Direct3D")
54 | pipeline.to("cuda")
55 | mesh = pipeline(
56 | "assets/devil.png",
57 | remove_background=False, # set to True if the background of the image needs to be removed
58 | mc_threshold=-1.0,
59 | guidance_scale=4.0,
60 | num_inference_steps=50,
61 | )["meshes"][0]
62 | mesh.export("output.obj")
63 | ```
64 |
65 | ## 🤗 Acknowledgements
66 |
67 | Thanks to the following repos for their great work, which helps us a lot in the development of Direct3D:
68 |
69 | - [3DShape2VecSet](https://github.com/1zb/3DShape2VecSet/tree/master)
70 | - [Michelangelo](https://github.com/NeuralCarver/Michelangelo)
71 | - [Objaverse](https://objaverse.allenai.org/)
72 | - [diffusers](https://github.com/huggingface/diffusers)
73 |
74 | ## 📖 Citation
75 |
76 | If you find our work useful, please consider citing our paper:
77 |
78 | ```bibtex
79 | @article{direct3d,
80 | title={Direct3D: Scalable Image-to-3D Generation via 3D Latent Diffusion Transformer},
81 | author={Wu, Shuang and Lin, Youtian and Zhang, Feihu and Zeng, Yifei and Xu, Jingxi and Torr, Philip and Cao, Xun and Yao, Yao},
82 | journal={arXiv preprint arXiv:2405.14832},
83 | year={2024}
84 | }
85 | ```
86 |
87 | ---
88 |
--------------------------------------------------------------------------------
/assets/bear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/bear.png
--------------------------------------------------------------------------------
/assets/book.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/book.png
--------------------------------------------------------------------------------
/assets/car.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/car.png
--------------------------------------------------------------------------------
/assets/castle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/castle.png
--------------------------------------------------------------------------------
/assets/demo/video1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/demo/video1.gif
--------------------------------------------------------------------------------
/assets/demo/video2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/demo/video2.gif
--------------------------------------------------------------------------------
/assets/devil.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/devil.png
--------------------------------------------------------------------------------
/assets/dragon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/dragon.png
--------------------------------------------------------------------------------
/assets/earth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/earth.png
--------------------------------------------------------------------------------
/assets/figure/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/figure/framework.png
--------------------------------------------------------------------------------
/assets/figure/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/figure/teaser.gif
--------------------------------------------------------------------------------
/assets/fish.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/fish.png
--------------------------------------------------------------------------------
/assets/girl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/girl.png
--------------------------------------------------------------------------------
/assets/hamburger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/hamburger.png
--------------------------------------------------------------------------------
/assets/man.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/man.png
--------------------------------------------------------------------------------
/assets/panda.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/panda.png
--------------------------------------------------------------------------------
/assets/parrot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/parrot.png
--------------------------------------------------------------------------------
/assets/phoenix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/phoenix.png
--------------------------------------------------------------------------------
/assets/riding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/riding.png
--------------------------------------------------------------------------------
/assets/sunglasses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/assets/sunglasses.png
--------------------------------------------------------------------------------
/direct3d/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DreamTechAI/Direct3D/e786532fb4ed4564fe352c67267f1728c92c9da6/direct3d/models/__init__.py
--------------------------------------------------------------------------------
/direct3d/models/condition.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import CLIPModel, AutoModel
3 | from torchvision import transforms as T
4 |
5 | class ClipImageEncoder(nn.Module):
6 |
7 | def __init__(self, version="openai/clip-vit-large-patch14", img_size=224):
8 | super().__init__()
9 |
10 | encoder = CLIPModel.from_pretrained(version)
11 | encoder = encoder.eval()
12 | self.encoder = encoder
13 | self.transform = T.Compose(
14 | [
15 | T.Resize(img_size, antialias=True),
16 | T.Normalize(
17 | mean=[0.48145466, 0.4578275, 0.40821073],
18 | std=[0.26862954, 0.26130258, 0.27577711],
19 | ),
20 | ]
21 | )
22 |
23 | def forward(self, image):
24 | image = self.transform(image)
25 | embbed = self.encoder.vision_model(image).last_hidden_state
26 | return embbed
27 |
28 |
29 | class DinoEncoder(nn.Module):
30 |
31 | def __init__(self, version="facebook/dinov2-large", img_size=224):
32 | super().__init__()
33 |
34 | encoder = AutoModel.from_pretrained(version)
35 | encoder = encoder.eval()
36 | self.encoder = encoder
37 | self.transform = T.Compose(
38 | [
39 | T.Resize(img_size, antialias=True),
40 | T.Normalize(
41 | mean=[0.485, 0.456, 0.406],
42 | std=[0.229, 0.224, 0.225],
43 | ),
44 | ]
45 | )
46 |
47 | def forward(self, image):
48 | image = self.transform(image)
49 | embbed = self.encoder(image).last_hidden_state
50 | return embbed
51 |
52 |
--------------------------------------------------------------------------------
/direct3d/models/dit.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py
2 |
3 | from typing import Optional
4 |
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import nn
10 |
11 | from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection, get_2d_sincos_pos_embed_from_grid
12 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
13 | from diffusers.models.attention import FeedForward
14 |
15 |
16 | class ClassCombinedTimestepSizeEmbeddings(nn.Module):
17 |
18 | def __init__(self, embedding_dim, class_emb_dim):
19 | super().__init__()
20 |
21 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
22 | self.timestep_embedder = TimestepEmbedding(in_channels=256,
23 | time_embed_dim=embedding_dim,
24 | cond_proj_dim=class_emb_dim)
25 |
26 | def forward(self, timestep, hidden_dtype, class_embedding=None):
27 | timesteps_proj = self.time_proj(timestep)
28 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype),
29 | condition=class_embedding) # (N, D)
30 | return timesteps_emb
31 |
32 |
33 | class AdaLayerNormClassEmb(nn.Module):
34 |
35 | def __init__(self, embedding_dim: int, class_emb_dim: int):
36 | super().__init__()
37 |
38 | self.emb = ClassCombinedTimestepSizeEmbeddings(
39 | embedding_dim, class_emb_dim
40 | )
41 | self.silu = nn.SiLU()
42 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
43 |
44 | def forward(
45 | self,
46 | timestep: torch.Tensor,
47 | class_embedding: torch.Tensor = None,
48 | hidden_dtype: Optional[torch.dtype] = None,
49 | ):
50 | embedded_timestep = self.emb(timestep,
51 | class_embedding=class_embedding,
52 | hidden_dtype=hidden_dtype)
53 | return self.linear(self.silu(embedded_timestep)), embedded_timestep
54 |
55 |
56 | def get_2d_sincos_pos_embed(
57 | embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
58 | ):
59 | if isinstance(grid_size, int):
60 | grid_size = (grid_size, grid_size)
61 |
62 | if isinstance(base_size, int):
63 | base_size = (base_size, base_size)
64 |
65 | grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale
66 | grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale
67 | grid = np.meshgrid(grid_w, grid_h)
68 | grid = np.stack(grid, axis=0)
69 |
70 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
71 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
72 | if cls_token and extra_tokens > 0:
73 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
74 | return pos_embed
75 |
76 |
77 | class PatchEmbed(nn.Module):
78 |
79 | def __init__(
80 | self,
81 | height=224,
82 | width=224,
83 | patch_size=16,
84 | in_channels=3,
85 | embed_dim=768,
86 | layer_norm=False,
87 | flatten=True,
88 | bias=True,
89 | interpolation_scale=1,
90 | ):
91 | super().__init__()
92 |
93 | self.flatten = flatten
94 | self.layer_norm = layer_norm
95 |
96 | self.proj = nn.Conv2d(
97 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
98 | )
99 | if layer_norm:
100 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
101 | else:
102 | self.norm = None
103 |
104 | self.patch_size = patch_size
105 | # See:
106 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
107 | self.height, self.width = height // patch_size, width // patch_size
108 | self.base_size = height // patch_size
109 | self.interpolation_scale = interpolation_scale
110 | pos_embed = get_2d_sincos_pos_embed(
111 | embed_dim, (self.height, self.width), base_size=(self.height, self.width), interpolation_scale=self.interpolation_scale
112 | )
113 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
114 |
115 | def forward(self, latent):
116 | height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
117 |
118 | latent = self.proj(latent)
119 | if self.flatten:
120 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
121 | if self.layer_norm:
122 | latent = self.norm(latent)
123 |
124 | # Interpolate positional embeddings if needed.
125 | # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
126 | if self.height != height or self.width != width:
127 | pos_embed = get_2d_sincos_pos_embed(
128 | embed_dim=self.pos_embed.shape[-1],
129 | grid_size=(height, width),
130 | base_size=(height, width),
131 | interpolation_scale=self.interpolation_scale,
132 | )
133 | pos_embed = torch.from_numpy(pos_embed)
134 | pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
135 | else:
136 | pos_embed = self.pos_embed
137 |
138 | return (latent + pos_embed).to(latent.dtype)
139 |
140 |
141 | class Attention(nn.Module):
142 |
143 | def __init__(
144 | self,
145 | heads: int = 8,
146 | dim_head: int = 64,
147 | dropout: float = 0.0,
148 | bias: bool = False,
149 | out_bias: bool = True,
150 | ):
151 | super().__init__()
152 |
153 | self.inner_dim = dim_head * heads
154 | self.use_bias = bias
155 | self.dropout = dropout
156 | self.heads = heads
157 |
158 | self.to_q = nn.Linear(self.inner_dim, self.inner_dim, bias=bias)
159 | self.to_k = nn.Linear(self.inner_dim, self.inner_dim, bias=bias)
160 | self.to_v = nn.Linear(self.inner_dim, self.inner_dim, bias=bias)
161 |
162 | self.to_out = nn.ModuleList([
163 | nn.Linear(self.inner_dim, self.inner_dim, bias=out_bias),
164 | nn.Dropout(dropout)
165 | ])
166 |
167 | def forward(
168 | self,
169 | hidden_states: torch.Tensor,
170 | encoder_hidden_states: Optional[torch.Tensor] = None,
171 | ):
172 |
173 | input_ndim = hidden_states.ndim
174 |
175 | if input_ndim == 4:
176 | batch_size, channel, height, width = hidden_states.shape
177 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
178 |
179 | batch_size = hidden_states.shape[0] if encoder_hidden_states is None else encoder_hidden_states.shape[0]
180 |
181 | query = self.to_q(hidden_states)
182 |
183 | if encoder_hidden_states is None:
184 | encoder_hidden_states = hidden_states
185 |
186 | key = self.to_k(encoder_hidden_states)
187 | value = self.to_v(encoder_hidden_states)
188 |
189 | inner_dim = key.shape[-1]
190 | head_dim = inner_dim // self.heads
191 |
192 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
193 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
194 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
195 |
196 | hidden_states = F.scaled_dot_product_attention(
197 | query, key, value, dropout_p=0.0, is_causal=False
198 | )
199 |
200 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
201 | hidden_states = hidden_states.to(query.dtype)
202 |
203 | hidden_states = self.to_out[0](hidden_states)
204 | hidden_states = self.to_out[1](hidden_states)
205 |
206 | if input_ndim == 4:
207 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
208 |
209 | return hidden_states
210 |
211 |
212 | class DiTBlock(nn.Module):
213 |
214 | def __init__(
215 | self,
216 | dim: int,
217 | num_attention_heads: int,
218 | attention_head_dim: int,
219 | dropout=0.0,
220 | activation_fn: str = "geglu",
221 | attention_bias: bool = False,
222 | norm_elementwise_affine: bool = True,
223 | norm_eps: float = 1e-5,
224 | final_dropout: bool = False,
225 | ff_inner_dim: Optional[int] = None,
226 | ff_bias: bool = True,
227 | attention_out_bias: bool = True,
228 | ):
229 | super().__init__()
230 |
231 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
232 |
233 | self.attn1 = Attention(
234 | heads=num_attention_heads,
235 | dim_head=attention_head_dim,
236 | dropout=dropout,
237 | bias=attention_bias,
238 | out_bias=attention_out_bias,
239 | )
240 |
241 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
242 |
243 | self.attn2 = Attention(
244 | heads=num_attention_heads,
245 | dim_head=attention_head_dim,
246 | dropout=dropout,
247 | bias=attention_bias,
248 | out_bias=attention_out_bias,
249 | )
250 |
251 | self.ff = FeedForward(
252 | dim,
253 | dropout=dropout,
254 | activation_fn=activation_fn,
255 | final_dropout=final_dropout,
256 | inner_dim=ff_inner_dim,
257 | bias=ff_bias,
258 | )
259 |
260 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
261 |
262 | def forward(
263 | self,
264 | hidden_states: torch.FloatTensor,
265 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
266 | timestep: Optional[torch.LongTensor] = None,
267 | pixel_hidden_states: Optional[torch.FloatTensor] = None,
268 | ):
269 | batch_size = hidden_states.shape[0]
270 |
271 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
272 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
273 | ).chunk(6, dim=1)
274 | norm_hidden_states = self.norm1(hidden_states)
275 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
276 | norm_hidden_states = norm_hidden_states.squeeze(1)
277 | hidden_states_len = norm_hidden_states.shape[1]
278 | attn_output = self.attn1(
279 | torch.cat([pixel_hidden_states, norm_hidden_states], dim=1),
280 | )[:, -hidden_states_len:]
281 | attn_output = gate_msa * attn_output
282 |
283 | hidden_states = attn_output + hidden_states
284 | if hidden_states.ndim == 4:
285 | hidden_states = hidden_states.squeeze(1)
286 |
287 | norm_hidden_states = hidden_states
288 |
289 | attn_output = self.attn2(
290 | norm_hidden_states,
291 | encoder_hidden_states=encoder_hidden_states,
292 | )
293 | hidden_states = attn_output + hidden_states
294 |
295 | norm_hidden_states = self.norm2(hidden_states)
296 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
297 |
298 | ff_output = self.ff(norm_hidden_states)
299 |
300 | ff_output = gate_mlp * ff_output
301 |
302 | hidden_states = ff_output + hidden_states
303 | if hidden_states.ndim == 4:
304 | hidden_states = hidden_states.squeeze(1)
305 |
306 | return hidden_states
307 |
308 |
309 | class D3D_DiT(nn.Module):
310 |
311 | def __init__(
312 | self,
313 | num_attention_heads: int = 16,
314 | attention_head_dim: int = 72,
315 | in_channels: Optional[int] = None,
316 | out_channels: Optional[int] = None,
317 | num_layers: int = 1,
318 | dropout: float = 0.0,
319 | attention_bias: bool = False,
320 | sample_size: Optional[int] = None,
321 | patch_size: Optional[int] = None,
322 | activation_fn: str = "gelu-approximate",
323 | norm_elementwise_affine: bool = False,
324 | norm_eps: float = 1e-6,
325 | semantic_channels: int = None,
326 | pixel_channels: int = None,
327 | interpolation_scale: float = 1.0,
328 | gradient_checkpointing: bool = False,
329 | ):
330 | super().__init__()
331 |
332 | self.num_attention_heads = num_attention_heads
333 | self.attention_head_dim = attention_head_dim
334 | inner_dim = num_attention_heads * attention_head_dim
335 |
336 | if isinstance(sample_size, int):
337 | sample_size = (sample_size, sample_size)
338 | self.height = sample_size[0]
339 | self.width = sample_size[1]
340 |
341 | self.patch_size = patch_size
342 | interpolation_scale = (
343 | interpolation_scale if interpolation_scale is not None else max(min(self.config.sample_size) // 32, 1)
344 | )
345 | self.pos_embed = PatchEmbed(
346 | height=sample_size[0],
347 | width=sample_size[1],
348 | patch_size=patch_size,
349 | in_channels=in_channels,
350 | embed_dim=inner_dim,
351 | interpolation_scale=interpolation_scale,
352 | )
353 |
354 | self.transformer_blocks = nn.ModuleList(
355 | [
356 | DiTBlock(
357 | inner_dim,
358 | num_attention_heads,
359 | attention_head_dim,
360 | dropout=dropout,
361 | activation_fn=activation_fn,
362 | attention_bias=attention_bias,
363 | norm_elementwise_affine=norm_elementwise_affine,
364 | norm_eps=norm_eps,
365 | )
366 | for d in range(num_layers)
367 | ]
368 | )
369 |
370 | self.out_channels = in_channels if out_channels is None else out_channels
371 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
372 | self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
373 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
374 |
375 | self.adaln_single = AdaLayerNormClassEmb(inner_dim, semantic_channels)
376 |
377 | self.semantic_projection = PixArtAlphaTextProjection(in_features=semantic_channels, hidden_size=inner_dim)
378 | self.pixel_projection = PixArtAlphaTextProjection(in_features=pixel_channels, hidden_size=inner_dim)
379 |
380 | self.gradient_checkpointing = gradient_checkpointing
381 |
382 | def forward(
383 | self,
384 | hidden_states: torch.Tensor,
385 | encoder_hidden_states: Optional[torch.Tensor] = None,
386 | timestep: Optional[torch.LongTensor] = None,
387 | pixel_hidden_states: Optional[torch.Tensor] = None,
388 | ):
389 | height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
390 | hidden_states = self.pos_embed(hidden_states)
391 |
392 | timestep, embedded_timestep = self.adaln_single(
393 | timestep, class_embedding=encoder_hidden_states[:, 0], hidden_dtype=hidden_states.dtype
394 | )
395 |
396 | batch_size = hidden_states.shape[0]
397 | encoder_hidden_states = self.semantic_projection(encoder_hidden_states)
398 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
399 |
400 | pixel_hidden_states = self.pixel_projection(pixel_hidden_states)
401 | pixel_hidden_states = pixel_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
402 |
403 | for block in self.transformer_blocks:
404 | if self.training and self.gradient_checkpointing:
405 | hidden_states = torch.utils.checkpoint.checkpoint(
406 | block,
407 | hidden_states,
408 | encoder_hidden_states,
409 | timestep,
410 | pixel_hidden_states,
411 | )
412 | else:
413 | hidden_states = block(
414 | hidden_states,
415 | encoder_hidden_states=encoder_hidden_states,
416 | timestep=timestep,
417 | pixel_hidden_states=pixel_hidden_states,
418 | )
419 |
420 | shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
421 | hidden_states = self.norm_out(hidden_states)
422 | hidden_states = hidden_states * (1 + scale) + shift
423 | hidden_states = self.proj_out(hidden_states)
424 | hidden_states = hidden_states.squeeze(1)
425 |
426 | hidden_states = hidden_states.reshape(
427 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
428 | )
429 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
430 | output = hidden_states.reshape(
431 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
432 | )
433 |
434 | return output
435 |
436 |
--------------------------------------------------------------------------------
/direct3d/models/vae.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/vae.py
2 |
3 | import trimesh
4 | import itertools
5 | import numpy as np
6 | from tqdm import tqdm
7 | from einops import rearrange, repeat
8 | from skimage import measure
9 | from typing import List, Tuple, Optional, Union
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | from direct3d.utils.triplane import sample_from_planes, generate_planes
16 | from diffusers.models.autoencoders.vae import UNetMidBlock2D
17 | from diffusers.models.unets.unet_2d_blocks import UpDecoderBlock2D
18 |
19 |
20 | class DiagonalGaussianDistribution(object):
21 | def __init__(self, parameters: torch.Tensor, deterministic: bool=False):
22 | self.parameters = parameters
23 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
24 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
25 | self.deterministic = deterministic
26 | self.std = torch.exp(0.5 * self.logvar)
27 | self.var = torch.exp(self.logvar)
28 | if self.deterministic:
29 | self.var = self.std = torch.zeros_like(
30 | self.mean, device=self.parameters.device, dtype=self.parameters.dtype
31 | )
32 |
33 | def sample(self):
34 | x = self.mean + self.std * torch.randn_like(self.mean)
35 | return x
36 |
37 | def kl(self, other=None):
38 | if self.deterministic:
39 | return torch.Tensor([0.])
40 | else:
41 | if other is None:
42 | return 0.5 * torch.mean(torch.pow(self.mean, 2)
43 | + self.var - 1.0 - self.logvar,
44 | dim=[1, 2, 3])
45 | else:
46 | return 0.5 * torch.mean(
47 | torch.pow(self.mean - other.mean, 2) / other.var
48 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
49 | dim=[1, 2, 3])
50 |
51 | def nll(self, sample, dims=(1, 2, 3)):
52 | if self.deterministic:
53 | return torch.Tensor([0.])
54 | logtwopi = np.log(2.0 * np.pi)
55 | return 0.5 * torch.sum(
56 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
57 | dim=dims)
58 |
59 | def mode(self):
60 | return self.mean
61 |
62 |
63 | class FourierEmbedder(nn.Module):
64 |
65 | def __init__(self,
66 | num_freqs: int = 6,
67 | input_dim: int = 3):
68 |
69 | super().__init__()
70 | freq = 2.0 ** torch.arange(num_freqs)
71 | self.register_buffer("freq", freq, persistent=False)
72 | self.num_freqs = num_freqs
73 | self.out_dim = input_dim * (num_freqs * 2 + 1)
74 |
75 | def forward(self, x: torch.Tensor):
76 | embed = (x[..., None].contiguous() * self.freq).view(*x.shape[:-1], -1)
77 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
78 |
79 |
80 | class OccDecoder(nn.Module):
81 |
82 | def __init__(self,
83 | n_features: int,
84 | hidden_dim: int = 64,
85 | num_layers: int = 4,
86 | activation: nn.Module = nn.ReLU,
87 | final_activation: str = None):
88 | super().__init__()
89 |
90 | self.net = nn.Sequential(
91 | nn.Linear(3 * n_features, hidden_dim),
92 | activation(),
93 | *itertools.chain(*[[
94 | nn.Linear(hidden_dim, hidden_dim),
95 | activation(),
96 | ] for _ in range(num_layers - 2)]),
97 | nn.Linear(hidden_dim, 1),
98 | )
99 | self.final_activation = final_activation
100 |
101 | def forward(self, sampled_features):
102 |
103 | x = rearrange(sampled_features, "N_b N_t N_s C -> N_b N_s (N_t C)")
104 | x = self.net(x)
105 |
106 | if self.final_activation is None:
107 | pass
108 | elif self.final_activation == 'tanh':
109 | x = torch.tanh(x)
110 | elif self.final_activation == 'sigmoid':
111 | x = torch.sigmoid(x)
112 | else:
113 | raise ValueError(f"Unknown final activation: {self.final_activation}")
114 |
115 | return x[..., 0]
116 |
117 |
118 | class Attention(nn.Module):
119 |
120 | def __init__(self,
121 | dim: int,
122 | heads: int = 8,
123 | dim_head: int = 64):
124 | super().__init__()
125 | inner_dim = dim_head * heads
126 | self.heads = heads
127 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
128 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
129 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
130 | self.to_out = nn.Linear(inner_dim, inner_dim)
131 |
132 | def forward(self,
133 | hidden_states: torch.Tensor,
134 | encoder_hidden_states: Optional[torch.Tensor] = None,
135 | ):
136 | batch_size = hidden_states.shape[0]
137 | if encoder_hidden_states is None:
138 | encoder_hidden_states = hidden_states
139 | key = self.to_k(encoder_hidden_states)
140 | value = self.to_v(encoder_hidden_states)
141 | query = self.to_q(hidden_states)
142 |
143 | inner_dim = key.shape[-1]
144 | head_dim = inner_dim // self.heads
145 |
146 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
147 |
148 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
149 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
150 |
151 | hidden_states = F.scaled_dot_product_attention(
152 | query, key, value
153 | )
154 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
155 | hidden_states = self.to_out(hidden_states)
156 |
157 | return hidden_states
158 |
159 |
160 | class TransformerBlock(nn.Module):
161 |
162 | def __init__(self,
163 | num_attention_heads: int,
164 | attention_head_dim: int,
165 | cross_attention: bool = False):
166 | super().__init__()
167 | inner_dim = attention_head_dim * num_attention_heads
168 | self.norm1 = nn.LayerNorm(inner_dim)
169 | if cross_attention:
170 | self.norm1_c = nn.LayerNorm(inner_dim)
171 | else:
172 | self.norm1_c = None
173 | self.attn = Attention(inner_dim, num_attention_heads, attention_head_dim)
174 | self.norm2 = nn.LayerNorm(inner_dim)
175 | self.mlp = nn.Sequential(
176 | nn.Linear(inner_dim, 4 * inner_dim),
177 | nn.GELU(),
178 | nn.Linear(4 * inner_dim, inner_dim),
179 | )
180 |
181 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
182 | if self.norm1_c is not None:
183 | x = self.attn(self.norm1(x), self.norm1_c(y)) + x
184 | else:
185 | x = self.attn(self.norm1(x)) + x
186 | x = x + self.mlp(self.norm2(x))
187 | return x
188 |
189 |
190 | class PointEncoder(nn.Module):
191 |
192 | def __init__(self,
193 | num_latents: int,
194 | in_channels: int,
195 | num_attention_heads: int,
196 | attention_head_dim: int,
197 | num_layers: int,
198 | gradient_checkpointing: bool = False):
199 |
200 | super().__init__()
201 |
202 | self.gradient_checkpointing = gradient_checkpointing
203 | self.num_latents = num_latents
204 | inner_dim = attention_head_dim * num_attention_heads
205 |
206 | self.learnable_token = nn.Parameter(torch.randn((num_latents, inner_dim)) * 0.01)
207 |
208 | self.proj_in = nn.Linear(in_channels, inner_dim)
209 | self.cross_attn = TransformerBlock(num_attention_heads, attention_head_dim, cross_attention=True)
210 |
211 | self.self_attn = nn.ModuleList([
212 | TransformerBlock(num_attention_heads, attention_head_dim) for _ in range(num_layers)
213 | ])
214 |
215 | self.norm_out = nn.LayerNorm(inner_dim)
216 |
217 | def forward(self, pc):
218 |
219 | bs = pc.shape[0]
220 | pc = self.proj_in(pc)
221 |
222 | learnable_token = repeat(self.learnable_token, "m c -> b m c", b=bs)
223 |
224 | if self.training and self.gradient_checkpointing:
225 | latents = torch.utils.checkpoint.checkpoint(self.cross_attn, learnable_token, pc)
226 | for block in self.self_attn:
227 | latents = torch.utils.checkpoint.checkpoint(block, latents)
228 | else:
229 | latents = self.cross_attn(learnable_token, pc)
230 | for block in self.self_attn:
231 | latents = block(latents)
232 |
233 | latents = self.norm_out(latents)
234 |
235 | return latents
236 |
237 |
238 | class TriplaneDecoder(nn.Module):
239 |
240 | def __init__(
241 | self,
242 | in_channels: int = 3,
243 | out_channels: int = 3,
244 | block_out_channels: Tuple[int, ...] = (64,),
245 | layers_per_block: int = 2,
246 | norm_num_groups: int = 32,
247 | act_fn: str = "silu",
248 | norm_type: str = "group",
249 | mid_block_add_attention=True,
250 | gradient_checkpointing: bool = False,
251 | ):
252 | super().__init__()
253 | self.layers_per_block = layers_per_block
254 |
255 | self.conv_in = nn.Conv2d(
256 | in_channels,
257 | block_out_channels[-1],
258 | kernel_size=3,
259 | stride=1,
260 | padding=1,
261 | )
262 |
263 | self.up_blocks = nn.ModuleList([])
264 |
265 | temb_channels = in_channels if norm_type == "spatial" else None
266 |
267 | # mid
268 | self.mid_block = UNetMidBlock2D(
269 | in_channels=block_out_channels[-1],
270 | resnet_eps=1e-6,
271 | resnet_act_fn=act_fn,
272 | output_scale_factor=1,
273 | resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
274 | attention_head_dim=block_out_channels[-1],
275 | resnet_groups=norm_num_groups,
276 | temb_channels=temb_channels,
277 | add_attention=mid_block_add_attention,
278 | )
279 |
280 | # up
281 | reversed_block_out_channels = list(reversed(block_out_channels))
282 | output_channel = reversed_block_out_channels[0]
283 | for i in range(len(block_out_channels)):
284 | prev_output_channel = output_channel
285 | output_channel = reversed_block_out_channels[i]
286 |
287 | is_final_block = i == len(block_out_channels) - 1
288 | up_block = UpDecoderBlock2D(
289 | num_layers=self.layers_per_block + 1,
290 | in_channels=prev_output_channel,
291 | out_channels=output_channel,
292 | add_upsample=not is_final_block,
293 | resnet_eps=1e-6,
294 | resnet_act_fn=act_fn,
295 | resnet_groups=norm_num_groups,
296 | resnet_time_scale_shift=norm_type,
297 | temb_channels=temb_channels,
298 | )
299 | self.up_blocks.append(up_block)
300 | prev_output_channel = output_channel
301 |
302 | # out
303 | if norm_type == "group":
304 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
305 | else:
306 | raise ValueError(f"Unsupported norm type: {norm_type}")
307 |
308 | self.conv_act = nn.SiLU()
309 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
310 |
311 | self.gradient_checkpointing = gradient_checkpointing
312 |
313 | def forward(self, sample: torch.Tensor):
314 | r"""The forward method of the `Decoder` class."""
315 |
316 | sample = self.conv_in(sample)
317 |
318 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
319 | if self.training and self.gradient_checkpointing:
320 | # middle
321 | sample = torch.utils.checkpoint.checkpoint(
322 | self.mid_block, sample
323 | )
324 | sample = sample.to(upscale_dtype)
325 |
326 | # up
327 | for up_block in self.up_blocks:
328 | sample = torch.utils.checkpoint.checkpoint(up_block, sample)
329 | else:
330 | # middle
331 | sample = self.mid_block(sample)
332 | sample = sample.to(upscale_dtype)
333 |
334 | # up
335 | for up_block in self.up_blocks:
336 | sample = up_block(sample)
337 |
338 | # post-process
339 | sample = self.conv_norm_out(sample)
340 | sample = self.conv_act(sample)
341 | sample = self.conv_out(sample)
342 |
343 | return sample
344 |
345 |
346 | class D3D_VAE(nn.Module):
347 | def __init__(self,
348 | triplane_res: int,
349 | latent_dim: int = 0,
350 | triplane_dim: int = 32,
351 | num_freqs: int = 8,
352 | num_attention_heads: int = 12,
353 | attention_head_dim: int = 64,
354 | num_encoder_layers: int = 8,
355 | num_geodecoder_layers: int = 5,
356 | final_activation: str = None,
357 | block_out_channels=[128, 256, 512, 512],
358 | mid_block_add_attention=True,
359 | gradient_checkpointing: bool = False,
360 | latents_scale: float = 1.0,
361 | latents_shift: float = 0.0):
362 |
363 | super().__init__()
364 |
365 | self.gradient_checkpointing = gradient_checkpointing
366 |
367 | self.triplane_res = triplane_res
368 | self.num_latents = triplane_res ** 2 * 3
369 | self.latent_shape = (latent_dim, triplane_res, 3 * triplane_res)
370 | self.latents_scale = latents_scale
371 | self.latents_shift = latents_shift
372 |
373 | self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs)
374 |
375 | inner_dim = attention_head_dim * num_attention_heads
376 | self.encoder = PointEncoder(
377 | num_latents=self.num_latents,
378 | in_channels=self.fourier_embedder.out_dim + 3,
379 | num_attention_heads=num_attention_heads,
380 | attention_head_dim=attention_head_dim,
381 | num_layers=num_encoder_layers,
382 | gradient_checkpointing=gradient_checkpointing,
383 | )
384 |
385 | self.latent_dim = latent_dim
386 |
387 | self.pre_latent = nn.Conv2d(inner_dim, 2 * latent_dim, 1)
388 | self.post_latent = nn.Conv2d(latent_dim, inner_dim, 1)
389 |
390 | self.decoder = TriplaneDecoder(
391 | in_channels=inner_dim,
392 | out_channels=triplane_dim,
393 | block_out_channels=block_out_channels,
394 | mid_block_add_attention=mid_block_add_attention,
395 | gradient_checkpointing=gradient_checkpointing,
396 | )
397 |
398 | self.plane_axes = generate_planes()
399 | self.occ_decoder = OccDecoder(
400 | n_features=triplane_dim,
401 | num_layers=num_geodecoder_layers,
402 | final_activation=final_activation,
403 | )
404 |
405 | def rollout(self, triplane):
406 | triplane = rearrange(triplane, "N_b (N_t C) H_t W_t -> N_b C H_t (N_t W_t)", N_t=3)
407 | return triplane
408 |
409 | def unrollout(self, triplane):
410 | triplane = rearrange(triplane, "N_b C H_t (N_t W_t) -> N_b N_t C H_t W_t", N_t=3)
411 | return triplane
412 |
413 | def encode(self,
414 | pc: torch.FloatTensor,
415 | feats: Optional[torch.FloatTensor] = None):
416 |
417 | x = self.fourier_embedder(pc)
418 | if feats is not None:
419 | x = torch.cat((x, feats), dim=-1)
420 | x = self.encoder(x)
421 | x = rearrange(x, "N_b (N_t H_t W_t) C -> N_b (N_t C) H_t W_t",
422 | N_t=3, H_t=self.triplane_res, W_t=self.triplane_res)
423 | x = self.rollout(x)
424 | moments = self.pre_latent(x)
425 |
426 | posterior = DiagonalGaussianDistribution(moments)
427 | latents = posterior.sample()
428 |
429 | return latents, posterior
430 |
431 | def decode(self, z, unrollout=False):
432 | z = self.post_latent(z)
433 | dec = self.decoder(z)
434 | if unrollout:
435 | dec = self.unrollout(dec)
436 | return dec
437 |
438 | def decode_mesh(self,
439 | latents,
440 | bounds: Union[Tuple[float], List[float], float] = 1.0,
441 | voxel_resolution: int = 512,
442 | mc_threshold: float = 0.0):
443 | triplane = self.decode(latents, unrollout=True)
444 | mesh = self.triplane2mesh(triplane,
445 | bounds=bounds,
446 | voxel_resolution=voxel_resolution,
447 | mc_threshold=mc_threshold)
448 | return mesh
449 |
450 | def triplane2mesh(self,
451 | latents: torch.FloatTensor,
452 | bounds: Union[Tuple[float], List[float], float] = 1.0,
453 | voxel_resolution: int = 512,
454 | mc_threshold: float = 0.0,
455 | chunk_size: int = 50000):
456 |
457 | batch_size = len(latents)
458 |
459 | if isinstance(bounds, float):
460 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
461 | bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
462 | bbox_length = bbox_max - bbox_min
463 |
464 | x = torch.linspace(bbox_min[0], bbox_max[0], steps=int(voxel_resolution) + 1)
465 | y = torch.linspace(bbox_min[1], bbox_max[1], steps=int(voxel_resolution) + 1)
466 | z = torch.linspace(bbox_min[2], bbox_max[2], steps=int(voxel_resolution) + 1)
467 | xs, ys, zs = torch.meshgrid(x, y, z, indexing='ij')
468 | xyz = torch.stack((xs, ys, zs), dim=-1)
469 | xyz = xyz.reshape(-1, 3)
470 | grid_size = [int(voxel_resolution) + 1, int(voxel_resolution) + 1, int(voxel_resolution) + 1]
471 |
472 | logits_total = []
473 | for start in tqdm(range(0, xyz.shape[0], chunk_size), desc="Triplane Sampling:"):
474 | positions = xyz[start:start + chunk_size].to(latents.device)
475 | positions = repeat(positions, "p d -> b p d", b=batch_size)
476 |
477 | triplane_features = sample_from_planes(self.plane_axes.to(latents.device),
478 | latents, positions,
479 | box_warp=2.0)
480 | logits = self.occ_decoder(triplane_features)
481 | logits_total.append(logits)
482 |
483 | logits_total = torch.cat(logits_total, dim=1).view(
484 | (batch_size, grid_size[0], grid_size[1], grid_size[2])).cpu().numpy()
485 |
486 | meshes = []
487 | for i in range(batch_size):
488 | vertices, faces, _, _ = measure.marching_cubes(
489 | logits_total[i],
490 | mc_threshold,
491 | method="lewiner"
492 | )
493 | vertices = vertices / grid_size * bbox_length + bbox_min
494 | faces = faces[:, ::-1]
495 | meshes.append(trimesh.Trimesh(vertices, faces))
496 | return meshes
497 |
498 |
--------------------------------------------------------------------------------
/direct3d/pipeline.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2 |
3 | import os
4 | from tqdm import tqdm
5 | from PIL import Image
6 | from omegaconf import OmegaConf
7 | from huggingface_hub import hf_hub_download
8 | from typing import Union, List, Optional
9 |
10 | import torch
11 | from direct3d.utils import instantiate_from_config, preprocess
12 | from diffusers.utils.torch_utils import randn_tensor
13 |
14 |
15 | class Direct3dPipeline(object):
16 |
17 | def __init__(self,
18 | vae,
19 | dit,
20 | semantic_encoder,
21 | pixel_encoder,
22 | scheduler):
23 | self.vae = vae
24 | self.dit = dit
25 | self.semantic_encoder = semantic_encoder
26 | self.pixel_encoder = pixel_encoder
27 | self.scheduler = scheduler
28 |
29 | def to(self, device):
30 | self.device = torch.device(device)
31 | self.vae.to(device)
32 | self.dit.to(device)
33 | self.semantic_encoder.to(device)
34 | self.pixel_encoder.to(device)
35 |
36 | @classmethod
37 | def from_pretrained(cls,
38 | pipeline_path):
39 |
40 | if os.path.isdir(pipeline_path):
41 | config_path = os.path.join(pipeline_path, 'config.yaml')
42 | model_path = os.path.join(pipeline_path, 'model.ckpt')
43 | else:
44 | config_path = hf_hub_download(repo_id=pipeline_path, filename="config.yaml", repo_type="model")
45 | model_path = hf_hub_download(repo_id=pipeline_path, filename="model.ckpt", repo_type="model")
46 |
47 | cfg = OmegaConf.load(config_path)
48 | state_dict = torch.load(model_path, map_location='cpu')
49 |
50 | vae = instantiate_from_config(cfg.vae)
51 | vae.load_state_dict(state_dict["vae"], strict=True)
52 | dit = instantiate_from_config(cfg.dit)
53 | dit.load_state_dict(state_dict["dit"], strict=True)
54 |
55 | semantic_encoder = instantiate_from_config(cfg.semantic_encoder)
56 | pixel_encoder = instantiate_from_config(cfg.pixel_encoder)
57 |
58 | scheduler = instantiate_from_config(cfg.scheduler)
59 |
60 | return cls(
61 | vae=vae,
62 | dit=dit,
63 | semantic_encoder=semantic_encoder,
64 | pixel_encoder=pixel_encoder,
65 | scheduler=scheduler)
66 |
67 | def prepare_image(self, image: Union[str, List[str], Image.Image, List[Image.Image]], rmbg: bool = True):
68 | if not isinstance(image, list):
69 | image = [image]
70 | if isinstance(image[0], str):
71 | image = [Image.open(img) for img in image]
72 | image = [preprocess(img, rmbg=rmbg) for img in image]
73 | image = torch.stack([img for img in image]).to(self.device)
74 | return image
75 |
76 | def encode_image(self, image: torch.Tensor, do_classifier_free_guidance: bool = True):
77 | semantic_cond = self.semantic_encoder(image)
78 | pixel_cond = self.pixel_encoder(image)
79 | if do_classifier_free_guidance:
80 | semantic_uncond = torch.zeros_like(semantic_cond)
81 | pixel_uncond = torch.zeros_like(pixel_cond)
82 | semantic_cond = torch.cat([semantic_uncond, semantic_cond], dim=0)
83 | pixel_cond = torch.cat([pixel_uncond, pixel_cond], dim=0)
84 |
85 | return semantic_cond, pixel_cond
86 |
87 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator):
88 | shape = (
89 | batch_size,
90 | num_channels_latents,
91 | height,
92 | width,
93 | )
94 | if isinstance(generator, list) and len(generator) != batch_size:
95 | raise ValueError(
96 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
97 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
98 | )
99 |
100 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
101 | return latents
102 |
103 | @torch.no_grad()
104 | def __call__(
105 | self,
106 | image: Union[str, List[str], Image.Image, List[Image.Image]] = None,
107 | num_inference_steps: int = 50,
108 | guidance_scale: float = 4.0,
109 | generator: Optional[torch.Generator] = None,
110 | mc_threshold: float = -2.0,
111 | remove_background: bool = True,):
112 |
113 | batch_size = len(image) if isinstance(image, list) else 1
114 | do_classifier_free_guidance = guidance_scale > 0
115 |
116 | self.scheduler.set_timesteps(num_inference_steps, device=self.device)
117 | timesteps = self.scheduler.timesteps
118 |
119 | image = self.prepare_image(image, remove_background)
120 | semantic_cond, pixel_cond = self.encode_image(image, do_classifier_free_guidance)
121 | latents = self.prepare_latents(
122 | batch_size=batch_size,
123 | num_channels_latents=self.vae.latent_shape[0],
124 | height=self.vae.latent_shape[1],
125 | width=self.vae.latent_shape[2],
126 | dtype=image.dtype,
127 | device=self.device,
128 | generator=generator,
129 | )
130 |
131 | extra_step_kwargs = {
132 | "generator": generator
133 | }
134 |
135 | for i, t in enumerate(tqdm(timesteps, desc="Diffusion Sampling:")):
136 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
137 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
138 |
139 | t = t.expand(latent_model_input.shape[0])
140 |
141 | noise_pred = self.dit(
142 | hidden_states=latent_model_input,
143 | timestep=t,
144 | encoder_hidden_states=semantic_cond,
145 | pixel_hidden_states=pixel_cond,
146 | )
147 |
148 | if do_classifier_free_guidance:
149 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
150 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
151 |
152 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
153 |
154 | latents = 1. / self.vae.latents_scale * latents + self.vae.latents_shift
155 | meshes = self.vae.decode_mesh(latents, mc_threshold=mc_threshold)
156 | outputs = {"meshes": meshes, "latents": latents}
157 |
158 | return outputs
159 |
--------------------------------------------------------------------------------
/direct3d/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import instantiate_from_config, get_obj_from_str
2 | from .image import preprocess
--------------------------------------------------------------------------------
/direct3d/utils/image.py:
--------------------------------------------------------------------------------
1 | import rembg
2 | import numpy as np
3 | from PIL import Image
4 | from torchvision import transforms as T
5 |
6 |
7 | def crop_recenter(image_no_bg, thereshold=100):
8 | image_no_bg_np = np.array(image_no_bg)
9 | if image_no_bg_np.shape[2] == 3:
10 | return image_no_bg
11 | mask = (image_no_bg_np[..., -1]).astype(np.uint8)
12 | mask_bin = mask > thereshold
13 |
14 | H, W = image_no_bg_np.shape[:2]
15 |
16 | valid_pixels = mask_bin.astype(np.float32).nonzero()
17 | if np.sum(mask_bin) < (H*W) * 0.001:
18 | min_h = 0
19 | max_h = H - 1
20 | min_w = 0
21 | max_w = W -1
22 | else:
23 | min_h, max_h = valid_pixels[0].min(), valid_pixels[0].max()
24 | min_w, max_w = valid_pixels[1].min(), valid_pixels[1].max()
25 |
26 | if min_h < 0:
27 | min_h = 0
28 | if min_w < 0:
29 | min_w = 0
30 | if max_h > H:
31 | max_h = H - 1
32 | if max_w > W:
33 | max_w = W - 1
34 |
35 | image_no_bg_np = image_no_bg_np[min_h:max_h+1, min_w:max_w+1]
36 | image_no_bg = Image.fromarray(image_no_bg_np)
37 | return image_no_bg
38 |
39 |
40 | def pad_to_same_size(image, pad_value=1):
41 | image = np.array(image)
42 | h, w, _ = image.shape
43 | image_temp = image.copy()
44 | if h != w:
45 | # find the max one and pad the other side with white
46 | max_size = max(h, w)
47 |
48 | pad_h = max_size - h
49 | pad_w = max_size - w
50 | pad_h_top = max(pad_h // 2, 0)
51 | pad_h_bottom = max(pad_h - pad_h_top, 0)
52 | pad_w_left = max(pad_w // 2, 0)
53 | pad_w_right = max(pad_w - pad_w_left, 0)
54 |
55 | image_temp = np.pad(
56 | image[..., :3],
57 | ((pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)),
58 | constant_values=pad_value
59 | )
60 | if image.shape[2] == 4:
61 | image_bg = np.pad(
62 | image[..., 3:],
63 | ((pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)),
64 | constant_values=0
65 | )
66 | image = np.concatenate([image_temp, image_bg], axis=2)
67 | else:
68 | image = image_temp
69 |
70 | return Image.fromarray(image)
71 |
72 |
73 | def remove_bg(image):
74 | image = rembg.remove(image)
75 | return image
76 |
77 |
78 | def preprocess(image, rmbg=True):
79 |
80 | if rmbg:
81 | image = remove_bg(image)
82 |
83 | image = crop_recenter(image)
84 | image = pad_to_same_size(image, pad_value=255)
85 | image = np.array(image)
86 | image = image / 255.
87 | if image.shape[2] == 4:
88 | image = image[..., :3] * image[..., 3:] + (1 - image[..., 3:])
89 | image = Image.fromarray((image * 255).astype('uint8'), "RGB")
90 |
91 | W, H = image.size[:2]
92 | pad_margin = int(W * 0.04)
93 | image_transforms = T.Compose([
94 | T.Pad((pad_margin, pad_margin, pad_margin, pad_margin), fill=255),
95 | T.ToTensor(),
96 | ])
97 |
98 | image = image_transforms(image)
99 |
100 | return image
101 |
--------------------------------------------------------------------------------
/direct3d/utils/triplane.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange
4 |
5 |
6 | def generate_planes():
7 | """
8 | Defines planes by the three vectors that form the "axes" of the
9 | plane. Should work with arbitrary number of planes and planes of
10 | arbitrary orientation.
11 |
12 | Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
13 | """
14 | return torch.tensor([[[1, 0, 0],
15 | [0, 1, 0],
16 | [0, 0, 1]],
17 | [[1, 0, 0],
18 | [0, 0, 1],
19 | [0, 1, 0]],
20 | [[0, 0, 1],
21 | [0, 1, 0],
22 | [1, 0, 0]]])
23 |
24 |
25 | def project_onto_planes(planes, coordinates):
26 | """
27 | Does a projection of a 3D point onto a batch of 2D planes,
28 | returning 2D plane coordinates.
29 |
30 | Takes plane axes of shape n_planes, 3, 3
31 | # Takes coordinates of shape N, M, 3
32 | # returns projections of shape N*n_planes, M, 2
33 | """
34 | N, M, C = coordinates.shape
35 | n_planes, _, _ = planes.shape
36 | coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
37 | inv_planes = torch.linalg.inv(planes.to(coordinates.dtype)).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
38 | projections = torch.bmm(coordinates, inv_planes)
39 | return projections[..., :2]
40 |
41 |
42 | def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
43 | assert padding_mode == 'zeros'
44 | N, n_planes, C, H, W = plane_features.shape
45 | _, M, _ = coordinates.shape
46 | plane_features = rearrange(plane_features, "N_b N_t C H_t W_t -> (N_b N_t) C H_t W_t")
47 |
48 | coordinates = (2/box_warp) * coordinates
49 |
50 | projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
51 | output_features = F.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
52 | return output_features
--------------------------------------------------------------------------------
/direct3d/utils/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 |
4 | def instantiate_from_config(config):
5 | if not "target" in config:
6 | if config == '__is_first_stage__':
7 | return None
8 | elif config == "__is_unconditional__":
9 | return None
10 | raise KeyError("Expected key `target` to instantiate.")
11 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
12 |
13 |
14 | def get_obj_from_str(string, reload=False):
15 | module, cls = string.rsplit(".", 1)
16 | if reload:
17 | module_imp = importlib.import_module(module)
18 | importlib.reload(module_imp)
19 | return getattr(importlib.import_module(module, package=None), cls)
20 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.1.0
2 | scikit-image
3 | trimesh
4 | tqdm
5 | einops
6 | numpy
7 | transformers==4.40.2
8 | diffusers
9 | rembg
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 |
4 | setup(
5 | name="direct3d",
6 | version="1.0.0",
7 | description="Direct3D: Scalable Image-to-3D Generation via 3D Latent Diffusion Transformer",
8 | packages=find_packages(),
9 | python_requires=">=3.10",
10 | install_requires=[
11 | "torch",
12 | "numpy",
13 | "cython",
14 | "trimesh",
15 | "diffusers",
16 | ],
17 | )
--------------------------------------------------------------------------------