├── .gitignore ├── LICENSE ├── README.md ├── cam_utils.py ├── clip_sim.py ├── configs ├── image_sai.yaml ├── imagedream.yaml └── text_mv.yaml ├── convert.py ├── core ├── __init__.py ├── attention.py ├── gs.py ├── models.py ├── options.py ├── provider_objaverse.py ├── unet.py └── utils.py ├── grid_put.py ├── gs_postprocess.py ├── gs_renderer.py ├── guidance ├── imagedream_utils.py ├── mvdream_utils.py └── zero123_utils.py ├── loss_utils.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── main.py ├── main2.py ├── mesh.py ├── mesh_renderer.py ├── mesh_utils.py ├── process.py ├── requirements.txt ├── scheduler_config.json ├── scripts ├── cal_sim.py ├── convert_obj_to_video.py ├── run_cal_sim.sh ├── run_imagedream.sh └── run_sai.sh ├── sh_utils.py ├── simple-knn ├── ext.cpp ├── setup.py ├── simple_knn.cu ├── simple_knn.h ├── simple_knn │ └── .gitkeep ├── spatial.cu └── spatial.h ├── test_data ├── 00_zero123_lysol_rgba.png ├── 01_wild_hydrant_rgba.png ├── 02_zero123_spyro_rgba.png ├── 03_wild2_pineapple_bottle_rgba.png ├── 04_unsplash_broccoli_rgba.png ├── 05_objaverse_backpack_rgba.png ├── 06_unsplash_chocolatecake_rgba.png ├── 07_unsplash_stool2_rgba.png ├── 08_dalle_icecream_rgba.png ├── 09_unsplash_bigmac_rgba.png ├── 10_dalle3_blueberryicecream2_rgba.png ├── 11_GSO_Crosley_Alarm_Clock_Vintage_Metal_rgba.png ├── 12_realfusion_cactus_1_rgba.png ├── 13_realfusion_cherry_1_rgba.png ├── 14_dalle_cowbear_rgba.png ├── 15_dalle3_gramophone1_rgba.png ├── 16_dalle3_mushroom2_rgba.png ├── 17_dalle3_rockingchair1_rgba.png ├── 18_unsplash_mario_rgba.png ├── 19_dalle3_stump1_rgba.png ├── 20_objaverse_stool_rgba.png ├── 21_objaverse_barrel_rgba.png ├── 22_unsplash_boxtoy_rgba.png ├── 23_objaverse_tank_rgba.png ├── 24_wild2_yellow_duck_rgba.png ├── 25_unsplash_teapot_rgba.png ├── 26_unsplash_strawberrycake_rgba.png ├── 27_objaverse_robocat_rgba.png ├── 28_wild_goose_chef_rgba.png ├── 29_wild_peroxide_rgba.png ├── alarm_rgba.png ├── anya_rgba.png ├── armor_rgba.png ├── astronaut_rgba.png ├── backpack_rgba.png ├── box_rgba.png ├── bread_rgba.png ├── bucket_rgba.png ├── busket_rgba.png ├── cargo_rgba.png ├── cat_rgba.png ├── catstatue_rgba.png ├── chili_rgba.png ├── crab_rgba.png ├── crystal_rgba.png ├── csm_luigi_rgba.png ├── deer_rgba.png ├── drum2_rgba.png ├── drum_rgba.png ├── elephant_rgba.png ├── flower2_rgba.png ├── flower_rgba.png ├── forest_rgba.png ├── frog_sweater_rgba.png ├── ghost_rgba.png ├── giraffe_rgba.png ├── grandfather_rgba.png ├── ground_rgba.png ├── halloween_rgba.png ├── hat_rgba.png ├── head_rgba.png ├── house_rgba.png ├── kettle_rgba.png ├── kunkun_rgba.png ├── lantern_rgba.png ├── lotus_seed_rgba.png ├── lunch_bag_rgba.png ├── milk_rgba.png ├── monkey_rgba.png ├── oil_rgba.png ├── poro_rgba.png ├── rabbit_chinese_rgba.png ├── school_bus1_rgba.png ├── school_bus2_rgba.png ├── shed_rgba.png ├── shoe_rgba.png ├── sofa2_rgba.png ├── sofa_rgba.png ├── steak_rgba.png ├── teapot2_rgba.png ├── teapot_rgba.png ├── test_rgba.png ├── toaster_rgba.png ├── turtle_rgba.png ├── vase_rgba.png ├── wisky_rgba.png └── zelda_rgba.png └── zero123.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | *.egg-info/ 4 | *.so 5 | venv_*/ 6 | .vs/ 7 | .vscode/ 8 | .idea/ 9 | 10 | 11 | logs* 12 | 13 | testing.py 14 | test_dirs/ 15 | tmp/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MicroDreamer 2 | Official implementation of *[MicroDreamer: Zero-shot 3D Generation in ~20 Seconds by Score-based Iterative Reconstruction](http://arxiv.org/abs/2404.19525)*. 3 | 4 | 5 | 6 | 7 | https://github.com/user-attachments/assets/0a99424a-2e7a-47f0-9f0a-b6713b7686b5 8 | 9 | 10 | 11 | ## News 12 | [10/2024] Add a new mesh export method from [LGM](https://github.com/3DTopia/LGM) 13 | 14 | 15 | ## Installation 16 | 17 | The codebase is built on [DreamGaussian](https://github.com/dreamgaussian/dreamgaussian). For installation, 18 | ```bash 19 | conda create -n MicroDreamer python=3.11 20 | conda activate MicroDreamer 21 | 22 | pip install -r requirements.txt 23 | 24 | # a modified gaussian splatting (+ depth, alpha rendering) 25 | git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization 26 | pip install ./diff-gaussian-rasterization 27 | 28 | # The commit hash we used 29 | # d986da0d4cf2dfeb43b9a379b6e9fa0a7f3f7eea 30 | 31 | # simple-knn 32 | pip install ./simple-knn 33 | 34 | # nvdiffrast 35 | pip install git+https://github.com/NVlabs/nvdiffrast/ 36 | 37 | # The version we used 38 | # pip install git+https://github.com/NVlabs/nvdiffrast/@0.3.1 39 | 40 | # kiuikit 41 | pip install git+https://github.com/ashawkey/kiuikit/ 42 | 43 | # The version we used 44 | # pip install git+https://github.com/ashawkey/kiuikit/@0.2.3 45 | 46 | # To use ImageDream, also install: 47 | pip install git+https://github.com/bytedance/ImageDream/#subdirectory=extern/ImageDream 48 | 49 | # The commit hash we used 50 | # 26c3972e586f0c8d2f6c6b297aa9d792d06abebb 51 | ``` 52 | 53 | ## Usage 54 | 55 | Image-to-3D: 56 | 57 | ```bash 58 | ### preprocess 59 | # background removal and recentering, save rgba at 256x256 60 | python process.py test_data/name.jpg 61 | 62 | # save at a larger resolution 63 | python process.py test_data/name.jpg --size 512 64 | 65 | # process all jpg images under a dir 66 | python process.py test_data 67 | 68 | ### training gaussian stage 69 | # train 20 iters and export ckpt & coarse_mesh to logs 70 | python main.py --config configs/image_sai.yaml input=test_data/name_rgba.png save_path=name_rgba 71 | 72 | ### training mesh stage 73 | # auto load coarse_mesh and refine 3 iters, export fine_mesh to logs 74 | python main2.py --config configs/image_sai.yaml input=test_data/name_rgba.png save_path=name_rgba 75 | ``` 76 | 77 | Image+Text-to-3D (ImageDream): 78 | 79 | ```bash 80 | ### training gaussian stage 81 | python main.py --config configs/imagedream.yaml input=test_data/ghost_rgba.png prompt="a ghost eating hamburger" save_path=ghost_rgba 82 | ``` 83 | 84 | Calculate for CLIP similarity: 85 | ```bash 86 | PYTHONPATH='.' python scripts/cal_sim.py 87 | ``` 88 | 89 | ## More Results 90 | 91 | 92 | 93 | https://github.com/user-attachments/assets/8888a353-df16-4e19-ac1b-7ee37ece7ed1 94 | 95 | 96 | 97 | 98 | https://github.com/user-attachments/assets/7e52a87b-d1f6-4e7b-a6b4-7732ea69613c 99 | 100 | 101 | 102 | 103 | 104 | ## Acknowledgement 105 | 106 | This work is built on many amazing open source projects, thanks to all the authors! 107 | 108 | - [DreamGaussian](https://github.com/dreamgaussian/dreamgaussian) 109 | - [LGM](https://github.com/3DTopia/LGM) 110 | - [threestudio](https://github.com/threestudio-project/threestudio) 111 | 112 | 113 | ## BibTeX 114 | 115 | ``` 116 | @misc{chen2024microdreamerzeroshot3dgeneration, 117 | title={MicroDreamer: Zero-shot 3D Generation in $\sim$20 Seconds by Score-based Iterative Reconstruction}, 118 | author={Luxi Chen and Zhengyi Wang and Zihan Zhou and Tingting Gao and Hang Su and Jun Zhu and Chongxuan Li}, 119 | year={2024}, 120 | eprint={2404.19525}, 121 | archivePrefix={arXiv}, 122 | primaryClass={cs.CV}, 123 | url={https://arxiv.org/abs/2404.19525}, 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /cam_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation as R 3 | 4 | import torch 5 | 6 | def dot(x, y): 7 | if isinstance(x, np.ndarray): 8 | return np.sum(x * y, -1, keepdims=True) 9 | else: 10 | return torch.sum(x * y, -1, keepdim=True) 11 | 12 | 13 | def length(x, eps=1e-20): 14 | if isinstance(x, np.ndarray): 15 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 16 | else: 17 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 18 | 19 | 20 | def safe_normalize(x, eps=1e-20): 21 | return x / length(x, eps) 22 | 23 | 24 | def look_at(campos, target, opengl=True): 25 | # campos: [N, 3], camera/eye position 26 | # target: [N, 3], object to look at 27 | # return: [N, 3, 3], rotation matrix 28 | if not opengl: 29 | # camera forward aligns with -z 30 | forward_vector = safe_normalize(target - campos) 31 | up_vector = np.array([0, 1, 0], dtype=np.float32) 32 | right_vector = safe_normalize(np.cross(forward_vector, up_vector)) 33 | up_vector = safe_normalize(np.cross(right_vector, forward_vector)) 34 | else: 35 | # camera forward aligns with +z 36 | forward_vector = safe_normalize(campos - target) 37 | up_vector = np.array([0, 1, 0], dtype=np.float32) 38 | right_vector = safe_normalize(np.cross(up_vector, forward_vector)) 39 | up_vector = safe_normalize(np.cross(forward_vector, right_vector)) 40 | R = np.stack([right_vector, up_vector, forward_vector], axis=1) 41 | return R 42 | 43 | 44 | # elevation & azimuth to pose (cam2world) matrix 45 | def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True): 46 | # radius: scalar 47 | # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90) 48 | # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90) 49 | # return: [4, 4], camera pose matrix 50 | if is_degree: 51 | elevation = np.deg2rad(elevation) 52 | azimuth = np.deg2rad(azimuth) 53 | x = radius * np.cos(elevation) * np.sin(azimuth) 54 | y = - radius * np.sin(elevation) 55 | z = radius * np.cos(elevation) * np.cos(azimuth) 56 | if target is None: 57 | target = np.zeros([3], dtype=np.float32) 58 | campos = np.array([x, y, z]) + target # [3] 59 | T = np.eye(4, dtype=np.float32) 60 | T[:3, :3] = look_at(campos, target, opengl) 61 | T[:3, 3] = campos 62 | return T 63 | 64 | 65 | class OrbitCamera: 66 | def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100): 67 | self.W = W 68 | self.H = H 69 | self.radius = r # camera distance from center 70 | self.fovy = np.deg2rad(fovy) # deg 2 rad 71 | self.near = near 72 | self.far = far 73 | self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point 74 | self.rot = R.from_matrix(np.eye(3)) 75 | self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! 76 | 77 | @property 78 | def fovx(self): 79 | return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H) 80 | 81 | @property 82 | def campos(self): 83 | return self.pose[:3, 3] 84 | 85 | # pose (c2w) 86 | @property 87 | def pose(self): 88 | # first move camera to radius 89 | res = np.eye(4, dtype=np.float32) 90 | res[2, 3] = self.radius # opengl convention... 91 | # rotate 92 | rot = np.eye(4, dtype=np.float32) 93 | rot[:3, :3] = self.rot.as_matrix() 94 | res = rot @ res 95 | # translate 96 | res[:3, 3] -= self.center 97 | return res 98 | 99 | # view (w2c) 100 | @property 101 | def view(self): 102 | return np.linalg.inv(self.pose) 103 | 104 | # projection (perspective) 105 | @property 106 | def perspective(self): 107 | y = np.tan(self.fovy / 2) 108 | aspect = self.W / self.H 109 | return np.array( 110 | [ 111 | [1 / (y * aspect), 0, 0, 0], 112 | [0, -1 / y, 0, 0], 113 | [ 114 | 0, 115 | 0, 116 | -(self.far + self.near) / (self.far - self.near), 117 | -(2 * self.far * self.near) / (self.far - self.near), 118 | ], 119 | [0, 0, -1, 0], 120 | ], 121 | dtype=np.float32, 122 | ) 123 | 124 | # intrinsics 125 | @property 126 | def intrinsics(self): 127 | focal = self.H / (2 * np.tan(self.fovy / 2)) 128 | return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32) 129 | 130 | @property 131 | def mvp(self): 132 | return self.perspective @ np.linalg.inv(self.pose) # [4, 4] 133 | 134 | def orbit(self, dx, dy): 135 | # rotate along camera up/side axis! 136 | side = self.rot.as_matrix()[:3, 0] 137 | rotvec_x = self.up * np.radians(-0.05 * dx) 138 | rotvec_y = side * np.radians(-0.05 * dy) 139 | self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot 140 | 141 | def scale(self, delta): 142 | self.radius *= 1.1 ** (-delta) 143 | 144 | def pan(self, dx, dy, dz=0): 145 | # pan in camera coordinate system (careful on the sensitivity!) 146 | self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz]) -------------------------------------------------------------------------------- /clip_sim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from torchvision import transforms as T 5 | from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor 6 | from typing import Literal 7 | from PIL import Image 8 | import kiui 9 | 10 | 11 | class CLIP: 12 | def __init__(self, device, model_name='openai/clip-vit-large-patch14'): 13 | 14 | self.device = device 15 | 16 | self.clip_model = CLIPModel.from_pretrained(model_name).to(self.device) 17 | self.processor = CLIPProcessor.from_pretrained(model_name) 18 | 19 | def encode_image(self, image): 20 | # image: PIL, np.ndarray uint8 [H, W, 3] 21 | 22 | pixel_values = self.processor( 23 | images=image, return_tensors="pt").pixel_values.to(self.device) 24 | image_features = self.clip_model.get_image_features( 25 | pixel_values=pixel_values) 26 | 27 | image_features = image_features / \ 28 | image_features.norm(dim=-1, keepdim=True) # normalize features 29 | 30 | return image_features 31 | 32 | def encode_text(self, text): 33 | # text: str 34 | 35 | inputs = self.processor(text=[text], padding=True, return_tensors="pt").to( 36 | self.device) 37 | text_features = self.clip_model.get_text_features(**inputs) 38 | 39 | text_features = text_features / \ 40 | text_features.norm(dim=-1, keepdim=True) # normalize features 41 | 42 | return text_features 43 | 44 | 45 | def read_image( 46 | path: str, 47 | mode: Literal["float", "uint8", "pil", "torch", "tensor"] = "float", 48 | order: Literal["RGB", "RGBA", "BGR", "BGRA"] = "RGB", 49 | ): 50 | 51 | if mode == "pil": 52 | return Image.open(path).convert(order) 53 | 54 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 55 | 56 | # cvtColor 57 | if len(img.shape) == 3: # ignore if gray scale 58 | if order in ["RGB", "RGBA"]: 59 | if img.shape[-1] == 4: 60 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) 61 | elif img.shape[-1] == 3: 62 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 63 | 64 | # mix background 65 | if img.shape[-1] == 4 and 'A' not in order: 66 | img = img.astype(np.float32) / 255 67 | img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:]) 68 | 69 | # mode 70 | if mode == "uint8": 71 | if img.dtype != np.uint8: 72 | img = (img * 255).astype(np.uint8) 73 | return img 74 | elif mode == "float": 75 | if img.dtype == np.uint8: 76 | img = img.astype(np.float32) / 255 77 | return img 78 | elif mode in ["tensor", "torch"]: 79 | if img.dtype == np.uint8: 80 | img = img.astype(np.float32) / 255 81 | return torch.from_numpy(img) 82 | else: 83 | raise ValueError(f"Unknown read_image mode {mode}") 84 | 85 | 86 | clip = CLIP('cuda', model_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k') 87 | 88 | def cal_clip_sim(ref_path, novel_path_ls): 89 | 90 | ref_img = read_image(ref_path, mode='float') 91 | if ref_img.shape[-1] == 4: 92 | # rgba to white-bg rgb 93 | ref_img = ref_img[..., :3] * ref_img[..., 3:] + (1 - ref_img[..., 3:]) 94 | ref_img = (ref_img * 255).astype(np.uint8) 95 | with torch.no_grad(): 96 | ref_features = clip.encode_image(ref_img) 97 | 98 | results = [] 99 | for novel_path in novel_path_ls: 100 | novel_img = read_image(novel_path, mode='float') 101 | if novel_img.shape[-1] == 4: 102 | # rgba to white-bg rgb 103 | novel_img = novel_img[..., :3] * novel_img[..., 3:] + (1 - novel_img[..., 3:]) 104 | novel_img = (novel_img * 255).astype(np.uint8) 105 | with torch.no_grad(): 106 | novel_features = clip.encode_image(novel_img) 107 | 108 | sim = (ref_features * novel_features).sum(dim=-1).mean().item() 109 | results.append(sim) 110 | 111 | avg_similarity = np.mean(results) 112 | return avg_similarity 113 | 114 | 115 | def cal_clip_sim_text(ref_text, novel_path_ls): 116 | with torch.no_grad(): 117 | ref_features = clip.encode_text(ref_text) 118 | 119 | results = [] 120 | for novel_path in novel_path_ls: 121 | novel_img = read_image(novel_path, mode='float') 122 | if novel_img.shape[-1] == 4: 123 | # rgba to white-bg rgb 124 | novel_img = novel_img[..., :3] * novel_img[..., 3:] + (1 - novel_img[..., 3:]) 125 | novel_img = (novel_img * 255).astype(np.uint8) 126 | with torch.no_grad(): 127 | novel_features = clip.encode_image(novel_img) 128 | 129 | sim = (ref_features * novel_features).sum(dim=-1).mean().item() 130 | results.append(sim) 131 | 132 | avg_similarity = np.mean(results) 133 | return avg_similarity 134 | -------------------------------------------------------------------------------- /configs/image_sai.yaml: -------------------------------------------------------------------------------- 1 | ### Input 2 | # input rgba image path (default to None, can be load in GUI too) 3 | input: 4 | # input text prompt (default to None, can be input in GUI too) 5 | prompt: 6 | negative_prompt: 7 | # input mesh for stage 2 (auto-search from stage 1 output path if None) 8 | mesh: 9 | # estimated elevation angle for input image 10 | elevation: 0 11 | # reference image resolution 12 | ref_size: 256 13 | # density thresh for mesh extraction 14 | density_thresh: 0.2 15 | 16 | ### Output 17 | outdir: logs 18 | mesh_format: obj 19 | save_path: ??? 20 | 21 | ### Training 22 | # use mvdream instead of sd 2.1 23 | mvdream: False 24 | # use imagedream 25 | imagedream: False 26 | # use stable-zero123 instead of zero123-xl 27 | stable_zero123: True 28 | # guidance loss weights (0 to disable) 29 | lambda_sd: 0 30 | lambda_zero123: 1 31 | # warmup rgb supervision for image-to-3d 32 | warmup_rgb_loss: False 33 | # training batch size per iter 34 | batch_size: 8 35 | # training iterations for stage 1 36 | iters: 500 37 | # whether to linearly anneal timestep 38 | anneal_timestep: True 39 | # training iterations for stage 2 40 | iters_refine: 1 41 | # training camera radius 42 | radius: 2 43 | # training camera fovy 44 | fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 45 | # training camera min elevation 46 | min_ver: -30 47 | # training camera max elevation 48 | max_ver: 30 49 | # checkpoint to load for stage 1 (should be a ply file) 50 | load: 51 | # whether allow geom training in stage 2 52 | train_geo: False 53 | # prob to invert background color during training (0 = always black, 1 = always white) 54 | invert_bg_prob: 1 55 | 56 | 57 | ### GUI 58 | gui: False 59 | force_cuda_rast: True 60 | # GUI resolution 61 | H: 800 62 | W: 800 63 | 64 | ### Gaussian splatting 65 | num_pts: 5000 66 | sh_degree: 0 67 | position_lr_init: 0.001 68 | position_lr_final: 0.00002 69 | position_lr_delay_mult: 0.02 70 | position_lr_max_steps: 500 71 | feature_lr: 0.01 72 | opacity_lr: 0.05 73 | scaling_lr: 0.005 74 | rotation_lr: 0.0005 75 | percent_dense: 0.01 76 | density_start_iter: 0 77 | density_end_iter: 300 78 | densification_interval: 100 79 | opacity_reset_interval: 700 80 | densify_grad_threshold: 0.01 81 | 82 | ### Textured Mesh 83 | geom_lr: 0.0001 84 | texture_lr: 0.2 85 | 86 | 87 | ### addtional 88 | denoise_steps: 20 89 | total_steps: 21 90 | t_start: 0.9 91 | t_end: 0.2 92 | steps_max: 15 93 | steps_min: 15 94 | init_steps: 15 95 | steps_schedule: 'fixed' 96 | cfg: 3.0 97 | ref_loss: 0.3 98 | ref_mask_loss: 0.01 99 | inv_r: 0.6 100 | eta: 0.5 101 | batch_size_max: 6 102 | batch_size_min: 6 103 | even_view: True -------------------------------------------------------------------------------- /configs/imagedream.yaml: -------------------------------------------------------------------------------- 1 | ### Input 2 | # input rgba image path (default to None, can be load in GUI too) 3 | input: 4 | # input text prompt (default to None, can be input in GUI too) 5 | prompt: 6 | negative_prompt: "ugly, bad anatomy, blurry, pixelated obscure, unnatural colors, poor lighting, dull, and unclear, cropped, lowres, low quality, artifacts, duplicate, morbid, mutilated, poorly drawn face, deformed, dehydrated, bad proportions" 7 | # input mesh for stage 2 (auto-search from stage 1 output path if None) 8 | mesh: 9 | # estimated elevation angle for input image 10 | elevation: 0 11 | # reference image resolution 12 | ref_size: 256 13 | # density thresh for mesh extraction 14 | density_thresh: 0.2 15 | 16 | ### Output 17 | outdir: logs 18 | mesh_format: obj 19 | save_path: ??? 20 | 21 | ### Training 22 | # use mvdream instead of sd 2.1 23 | mvdream: False 24 | # use imagedream 25 | imagedream: True 26 | # use stable-zero123 instead of zero123-xl 27 | stable_zero123: False 28 | # guidance loss weights (0 to disable) 29 | lambda_sd: 1 30 | lambda_zero123: 0 31 | # warmup rgb supervision for image-to-3d 32 | warmup_rgb_loss: False 33 | # training batch size per iter 34 | batch_size: 4 35 | # training iterations for stage 1 36 | iters: 500 37 | # whether to linearly anneal timestep 38 | anneal_timestep: True 39 | # training iterations for stage 2 40 | iters_refine: 3 41 | # training camera radius 42 | radius: 2.5 43 | # training camera fovy 44 | fovy: 49.1 45 | # training camera min elevation 46 | min_ver: -5 47 | # training camera max elevation 48 | max_ver: 0 49 | # checkpoint to load for stage 1 (should be a ply file) 50 | load: 51 | # whether allow geom training in stage 2 52 | train_geo: False 53 | # prob to invert background color during training (0 = always black, 1 = always white) 54 | invert_bg_prob: 1 55 | 56 | ### GUI 57 | gui: False 58 | force_cuda_rast: True 59 | # GUI resolution 60 | H: 800 61 | W: 800 62 | 63 | ### Gaussian splatting 64 | num_pts: 5000 65 | sh_degree: 0 66 | position_lr_init: 0.001 67 | position_lr_final: 0.00002 68 | position_lr_delay_mult: 0.02 69 | position_lr_max_steps: 500 70 | feature_lr: 0.01 71 | opacity_lr: 0.05 72 | scaling_lr: 0.005 73 | rotation_lr: 0.0005 74 | percent_dense: 0.01 75 | density_start_iter: 0 76 | density_end_iter: 300 77 | densification_interval: 100 78 | opacity_reset_interval: 700 79 | densify_grad_threshold: 0.01 80 | 81 | ### Textured Mesh 82 | geom_lr: 0.0001 83 | texture_lr: 0.2 84 | 85 | ### addtional 86 | denoise_steps: 10 87 | total_steps: 31 88 | t_start: 0.8 89 | t_end: 0.4 90 | steps_max: 15 91 | steps_min: 15 92 | init_steps: 50 93 | steps_schedule: 'cosine_up' 94 | cfg: 2.0 95 | # ref_loss: 0.3 96 | # ref_mask_loss: 0.001 97 | inv_r: 0.6 98 | eta: 0.0 99 | batch_size_max: 4 100 | batch_size_min: 4 101 | even_view: True -------------------------------------------------------------------------------- /configs/text_mv.yaml: -------------------------------------------------------------------------------- 1 | ### Input 2 | # input rgba image path (default to None, can be load in GUI too) 3 | input: 4 | # input text prompt (default to None, can be input in GUI too) 5 | prompt: 6 | negative_prompt: "ugly, bad anatomy, blurry, pixelated obscure, unnatural colors, poor lighting, dull, and unclear, cropped, lowres, low quality, artifacts, duplicate, morbid, mutilated, poorly drawn face, deformed, dehydrated, bad proportions" 7 | # input mesh for stage 2 (auto-search from stage 1 output path if None) 8 | mesh: 9 | # estimated elevation angle for input image 10 | elevation: 0 11 | # reference image resolution 12 | ref_size: 256 13 | # density thresh for mesh extraction 14 | density_thresh: 1 15 | 16 | ### Output 17 | outdir: logs 18 | mesh_format: obj 19 | save_path: ??? 20 | 21 | ### Training 22 | # use mvdream instead of sd 2.1 23 | mvdream: True 24 | # use imagedream 25 | imagedream: False 26 | # use stable-zero123 instead of zero123-xl 27 | stable_zero123: False 28 | # guidance loss weights (0 to disable) 29 | lambda_sd: 1 30 | lambda_zero123: 0 31 | # warmup rgb supervision for image-to-3d 32 | warmup_rgb_loss: False 33 | # training batch size per iter 34 | batch_size: 4 35 | # training iterations for stage 1 36 | iters: 500 37 | # whether to linearly anneal timestep 38 | anneal_timestep: True 39 | # training iterations for stage 2 40 | iters_refine: 50 41 | # training camera radius 42 | radius: 2.5 43 | # training camera fovy 44 | fovy: 49.1 45 | # training camera min elevation 46 | min_ver: -10 47 | # training camera max elevation 48 | max_ver: 10 49 | # checkpoint to load for stage 1 (should be a ply file) 50 | load: 51 | # whether allow geom training in stage 2 52 | train_geo: False 53 | # prob to invert background color during training (0 = always black, 1 = always white) 54 | invert_bg_prob: 1 55 | 56 | ### GUI 57 | gui: False 58 | force_cuda_rast: True 59 | # GUI resolution 60 | H: 800 61 | W: 800 62 | 63 | ### Gaussian splatting 64 | num_pts: 5000 65 | sh_degree: 0 66 | position_lr_init: 0.001 67 | position_lr_final: 0.00002 68 | position_lr_delay_mult: 0.02 69 | position_lr_max_steps: 500 70 | feature_lr: 0.01 71 | opacity_lr: 0.05 72 | scaling_lr: 0.005 73 | rotation_lr: 0.0005 74 | percent_dense: 0.01 75 | density_start_iter: 0 76 | density_end_iter: 300 77 | densification_interval: 100 78 | opacity_reset_interval: 700 79 | densify_grad_threshold: 0.01 80 | 81 | ### Textured Mesh 82 | geom_lr: 0.0001 83 | texture_lr: 0.2 84 | 85 | ### addtional 86 | denoise_steps: 50 87 | total_steps: 31 -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/core/__init__.py -------------------------------------------------------------------------------- /core/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import os 11 | import warnings 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 17 | try: 18 | if XFORMERS_ENABLED: 19 | from xformers.ops import memory_efficient_attention, unbind 20 | 21 | XFORMERS_AVAILABLE = True 22 | warnings.warn("xFormers is available (Attention)") 23 | else: 24 | warnings.warn("xFormers is disabled (Attention)") 25 | raise ImportError 26 | except ImportError: 27 | XFORMERS_AVAILABLE = False 28 | warnings.warn("xFormers is not available (Attention)") 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__( 33 | self, 34 | dim: int, 35 | num_heads: int = 8, 36 | qkv_bias: bool = False, 37 | proj_bias: bool = True, 38 | attn_drop: float = 0.0, 39 | proj_drop: float = 0.0, 40 | ) -> None: 41 | super().__init__() 42 | self.num_heads = num_heads 43 | head_dim = dim // num_heads 44 | self.scale = head_dim**-0.5 45 | 46 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 47 | self.attn_drop = nn.Dropout(attn_drop) 48 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 49 | self.proj_drop = nn.Dropout(proj_drop) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | B, N, C = x.shape 53 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 54 | 55 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 56 | attn = q @ k.transpose(-2, -1) 57 | 58 | attn = attn.softmax(dim=-1) 59 | attn = self.attn_drop(attn) 60 | 61 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 62 | x = self.proj(x) 63 | x = self.proj_drop(x) 64 | return x 65 | 66 | 67 | class MemEffAttention(Attention): 68 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 69 | if not XFORMERS_AVAILABLE: 70 | if attn_bias is not None: 71 | raise AssertionError("xFormers is required for using nested tensors") 72 | return super().forward(x) 73 | 74 | B, N, C = x.shape 75 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 76 | 77 | q, k, v = unbind(qkv, 2) 78 | 79 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 80 | x = x.reshape([B, N, C]) 81 | 82 | x = self.proj(x) 83 | x = self.proj_drop(x) 84 | return x 85 | 86 | 87 | class CrossAttention(nn.Module): 88 | def __init__( 89 | self, 90 | dim: int, 91 | dim_q: int, 92 | dim_k: int, 93 | dim_v: int, 94 | num_heads: int = 8, 95 | qkv_bias: bool = False, 96 | proj_bias: bool = True, 97 | attn_drop: float = 0.0, 98 | proj_drop: float = 0.0, 99 | ) -> None: 100 | super().__init__() 101 | self.dim = dim 102 | self.num_heads = num_heads 103 | head_dim = dim // num_heads 104 | self.scale = head_dim**-0.5 105 | 106 | self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) 107 | self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) 108 | self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 114 | # q: [B, N, Cq] 115 | # k: [B, M, Ck] 116 | # v: [B, M, Cv] 117 | # return: [B, N, C] 118 | 119 | B, N, _ = q.shape 120 | M = k.shape[1] 121 | 122 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh] 123 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] 124 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] 125 | 126 | attn = q @ k.transpose(-2, -1) # [B, nh, N, M] 127 | 128 | attn = attn.softmax(dim=-1) # [B, nh, N, M] 129 | attn = self.attn_drop(attn) 130 | 131 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C] 132 | x = self.proj(x) 133 | x = self.proj_drop(x) 134 | return x 135 | 136 | 137 | class MemEffCrossAttention(CrossAttention): 138 | def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor: 139 | if not XFORMERS_AVAILABLE: 140 | if attn_bias is not None: 141 | raise AssertionError("xFormers is required for using nested tensors") 142 | return super().forward(x) 143 | 144 | B, N, _ = q.shape 145 | M = k.shape[1] 146 | 147 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] 148 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 149 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 150 | 151 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 152 | x = x.reshape(B, N, -1) 153 | 154 | x = self.proj(x) 155 | x = self.proj_drop(x) 156 | return x 157 | -------------------------------------------------------------------------------- /core/gs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from diff_gaussian_rasterization import ( 8 | GaussianRasterizationSettings, 9 | GaussianRasterizer, 10 | ) 11 | 12 | from core.options import Options 13 | 14 | import kiui 15 | 16 | class GaussianRenderer: 17 | def __init__(self, opt: Options): 18 | 19 | self.opt = opt 20 | self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") 21 | 22 | # intrinsics 23 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) 24 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) 25 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov 26 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov 27 | self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) 28 | self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) 29 | self.proj_matrix[2, 3] = 1 30 | 31 | def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1): 32 | # gaussians: [B, N, 14] 33 | # cam_view, cam_view_proj: [B, V, 4, 4] 34 | # cam_pos: [B, V, 3] 35 | 36 | device = gaussians.device 37 | B, V = cam_view.shape[:2] 38 | 39 | # loop of loop... 40 | images = [] 41 | alphas = [] 42 | for b in range(B): 43 | 44 | # pos, opacity, scale, rotation, shs 45 | means3D = gaussians[b, :, 0:3].contiguous().float() 46 | opacity = gaussians[b, :, 3:4].contiguous().float() 47 | scales = gaussians[b, :, 4:7].contiguous().float() 48 | rotations = gaussians[b, :, 7:11].contiguous().float() 49 | rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3] 50 | 51 | for v in range(V): 52 | 53 | # render novel views 54 | view_matrix = cam_view[b, v].float() 55 | view_proj_matrix = cam_view_proj[b, v].float() 56 | campos = cam_pos[b, v].float() 57 | 58 | raster_settings = GaussianRasterizationSettings( 59 | image_height=self.opt.output_size, 60 | image_width=self.opt.output_size, 61 | tanfovx=self.tan_half_fov, 62 | tanfovy=self.tan_half_fov, 63 | bg=self.bg_color if bg_color is None else bg_color, 64 | scale_modifier=scale_modifier, 65 | viewmatrix=view_matrix, 66 | projmatrix=view_proj_matrix, 67 | sh_degree=0, 68 | campos=campos, 69 | prefiltered=False, 70 | debug=False, 71 | ) 72 | 73 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 74 | 75 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 76 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 77 | means3D=means3D, 78 | means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device), 79 | shs=None, 80 | colors_precomp=rgbs, 81 | opacities=opacity, 82 | scales=scales, 83 | rotations=rotations, 84 | cov3D_precomp=None, 85 | ) 86 | 87 | rendered_image = rendered_image.clamp(0, 1) 88 | 89 | images.append(rendered_image) 90 | alphas.append(rendered_alpha) 91 | 92 | images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size) 93 | alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size) 94 | 95 | return { 96 | "image": images, # [B, V, 3, H, W] 97 | "alpha": alphas, # [B, V, 1, H, W] 98 | } 99 | 100 | 101 | def save_ply(self, gaussians, path, compatible=True): 102 | # gaussians: [B, N, 14] 103 | # compatible: save pre-activated gaussians as in the original paper 104 | 105 | assert gaussians.shape[0] == 1, 'only support batch size 1' 106 | 107 | from plyfile import PlyData, PlyElement 108 | 109 | means3D = gaussians[0, :, 0:3].contiguous().float() 110 | opacity = gaussians[0, :, 3:4].contiguous().float() 111 | scales = gaussians[0, :, 4:7].contiguous().float() 112 | rotations = gaussians[0, :, 7:11].contiguous().float() 113 | shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] 114 | 115 | # prune by opacity 116 | mask = opacity.squeeze(-1) >= 0.005 117 | means3D = means3D[mask] 118 | opacity = opacity[mask] 119 | scales = scales[mask] 120 | rotations = rotations[mask] 121 | shs = shs[mask] 122 | 123 | # invert activation to make it compatible with the original ply format 124 | if compatible: 125 | opacity = kiui.op.inverse_sigmoid(opacity) 126 | scales = torch.log(scales + 1e-8) 127 | shs = (shs - 0.5) / 0.28209479177387814 128 | 129 | xyzs = means3D.detach().cpu().numpy() 130 | f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 131 | opacities = opacity.detach().cpu().numpy() 132 | scales = scales.detach().cpu().numpy() 133 | rotations = rotations.detach().cpu().numpy() 134 | 135 | l = ['x', 'y', 'z'] 136 | # All channels except the 3 DC 137 | for i in range(f_dc.shape[1]): 138 | l.append('f_dc_{}'.format(i)) 139 | l.append('opacity') 140 | for i in range(scales.shape[1]): 141 | l.append('scale_{}'.format(i)) 142 | for i in range(rotations.shape[1]): 143 | l.append('rot_{}'.format(i)) 144 | 145 | dtype_full = [(attribute, 'f4') for attribute in l] 146 | 147 | elements = np.empty(xyzs.shape[0], dtype=dtype_full) 148 | attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) 149 | elements[:] = list(map(tuple, attributes)) 150 | el = PlyElement.describe(elements, 'vertex') 151 | 152 | PlyData([el]).write(path) 153 | 154 | def load_ply(self, path, compatible=True): 155 | 156 | from plyfile import PlyData, PlyElement 157 | 158 | plydata = PlyData.read(path) 159 | 160 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 161 | np.asarray(plydata.elements[0]["y"]), 162 | np.asarray(plydata.elements[0]["z"])), axis=1) 163 | print("Number of points at loading : ", xyz.shape[0]) 164 | 165 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 166 | 167 | shs = np.zeros((xyz.shape[0], 3)) 168 | shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 169 | shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) 170 | shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) 171 | 172 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 173 | scales = np.zeros((xyz.shape[0], len(scale_names))) 174 | for idx, attr_name in enumerate(scale_names): 175 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 176 | 177 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] 178 | rots = np.zeros((xyz.shape[0], len(rot_names))) 179 | for idx, attr_name in enumerate(rot_names): 180 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 181 | 182 | gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) 183 | gaussians = torch.from_numpy(gaussians).float() # cpu 184 | 185 | if compatible: 186 | gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) 187 | gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) 188 | gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 189 | 190 | return gaussians -------------------------------------------------------------------------------- /core/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import kiui 7 | from kiui.lpips import LPIPS 8 | 9 | from core.unet import UNet 10 | from core.options import Options 11 | from core.gs import GaussianRenderer 12 | 13 | 14 | class LGM(nn.Module): 15 | def __init__( 16 | self, 17 | opt: Options, 18 | ): 19 | super().__init__() 20 | 21 | self.opt = opt 22 | 23 | # unet 24 | self.unet = UNet( 25 | 9, 14, 26 | down_channels=self.opt.down_channels, 27 | down_attention=self.opt.down_attention, 28 | mid_attention=self.opt.mid_attention, 29 | up_channels=self.opt.up_channels, 30 | up_attention=self.opt.up_attention, 31 | ) 32 | 33 | # last conv 34 | self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again 35 | 36 | # Gaussian Renderer 37 | self.gs = GaussianRenderer(opt) 38 | 39 | # activations... 40 | self.pos_act = lambda x: x.clamp(-1, 1) 41 | self.scale_act = lambda x: 0.1 * F.softplus(x) 42 | self.opacity_act = lambda x: torch.sigmoid(x) 43 | self.rot_act = lambda x: F.normalize(x, dim=-1) 44 | self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again 45 | 46 | # LPIPS loss 47 | if self.opt.lambda_lpips > 0: 48 | self.lpips_loss = LPIPS(net='vgg') 49 | self.lpips_loss.requires_grad_(False) 50 | 51 | 52 | def state_dict(self, **kwargs): 53 | # remove lpips_loss 54 | state_dict = super().state_dict(**kwargs) 55 | for k in list(state_dict.keys()): 56 | if 'lpips_loss' in k: 57 | del state_dict[k] 58 | return state_dict 59 | 60 | 61 | def prepare_default_rays(self, device, elevation=0): 62 | 63 | from kiui.cam import orbit_camera 64 | from core.utils import get_rays 65 | 66 | cam_poses = np.stack([ 67 | orbit_camera(elevation, 0, radius=self.opt.cam_radius), 68 | orbit_camera(elevation, 90, radius=self.opt.cam_radius), 69 | orbit_camera(elevation, 180, radius=self.opt.cam_radius), 70 | orbit_camera(elevation, 270, radius=self.opt.cam_radius), 71 | ], axis=0) # [4, 4, 4] 72 | cam_poses = torch.from_numpy(cam_poses) 73 | 74 | rays_embeddings = [] 75 | for i in range(cam_poses.shape[0]): 76 | rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] 77 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] 78 | rays_embeddings.append(rays_plucker) 79 | 80 | ## visualize rays for plotting figure 81 | # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) 82 | 83 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] 84 | 85 | return rays_embeddings 86 | 87 | 88 | def forward_gaussians(self, images): 89 | # images: [B, 4, 9, H, W] 90 | # return: Gaussians: [B, dim_t] 91 | 92 | B, V, C, H, W = images.shape 93 | images = images.view(B*V, C, H, W) 94 | 95 | x = self.unet(images) # [B*4, 14, h, w] 96 | x = self.conv(x) # [B*4, 14, h, w] 97 | 98 | x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size) 99 | 100 | ## visualize multi-view gaussian features for plotting figure 101 | # tmp_alpha = self.opacity_act(x[0, :, 3:4]) 102 | # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha) 103 | # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5 104 | # kiui.vis.plot_image(tmp_img_rgb, save=True) 105 | # kiui.vis.plot_image(tmp_img_pos, save=True) 106 | 107 | x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14) 108 | 109 | pos = self.pos_act(x[..., 0:3]) # [B, N, 3] 110 | opacity = self.opacity_act(x[..., 3:4]) 111 | scale = self.scale_act(x[..., 4:7]) 112 | rotation = self.rot_act(x[..., 7:11]) 113 | rgbs = self.rgb_act(x[..., 11:]) 114 | 115 | gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14] 116 | 117 | return gaussians 118 | 119 | 120 | def forward(self, data, step_ratio=1): 121 | # data: output of the dataloader 122 | # return: loss 123 | 124 | results = {} 125 | loss = 0 126 | 127 | images = data['input'] # [B, 4, 9, h, W], input features 128 | 129 | # use the first view to predict gaussians 130 | gaussians = self.forward_gaussians(images) # [B, N, 14] 131 | 132 | results['gaussians'] = gaussians 133 | 134 | # always use white bg 135 | bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) 136 | 137 | # use the other views for rendering and supervision 138 | results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) 139 | pred_images = results['image'] # [B, V, C, output_size, output_size] 140 | pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size] 141 | 142 | results['images_pred'] = pred_images 143 | results['alphas_pred'] = pred_alphas 144 | 145 | gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views 146 | gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks 147 | 148 | gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) 149 | 150 | loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) 151 | loss = loss + loss_mse 152 | 153 | if self.opt.lambda_lpips > 0: 154 | loss_lpips = self.lpips_loss( 155 | # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, 156 | # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, 157 | # downsampled to at most 256 to reduce memory cost 158 | F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), 159 | F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), 160 | ).mean() 161 | results['loss_lpips'] = loss_lpips 162 | loss = loss + self.opt.lambda_lpips * loss_lpips 163 | 164 | results['loss'] = loss 165 | 166 | # metric 167 | with torch.no_grad(): 168 | psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) 169 | results['psnr'] = psnr 170 | 171 | return results 172 | -------------------------------------------------------------------------------- /core/options.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass 3 | from typing import Tuple, Literal, Dict, Optional 4 | 5 | 6 | @dataclass 7 | class Options: 8 | ### model 9 | # Unet image input size 10 | input_size: int = 256 11 | # Unet definition 12 | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) 13 | down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) 14 | mid_attention: bool = True 15 | up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) 16 | up_attention: Tuple[bool, ...] = (True, True, True, False) 17 | # Unet output size, dependent on the input_size and U-Net structure! 18 | splat_size: int = 64 19 | # gaussian render size 20 | output_size: int = 256 21 | 22 | ### dataset 23 | # data mode (only support s3 now) 24 | data_mode: Literal['s3'] = 's3' 25 | # fovy of the dataset 26 | fovy: float = 49.1 27 | # camera near plane 28 | znear: float = 0.5 29 | # camera far plane 30 | zfar: float = 2.5 31 | # number of all views (input + output) 32 | num_views: int = 12 33 | # number of views 34 | num_input_views: int = 4 35 | # camera radius 36 | cam_radius: float = 1.5 # to better use [-1, 1]^3 space 37 | # num workers 38 | num_workers: int = 8 39 | 40 | ### training 41 | # workspace 42 | workspace: str = './workspace' 43 | # resume 44 | resume: Optional[str] = None 45 | # batch size (per-GPU) 46 | batch_size: int = 8 47 | # gradient accumulation 48 | gradient_accumulation_steps: int = 1 49 | # training epochs 50 | num_epochs: int = 30 51 | # lpips loss weight 52 | lambda_lpips: float = 1.0 53 | # gradient clip 54 | gradient_clip: float = 1.0 55 | # mixed precision 56 | mixed_precision: str = 'bf16' 57 | # learning rate 58 | lr: float = 4e-4 59 | # augmentation prob for grid distortion 60 | prob_grid_distortion: float = 0.5 61 | # augmentation prob for camera jitter 62 | prob_cam_jitter: float = 0.5 63 | 64 | ### testing 65 | # test image path 66 | test_path: Optional[str] = None 67 | 68 | ### misc 69 | # nvdiffrast backend setting 70 | force_cuda_rast: bool = False 71 | # render fancy video with gaussian scaling effect 72 | fancy_video: bool = False 73 | 74 | 75 | # all the default settings 76 | config_defaults: Dict[str, Options] = {} 77 | config_doc: Dict[str, str] = {} 78 | 79 | config_doc['lrm'] = 'the default settings for LGM' 80 | config_defaults['lrm'] = Options() 81 | 82 | config_doc['small'] = 'small model with lower resolution Gaussians' 83 | config_defaults['small'] = Options( 84 | input_size=256, 85 | splat_size=64, 86 | output_size=256, 87 | batch_size=8, 88 | gradient_accumulation_steps=1, 89 | mixed_precision='bf16', 90 | ) 91 | 92 | config_doc['big'] = 'big model with higher resolution Gaussians' 93 | config_defaults['big'] = Options( 94 | input_size=256, 95 | up_channels=(1024, 1024, 512, 256, 128), # one more decoder 96 | up_attention=(True, True, True, False, False), 97 | splat_size=128, 98 | output_size=512, # render & supervise Gaussians at a higher resolution. 99 | batch_size=8, 100 | num_views=8, 101 | gradient_accumulation_steps=1, 102 | mixed_precision='bf16', 103 | ) 104 | 105 | config_doc['tiny'] = 'tiny model for ablation' 106 | config_defaults['tiny'] = Options( 107 | input_size=256, 108 | down_channels=(32, 64, 128, 256, 512), 109 | down_attention=(False, False, False, False, True), 110 | up_channels=(512, 256, 128), 111 | up_attention=(True, False, False, False), 112 | splat_size=64, 113 | output_size=256, 114 | batch_size=16, 115 | num_views=8, 116 | gradient_accumulation_steps=1, 117 | mixed_precision='bf16', 118 | ) 119 | 120 | AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) 121 | -------------------------------------------------------------------------------- /core/provider_objaverse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.transforms.functional as TF 10 | from torch.utils.data import Dataset 11 | 12 | import kiui 13 | from core.options import Options 14 | from core.utils import get_rays, grid_distortion, orbit_camera_jitter 15 | 16 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 17 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 18 | 19 | 20 | class ObjaverseDataset(Dataset): 21 | 22 | def _warn(self): 23 | raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)') 24 | 25 | def __init__(self, opt: Options, training=True): 26 | 27 | self.opt = opt 28 | self.training = training 29 | 30 | # TODO: remove this barrier 31 | self._warn() 32 | 33 | # TODO: load the list of objects for training 34 | self.items = [] 35 | with open('TODO: file containing the list', 'r') as f: 36 | for line in f.readlines(): 37 | self.items.append(line.strip()) 38 | 39 | # naive split 40 | if self.training: 41 | self.items = self.items[:-self.opt.batch_size] 42 | else: 43 | self.items = self.items[-self.opt.batch_size:] 44 | 45 | # default camera intrinsics 46 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) 47 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) 48 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov 49 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov 50 | self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) 51 | self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) 52 | self.proj_matrix[2, 3] = 1 53 | 54 | 55 | def __len__(self): 56 | return len(self.items) 57 | 58 | def __getitem__(self, idx): 59 | 60 | uid = self.items[idx] 61 | results = {} 62 | 63 | # load num_views images 64 | images = [] 65 | masks = [] 66 | cam_poses = [] 67 | 68 | vid_cnt = 0 69 | 70 | # TODO: choose views, based on your rendering settings 71 | if self.training: 72 | # input views are in (36, 72), other views are randomly selected 73 | vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist() 74 | else: 75 | # fixed views 76 | vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist() 77 | 78 | for vid in vids: 79 | 80 | image_path = os.path.join(uid, 'rgb', f'{vid:03d}.png') 81 | camera_path = os.path.join(uid, 'pose', f'{vid:03d}.txt') 82 | 83 | try: 84 | # TODO: load data (modify self.client here) 85 | image = np.frombuffer(self.client.get(image_path), np.uint8) 86 | image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] 87 | c2w = [float(t) for t in self.client.get(camera_path).decode().strip().split(' ')] 88 | c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4) 89 | except Exception as e: 90 | # print(f'[WARN] dataset {uid} {vid}: {e}') 91 | continue 92 | 93 | # TODO: you may have a different camera system 94 | # blender world + opencv cam --> opengl world & cam 95 | c2w[1] *= -1 96 | c2w[[1, 2]] = c2w[[2, 1]] 97 | c2w[:3, 1:3] *= -1 # invert up and forward direction 98 | 99 | # scale up radius to fully use the [-1, 1]^3 space! 100 | c2w[:3, 3] *= self.opt.cam_radius / 1.5 # 1.5 is the default scale 101 | 102 | image = image.permute(2, 0, 1) # [4, 512, 512] 103 | mask = image[3:4] # [1, 512, 512] 104 | image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg 105 | image = image[[2,1,0]].contiguous() # bgr to rgb 106 | 107 | images.append(image) 108 | masks.append(mask.squeeze(0)) 109 | cam_poses.append(c2w) 110 | 111 | vid_cnt += 1 112 | if vid_cnt == self.opt.num_views: 113 | break 114 | 115 | if vid_cnt < self.opt.num_views: 116 | print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!') 117 | n = self.opt.num_views - vid_cnt 118 | images = images + [images[-1]] * n 119 | masks = masks + [masks[-1]] * n 120 | cam_poses = cam_poses + [cam_poses[-1]] * n 121 | 122 | images = torch.stack(images, dim=0) # [V, C, H, W] 123 | masks = torch.stack(masks, dim=0) # [V, H, W] 124 | cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4] 125 | 126 | # normalized camera feats as in paper (transform the first pose to a fixed position) 127 | transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0]) 128 | cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4] 129 | 130 | images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W] 131 | cam_poses_input = cam_poses[:self.opt.num_input_views].clone() 132 | 133 | # data augmentation 134 | if self.training: 135 | # apply random grid distortion to simulate 3D inconsistency 136 | if random.random() < self.opt.prob_grid_distortion: 137 | images_input[1:] = grid_distortion(images_input[1:]) 138 | # apply camera jittering (only to input!) 139 | if random.random() < self.opt.prob_cam_jitter: 140 | cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:]) 141 | 142 | images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 143 | 144 | # resize render ground-truth images, range still in [0, 1] 145 | results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size] 146 | results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size] 147 | 148 | # build rays for input views 149 | rays_embeddings = [] 150 | for i in range(self.opt.num_input_views): 151 | rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] 152 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] 153 | rays_embeddings.append(rays_plucker) 154 | 155 | 156 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w] 157 | final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W] 158 | results['input'] = final_input 159 | 160 | # opengl to colmap camera for gaussian renderer 161 | cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction 162 | 163 | # cameras needed by gaussian rasterizer 164 | cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] 165 | cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] 166 | cam_pos = - cam_poses[:, :3, 3] # [V, 3] 167 | 168 | results['cam_view'] = cam_view 169 | results['cam_view_proj'] = cam_view_proj 170 | results['cam_pos'] = cam_pos 171 | 172 | return results -------------------------------------------------------------------------------- /core/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from typing import Tuple, Literal 7 | from functools import partial 8 | 9 | from core.attention import MemEffAttention 10 | 11 | class MVAttention(nn.Module): 12 | def __init__( 13 | self, 14 | dim: int, 15 | num_heads: int = 8, 16 | qkv_bias: bool = False, 17 | proj_bias: bool = True, 18 | attn_drop: float = 0.0, 19 | proj_drop: float = 0.0, 20 | groups: int = 32, 21 | eps: float = 1e-5, 22 | residual: bool = True, 23 | skip_scale: float = 1, 24 | num_frames: int = 4, # WARN: hardcoded! 25 | ): 26 | super().__init__() 27 | 28 | self.residual = residual 29 | self.skip_scale = skip_scale 30 | self.num_frames = num_frames 31 | 32 | self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) 33 | self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) 34 | 35 | def forward(self, x): 36 | # x: [B*V, C, H, W] 37 | BV, C, H, W = x.shape 38 | B = BV // self.num_frames # assert BV % self.num_frames == 0 39 | 40 | res = x 41 | x = self.norm(x) 42 | 43 | x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C) 44 | x = self.attn(x) 45 | x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W) 46 | 47 | if self.residual: 48 | x = (x + res) * self.skip_scale 49 | return x 50 | 51 | class ResnetBlock(nn.Module): 52 | def __init__( 53 | self, 54 | in_channels: int, 55 | out_channels: int, 56 | resample: Literal['default', 'up', 'down'] = 'default', 57 | groups: int = 32, 58 | eps: float = 1e-5, 59 | skip_scale: float = 1, # multiplied to output 60 | ): 61 | super().__init__() 62 | 63 | self.in_channels = in_channels 64 | self.out_channels = out_channels 65 | self.skip_scale = skip_scale 66 | 67 | self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 68 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 69 | 70 | self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) 71 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 72 | 73 | self.act = F.silu 74 | 75 | self.resample = None 76 | if resample == 'up': 77 | self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") 78 | elif resample == 'down': 79 | self.resample = nn.AvgPool2d(kernel_size=2, stride=2) 80 | 81 | self.shortcut = nn.Identity() 82 | if self.in_channels != self.out_channels: 83 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) 84 | 85 | 86 | def forward(self, x): 87 | res = x 88 | 89 | x = self.norm1(x) 90 | x = self.act(x) 91 | 92 | if self.resample: 93 | res = self.resample(res) 94 | x = self.resample(x) 95 | 96 | x = self.conv1(x) 97 | x = self.norm2(x) 98 | x = self.act(x) 99 | x = self.conv2(x) 100 | 101 | x = (x + self.shortcut(res)) * self.skip_scale 102 | 103 | return x 104 | 105 | class DownBlock(nn.Module): 106 | def __init__( 107 | self, 108 | in_channels: int, 109 | out_channels: int, 110 | num_layers: int = 1, 111 | downsample: bool = True, 112 | attention: bool = True, 113 | attention_heads: int = 16, 114 | skip_scale: float = 1, 115 | ): 116 | super().__init__() 117 | 118 | nets = [] 119 | attns = [] 120 | for i in range(num_layers): 121 | in_channels = in_channels if i == 0 else out_channels 122 | nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale)) 123 | if attention: 124 | attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) 125 | else: 126 | attns.append(None) 127 | self.nets = nn.ModuleList(nets) 128 | self.attns = nn.ModuleList(attns) 129 | 130 | self.downsample = None 131 | if downsample: 132 | self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) 133 | 134 | def forward(self, x): 135 | xs = [] 136 | 137 | for attn, net in zip(self.attns, self.nets): 138 | x = net(x) 139 | if attn: 140 | x = attn(x) 141 | xs.append(x) 142 | 143 | if self.downsample: 144 | x = self.downsample(x) 145 | xs.append(x) 146 | 147 | return x, xs 148 | 149 | 150 | class MidBlock(nn.Module): 151 | def __init__( 152 | self, 153 | in_channels: int, 154 | num_layers: int = 1, 155 | attention: bool = True, 156 | attention_heads: int = 16, 157 | skip_scale: float = 1, 158 | ): 159 | super().__init__() 160 | 161 | nets = [] 162 | attns = [] 163 | # first layer 164 | nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) 165 | # more layers 166 | for i in range(num_layers): 167 | nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) 168 | if attention: 169 | attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale)) 170 | else: 171 | attns.append(None) 172 | self.nets = nn.ModuleList(nets) 173 | self.attns = nn.ModuleList(attns) 174 | 175 | def forward(self, x): 176 | x = self.nets[0](x) 177 | for attn, net in zip(self.attns, self.nets[1:]): 178 | if attn: 179 | x = attn(x) 180 | x = net(x) 181 | return x 182 | 183 | 184 | class UpBlock(nn.Module): 185 | def __init__( 186 | self, 187 | in_channels: int, 188 | prev_out_channels: int, 189 | out_channels: int, 190 | num_layers: int = 1, 191 | upsample: bool = True, 192 | attention: bool = True, 193 | attention_heads: int = 16, 194 | skip_scale: float = 1, 195 | ): 196 | super().__init__() 197 | 198 | nets = [] 199 | attns = [] 200 | for i in range(num_layers): 201 | cin = in_channels if i == 0 else out_channels 202 | cskip = prev_out_channels if (i == num_layers - 1) else out_channels 203 | 204 | nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale)) 205 | if attention: 206 | attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) 207 | else: 208 | attns.append(None) 209 | self.nets = nn.ModuleList(nets) 210 | self.attns = nn.ModuleList(attns) 211 | 212 | self.upsample = None 213 | if upsample: 214 | self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 215 | 216 | def forward(self, x, xs): 217 | 218 | for attn, net in zip(self.attns, self.nets): 219 | res_x = xs[-1] 220 | xs = xs[:-1] 221 | x = torch.cat([x, res_x], dim=1) 222 | x = net(x) 223 | if attn: 224 | x = attn(x) 225 | 226 | if self.upsample: 227 | x = F.interpolate(x, scale_factor=2.0, mode='nearest') 228 | x = self.upsample(x) 229 | 230 | return x 231 | 232 | 233 | # it could be asymmetric! 234 | class UNet(nn.Module): 235 | def __init__( 236 | self, 237 | in_channels: int = 3, 238 | out_channels: int = 3, 239 | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), 240 | down_attention: Tuple[bool, ...] = (False, False, False, True, True), 241 | mid_attention: bool = True, 242 | up_channels: Tuple[int, ...] = (1024, 512, 256), 243 | up_attention: Tuple[bool, ...] = (True, True, False), 244 | layers_per_block: int = 2, 245 | skip_scale: float = np.sqrt(0.5), 246 | ): 247 | super().__init__() 248 | 249 | # first 250 | self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) 251 | 252 | # down 253 | down_blocks = [] 254 | cout = down_channels[0] 255 | for i in range(len(down_channels)): 256 | cin = cout 257 | cout = down_channels[i] 258 | 259 | down_blocks.append(DownBlock( 260 | cin, cout, 261 | num_layers=layers_per_block, 262 | downsample=(i != len(down_channels) - 1), # not final layer 263 | attention=down_attention[i], 264 | skip_scale=skip_scale, 265 | )) 266 | self.down_blocks = nn.ModuleList(down_blocks) 267 | 268 | # mid 269 | self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) 270 | 271 | # up 272 | up_blocks = [] 273 | cout = up_channels[0] 274 | for i in range(len(up_channels)): 275 | cin = cout 276 | cout = up_channels[i] 277 | cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric 278 | 279 | up_blocks.append(UpBlock( 280 | cin, cskip, cout, 281 | num_layers=layers_per_block + 1, # one more layer for up 282 | upsample=(i != len(up_channels) - 1), # not final layer 283 | attention=up_attention[i], 284 | skip_scale=skip_scale, 285 | )) 286 | self.up_blocks = nn.ModuleList(up_blocks) 287 | 288 | # last 289 | self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5) 290 | self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) 291 | 292 | 293 | def forward(self, x): 294 | # x: [B, Cin, H, W] 295 | 296 | # first 297 | x = self.conv_in(x) 298 | 299 | # down 300 | xss = [x] 301 | for block in self.down_blocks: 302 | x, xs = block(x) 303 | xss.extend(xs) 304 | 305 | # mid 306 | x = self.mid_block(x) 307 | 308 | # up 309 | for block in self.up_blocks: 310 | xs = xss[-len(block.nets):] 311 | xss = xss[:-len(block.nets)] 312 | x = block(x, xs) 313 | 314 | # last 315 | x = self.norm_out(x) 316 | x = F.silu(x) 317 | x = self.conv_out(x) # [B, Cout, H', W'] 318 | 319 | return x 320 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import roma 8 | from kiui.op import safe_normalize 9 | 10 | def get_rays(pose, h, w, fovy, opengl=True): 11 | 12 | x, y = torch.meshgrid( 13 | torch.arange(w, device=pose.device), 14 | torch.arange(h, device=pose.device), 15 | indexing="xy", 16 | ) 17 | x = x.flatten() 18 | y = y.flatten() 19 | 20 | cx = w * 0.5 21 | cy = h * 0.5 22 | 23 | focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) 24 | 25 | camera_dirs = F.pad( 26 | torch.stack( 27 | [ 28 | (x - cx + 0.5) / focal, 29 | (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), 30 | ], 31 | dim=-1, 32 | ), 33 | (0, 1), 34 | value=(-1.0 if opengl else 1.0), 35 | ) # [hw, 3] 36 | 37 | rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] 38 | rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] 39 | 40 | rays_o = rays_o.view(h, w, 3) 41 | rays_d = safe_normalize(rays_d).view(h, w, 3) 42 | 43 | return rays_o, rays_d 44 | 45 | def orbit_camera_jitter(poses, strength=0.1): 46 | # poses: [B, 4, 4], assume orbit camera in opengl format 47 | # random orbital rotate 48 | 49 | B = poses.shape[0] 50 | rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) 51 | rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) 52 | 53 | rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) 54 | R = rot @ poses[:, :3, :3] 55 | T = rot @ poses[:, :3, 3:] 56 | 57 | new_poses = poses.clone() 58 | new_poses[:, :3, :3] = R 59 | new_poses[:, :3, 3:] = T 60 | 61 | return new_poses 62 | 63 | def grid_distortion(images, strength=0.5): 64 | # images: [B, C, H, W] 65 | # num_steps: int, grid resolution for distortion 66 | # strength: float in [0, 1], strength of distortion 67 | 68 | B, C, H, W = images.shape 69 | 70 | num_steps = np.random.randint(8, 17) 71 | grid_steps = torch.linspace(-1, 1, num_steps) 72 | 73 | # have to loop batch... 74 | grids = [] 75 | for b in range(B): 76 | # construct displacement 77 | x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive 78 | x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb 79 | x_steps = (x_steps * W).long() # [num_steps] 80 | x_steps[0] = 0 81 | x_steps[-1] = W 82 | xs = [] 83 | for i in range(num_steps - 1): 84 | xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) 85 | xs = torch.cat(xs, dim=0) # [W] 86 | 87 | y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive 88 | y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb 89 | y_steps = (y_steps * H).long() # [num_steps] 90 | y_steps[0] = 0 91 | y_steps[-1] = H 92 | ys = [] 93 | for i in range(num_steps - 1): 94 | ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) 95 | ys = torch.cat(ys, dim=0) # [H] 96 | 97 | # construct grid 98 | grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W] 99 | grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2] 100 | 101 | grids.append(grid) 102 | 103 | grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2] 104 | 105 | # grid sample 106 | images = F.grid_sample(images, grids, align_corners=False) 107 | 108 | return images 109 | 110 | -------------------------------------------------------------------------------- /grid_put.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def stride_from_shape(shape): 5 | stride = [1] 6 | for x in reversed(shape[1:]): 7 | stride.append(stride[-1] * x) 8 | return list(reversed(stride)) 9 | 10 | 11 | def scatter_add_nd(input, indices, values): 12 | # input: [..., C], D dimension + C channel 13 | # indices: [N, D], long 14 | # values: [N, C] 15 | 16 | D = indices.shape[-1] 17 | C = input.shape[-1] 18 | size = input.shape[:-1] 19 | stride = stride_from_shape(size) 20 | 21 | assert len(size) == D 22 | 23 | input = input.view(-1, C) # [HW, C] 24 | flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N] 25 | 26 | input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) 27 | 28 | return input.view(*size, C) 29 | 30 | 31 | def scatter_add_nd_with_count(input, count, indices, values, weights=None): 32 | # input: [..., C], D dimension + C channel 33 | # count: [..., 1], D dimension 34 | # indices: [N, D], long 35 | # values: [N, C] 36 | 37 | D = indices.shape[-1] 38 | C = input.shape[-1] 39 | size = input.shape[:-1] 40 | stride = stride_from_shape(size) 41 | 42 | assert len(size) == D 43 | 44 | input = input.view(-1, C) # [HW, C] 45 | count = count.view(-1, 1) 46 | 47 | flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N] 48 | 49 | if weights is None: 50 | weights = torch.ones_like(values[..., :1]) 51 | 52 | input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) 53 | count.scatter_add_(0, flatten_indices.unsqueeze(1), weights) 54 | 55 | return input.view(*size, C), count.view(*size, 1) 56 | 57 | def nearest_grid_put_2d(H, W, coords, values, return_count=False): 58 | # coords: [N, 2], float in [-1, 1] 59 | # values: [N, C] 60 | 61 | C = values.shape[-1] 62 | 63 | indices = (coords * 0.5 + 0.5) * torch.tensor( 64 | [H - 1, W - 1], dtype=torch.float32, device=coords.device 65 | ) 66 | indices = indices.round().long() # [N, 2] 67 | 68 | result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] 69 | count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] 70 | weights = torch.ones_like(values[..., :1]) # [N, 1] 71 | 72 | result, count = scatter_add_nd_with_count(result, count, indices, values, weights) 73 | 74 | if return_count: 75 | return result, count 76 | 77 | mask = (count.squeeze(-1) > 0) 78 | result[mask] = result[mask] / count[mask].repeat(1, C) 79 | 80 | return result 81 | 82 | 83 | def linear_grid_put_2d(H, W, coords, values, return_count=False): 84 | # coords: [N, 2], float in [-1, 1] 85 | # values: [N, C] 86 | 87 | C = values.shape[-1] 88 | 89 | indices = (coords * 0.5 + 0.5) * torch.tensor( 90 | [H - 1, W - 1], dtype=torch.float32, device=coords.device 91 | ) 92 | indices_00 = indices.floor().long() # [N, 2] 93 | indices_00[:, 0].clamp_(0, H - 2) 94 | indices_00[:, 1].clamp_(0, W - 2) 95 | indices_01 = indices_00 + torch.tensor( 96 | [0, 1], dtype=torch.long, device=indices.device 97 | ) 98 | indices_10 = indices_00 + torch.tensor( 99 | [1, 0], dtype=torch.long, device=indices.device 100 | ) 101 | indices_11 = indices_00 + torch.tensor( 102 | [1, 1], dtype=torch.long, device=indices.device 103 | ) 104 | 105 | h = indices[..., 0] - indices_00[..., 0].float() 106 | w = indices[..., 1] - indices_00[..., 1].float() 107 | w_00 = (1 - h) * (1 - w) 108 | w_01 = (1 - h) * w 109 | w_10 = h * (1 - w) 110 | w_11 = h * w 111 | 112 | result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] 113 | count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] 114 | weights = torch.ones_like(values[..., :1]) # [N, 1] 115 | 116 | result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1)) 117 | result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1)) 118 | result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1)) 119 | result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1)) 120 | 121 | if return_count: 122 | return result, count 123 | 124 | mask = (count.squeeze(-1) > 0) 125 | result[mask] = result[mask] / count[mask].repeat(1, C) 126 | 127 | return result 128 | 129 | def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False): 130 | # coords: [N, 2], float in [-1, 1] 131 | # values: [N, C] 132 | 133 | C = values.shape[-1] 134 | 135 | result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] 136 | count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] 137 | 138 | cur_H, cur_W = H, W 139 | 140 | while min(cur_H, cur_W) > min_resolution: 141 | 142 | # try to fill the holes 143 | mask = (count.squeeze(-1) == 0) 144 | if not mask.any(): 145 | break 146 | 147 | cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True) 148 | result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask] 149 | count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask] 150 | cur_H //= 2 151 | cur_W //= 2 152 | 153 | if return_count: 154 | return result, count 155 | 156 | mask = (count.squeeze(-1) > 0) 157 | result[mask] = result[mask] / count[mask].repeat(1, C) 158 | 159 | return result 160 | 161 | def nearest_grid_put_3d(H, W, D, coords, values, return_count=False): 162 | # coords: [N, 3], float in [-1, 1] 163 | # values: [N, C] 164 | 165 | C = values.shape[-1] 166 | 167 | indices = (coords * 0.5 + 0.5) * torch.tensor( 168 | [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device 169 | ) 170 | indices = indices.round().long() # [N, 2] 171 | 172 | result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, C] 173 | count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, 1] 174 | weights = torch.ones_like(values[..., :1]) # [N, 1] 175 | 176 | result, count = scatter_add_nd_with_count(result, count, indices, values, weights) 177 | 178 | if return_count: 179 | return result, count 180 | 181 | mask = (count.squeeze(-1) > 0) 182 | result[mask] = result[mask] / count[mask].repeat(1, C) 183 | 184 | return result 185 | 186 | 187 | def linear_grid_put_3d(H, W, D, coords, values, return_count=False): 188 | # coords: [N, 3], float in [-1, 1] 189 | # values: [N, C] 190 | 191 | C = values.shape[-1] 192 | 193 | indices = (coords * 0.5 + 0.5) * torch.tensor( 194 | [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device 195 | ) 196 | indices_000 = indices.floor().long() # [N, 3] 197 | indices_000[:, 0].clamp_(0, H - 2) 198 | indices_000[:, 1].clamp_(0, W - 2) 199 | indices_000[:, 2].clamp_(0, D - 2) 200 | 201 | indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device) 202 | indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device) 203 | indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device) 204 | indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device) 205 | indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device) 206 | indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device) 207 | indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device) 208 | 209 | h = indices[..., 0] - indices_000[..., 0].float() 210 | w = indices[..., 1] - indices_000[..., 1].float() 211 | d = indices[..., 2] - indices_000[..., 2].float() 212 | 213 | w_000 = (1 - h) * (1 - w) * (1 - d) 214 | w_001 = (1 - h) * w * (1 - d) 215 | w_010 = h * (1 - w) * (1 - d) 216 | w_011 = h * w * (1 - d) 217 | w_100 = (1 - h) * (1 - w) * d 218 | w_101 = (1 - h) * w * d 219 | w_110 = h * (1 - w) * d 220 | w_111 = h * w * d 221 | 222 | result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C] 223 | count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1] 224 | weights = torch.ones_like(values[..., :1]) # [N, 1] 225 | 226 | result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1)) 227 | result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1)) 228 | result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1)) 229 | result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1)) 230 | result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1)) 231 | result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1)) 232 | result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1)) 233 | result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1)) 234 | 235 | if return_count: 236 | return result, count 237 | 238 | mask = (count.squeeze(-1) > 0) 239 | result[mask] = result[mask] / count[mask].repeat(1, C) 240 | 241 | return result 242 | 243 | def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False): 244 | # coords: [N, 3], float in [-1, 1] 245 | # values: [N, C] 246 | 247 | C = values.shape[-1] 248 | 249 | result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C] 250 | count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1] 251 | cur_H, cur_W, cur_D = H, W, D 252 | 253 | while min(min(cur_H, cur_W), cur_D) > min_resolution: 254 | 255 | # try to fill the holes 256 | mask = (count.squeeze(-1) == 0) 257 | if not mask.any(): 258 | break 259 | 260 | cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True) 261 | result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask] 262 | count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask] 263 | cur_H //= 2 264 | cur_W //= 2 265 | cur_D //= 2 266 | 267 | if return_count: 268 | return result, count 269 | 270 | mask = (count.squeeze(-1) > 0) 271 | result[mask] = result[mask] / count[mask].repeat(1, C) 272 | 273 | return result 274 | 275 | 276 | def grid_put(shape, coords, values, mode='linear-mipmap', min_resolution=32, return_raw=False): 277 | # shape: [D], list/tuple 278 | # coords: [N, D], float in [-1, 1] 279 | # values: [N, C] 280 | 281 | D = len(shape) 282 | assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}' 283 | 284 | if mode == 'nearest': 285 | if D == 2: 286 | return nearest_grid_put_2d(*shape, coords, values, return_raw) 287 | else: 288 | return nearest_grid_put_3d(*shape, coords, values, return_raw) 289 | elif mode == 'linear': 290 | if D == 2: 291 | return linear_grid_put_2d(*shape, coords, values, return_raw) 292 | else: 293 | return linear_grid_put_3d(*shape, coords, values, return_raw) 294 | elif mode == 'linear-mipmap': 295 | if D == 2: 296 | return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_raw) 297 | else: 298 | return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_raw) 299 | else: 300 | raise NotImplementedError(f"got mode {mode}") -------------------------------------------------------------------------------- /gs_postprocess.py: -------------------------------------------------------------------------------- 1 | from gs_renderer import GaussianModel, Renderer, MiniCam 2 | from cam_utils import orbit_camera, OrbitCamera 3 | from torchvision.utils import save_image 4 | import numpy as np 5 | import torch 6 | import kiui 7 | from sh_utils import SH2RGB 8 | import os 9 | from grid_put import mipmap_linear_grid_put_2d 10 | from mesh import Mesh, safe_normalize 11 | import torch.nn.functional as F 12 | from copy import deepcopy 13 | from glob import glob 14 | 15 | 16 | #很小的高斯不剔除, 很大的高斯可能剔除, 如果它的投影面积大同时周围没有alpha,不剔除 17 | #一个点影响的范围是否对alpha造成了贡献,如果没有则不剔除 18 | 19 | def save_model(renderer, path): 20 | mesh = renderer.gaussians.extract_mesh(path, 0.2) 21 | 22 | # perform texture extraction 23 | print(f"[INFO] unwrap uv...") 24 | h = w = 512 25 | mesh.auto_uv() 26 | mesh.auto_normal() 27 | 28 | albedo = torch.zeros((h, w, 3), device="cuda", dtype=torch.float32) 29 | cnt = torch.zeros((h, w, 1), device="cuda", dtype=torch.float32) 30 | 31 | # self.prepare_train() # tmp fix for not loading 0123 32 | # vers = [0] 33 | # hors = [0] 34 | vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9] 35 | hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0] 36 | 37 | render_resolution = 512 38 | 39 | import nvdiffrast.torch as dr 40 | 41 | glctx = dr.RasterizeCudaContext() 42 | 43 | for ver, hor in zip(vers, hors): 44 | # render image 45 | pose = orbit_camera(ver, hor, 2) 46 | 47 | cur_cam = MiniCam( 48 | pose, 49 | render_resolution, 50 | render_resolution, 51 | np.deg2rad(49.1), 52 | np.deg2rad(49.1), 53 | 0.01, 54 | 100, 55 | ) 56 | cam = OrbitCamera(512, 512, r=2, fovy=49.1) 57 | 58 | cur_out = renderer.render(cur_cam) 59 | 60 | rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1] 61 | 62 | 63 | # get coordinate in texture image 64 | pose = torch.from_numpy(pose.astype(np.float32)).to("cuda") 65 | proj = torch.from_numpy(cam.perspective.astype(np.float32)).to("cuda") 66 | 67 | v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) 68 | v_clip = v_cam @ proj.T 69 | rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution)) 70 | 71 | depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1] 72 | depth = depth.squeeze(0) # [H, W, 1] 73 | 74 | alpha = (rast[0, ..., 3:] > 0).float() 75 | 76 | uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1] 77 | 78 | # use normal to produce a back-project mask 79 | normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn) 80 | normal = safe_normalize(normal[0]) 81 | 82 | # rotated normal (where [0, 0, 1] always faces camera) 83 | rot_normal = normal @ pose[:3, :3] 84 | viewcos = rot_normal[..., [2]] 85 | 86 | mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1] 87 | mask = mask.view(-1) 88 | 89 | uvs = uvs.view(-1, 2).clamp(0, 1)[mask] 90 | rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous() 91 | 92 | # update texture image 93 | cur_albedo, cur_cnt = mipmap_linear_grid_put_2d( 94 | h, w, 95 | uvs[..., [1, 0]] * 2 - 1, 96 | rgbs, 97 | min_resolution=256, 98 | return_count=True, 99 | ) 100 | 101 | # albedo += cur_albedo 102 | # cnt += cur_cnt 103 | mask = cnt.squeeze(-1) < 0.1 104 | albedo[mask] += cur_albedo[mask] 105 | cnt[mask] += cur_cnt[mask] 106 | 107 | mask = cnt.squeeze(-1) > 0 108 | albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3) 109 | 110 | mask = mask.view(h, w) 111 | 112 | albedo = albedo.detach().cpu().numpy() 113 | mask = mask.detach().cpu().numpy() 114 | 115 | # dilate texture 116 | from sklearn.neighbors import NearestNeighbors 117 | from scipy.ndimage import binary_dilation, binary_erosion 118 | 119 | inpaint_region = binary_dilation(mask, iterations=32) 120 | inpaint_region[mask] = 0 121 | 122 | search_region = mask.copy() 123 | not_search_region = binary_erosion(search_region, iterations=3) 124 | search_region[not_search_region] = 0 125 | 126 | search_coords = np.stack(np.nonzero(search_region), axis=-1) 127 | inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) 128 | 129 | knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit( 130 | search_coords 131 | ) 132 | _, indices = knn.kneighbors(inpaint_coords) 133 | 134 | albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)] 135 | 136 | mesh.albedo = torch.from_numpy(albedo).to("cuda") 137 | mesh.write(path) 138 | 139 | print(f"[INFO] save model to {path}.") 140 | 141 | def render_save(renderer, pose, path=None): 142 | cam = MiniCam(pose, 512, 512, np.deg2rad(49), np.deg2rad(49), 0.01, 100) 143 | bg = torch.tensor([1,0,0],dtype=torch.float,device="cuda") 144 | out = renderer.render(cam,bg_color=bg) 145 | if path is not None: 146 | save_image(out["image"], path) 147 | return out["image"], out["alpha"] 148 | 149 | 150 | # def alpha_save(path): 151 | # bg = torch.tensor([1,0,0],dtype=torch.float,device="cuda") 152 | # out = renderer.render(cam,bg_color=bg) 153 | # save_image(out["alpha"], path) 154 | 155 | def drawPoint(img, point_img, path=None): 156 | color = torch.zeros((3, point_img.shape[0]), device = img.device) 157 | color[1, :] = 1 158 | img[:, point_img[:,1], point_img[:,0]] = color 159 | if path is not None: 160 | save_image(img, path) 161 | return img 162 | 163 | 164 | def filter_out_once(renderer, pose): 165 | cam = MiniCam(pose, 512, 512, np.deg2rad(49), np.deg2rad(49), 0.01, 100) 166 | 167 | renderer_backup = deepcopy(renderer) 168 | 169 | scale = renderer_backup.gaussians.get_scaling 170 | xyz = renderer_backup.gaussians.get_xyz 171 | color = SH2RGB(renderer_backup.gaussians.get_features) 172 | opa = renderer_backup.gaussians.get_opacity 173 | 174 | prune_mask2 = torch.any(scale > 0.01, dim=1) 175 | 176 | prune_mask = torch.all(color > 1.0, dim=2)[:,0] 177 | prune_mask = torch.logical_and(prune_mask, prune_mask2) 178 | 179 | 180 | prune_xyz = xyz[prune_mask] 181 | homo = torch.ones((prune_xyz.shape[0], 1), device=prune_xyz.device) 182 | prune_xyz = torch.cat((prune_xyz, homo), dim=1) 183 | 184 | proj_mat = cam.full_proj_transform 185 | p_proj = prune_xyz @ proj_mat 186 | point_img = p_proj[:,:2] / p_proj[:,[2]] 187 | point_img = ((point_img + 1.0) * 512 / 2 - 0.5).round().int() 188 | 189 | img, alpha_before = render_save(renderer_backup, pose) 190 | #drawPoint(img, point_img, "1.png") 191 | prune_mask_backup = prune_mask.clone() 192 | renderer_backup.gaussians.prune_points_test(prune_mask_backup) 193 | 194 | _, alpha_after = render_save(renderer_backup, pose) 195 | 196 | alpha_delta = alpha_before - alpha_after 197 | alpha_delta = torch.where(alpha_delta > 0.1, 1.0, 0.0) 198 | 199 | torch.clamp_(point_img[:, 1], 0, alpha_delta.shape[1] - 1) 200 | torch.clamp_(point_img[:, 0], 0, alpha_delta.shape[2] - 1) 201 | 202 | prune_mask_contributed = torch.where(alpha_delta[0, point_img[:, 1], point_img[:, 0]] == 1.0, True, False) 203 | prune_mask_ind = torch.arange(prune_mask.shape[0], device=xyz.device)[prune_mask] 204 | prune_mask_ind = prune_mask_ind[prune_mask_contributed] 205 | prune_mask[:] = False 206 | prune_mask[prune_mask_ind] = True 207 | 208 | renderer.gaussians.prune_points_test(prune_mask) 209 | 210 | point_img = point_img[prune_mask_contributed] 211 | pointed_img = drawPoint(img, point_img) 212 | print(f"{point_img.shape[0]} guassians have been cleaned") 213 | return renderer, pointed_img 214 | 215 | def filter_out(renderer): 216 | cams = [] 217 | cams.append(orbit_camera(0, 0, 2)) 218 | cams.append(orbit_camera(0, 180, 2)) 219 | cams.append(orbit_camera(0, 90, 2)) 220 | cams.append(orbit_camera(20, 180, 2)) 221 | cams.append(orbit_camera(20, 90, 2)) 222 | cams.append(orbit_camera(0, -90, 2)) 223 | for ind, cam in enumerate(cams): 224 | _, pointed_img = filter_out_once(renderer, cam) 225 | 226 | 227 | 228 | if __name__ == "__main__": 229 | 230 | ply_files = sorted(glob("./logs_v3d/*.ply")) 231 | 232 | renderer = Renderer(sh_degree=0) 233 | renderer.initialize(input=ply_files[13]) 234 | 235 | filter_out(renderer) 236 | -------------------------------------------------------------------------------- /guidance/imagedream_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.transforms.functional as TF 6 | 7 | from imagedream.camera_utils import get_camera, convert_opengl_to_blender, normalize_camera 8 | from imagedream.model_zoo import build_model 9 | from imagedream.ldm.models.diffusion.ddim import DDIMSampler 10 | 11 | from diffusers import DDIMScheduler,DDIMInverseScheduler 12 | from torchvision.utils import save_image 13 | import kiui 14 | 15 | class ImageDream(nn.Module): 16 | def __init__( 17 | self, 18 | device, 19 | model_name='sd-v2.1-base-4view-ipmv', 20 | ckpt_path=None, 21 | t_range=[0.02, 0.98], 22 | ): 23 | super().__init__() 24 | 25 | self.device = device 26 | self.model_name = model_name 27 | self.ckpt_path = ckpt_path 28 | 29 | self.model = build_model(self.model_name, ckpt_path=self.ckpt_path).eval().to(self.device) 30 | self.model.device = device 31 | for p in self.model.parameters(): 32 | p.requires_grad_(False) 33 | 34 | self.dtype = torch.float32 35 | 36 | self.num_train_timesteps = 1000 37 | self.min_step = int(self.num_train_timesteps * t_range[0]) 38 | self.max_step = int(self.num_train_timesteps * t_range[1]) 39 | 40 | self.image_embeddings = {} 41 | self.embeddings = {} 42 | 43 | # self.scheduler = DDIMScheduler.from_pretrained( 44 | # "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", torch_dtype=self.dtype 45 | # ) 46 | import json 47 | self.config =json.load(open("./scheduler_config.json")) 48 | 49 | @torch.no_grad() 50 | def get_image_text_embeds(self, image, prompts, negative_prompts): 51 | 52 | image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) 53 | image_pil = TF.to_pil_image(image[0]) 54 | image_embeddings = self.model.get_learned_image_conditioning(image_pil).repeat(5,1,1) # [5, 257, 1280] 55 | self.image_embeddings['pos'] = image_embeddings 56 | self.image_embeddings['neg'] = torch.zeros_like(image_embeddings) 57 | 58 | self.image_embeddings['ip_img'] = self.encode_imgs(image) 59 | self.image_embeddings['neg_ip_img'] = torch.zeros_like(self.image_embeddings['ip_img']) 60 | 61 | pos_embeds = self.encode_text(prompts).repeat(5,1,1) 62 | neg_embeds = self.encode_text(negative_prompts).repeat(5,1,1) 63 | self.embeddings['pos'] = pos_embeds 64 | self.embeddings['neg'] = neg_embeds 65 | 66 | def encode_text(self, prompt): 67 | # prompt: [str] 68 | embeddings = self.model.get_learned_conditioning(prompt).to(self.device) 69 | return embeddings 70 | 71 | @torch.no_grad() 72 | def refine(self, pred_rgb, camera, 73 | guidance_scale=2.0, steps=10, strength=0.8, 74 | ): 75 | 76 | batch_size = pred_rgb.shape[0] 77 | real_batch_size = batch_size // 4 78 | pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) 79 | latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) 80 | 81 | self.scheduler.set_timesteps(steps) 82 | init_step = int(steps * strength) 83 | latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) 84 | 85 | camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) 86 | camera[:, 1] *= -1 87 | camera = normalize_camera(camera).view(batch_size, 16) 88 | 89 | # extra view 90 | camera = camera.view(real_batch_size, 4, 16) 91 | camera = torch.cat([camera, torch.zeros_like(camera[:, :1])], dim=1) # [rB, 5, 16] 92 | camera = camera.view(real_batch_size * 5, 16) 93 | 94 | camera = camera.repeat(2, 1) 95 | embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) 96 | image_embeddings = torch.cat([self.image_embeddings['neg'].repeat(real_batch_size, 1, 1), self.image_embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) 97 | ip_img_embeddings= torch.cat([self.image_embeddings['neg_ip_img'].repeat(real_batch_size, 1, 1, 1), self.image_embeddings['ip_img'].repeat(real_batch_size, 1, 1, 1)], dim=0) 98 | 99 | context = { 100 | "context": embeddings, 101 | "ip": image_embeddings, 102 | "ip_img": ip_img_embeddings, 103 | "camera": camera, 104 | "num_frames": 4 + 1 105 | } 106 | 107 | for i, t in enumerate(self.scheduler.timesteps[init_step:]): 108 | 109 | # extra view 110 | 111 | latents = latents.view(real_batch_size, 4, 4, 32, 32) 112 | latents = torch.cat([latents, torch.zeros_like(latents[:, :1])], dim=1).view(-1, 4, 32, 32) 113 | latent_model_input = torch.cat([latents] * 2) 114 | 115 | tt = torch.cat([t.unsqueeze(0).repeat(real_batch_size * 5)] * 2).to(self.device) 116 | 117 | noise_pred = self.model.apply_model(latent_model_input, tt, context) 118 | 119 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 120 | 121 | # remove extra view 122 | noise_pred_uncond = noise_pred_uncond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) 123 | noise_pred_cond = noise_pred_cond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) 124 | latents = latents.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) 125 | 126 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 127 | 128 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 129 | 130 | imgs = self.decode_latents(latents) # [1, 3, 512, 512] 131 | return imgs 132 | 133 | def train_step( 134 | self, 135 | pred_rgb, # [B, C, H, W] 136 | camera, # [B, 4, 4] 137 | step_ratio=None, 138 | guidance_scale=3, 139 | as_latent=False, 140 | target_img=None, 141 | step=None, 142 | iter_steps=20, 143 | init_3d=False, 144 | inverse_ratio=0.6, 145 | ddim_eta=1.0 146 | ): 147 | 148 | batch_size = pred_rgb.shape[0] 149 | real_batch_size = batch_size // 4 150 | pred_rgb = pred_rgb.to(self.dtype) 151 | if target_img is None: 152 | with torch.no_grad(): 153 | if as_latent: 154 | latents = F.interpolate(pred_rgb, (32, 32), mode="bilinear", align_corners=False) * 2 - 1 155 | else: 156 | # interp to 256x256 to be fed into vae. 157 | pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) 158 | # encode image into latents with vae, requires grad! 159 | latents = self.encode_imgs(pred_rgb_256) 160 | t=torch.tensor(step,dtype=torch.long,device=self.device).unsqueeze(0) 161 | t_expand = t.repeat(batch_size) 162 | 163 | latents_noisy = latents 164 | inverse_scheduler = DDIMInverseScheduler.from_config(self.config) 165 | # inverse_scheduler = DDIMInverseScheduler(clip_sample=False) 166 | 167 | inverse_scheduler.set_timesteps(iter_steps) 168 | # scheduler = DPMSolverMultistepScheduler() 169 | # scheduler.config.algorithm_type = 'sde-dpmsolver++' 170 | # scheduler.config.solver_order = 1 171 | 172 | scheduler=DDIMScheduler.from_config(self.config) 173 | # scheduler=DDIMScheduler(clip_sample=False) 174 | 175 | scheduler.set_timesteps(iter_steps) 176 | 177 | camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) 178 | camera[:, 1] *= -1 179 | camera = normalize_camera(camera).view(batch_size, 16) 180 | 181 | # extra view 182 | camera = camera.view(real_batch_size, 4, 16) 183 | camera = torch.cat([camera, torch.zeros_like(camera[:, :1])], dim=1) # [rB, 5, 16] 184 | camera = camera.view(real_batch_size * 5, 16) 185 | 186 | camera = camera.repeat(2, 1) 187 | embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) 188 | image_embeddings = torch.cat([self.image_embeddings['neg'].repeat(real_batch_size, 1, 1), self.image_embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) 189 | ip_img_embeddings= torch.cat([self.image_embeddings['neg_ip_img'].repeat(real_batch_size, 1, 1, 1), self.image_embeddings['ip_img'].repeat(real_batch_size, 1, 1, 1)], dim=0) 190 | 191 | context = { 192 | "context": embeddings, 193 | "ip": image_embeddings, 194 | "ip_img": ip_img_embeddings, 195 | "camera": camera, 196 | "num_frames": 4 + 1 197 | } 198 | 199 | # predict the noise residual with unet, NO grad! 200 | @torch.no_grad() 201 | def pred_noise(latents,t,uncond=False): 202 | latents_noisy=latents 203 | # extra view 204 | t = t.view(real_batch_size, 4) 205 | t = torch.cat([t, t[:, :1]], dim=1).view(-1) 206 | latents_noisy = latents_noisy.view(real_batch_size, 4, 4, 32, 32) 207 | latents_noisy = torch.cat([latents_noisy, torch.zeros_like(latents_noisy[:, :1])], dim=1).view(-1, 4, 32, 32) 208 | # pred noise 209 | latent_model_input = torch.cat([latents_noisy] * 2) 210 | tt = torch.cat([t] * 2) 211 | 212 | # import kiui 213 | # kiui.lo(latent_model_input, t, context['context'], context['camera']) 214 | 215 | noise_pred = self.model.apply_model(latent_model_input, tt, context) 216 | 217 | # perform guidance (high scale from paper!) 218 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 219 | 220 | # remove extra view 221 | noise_pred_uncond = noise_pred_uncond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) 222 | if uncond: 223 | return noise_pred_uncond 224 | noise_pred_cond = noise_pred_cond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) 225 | 226 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 227 | return noise_pred 228 | 229 | t_stop = t[0] 230 | ratio=torch.tensor((step/1000)*iter_steps*inverse_ratio,dtype=torch.long) 231 | t_spe=inverse_scheduler.timesteps[ratio] 232 | if t_spe>0: 233 | latents_noisy=scheduler.add_noise(latents_noisy,torch.randn_like(latents),t_spe) 234 | 235 | if init_3d: 236 | latents_noisy=torch.randn_like(latents) 237 | t_stop=1000 238 | else: 239 | for i,t_inv in enumerate(inverse_scheduler.timesteps[:-1]): 240 | t_inv_prev=inverse_scheduler.timesteps[i+1] 241 | if t_inv_prev <= t_spe: 242 | continue 243 | if t_inv_prev > t_stop: 244 | break 245 | t_inv_expand=t_inv.repeat(batch_size).to(self.device) 246 | noise_pred = pred_noise(latents_noisy,t_inv_expand,uncond=True) 247 | latents_noisy = inverse_scheduler.step( 248 | noise_pred, t_inv_prev, latents_noisy).prev_sample.clone().detach() 249 | 250 | for tt in scheduler.timesteps: 251 | if tt > t_stop: 252 | continue 253 | tt_expand=tt.repeat(batch_size).to(self.device) 254 | noise_pred=pred_noise(latents_noisy,tt_expand) 255 | latents_noisy=scheduler.step(noise_pred,tt,latents_noisy,eta=ddim_eta).prev_sample.to(latents.dtype).clone().detach() 256 | pred_latents = latents_noisy 257 | target_img=self.decode_latents(pred_latents) 258 | 259 | # real_target_img=F.interpolate(target_img, (pred_rgb.shape[-2], pred_rgb.shape[-1]), mode='bicubic', align_corners=False) 260 | # loss=F.l1_loss(pred_rgb,real_target_img.to(pred_rgb),reduction='sum')/pred_rgb.shape[0] 261 | 262 | # return loss,target_img 263 | return target_img 264 | 265 | def decode_latents(self, latents): 266 | imgs = self.model.decode_first_stage(latents) 267 | imgs = ((imgs + 1) / 2).clamp(0, 1) 268 | return imgs 269 | 270 | def encode_imgs(self, imgs): 271 | # imgs: [B, 3, 256, 256] 272 | imgs = 2 * imgs - 1 273 | latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) 274 | return latents # [B, 4, 32, 32] 275 | 276 | @torch.no_grad() 277 | def prompt_to_img( 278 | self, 279 | image, 280 | prompts, 281 | negative_prompts="", 282 | height=256, 283 | width=256, 284 | num_inference_steps=50, 285 | guidance_scale=5.0, 286 | latents=None, 287 | elevation=0, 288 | azimuth_start=0, 289 | ): 290 | if isinstance(prompts, str): 291 | prompts = [prompts] 292 | 293 | if isinstance(negative_prompts, str): 294 | negative_prompts = [negative_prompts] 295 | 296 | real_batch_size = len(prompts) 297 | batch_size = len(prompts) * 5 298 | 299 | # Text embeds -> img latents 300 | sampler = DDIMSampler(self.model) 301 | shape = [4, height // 8, width // 8] 302 | 303 | c_ = {"context": self.encode_text(prompts).repeat(5,1,1)} 304 | uc_ = {"context": self.encode_text(negative_prompts).repeat(5,1,1)} 305 | 306 | # image embeddings 307 | image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) 308 | image_pil = TF.to_pil_image(image[0]) 309 | image_embeddings = self.model.get_learned_image_conditioning(image_pil).repeat(5,1,1).to(self.device) 310 | c_["ip"] = image_embeddings 311 | uc_["ip"] = torch.zeros_like(image_embeddings) 312 | 313 | ip_img = self.encode_imgs(image) 314 | c_["ip_img"] = ip_img 315 | uc_["ip_img"] = torch.zeros_like(ip_img) 316 | 317 | camera = get_camera(4, elevation=elevation, azimuth_start=azimuth_start, extra_view=True) 318 | camera = camera.repeat(real_batch_size, 1).to(self.device) 319 | 320 | c_["camera"] = uc_["camera"] = camera 321 | c_["num_frames"] = uc_["num_frames"] = 5 322 | 323 | kiui.lo(image_embeddings, ip_img, camera) 324 | 325 | latents, _ = sampler.sample(S=num_inference_steps, conditioning=c_, 326 | batch_size=batch_size, shape=shape, 327 | verbose=False, 328 | unconditional_guidance_scale=guidance_scale, 329 | unconditional_conditioning=uc_, 330 | eta=0, x_T=None) 331 | 332 | # Img latents -> imgs 333 | imgs = self.decode_latents(latents) # [4, 3, 256, 256] 334 | 335 | kiui.lo(latents, imgs) 336 | 337 | # Img to Numpy 338 | imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() 339 | imgs = (imgs * 255).round().astype("uint8") 340 | 341 | return imgs 342 | -------------------------------------------------------------------------------- /guidance/mvdream_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from mvdream.camera_utils import get_camera, convert_opengl_to_blender, normalize_camera 7 | from mvdream.model_zoo import build_model 8 | from mvdream.ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | from diffusers import DDIMScheduler,DDIMInverseScheduler 11 | from torchvision.utils import save_image 12 | 13 | class MVDream(nn.Module): 14 | def __init__( 15 | self, 16 | device, 17 | model_name='sd-v2.1-base-4view', 18 | ckpt_path=None, 19 | t_range=[0.02, 0.98], 20 | ): 21 | super().__init__() 22 | 23 | self.device = device 24 | self.model_name = model_name 25 | self.ckpt_path = ckpt_path 26 | 27 | self.model = build_model(self.model_name, ckpt_path=self.ckpt_path).eval().to(self.device) 28 | self.model.device = device 29 | for p in self.model.parameters(): 30 | p.requires_grad_(False) 31 | 32 | self.dtype = torch.float32 33 | 34 | self.num_train_timesteps = 1000 35 | self.min_step = int(self.num_train_timesteps * t_range[0]) 36 | self.max_step = int(self.num_train_timesteps * t_range[1]) 37 | 38 | self.embeddings = {} 39 | 40 | # self.scheduler = DDIMScheduler.from_pretrained( 41 | # "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", torch_dtype=self.dtype 42 | # ) 43 | import json 44 | self.config =json.load(open("./scheduler_config.json")) 45 | 46 | @torch.no_grad() 47 | def get_text_embeds(self, prompts, negative_prompts): 48 | pos_embeds = self.encode_text(prompts).repeat(4,1,1) # [1, 77, 768] 49 | neg_embeds = self.encode_text(negative_prompts).repeat(4,1,1) 50 | self.embeddings['pos'] = pos_embeds 51 | self.embeddings['neg'] = neg_embeds 52 | 53 | def encode_text(self, prompt): 54 | # prompt: [str] 55 | embeddings = self.model.get_learned_conditioning(prompt).to(self.device) 56 | return embeddings 57 | 58 | @torch.no_grad() 59 | def refine(self, pred_rgb, camera, 60 | guidance_scale=100, steps=50, strength=0.8, 61 | ): 62 | 63 | batch_size = pred_rgb.shape[0] 64 | real_batch_size = batch_size // 4 65 | pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) 66 | latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) 67 | # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) 68 | 69 | self.scheduler.set_timesteps(steps) 70 | init_step = int(steps * strength) 71 | latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) 72 | 73 | camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) 74 | camera[:, 1] *= -1 75 | camera = normalize_camera(camera).view(batch_size, 16) 76 | camera = camera.repeat(2, 1) 77 | 78 | embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) 79 | context = {"context": embeddings, "camera": camera, "num_frames": 4} 80 | 81 | for i, t in enumerate(self.scheduler.timesteps[init_step:]): 82 | 83 | latent_model_input = torch.cat([latents] * 2) 84 | 85 | tt = torch.cat([t.unsqueeze(0).repeat(batch_size)] * 2).to(self.device) 86 | 87 | noise_pred = self.model.apply_model(latent_model_input, tt, context) 88 | 89 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 90 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 91 | 92 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 93 | 94 | imgs = self.decode_latents(latents) # [1, 3, 512, 512] 95 | return imgs 96 | 97 | def train_step( 98 | self, 99 | pred_rgb, # [B, C, H, W], B is multiples of 4 100 | camera, # [B, 4, 4] 101 | step_ratio=None, 102 | guidance_scale=7.5, 103 | as_latent=False, 104 | target_img=None, 105 | step=None, 106 | iter_steps=20, 107 | init_3d=False 108 | ): 109 | 110 | batch_size = pred_rgb.shape[0] 111 | real_batch_size = batch_size // 4 112 | pred_rgb = pred_rgb.to(self.dtype) 113 | if target_img is None: 114 | with torch.no_grad(): 115 | if as_latent: 116 | latents = F.interpolate(pred_rgb, (32, 32), mode="bilinear", align_corners=False) * 2 - 1 117 | else: 118 | # interp to 256x256 to be fed into vae. 119 | pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) 120 | # encode image into latents with vae, requires grad! 121 | latents = self.encode_imgs(pred_rgb_256) 122 | t=torch.tensor(step,dtype=torch.long,device=self.device).unsqueeze(0) 123 | t_expand = t.repeat(batch_size) 124 | 125 | latents_noisy = latents 126 | inverse_scheduler = DDIMInverseScheduler.from_config(self.config) 127 | # inverse_scheduler = DDIMInverseScheduler(clip_sample=False) 128 | 129 | inverse_scheduler.set_timesteps(iter_steps) 130 | # scheduler = DPMSolverMultistepScheduler() 131 | # scheduler.config.algorithm_type = 'sde-dpmsolver++' 132 | # scheduler.config.solver_order = 1 133 | 134 | scheduler=DDIMScheduler.from_config(self.config) 135 | # scheduler=DDIMScheduler(clip_sample=False) 136 | 137 | scheduler.set_timesteps(iter_steps) 138 | 139 | camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) 140 | camera[:, 1] *= -1 141 | camera = normalize_camera(camera).view(batch_size, 16) 142 | 143 | camera = camera.repeat(2, 1) 144 | embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) 145 | context = {"context": embeddings, "camera": camera, "num_frames": 4} 146 | 147 | @torch.no_grad() 148 | def pred_noise(latents,t,uncond=False): 149 | # add noise 150 | latents_noisy = latents 151 | # pred noise 152 | latent_model_input = torch.cat([latents_noisy] * 2) 153 | tt = torch.cat([t] * 2) 154 | noise_pred = self.model.apply_model(latent_model_input, tt, context) 155 | 156 | # perform guidance (high scale from paper!) 157 | noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) 158 | if uncond: 159 | return noise_pred_uncond 160 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) 161 | return noise_pred 162 | 163 | t_stop = t[0] 164 | ratio=torch.tensor((step/1000)*iter_steps*0.6,dtype=torch.long) 165 | t_spe=inverse_scheduler.timesteps[ratio] 166 | if t_spe>0: 167 | latents_noisy=scheduler.add_noise(latents_noisy,torch.randn_like(latents),t_spe) 168 | 169 | if init_3d: 170 | latents_noisy=torch.randn_like(latents) 171 | t_stop=1000 172 | else: 173 | for i,t_inv in enumerate(inverse_scheduler.timesteps[:-1]): 174 | t_inv_prev=inverse_scheduler.timesteps[i+1] 175 | if t_inv_prev <= t_spe: 176 | continue 177 | if t_inv_prev > t_stop: 178 | break 179 | t_inv_expand=t_inv.repeat(batch_size).to(self.device) 180 | noise_pred = pred_noise(latents_noisy,t_inv_expand,uncond=True) 181 | latents_noisy = inverse_scheduler.step( 182 | noise_pred, t_inv_prev, latents_noisy).prev_sample.clone().detach() 183 | 184 | for tt in scheduler.timesteps: 185 | if tt > t_stop: 186 | continue 187 | tt_expand=tt.repeat(batch_size).to(self.device) 188 | noise_pred=pred_noise(latents_noisy,tt_expand) 189 | latents_noisy=scheduler.step(noise_pred,tt,latents_noisy,eta=0.0).prev_sample.to(latents.dtype).clone().detach() 190 | pred_latents = latents_noisy 191 | target_img=self.decode_latents(pred_latents) 192 | 193 | real_target_img=F.interpolate(target_img, (pred_rgb.shape[-2], pred_rgb.shape[-1]), mode='bicubic', align_corners=False) 194 | loss=F.l1_loss(pred_rgb,real_target_img.to(pred_rgb),reduction='sum')/pred_rgb.shape[0] 195 | 196 | return loss,target_img 197 | 198 | def decode_latents(self, latents): 199 | imgs = self.model.decode_first_stage(latents) 200 | imgs = ((imgs + 1) / 2).clamp(0, 1) 201 | return imgs 202 | 203 | def encode_imgs(self, imgs): 204 | # imgs: [B, 3, 256, 256] 205 | imgs = 2 * imgs - 1 206 | latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) 207 | return latents # [B, 4, 32, 32] 208 | 209 | @torch.no_grad() 210 | def prompt_to_img( 211 | self, 212 | prompts, 213 | negative_prompts="", 214 | height=256, 215 | width=256, 216 | num_inference_steps=50, 217 | guidance_scale=7.5, 218 | latents=None, 219 | elevation=0, 220 | azimuth_start=0, 221 | ): 222 | if isinstance(prompts, str): 223 | prompts = [prompts] 224 | 225 | if isinstance(negative_prompts, str): 226 | negative_prompts = [negative_prompts] 227 | 228 | batch_size = len(prompts) * 4 229 | 230 | # Text embeds -> img latents 231 | sampler = DDIMSampler(self.model) 232 | shape = [4, height // 8, width // 8] 233 | c_ = {"context": self.encode_text(prompts).repeat(4,1,1)} 234 | uc_ = {"context": self.encode_text(negative_prompts).repeat(4,1,1)} 235 | 236 | camera = get_camera(4, elevation=elevation, azimuth_start=azimuth_start) 237 | camera = camera.repeat(batch_size // 4, 1).to(self.device) 238 | 239 | c_["camera"] = uc_["camera"] = camera 240 | c_["num_frames"] = uc_["num_frames"] = 4 241 | 242 | latents, _ = sampler.sample(S=num_inference_steps, conditioning=c_, 243 | batch_size=batch_size, shape=shape, 244 | verbose=False, 245 | unconditional_guidance_scale=guidance_scale, 246 | unconditional_conditioning=uc_, 247 | eta=0, x_T=None) 248 | 249 | # Img latents -> imgs 250 | imgs = self.decode_latents(latents) # [4, 3, 256, 256] 251 | 252 | # Img to Numpy 253 | imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() 254 | imgs = (imgs * 255).round().astype("uint8") 255 | 256 | return imgs 257 | -------------------------------------------------------------------------------- /guidance/zero123_utils.py: -------------------------------------------------------------------------------- 1 | from zero123 import Zero123Pipeline 2 | from diffusers import DDIMScheduler 3 | import torchvision.transforms.functional as TF 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision.utils import save_image 10 | from diffusers import DDIMInverseScheduler, DPMSolverMultistepScheduler 11 | 12 | import sys 13 | sys.path.append('./') 14 | 15 | 16 | class Zero123(nn.Module): 17 | def __init__(self, device, fp16=True, t_range=[0.02, 0.98], model_key="ashawkey/stable-zero123-diffusers"): 18 | super().__init__() 19 | 20 | self.device = device 21 | self.fp16 = fp16 22 | self.dtype = torch.float16 if fp16 else torch.float32 23 | 24 | # assert self.fp16, 'Only zero123 fp16 is supported for now.' 25 | 26 | self.pipe = Zero123Pipeline.from_pretrained( 27 | model_key, 28 | torch_dtype=self.dtype, 29 | trust_remote_code=True, 30 | ).to(self.device) 31 | 32 | # stable-zero123 has a different camera embedding 33 | self.use_stable_zero123 = 'stable' in model_key 34 | 35 | self.pipe.image_encoder.eval() 36 | self.pipe.vae.eval() 37 | self.pipe.unet.eval() 38 | self.pipe.clip_camera_projection.eval() 39 | 40 | self.vae = self.pipe.vae 41 | self.unet = self.pipe.unet 42 | 43 | self.pipe.set_progress_bar_config(disable=True) 44 | 45 | self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) 46 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps 47 | 48 | self.min_step = int(self.num_train_timesteps * t_range[0]) 49 | self.max_step = int(self.num_train_timesteps * t_range[1]) 50 | self.alphas = self.scheduler.alphas_cumprod.to( 51 | self.device) # for convenience 52 | 53 | self.embeddings = None 54 | 55 | @torch.no_grad() 56 | def get_img_embeds(self, x): 57 | # x: image tensor in [0, 1] 58 | x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) 59 | x_pil = [TF.to_pil_image(image) for image in x] 60 | x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to( 61 | device=self.device, dtype=self.dtype) 62 | c = self.pipe.image_encoder(x_clip).image_embeds 63 | v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor 64 | self.embeddings = [c, v] 65 | 66 | def get_cam_embeddings(self, elevation, azimuth, radius, default_elevation=0): 67 | if self.use_stable_zero123: 68 | T = np.stack([np.deg2rad(elevation), np.sin(np.deg2rad(azimuth)), np.cos( 69 | np.deg2rad(azimuth)), np.deg2rad([90 + default_elevation] * len(elevation))], axis=-1) 70 | else: 71 | # original zero123 camera embedding 72 | T = np.stack([np.deg2rad(elevation), np.sin(np.deg2rad( 73 | azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) 74 | T = torch.from_numpy(T).unsqueeze(1).to( 75 | dtype=self.dtype, device=self.device) # [8, 1, 4] 76 | return T 77 | 78 | @torch.no_grad() 79 | def refine(self, pred_rgb, elevation, azimuth, radius, 80 | guidance_scale=2, steps=50, strength=0.8, default_elevation=0, 81 | ): 82 | 83 | batch_size = pred_rgb.shape[0] 84 | 85 | self.scheduler.set_timesteps(steps) 86 | 87 | if strength == 0: 88 | init_step = 0 89 | latents = torch.randn( 90 | (1, 4, 32, 32), device=self.device, dtype=self.dtype) 91 | else: 92 | init_step = int(steps * strength) 93 | pred_rgb_256 = F.interpolate( 94 | pred_rgb, (256, 256), mode='bilinear', align_corners=False) 95 | latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) 96 | latents = self.scheduler.add_noise(latents, torch.randn_like( 97 | latents), self.scheduler.timesteps[init_step]) 98 | 99 | T = self.get_cam_embeddings( 100 | elevation, azimuth, radius, default_elevation) 101 | cc_emb = torch.cat( 102 | [self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) 103 | cc_emb = self.pipe.clip_camera_projection(cc_emb) 104 | cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) 105 | 106 | vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) 107 | vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) 108 | 109 | for i, t in enumerate(self.scheduler.timesteps[init_step:]): 110 | 111 | x_in = torch.cat([latents] * 2) 112 | t_in = torch.cat( 113 | [t.view(1)] * 2).repeat(batch_size).to(self.device) 114 | 115 | noise_pred = self.unet( 116 | torch.cat([x_in, vae_emb], dim=1), 117 | t_in.to(self.unet.dtype), 118 | encoder_hidden_states=cc_emb, 119 | ).sample 120 | 121 | noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) 122 | noise_pred = noise_pred_uncond + guidance_scale * \ 123 | (noise_pred_cond - noise_pred_uncond) 124 | 125 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 126 | 127 | imgs = self.decode_latents(latents) # [1, 3, 256, 256] 128 | return imgs 129 | 130 | def train_step(self, pred_rgb, elevation, azimuth, radius, step_ratio=None, guidance_scale=3.0, as_latent=False, default_elevation=0, target_img=None, step=None, iter_steps=20, init_3d=False, inverse_ratio=0.6, ddim_eta=1.0): 131 | 132 | batch_size = pred_rgb.shape[0] 133 | if target_img is None: 134 | with torch.no_grad(): 135 | if as_latent: 136 | latents = F.interpolate( 137 | pred_rgb, (32, 32), mode='bicubic', align_corners=False) * 2 - 1 138 | else: 139 | pred_rgb_256 = F.interpolate( 140 | pred_rgb, (256, 256), mode='bicubic', align_corners=False) 141 | latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) 142 | 143 | t = torch.tensor(step, dtype=torch.long, 144 | device=self.device).unsqueeze(0) 145 | t_expand = t.repeat(batch_size) 146 | 147 | latents_noisy = latents 148 | inverse_scheduler = DDIMInverseScheduler.from_config( 149 | self.pipe.scheduler.config) 150 | # inverse_scheduler = DDIMInverseScheduler(clip_sample=False) 151 | 152 | inverse_scheduler.set_timesteps(iter_steps) 153 | # scheduler = DPMSolverMultistepScheduler() 154 | # scheduler.config.algorithm_type = 'sde-dpmsolver++' 155 | # scheduler.config.solver_order = 1 156 | 157 | scheduler = DDIMScheduler.from_config( 158 | self.pipe.scheduler.config) 159 | # scheduler=DDIMScheduler(clip_sample=False) 160 | 161 | scheduler.set_timesteps(iter_steps) 162 | 163 | # cond 164 | T = self.get_cam_embeddings( 165 | elevation, azimuth, radius, default_elevation) 166 | cc_emb = torch.cat( 167 | [self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) 168 | cc_emb = self.pipe.clip_camera_projection(cc_emb) 169 | cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) 170 | 171 | vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) 172 | vae_emb = torch.cat( 173 | [vae_emb, torch.zeros_like(vae_emb)], dim=0) 174 | 175 | @torch.no_grad() 176 | def pred_noise(latents, t_expand, uncond=False): 177 | 178 | x_in = torch.cat([latents] * 2) 179 | t_in = torch.cat([t_expand] * 2) 180 | noise_pred = self.unet( 181 | torch.cat([x_in, vae_emb], dim=1), 182 | t_in.to(self.unet.dtype), 183 | encoder_hidden_states=cc_emb, 184 | ).sample 185 | noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) 186 | if uncond: 187 | return noise_pred_uncond 188 | noise_pred = noise_pred_uncond + guidance_scale * \ 189 | (noise_pred_cond - noise_pred_uncond) 190 | 191 | return noise_pred 192 | 193 | t_stop = t[0] # t_2 in paper 194 | ratio = torch.tensor( 195 | (step/1000)*iter_steps*inverse_ratio, dtype=torch.long) 196 | t_spe = inverse_scheduler.timesteps[ratio] # t_1 in paper 197 | if t_spe > 0: 198 | latents_noisy = self.scheduler.add_noise( 199 | latents_noisy, torch.randn_like(latents), t_spe) 200 | 201 | if init_3d: 202 | latents_noisy = torch.randn_like(latents) 203 | t_stop = 1000 204 | else: 205 | for i, t_inv in enumerate(inverse_scheduler.timesteps[:-1]): 206 | t_inv_prev = inverse_scheduler.timesteps[i+1] 207 | if t_inv_prev <= t_spe: 208 | continue 209 | if t_inv_prev > t_stop: 210 | break 211 | t_inv_expand = t_inv.repeat(batch_size).to(self.device) 212 | noise_pred = pred_noise( 213 | latents_noisy, t_inv_expand, uncond=True) 214 | latents_noisy = inverse_scheduler.step( 215 | noise_pred, t_inv_prev, latents_noisy).prev_sample.clone().detach() 216 | 217 | for tt in scheduler.timesteps: 218 | if tt > t_stop: 219 | continue 220 | tt_expand = tt.repeat(batch_size).to(self.device) 221 | noise_pred = pred_noise(latents_noisy, tt_expand) 222 | latents_noisy = scheduler.step(noise_pred, tt, latents_noisy, eta=ddim_eta).prev_sample.to( 223 | latents.dtype).clone().detach() 224 | pred_latents = latents_noisy 225 | target_img = self.decode_latents(pred_latents) 226 | 227 | # real_target_img = F.interpolate( 228 | # target_img, (pred_rgb.shape[-2], pred_rgb.shape[-1]), mode='bicubic', align_corners=False) 229 | # loss = F.l1_loss(pred_rgb, real_target_img.to( 230 | # pred_rgb), reduction='sum')/pred_rgb.shape[0] 231 | 232 | return target_img 233 | 234 | def decode_latents(self, latents): 235 | latents = 1 / self.vae.config.scaling_factor * latents 236 | 237 | imgs = self.vae.decode(latents).sample 238 | imgs = (imgs / 2 + 0.5).clamp(0, 1) 239 | 240 | return imgs 241 | 242 | def encode_imgs(self, imgs, mode=False): 243 | # imgs: [B, 3, H, W] 244 | 245 | imgs = 2 * imgs - 1 246 | 247 | posterior = self.vae.encode(imgs).latent_dist 248 | if mode: 249 | latents = posterior.mode() 250 | else: 251 | latents = posterior.sample() 252 | latents = latents * self.vae.config.scaling_factor 253 | 254 | return latents 255 | 256 | -------------------------------------------------------------------------------- /loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | from lpipsPyTorch import lpips as lpips_fn 17 | from lpipsPyTorch.modules.lpips import LPIPS 18 | 19 | _lpips = None 20 | 21 | 22 | def l1_loss(network_output, gt): 23 | return torch.abs((network_output - gt)).mean() 24 | 25 | 26 | def l2_loss(network_output, gt): 27 | return ((network_output - gt) ** 2).mean() 28 | 29 | 30 | def gaussian(window_size, sigma): 31 | gauss = torch.Tensor( 32 | [ 33 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 34 | for x in range(window_size) 35 | ] 36 | ) 37 | return gauss / gauss.sum() 38 | 39 | 40 | def create_window(window_size, channel): 41 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 42 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 43 | window = Variable( 44 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 45 | ) 46 | return window 47 | 48 | 49 | def ssim(img1, img2, window_size=11, size_average=True): 50 | channel = img1.size(-3) 51 | window = create_window(window_size, channel) 52 | 53 | if img1.is_cuda: 54 | window = window.cuda(img1.get_device()) 55 | window = window.type_as(img1) 56 | 57 | return _ssim(img1, img2, window, window_size, channel, size_average) 58 | 59 | 60 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 61 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 62 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 63 | 64 | mu1_sq = mu1.pow(2) 65 | mu2_sq = mu2.pow(2) 66 | mu1_mu2 = mu1 * mu2 67 | 68 | sigma1_sq = ( 69 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 70 | ) 71 | sigma2_sq = ( 72 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 73 | ) 74 | sigma12 = ( 75 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 76 | - mu1_mu2 77 | ) 78 | 79 | C1 = 0.01**2 80 | C2 = 0.03**2 81 | 82 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 83 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 84 | ) 85 | 86 | if size_average: 87 | return ssim_map.mean() 88 | else: 89 | return ssim_map.mean(1).mean(1).mean(1) 90 | 91 | 92 | def lpips(img1, img2): 93 | global _lpips 94 | if _lpips is None: 95 | _lpips = LPIPS("vgg", "0.1").to("cuda") 96 | return _lpips(img1, img2).mean() 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | 18 | def __init__(self, net_type: str = "alex", version: str = "0.1"): 19 | 20 | assert version in ["0.1"], "v0.1 is only supported now" 21 | 22 | super(LPIPS, self).__init__() 23 | 24 | # pretrained network 25 | self.net = get_network(net_type) 26 | 27 | # linear layers 28 | self.lin = LinLayers(self.net.n_channels_list) 29 | self.lin.load_state_dict(get_state_dict(net_type, version)) 30 | self.eval() 31 | 32 | def forward(self, x: torch.Tensor, y: torch.Tensor): 33 | feat_x, feat_y = self.net(x), self.net(y) 34 | 35 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 36 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 37 | 38 | return torch.sum(torch.cat(res, 0), 0, True) 39 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /mesh_renderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import cv2 4 | import trimesh 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import nvdiffrast.torch as dr 12 | from mesh import Mesh, safe_normalize 13 | 14 | def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'): 15 | assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" 16 | y = x.permute(0, 3, 1, 2) # NHWC -> NCHW 17 | if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger 18 | y = torch.nn.functional.interpolate(y, size, mode=min) 19 | else: # Magnification 20 | if mag == 'bilinear' or mag == 'bicubic': 21 | y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) 22 | else: 23 | y = torch.nn.functional.interpolate(y, size, mode=mag) 24 | return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC 25 | 26 | def scale_img_hwc(x, size, mag='bilinear', min='bilinear'): 27 | return scale_img_nhwc(x[None, ...], size, mag, min)[0] 28 | 29 | def scale_img_nhw(x, size, mag='bilinear', min='bilinear'): 30 | return scale_img_nhwc(x[..., None], size, mag, min)[..., 0] 31 | 32 | def scale_img_hw(x, size, mag='bilinear', min='bilinear'): 33 | return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0] 34 | 35 | def trunc_rev_sigmoid(x, eps=1e-6): 36 | x = x.clamp(eps, 1 - eps) 37 | return torch.log(x / (1 - x)) 38 | 39 | def make_divisible(x, m=8): 40 | return int(math.ceil(x / m) * m) 41 | 42 | class Renderer(nn.Module): 43 | def __init__(self, opt): 44 | 45 | super().__init__() 46 | 47 | self.opt = opt 48 | 49 | self.mesh = Mesh.load(self.opt.mesh, resize=False) 50 | 51 | if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'): 52 | self.glctx = dr.RasterizeGLContext() 53 | else: 54 | self.glctx = dr.RasterizeCudaContext() 55 | 56 | # extract trainable parameters 57 | self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)) 58 | self.raw_albedo = nn.Parameter(trunc_rev_sigmoid(self.mesh.albedo)) 59 | 60 | 61 | def get_params(self): 62 | 63 | params = [ 64 | {'params': self.raw_albedo, 'lr': self.opt.texture_lr}, 65 | ] 66 | 67 | if self.opt.train_geo: 68 | params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr}) 69 | 70 | return params 71 | 72 | @torch.no_grad() 73 | def export_mesh(self, save_path): 74 | self.mesh.v = (self.mesh.v + self.v_offsets).detach() 75 | self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach()) 76 | self.mesh.write(save_path) 77 | 78 | 79 | def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'): 80 | 81 | # do super-sampling 82 | if ssaa != 1: 83 | h = make_divisible(h0 * ssaa, 8) 84 | w = make_divisible(w0 * ssaa, 8) 85 | else: 86 | h, w = h0, w0 87 | 88 | results = {} 89 | 90 | # get v 91 | if self.opt.train_geo: 92 | v = self.mesh.v + self.v_offsets # [N, 3] 93 | else: 94 | v = self.mesh.v 95 | 96 | pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) 97 | proj = torch.from_numpy(proj.astype(np.float32)).to(v.device) 98 | 99 | # get v_clip and render rgb 100 | v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) 101 | v_clip = v_cam @ proj.T 102 | 103 | rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w)) 104 | 105 | alpha = (rast[0, ..., 3:] > 0).float() 106 | depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1] 107 | depth = depth.squeeze(0) # [H, W, 1] 108 | 109 | texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all') 110 | albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3] 111 | albedo = torch.sigmoid(albedo) 112 | # get vn and render normal 113 | if self.opt.train_geo: 114 | i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long() 115 | v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] 116 | 117 | face_normals = torch.cross(v1 - v0, v2 - v0) 118 | face_normals = safe_normalize(face_normals) 119 | 120 | vn = torch.zeros_like(v) 121 | vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) 122 | vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) 123 | vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) 124 | 125 | vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) 126 | else: 127 | vn = self.mesh.vn 128 | 129 | normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.fn) 130 | normal = safe_normalize(normal[0]) 131 | 132 | # rotated normal (where [0, 0, 1] always faces camera) 133 | rot_normal = normal @ pose[:3, :3] 134 | viewcos = rot_normal[..., [2]] 135 | 136 | # antialias 137 | albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3] 138 | albedo = alpha * albedo + (1 - alpha) * bg_color 139 | 140 | # ssaa 141 | if ssaa != 1: 142 | albedo = scale_img_hwc(albedo, (h0, w0)) 143 | alpha = scale_img_hwc(alpha, (h0, w0)) 144 | depth = scale_img_hwc(depth, (h0, w0)) 145 | normal = scale_img_hwc(normal, (h0, w0)) 146 | viewcos = scale_img_hwc(viewcos, (h0, w0)) 147 | 148 | results['image'] = albedo.clamp(0, 1) 149 | results['alpha'] = alpha 150 | results['depth'] = depth 151 | results['normal'] = (normal + 1) / 2 152 | results['viewcos'] = viewcos 153 | 154 | return results -------------------------------------------------------------------------------- /mesh_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pymeshlab as pml 3 | 4 | 5 | def poisson_mesh_reconstruction(points, normals=None): 6 | # points/normals: [N, 3] np.ndarray 7 | 8 | import open3d as o3d 9 | 10 | pcd = o3d.geometry.PointCloud() 11 | pcd.points = o3d.utility.Vector3dVector(points) 12 | 13 | # outlier removal 14 | pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10) 15 | 16 | # normals 17 | if normals is None: 18 | pcd.estimate_normals() 19 | else: 20 | pcd.normals = o3d.utility.Vector3dVector(normals[ind]) 21 | 22 | # visualize 23 | o3d.visualization.draw_geometries([pcd], point_show_normal=False) 24 | 25 | mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( 26 | pcd, depth=9 27 | ) 28 | vertices_to_remove = densities < np.quantile(densities, 0.1) 29 | mesh.remove_vertices_by_mask(vertices_to_remove) 30 | 31 | # visualize 32 | o3d.visualization.draw_geometries([mesh]) 33 | 34 | vertices = np.asarray(mesh.vertices) 35 | triangles = np.asarray(mesh.triangles) 36 | 37 | print( 38 | f"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}" 39 | ) 40 | 41 | return vertices, triangles 42 | 43 | 44 | def decimate_mesh( 45 | verts, faces, target, backend="pymeshlab", remesh=False, optimalplacement=True 46 | ): 47 | # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect. 48 | 49 | _ori_vert_shape = verts.shape 50 | _ori_face_shape = faces.shape 51 | 52 | if backend == "pyfqmr": 53 | import pyfqmr 54 | 55 | solver = pyfqmr.Simplify() 56 | solver.setMesh(verts, faces) 57 | solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False) 58 | verts, faces, normals = solver.getMesh() 59 | else: 60 | m = pml.Mesh(verts, faces) 61 | ms = pml.MeshSet() 62 | ms.add_mesh(m, "mesh") # will copy! 63 | 64 | # filters 65 | # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1)) 66 | ms.meshing_decimation_quadric_edge_collapse( 67 | targetfacenum=int(target), optimalplacement=optimalplacement 68 | ) 69 | 70 | if remesh: 71 | # ms.apply_coord_taubin_smoothing() 72 | ms.meshing_isotropic_explicit_remeshing( 73 | iterations=3, targetlen=pml.PercentageValue(1) 74 | ) 75 | 76 | # extract mesh 77 | m = ms.current_mesh() 78 | verts = m.vertex_matrix() 79 | faces = m.face_matrix() 80 | 81 | print( 82 | f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}" 83 | ) 84 | 85 | return verts, faces 86 | 87 | 88 | def clean_mesh( 89 | verts, 90 | faces, 91 | v_pct=1, 92 | min_f=64, 93 | min_d=20, 94 | repair=True, 95 | remesh=True, 96 | remesh_size=0.01, 97 | ): 98 | # verts: [N, 3] 99 | # faces: [N, 3] 100 | 101 | _ori_vert_shape = verts.shape 102 | _ori_face_shape = faces.shape 103 | 104 | m = pml.Mesh(verts, faces) 105 | ms = pml.MeshSet() 106 | ms.add_mesh(m, "mesh") # will copy! 107 | 108 | # filters 109 | ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces 110 | 111 | if v_pct > 0: 112 | ms.meshing_merge_close_vertices( 113 | threshold=pml.PercentageValue(v_pct) 114 | ) # 1/10000 of bounding box diagonal 115 | 116 | ms.meshing_remove_duplicate_faces() # faces defined by the same verts 117 | ms.meshing_remove_null_faces() # faces with area == 0 118 | 119 | if min_d > 0: 120 | ms.meshing_remove_connected_component_by_diameter( 121 | mincomponentdiag=pml.PercentageValue(min_d) 122 | ) 123 | 124 | if min_f > 0: 125 | ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) 126 | 127 | if repair: 128 | # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) 129 | ms.meshing_repair_non_manifold_edges(method=0) 130 | ms.meshing_repair_non_manifold_vertices(vertdispratio=0) 131 | 132 | if remesh: 133 | # ms.apply_coord_taubin_smoothing() 134 | ms.meshing_isotropic_explicit_remeshing( 135 | iterations=3, targetlen=pml.PureValue(remesh_size) 136 | ) 137 | 138 | # extract mesh 139 | m = ms.current_mesh() 140 | verts = m.vertex_matrix() 141 | faces = m.face_matrix() 142 | 143 | print( 144 | f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}" 145 | ) 146 | 147 | return verts, faces 148 | -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import cv2 5 | import argparse 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision import transforms 13 | from PIL import Image 14 | import rembg 15 | 16 | class BLIP2(): 17 | def __init__(self, device='cuda'): 18 | self.device = device 19 | from transformers import AutoProcessor, Blip2ForConditionalGeneration 20 | self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") 21 | self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device) 22 | 23 | @torch.no_grad() 24 | def __call__(self, image): 25 | image = Image.fromarray(image) 26 | inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16) 27 | 28 | generated_ids = self.model.generate(**inputs, max_new_tokens=20) 29 | generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() 30 | 31 | return generated_text 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)") 38 | parser.add_argument('--model', default='u2net', type=str, help="rembg model, see https://github.com/danielgatis/rembg#models") 39 | parser.add_argument('--size', default=256, type=int, help="output resolution") 40 | parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio") 41 | parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123") 42 | opt = parser.parse_args() 43 | 44 | session = rembg.new_session(model_name=opt.model) 45 | 46 | if os.path.isdir(opt.path): 47 | print(f'[INFO] processing directory {opt.path}...') 48 | files = glob.glob(f'{opt.path}/*') 49 | out_dir = opt.path 50 | else: # isfile 51 | files = [opt.path] 52 | out_dir = os.path.dirname(opt.path) 53 | 54 | for file in files: 55 | 56 | out_base = os.path.basename(file).split('.')[0] 57 | out_rgba = os.path.join(out_dir, out_base + '_rgba.png') 58 | 59 | # load image 60 | print(f'[INFO] loading image {file}...') 61 | image = cv2.imread(file, cv2.IMREAD_UNCHANGED) 62 | 63 | # carve background 64 | print(f'[INFO] background removal...') 65 | try: 66 | carved_image = rembg.remove(image, session=session) # [H, W, 4] 67 | except: 68 | carved_image=image 69 | mask = carved_image[..., -1] > 0 70 | 71 | # recenter 72 | if opt.recenter: 73 | print(f'[INFO] recenter...') 74 | final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8) 75 | 76 | coords = np.nonzero(mask) 77 | x_min, x_max = coords[0].min(), coords[0].max() 78 | y_min, y_max = coords[1].min(), coords[1].max() 79 | h = x_max - x_min 80 | w = y_max - y_min 81 | desired_size = int(opt.size * (1 - opt.border_ratio)) 82 | scale = desired_size / max(h, w) 83 | h2 = int(h * scale) 84 | w2 = int(w * scale) 85 | x2_min = (opt.size - h2) // 2 86 | x2_max = x2_min + h2 87 | y2_min = (opt.size - w2) // 2 88 | y2_max = y2_min + w2 89 | final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) 90 | 91 | else: 92 | final_rgba = carved_image 93 | 94 | # write image 95 | cv2.imwrite(out_rgba, final_rgba) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | rich 3 | ninja 4 | numpy 5 | pandas 6 | scipy 7 | scikit-learn 8 | matplotlib 9 | opencv-python 10 | imageio==2.33.1 11 | imageio-ffmpeg 12 | omegaconf 13 | 14 | torch==2.0.1 15 | torchvision==0.15.2 16 | xformers==0.0.22 17 | einops 18 | plyfile 19 | pygltflib 20 | 21 | 22 | # for stable-diffusion 23 | huggingface_hub 24 | diffusers==0.25.0 25 | accelerate 26 | transformers 27 | 28 | # for dmtet and mesh export 29 | xatlas 30 | trimesh 31 | PyMCubes 32 | pymeshlab 33 | 34 | rembg[gpu,cli] 35 | -------------------------------------------------------------------------------- /scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "DDIMScheduler", 3 | "_diffusers_version": "0.25.0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "clip_sample_range": 1.0, 9 | "dynamic_thresholding_ratio": 0.995, 10 | "num_train_timesteps": 1000, 11 | "prediction_type": "epsilon", 12 | "rescale_betas_zero_snr": false, 13 | "sample_max_value": 1.0, 14 | "set_alpha_to_one": false, 15 | "steps_offset": 1, 16 | "thresholding": false, 17 | "timestep_spacing": "leading", 18 | "trained_betas": null 19 | } -------------------------------------------------------------------------------- /scripts/cal_sim.py: -------------------------------------------------------------------------------- 1 | import os 2 | base_dir='./logs' 3 | output_file = "./clip_similarity.txt" 4 | ext_name = ".obj" 5 | ft = "+z" # front direction 6 | 7 | def mesh_render(mesh_path,ft='-y',save_path=''): 8 | import kiui 9 | from kiui.render import GUI 10 | import tqdm 11 | import argparse 12 | from PIL import Image 13 | import numpy as np 14 | import os 15 | parser = argparse.ArgumentParser() 16 | # parser.add_argument('mesh', type=str, help="path to mesh (obj, glb, ...)") 17 | parser.add_argument('--pbr', action='store_true', help="enable PBR material") 18 | parser.add_argument('--envmap', type=str, default=None, help="hdr env map path for pbr") 19 | parser.add_argument('--front_dir', type=str, default='+z', help="mesh front-facing dir") 20 | parser.add_argument('--mode', default='albedo', type=str, choices=['lambertian', 'albedo', 'normal', 'depth', 'pbr'], help="rendering mode") 21 | parser.add_argument('--W', type=int, default=512, help="GUI width") 22 | parser.add_argument('--H', type=int, default=512, help="GUI height") 23 | parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center") 24 | parser.add_argument('--fovy', type=float, default=49, help="default GUI camera fovy") 25 | parser.add_argument("--force_cuda_rast", action='store_true', help="force to use RasterizeCudaContext.") 26 | parser.add_argument('--elevation', type=int, default=0, help="rendering elevation") 27 | parser.add_argument('--num_azimuth', type=int, default=8, help="number of images to render from different azimuths") 28 | 29 | os.makedirs(save_path,exist_ok=True) 30 | 31 | opt = parser.parse_args() 32 | opt.mesh=mesh_path 33 | opt.force_cuda_rast=True 34 | opt.wogui = True 35 | opt.front_dir = ft 36 | opt.ssaa=1 37 | gui = GUI(opt) 38 | elevation = [opt.elevation] 39 | azimuth = np.linspace(0, 360, opt.num_azimuth, dtype=np.int32, endpoint=False) 40 | for ele in elevation: 41 | for ii,azi in enumerate(azimuth): 42 | gui.cam.from_angle(ele, azi) 43 | gui.need_update = True 44 | gui.step() 45 | image = (gui.render_buffer * 255).astype(np.uint8) 46 | image = Image.fromarray(image) 47 | 48 | img_pt=save_path 49 | os.makedirs(f'{img_pt}',exist_ok=True) 50 | image.save(f'{img_pt}/{ii}.png') 51 | 52 | def gen(): 53 | cnt=0 54 | os.makedirs('./tmp',exist_ok=True) 55 | for file in sorted(os.listdir(base_dir)): 56 | if file.endswith('rgba'+ext_name): 57 | cnt+=1 58 | filename=file.split(ext_name)[0] 59 | mesh_render(os.path.join(base_dir,file),ft=ft,save_path=f'./tmp/{filename}') 60 | print(cnt) 61 | 62 | gen() 63 | 64 | from clip_sim import cal_clip_sim 65 | def cal_metrics(): 66 | test_dirs='./tmp' 67 | sims=[] 68 | with open(output_file,'w') as f: 69 | for file in sorted(os.listdir(test_dirs)): 70 | pt=os.path.join(test_dirs,file) 71 | if not os.path.isdir(pt): 72 | continue 73 | 74 | ref_img=os.path.join('./test_data',file+'.png') 75 | print(ref_img) 76 | novel=[os.path.join(pt,f'{i}.png') for i in range(8)] 77 | sim=cal_clip_sim(ref_img,novel) 78 | # print(sim) 79 | sims.append(sim) 80 | f.write(f"{file}: {sim}\n") 81 | print(sum(sims)/len(sims)) 82 | f.write(f"average: {sum(sims)/len(sims)}\n") 83 | 84 | 85 | cal_metrics() 86 | 87 | -------------------------------------------------------------------------------- /scripts/convert_obj_to_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--dir', default='logs', type=str, help='Directory where obj files are stored') 7 | parser.add_argument('--out', default='videos', type=str, help='Directory where videos will be saved') 8 | args = parser.parse_args() 9 | 10 | out = args.out 11 | os.makedirs(out, exist_ok=True) 12 | 13 | files = glob.glob(f'{args.dir}/*.obj') 14 | for f in files: 15 | name = os.path.basename(f) 16 | # first stage model, ignore 17 | if name.endswith('_mesh.obj'): 18 | continue 19 | print(f'[INFO] process {name}') 20 | os.system(f"python -m kiui.render {f} --save_video {os.path.join(out, name.replace('.obj', '.mp4'))} ") -------------------------------------------------------------------------------- /scripts/run_cal_sim.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | PYTHONPATH='.' python scripts/cal_sim.py -------------------------------------------------------------------------------- /scripts/run_imagedream.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # imagedream with no refinement 4 | 5 | python main.py --config configs/imagedream.yaml input=test_data/ghost_rgba.png prompt="a ghost eating hamburger" save_path=ghost 6 | 7 | -------------------------------------------------------------------------------- /scripts/run_sai.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | python main.py --config configs/image_sai.yaml input=test_data/anya_rgba.png save_path=anya_sai 4 | python main2.py --config configs/image_sai.yaml input=test_data/anya_rgba.png save_path=anya_sai 5 | 6 | python main.py --config configs/image_sai.yaml input=test_data/ghost_rgba.png save_path=ghost_sai 7 | python main2.py --config configs/image_sai.yaml input=test_data/ghost_rgba.png save_path=ghost_sai 8 | 9 | python main.py --config configs/image_sai.yaml input=test_data/astronaut_rgba.png save_path=astro_sai 10 | python main2.py --config configs/image_sai.yaml input=test_data/astronaut_rgba.png save_path=astro_sai -------------------------------------------------------------------------------- /sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /simple-knn/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /simple-knn/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | cxx_compiler_flags = [] 17 | 18 | if os.name == 'nt': 19 | cxx_compiler_flags.append("/wd4624") 20 | 21 | setup( 22 | name="simple_knn", 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="simple_knn._C", 26 | sources=[ 27 | "spatial.cu", 28 | "simple_knn.cu", 29 | "ext.cpp"], 30 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /simple-knn/simple_knn.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #define BOX_SIZE 1024 13 | 14 | #include "cuda_runtime.h" 15 | #include "device_launch_parameters.h" 16 | #include "simple_knn.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #define __CUDACC__ 24 | #include 25 | #include 26 | 27 | namespace cg = cooperative_groups; 28 | 29 | struct CustomMin 30 | { 31 | __device__ __forceinline__ 32 | float3 operator()(const float3& a, const float3& b) const { 33 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; 34 | } 35 | }; 36 | 37 | struct CustomMax 38 | { 39 | __device__ __forceinline__ 40 | float3 operator()(const float3& a, const float3& b) const { 41 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; 42 | } 43 | }; 44 | 45 | __host__ __device__ uint32_t prepMorton(uint32_t x) 46 | { 47 | x = (x | (x << 16)) & 0x030000FF; 48 | x = (x | (x << 8)) & 0x0300F00F; 49 | x = (x | (x << 4)) & 0x030C30C3; 50 | x = (x | (x << 2)) & 0x09249249; 51 | return x; 52 | } 53 | 54 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) 55 | { 56 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); 57 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); 58 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); 59 | 60 | return x | (y << 1) | (z << 2); 61 | } 62 | 63 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) 64 | { 65 | auto idx = cg::this_grid().thread_rank(); 66 | if (idx >= P) 67 | return; 68 | 69 | codes[idx] = coord2Morton(points[idx], minn, maxx); 70 | } 71 | 72 | struct MinMax 73 | { 74 | float3 minn; 75 | float3 maxx; 76 | }; 77 | 78 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | 82 | MinMax me; 83 | if (idx < P) 84 | { 85 | me.minn = points[indices[idx]]; 86 | me.maxx = points[indices[idx]]; 87 | } 88 | else 89 | { 90 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; 91 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; 92 | } 93 | 94 | __shared__ MinMax redResult[BOX_SIZE]; 95 | 96 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2) 97 | { 98 | if (threadIdx.x < 2 * off) 99 | redResult[threadIdx.x] = me; 100 | __syncthreads(); 101 | 102 | if (threadIdx.x < off) 103 | { 104 | MinMax other = redResult[threadIdx.x + off]; 105 | me.minn.x = min(me.minn.x, other.minn.x); 106 | me.minn.y = min(me.minn.y, other.minn.y); 107 | me.minn.z = min(me.minn.z, other.minn.z); 108 | me.maxx.x = max(me.maxx.x, other.maxx.x); 109 | me.maxx.y = max(me.maxx.y, other.maxx.y); 110 | me.maxx.z = max(me.maxx.z, other.maxx.z); 111 | } 112 | __syncthreads(); 113 | } 114 | 115 | if (threadIdx.x == 0) 116 | boxes[blockIdx.x] = me; 117 | } 118 | 119 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) 120 | { 121 | float3 diff = { 0, 0, 0 }; 122 | if (p.x < box.minn.x || p.x > box.maxx.x) 123 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); 124 | if (p.y < box.minn.y || p.y > box.maxx.y) 125 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); 126 | if (p.z < box.minn.z || p.z > box.maxx.z) 127 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); 128 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; 129 | } 130 | 131 | template 132 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn) 133 | { 134 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; 135 | float dist = d.x * d.x + d.y * d.y + d.z * d.z; 136 | for (int j = 0; j < K; j++) 137 | { 138 | if (knn[j] > dist) 139 | { 140 | float t = knn[j]; 141 | knn[j] = dist; 142 | dist = t; 143 | } 144 | } 145 | } 146 | 147 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) 148 | { 149 | int idx = cg::this_grid().thread_rank(); 150 | if (idx >= P) 151 | return; 152 | 153 | float3 point = points[indices[idx]]; 154 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; 155 | 156 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) 157 | { 158 | if (i == idx) 159 | continue; 160 | updateKBest<3>(point, points[indices[i]], best); 161 | } 162 | 163 | float reject = best[2]; 164 | best[0] = FLT_MAX; 165 | best[1] = FLT_MAX; 166 | best[2] = FLT_MAX; 167 | 168 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) 169 | { 170 | MinMax box = boxes[b]; 171 | float dist = distBoxPoint(box, point); 172 | if (dist > reject || dist > best[2]) 173 | continue; 174 | 175 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) 176 | { 177 | if (i == idx) 178 | continue; 179 | updateKBest<3>(point, points[indices[i]], best); 180 | } 181 | } 182 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; 183 | } 184 | 185 | void SimpleKNN::knn(int P, float3* points, float* meanDists) 186 | { 187 | float3* result; 188 | cudaMalloc(&result, sizeof(float3)); 189 | size_t temp_storage_bytes; 190 | 191 | float3 init = { 0, 0, 0 }, minn, maxx; 192 | 193 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); 194 | thrust::device_vector temp_storage(temp_storage_bytes); 195 | 196 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); 197 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); 198 | 199 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); 200 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); 201 | 202 | thrust::device_vector morton(P); 203 | thrust::device_vector morton_sorted(P); 204 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); 205 | 206 | thrust::device_vector indices(P); 207 | thrust::sequence(indices.begin(), indices.end()); 208 | thrust::device_vector indices_sorted(P); 209 | 210 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 211 | temp_storage.resize(temp_storage_bytes); 212 | 213 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 214 | 215 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; 216 | thrust::device_vector boxes(num_boxes); 217 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); 218 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); 219 | 220 | cudaFree(result); 221 | } -------------------------------------------------------------------------------- /simple-knn/simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /simple-knn/simple_knn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/simple-knn/simple_knn/.gitkeep -------------------------------------------------------------------------------- /simple-knn/spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | torch::Tensor 16 | distCUDA2(const torch::Tensor& points) 17 | { 18 | const int P = points.size(0); 19 | 20 | auto float_opts = points.options().dtype(torch::kFloat32); 21 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 22 | 23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); 24 | 25 | return means; 26 | } -------------------------------------------------------------------------------- /simple-knn/spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); -------------------------------------------------------------------------------- /test_data/00_zero123_lysol_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/00_zero123_lysol_rgba.png -------------------------------------------------------------------------------- /test_data/01_wild_hydrant_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/01_wild_hydrant_rgba.png -------------------------------------------------------------------------------- /test_data/02_zero123_spyro_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/02_zero123_spyro_rgba.png -------------------------------------------------------------------------------- /test_data/03_wild2_pineapple_bottle_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/03_wild2_pineapple_bottle_rgba.png -------------------------------------------------------------------------------- /test_data/04_unsplash_broccoli_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/04_unsplash_broccoli_rgba.png -------------------------------------------------------------------------------- /test_data/05_objaverse_backpack_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/05_objaverse_backpack_rgba.png -------------------------------------------------------------------------------- /test_data/06_unsplash_chocolatecake_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/06_unsplash_chocolatecake_rgba.png -------------------------------------------------------------------------------- /test_data/07_unsplash_stool2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/07_unsplash_stool2_rgba.png -------------------------------------------------------------------------------- /test_data/08_dalle_icecream_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/08_dalle_icecream_rgba.png -------------------------------------------------------------------------------- /test_data/09_unsplash_bigmac_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/09_unsplash_bigmac_rgba.png -------------------------------------------------------------------------------- /test_data/10_dalle3_blueberryicecream2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/10_dalle3_blueberryicecream2_rgba.png -------------------------------------------------------------------------------- /test_data/11_GSO_Crosley_Alarm_Clock_Vintage_Metal_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/11_GSO_Crosley_Alarm_Clock_Vintage_Metal_rgba.png -------------------------------------------------------------------------------- /test_data/12_realfusion_cactus_1_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/12_realfusion_cactus_1_rgba.png -------------------------------------------------------------------------------- /test_data/13_realfusion_cherry_1_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/13_realfusion_cherry_1_rgba.png -------------------------------------------------------------------------------- /test_data/14_dalle_cowbear_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/14_dalle_cowbear_rgba.png -------------------------------------------------------------------------------- /test_data/15_dalle3_gramophone1_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/15_dalle3_gramophone1_rgba.png -------------------------------------------------------------------------------- /test_data/16_dalle3_mushroom2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/16_dalle3_mushroom2_rgba.png -------------------------------------------------------------------------------- /test_data/17_dalle3_rockingchair1_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/17_dalle3_rockingchair1_rgba.png -------------------------------------------------------------------------------- /test_data/18_unsplash_mario_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/18_unsplash_mario_rgba.png -------------------------------------------------------------------------------- /test_data/19_dalle3_stump1_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/19_dalle3_stump1_rgba.png -------------------------------------------------------------------------------- /test_data/20_objaverse_stool_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/20_objaverse_stool_rgba.png -------------------------------------------------------------------------------- /test_data/21_objaverse_barrel_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/21_objaverse_barrel_rgba.png -------------------------------------------------------------------------------- /test_data/22_unsplash_boxtoy_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/22_unsplash_boxtoy_rgba.png -------------------------------------------------------------------------------- /test_data/23_objaverse_tank_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/23_objaverse_tank_rgba.png -------------------------------------------------------------------------------- /test_data/24_wild2_yellow_duck_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/24_wild2_yellow_duck_rgba.png -------------------------------------------------------------------------------- /test_data/25_unsplash_teapot_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/25_unsplash_teapot_rgba.png -------------------------------------------------------------------------------- /test_data/26_unsplash_strawberrycake_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/26_unsplash_strawberrycake_rgba.png -------------------------------------------------------------------------------- /test_data/27_objaverse_robocat_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/27_objaverse_robocat_rgba.png -------------------------------------------------------------------------------- /test_data/28_wild_goose_chef_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/28_wild_goose_chef_rgba.png -------------------------------------------------------------------------------- /test_data/29_wild_peroxide_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/29_wild_peroxide_rgba.png -------------------------------------------------------------------------------- /test_data/alarm_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/alarm_rgba.png -------------------------------------------------------------------------------- /test_data/anya_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/anya_rgba.png -------------------------------------------------------------------------------- /test_data/armor_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/armor_rgba.png -------------------------------------------------------------------------------- /test_data/astronaut_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/astronaut_rgba.png -------------------------------------------------------------------------------- /test_data/backpack_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/backpack_rgba.png -------------------------------------------------------------------------------- /test_data/box_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/box_rgba.png -------------------------------------------------------------------------------- /test_data/bread_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/bread_rgba.png -------------------------------------------------------------------------------- /test_data/bucket_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/bucket_rgba.png -------------------------------------------------------------------------------- /test_data/busket_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/busket_rgba.png -------------------------------------------------------------------------------- /test_data/cargo_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/cargo_rgba.png -------------------------------------------------------------------------------- /test_data/cat_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/cat_rgba.png -------------------------------------------------------------------------------- /test_data/catstatue_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/catstatue_rgba.png -------------------------------------------------------------------------------- /test_data/chili_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/chili_rgba.png -------------------------------------------------------------------------------- /test_data/crab_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/crab_rgba.png -------------------------------------------------------------------------------- /test_data/crystal_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/crystal_rgba.png -------------------------------------------------------------------------------- /test_data/csm_luigi_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/csm_luigi_rgba.png -------------------------------------------------------------------------------- /test_data/deer_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/deer_rgba.png -------------------------------------------------------------------------------- /test_data/drum2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/drum2_rgba.png -------------------------------------------------------------------------------- /test_data/drum_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/drum_rgba.png -------------------------------------------------------------------------------- /test_data/elephant_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/elephant_rgba.png -------------------------------------------------------------------------------- /test_data/flower2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/flower2_rgba.png -------------------------------------------------------------------------------- /test_data/flower_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/flower_rgba.png -------------------------------------------------------------------------------- /test_data/forest_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/forest_rgba.png -------------------------------------------------------------------------------- /test_data/frog_sweater_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/frog_sweater_rgba.png -------------------------------------------------------------------------------- /test_data/ghost_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/ghost_rgba.png -------------------------------------------------------------------------------- /test_data/giraffe_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/giraffe_rgba.png -------------------------------------------------------------------------------- /test_data/grandfather_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/grandfather_rgba.png -------------------------------------------------------------------------------- /test_data/ground_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/ground_rgba.png -------------------------------------------------------------------------------- /test_data/halloween_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/halloween_rgba.png -------------------------------------------------------------------------------- /test_data/hat_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/hat_rgba.png -------------------------------------------------------------------------------- /test_data/head_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/head_rgba.png -------------------------------------------------------------------------------- /test_data/house_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/house_rgba.png -------------------------------------------------------------------------------- /test_data/kettle_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/kettle_rgba.png -------------------------------------------------------------------------------- /test_data/kunkun_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/kunkun_rgba.png -------------------------------------------------------------------------------- /test_data/lantern_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/lantern_rgba.png -------------------------------------------------------------------------------- /test_data/lotus_seed_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/lotus_seed_rgba.png -------------------------------------------------------------------------------- /test_data/lunch_bag_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/lunch_bag_rgba.png -------------------------------------------------------------------------------- /test_data/milk_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/milk_rgba.png -------------------------------------------------------------------------------- /test_data/monkey_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/monkey_rgba.png -------------------------------------------------------------------------------- /test_data/oil_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/oil_rgba.png -------------------------------------------------------------------------------- /test_data/poro_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/poro_rgba.png -------------------------------------------------------------------------------- /test_data/rabbit_chinese_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/rabbit_chinese_rgba.png -------------------------------------------------------------------------------- /test_data/school_bus1_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/school_bus1_rgba.png -------------------------------------------------------------------------------- /test_data/school_bus2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/school_bus2_rgba.png -------------------------------------------------------------------------------- /test_data/shed_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/shed_rgba.png -------------------------------------------------------------------------------- /test_data/shoe_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/shoe_rgba.png -------------------------------------------------------------------------------- /test_data/sofa2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/sofa2_rgba.png -------------------------------------------------------------------------------- /test_data/sofa_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/sofa_rgba.png -------------------------------------------------------------------------------- /test_data/steak_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/steak_rgba.png -------------------------------------------------------------------------------- /test_data/teapot2_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/teapot2_rgba.png -------------------------------------------------------------------------------- /test_data/teapot_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/teapot_rgba.png -------------------------------------------------------------------------------- /test_data/test_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/test_rgba.png -------------------------------------------------------------------------------- /test_data/toaster_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/toaster_rgba.png -------------------------------------------------------------------------------- /test_data/turtle_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/turtle_rgba.png -------------------------------------------------------------------------------- /test_data/vase_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/vase_rgba.png -------------------------------------------------------------------------------- /test_data/wisky_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/wisky_rgba.png -------------------------------------------------------------------------------- /test_data/zelda_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ML-GSAI/MicroDreamer/597e3731f7a425f121cbe781182488d23184402d/test_data/zelda_rgba.png --------------------------------------------------------------------------------