├── LICENSE.txt ├── MeshAnything ├── miche │ ├── LICENSE │ ├── encode.py │ ├── michelangelo │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── templates.json │ │ │ ├── transforms.py │ │ │ └── utils.py │ │ ├── graphics │ │ │ ├── __init__.py │ │ │ └── primitives │ │ │ │ ├── __init__.py │ │ │ │ ├── mesh.py │ │ │ │ └── volume.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── asl_diffusion │ │ │ │ ├── __init__.py │ │ │ │ ├── asl_diffuser_pl_module.py │ │ │ │ ├── asl_udt.py │ │ │ │ ├── base.py │ │ │ │ ├── clip_asl_diffuser_pl_module.py │ │ │ │ └── inference_utils.py │ │ │ ├── conditional_encoders │ │ │ │ ├── __init__.py │ │ │ │ ├── clip.py │ │ │ │ └── encoder_factory.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── checkpoint.py │ │ │ │ ├── diffusion_transformer.py │ │ │ │ ├── distributions.py │ │ │ │ ├── embedder.py │ │ │ │ ├── transformer_blocks.py │ │ │ │ └── transformer_vit.py │ │ │ └── tsal │ │ │ │ ├── __init__.py │ │ │ │ ├── asl_pl_module.py │ │ │ │ ├── clip_asl_module.py │ │ │ │ ├── inference_utils.py │ │ │ │ ├── loss.py │ │ │ │ ├── sal_perceiver.py │ │ │ │ ├── sal_pl_module.py │ │ │ │ └── tsal_base.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── eval.py │ │ │ ├── io.py │ │ │ ├── misc.py │ │ │ └── visualizers │ │ │ ├── __init__.py │ │ │ ├── color_util.py │ │ │ ├── html_util.py │ │ │ └── pythreejs_viewer.py │ └── shapevae-256.yaml └── models │ ├── meshanything_v2.py │ └── shape_opt.py ├── README.md ├── adjacent_mesh_tokenization.py ├── app.py ├── data_process.py ├── demo └── demo_video.gif ├── examples ├── screwdriver.obj └── wand.obj ├── gt_examples └── seals.ply ├── main.py ├── mesh_to_pc.py ├── meshanything_train ├── dist.py ├── engine.py ├── eval_cond_gpt.py ├── loop_set_256.py ├── miche │ ├── LICENSE │ ├── README.md │ ├── configs │ │ ├── aligned_shape_latents │ │ │ ├── shapevae-256.yaml │ │ │ └── shapevae-512.yaml │ │ ├── image_cond_diffuser_asl │ │ │ └── image-ASLDM-256.yaml │ │ └── text_cond_diffuser_asl │ │ │ └── text-ASLDM-256.yaml │ ├── encode.py │ ├── example_data │ │ ├── image │ │ │ └── car.jpg │ │ └── surface │ │ │ └── surface.npz │ ├── inference.py │ ├── michelangelo │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── templates.json │ │ │ ├── transforms.py │ │ │ └── utils.py │ │ ├── graphics │ │ │ ├── __init__.py │ │ │ └── primitives │ │ │ │ ├── __init__.py │ │ │ │ ├── mesh.py │ │ │ │ └── volume.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── asl_diffusion │ │ │ │ ├── __init__.py │ │ │ │ ├── asl_diffuser_pl_module.py │ │ │ │ ├── asl_udt.py │ │ │ │ ├── base.py │ │ │ │ ├── clip_asl_diffuser_pl_module.py │ │ │ │ └── inference_utils.py │ │ │ ├── conditional_encoders │ │ │ │ ├── __init__.py │ │ │ │ ├── clip.py │ │ │ │ └── encoder_factory.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── checkpoint.py │ │ │ │ ├── diffusion_transformer.py │ │ │ │ ├── distributions.py │ │ │ │ ├── embedder.py │ │ │ │ ├── transformer_blocks.py │ │ │ │ └── transformer_vit.py │ │ │ └── tsal │ │ │ │ ├── __init__.py │ │ │ │ ├── asl_pl_module.py │ │ │ │ ├── clip_asl_module.py │ │ │ │ ├── inference_utils.py │ │ │ │ ├── loss.py │ │ │ │ ├── sal_perceiver.py │ │ │ │ ├── sal_pl_module.py │ │ │ │ └── tsal_base.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── eval.py │ │ │ ├── io.py │ │ │ ├── misc.py │ │ │ └── visualizers │ │ │ ├── __init__.py │ │ │ ├── color_util.py │ │ │ ├── html_util.py │ │ │ └── pythreejs_viewer.py │ ├── requirements.txt │ ├── scripts │ │ ├── infer.sh │ │ └── inference │ │ │ ├── image2mesh.sh │ │ │ ├── reconstruction.sh │ │ │ └── text2mesh.sh │ └── setup.py ├── misc.py └── models │ ├── shape_opt.py │ └── single_gpt.py ├── pc_examples └── grenade.npy ├── requirements.txt ├── train.py └── training_requirement.txt /LICENSE.txt: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /MeshAnything/miche/encode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | from omegaconf import OmegaConf 4 | import numpy as np 5 | import torch 6 | from .michelangelo.utils.misc import instantiate_from_config 7 | 8 | def load_surface(fp): 9 | 10 | with np.load(fp) as input_pc: 11 | surface = input_pc['points'] 12 | normal = input_pc['normals'] 13 | 14 | rng = np.random.default_rng() 15 | ind = rng.choice(surface.shape[0], 4096, replace=False) 16 | surface = torch.FloatTensor(surface[ind]) 17 | normal = torch.FloatTensor(normal[ind]) 18 | 19 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() 20 | 21 | return surface 22 | 23 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): 24 | 25 | surface = load_surface(args.pointcloud_path) 26 | # old_surface = surface.clone() 27 | 28 | # surface[0,:,0]*=-1 29 | # surface[0,:,1]*=-1 30 | surface[0,:,2]*=-1 31 | 32 | # encoding 33 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) 34 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) 35 | 36 | # decoding 37 | latents = model.model.shape_model.decode(shape_zq) 38 | # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) 39 | 40 | return 0 41 | 42 | def load_model(ckpt_path="MeshAnything/miche/shapevae-256.ckpt"): 43 | model_config = OmegaConf.load("MeshAnything/miche/shapevae-256.yaml") 44 | # print(model_config) 45 | if hasattr(model_config, "model"): 46 | model_config = model_config.model 47 | 48 | model = instantiate_from_config(model_config, ckpt_path=ckpt_path) 49 | device = "cuda" if torch.cuda.is_available() else "cpu" 50 | model.to(device) 51 | model = model.eval() 52 | 53 | return model 54 | if __name__ == "__main__": 55 | ''' 56 | 1. Reconstruct point cloud 57 | 2. Image-conditioned generation 58 | 3. Text-conditioned generation 59 | ''' 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--config_path", type=str, required=True) 62 | parser.add_argument("--ckpt_path", type=str, required=True) 63 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud') 64 | parser.add_argument("--image_path", type=str, help='Path to the input image') 65 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.') 66 | parser.add_argument("--output_dir", type=str, default='./output') 67 | parser.add_argument("-s", "--seed", type=int, default=0) 68 | args = parser.parse_args() 69 | 70 | print(f'-----------------------------------------------------------------------------') 71 | print(f'>>> Output directory: {args.output_dir}') 72 | print(f'-----------------------------------------------------------------------------') 73 | 74 | reconstruction(args, load_model(args)) -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "shape": [ 3 | "a point cloud model of {}.", 4 | "There is a {} in the scene.", 5 | "There is the {} in the scene.", 6 | "a photo of a {} in the scene.", 7 | "a photo of the {} in the scene.", 8 | "a photo of one {} in the scene.", 9 | "itap of a {}.", 10 | "itap of my {}.", 11 | "itap of the {}.", 12 | "a photo of a {}.", 13 | "a photo of my {}.", 14 | "a photo of the {}.", 15 | "a photo of one {}.", 16 | "a photo of many {}.", 17 | "a good photo of a {}.", 18 | "a good photo of the {}.", 19 | "a bad photo of a {}.", 20 | "a bad photo of the {}.", 21 | "a photo of a nice {}.", 22 | "a photo of the nice {}.", 23 | "a photo of a cool {}.", 24 | "a photo of the cool {}.", 25 | "a photo of a weird {}.", 26 | "a photo of the weird {}.", 27 | "a photo of a small {}.", 28 | "a photo of the small {}.", 29 | "a photo of a large {}.", 30 | "a photo of the large {}.", 31 | "a photo of a clean {}.", 32 | "a photo of the clean {}.", 33 | "a photo of a dirty {}.", 34 | "a photo of the dirty {}.", 35 | "a bright photo of a {}.", 36 | "a bright photo of the {}.", 37 | "a dark photo of a {}.", 38 | "a dark photo of the {}.", 39 | "a photo of a hard to see {}.", 40 | "a photo of the hard to see {}.", 41 | "a low resolution photo of a {}.", 42 | "a low resolution photo of the {}.", 43 | "a cropped photo of a {}.", 44 | "a cropped photo of the {}.", 45 | "a close-up photo of a {}.", 46 | "a close-up photo of the {}.", 47 | "a jpeg corrupted photo of a {}.", 48 | "a jpeg corrupted photo of the {}.", 49 | "a blurry photo of a {}.", 50 | "a blurry photo of the {}.", 51 | "a pixelated photo of a {}.", 52 | "a pixelated photo of the {}.", 53 | "a black and white photo of the {}.", 54 | "a black and white photo of a {}", 55 | "a plastic {}.", 56 | "the plastic {}.", 57 | "a toy {}.", 58 | "the toy {}.", 59 | "a plushie {}.", 60 | "the plushie {}.", 61 | "a cartoon {}.", 62 | "the cartoon {}.", 63 | "an embroidered {}.", 64 | "the embroidered {}.", 65 | "a painting of the {}.", 66 | "a painting of a {}." 67 | ] 68 | 69 | } -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def worker_init_fn(_): 8 | worker_info = torch.utils.data.get_worker_info() 9 | worker_id = worker_info.id 10 | 11 | # dataset = worker_info.dataset 12 | # split_size = dataset.num_records // worker_info.num_workers 13 | # # reset num_records to the true number to retain reliable length information 14 | # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 15 | # current_id = np.random.choice(len(np.random.get_state()[1]), 1) 16 | # return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 17 | 18 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 19 | 20 | 21 | def collation_fn(samples, combine_tensors=True, combine_scalars=True): 22 | """ 23 | 24 | Args: 25 | samples (list[dict]): 26 | combine_tensors: 27 | combine_scalars: 28 | 29 | Returns: 30 | 31 | """ 32 | 33 | result = {} 34 | 35 | keys = samples[0].keys() 36 | 37 | for key in keys: 38 | result[key] = [] 39 | 40 | for sample in samples: 41 | for key in keys: 42 | val = sample[key] 43 | result[key].append(val) 44 | 45 | for key in keys: 46 | val_list = result[key] 47 | if isinstance(val_list[0], (int, float)): 48 | if combine_scalars: 49 | result[key] = np.array(result[key]) 50 | 51 | elif isinstance(val_list[0], torch.Tensor): 52 | if combine_tensors: 53 | result[key] = torch.stack(val_list) 54 | 55 | elif isinstance(val_list[0], np.ndarray): 56 | if combine_tensors: 57 | result[key] = np.stack(val_list) 58 | 59 | return result 60 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .volume import generate_dense_grid_points 4 | 5 | from .mesh import ( 6 | MeshOutput, 7 | save_obj, 8 | savemeshtes2 9 | ) 10 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/mesh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import PIL.Image 7 | from typing import Optional 8 | 9 | import trimesh 10 | 11 | 12 | def save_obj(pointnp_px3, facenp_fx3, fname): 13 | fid = open(fname, "w") 14 | write_str = "" 15 | for pidx, p in enumerate(pointnp_px3): 16 | pp = p 17 | write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2]) 18 | 19 | for i, f in enumerate(facenp_fx3): 20 | f1 = f + 1 21 | write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2]) 22 | fid.write(write_str) 23 | fid.close() 24 | return 25 | 26 | 27 | def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname): 28 | fol, na = os.path.split(fname) 29 | na, _ = os.path.splitext(na) 30 | 31 | matname = "%s/%s.mtl" % (fol, na) 32 | fid = open(matname, "w") 33 | fid.write("newmtl material_0\n") 34 | fid.write("Kd 1 1 1\n") 35 | fid.write("Ka 0 0 0\n") 36 | fid.write("Ks 0.4 0.4 0.4\n") 37 | fid.write("Ns 10\n") 38 | fid.write("illum 2\n") 39 | fid.write("map_Kd %s.png\n" % na) 40 | fid.close() 41 | #### 42 | 43 | fid = open(fname, "w") 44 | fid.write("mtllib %s.mtl\n" % na) 45 | 46 | for pidx, p in enumerate(pointnp_px3): 47 | pp = p 48 | fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2])) 49 | 50 | for pidx, p in enumerate(tcoords_px2): 51 | pp = p 52 | fid.write("vt %f %f\n" % (pp[0], pp[1])) 53 | 54 | fid.write("usemtl material_0\n") 55 | for i, f in enumerate(facenp_fx3): 56 | f1 = f + 1 57 | f2 = facetex_fx3[i] + 1 58 | fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) 59 | fid.close() 60 | 61 | PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save( 62 | os.path.join(fol, "%s.png" % na)) 63 | 64 | return 65 | 66 | 67 | class MeshOutput(object): 68 | 69 | def __init__(self, 70 | mesh_v: np.ndarray, 71 | mesh_f: np.ndarray, 72 | vertex_colors: Optional[np.ndarray] = None, 73 | uvs: Optional[np.ndarray] = None, 74 | mesh_tex_idx: Optional[np.ndarray] = None, 75 | tex_map: Optional[np.ndarray] = None): 76 | 77 | self.mesh_v = mesh_v 78 | self.mesh_f = mesh_f 79 | self.vertex_colors = vertex_colors 80 | self.uvs = uvs 81 | self.mesh_tex_idx = mesh_tex_idx 82 | self.tex_map = tex_map 83 | 84 | def contain_uv_texture(self): 85 | return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None) 86 | 87 | def contain_vertex_colors(self): 88 | return self.vertex_colors is not None 89 | 90 | def export(self, fname): 91 | 92 | if self.contain_uv_texture(): 93 | savemeshtes2( 94 | self.mesh_v, 95 | self.uvs, 96 | self.mesh_f, 97 | self.mesh_tex_idx, 98 | self.tex_map, 99 | fname 100 | ) 101 | 102 | elif self.contain_vertex_colors(): 103 | mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors) 104 | mesh_obj.export(fname) 105 | 106 | else: 107 | save_obj( 108 | self.mesh_v, 109 | self.mesh_f, 110 | fname 111 | ) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/volume.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | 6 | def generate_dense_grid_points(bbox_min: np.ndarray, 7 | bbox_max: np.ndarray, 8 | octree_depth: int, 9 | indexing: str = "ij"): 10 | length = bbox_max - bbox_min 11 | num_cells = np.exp2(octree_depth) 12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) 13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) 14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) 15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) 16 | xyz = np.stack((xs, ys, zs), axis=-1) 17 | xyz = xyz.reshape(-1, 3) 18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] 19 | 20 | return xyz, grid_size, length 21 | 22 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/asl_udt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Optional 6 | from diffusers.models.embeddings import Timesteps 7 | import math 8 | 9 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import MLP 10 | from MeshAnything.miche.michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer 11 | 12 | 13 | class ConditionalASLUDTDenoiser(nn.Module): 14 | 15 | def __init__(self, *, 16 | device: Optional[torch.device], 17 | dtype: Optional[torch.dtype], 18 | input_channels: int, 19 | output_channels: int, 20 | n_ctx: int, 21 | width: int, 22 | layers: int, 23 | heads: int, 24 | context_dim: int, 25 | context_ln: bool = True, 26 | skip_ln: bool = False, 27 | init_scale: float = 0.25, 28 | flip_sin_to_cos: bool = False, 29 | use_checkpoint: bool = False): 30 | super().__init__() 31 | 32 | self.use_checkpoint = use_checkpoint 33 | 34 | init_scale = init_scale * math.sqrt(1.0 / width) 35 | 36 | self.backbone = UNetDiffusionTransformer( 37 | device=device, 38 | dtype=dtype, 39 | n_ctx=n_ctx, 40 | width=width, 41 | layers=layers, 42 | heads=heads, 43 | skip_ln=skip_ln, 44 | init_scale=init_scale, 45 | use_checkpoint=use_checkpoint 46 | ) 47 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 48 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) 49 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) 50 | 51 | # timestep embedding 52 | self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) 53 | self.time_proj = MLP( 54 | device=device, dtype=dtype, width=width, init_scale=init_scale 55 | ) 56 | 57 | self.context_embed = nn.Sequential( 58 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 59 | nn.Linear(context_dim, width, device=device, dtype=dtype), 60 | ) 61 | 62 | if context_ln: 63 | self.context_embed = nn.Sequential( 64 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 65 | nn.Linear(context_dim, width, device=device, dtype=dtype), 66 | ) 67 | else: 68 | self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) 69 | 70 | def forward(self, 71 | model_input: torch.FloatTensor, 72 | timestep: torch.LongTensor, 73 | context: torch.FloatTensor): 74 | 75 | r""" 76 | Args: 77 | model_input (torch.FloatTensor): [bs, n_data, c] 78 | timestep (torch.LongTensor): [bs,] 79 | context (torch.FloatTensor): [bs, context_tokens, c] 80 | 81 | Returns: 82 | sample (torch.FloatTensor): [bs, n_data, c] 83 | 84 | """ 85 | 86 | _, n_data, _ = model_input.shape 87 | 88 | # 1. time 89 | t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) 90 | 91 | # 2. conditions projector 92 | context = self.context_embed(context) 93 | 94 | # 3. denoiser 95 | x = self.input_proj(model_input) 96 | x = torch.cat([t_emb, context, x], dim=1) 97 | x = self.backbone(x) 98 | x = self.ln_post(x) 99 | x = x[:, -n_data:] 100 | sample = self.output_proj(x) 101 | 102 | return sample 103 | 104 | 105 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BaseDenoiser(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, t, context): 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from typing import Tuple, List, Union, Optional 6 | from diffusers.schedulers import DDIMScheduler 7 | 8 | 9 | __all__ = ["ddim_sample"] 10 | 11 | 12 | def ddim_sample(ddim_scheduler: DDIMScheduler, 13 | diffusion_model: torch.nn.Module, 14 | shape: Union[List[int], Tuple[int]], 15 | cond: torch.FloatTensor, 16 | steps: int, 17 | eta: float = 0.0, 18 | guidance_scale: float = 3.0, 19 | do_classifier_free_guidance: bool = True, 20 | generator: Optional[torch.Generator] = None, 21 | device: torch.device = "cuda:0", 22 | disable_prog: bool = True): 23 | 24 | assert steps > 0, f"{steps} must > 0." 25 | 26 | # init latents 27 | bsz = cond.shape[0] 28 | if do_classifier_free_guidance: 29 | bsz = bsz // 2 30 | 31 | latents = torch.randn( 32 | (bsz, *shape), 33 | generator=generator, 34 | device=cond.device, 35 | dtype=cond.dtype, 36 | ) 37 | # scale the initial noise by the standard deviation required by the scheduler 38 | latents = latents * ddim_scheduler.init_noise_sigma 39 | # set timesteps 40 | ddim_scheduler.set_timesteps(steps) 41 | timesteps = ddim_scheduler.timesteps.to(device) 42 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 43 | # eta (η) is only used with the DDIMScheduler, and between [0, 1] 44 | extra_step_kwargs = { 45 | "eta": eta, 46 | "generator": generator 47 | } 48 | 49 | # reverse 50 | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): 51 | # expand the latents if we are doing classifier free guidance 52 | latent_model_input = ( 53 | torch.cat([latents] * 2) 54 | if do_classifier_free_guidance 55 | else latents 56 | ) 57 | # latent_model_input = scheduler.scale_model_input(latent_model_input, t) 58 | # predict the noise residual 59 | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) 60 | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) 61 | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) 62 | 63 | # perform guidance 64 | if do_classifier_free_guidance: 65 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 66 | noise_pred = noise_pred_uncond + guidance_scale * ( 67 | noise_pred_text - noise_pred_uncond 68 | ) 69 | # text_embeddings_for_guidance = encoder_hidden_states.chunk( 70 | # 2)[1] if do_classifier_free_guidance else encoder_hidden_states 71 | # compute the previous noisy sample x_t -> x_t-1 72 | latents = ddim_scheduler.step( 73 | noise_pred, t, latents, **extra_step_kwargs 74 | ).prev_sample 75 | 76 | yield latents, t 77 | 78 | 79 | def karra_sample(): 80 | pass 81 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/conditional_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .clip import CLIPEncoder 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/conditional_encoders/clip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from dataclasses import dataclass 7 | from torchvision.transforms import Normalize 8 | from transformers import CLIPModel, CLIPTokenizer 9 | from transformers.utils import ModelOutput 10 | from typing import Iterable, Optional, Union, List 11 | 12 | 13 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 14 | 15 | 16 | @dataclass 17 | class CLIPEmbedOutput(ModelOutput): 18 | last_hidden_state: torch.FloatTensor = None 19 | pooler_output: torch.FloatTensor = None 20 | embeds: torch.FloatTensor = None 21 | 22 | 23 | class CLIPEncoder(torch.nn.Module): 24 | 25 | def __init__(self, model_path="openai/clip-vit-base-patch32"): 26 | 27 | super().__init__() 28 | 29 | # Load the CLIP model and processor 30 | self.model: CLIPModel = CLIPModel.from_pretrained(model_path) 31 | self.tokenizer = CLIPTokenizer.from_pretrained(model_path) 32 | self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 33 | 34 | self.model.training = False 35 | for p in self.model.parameters(): 36 | p.requires_grad = False 37 | 38 | @torch.no_grad() 39 | def encode_image(self, images: Iterable[Optional[ImageType]]): 40 | pixel_values = self.image_preprocess(images) 41 | 42 | vision_outputs = self.model.vision_model(pixel_values=pixel_values) 43 | 44 | pooler_output = vision_outputs[1] # pooled_output 45 | image_features = self.model.visual_projection(pooler_output) 46 | 47 | visual_embeds = CLIPEmbedOutput( 48 | last_hidden_state=vision_outputs.last_hidden_state, 49 | pooler_output=pooler_output, 50 | embeds=image_features 51 | ) 52 | 53 | return visual_embeds 54 | 55 | @torch.no_grad() 56 | def encode_text(self, texts: List[str]): 57 | text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") 58 | 59 | text_outputs = self.model.text_model(input_ids=text_inputs) 60 | 61 | pooler_output = text_outputs[1] # pooled_output 62 | text_features = self.model.text_projection(pooler_output) 63 | 64 | text_embeds = CLIPEmbedOutput( 65 | last_hidden_state=text_outputs.last_hidden_state, 66 | pooler_output=pooler_output, 67 | embeds=text_features 68 | ) 69 | 70 | return text_embeds 71 | 72 | def forward(self, 73 | images: Iterable[Optional[ImageType]], 74 | texts: List[str]): 75 | 76 | visual_embeds = self.encode_image(images) 77 | text_embeds = self.encode_text(texts) 78 | 79 | return visual_embeds, text_embeds 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import checkpoint 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 4 | """ 5 | 6 | import torch 7 | from typing import Callable, Iterable, Sequence, Union 8 | 9 | 10 | def checkpoint( 11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 12 | inputs: Sequence[torch.Tensor], 13 | params: Iterable[torch.Tensor], 14 | flag: bool, 15 | use_deepspeed: bool = False 16 | ): 17 | """ 18 | Evaluate a function without caching intermediate activations, allowing for 19 | reduced memory at the expense of extra compute in the backward pass. 20 | :param func: the function to evaluate. 21 | :param inputs: the argument sequence to pass to `func`. 22 | :param params: a sequence of parameters `func` depends on but does not 23 | explicitly take as arguments. 24 | :param flag: if False, disable gradient checkpointing. 25 | :param use_deepspeed: if True, use deepspeed 26 | """ 27 | if flag: 28 | if use_deepspeed: 29 | import deepspeed 30 | return deepspeed.checkpointing.checkpoint(func, *inputs) 31 | 32 | args = tuple(inputs) + tuple(params) 33 | return CheckpointFunction.apply(func, len(inputs), *args) 34 | else: 35 | return func(*inputs) 36 | 37 | 38 | class CheckpointFunction(torch.autograd.Function): 39 | @staticmethod 40 | @torch.cuda.amp.custom_fwd 41 | def forward(ctx, run_function, length, *args): 42 | ctx.run_function = run_function 43 | ctx.input_tensors = list(args[:length]) 44 | ctx.input_params = list(args[length:]) 45 | 46 | with torch.no_grad(): 47 | output_tensors = ctx.run_function(*ctx.input_tensors) 48 | return output_tensors 49 | 50 | @staticmethod 51 | @torch.cuda.amp.custom_bwd 52 | def backward(ctx, *output_grads): 53 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 54 | with torch.enable_grad(): 55 | # Fixes a bug where the first op in run_function modifies the 56 | # Tensor storage in place, which is not allowed for detach()'d 57 | # Tensors. 58 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 59 | output_tensors = ctx.run_function(*shallow_copies) 60 | input_grads = torch.autograd.grad( 61 | output_tensors, 62 | ctx.input_tensors + ctx.input_params, 63 | output_grads, 64 | allow_unused=True, 65 | ) 66 | del ctx.input_tensors 67 | del ctx.input_params 68 | del output_tensors 69 | return (None, None) + input_grads 70 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/diffusion_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from typing import Optional 7 | 8 | from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint 9 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import ( 10 | init_linear, 11 | MLP, 12 | MultiheadCrossAttention, 13 | MultiheadAttention, 14 | ResidualAttentionBlock 15 | ) 16 | 17 | 18 | class AdaLayerNorm(nn.Module): 19 | def __init__(self, 20 | device: torch.device, 21 | dtype: torch.dtype, 22 | width: int): 23 | 24 | super().__init__() 25 | 26 | self.silu = nn.SiLU(inplace=True) 27 | self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype) 28 | self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype) 29 | 30 | def forward(self, x, timestep): 31 | emb = self.linear(timestep) 32 | scale, shift = torch.chunk(emb, 2, dim=2) 33 | x = self.layernorm(x) * (1 + scale) + shift 34 | return x 35 | 36 | 37 | class DitBlock(nn.Module): 38 | def __init__( 39 | self, 40 | *, 41 | device: torch.device, 42 | dtype: torch.dtype, 43 | n_ctx: int, 44 | width: int, 45 | heads: int, 46 | context_dim: int, 47 | qkv_bias: bool = False, 48 | init_scale: float = 1.0, 49 | use_checkpoint: bool = False 50 | ): 51 | super().__init__() 52 | 53 | self.use_checkpoint = use_checkpoint 54 | 55 | self.attn = MultiheadAttention( 56 | device=device, 57 | dtype=dtype, 58 | n_ctx=n_ctx, 59 | width=width, 60 | heads=heads, 61 | init_scale=init_scale, 62 | qkv_bias=qkv_bias 63 | ) 64 | self.ln_1 = AdaLayerNorm(device, dtype, width) 65 | 66 | if context_dim is not None: 67 | self.ln_2 = AdaLayerNorm(device, dtype, width) 68 | self.cross_attn = MultiheadCrossAttention( 69 | device=device, 70 | dtype=dtype, 71 | width=width, 72 | heads=heads, 73 | data_width=context_dim, 74 | init_scale=init_scale, 75 | qkv_bias=qkv_bias 76 | ) 77 | 78 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 79 | self.ln_3 = AdaLayerNorm(device, dtype, width) 80 | 81 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 82 | return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint) 83 | 84 | def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 85 | x = x + self.attn(self.ln_1(x, t)) 86 | if context is not None: 87 | x = x + self.cross_attn(self.ln_2(x, t), context) 88 | x = x + self.mlp(self.ln_3(x, t)) 89 | return x 90 | 91 | 92 | class DiT(nn.Module): 93 | def __init__( 94 | self, 95 | *, 96 | device: Optional[torch.device], 97 | dtype: Optional[torch.dtype], 98 | n_ctx: int, 99 | width: int, 100 | layers: int, 101 | heads: int, 102 | context_dim: int, 103 | init_scale: float = 0.25, 104 | qkv_bias: bool = False, 105 | use_checkpoint: bool = False 106 | ): 107 | super().__init__() 108 | self.n_ctx = n_ctx 109 | self.width = width 110 | self.layers = layers 111 | 112 | self.resblocks = nn.ModuleList( 113 | [ 114 | DitBlock( 115 | device=device, 116 | dtype=dtype, 117 | n_ctx=n_ctx, 118 | width=width, 119 | heads=heads, 120 | context_dim=context_dim, 121 | qkv_bias=qkv_bias, 122 | init_scale=init_scale, 123 | use_checkpoint=use_checkpoint 124 | ) 125 | for _ in range(layers) 126 | ] 127 | ) 128 | 129 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 130 | for block in self.resblocks: 131 | x = block(x, t, context) 132 | return x 133 | 134 | 135 | class UNetDiffusionTransformer(nn.Module): 136 | def __init__( 137 | self, 138 | *, 139 | device: Optional[torch.device], 140 | dtype: Optional[torch.dtype], 141 | n_ctx: int, 142 | width: int, 143 | layers: int, 144 | heads: int, 145 | init_scale: float = 0.25, 146 | qkv_bias: bool = False, 147 | skip_ln: bool = False, 148 | use_checkpoint: bool = False 149 | ): 150 | super().__init__() 151 | 152 | self.n_ctx = n_ctx 153 | self.width = width 154 | self.layers = layers 155 | 156 | self.encoder = nn.ModuleList() 157 | for _ in range(layers): 158 | resblock = ResidualAttentionBlock( 159 | device=device, 160 | dtype=dtype, 161 | n_ctx=n_ctx, 162 | width=width, 163 | heads=heads, 164 | init_scale=init_scale, 165 | qkv_bias=qkv_bias, 166 | use_checkpoint=use_checkpoint 167 | ) 168 | self.encoder.append(resblock) 169 | 170 | self.middle_block = ResidualAttentionBlock( 171 | device=device, 172 | dtype=dtype, 173 | n_ctx=n_ctx, 174 | width=width, 175 | heads=heads, 176 | init_scale=init_scale, 177 | qkv_bias=qkv_bias, 178 | use_checkpoint=use_checkpoint 179 | ) 180 | 181 | self.decoder = nn.ModuleList() 182 | for _ in range(layers): 183 | resblock = ResidualAttentionBlock( 184 | device=device, 185 | dtype=dtype, 186 | n_ctx=n_ctx, 187 | width=width, 188 | heads=heads, 189 | init_scale=init_scale, 190 | qkv_bias=qkv_bias, 191 | use_checkpoint=use_checkpoint 192 | ) 193 | linear = nn.Linear(width * 2, width, device=device, dtype=dtype) 194 | init_linear(linear, init_scale) 195 | 196 | layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None 197 | 198 | self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) 199 | 200 | def forward(self, x: torch.Tensor): 201 | 202 | enc_outputs = [] 203 | for block in self.encoder: 204 | x = block(x) 205 | enc_outputs.append(x) 206 | 207 | x = self.middle_block(x) 208 | 209 | for i, (resblock, linear, layer_norm) in enumerate(self.decoder): 210 | x = torch.cat([enc_outputs.pop(), x], dim=-1) 211 | x = linear(x) 212 | 213 | if layer_norm is not None: 214 | x = layer_norm(x) 215 | 216 | x = resblock(x) 217 | 218 | return x 219 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union, List 4 | 5 | 6 | class AbstractDistribution(object): 7 | def sample(self): 8 | raise NotImplementedError() 9 | 10 | def mode(self): 11 | raise NotImplementedError() 12 | 13 | 14 | class DiracDistribution(AbstractDistribution): 15 | def __init__(self, value): 16 | self.value = value 17 | 18 | def sample(self): 19 | return self.value 20 | 21 | def mode(self): 22 | return self.value 23 | 24 | 25 | class DiagonalGaussianDistribution(object): 26 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): 27 | self.feat_dim = feat_dim 28 | self.parameters = parameters 29 | 30 | if isinstance(parameters, list): 31 | self.mean = parameters[0] 32 | self.logvar = parameters[1] 33 | else: 34 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) 35 | 36 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 37 | self.deterministic = deterministic 38 | self.std = torch.exp(0.5 * self.logvar) 39 | self.var = torch.exp(self.logvar) 40 | if self.deterministic: 41 | self.var = self.std = torch.zeros_like(self.mean) 42 | 43 | def sample(self): 44 | x = self.mean + self.std * torch.randn_like(self.mean) 45 | return x 46 | 47 | def kl(self, other=None, dims=(1, 2, 3)): 48 | if self.deterministic: 49 | return torch.Tensor([0.]) 50 | else: 51 | if other is None: 52 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 53 | + self.var - 1.0 - self.logvar, 54 | dim=dims) 55 | else: 56 | return 0.5 * torch.mean( 57 | torch.pow(self.mean - other.mean, 2) / other.var 58 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 59 | dim=dims) 60 | 61 | def nll(self, sample, dims=(1, 2, 3)): 62 | if self.deterministic: 63 | return torch.Tensor([0.]) 64 | logtwopi = np.log(2.0 * np.pi) 65 | return 0.5 * torch.sum( 66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 67 | dim=dims) 68 | 69 | def mode(self): 70 | return self.mean 71 | 72 | 73 | def normal_kl(mean1, logvar1, mean2, logvar2): 74 | """ 75 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 76 | Compute the KL divergence between two gaussians. 77 | Shapes are automatically broadcasted, so batches can be compared to 78 | scalars, among other use cases. 79 | """ 80 | tensor = None 81 | for obj in (mean1, logvar1, mean2, logvar2): 82 | if isinstance(obj, torch.Tensor): 83 | tensor = obj 84 | break 85 | assert tensor is not None, "at least one argument must be a Tensor" 86 | 87 | # Force variances to be Tensors. Broadcasting helps convert scalars to 88 | # Tensors, but it does not work for torch.exp(). 89 | logvar1, logvar2 = [ 90 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 91 | for x in (logvar1, logvar2) 92 | ] 93 | 94 | return 0.5 * ( 95 | -1.0 96 | + logvar2 97 | - logvar1 98 | + torch.exp(logvar1 - logvar2) 99 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 100 | ) 101 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/embedder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] 9 | 10 | 11 | class FourierEmbedder(nn.Module): 12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts 13 | each feature dimension of `x[..., i]` into: 14 | [ 15 | sin(x[..., i]), 16 | sin(f_1*x[..., i]), 17 | sin(f_2*x[..., i]), 18 | ... 19 | sin(f_N * x[..., i]), 20 | cos(x[..., i]), 21 | cos(f_1*x[..., i]), 22 | cos(f_2*x[..., i]), 23 | ... 24 | cos(f_N * x[..., i]), 25 | x[..., i] # only present if include_input is True. 26 | ], here f_i is the frequency. 27 | 28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. 29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; 30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. 31 | 32 | Args: 33 | num_freqs (int): the number of frequencies, default is 6; 34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; 36 | input_dim (int): the input dimension, default is 3; 37 | include_input (bool): include the input tensor or not, default is True. 38 | 39 | Attributes: 40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); 42 | 43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), 44 | otherwise, it is input_dim * num_freqs * 2. 45 | 46 | """ 47 | 48 | def __init__(self, 49 | num_freqs: int = 6, 50 | logspace: bool = True, 51 | input_dim: int = 3, 52 | include_input: bool = True, 53 | include_pi: bool = True) -> None: 54 | 55 | """The initialization""" 56 | 57 | super().__init__() 58 | 59 | if logspace: 60 | frequencies = 2.0 ** torch.arange( 61 | num_freqs, 62 | dtype=torch.float32 63 | ) 64 | else: 65 | frequencies = torch.linspace( 66 | 1.0, 67 | 2.0 ** (num_freqs - 1), 68 | num_freqs, 69 | dtype=torch.float32 70 | ) 71 | 72 | if include_pi: 73 | frequencies *= torch.pi 74 | 75 | self.register_buffer("frequencies", frequencies, persistent=False) 76 | self.include_input = include_input 77 | self.num_freqs = num_freqs 78 | 79 | self.out_dim = self.get_dims(input_dim) 80 | 81 | def get_dims(self, input_dim): 82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0 83 | out_dim = input_dim * (self.num_freqs * 2 + temp) 84 | 85 | return out_dim 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | """ Forward process. 89 | 90 | Args: 91 | x: tensor of shape [..., dim] 92 | 93 | Returns: 94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] 95 | where temp is 1 if include_input is True and 0 otherwise. 96 | """ 97 | 98 | if self.num_freqs > 0: 99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) 100 | if self.include_input: 101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) 102 | else: 103 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 104 | else: 105 | return x 106 | 107 | 108 | class LearnedFourierEmbedder(nn.Module): 109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """ 110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 111 | 112 | def __init__(self, in_channels, dim): 113 | super().__init__() 114 | assert (dim % 2) == 0 115 | half_dim = dim // 2 116 | per_channel_dim = half_dim // in_channels 117 | self.weights = nn.Parameter(torch.randn(per_channel_dim)) 118 | 119 | def forward(self, x): 120 | """ 121 | 122 | Args: 123 | x (torch.FloatTensor): [..., c] 124 | 125 | Returns: 126 | x (torch.FloatTensor): [..., d] 127 | """ 128 | 129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] 130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) 131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) 132 | return fouriered 133 | 134 | 135 | class TriplaneLearnedFourierEmbedder(nn.Module): 136 | def __init__(self, in_channels, dim): 137 | super().__init__() 138 | 139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 142 | 143 | self.out_dim = in_channels + dim 144 | 145 | def forward(self, x): 146 | 147 | yz_embed = self.yz_plane_embedder(x) 148 | xz_embed = self.xz_plane_embedder(x) 149 | xy_embed = self.xy_plane_embedder(x) 150 | 151 | embed = yz_embed + xz_embed + xy_embed 152 | 153 | return embed 154 | 155 | 156 | def sequential_pos_embed(num_len, embed_dim): 157 | assert embed_dim % 2 == 0 158 | 159 | pos = torch.arange(num_len, dtype=torch.float32) 160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32) 161 | omega /= embed_dim / 2. 162 | omega = 1. / 10000 ** omega # (D/2,) 163 | 164 | pos = pos.reshape(-1) # (M,) 165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 166 | 167 | emb_sin = torch.sin(out) # (M, D/2) 168 | emb_cos = torch.cos(out) # (M, D/2) 169 | 170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 171 | 172 | return embeddings 173 | 174 | 175 | def timestep_embedding(timesteps, dim, max_period=10000): 176 | """ 177 | Create sinusoidal timestep embeddings. 178 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 179 | These may be fractional. 180 | :param dim: the dimension of the output. 181 | :param max_period: controls the minimum frequency of the embeddings. 182 | :return: an [N x dim] Tensor of positional embeddings. 183 | """ 184 | half = dim // 2 185 | freqs = torch.exp( 186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 187 | ).to(device=timesteps.device) 188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None] 189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 190 | if dim % 2: 191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 192 | return embedding 193 | 194 | 195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, 196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, 197 | log2_hashmap_size=19, desired_resolution=None): 198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): 199 | return nn.Identity(), input_dim 200 | 201 | elif embed_type == "fourier": 202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, 203 | logspace=True, include_input=True) 204 | return embedder_obj, embedder_obj.out_dim 205 | 206 | elif embed_type == "hashgrid": 207 | raise NotImplementedError 208 | 209 | elif embed_type == "sphere_harmonic": 210 | raise NotImplementedError 211 | 212 | else: 213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") 214 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/clip_asl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange 6 | from transformers import CLIPModel 7 | 8 | from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule 9 | 10 | 11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): 12 | 13 | def __init__(self, *, 14 | shape_model, 15 | clip_model_version: str = "openai/clip-vit-large-patch14"): 16 | 17 | super().__init__() 18 | 19 | # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) 20 | # for params in self.clip_model.parameters(): 21 | # params.requires_grad = False 22 | self.clip_model = None 23 | self.shape_model = shape_model 24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width)) 25 | # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5) 26 | 27 | def set_shape_model_only(self): 28 | self.clip_model = None 29 | 30 | def encode_shape_embed(self, surface, return_latents: bool = False): 31 | """ 32 | 33 | Args: 34 | surface (torch.FloatTensor): [bs, n, 3 + c] 35 | return_latents (bool): 36 | 37 | Returns: 38 | x (torch.FloatTensor): [bs, projection_dim] 39 | shape_latents (torch.FloatTensor): [bs, m, d] 40 | """ 41 | 42 | pc = surface[..., 0:3] 43 | feats = surface[..., 3:] 44 | 45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) 46 | x = shape_embed @ self.shape_projection 47 | 48 | if return_latents: 49 | return x, shape_latents 50 | else: 51 | return x 52 | 53 | def encode_image_embed(self, image): 54 | """ 55 | 56 | Args: 57 | image (torch.FloatTensor): [bs, 3, h, w] 58 | 59 | Returns: 60 | x (torch.FloatTensor): [bs, projection_dim] 61 | """ 62 | 63 | x = self.clip_model.get_image_features(image) 64 | 65 | return x 66 | 67 | def encode_text_embed(self, text): 68 | x = self.clip_model.get_text_features(text) 69 | return x 70 | 71 | def forward(self, surface, image, text): 72 | """ 73 | 74 | Args: 75 | surface (torch.FloatTensor): 76 | image (torch.FloatTensor): [bs, 3, 224, 224] 77 | text (torch.LongTensor): [bs, num_templates, 77] 78 | 79 | Returns: 80 | embed_outputs (dict): the embedding outputs, and it contains: 81 | - image_embed (torch.FloatTensor): 82 | - text_embed (torch.FloatTensor): 83 | - shape_embed (torch.FloatTensor): 84 | - logit_scale (float): 85 | """ 86 | 87 | # # text embedding 88 | # text_embed_all = [] 89 | # for i in range(text.shape[0]): 90 | # text_for_one_sample = text[i] 91 | # text_embed = self.encode_text_embed(text_for_one_sample) 92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 93 | # text_embed = text_embed.mean(dim=0) 94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 95 | # text_embed_all.append(text_embed) 96 | # text_embed_all = torch.stack(text_embed_all) 97 | 98 | b = text.shape[0] 99 | text_tokens = rearrange(text, "b t l -> (b t) l") 100 | text_embed = self.encode_text_embed(text_tokens) 101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 102 | text_embed = text_embed.mean(dim=1) 103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 104 | 105 | # image embedding 106 | image_embed = self.encode_image_embed(image) 107 | 108 | # shape embedding 109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) 110 | 111 | embed_outputs = { 112 | "image_embed": image_embed, 113 | "text_embed": text_embed, 114 | "shape_embed": shape_embed, 115 | # "logit_scale": self.clip_model.logit_scale.exp() 116 | } 117 | 118 | return embed_outputs, shape_latents 119 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from einops import repeat 6 | import numpy as np 7 | from typing import Callable, Tuple, List, Union, Optional 8 | from skimage import measure 9 | 10 | from MeshAnything.miche.michelangelo.graphics.primitives import generate_dense_grid_points 11 | 12 | 13 | @torch.no_grad() 14 | def extract_geometry(geometric_func: Callable, 15 | device: torch.device, 16 | batch_size: int = 1, 17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 18 | octree_depth: int = 7, 19 | num_chunks: int = 10000, 20 | disable: bool = True): 21 | """ 22 | 23 | Args: 24 | geometric_func: 25 | device: 26 | bounds: 27 | octree_depth: 28 | batch_size: 29 | num_chunks: 30 | disable: 31 | 32 | Returns: 33 | 34 | """ 35 | 36 | if isinstance(bounds, float): 37 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] 38 | 39 | bbox_min = np.array(bounds[0:3]) 40 | bbox_max = np.array(bounds[3:6]) 41 | bbox_size = bbox_max - bbox_min 42 | 43 | xyz_samples, grid_size, length = generate_dense_grid_points( 44 | bbox_min=bbox_min, 45 | bbox_max=bbox_max, 46 | octree_depth=octree_depth, 47 | indexing="ij" 48 | ) 49 | xyz_samples = torch.FloatTensor(xyz_samples) 50 | 51 | batch_logits = [] 52 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), 53 | desc="Implicit Function:", disable=disable, leave=False): 54 | queries = xyz_samples[start: start + num_chunks, :].to(device) 55 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size) 56 | 57 | logits = geometric_func(batch_queries) 58 | batch_logits.append(logits.cpu()) 59 | 60 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() 61 | 62 | mesh_v_f = [] 63 | has_surface = np.zeros((batch_size,), dtype=np.bool_) 64 | for i in range(batch_size): 65 | try: 66 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") 67 | vertices = vertices / grid_size * bbox_size + bbox_min 68 | # vertices[:, [0, 1]] = vertices[:, [1, 0]] 69 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) 70 | has_surface[i] = True 71 | 72 | except ValueError: 73 | mesh_v_f.append((None, None)) 74 | has_surface[i] = False 75 | 76 | except RuntimeError: 77 | mesh_v_f.append((None, None)) 78 | has_surface[i] = False 79 | 80 | return mesh_v_f, has_surface 81 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/tsal_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from typing import Tuple, List, Optional 5 | 6 | 7 | class Point2MeshOutput(object): 8 | def __init__(self): 9 | self.mesh_v = None 10 | self.mesh_f = None 11 | self.center = None 12 | self.pc = None 13 | 14 | 15 | class Latent2MeshOutput(object): 16 | 17 | def __init__(self): 18 | self.mesh_v = None 19 | self.mesh_f = None 20 | 21 | 22 | class AlignedMeshOutput(object): 23 | 24 | def __init__(self): 25 | self.mesh_v = None 26 | self.mesh_f = None 27 | self.surface = None 28 | self.image = None 29 | self.text: Optional[str] = None 30 | self.shape_text_similarity: Optional[float] = None 31 | self.shape_image_similarity: Optional[float] = None 32 | 33 | 34 | class ShapeAsLatentPLModule(nn.Module): 35 | latent_shape: Tuple[int] 36 | 37 | def encode(self, surface, *args, **kwargs): 38 | raise NotImplementedError 39 | 40 | def decode(self, z_q, *args, **kwargs): 41 | raise NotImplementedError 42 | 43 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 44 | raise NotImplementedError 45 | 46 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 47 | raise NotImplementedError 48 | 49 | 50 | class ShapeAsLatentModule(nn.Module): 51 | latent_shape: Tuple[int, int] 52 | 53 | def __init__(self, *args, **kwargs): 54 | super().__init__() 55 | 56 | def encode(self, *args, **kwargs): 57 | raise NotImplementedError 58 | 59 | def decode(self, *args, **kwargs): 60 | raise NotImplementedError 61 | 62 | def query_geometry(self, *args, **kwargs): 63 | raise NotImplementedError 64 | 65 | 66 | class AlignedShapeAsLatentPLModule(nn.Module): 67 | latent_shape: Tuple[int] 68 | 69 | def set_shape_model_only(self): 70 | raise NotImplementedError 71 | 72 | def encode(self, surface, *args, **kwargs): 73 | raise NotImplementedError 74 | 75 | def decode(self, z_q, *args, **kwargs): 76 | raise NotImplementedError 77 | 78 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 79 | raise NotImplementedError 80 | 81 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 82 | raise NotImplementedError 83 | 84 | 85 | class AlignedShapeAsLatentModule(nn.Module): 86 | shape_model: ShapeAsLatentModule 87 | latent_shape: Tuple[int, int] 88 | 89 | def __init__(self, *args, **kwargs): 90 | super().__init__() 91 | 92 | def set_shape_model_only(self): 93 | raise NotImplementedError 94 | 95 | def encode_image_embed(self, *args, **kwargs): 96 | raise NotImplementedError 97 | 98 | def encode_text_embed(self, *args, **kwargs): 99 | raise NotImplementedError 100 | 101 | def encode_shape_embed(self, *args, **kwargs): 102 | raise NotImplementedError 103 | 104 | 105 | class TexturedShapeAsLatentModule(nn.Module): 106 | 107 | def __init__(self, *args, **kwargs): 108 | super().__init__() 109 | 110 | def encode(self, *args, **kwargs): 111 | raise NotImplementedError 112 | 113 | def decode(self, *args, **kwargs): 114 | raise NotImplementedError 115 | 116 | def query_geometry(self, *args, **kwargs): 117 | raise NotImplementedError 118 | 119 | def query_color(self, *args, **kwargs): 120 | raise NotImplementedError 121 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .misc import instantiate_from_config 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7): 7 | 8 | mse = torch.mean((x - y) ** 2) 9 | psnr = 10 * torch.log10(data_range / (mse + eps)) 10 | 11 | return psnr 12 | 13 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import io 5 | import tarfile 6 | import json 7 | import numpy as np 8 | import numpy.lib.format 9 | 10 | 11 | def mkdir(path): 12 | os.makedirs(path, exist_ok=True) 13 | return path 14 | 15 | 16 | def npy_loads(data): 17 | stream = io.BytesIO(data) 18 | return np.lib.format.read_array(stream) 19 | 20 | 21 | def npz_loads(data): 22 | return np.load(io.BytesIO(data)) 23 | 24 | 25 | def json_loads(data): 26 | return json.loads(data) 27 | 28 | 29 | def load_json(filepath): 30 | with open(filepath, "r") as f: 31 | data = json.load(f) 32 | return data 33 | 34 | 35 | def write_json(filepath, data): 36 | with open(filepath, "w") as f: 37 | json.dump(data, f, indent=2) 38 | 39 | 40 | def extract_tar(tar_path, tar_cache_folder): 41 | 42 | with tarfile.open(tar_path, "r") as tar: 43 | tar.extractall(path=tar_cache_folder) 44 | 45 | tar_uids = sorted(os.listdir(tar_cache_folder)) 46 | print(f"extract tar: {tar_path} to {tar_cache_folder}") 47 | return tar_uids 48 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import importlib 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | 10 | def get_obj_from_str(string, reload=False): 11 | module, cls = string.rsplit(".", 1) 12 | if reload: 13 | module_imp = importlib.import_module(module) 14 | importlib.reload(module_imp) 15 | return getattr(importlib.import_module(module, package=None), cls) 16 | 17 | 18 | def get_obj_from_config(config): 19 | if "target" not in config: 20 | raise KeyError("Expected key `target` to instantiate.") 21 | 22 | return get_obj_from_str(config["target"]) 23 | 24 | 25 | def instantiate_from_config(config, **kwargs): 26 | if "target" not in config: 27 | raise KeyError("Expected key `target` to instantiate.") 28 | 29 | cls = get_obj_from_str(config["target"]) 30 | 31 | params = config.get("params", dict()) 32 | # params.update(kwargs) 33 | # instance = cls(**params) 34 | kwargs.update(params) 35 | instance = cls(**kwargs) 36 | 37 | return instance 38 | 39 | 40 | def is_dist_avail_and_initialized(): 41 | if not dist.is_available(): 42 | return False 43 | if not dist.is_initialized(): 44 | return False 45 | return True 46 | 47 | 48 | def get_rank(): 49 | if not is_dist_avail_and_initialized(): 50 | return 0 51 | return dist.get_rank() 52 | 53 | 54 | def get_world_size(): 55 | if not is_dist_avail_and_initialized(): 56 | return 1 57 | return dist.get_world_size() 58 | 59 | 60 | def all_gather_batch(tensors): 61 | """ 62 | Performs all_gather operation on the provided tensors. 63 | """ 64 | # Queue the gathered tensors 65 | world_size = get_world_size() 66 | # There is no need for reduction in the single-proc case 67 | if world_size == 1: 68 | return tensors 69 | tensor_list = [] 70 | output_tensor = [] 71 | for tensor in tensors: 72 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 73 | dist.all_gather( 74 | tensor_all, 75 | tensor, 76 | async_op=False # performance opt 77 | ) 78 | 79 | tensor_list.append(tensor_all) 80 | 81 | for tensor_all in tensor_list: 82 | output_tensor.append(torch.cat(tensor_all, dim=0)) 83 | return output_tensor 84 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/color_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | # Helper functions 6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): 7 | colormap = plt.cm.get_cmap(colormap) 8 | if normalize: 9 | vmin = np.min(inp) 10 | vmax = np.max(inp) 11 | 12 | norm = plt.Normalize(vmin, vmax) 13 | return colormap(norm(inp))[:, :3] 14 | 15 | 16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): 17 | # tex dims need to be power of two. 18 | array = np.ones((width, height, 3), dtype='float32') 19 | 20 | # width in texels of each checker 21 | checker_w = width / n_checkers_x 22 | checker_h = height / n_checkers_y 23 | 24 | for y in range(height): 25 | for x in range(width): 26 | color_key = int(x / checker_w) + int(y / checker_h) 27 | if color_key % 2 == 0: 28 | array[x, y, :] = [1., 0.874, 0.0] 29 | else: 30 | array[x, y, :] = [0., 0., 0.] 31 | return array 32 | 33 | 34 | def gen_circle(width=256, height=256): 35 | xx, yy = np.mgrid[:width, :height] 36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 37 | array = np.ones((width, height, 4), dtype='float32') 38 | array[:, :, 0] = (circle <= width) 39 | array[:, :, 1] = (circle <= width) 40 | array[:, :, 2] = (circle <= width) 41 | array[:, :, 3] = circle <= width 42 | return array 43 | 44 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/html_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | import base64 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def to_html_frame(content): 9 | 10 | html_frame = f""" 11 | 12 | 13 | {content} 14 | 15 | 16 | """ 17 | 18 | return html_frame 19 | 20 | 21 | def to_single_row_table(caption: str, content: str): 22 | 23 | table_html = f""" 24 | 25 | 26 | 27 | 28 | 29 |
{caption}
{content}
30 | """ 31 | 32 | return table_html 33 | 34 | 35 | def to_image_embed_tag(image: np.ndarray): 36 | 37 | # Convert np.ndarray to bytes 38 | img = Image.fromarray(image) 39 | raw_bytes = io.BytesIO() 40 | img.save(raw_bytes, "PNG") 41 | 42 | # Encode bytes to base64 43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8") 44 | 45 | image_tag = f""" 46 | Embedded Image 47 | """ 48 | 49 | return image_tag 50 | -------------------------------------------------------------------------------- /MeshAnything/miche/shapevae-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: MeshAnything.miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 3 | params: 4 | shape_module_cfg: 5 | target: MeshAnything.miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 6 | params: 7 | num_latents: 256 8 | embed_dim: 64 9 | point_feats: 3 # normal 10 | num_freqs: 8 11 | include_pi: false 12 | heads: 12 13 | width: 768 14 | num_encoder_layers: 8 15 | num_decoder_layers: 16 16 | use_ln_post: true 17 | init_scale: 0.25 18 | qkv_bias: false 19 | use_checkpoint: true 20 | aligned_module_cfg: 21 | target: MeshAnything.miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 22 | params: 23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 24 | 25 | loss_cfg: 26 | target: MeshAnything.miche.michelangelo.models.tsal.loss.ContrastKLNearFar 27 | params: 28 | contrast_weight: 0.1 29 | near_weight: 0.1 30 | kl_weight: 0.001 31 | 32 | optimizer_cfg: 33 | optimizer: 34 | target: torch.optim.AdamW 35 | params: 36 | betas: [0.9, 0.99] 37 | eps: 1.e-6 38 | weight_decay: 1.e-2 39 | 40 | scheduler: 41 | target: MeshAnything.miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 42 | params: 43 | warm_up_steps: 5000 44 | f_start: 1.e-6 45 | f_min: 1.e-3 46 | f_max: 1.0 47 | -------------------------------------------------------------------------------- /MeshAnything/models/meshanything_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import AutoModelForCausalLM 4 | from MeshAnything.miche.encode import load_model 5 | from MeshAnything.models.shape_opt import ShapeOPTConfig 6 | from einops import rearrange 7 | 8 | from huggingface_hub import PyTorchModelHubMixin 9 | 10 | class MeshAnythingV2(nn.Module, PyTorchModelHubMixin, 11 | repo_url="https://github.com/buaacyw/MeshAnythingV2", pipeline_tag="image-to-3d", license="mit"): 12 | def __init__(self, config={}): 13 | super().__init__() 14 | self.config = config 15 | self.point_encoder = load_model(ckpt_path=None) 16 | self.n_discrete_size = 128 17 | self.max_seq_ratio = 0.70 18 | self.face_per_token = 9 19 | self.cond_length = 257 20 | self.cond_dim = 768 21 | self.pad_id = -1 22 | self.n_max_triangles = 1600 23 | self.max_length = int(self.n_max_triangles * self.face_per_token * self.max_seq_ratio + 3 + self.cond_length) # add 1 24 | 25 | self.coor_continuous_range = (-0.5, 0.5) 26 | 27 | self.config = ShapeOPTConfig.from_pretrained( 28 | "facebook/opt-350m", 29 | n_positions=self.max_length, 30 | max_position_embeddings=self.max_length, 31 | vocab_size=self.n_discrete_size + 4, 32 | _attn_implementation="flash_attention_2" 33 | ) 34 | 35 | self.bos_token_id = 0 36 | self.eos_token_id = 1 37 | self.pad_token_id = 2 38 | 39 | self.config.bos_token_id = self.bos_token_id 40 | self.config.eos_token_id = self.eos_token_id 41 | self.config.pad_token_id = self.pad_token_id 42 | self.config._attn_implementation="flash_attention_2" 43 | self.config.n_discrete_size = self.n_discrete_size 44 | self.config.face_per_token = self.face_per_token 45 | self.config.cond_length = self.cond_length 46 | 47 | if self.config.word_embed_proj_dim != self.config.hidden_size: 48 | self.config.word_embed_proj_dim = self.config.hidden_size 49 | self.transformer = AutoModelForCausalLM.from_config( 50 | config=self.config, use_flash_attention_2 = True 51 | ) 52 | self.transformer.to_bettertransformer() 53 | 54 | self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim) 55 | self.cond_proj = nn.Linear(self.cond_dim * 2, self.config.word_embed_proj_dim) 56 | 57 | self.eval() 58 | 59 | def adjacent_detokenize(self, input_ids): 60 | input_ids = input_ids.reshape(input_ids.shape[0], -1) # B x L 61 | batch_size = input_ids.shape[0] 62 | continuous_coors = torch.zeros((batch_size, self.n_max_triangles * 3 * 10, 3), device=input_ids.device) 63 | continuous_coors[...] = float('nan') 64 | 65 | for i in range(batch_size): 66 | cur_ids = input_ids[i] 67 | coor_loop_check = 0 68 | vertice_count = 0 69 | continuous_coors[i, :3, :] = torch.tensor([[-0.1, 0.0, 0.1], [-0.1, 0.1, 0.2], [-0.3, 0.3, 0.2]], 70 | device=input_ids.device) 71 | for id in cur_ids: 72 | if id == self.pad_id: 73 | break 74 | elif id == self.n_discrete_size: 75 | if coor_loop_check < 9: 76 | break 77 | if coor_loop_check % 3 !=0: 78 | break 79 | coor_loop_check = 0 80 | else: 81 | 82 | if coor_loop_check % 3 == 0 and coor_loop_check >= 9: 83 | continuous_coors[i, vertice_count] = continuous_coors[i, vertice_count-2] 84 | continuous_coors[i, vertice_count+1] = continuous_coors[i, vertice_count-1] 85 | vertice_count += 2 86 | continuous_coors[i, vertice_count, coor_loop_check % 3] = undiscretize(id, self.coor_continuous_range[0], self.coor_continuous_range[1], self.n_discrete_size) 87 | if coor_loop_check % 3 == 2: 88 | vertice_count += 1 89 | coor_loop_check += 1 90 | 91 | continuous_coors = rearrange(continuous_coors, 'b (nf nv) c -> b nf nv c', nv=3, c=3) 92 | 93 | return continuous_coors # b, nf, 3, 3 94 | 95 | 96 | def forward(self, data_dict: dict, is_eval: bool = False) -> dict: 97 | if not is_eval: 98 | return self.train_one_step(data_dict) 99 | else: 100 | return self.generate(data_dict) 101 | 102 | def process_point_feature(self, point_feature): 103 | encode_feature = torch.zeros(point_feature.shape[0], self.cond_length, self.config.word_embed_proj_dim, 104 | device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype) 105 | encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0]) 106 | shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:]) 107 | encode_feature[:, 1:] = self.cond_proj(torch.cat([point_feature[:, 1:], shape_latents], dim=-1)) 108 | 109 | return encode_feature 110 | 111 | @torch.no_grad() 112 | def forward(self, pc_normal, sampling=False) -> dict: 113 | batch_size = pc_normal.shape[0] 114 | point_feature = self.point_encoder.encode_latents(pc_normal) 115 | processed_point_feature = self.process_point_feature(point_feature) 116 | generate_length = self.max_length - self.cond_length 117 | net_device = next(self.parameters()).device 118 | outputs = torch.ones(batch_size, generate_length).long().to(net_device) * self.eos_token_id 119 | # batch x ntokens 120 | if not sampling: 121 | results = self.transformer.generate( 122 | inputs_embeds=processed_point_feature, 123 | max_new_tokens=generate_length, # all faces plus two 124 | num_beams=1, 125 | bos_token_id=self.bos_token_id, 126 | eos_token_id=self.eos_token_id, 127 | pad_token_id=self.pad_token_id, 128 | ) 129 | else: 130 | results = self.transformer.generate( 131 | inputs_embeds = processed_point_feature, 132 | max_new_tokens = generate_length, # all faces plus two 133 | do_sample=True, 134 | top_k=50, 135 | top_p=0.95, 136 | bos_token_id = self.bos_token_id, 137 | eos_token_id = self.eos_token_id, 138 | pad_token_id = self.pad_token_id, 139 | ) 140 | assert results.shape[1] <= generate_length # B x ID bos is not included since it's predicted 141 | outputs[:, :results.shape[1]] = results 142 | # batch x ntokens ====> batch x ntokens x D 143 | outputs = outputs[:, 1: -1] 144 | 145 | outputs[outputs == self.bos_token_id] = self.pad_id 146 | outputs[outputs == self.eos_token_id] = self.pad_id 147 | outputs[outputs == self.pad_token_id] = self.pad_id 148 | 149 | outputs[outputs != self.pad_id] -= 3 150 | gen_mesh = self.adjacent_detokenize(outputs) 151 | 152 | return gen_mesh 153 | 154 | def undiscretize( 155 | t, 156 | low,#-0.5 157 | high,# 0.5 158 | num_discrete 159 | ): 160 | t = t.float() #[0, num_discrete-1] 161 | 162 | t /= num_discrete # 0<=t<1 163 | t = t * (high - low) + low # -0.5 <= t < 0.5 164 | return t 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

MeshAnything V2:
Artist-Created Mesh Generation
With Adjacent Mesh Tokenization

3 | 4 |

5 | Yiwen Chen1,2, 6 | Yikai Wang3*, 7 | Yihao Luo4, 8 | Zhengyi Wang2,3, 9 |
10 | Zilong Chen2,3, 11 | Jun Zhu2,3, 12 | Chi Zhang5*, 13 | Guosheng Lin1* 14 |
15 | *Corresponding authors. 16 |
17 | 1S-Lab, Nanyang Technological University, 18 | 2Shengshu, 19 |
20 | 3Tsinghua University, 21 | 4Imperial College London, 22 | 5Westlake University 23 |

24 | 25 | 26 | 27 |
28 | 29 |      30 |      31 |      32 | 33 | 34 |
35 | 36 | 37 |

38 | Demo GIF 39 |

40 | 41 | 42 | ## Contents 43 | - [Contents](#contents) 44 | - [Installation](#installation) 45 | - [Usage](#usage) 46 | - [Training](#training) 47 | - [Important Notes](#important-notes) 48 | - [Acknowledgement](#acknowledgement) 49 | - [BibTeX](#bibtex) 50 | 51 | ## Installation 52 | Our environment has been tested on Ubuntu 22, CUDA 11.8 with A800. 53 | 1. Clone our repo and create conda environment 54 | ``` 55 | git clone https://github.com/buaacyw/MeshAnythingV2.git && cd MeshAnythingV2 56 | conda create -n MeshAnythingV2 python==3.10.13 -y 57 | conda activate MeshAnythingV2 58 | pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118 59 | pip install -r requirements.txt 60 | pip install -r training_requirements.txt # in case you want to train 61 | pip install flash-attn --no-build-isolation 62 | pip install -U gradio 63 | ``` 64 | 65 | ## Usage 66 | 67 | ### Implementation of Adjacent Mesh Tokenization and Detokenization 68 | ``` 69 | # We release our adjacent mesh tokenization implementation in adjacent_mesh_tokenization.py. 70 | # For detokenization please check the function adjacent_detokenize in MeshAnything/models/meshanything_v2.py 71 | python adjacent_mesh_tokenization.py 72 | ``` 73 | 74 | 75 | ### For text/image to Artist-Create Mesh. We suggest using [Rodin](https://hyperhuman.deemos.com/rodin) to first achieve text or image to dense mesh. And then input the dense mesh to us. 76 | ``` 77 | # Put the output obj file of Rodin to rodin_result and using the following command to generate the Artist-Created Mesh. 78 | # We suggest using the --mc flag to preprocess the input mesh with Marching Cubes first. This helps us to align the inference point cloud to our training domain. 79 | python main.py --input_dir rodin_result --out_dir mesh_output --input_type mesh --mc 80 | ``` 81 | 82 | ### Mesh Command line inference 83 | #### Important Notes: If your mesh input is not produced by Marching Cubes, We suggest you to preprocess the mesh with Marching Cubes first (simply by adding --mc). 84 | ``` 85 | # folder input 86 | python main.py --input_dir examples --out_dir mesh_output --input_type mesh 87 | 88 | # single file input 89 | python main.py --input_path examples/wand.obj --out_dir mesh_output --input_type mesh 90 | 91 | # Preprocess with Marching Cubes first 92 | python main.py --input_dir examples --out_dir mesh_output --input_type mesh --mc 93 | 94 | # The mc resolution is default to be 128. For some delicate mesh, this resolution is not sufficient. Raise this resolution takes more time to preprocess but should achieve a better result. 95 | # Change it by : --mc_level 7 -> 128 (2^7), --mc_level 8 -> 256 (2^8). 96 | # 256 resolution Marching Cube example. 97 | python main.py --input_dir examples --out_dir mesh_output --input_type mesh --mc --mc_level 8 98 | ``` 99 | 100 | ### Point Cloud Command line inference 101 | ``` 102 | # Note: if you want to use your own point cloud, please make sure the normal is included. 103 | # The file format should be a .npy file with shape (N, 6), where N is the number of points. The first 3 columns are the coordinates, and the last 3 columns are the normal. 104 | 105 | # inference for folder 106 | python main.py --input_dir pc_examples --out_dir pc_output --input_type pc_normal 107 | 108 | # inference for single file 109 | python main.py --input_path pc_examples/grenade.npy --out_dir pc_output --input_type pc_normal 110 | ``` 111 | 112 | ### Local Gradio Demo 113 | ``` 114 | python app.py 115 | ``` 116 | 117 | ## Training 118 | 119 | ### Step 1 Download Dataset 120 | We provide part of our processed dataset from Objaverse. You can download it from https://huggingface.co/datasets/Yiwen-ntu/MeshAnythingV2/tree/main 121 | 122 | After downloading, place `train.npz` and `test.npz` into the `dataset` directory. 123 | 124 | If you prefer to process your own data, please refer to `data_process.py`. 125 | 126 | ### Step 2 Download Point Cloud Encoder Checkpoints 127 | 128 | Download Michelangelo's point encoder from https://huggingface.co/Maikou/Michelangelo/tree/main/checkpoints/aligned_shape_latents and put it into `meshanything_train/miche/checkpoints/aligned_shape_latents/shapevae-256.ckpt`. 129 | 130 | ### Step 3 Training and Evaluation 131 | ``` 132 | # Training with MultiGPU 133 | accelerate launch --multi_gpu --num_processes 8 train.py --batchsize_per_gpu 2 --checkpoint_dir training_trial 134 | 135 | # Evaluation 136 | python train.py --batchsize_per_gpu 2 --checkpoint_dir evaluation_trial --pretrained_weights gpt_output/training_trial/xxx_xxx.pth --test_only 137 | ``` 138 | 139 | ## Important Notes 140 | - It takes about 8GB and 45s to generate a mesh on an A6000 GPU (depending on the face number of the generated mesh). 141 | - The input mesh will be normalized to a unit bounding box. The up vector of the input mesh should be +Y for better results. 142 | - Limited by computational resources, MeshAnything is trained on meshes with fewer than 1600 faces and cannot generate meshes with more than 1600 faces. The shape of the input mesh should be sharp enough; otherwise, it will be challenging to represent it with only 1600 faces. Thus, feed-forward 3D generation methods may often produce bad results due to insufficient shape quality. We suggest using results from 3D reconstruction, scanning, SDS-based method (like [DreamCraft3D](https://github.com/deepseek-ai/DreamCraft3D)) or [Rodin](https://hyperhuman.deemos.com/rodin) as the input of MeshAnything. 143 | - Please refer to https://huggingface.co/spaces/Yiwen-ntu/MeshAnything/tree/main/examples for more examples. 144 | 145 | ## Acknowledgement 146 | 147 | Our code is based on these wonderful repos: 148 | 149 | * [MeshAnything](https://github.com/buaacyw/MeshAnything) 150 | * [MeshGPT](https://nihalsid.github.io/mesh-gpt/) 151 | * [meshgpt-pytorch](https://github.com/lucidrains/meshgpt-pytorch) 152 | * [Michelangelo](https://github.com/NeuralCarver/Michelangelo) 153 | * [transformers](https://github.com/huggingface/transformers) 154 | * [vector-quantize-pytorch](https://github.com/lucidrains/vector-quantize-pytorch) 155 | 156 | ## BibTeX 157 | ``` 158 | @misc{chen2024meshanythingv2artistcreatedmesh, 159 | title={MeshAnything V2: Artist-Created Mesh Generation With Adjacent Mesh Tokenization}, 160 | author={Yiwen Chen and Yikai Wang and Yihao Luo and Zhengyi Wang and Zilong Chen and Jun Zhu and Chi Zhang and Guosheng Lin}, 161 | year={2024}, 162 | eprint={2408.02555}, 163 | archivePrefix={arXiv}, 164 | primaryClass={cs.CV}, 165 | url={https://arxiv.org/abs/2408.02555}, 166 | } 167 | ``` 168 | -------------------------------------------------------------------------------- /adjacent_mesh_tokenization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | import networkx as nx 4 | import os 5 | 6 | def mesh_sort(vertices_, faces_): 7 | assert (vertices_ <= 0.5).all() and (vertices_ >= -0.5).all() # [-0.5, 0.5] 8 | vertices = (vertices_ + 0.5) * 128 # [0, num_tokens] 9 | vertices -= 0.5 # for evenly distributed, [-0.5, num_tokens -0.5] will be round to 0 or num_tokens (-1) 10 | vertices_quantized_ = np.clip(vertices.round(), 0, 128 - 1).astype(int) # [0, num_tokens -1] 11 | 12 | cur_mesh = trimesh.Trimesh(vertices=vertices_quantized_, faces=faces_) 13 | 14 | cur_mesh.merge_vertices() 15 | cur_mesh.update_faces(cur_mesh.nondegenerate_faces()) 16 | cur_mesh.update_faces(cur_mesh.unique_faces()) 17 | cur_mesh.remove_unreferenced_vertices() 18 | 19 | sort_inds = np.lexsort(cur_mesh.vertices.T) 20 | vertices = cur_mesh.vertices[sort_inds] 21 | faces = [np.argsort(sort_inds)[f] for f in cur_mesh.faces] 22 | 23 | faces = [sorted(sub_arr) for sub_arr in faces] 24 | 25 | def sort_faces(face): 26 | return face[0], face[1], face[2] 27 | 28 | faces = sorted(faces, key=sort_faces) 29 | 30 | vertices = vertices / 128 - 0.5 # [0, num_tokens -1] to [-0.5, 0.5) for computing 31 | 32 | return vertices, faces 33 | 34 | def adjacent_mesh_tokenization(mesh): 35 | naive_v_length = mesh.faces.shape[0] * 9 36 | 37 | graph = mesh.vertex_adjacency_graph 38 | 39 | unvisited_faces = mesh.faces.copy() 40 | dis_vertices = np.asarray((mesh.vertices.copy() + 0.5) * 128) 41 | 42 | sequence = [] 43 | while unvisited_faces.shape[0] > 0: 44 | # find the face with the smallest index 45 | if len(sequence) == 0 or sequence[-1] == -1: 46 | cur_face = unvisited_faces[0] 47 | unvisited_faces = unvisited_faces[1:] 48 | sequence.extend(cur_face.tolist()) 49 | else: 50 | last_vertices = sequence[-2:] 51 | # find common neighbors 52 | commons = sorted(list(nx.common_neighbors(graph, last_vertices[0], last_vertices[1]))) 53 | next_token = None 54 | for common in commons: 55 | common_face = sorted(np.array(last_vertices + [common])) 56 | # find index of common face 57 | equals = np.where((unvisited_faces == common_face).all(axis=1))[0] 58 | assert len(equals) == 1 or len(equals) == 0 59 | if len(equals) == 1: 60 | next_token = common 61 | next_face_index = equals[0] 62 | break 63 | if next_token is not None: 64 | unvisited_faces = np.delete(unvisited_faces, next_face_index, axis=0) 65 | sequence.append(int(next_token)) 66 | else: 67 | sequence.append(-1) 68 | 69 | final_sequence = [] 70 | for token_id in sequence: 71 | if token_id == -1: 72 | final_sequence.append(128) 73 | else: 74 | final_sequence.extend(dis_vertices[token_id].tolist()) 75 | 76 | cur_ratio = len(final_sequence) / naive_v_length 77 | 78 | return cur_ratio 79 | 80 | 81 | if __name__ == "__main__": 82 | # read_ply 83 | data_dir = 'gt_examples' 84 | data_list = sorted(os.listdir(data_dir)) 85 | data_list = [os.path.join(data_dir, x) for x in data_list if x.endswith('.ply') or x.endswith('.obj')] 86 | ratio_list = [] 87 | for idx, cur_data in enumerate(data_list): 88 | cur_mesh = trimesh.load(cur_data) 89 | 90 | vertices = cur_mesh.vertices 91 | faces = cur_mesh.faces 92 | bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) 93 | vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 94 | vertices = vertices / (bounds[1] - bounds[0]).max() 95 | vertices = vertices.clip(-0.5, 0.5) 96 | 97 | vertices, faces = mesh_sort(vertices, faces) 98 | dis_mesh = trimesh.Trimesh(vertices=vertices, faces=faces) 99 | try: 100 | cur_ratio = adjacent_mesh_tokenization(dis_mesh) 101 | 102 | ratio_list.append(cur_ratio) 103 | except Exception as e: 104 | print(e) 105 | 106 | # print mean and variance of ratio: 107 | print(f"mean ratio: {np.mean(ratio_list)}, variance ratio: {np.var(ratio_list)}") -------------------------------------------------------------------------------- /demo/demo_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/demo/demo_video.gif -------------------------------------------------------------------------------- /gt_examples/seals.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/gt_examples/seals.ply -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import torch 3 | import time 4 | import trimesh 5 | import numpy as np 6 | import datetime 7 | from accelerate import Accelerator 8 | from accelerate.utils import set_seed 9 | from accelerate.utils import DistributedDataParallelKwargs 10 | from safetensors.torch import load_model 11 | 12 | from mesh_to_pc import process_mesh_to_pc 13 | from huggingface_hub import hf_hub_download 14 | from MeshAnything.models.meshanything_v2 import MeshAnythingV2 15 | 16 | class Dataset: 17 | def __init__(self, input_type, input_list, mc=False, mc_level = 7): 18 | super().__init__() 19 | self.data = [] 20 | if input_type == 'pc_normal': 21 | for input_path in input_list: 22 | # load npy 23 | cur_data = np.load(input_path) 24 | # sample 4096 25 | assert cur_data.shape[0] >= 8192, "input pc_normal should have at least 4096 points" 26 | idx = np.random.choice(cur_data.shape[0], 8192, replace=False) 27 | cur_data = cur_data[idx] 28 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) 29 | 30 | elif input_type == 'mesh': 31 | mesh_list = [] 32 | for input_path in input_list: 33 | # load ply 34 | cur_data = trimesh.load(input_path) 35 | mesh_list.append(cur_data) 36 | if mc: 37 | print("First Marching Cubes and then sample point cloud, need several minutes...") 38 | pc_list, _ = process_mesh_to_pc(mesh_list, marching_cubes=mc, mc_level=mc_level) 39 | for input_path, cur_data in zip(input_list, pc_list): 40 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) 41 | print(f"dataset total data samples: {len(self.data)}") 42 | 43 | def __len__(self): 44 | return len(self.data) 45 | 46 | def __getitem__(self, idx): 47 | data_dict = {} 48 | data_dict['pc_normal'] = self.data[idx]['pc_normal'] 49 | # normalize pc coor 50 | pc_coor = data_dict['pc_normal'][:, :3] 51 | normals = data_dict['pc_normal'][:, 3:] 52 | bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)]) 53 | pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 54 | pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995 55 | assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong" 56 | data_dict['pc_normal'] = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) 57 | data_dict['uid'] = self.data[idx]['uid'] 58 | 59 | return data_dict 60 | 61 | def get_args(): 62 | parser = argparse.ArgumentParser("MeshAnything", add_help=False) 63 | 64 | parser.add_argument('--input_dir', default=None, type=str) 65 | parser.add_argument('--input_path', default=None, type=str) 66 | 67 | parser.add_argument('--out_dir', default="inference_out", type=str) 68 | 69 | parser.add_argument( 70 | '--input_type', 71 | choices=['mesh','pc_normal'], 72 | default='pc', 73 | help="Type of the asset to process (default: pc)" 74 | ) 75 | 76 | parser.add_argument("--batchsize_per_gpu", default=1, type=int) 77 | parser.add_argument("--seed", default=0, type=int) 78 | 79 | parser.add_argument("--mc", default=False, action="store_true") 80 | parser.add_argument("--mc_level", default=7, type=int) 81 | 82 | parser.add_argument("--sampling", default=False, action="store_true") 83 | 84 | args = parser.parse_args() 85 | return args 86 | 87 | if __name__ == "__main__": 88 | args = get_args() 89 | 90 | cur_time = datetime.datetime.now().strftime("%d_%H-%M-%S") 91 | checkpoint_dir = os.path.join(args.out_dir, cur_time) 92 | os.makedirs(checkpoint_dir, exist_ok=True) 93 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 94 | accelerator = Accelerator( 95 | mixed_precision="fp16", 96 | project_dir=checkpoint_dir, 97 | kwargs_handlers=[kwargs] 98 | ) 99 | 100 | model = MeshAnythingV2.from_pretrained("Yiwen-ntu/meshanythingv2") 101 | 102 | # create dataset 103 | if args.input_dir is not None: 104 | input_list = sorted(os.listdir(args.input_dir)) 105 | # only ply, obj or npy 106 | if args.input_type == 'pc_normal': 107 | input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.npy')] 108 | else: 109 | input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.ply') or x.endswith('.obj') or x.endswith('.npy')] 110 | set_seed(args.seed) 111 | dataset = Dataset(args.input_type, input_list, args.mc, args.mc_level) 112 | elif args.input_path is not None: 113 | set_seed(args.seed) 114 | dataset = Dataset(args.input_type, [args.input_path], args.mc, args.mc_level) 115 | else: 116 | raise ValueError("input_dir or input_path must be provided.") 117 | 118 | dataloader = torch.utils.data.DataLoader( 119 | dataset, 120 | batch_size=args.batchsize_per_gpu, 121 | drop_last = False, 122 | shuffle = False, 123 | ) 124 | 125 | if accelerator.state.num_processes > 1: 126 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 127 | dataloader, model = accelerator.prepare(dataloader, model) 128 | begin_time = time.time() 129 | print("Generation Start!!!") 130 | with accelerator.autocast(): 131 | for curr_iter, batch_data_label in enumerate(dataloader): 132 | curr_time = time.time() 133 | outputs = model(batch_data_label['pc_normal'], sampling=args.sampling) 134 | batch_size = outputs.shape[0] 135 | device = outputs.device 136 | 137 | for batch_id in range(batch_size): 138 | recon_mesh = outputs[batch_id] 139 | valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1) 140 | recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3 141 | 142 | vertices = recon_mesh.reshape(-1, 3).cpu() 143 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face 144 | triangles = vertices_index.reshape(-1, 3) 145 | 146 | scene_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh", 147 | merge_primitives=True) 148 | scene_mesh.merge_vertices() 149 | scene_mesh.update_faces(scene_mesh.nondegenerate_faces()) 150 | scene_mesh.update_faces(scene_mesh.unique_faces()) 151 | scene_mesh.remove_unreferenced_vertices() 152 | scene_mesh.fix_normals() 153 | save_path = os.path.join(checkpoint_dir, f'{batch_data_label["uid"][batch_id]}_gen.obj') 154 | num_faces = len(scene_mesh.faces) 155 | brown_color = np.array([255, 165, 0, 255], dtype=np.uint8) 156 | face_colors = np.tile(brown_color, (num_faces, 1)) 157 | 158 | scene_mesh.visual.face_colors = face_colors 159 | scene_mesh.export(save_path) 160 | print(f"{save_path} Over!!") 161 | end_time = time.time() 162 | print(f"Total time: {end_time - begin_time}") -------------------------------------------------------------------------------- /mesh_to_pc.py: -------------------------------------------------------------------------------- 1 | import mesh2sdf.core 2 | import numpy as np 3 | import skimage.measure 4 | import trimesh 5 | import time 6 | def normalize_vertices(vertices, scale=0.95): 7 | bbmin, bbmax = vertices.min(0), vertices.max(0) 8 | center = (bbmin + bbmax) * 0.5 9 | scale = 2.0 * scale / (bbmax - bbmin).max() 10 | vertices = (vertices - center) * scale 11 | return vertices, center, scale 12 | 13 | def export_to_watertight(normalized_mesh, octree_depth: int = 7): 14 | """ 15 | Convert the non-watertight mesh to watertight. 16 | 17 | Args: 18 | input_path (str): normlized path 19 | octree_depth (int): 20 | 21 | Returns: 22 | mesh(trimesh.Trimesh): watertight mesh 23 | 24 | """ 25 | size = 2 ** octree_depth 26 | level = 2 / size 27 | scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices) 28 | 29 | sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size) 30 | 31 | vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level) 32 | 33 | vertices = vertices / size * 2 - 1 # -1 to 1 34 | vertices = vertices / to_orig_scale + to_orig_center 35 | # vertices = vertices / to_orig_scale + to_orig_center 36 | mesh = trimesh.Trimesh(vertices, faces, normals=normals) 37 | 38 | return mesh 39 | 40 | def process_mesh_to_pc(mesh_list, marching_cubes = False, sample_num = 8192, mc_level= 7): 41 | # mesh_list : list of trimesh 42 | pc_normal_list = [] 43 | return_mesh_list = [] 44 | for mesh in mesh_list: 45 | if marching_cubes: 46 | cur_time = time.time() 47 | mesh = export_to_watertight(mesh, octree_depth=mc_level) 48 | print("MC over! ", "mc_level: ", mc_level, "process_time:" , time.time() - cur_time) 49 | return_mesh_list.append(mesh) 50 | points, face_idx = mesh.sample(sample_num, return_index=True) 51 | normals = mesh.face_normals[face_idx] 52 | 53 | pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16) 54 | pc_normal_list.append(pc_normal) 55 | print("process mesh success") 56 | return pc_normal_list, return_mesh_list 57 | 58 | -------------------------------------------------------------------------------- /meshanything_train/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import pickle 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import os 7 | 8 | def is_distributed(): 9 | if not dist.is_available() or not dist.is_initialized(): 10 | return False 11 | return True 12 | 13 | 14 | def get_rank(): 15 | if not is_distributed(): 16 | return 0 17 | return dist.get_rank() 18 | 19 | 20 | def is_primary(): 21 | return get_rank() == 0 22 | 23 | 24 | def get_world_size(): 25 | if not is_distributed(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def barrier(): 31 | if not is_distributed(): 32 | return 33 | torch.distributed.barrier() 34 | 35 | 36 | def setup_print_for_distributed(is_primary): 37 | """ 38 | This function disables printing when not in primary process 39 | """ 40 | import builtins as __builtin__ 41 | builtin_print = __builtin__.print 42 | 43 | def print(*args, **kwargs): 44 | force = kwargs.pop('force', False) 45 | if is_primary or force: 46 | builtin_print(*args, **kwargs) 47 | 48 | __builtin__.print = print 49 | 50 | 51 | def init_distributed(gpu_id, global_rank, world_size, dist_url, dist_backend): 52 | torch.cuda.set_device(gpu_id) 53 | os.environ["MASTER_ADDR"] = "localhost" 54 | os.environ["MASTER_PORT"] = "12355" 55 | print( 56 | f"| distributed init (rank {global_rank}) (world {world_size}): {dist_url}", 57 | flush=True, 58 | ) 59 | torch.distributed.init_process_group( 60 | backend=dist_backend, 61 | # init_method="env://", 62 | world_size=world_size, 63 | rank=global_rank, 64 | ) 65 | torch.distributed.barrier() 66 | setup_print_for_distributed(is_primary()) 67 | 68 | 69 | def all_reduce_sum(tensor): 70 | if not is_distributed(): 71 | return tensor 72 | dim_squeeze = False 73 | if tensor.ndim == 0: 74 | tensor = tensor[None, ...] 75 | dim_squeeze = True 76 | torch.distributed.all_reduce(tensor) 77 | print("loss_tensor: ", tensor) 78 | if dim_squeeze: 79 | tensor = tensor.squeeze(0) 80 | return tensor 81 | 82 | 83 | def all_reduce_average(tensor): 84 | val = all_reduce_sum(tensor) 85 | return val / get_world_size() 86 | 87 | 88 | # Function from DETR - https://github.com/facebookresearch/detr/blob/master/util/misc.py 89 | def reduce_dict(input_dict, average=True): 90 | """ 91 | Args: 92 | input_dict (dict): all the values will be reduced 93 | average (bool): whether to do average or sum 94 | Reduce the values in the dictionary from all processes so that all processes 95 | have the averaged results. Returns a dict with the same fields as 96 | input_dict, after reduction. 97 | """ 98 | world_size = get_world_size() 99 | if world_size < 2: 100 | return input_dict 101 | with torch.no_grad(): 102 | names = [] 103 | values = [] 104 | # sort the keys so that they are consistent across processes 105 | for k in sorted(input_dict.keys()): 106 | names.append(k) 107 | values.append(input_dict[k]) 108 | values = torch.stack(values, dim=0) 109 | torch.distributed.all_reduce(values) 110 | if average: 111 | values /= world_size 112 | reduced_dict = {k: v for k, v in zip(names, values)} 113 | return reduced_dict 114 | 115 | 116 | # Function from https://github.com/facebookresearch/detr/blob/master/util/misc.py 117 | def all_gather_pickle(data, device): 118 | """ 119 | Run all_gather on arbitrary picklable data (not necessarily tensors) 120 | Args: 121 | data: any picklable object 122 | Returns: 123 | list[data]: list of data gathered from each rank 124 | """ 125 | world_size = get_world_size() 126 | if world_size == 1: 127 | return [data] 128 | 129 | # serialized to a Tensor 130 | buffer = pickle.dumps(data) 131 | storage = torch.ByteStorage.from_buffer(buffer) 132 | tensor = torch.ByteTensor(storage).to(device) 133 | 134 | # obtain Tensor size of each rank 135 | local_size = torch.tensor([tensor.numel()], device=device) 136 | size_list = [torch.tensor([0], device=device) for _ in range(world_size)] 137 | dist.all_gather(size_list, local_size) 138 | size_list = [int(size.item()) for size in size_list] 139 | max_size = max(size_list) 140 | 141 | # receiving Tensor from all ranks 142 | # we pad the tensor because torch all_gather does not support 143 | # gathering tensors of different shapes 144 | tensor_list = [] 145 | for _ in size_list: 146 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) 147 | if local_size != max_size: 148 | padding = torch.empty( 149 | size=(max_size - local_size,), dtype=torch.uint8, device=device 150 | ) 151 | tensor = torch.cat((tensor, padding), dim=0) 152 | dist.all_gather(tensor_list, tensor) 153 | 154 | data_list = [] 155 | for size, tensor in zip(size_list, tensor_list): 156 | buffer = tensor.cpu().numpy().tobytes()[:size] 157 | data_list.append(pickle.loads(buffer)) 158 | 159 | return data_list 160 | 161 | 162 | def all_gather_dict(data): 163 | """ 164 | Run all_gather on data which is a dictionary of Tensors 165 | """ 166 | assert isinstance(data, dict) 167 | 168 | gathered_dict = {} 169 | for item_key in data: 170 | if isinstance(data[item_key], torch.Tensor): 171 | if is_distributed(): 172 | data[item_key] = data[item_key].contiguous() 173 | tensor_list = [torch.empty_like(data[item_key]) for _ in range(get_world_size())] 174 | dist.all_gather(tensor_list, data[item_key]) 175 | gathered_tensor = torch.cat(tensor_list, dim=0) 176 | else: 177 | gathered_tensor = data[item_key] 178 | gathered_dict[item_key] = gathered_tensor 179 | return gathered_dict 180 | -------------------------------------------------------------------------------- /meshanything_train/engine.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, math, json, importlib 2 | import torch 3 | import datetime 4 | 5 | from meshanything_train.misc import SmoothedValue 6 | from collections import defaultdict 7 | 8 | def save_checkpoint( 9 | checkpoint_dir, 10 | model, 11 | optimizer, 12 | epoch, 13 | args, 14 | best_val_metrics, 15 | filename=None, 16 | ): 17 | 18 | checkpoint_name = os.path.join(checkpoint_dir, filename) 19 | try: 20 | weight_ckpt = model.module.state_dict() 21 | except Exception as e: 22 | print("single GPU") 23 | weight_ckpt = model.state_dict() 24 | 25 | sd = { 26 | "model": weight_ckpt, 27 | "optimizer": optimizer.state_dict(), 28 | "epoch": epoch, 29 | "args": args, 30 | "best_val_metrics": best_val_metrics, 31 | } 32 | torch.save(sd, checkpoint_name) 33 | 34 | 35 | def compute_learning_rate(args, curr_epoch_normalized): 36 | assert curr_epoch_normalized <= 1.0 and curr_epoch_normalized >= 0.0 37 | if ( 38 | curr_epoch_normalized <= (args.warm_lr_epochs / args.max_epoch) 39 | and args.warm_lr_epochs > 0 40 | ): 41 | # Linear Warmup 42 | curr_lr = args.warm_lr + curr_epoch_normalized * args.max_epoch * ( 43 | (args.base_lr - args.warm_lr) / args.warm_lr_epochs 44 | ) 45 | else: 46 | # Cosine Learning Rate Schedule 47 | curr_lr = args.final_lr + 0.5 * (args.base_lr - args.final_lr) * ( 48 | 1 + math.cos(math.pi * curr_epoch_normalized) 49 | ) 50 | return curr_lr 51 | 52 | def adjust_learning_rate(args, optimizer, curr_epoch): 53 | curr_lr = compute_learning_rate(args, curr_epoch) 54 | for param_group in optimizer.param_groups: 55 | param_group["lr"] = curr_lr 56 | return curr_lr 57 | 58 | def do_train( 59 | args, 60 | model, 61 | dataloaders, 62 | logger, 63 | accelerator, 64 | best_val_metrics=dict() 65 | ): 66 | 67 | optimizer = torch.optim.AdamW( 68 | filter(lambda params: params.requires_grad, model.parameters()), # list(model.named_parameters()) 69 | lr=args.base_lr, 70 | weight_decay=args.weight_decay 71 | ) 72 | start_epoch = 0 73 | if args.pretrained_weights is not None: 74 | sd = torch.load(args.pretrained_weights, map_location=torch.device("cpu")) 75 | epoch = sd["epoch"] 76 | print(f"Found checkpoint at {epoch}. Resuming.") 77 | model.load_state_dict(sd["model"], strict=not args.no_strict) 78 | optimizer.load_state_dict(sd["optimizer"]) 79 | start_epoch = epoch 80 | print( 81 | f"Loaded model and optimizer state at {epoch}. Loaded best val metrics so far." 82 | ) 83 | 84 | if accelerator.state.num_processes > 1: 85 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 86 | 87 | dataloaders['train'], dataloaders['test'], model, optimizer = accelerator.prepare( 88 | dataloaders['train'], 89 | dataloaders['test'], 90 | model, 91 | optimizer, 92 | ) 93 | 94 | max_iters = args.max_epoch * len(dataloaders['train']) // args.gradient_accumulation_steps 95 | print("train dataloader len ",len(dataloaders['train'])) 96 | 97 | time_delta = SmoothedValue(window_size=10) 98 | 99 | model.train() 100 | if args.start_epoch == -1: 101 | args.start_epoch = start_epoch 102 | 103 | if args.test_only: 104 | with accelerator.autocast(): 105 | task_metrics, eval_loss_dict = dataloaders['test'].dataset.eval_func( 106 | args, 107 | 0, 108 | model, 109 | dataloaders['test'], 110 | accelerator, 111 | logger, 112 | 0, 113 | test_only = True 114 | ) 115 | accelerator.log(eval_loss_dict, step=0) 116 | return 117 | 118 | curr_iter = args.start_epoch * len(dataloaders['train']) // args.gradient_accumulation_steps 119 | curr_time = time.time() 120 | loss_dict = defaultdict(list) 121 | 122 | for curr_epoch in range(args.start_epoch, args.max_epoch): 123 | for batch_idx, batch_data_label in enumerate(dataloaders['train']): 124 | curr_lr = adjust_learning_rate(args, optimizer, curr_iter / max_iters) 125 | with accelerator.accumulate(model): 126 | optimizer.zero_grad() 127 | with accelerator.autocast(): 128 | outputs = model(batch_data_label) 129 | 130 | loss = outputs['loss'] 131 | 132 | if not math.isfinite(loss.item()): 133 | logger.info("Loss in not finite. Terminate training.") 134 | exit(-1) 135 | 136 | accelerator.backward(loss) 137 | if args.clip_gradient > 0 and accelerator.sync_gradients: 138 | accelerator.clip_grad_norm_(model.parameters(), args.clip_gradient) 139 | 140 | optimizer.step() 141 | 142 | for key, value in outputs.items(): 143 | if 'loss' in key.lower(): 144 | loss_dict[key].append(value.item()) 145 | # logging 146 | if accelerator.sync_gradients: 147 | time_delta.update(time.time() - curr_time) 148 | curr_time = time.time() 149 | curr_iter += 1 150 | 151 | if curr_iter % args.log_every == 0: 152 | mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2) 153 | eta_seconds = (max_iters - curr_iter) * time_delta.avg 154 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) 155 | 156 | logger.info( 157 | f"Epoch [{curr_epoch}/{args.max_epoch}]; " 158 | f"Iter [{curr_iter}/{max_iters}]; " + \ 159 | f"LR {curr_lr:0.2e}; Iter time {time_delta.avg:0.2f}; " 160 | f"ETA {eta_str}; Mem {mem_mb:0.2f}MB" 161 | ) 162 | for key, value in loss_dict.items(): 163 | loss_dict[key] = torch.tensor(value, dtype=torch.float32).mean().item() 164 | loss_dict["learning_rate"] = curr_lr 165 | accelerator.log(loss_dict, step=curr_iter) 166 | loss_dict = defaultdict(list) 167 | 168 | if accelerator.is_main_process and (curr_iter + 1) % args.eval_every_iteration == 0: 169 | save_checkpoint( 170 | args.checkpoint_dir, 171 | model, 172 | optimizer, 173 | curr_epoch, 174 | args, 175 | best_val_metrics, 176 | filename=f"checkpoint_{curr_iter+1}.pth", 177 | ) 178 | 179 | # do eval 180 | do_eval_flag = (curr_iter + 1) % args.eval_every_iteration == 0 181 | do_eval_flag &= (curr_iter + 1) > args.start_eval_after 182 | do_eval_flag |= (curr_iter + 1) == max_iters 183 | do_eval_flag |= curr_iter == 1000 184 | 185 | if do_eval_flag is True: 186 | with accelerator.autocast(): 187 | task_metrics, eval_loss_dict = dataloaders['test'].dataset.eval_func( 188 | args, 189 | curr_epoch, 190 | model, 191 | dataloaders['test'], 192 | accelerator, 193 | logger, 194 | curr_iter+1, 195 | ) 196 | if accelerator.is_main_process: 197 | print("Evaluation End, Begin Log!") 198 | accelerator.log(eval_loss_dict, step=curr_iter+1) 199 | print("Resume Training!") 200 | model.train() 201 | 202 | accelerator.end_training() 203 | return -------------------------------------------------------------------------------- /meshanything_train/miche/README.md: -------------------------------------------------------------------------------- 1 | # Michelangelo 2 | 3 | ## [Conditional 3D Shape Generation based on Shape-Image-Text Aligned Latent Representation](https://neuralcarver.github.io/michelangelo)
4 | [Zibo Zhao](https://github.com/Maikouuu), 5 | [Wen Liu](https://github.com/StevenLiuWen), 6 | [Xin Chen](https://chenxin.tech/), 7 | [Xianfang Zeng](https://github.com/Zzlongjuanfeng), 8 | [Rui Wang](https://wrong.wang/), 9 | [Pei Cheng](https://neuralcarver.github.io/michelangelo), 10 | [Bin Fu](https://neuralcarver.github.io/michelangelo), 11 | [Tao Chen](https://eetchen.github.io), 12 | [Gang Yu](https://www.skicyyu.org), 13 | [Shenghua Gao](https://sist.shanghaitech.edu.cn/sist_en/2020/0814/c7582a54772/page.htm)
14 | ### [Hugging Face Demo](https://huggingface.co/spaces/Maikou/Michelangelo) | [Project Page](https://neuralcarver.github.io/michelangelo/) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Paper](https://openreview.net/pdf?id=xmxgMij3LY)
15 | 16 | https://github.com/NeuralCarver/Michelangelo/assets/37449470/123bae2c-fbb1-4d63-bd13-0e300a550868 17 | 18 | Visualization of the 3D shape produced by our framework, which splits into triplets with a conditional input on the left, a normal map in the middle, and a triangle mesh on the right. The generated 3D shapes semantically conform to the visual or textural conditional inputs.
19 | 20 | ## 🔆 Features 21 | **Michelangelo** possesses three capabilities: 22 | 23 | 1. Representing a shape into shape-image-text aligned space; 24 | 2. Image-conditioned Shape Generation; 25 | 3. Text-conditioned Shape Generation. 26 | 27 |
28 | Techniques 29 | 30 | We present a novel _alignment-before-generation_ approach to tackle the challenging task of generating general 3D shapes based on 2D images or texts. Directly learning a conditional generative model from images or texts to 3D shapes is prone to producing inconsistent results with the conditions because 3D shapes have an additional dimension whose distribution significantly differs from that of 2D images and texts. To bridge the domain gap among the three modalities and facilitate multi-modal-conditioned 3D shape generation, we explore representing 3D shapes in a shape-image-text-aligned space. Our framework comprises two models: a Shape-Image-Text-Aligned Variational Auto-Encoder (SITA-VAE) and a conditional Aligned Shape Latent Diffusion Model (ASLDM). The former model encodes the 3D shapes into the shape latent space aligned to the image and text and reconstructs the fine-grained 3D neural fields corresponding to given shape embeddings via the transformer-based decoder. The latter model learns a probabilistic mapping function from the image or text space to the latent shape space. Our extensive experiments demonstrate that our proposed approach can generate higher-quality and more diverse 3D shapes that better semantically conform to the visual or textural conditional inputs, validating the effectiveness of the shape-image-text-aligned space for cross-modality 3D shape generation. 31 | 32 | ![newnetwork](https://github.com/NeuralCarver/Michelangelo/assets/16475892/d5231fb7-7768-45ee-92e1-3599a4c43a2c) 33 |
34 | 35 | ## 📰 News 36 | - [2024/1/23] Set up the Hugging Face Demo and release the code 37 | - [2023/09/22] **Michelangelo got accepted by NeurIPS 2023!** 38 | - [2023/6/29] Upload paper and init project 39 | 40 | ## ⚙️ Setup 41 | 42 | ### Installation 43 | Follow the command below to install the environment. We have tested the installation package on Tesla V100 and Tesla T4. 44 | ``` 45 | git clone https://github.com/NeuralCarver/Michelangelo.git 46 | cd Michelangelo 47 | conda create --name Michelangelo python=3.9 48 | conda activate Michelangelo 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | ### Checkpoints 53 | Pleasae download weights from Hugging Face Model Space and put it to root folder. We have also uploaded the weights related to CLIP to facilitate quick usage. 54 | 55 |
56 | 57 | Tips for debugging configureation 58 | 59 | 60 | - If something goes wrong in the environment configuration process unfortunately, the user may consider skipping those packages, such as pysdf, torch-cluster, and torch-scatter. These packages will not affect the execution of the commands we provide. 61 | - If you encounter any issues while downloading CLIP, you can consider downloading it from [CLIP's Hugging Face page](https://huggingface.co/openai/clip-vit-large-patch14). Once the download is complete, remember to modify line [26](https://github.com/NeuralCarver/Michelangelo/blob/b53fa004cd4aeb0f4eb4d159ecec8489a4450dab/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml#L26C1-L26C76) and line [34](https://github.com/NeuralCarver/Michelangelo/blob/b53fa004cd4aeb0f4eb4d159ecec8489a4450dab/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml#L34) in the config file for providing correct path of CLIP. 62 | - From [issue 6](https://github.com/NeuralCarver/Michelangelo/issues/6#issuecomment-1913513382). For Windows users, running wsl2 + ubuntu 22.04, will have issues. As discussed in [issue 786](https://github.com/microsoft/WSL/issues/8587) it is just a matter to add this in the .bashrc: 63 | ``` 64 | export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH. 65 | ``` 66 |
67 | 68 | ## ⚡ Quick Start 69 | 70 | ### Inference 71 | 72 | #### Reconstruction a 3D shape 73 | ``` 74 | ./scripts/inference/reconstruction.sh 75 | ``` 76 | 77 | #### Image-conditioned shape generation 78 | ``` 79 | ./scripts/inference/image2mesh.sh 80 | ``` 81 | 82 | #### Text-conditioned shape generation 83 | ``` 84 | ./scripts/inference/text2mesh.sh 85 | ``` 86 | 87 | #### Simply run all the scripts 88 | ``` 89 | ./scripts/infer.sh 90 | ``` 91 | 92 | 93 | ## ❓ FAQ 94 | 95 | ## Citation 96 | 97 | If you find our code or paper helps, please consider citing: 98 | 99 | ```bibtex 100 | @inproceedings{ 101 | zhao2023michelangelo, 102 | title={Michelangelo: Conditional 3D Shape Generation based on Shape-Image-Text Aligned Latent Representation}, 103 | author={Zibo Zhao and Wen Liu and Xin Chen and Xianfang Zeng and Rui Wang and Pei Cheng and BIN FU and Tao Chen and Gang YU and Shenghua Gao}, 104 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 105 | year={2023}, 106 | url={https://openreview.net/forum?id=xmxgMij3LY} 107 | } 108 | ``` 109 | 110 | ## License 111 | 112 | This code is distributed under an [GPL-3.0 license](LICENSE). 113 | 114 | -------------------------------------------------------------------------------- /meshanything_train/miche/configs/aligned_shape_latents/shapevae-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: meshanything_train.miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 3 | params: 4 | shape_module_cfg: 5 | target: meshanything_train.miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 6 | params: 7 | num_latents: 256 8 | embed_dim: 64 9 | point_feats: 3 # normal 10 | num_freqs: 8 11 | include_pi: false 12 | heads: 12 13 | width: 768 14 | num_encoder_layers: 8 15 | num_decoder_layers: 16 16 | use_ln_post: true 17 | init_scale: 0.25 18 | qkv_bias: false 19 | use_checkpoint: true 20 | aligned_module_cfg: 21 | target: meshanything_train.miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 22 | params: 23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 24 | 25 | loss_cfg: 26 | target: meshanything_train.miche.michelangelo.models.tsal.loss.ContrastKLNearFar 27 | params: 28 | contrast_weight: 0.1 29 | near_weight: 0.1 30 | kl_weight: 0.001 31 | 32 | optimizer_cfg: 33 | optimizer: 34 | target: torch.optim.AdamW 35 | params: 36 | betas: [0.9, 0.99] 37 | eps: 1.e-6 38 | weight_decay: 1.e-2 39 | 40 | scheduler: 41 | target: meshanything_train.miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 42 | params: 43 | warm_up_steps: 5000 44 | f_start: 1.e-6 45 | f_min: 1.e-3 46 | f_max: 1.0 47 | -------------------------------------------------------------------------------- /meshanything_train/miche/configs/aligned_shape_latents/shapevae-512.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: meshanything_train.miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 3 | params: 4 | shape_module_cfg: 5 | target: meshanything_train.miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 6 | params: 7 | num_latents: 512 8 | embed_dim: 64 9 | point_feats: 3 # normal 10 | num_freqs: 8 11 | include_pi: false 12 | heads: 16 13 | width: 1024 14 | num_encoder_layers: 12 15 | num_decoder_layers: 16 16 | use_ln_post: true 17 | init_scale: 0.25 18 | qkv_bias: false 19 | use_checkpoint: true 20 | aligned_module_cfg: 21 | target: meshanything_train.miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 22 | params: 23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 24 | 25 | loss_cfg: 26 | target: meshanything_train.miche.michelangelo.models.tsal.loss.ContrastKLNearFar 27 | params: 28 | contrast_weight: 0.1 29 | near_weight: 0.1 30 | kl_weight: 0.001 31 | 32 | optimizer_cfg: 33 | optimizer: 34 | target: torch.optim.AdamW 35 | params: 36 | betas: [0.9, 0.99] 37 | eps: 1.e-6 38 | weight_decay: 1.e-2 39 | 40 | scheduler: 41 | target: meshanything_train.miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 42 | params: 43 | warm_up_steps: 5000 44 | f_start: 1.e-6 45 | f_min: 1.e-3 46 | f_max: 1.0 47 | -------------------------------------------------------------------------------- /meshanything_train/miche/configs/image_cond_diffuser_asl/image-ASLDM-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser 3 | params: 4 | first_stage_config: 5 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 6 | params: 7 | shape_module_cfg: 8 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 9 | params: 10 | num_latents: &num_latents 256 11 | embed_dim: &embed_dim 64 12 | point_feats: 3 # normal 13 | num_freqs: 8 14 | include_pi: false 15 | heads: 12 16 | width: 768 17 | num_encoder_layers: 8 18 | num_decoder_layers: 16 19 | use_ln_post: true 20 | init_scale: 0.25 21 | qkv_bias: false 22 | use_checkpoint: false 23 | aligned_module_cfg: 24 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 25 | params: 26 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 27 | 28 | loss_cfg: 29 | target: torch.nn.Identity 30 | 31 | cond_stage_config: 32 | target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder 33 | params: 34 | version: "./checkpoints/clip/clip-vit-large-patch14" 35 | zero_embedding_radio: 0.1 36 | 37 | first_stage_key: "surface" 38 | cond_stage_key: "image" 39 | scale_by_std: false 40 | 41 | denoiser_cfg: 42 | target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser 43 | params: 44 | input_channels: *embed_dim 45 | output_channels: *embed_dim 46 | n_ctx: *num_latents 47 | width: 768 48 | layers: 6 # 2 * 6 + 1 = 13 49 | heads: 12 50 | context_dim: 1024 51 | init_scale: 1.0 52 | skip_ln: true 53 | use_checkpoint: true 54 | 55 | scheduler_cfg: 56 | guidance_scale: 7.5 57 | num_inference_steps: 50 58 | eta: 0.0 59 | 60 | noise: 61 | target: diffusers.schedulers.DDPMScheduler 62 | params: 63 | num_train_timesteps: 1000 64 | beta_start: 0.00085 65 | beta_end: 0.012 66 | beta_schedule: "scaled_linear" 67 | variance_type: "fixed_small" 68 | clip_sample: false 69 | denoise: 70 | target: diffusers.schedulers.DDIMScheduler 71 | params: 72 | num_train_timesteps: 1000 73 | beta_start: 0.00085 74 | beta_end: 0.012 75 | beta_schedule: "scaled_linear" 76 | clip_sample: false # clip sample to -1~1 77 | set_alpha_to_one: false 78 | steps_offset: 1 79 | 80 | optimizer_cfg: 81 | optimizer: 82 | target: torch.optim.AdamW 83 | params: 84 | betas: [0.9, 0.99] 85 | eps: 1.e-6 86 | weight_decay: 1.e-2 87 | 88 | scheduler: 89 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 90 | params: 91 | warm_up_steps: 5000 92 | f_start: 1.e-6 93 | f_min: 1.e-3 94 | f_max: 1.0 95 | 96 | loss_cfg: 97 | loss_type: "mse" 98 | -------------------------------------------------------------------------------- /meshanything_train/miche/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser 3 | params: 4 | first_stage_config: 5 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 6 | params: 7 | shape_module_cfg: 8 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 9 | params: 10 | num_latents: &num_latents 256 11 | embed_dim: &embed_dim 64 12 | point_feats: 3 # normal 13 | num_freqs: 8 14 | include_pi: false 15 | heads: 12 16 | width: 768 17 | num_encoder_layers: 8 18 | num_decoder_layers: 16 19 | use_ln_post: true 20 | init_scale: 0.25 21 | qkv_bias: false 22 | use_checkpoint: true 23 | aligned_module_cfg: 24 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 25 | params: 26 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 27 | 28 | loss_cfg: 29 | target: torch.nn.Identity 30 | 31 | cond_stage_config: 32 | target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder 33 | params: 34 | version: "./checkpoints/clip/clip-vit-large-patch14" 35 | zero_embedding_radio: 0.1 36 | max_length: 77 37 | 38 | first_stage_key: "surface" 39 | cond_stage_key: "text" 40 | scale_by_std: false 41 | 42 | denoiser_cfg: 43 | target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser 44 | params: 45 | input_channels: *embed_dim 46 | output_channels: *embed_dim 47 | n_ctx: *num_latents 48 | width: 768 49 | layers: 8 # 2 * 6 + 1 = 13 50 | heads: 12 51 | context_dim: 768 52 | init_scale: 1.0 53 | skip_ln: true 54 | use_checkpoint: true 55 | 56 | scheduler_cfg: 57 | guidance_scale: 7.5 58 | num_inference_steps: 50 59 | eta: 0.0 60 | 61 | noise: 62 | target: diffusers.schedulers.DDPMScheduler 63 | params: 64 | num_train_timesteps: 1000 65 | beta_start: 0.00085 66 | beta_end: 0.012 67 | beta_schedule: "scaled_linear" 68 | variance_type: "fixed_small" 69 | clip_sample: false 70 | denoise: 71 | target: diffusers.schedulers.DDIMScheduler 72 | params: 73 | num_train_timesteps: 1000 74 | beta_start: 0.00085 75 | beta_end: 0.012 76 | beta_schedule: "scaled_linear" 77 | clip_sample: false # clip sample to -1~1 78 | set_alpha_to_one: false 79 | steps_offset: 1 80 | 81 | optimizer_cfg: 82 | optimizer: 83 | target: torch.optim.AdamW 84 | params: 85 | betas: [0.9, 0.99] 86 | eps: 1.e-6 87 | weight_decay: 1.e-2 88 | 89 | scheduler: 90 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 91 | params: 92 | warm_up_steps: 5000 93 | f_start: 1.e-6 94 | f_min: 1.e-3 95 | f_max: 1.0 96 | 97 | loss_cfg: 98 | loss_type: "mse" -------------------------------------------------------------------------------- /meshanything_train/miche/encode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import argparse 4 | # from functools import partial 5 | from omegaconf import OmegaConf, DictConfig, ListConfig 6 | import numpy as np 7 | import torch 8 | # from meshanything_train.models.helper import load_config 9 | from .michelangelo.utils.misc import instantiate_from_config 10 | 11 | def load_surface(fp): 12 | 13 | with np.load(fp) as input_pc: 14 | surface = input_pc['points'] 15 | normal = input_pc['normals'] 16 | 17 | rng = np.random.default_rng() 18 | ind = rng.choice(surface.shape[0], 4096, replace=False) 19 | surface = torch.FloatTensor(surface[ind]) 20 | normal = torch.FloatTensor(normal[ind]) 21 | 22 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() 23 | 24 | return surface 25 | 26 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): 27 | 28 | surface = load_surface(args.pointcloud_path) 29 | # old_surface = surface.clone() 30 | 31 | # surface[0,:,0]*=-1 32 | # surface[0,:,1]*=-1 33 | surface[0,:,2]*=-1 34 | 35 | # encoding 36 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) 37 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) 38 | 39 | # decoding 40 | latents = model.model.shape_model.decode(shape_zq) 41 | # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) 42 | 43 | return 0 44 | 45 | def load_model(ckpt_path="meshanything_train/miche/checkpoints/aligned_shape_latents/shapevae-256.ckpt"): 46 | model_config = OmegaConf.load("meshanything_train/miche/configs/aligned_shape_latents/shapevae-256.yaml") 47 | if hasattr(model_config, "model"): 48 | model_config = model_config.model 49 | 50 | model = instantiate_from_config(model_config, ckpt_path=ckpt_path) 51 | 52 | return model 53 | if __name__ == "__main__": 54 | ''' 55 | 1. Reconstruct point cloud 56 | 2. Image-conditioned generation 57 | 3. Text-conditioned generation 58 | ''' 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--config_path", type=str, required=True) 61 | parser.add_argument("--ckpt_path", type=str, required=True) 62 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud') 63 | parser.add_argument("--image_path", type=str, help='Path to the input image') 64 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.') 65 | parser.add_argument("--output_dir", type=str, default='./output') 66 | parser.add_argument("-s", "--seed", type=int, default=0) 67 | args = parser.parse_args() 68 | 69 | print(f'-----------------------------------------------------------------------------') 70 | print(f'>>> Output directory: {args.output_dir}') 71 | print(f'-----------------------------------------------------------------------------') 72 | 73 | reconstruction(args, load_model(args)) -------------------------------------------------------------------------------- /meshanything_train/miche/example_data/image/car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/meshanything_train/miche/example_data/image/car.jpg -------------------------------------------------------------------------------- /meshanything_train/miche/example_data/surface/surface.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/meshanything_train/miche/example_data/surface/surface.npz -------------------------------------------------------------------------------- /meshanything_train/miche/inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | from collections import OrderedDict 5 | from typing import Optional, List 6 | import argparse 7 | from functools import partial 8 | 9 | from einops import repeat, rearrange 10 | import numpy as np 11 | from PIL import Image 12 | import trimesh 13 | import cv2 14 | 15 | import torch 16 | import pytorch_lightning as pl 17 | 18 | from meshanything_train.miche.michelangelo.models.tsal.tsal_base import Latent2MeshOutput 19 | from meshanything_train.miche.michelangelo.models.tsal.inference_utils import extract_geometry 20 | from meshanything_train.miche.michelangelo.utils.misc import get_config_from_file, instantiate_from_config 21 | from meshanything_train.miche.michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer 22 | from meshanything_train.miche.michelangelo.utils.visualizers import html_util 23 | 24 | def load_model(args): 25 | 26 | model_config = get_config_from_file(args.config_path) 27 | if hasattr(model_config, "model"): 28 | model_config = model_config.model 29 | 30 | model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path) 31 | model = model.cuda() 32 | model = model.eval() 33 | 34 | return model 35 | 36 | def load_surface(fp): 37 | 38 | with np.load(args.pointcloud_path) as input_pc: 39 | surface = input_pc['points'] 40 | normal = input_pc['normals'] 41 | 42 | rng = np.random.default_rng() 43 | ind = rng.choice(surface.shape[0], 4096, replace=False) 44 | surface = torch.FloatTensor(surface[ind]) 45 | normal = torch.FloatTensor(normal[ind]) 46 | 47 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() 48 | 49 | return surface 50 | 51 | def prepare_image(args, number_samples=2): 52 | 53 | image = cv2.imread(f"{args.image_path}") 54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 55 | 56 | image_pt = torch.tensor(image).float() 57 | image_pt = image_pt / 255 * 2 - 1 58 | image_pt = rearrange(image_pt, "h w c -> c h w") 59 | 60 | image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples) 61 | 62 | return image_pt 63 | 64 | def save_output(args, mesh_outputs): 65 | 66 | os.makedirs(args.output_dir, exist_ok=True) 67 | for i, mesh in enumerate(mesh_outputs): 68 | mesh.mesh_f = mesh.mesh_f[:, ::-1] 69 | mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) 70 | 71 | name = str(i) + "_out_mesh.obj" 72 | mesh_output.export(os.path.join(args.output_dir, name), include_normals=True) 73 | 74 | print(f'-----------------------------------------------------------------------------') 75 | print(f'>>> Finished and mesh saved in {args.output_dir}') 76 | print(f'-----------------------------------------------------------------------------') 77 | 78 | return 0 79 | 80 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): 81 | 82 | surface = load_surface(args.pointcloud_path) 83 | 84 | # encoding 85 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) 86 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) 87 | 88 | # decoding 89 | latents = model.model.shape_model.decode(shape_zq) 90 | geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) 91 | 92 | # reconstruction 93 | mesh_v_f, has_surface = extract_geometry( 94 | geometric_func=geometric_func, 95 | device=surface.device, 96 | batch_size=surface.shape[0], 97 | bounds=bounds, 98 | octree_depth=octree_depth, 99 | num_chunks=num_chunks, 100 | ) 101 | recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1]) 102 | 103 | # save 104 | os.makedirs(args.output_dir, exist_ok=True) 105 | recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj')) 106 | 107 | print(f'-----------------------------------------------------------------------------') 108 | print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}') 109 | print(f'-----------------------------------------------------------------------------') 110 | 111 | return 0 112 | 113 | def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7): 114 | 115 | sample_inputs = { 116 | "image": prepare_image(args) 117 | } 118 | 119 | mesh_outputs = model.sample( 120 | sample_inputs, 121 | sample_times=1, 122 | guidance_scale=guidance_scale, 123 | return_intermediates=False, 124 | bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], 125 | octree_depth=octree_depth, 126 | )[0] 127 | 128 | save_output(args, mesh_outputs) 129 | 130 | return 0 131 | 132 | def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7): 133 | 134 | sample_inputs = { 135 | "text": [args.text] * num_samples 136 | } 137 | mesh_outputs = model.sample( 138 | sample_inputs, 139 | sample_times=1, 140 | guidance_scale=guidance_scale, 141 | return_intermediates=False, 142 | bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], 143 | octree_depth=octree_depth, 144 | )[0] 145 | 146 | save_output(args, mesh_outputs) 147 | 148 | return 0 149 | 150 | task_dick = { 151 | 'reconstruction': reconstruction, 152 | 'image2mesh': image2mesh, 153 | 'text2mesh': text2mesh, 154 | } 155 | 156 | if __name__ == "__main__": 157 | ''' 158 | 1. Reconstruct point cloud 159 | 2. Image-conditioned generation 160 | 3. Text-conditioned generation 161 | ''' 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True) 164 | parser.add_argument("--config_path", type=str, required=True) 165 | parser.add_argument("--ckpt_path", type=str, required=True) 166 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud') 167 | parser.add_argument("--image_path", type=str, help='Path to the input image') 168 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.') 169 | parser.add_argument("--output_dir", type=str, default='./output') 170 | parser.add_argument("-s", "--seed", type=int, default=0) 171 | args = parser.parse_args() 172 | 173 | pl.seed_everything(args.seed) 174 | 175 | print(f'-----------------------------------------------------------------------------') 176 | print(f'>>> Running {args.task}') 177 | args.output_dir = os.path.join(args.output_dir, args.task) 178 | print(f'>>> Output directory: {args.output_dir}') 179 | print(f'-----------------------------------------------------------------------------') 180 | 181 | task_dick[args.task](args, load_model(args)) -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/data/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "shape": [ 3 | "a point cloud model of {}.", 4 | "There is a {} in the scene.", 5 | "There is the {} in the scene.", 6 | "a photo of a {} in the scene.", 7 | "a photo of the {} in the scene.", 8 | "a photo of one {} in the scene.", 9 | "itap of a {}.", 10 | "itap of my {}.", 11 | "itap of the {}.", 12 | "a photo of a {}.", 13 | "a photo of my {}.", 14 | "a photo of the {}.", 15 | "a photo of one {}.", 16 | "a photo of many {}.", 17 | "a good photo of a {}.", 18 | "a good photo of the {}.", 19 | "a bad photo of a {}.", 20 | "a bad photo of the {}.", 21 | "a photo of a nice {}.", 22 | "a photo of the nice {}.", 23 | "a photo of a cool {}.", 24 | "a photo of the cool {}.", 25 | "a photo of a weird {}.", 26 | "a photo of the weird {}.", 27 | "a photo of a small {}.", 28 | "a photo of the small {}.", 29 | "a photo of a large {}.", 30 | "a photo of the large {}.", 31 | "a photo of a clean {}.", 32 | "a photo of the clean {}.", 33 | "a photo of a dirty {}.", 34 | "a photo of the dirty {}.", 35 | "a bright photo of a {}.", 36 | "a bright photo of the {}.", 37 | "a dark photo of a {}.", 38 | "a dark photo of the {}.", 39 | "a photo of a hard to see {}.", 40 | "a photo of the hard to see {}.", 41 | "a low resolution photo of a {}.", 42 | "a low resolution photo of the {}.", 43 | "a cropped photo of a {}.", 44 | "a cropped photo of the {}.", 45 | "a close-up photo of a {}.", 46 | "a close-up photo of the {}.", 47 | "a jpeg corrupted photo of a {}.", 48 | "a jpeg corrupted photo of the {}.", 49 | "a blurry photo of a {}.", 50 | "a blurry photo of the {}.", 51 | "a pixelated photo of a {}.", 52 | "a pixelated photo of the {}.", 53 | "a black and white photo of the {}.", 54 | "a black and white photo of a {}", 55 | "a plastic {}.", 56 | "the plastic {}.", 57 | "a toy {}.", 58 | "the toy {}.", 59 | "a plushie {}.", 60 | "the plushie {}.", 61 | "a cartoon {}.", 62 | "the cartoon {}.", 63 | "an embroidered {}.", 64 | "the embroidered {}.", 65 | "a painting of the {}.", 66 | "a painting of a {}." 67 | ] 68 | 69 | } -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/data/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def worker_init_fn(_): 8 | worker_info = torch.utils.data.get_worker_info() 9 | worker_id = worker_info.id 10 | 11 | # dataset = worker_info.dataset 12 | # split_size = dataset.num_records // worker_info.num_workers 13 | # # reset num_records to the true number to retain reliable length information 14 | # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 15 | # current_id = np.random.choice(len(np.random.get_state()[1]), 1) 16 | # return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 17 | 18 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 19 | 20 | 21 | def collation_fn(samples, combine_tensors=True, combine_scalars=True): 22 | """ 23 | 24 | Args: 25 | samples (list[dict]): 26 | combine_tensors: 27 | combine_scalars: 28 | 29 | Returns: 30 | 31 | """ 32 | 33 | result = {} 34 | 35 | keys = samples[0].keys() 36 | 37 | for key in keys: 38 | result[key] = [] 39 | 40 | for sample in samples: 41 | for key in keys: 42 | val = sample[key] 43 | result[key].append(val) 44 | 45 | for key in keys: 46 | val_list = result[key] 47 | if isinstance(val_list[0], (int, float)): 48 | if combine_scalars: 49 | result[key] = np.array(result[key]) 50 | 51 | elif isinstance(val_list[0], torch.Tensor): 52 | if combine_tensors: 53 | result[key] = torch.stack(val_list) 54 | 55 | elif isinstance(val_list[0], np.ndarray): 56 | if combine_tensors: 57 | result[key] = np.stack(val_list) 58 | 59 | return result 60 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/graphics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/graphics/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .volume import generate_dense_grid_points 4 | 5 | from .mesh import ( 6 | MeshOutput, 7 | save_obj, 8 | savemeshtes2 9 | ) 10 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/graphics/primitives/mesh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import PIL.Image 7 | from typing import Optional 8 | 9 | import trimesh 10 | 11 | 12 | def save_obj(pointnp_px3, facenp_fx3, fname): 13 | fid = open(fname, "w") 14 | write_str = "" 15 | for pidx, p in enumerate(pointnp_px3): 16 | pp = p 17 | write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2]) 18 | 19 | for i, f in enumerate(facenp_fx3): 20 | f1 = f + 1 21 | write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2]) 22 | fid.write(write_str) 23 | fid.close() 24 | return 25 | 26 | 27 | def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname): 28 | fol, na = os.path.split(fname) 29 | na, _ = os.path.splitext(na) 30 | 31 | matname = "%s/%s.mtl" % (fol, na) 32 | fid = open(matname, "w") 33 | fid.write("newmtl material_0\n") 34 | fid.write("Kd 1 1 1\n") 35 | fid.write("Ka 0 0 0\n") 36 | fid.write("Ks 0.4 0.4 0.4\n") 37 | fid.write("Ns 10\n") 38 | fid.write("illum 2\n") 39 | fid.write("map_Kd %s.png\n" % na) 40 | fid.close() 41 | #### 42 | 43 | fid = open(fname, "w") 44 | fid.write("mtllib %s.mtl\n" % na) 45 | 46 | for pidx, p in enumerate(pointnp_px3): 47 | pp = p 48 | fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2])) 49 | 50 | for pidx, p in enumerate(tcoords_px2): 51 | pp = p 52 | fid.write("vt %f %f\n" % (pp[0], pp[1])) 53 | 54 | fid.write("usemtl material_0\n") 55 | for i, f in enumerate(facenp_fx3): 56 | f1 = f + 1 57 | f2 = facetex_fx3[i] + 1 58 | fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) 59 | fid.close() 60 | 61 | PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save( 62 | os.path.join(fol, "%s.png" % na)) 63 | 64 | return 65 | 66 | 67 | class MeshOutput(object): 68 | 69 | def __init__(self, 70 | mesh_v: np.ndarray, 71 | mesh_f: np.ndarray, 72 | vertex_colors: Optional[np.ndarray] = None, 73 | uvs: Optional[np.ndarray] = None, 74 | mesh_tex_idx: Optional[np.ndarray] = None, 75 | tex_map: Optional[np.ndarray] = None): 76 | 77 | self.mesh_v = mesh_v 78 | self.mesh_f = mesh_f 79 | self.vertex_colors = vertex_colors 80 | self.uvs = uvs 81 | self.mesh_tex_idx = mesh_tex_idx 82 | self.tex_map = tex_map 83 | 84 | def contain_uv_texture(self): 85 | return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None) 86 | 87 | def contain_vertex_colors(self): 88 | return self.vertex_colors is not None 89 | 90 | def export(self, fname): 91 | 92 | if self.contain_uv_texture(): 93 | savemeshtes2( 94 | self.mesh_v, 95 | self.uvs, 96 | self.mesh_f, 97 | self.mesh_tex_idx, 98 | self.tex_map, 99 | fname 100 | ) 101 | 102 | elif self.contain_vertex_colors(): 103 | mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors) 104 | mesh_obj.export(fname) 105 | 106 | else: 107 | save_obj( 108 | self.mesh_v, 109 | self.mesh_f, 110 | fname 111 | ) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/graphics/primitives/volume.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | 6 | def generate_dense_grid_points(bbox_min: np.ndarray, 7 | bbox_max: np.ndarray, 8 | octree_depth: int, 9 | indexing: str = "ij"): 10 | length = bbox_max - bbox_min 11 | num_cells = np.exp2(octree_depth) 12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) 13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) 14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) 15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) 16 | xyz = np.stack((xs, ys, zs), axis=-1) 17 | xyz = xyz.reshape(-1, 3) 18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] 19 | 20 | return xyz, grid_size, length 21 | 22 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/asl_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/asl_diffusion/asl_udt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Optional 6 | from diffusers.models.embeddings import Timesteps 7 | import math 8 | 9 | from meshanything_train.miche.michelangelo.models.modules.transformer_blocks import MLP 10 | from meshanything_train.miche.michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer 11 | 12 | 13 | class ConditionalASLUDTDenoiser(nn.Module): 14 | 15 | def __init__(self, *, 16 | device: Optional[torch.device], 17 | dtype: Optional[torch.dtype], 18 | input_channels: int, 19 | output_channels: int, 20 | n_ctx: int, 21 | width: int, 22 | layers: int, 23 | heads: int, 24 | context_dim: int, 25 | context_ln: bool = True, 26 | skip_ln: bool = False, 27 | init_scale: float = 0.25, 28 | flip_sin_to_cos: bool = False, 29 | use_checkpoint: bool = False): 30 | super().__init__() 31 | 32 | self.use_checkpoint = use_checkpoint 33 | 34 | init_scale = init_scale * math.sqrt(1.0 / width) 35 | 36 | self.backbone = UNetDiffusionTransformer( 37 | device=device, 38 | dtype=dtype, 39 | n_ctx=n_ctx, 40 | width=width, 41 | layers=layers, 42 | heads=heads, 43 | skip_ln=skip_ln, 44 | init_scale=init_scale, 45 | use_checkpoint=use_checkpoint 46 | ) 47 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 48 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) 49 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) 50 | 51 | # timestep embedding 52 | self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) 53 | self.time_proj = MLP( 54 | device=device, dtype=dtype, width=width, init_scale=init_scale 55 | ) 56 | 57 | self.context_embed = nn.Sequential( 58 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 59 | nn.Linear(context_dim, width, device=device, dtype=dtype), 60 | ) 61 | 62 | if context_ln: 63 | self.context_embed = nn.Sequential( 64 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 65 | nn.Linear(context_dim, width, device=device, dtype=dtype), 66 | ) 67 | else: 68 | self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) 69 | 70 | def forward(self, 71 | model_input: torch.FloatTensor, 72 | timestep: torch.LongTensor, 73 | context: torch.FloatTensor): 74 | 75 | r""" 76 | Args: 77 | model_input (torch.FloatTensor): [bs, n_data, c] 78 | timestep (torch.LongTensor): [bs,] 79 | context (torch.FloatTensor): [bs, context_tokens, c] 80 | 81 | Returns: 82 | sample (torch.FloatTensor): [bs, n_data, c] 83 | 84 | """ 85 | 86 | _, n_data, _ = model_input.shape 87 | 88 | # 1. time 89 | t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) 90 | 91 | # 2. conditions projector 92 | context = self.context_embed(context) 93 | 94 | # 3. denoiser 95 | x = self.input_proj(model_input) 96 | x = torch.cat([t_emb, context, x], dim=1) 97 | x = self.backbone(x) 98 | x = self.ln_post(x) 99 | x = x[:, -n_data:] 100 | sample = self.output_proj(x) 101 | 102 | return sample 103 | 104 | 105 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/asl_diffusion/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BaseDenoiser(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, t, context): 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/asl_diffusion/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from typing import Tuple, List, Union, Optional 6 | from diffusers.schedulers import DDIMScheduler 7 | 8 | 9 | __all__ = ["ddim_sample"] 10 | 11 | 12 | def ddim_sample(ddim_scheduler: DDIMScheduler, 13 | diffusion_model: torch.nn.Module, 14 | shape: Union[List[int], Tuple[int]], 15 | cond: torch.FloatTensor, 16 | steps: int, 17 | eta: float = 0.0, 18 | guidance_scale: float = 3.0, 19 | do_classifier_free_guidance: bool = True, 20 | generator: Optional[torch.Generator] = None, 21 | device: torch.device = "cuda:0", 22 | disable_prog: bool = True): 23 | 24 | assert steps > 0, f"{steps} must > 0." 25 | 26 | # init latents 27 | bsz = cond.shape[0] 28 | if do_classifier_free_guidance: 29 | bsz = bsz // 2 30 | 31 | latents = torch.randn( 32 | (bsz, *shape), 33 | generator=generator, 34 | device=cond.device, 35 | dtype=cond.dtype, 36 | ) 37 | # scale the initial noise by the standard deviation required by the scheduler 38 | latents = latents * ddim_scheduler.init_noise_sigma 39 | # set timesteps 40 | ddim_scheduler.set_timesteps(steps) 41 | timesteps = ddim_scheduler.timesteps.to(device) 42 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 43 | # eta (η) is only used with the DDIMScheduler, and between [0, 1] 44 | extra_step_kwargs = { 45 | "eta": eta, 46 | "generator": generator 47 | } 48 | 49 | # reverse 50 | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): 51 | # expand the latents if we are doing classifier free guidance 52 | latent_model_input = ( 53 | torch.cat([latents] * 2) 54 | if do_classifier_free_guidance 55 | else latents 56 | ) 57 | # latent_model_input = scheduler.scale_model_input(latent_model_input, t) 58 | # predict the noise residual 59 | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) 60 | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) 61 | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) 62 | 63 | # perform guidance 64 | if do_classifier_free_guidance: 65 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 66 | noise_pred = noise_pred_uncond + guidance_scale * ( 67 | noise_pred_text - noise_pred_uncond 68 | ) 69 | # text_embeddings_for_guidance = encoder_hidden_states.chunk( 70 | # 2)[1] if do_classifier_free_guidance else encoder_hidden_states 71 | # compute the previous noisy sample x_t -> x_t-1 72 | latents = ddim_scheduler.step( 73 | noise_pred, t, latents, **extra_step_kwargs 74 | ).prev_sample 75 | 76 | yield latents, t 77 | 78 | 79 | def karra_sample(): 80 | pass 81 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/conditional_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .clip import CLIPEncoder 4 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/conditional_encoders/clip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from dataclasses import dataclass 7 | from torchvision.transforms import Normalize 8 | from transformers import CLIPModel, CLIPTokenizer 9 | from transformers.utils import ModelOutput 10 | from typing import Iterable, Optional, Union, List 11 | 12 | 13 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 14 | 15 | 16 | @dataclass 17 | class CLIPEmbedOutput(ModelOutput): 18 | last_hidden_state: torch.FloatTensor = None 19 | pooler_output: torch.FloatTensor = None 20 | embeds: torch.FloatTensor = None 21 | 22 | 23 | class CLIPEncoder(torch.nn.Module): 24 | 25 | def __init__(self, model_path="openai/clip-vit-base-patch32"): 26 | 27 | super().__init__() 28 | 29 | # Load the CLIP model and processor 30 | self.model: CLIPModel = CLIPModel.from_pretrained(model_path) 31 | self.tokenizer = CLIPTokenizer.from_pretrained(model_path) 32 | self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 33 | 34 | self.model.training = False 35 | for p in self.model.parameters(): 36 | p.requires_grad = False 37 | 38 | @torch.no_grad() 39 | def encode_image(self, images: Iterable[Optional[ImageType]]): 40 | pixel_values = self.image_preprocess(images) 41 | 42 | vision_outputs = self.model.vision_model(pixel_values=pixel_values) 43 | 44 | pooler_output = vision_outputs[1] # pooled_output 45 | image_features = self.model.visual_projection(pooler_output) 46 | 47 | visual_embeds = CLIPEmbedOutput( 48 | last_hidden_state=vision_outputs.last_hidden_state, 49 | pooler_output=pooler_output, 50 | embeds=image_features 51 | ) 52 | 53 | return visual_embeds 54 | 55 | @torch.no_grad() 56 | def encode_text(self, texts: List[str]): 57 | text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") 58 | 59 | text_outputs = self.model.text_model(input_ids=text_inputs) 60 | 61 | pooler_output = text_outputs[1] # pooled_output 62 | text_features = self.model.text_projection(pooler_output) 63 | 64 | text_embeds = CLIPEmbedOutput( 65 | last_hidden_state=text_outputs.last_hidden_state, 66 | pooler_output=pooler_output, 67 | embeds=text_features 68 | ) 69 | 70 | return text_embeds 71 | 72 | def forward(self, 73 | images: Iterable[Optional[ImageType]], 74 | texts: List[str]): 75 | 76 | visual_embeds = self.encode_image(images) 77 | text_embeds = self.encode_text(texts) 78 | 79 | return visual_embeds, text_embeds 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import checkpoint 4 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 4 | """ 5 | 6 | import torch 7 | from typing import Callable, Iterable, Sequence, Union 8 | 9 | 10 | def checkpoint( 11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 12 | inputs: Sequence[torch.Tensor], 13 | params: Iterable[torch.Tensor], 14 | flag: bool, 15 | use_deepspeed: bool = False 16 | ): 17 | """ 18 | Evaluate a function without caching intermediate activations, allowing for 19 | reduced memory at the expense of extra compute in the backward pass. 20 | :param func: the function to evaluate. 21 | :param inputs: the argument sequence to pass to `func`. 22 | :param params: a sequence of parameters `func` depends on but does not 23 | explicitly take as arguments. 24 | :param flag: if False, disable gradient checkpointing. 25 | :param use_deepspeed: if True, use deepspeed 26 | """ 27 | if flag: 28 | if use_deepspeed: 29 | import deepspeed 30 | return deepspeed.checkpointing.checkpoint(func, *inputs) 31 | 32 | args = tuple(inputs) + tuple(params) 33 | return CheckpointFunction.apply(func, len(inputs), *args) 34 | else: 35 | return func(*inputs) 36 | 37 | 38 | class CheckpointFunction(torch.autograd.Function): 39 | @staticmethod 40 | @torch.cuda.amp.custom_fwd 41 | def forward(ctx, run_function, length, *args): 42 | ctx.run_function = run_function 43 | ctx.input_tensors = list(args[:length]) 44 | ctx.input_params = list(args[length:]) 45 | 46 | with torch.no_grad(): 47 | output_tensors = ctx.run_function(*ctx.input_tensors) 48 | return output_tensors 49 | 50 | @staticmethod 51 | @torch.cuda.amp.custom_bwd 52 | def backward(ctx, *output_grads): 53 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 54 | with torch.enable_grad(): 55 | # Fixes a bug where the first op in run_function modifies the 56 | # Tensor storage in place, which is not allowed for detach()'d 57 | # Tensors. 58 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 59 | output_tensors = ctx.run_function(*shallow_copies) 60 | input_grads = torch.autograd.grad( 61 | output_tensors, 62 | ctx.input_tensors + ctx.input_params, 63 | output_grads, 64 | allow_unused=True, 65 | ) 66 | del ctx.input_tensors 67 | del ctx.input_params 68 | del output_tensors 69 | return (None, None) + input_grads 70 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/modules/diffusion_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from typing import Optional 7 | 8 | from meshanything_train.miche.michelangelo.models.modules.checkpoint import checkpoint 9 | from meshanything_train.miche.michelangelo.models.modules.transformer_blocks import ( 10 | init_linear, 11 | MLP, 12 | MultiheadCrossAttention, 13 | MultiheadAttention, 14 | ResidualAttentionBlock 15 | ) 16 | 17 | 18 | class AdaLayerNorm(nn.Module): 19 | def __init__(self, 20 | device: torch.device, 21 | dtype: torch.dtype, 22 | width: int): 23 | 24 | super().__init__() 25 | 26 | self.silu = nn.SiLU(inplace=True) 27 | self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype) 28 | self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype) 29 | 30 | def forward(self, x, timestep): 31 | emb = self.linear(timestep) 32 | scale, shift = torch.chunk(emb, 2, dim=2) 33 | x = self.layernorm(x) * (1 + scale) + shift 34 | return x 35 | 36 | 37 | class DitBlock(nn.Module): 38 | def __init__( 39 | self, 40 | *, 41 | device: torch.device, 42 | dtype: torch.dtype, 43 | n_ctx: int, 44 | width: int, 45 | heads: int, 46 | context_dim: int, 47 | qkv_bias: bool = False, 48 | init_scale: float = 1.0, 49 | use_checkpoint: bool = False 50 | ): 51 | super().__init__() 52 | 53 | self.use_checkpoint = use_checkpoint 54 | 55 | self.attn = MultiheadAttention( 56 | device=device, 57 | dtype=dtype, 58 | n_ctx=n_ctx, 59 | width=width, 60 | heads=heads, 61 | init_scale=init_scale, 62 | qkv_bias=qkv_bias 63 | ) 64 | self.ln_1 = AdaLayerNorm(device, dtype, width) 65 | 66 | if context_dim is not None: 67 | self.ln_2 = AdaLayerNorm(device, dtype, width) 68 | self.cross_attn = MultiheadCrossAttention( 69 | device=device, 70 | dtype=dtype, 71 | width=width, 72 | heads=heads, 73 | data_width=context_dim, 74 | init_scale=init_scale, 75 | qkv_bias=qkv_bias 76 | ) 77 | 78 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 79 | self.ln_3 = AdaLayerNorm(device, dtype, width) 80 | 81 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 82 | return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint) 83 | 84 | def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 85 | x = x + self.attn(self.ln_1(x, t)) 86 | if context is not None: 87 | x = x + self.cross_attn(self.ln_2(x, t), context) 88 | x = x + self.mlp(self.ln_3(x, t)) 89 | return x 90 | 91 | 92 | class DiT(nn.Module): 93 | def __init__( 94 | self, 95 | *, 96 | device: Optional[torch.device], 97 | dtype: Optional[torch.dtype], 98 | n_ctx: int, 99 | width: int, 100 | layers: int, 101 | heads: int, 102 | context_dim: int, 103 | init_scale: float = 0.25, 104 | qkv_bias: bool = False, 105 | use_checkpoint: bool = False 106 | ): 107 | super().__init__() 108 | self.n_ctx = n_ctx 109 | self.width = width 110 | self.layers = layers 111 | 112 | self.resblocks = nn.ModuleList( 113 | [ 114 | DitBlock( 115 | device=device, 116 | dtype=dtype, 117 | n_ctx=n_ctx, 118 | width=width, 119 | heads=heads, 120 | context_dim=context_dim, 121 | qkv_bias=qkv_bias, 122 | init_scale=init_scale, 123 | use_checkpoint=use_checkpoint 124 | ) 125 | for _ in range(layers) 126 | ] 127 | ) 128 | 129 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 130 | for block in self.resblocks: 131 | x = block(x, t, context) 132 | return x 133 | 134 | 135 | class UNetDiffusionTransformer(nn.Module): 136 | def __init__( 137 | self, 138 | *, 139 | device: Optional[torch.device], 140 | dtype: Optional[torch.dtype], 141 | n_ctx: int, 142 | width: int, 143 | layers: int, 144 | heads: int, 145 | init_scale: float = 0.25, 146 | qkv_bias: bool = False, 147 | skip_ln: bool = False, 148 | use_checkpoint: bool = False 149 | ): 150 | super().__init__() 151 | 152 | self.n_ctx = n_ctx 153 | self.width = width 154 | self.layers = layers 155 | 156 | self.encoder = nn.ModuleList() 157 | for _ in range(layers): 158 | resblock = ResidualAttentionBlock( 159 | device=device, 160 | dtype=dtype, 161 | n_ctx=n_ctx, 162 | width=width, 163 | heads=heads, 164 | init_scale=init_scale, 165 | qkv_bias=qkv_bias, 166 | use_checkpoint=use_checkpoint 167 | ) 168 | self.encoder.append(resblock) 169 | 170 | self.middle_block = ResidualAttentionBlock( 171 | device=device, 172 | dtype=dtype, 173 | n_ctx=n_ctx, 174 | width=width, 175 | heads=heads, 176 | init_scale=init_scale, 177 | qkv_bias=qkv_bias, 178 | use_checkpoint=use_checkpoint 179 | ) 180 | 181 | self.decoder = nn.ModuleList() 182 | for _ in range(layers): 183 | resblock = ResidualAttentionBlock( 184 | device=device, 185 | dtype=dtype, 186 | n_ctx=n_ctx, 187 | width=width, 188 | heads=heads, 189 | init_scale=init_scale, 190 | qkv_bias=qkv_bias, 191 | use_checkpoint=use_checkpoint 192 | ) 193 | linear = nn.Linear(width * 2, width, device=device, dtype=dtype) 194 | init_linear(linear, init_scale) 195 | 196 | layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None 197 | 198 | self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) 199 | 200 | def forward(self, x: torch.Tensor): 201 | 202 | enc_outputs = [] 203 | for block in self.encoder: 204 | x = block(x) 205 | enc_outputs.append(x) 206 | 207 | x = self.middle_block(x) 208 | 209 | for i, (resblock, linear, layer_norm) in enumerate(self.decoder): 210 | x = torch.cat([enc_outputs.pop(), x], dim=-1) 211 | x = linear(x) 212 | 213 | if layer_norm is not None: 214 | x = layer_norm(x) 215 | 216 | x = resblock(x) 217 | 218 | return x 219 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union, List 4 | 5 | 6 | class AbstractDistribution(object): 7 | def sample(self): 8 | raise NotImplementedError() 9 | 10 | def mode(self): 11 | raise NotImplementedError() 12 | 13 | 14 | class DiracDistribution(AbstractDistribution): 15 | def __init__(self, value): 16 | self.value = value 17 | 18 | def sample(self): 19 | return self.value 20 | 21 | def mode(self): 22 | return self.value 23 | 24 | 25 | class DiagonalGaussianDistribution(object): 26 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): 27 | self.feat_dim = feat_dim 28 | self.parameters = parameters 29 | 30 | if isinstance(parameters, list): 31 | self.mean = parameters[0] 32 | self.logvar = parameters[1] 33 | else: 34 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) 35 | 36 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 37 | self.deterministic = deterministic 38 | self.std = torch.exp(0.5 * self.logvar) 39 | self.var = torch.exp(self.logvar) 40 | if self.deterministic: 41 | self.var = self.std = torch.zeros_like(self.mean) 42 | 43 | def sample(self): 44 | x = self.mean + self.std * torch.randn_like(self.mean) 45 | return x 46 | 47 | def kl(self, other=None, dims=(1, 2, 3)): 48 | if self.deterministic: 49 | return torch.Tensor([0.]) 50 | else: 51 | if other is None: 52 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 53 | + self.var - 1.0 - self.logvar, 54 | dim=dims) 55 | else: 56 | return 0.5 * torch.mean( 57 | torch.pow(self.mean - other.mean, 2) / other.var 58 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 59 | dim=dims) 60 | 61 | def nll(self, sample, dims=(1, 2, 3)): 62 | if self.deterministic: 63 | return torch.Tensor([0.]) 64 | logtwopi = np.log(2.0 * np.pi) 65 | return 0.5 * torch.sum( 66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 67 | dim=dims) 68 | 69 | def mode(self): 70 | return self.mean 71 | 72 | 73 | def normal_kl(mean1, logvar1, mean2, logvar2): 74 | """ 75 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 76 | Compute the KL divergence between two gaussians. 77 | Shapes are automatically broadcasted, so batches can be compared to 78 | scalars, among other use cases. 79 | """ 80 | tensor = None 81 | for obj in (mean1, logvar1, mean2, logvar2): 82 | if isinstance(obj, torch.Tensor): 83 | tensor = obj 84 | break 85 | assert tensor is not None, "at least one argument must be a Tensor" 86 | 87 | # Force variances to be Tensors. Broadcasting helps convert scalars to 88 | # Tensors, but it does not work for torch.exp(). 89 | logvar1, logvar2 = [ 90 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 91 | for x in (logvar1, logvar2) 92 | ] 93 | 94 | return 0.5 * ( 95 | -1.0 96 | + logvar2 97 | - logvar1 98 | + torch.exp(logvar1 - logvar2) 99 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 100 | ) 101 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/modules/embedder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] 9 | 10 | 11 | class FourierEmbedder(nn.Module): 12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts 13 | each feature dimension of `x[..., i]` into: 14 | [ 15 | sin(x[..., i]), 16 | sin(f_1*x[..., i]), 17 | sin(f_2*x[..., i]), 18 | ... 19 | sin(f_N * x[..., i]), 20 | cos(x[..., i]), 21 | cos(f_1*x[..., i]), 22 | cos(f_2*x[..., i]), 23 | ... 24 | cos(f_N * x[..., i]), 25 | x[..., i] # only present if include_input is True. 26 | ], here f_i is the frequency. 27 | 28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. 29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; 30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. 31 | 32 | Args: 33 | num_freqs (int): the number of frequencies, default is 6; 34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; 36 | input_dim (int): the input dimension, default is 3; 37 | include_input (bool): include the input tensor or not, default is True. 38 | 39 | Attributes: 40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); 42 | 43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), 44 | otherwise, it is input_dim * num_freqs * 2. 45 | 46 | """ 47 | 48 | def __init__(self, 49 | num_freqs: int = 6, 50 | logspace: bool = True, 51 | input_dim: int = 3, 52 | include_input: bool = True, 53 | include_pi: bool = True) -> None: 54 | 55 | """The initialization""" 56 | 57 | super().__init__() 58 | 59 | if logspace: 60 | frequencies = 2.0 ** torch.arange( 61 | num_freqs, 62 | dtype=torch.float32 63 | ) 64 | else: 65 | frequencies = torch.linspace( 66 | 1.0, 67 | 2.0 ** (num_freqs - 1), 68 | num_freqs, 69 | dtype=torch.float32 70 | ) 71 | 72 | if include_pi: 73 | frequencies *= torch.pi 74 | 75 | self.register_buffer("frequencies", frequencies, persistent=False) 76 | self.include_input = include_input 77 | self.num_freqs = num_freqs 78 | 79 | self.out_dim = self.get_dims(input_dim) 80 | 81 | def get_dims(self, input_dim): 82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0 83 | out_dim = input_dim * (self.num_freqs * 2 + temp) 84 | 85 | return out_dim 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | """ Forward process. 89 | 90 | Args: 91 | x: tensor of shape [..., dim] 92 | 93 | Returns: 94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] 95 | where temp is 1 if include_input is True and 0 otherwise. 96 | """ 97 | 98 | if self.num_freqs > 0: 99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) 100 | if self.include_input: 101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) 102 | else: 103 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 104 | else: 105 | return x 106 | 107 | 108 | class LearnedFourierEmbedder(nn.Module): 109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """ 110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 111 | 112 | def __init__(self, in_channels, dim): 113 | super().__init__() 114 | assert (dim % 2) == 0 115 | half_dim = dim // 2 116 | per_channel_dim = half_dim // in_channels 117 | self.weights = nn.Parameter(torch.randn(per_channel_dim)) 118 | 119 | def forward(self, x): 120 | """ 121 | 122 | Args: 123 | x (torch.FloatTensor): [..., c] 124 | 125 | Returns: 126 | x (torch.FloatTensor): [..., d] 127 | """ 128 | 129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] 130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) 131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) 132 | return fouriered 133 | 134 | 135 | class TriplaneLearnedFourierEmbedder(nn.Module): 136 | def __init__(self, in_channels, dim): 137 | super().__init__() 138 | 139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 142 | 143 | self.out_dim = in_channels + dim 144 | 145 | def forward(self, x): 146 | 147 | yz_embed = self.yz_plane_embedder(x) 148 | xz_embed = self.xz_plane_embedder(x) 149 | xy_embed = self.xy_plane_embedder(x) 150 | 151 | embed = yz_embed + xz_embed + xy_embed 152 | 153 | return embed 154 | 155 | 156 | def sequential_pos_embed(num_len, embed_dim): 157 | assert embed_dim % 2 == 0 158 | 159 | pos = torch.arange(num_len, dtype=torch.float32) 160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32) 161 | omega /= embed_dim / 2. 162 | omega = 1. / 10000 ** omega # (D/2,) 163 | 164 | pos = pos.reshape(-1) # (M,) 165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 166 | 167 | emb_sin = torch.sin(out) # (M, D/2) 168 | emb_cos = torch.cos(out) # (M, D/2) 169 | 170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 171 | 172 | return embeddings 173 | 174 | 175 | def timestep_embedding(timesteps, dim, max_period=10000): 176 | """ 177 | Create sinusoidal timestep embeddings. 178 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 179 | These may be fractional. 180 | :param dim: the dimension of the output. 181 | :param max_period: controls the minimum frequency of the embeddings. 182 | :return: an [N x dim] Tensor of positional embeddings. 183 | """ 184 | half = dim // 2 185 | freqs = torch.exp( 186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 187 | ).to(device=timesteps.device) 188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None] 189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 190 | if dim % 2: 191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 192 | return embedding 193 | 194 | 195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, 196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, 197 | log2_hashmap_size=19, desired_resolution=None): 198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): 199 | return nn.Identity(), input_dim 200 | 201 | elif embed_type == "fourier": 202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, 203 | logspace=True, include_input=True) 204 | return embedder_obj, embedder_obj.out_dim 205 | 206 | elif embed_type == "hashgrid": 207 | raise NotImplementedError 208 | 209 | elif embed_type == "sphere_harmonic": 210 | raise NotImplementedError 211 | 212 | else: 213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") 214 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/tsal/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/tsal/clip_asl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange 6 | from transformers import CLIPModel 7 | 8 | from meshanything_train.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule 9 | 10 | 11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): 12 | 13 | def __init__(self, *, 14 | shape_model, 15 | clip_model_version: str = "openai/clip-vit-large-patch14"): 16 | 17 | super().__init__() 18 | 19 | # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) 20 | # for params in self.clip_model.parameters(): 21 | # params.requires_grad = False 22 | self.clip_model = None 23 | self.shape_model = shape_model 24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width)) 25 | # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5) 26 | 27 | def set_shape_model_only(self): 28 | self.clip_model = None 29 | 30 | def encode_shape_embed(self, surface, return_latents: bool = False): 31 | """ 32 | 33 | Args: 34 | surface (torch.FloatTensor): [bs, n, 3 + c] 35 | return_latents (bool): 36 | 37 | Returns: 38 | x (torch.FloatTensor): [bs, projection_dim] 39 | shape_latents (torch.FloatTensor): [bs, m, d] 40 | """ 41 | 42 | pc = surface[..., 0:3] 43 | feats = surface[..., 3:] 44 | 45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) 46 | x = shape_embed @ self.shape_projection 47 | 48 | if return_latents: 49 | return x, shape_latents 50 | else: 51 | return x 52 | 53 | def encode_image_embed(self, image): 54 | """ 55 | 56 | Args: 57 | image (torch.FloatTensor): [bs, 3, h, w] 58 | 59 | Returns: 60 | x (torch.FloatTensor): [bs, projection_dim] 61 | """ 62 | 63 | x = self.clip_model.get_image_features(image) 64 | 65 | return x 66 | 67 | def encode_text_embed(self, text): 68 | x = self.clip_model.get_text_features(text) 69 | return x 70 | 71 | def forward(self, surface, image, text): 72 | """ 73 | 74 | Args: 75 | surface (torch.FloatTensor): 76 | image (torch.FloatTensor): [bs, 3, 224, 224] 77 | text (torch.LongTensor): [bs, num_templates, 77] 78 | 79 | Returns: 80 | embed_outputs (dict): the embedding outputs, and it contains: 81 | - image_embed (torch.FloatTensor): 82 | - text_embed (torch.FloatTensor): 83 | - shape_embed (torch.FloatTensor): 84 | - logit_scale (float): 85 | """ 86 | 87 | # # text embedding 88 | # text_embed_all = [] 89 | # for i in range(text.shape[0]): 90 | # text_for_one_sample = text[i] 91 | # text_embed = self.encode_text_embed(text_for_one_sample) 92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 93 | # text_embed = text_embed.mean(dim=0) 94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 95 | # text_embed_all.append(text_embed) 96 | # text_embed_all = torch.stack(text_embed_all) 97 | 98 | b = text.shape[0] 99 | text_tokens = rearrange(text, "b t l -> (b t) l") 100 | text_embed = self.encode_text_embed(text_tokens) 101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 102 | text_embed = text_embed.mean(dim=1) 103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 104 | 105 | # image embedding 106 | image_embed = self.encode_image_embed(image) 107 | 108 | # shape embedding 109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) 110 | 111 | embed_outputs = { 112 | "image_embed": image_embed, 113 | "text_embed": text_embed, 114 | "shape_embed": shape_embed, 115 | # "logit_scale": self.clip_model.logit_scale.exp() 116 | } 117 | 118 | return embed_outputs, shape_latents 119 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/tsal/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from einops import repeat 6 | import numpy as np 7 | from typing import Callable, Tuple, List, Union, Optional 8 | from skimage import measure 9 | 10 | from meshanything_train.miche.michelangelo.graphics.primitives import generate_dense_grid_points 11 | 12 | 13 | @torch.no_grad() 14 | def extract_geometry(geometric_func: Callable, 15 | device: torch.device, 16 | batch_size: int = 1, 17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 18 | octree_depth: int = 7, 19 | num_chunks: int = 10000, 20 | disable: bool = True): 21 | """ 22 | 23 | Args: 24 | geometric_func: 25 | device: 26 | bounds: 27 | octree_depth: 28 | batch_size: 29 | num_chunks: 30 | disable: 31 | 32 | Returns: 33 | 34 | """ 35 | 36 | if isinstance(bounds, float): 37 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] 38 | 39 | bbox_min = np.array(bounds[0:3]) 40 | bbox_max = np.array(bounds[3:6]) 41 | bbox_size = bbox_max - bbox_min 42 | 43 | xyz_samples, grid_size, length = generate_dense_grid_points( 44 | bbox_min=bbox_min, 45 | bbox_max=bbox_max, 46 | octree_depth=octree_depth, 47 | indexing="ij" 48 | ) 49 | xyz_samples = torch.FloatTensor(xyz_samples) 50 | 51 | batch_logits = [] 52 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), 53 | desc="Implicit Function:", disable=disable, leave=False): 54 | queries = xyz_samples[start: start + num_chunks, :].to(device) 55 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size) 56 | 57 | logits = geometric_func(batch_queries) 58 | batch_logits.append(logits.cpu()) 59 | 60 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() 61 | 62 | mesh_v_f = [] 63 | has_surface = np.zeros((batch_size,), dtype=np.bool_) 64 | for i in range(batch_size): 65 | try: 66 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") 67 | vertices = vertices / grid_size * bbox_size + bbox_min 68 | # vertices[:, [0, 1]] = vertices[:, [1, 0]] 69 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) 70 | has_surface[i] = True 71 | 72 | except ValueError: 73 | mesh_v_f.append((None, None)) 74 | has_surface[i] = False 75 | 76 | except RuntimeError: 77 | mesh_v_f.append((None, None)) 78 | has_surface[i] = False 79 | 80 | return mesh_v_f, has_surface 81 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/models/tsal/tsal_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from typing import Tuple, List, Optional 5 | 6 | 7 | class Point2MeshOutput(object): 8 | def __init__(self): 9 | self.mesh_v = None 10 | self.mesh_f = None 11 | self.center = None 12 | self.pc = None 13 | 14 | 15 | class Latent2MeshOutput(object): 16 | 17 | def __init__(self): 18 | self.mesh_v = None 19 | self.mesh_f = None 20 | 21 | 22 | class AlignedMeshOutput(object): 23 | 24 | def __init__(self): 25 | self.mesh_v = None 26 | self.mesh_f = None 27 | self.surface = None 28 | self.image = None 29 | self.text: Optional[str] = None 30 | self.shape_text_similarity: Optional[float] = None 31 | self.shape_image_similarity: Optional[float] = None 32 | 33 | 34 | class ShapeAsLatentPLModule(nn.Module): 35 | latent_shape: Tuple[int] 36 | 37 | def encode(self, surface, *args, **kwargs): 38 | raise NotImplementedError 39 | 40 | def decode(self, z_q, *args, **kwargs): 41 | raise NotImplementedError 42 | 43 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 44 | raise NotImplementedError 45 | 46 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 47 | raise NotImplementedError 48 | 49 | 50 | class ShapeAsLatentModule(nn.Module): 51 | latent_shape: Tuple[int, int] 52 | 53 | def __init__(self, *args, **kwargs): 54 | super().__init__() 55 | 56 | def encode(self, *args, **kwargs): 57 | raise NotImplementedError 58 | 59 | def decode(self, *args, **kwargs): 60 | raise NotImplementedError 61 | 62 | def query_geometry(self, *args, **kwargs): 63 | raise NotImplementedError 64 | 65 | 66 | class AlignedShapeAsLatentPLModule(nn.Module): 67 | latent_shape: Tuple[int] 68 | 69 | def set_shape_model_only(self): 70 | raise NotImplementedError 71 | 72 | def encode(self, surface, *args, **kwargs): 73 | raise NotImplementedError 74 | 75 | def decode(self, z_q, *args, **kwargs): 76 | raise NotImplementedError 77 | 78 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 79 | raise NotImplementedError 80 | 81 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 82 | raise NotImplementedError 83 | 84 | 85 | class AlignedShapeAsLatentModule(nn.Module): 86 | shape_model: ShapeAsLatentModule 87 | latent_shape: Tuple[int, int] 88 | 89 | def __init__(self, *args, **kwargs): 90 | super().__init__() 91 | 92 | def set_shape_model_only(self): 93 | raise NotImplementedError 94 | 95 | def encode_image_embed(self, *args, **kwargs): 96 | raise NotImplementedError 97 | 98 | def encode_text_embed(self, *args, **kwargs): 99 | raise NotImplementedError 100 | 101 | def encode_shape_embed(self, *args, **kwargs): 102 | raise NotImplementedError 103 | 104 | 105 | class TexturedShapeAsLatentModule(nn.Module): 106 | 107 | def __init__(self, *args, **kwargs): 108 | super().__init__() 109 | 110 | def encode(self, *args, **kwargs): 111 | raise NotImplementedError 112 | 113 | def decode(self, *args, **kwargs): 114 | raise NotImplementedError 115 | 116 | def query_geometry(self, *args, **kwargs): 117 | raise NotImplementedError 118 | 119 | def query_color(self, *args, **kwargs): 120 | raise NotImplementedError 121 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .misc import instantiate_from_config 4 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/utils/eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7): 7 | 8 | mse = torch.mean((x - y) ** 2) 9 | psnr = 10 * torch.log10(data_range / (mse + eps)) 10 | 11 | return psnr 12 | 13 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import io 5 | import tarfile 6 | import json 7 | import numpy as np 8 | import numpy.lib.format 9 | 10 | 11 | def mkdir(path): 12 | os.makedirs(path, exist_ok=True) 13 | return path 14 | 15 | 16 | def npy_loads(data): 17 | stream = io.BytesIO(data) 18 | return np.lib.format.read_array(stream) 19 | 20 | 21 | def npz_loads(data): 22 | return np.load(io.BytesIO(data)) 23 | 24 | 25 | def json_loads(data): 26 | return json.loads(data) 27 | 28 | 29 | def load_json(filepath): 30 | with open(filepath, "r") as f: 31 | data = json.load(f) 32 | return data 33 | 34 | 35 | def write_json(filepath, data): 36 | with open(filepath, "w") as f: 37 | json.dump(data, f, indent=2) 38 | 39 | 40 | def extract_tar(tar_path, tar_cache_folder): 41 | 42 | with tarfile.open(tar_path, "r") as tar: 43 | tar.extractall(path=tar_cache_folder) 44 | 45 | tar_uids = sorted(os.listdir(tar_cache_folder)) 46 | print(f"extract tar: {tar_path} to {tar_cache_folder}") 47 | return tar_uids 48 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import importlib 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from typing import Union 8 | 9 | 10 | 11 | 12 | def get_obj_from_str(string, reload=False): 13 | module, cls = string.rsplit(".", 1) 14 | if reload: 15 | module_imp = importlib.import_module(module) 16 | importlib.reload(module_imp) 17 | return getattr(importlib.import_module(module, package=None), cls) 18 | 19 | 20 | def get_obj_from_config(config): 21 | if "target" not in config: 22 | raise KeyError("Expected key `target` to instantiate.") 23 | 24 | return get_obj_from_str(config["target"]) 25 | 26 | 27 | def instantiate_from_config(config, **kwargs): 28 | if "target" not in config: 29 | raise KeyError("Expected key `target` to instantiate.") 30 | 31 | cls = get_obj_from_str(config["target"]) 32 | 33 | params = config.get("params", dict()) 34 | # params.update(kwargs) 35 | # instance = cls(**params) 36 | kwargs.update(params) 37 | instance = cls(**kwargs) 38 | 39 | return instance 40 | 41 | 42 | def is_dist_avail_and_initialized(): 43 | if not dist.is_available(): 44 | return False 45 | if not dist.is_initialized(): 46 | return False 47 | return True 48 | 49 | 50 | def get_rank(): 51 | if not is_dist_avail_and_initialized(): 52 | return 0 53 | return dist.get_rank() 54 | 55 | 56 | def get_world_size(): 57 | if not is_dist_avail_and_initialized(): 58 | return 1 59 | return dist.get_world_size() 60 | 61 | 62 | def all_gather_batch(tensors): 63 | """ 64 | Performs all_gather operation on the provided tensors. 65 | """ 66 | # Queue the gathered tensors 67 | world_size = get_world_size() 68 | # There is no need for reduction in the single-proc case 69 | if world_size == 1: 70 | return tensors 71 | tensor_list = [] 72 | output_tensor = [] 73 | for tensor in tensors: 74 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 75 | dist.all_gather( 76 | tensor_all, 77 | tensor, 78 | async_op=False # performance opt 79 | ) 80 | 81 | tensor_list.append(tensor_all) 82 | 83 | for tensor_all in tensor_list: 84 | output_tensor.append(torch.cat(tensor_all, dim=0)) 85 | return output_tensor 86 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/utils/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/utils/visualizers/color_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | # Helper functions 6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): 7 | colormap = plt.cm.get_cmap(colormap) 8 | if normalize: 9 | vmin = np.min(inp) 10 | vmax = np.max(inp) 11 | 12 | norm = plt.Normalize(vmin, vmax) 13 | return colormap(norm(inp))[:, :3] 14 | 15 | 16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): 17 | # tex dims need to be power of two. 18 | array = np.ones((width, height, 3), dtype='float32') 19 | 20 | # width in texels of each checker 21 | checker_w = width / n_checkers_x 22 | checker_h = height / n_checkers_y 23 | 24 | for y in range(height): 25 | for x in range(width): 26 | color_key = int(x / checker_w) + int(y / checker_h) 27 | if color_key % 2 == 0: 28 | array[x, y, :] = [1., 0.874, 0.0] 29 | else: 30 | array[x, y, :] = [0., 0., 0.] 31 | return array 32 | 33 | 34 | def gen_circle(width=256, height=256): 35 | xx, yy = np.mgrid[:width, :height] 36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 37 | array = np.ones((width, height, 4), dtype='float32') 38 | array[:, :, 0] = (circle <= width) 39 | array[:, :, 1] = (circle <= width) 40 | array[:, :, 2] = (circle <= width) 41 | array[:, :, 3] = circle <= width 42 | return array 43 | 44 | -------------------------------------------------------------------------------- /meshanything_train/miche/michelangelo/utils/visualizers/html_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | import base64 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def to_html_frame(content): 9 | 10 | html_frame = f""" 11 | 12 | 13 | {content} 14 | 15 | 16 | """ 17 | 18 | return html_frame 19 | 20 | 21 | def to_single_row_table(caption: str, content: str): 22 | 23 | table_html = f""" 24 | 25 | 26 | 27 | 28 | 29 |
{caption}
{content}
30 | """ 31 | 32 | return table_html 33 | 34 | 35 | def to_image_embed_tag(image: np.ndarray): 36 | 37 | # Convert np.ndarray to bytes 38 | img = Image.fromarray(image) 39 | raw_bytes = io.BytesIO() 40 | img.save(raw_bytes, "PNG") 41 | 42 | # Encode bytes to base64 43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8") 44 | 45 | image_tag = f""" 46 | Embedded Image 47 | """ 48 | 49 | return image_tag 50 | -------------------------------------------------------------------------------- /meshanything_train/miche/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.20.3 3 | addict==2.4.0 4 | aiofiles==23.1.0 5 | aiohttp==3.8.4 6 | aiosignal==1.3.1 7 | altair==5.0.1 8 | antlr4-python3-runtime==4.9.3 9 | anyio==3.6.2 10 | appdirs==1.4.4 11 | argon2-cffi==21.3.0 12 | argon2-cffi-bindings==21.2.0 13 | arrow==1.2.3 14 | asttokens==2.2.1 15 | async-timeout==4.0.2 16 | attrs==22.2.0 17 | backcall==0.2.0 18 | beautifulsoup4==4.11.2 19 | bleach==6.0.0 20 | braceexpand==0.1.7 21 | cachetools==5.3.0 22 | cffi==1.15.1 23 | charset-normalizer==3.0.1 24 | click==8.1.3 25 | coloredlogs==15.0.1 26 | comm==0.1.2 27 | configargparse==1.5.3 28 | contourpy==1.0.7 29 | controlnet-aux==0.0.5 30 | cycler==0.11.0 31 | cython==0.29.33 32 | dash==2.8.1 33 | dash-core-components==2.0.0 34 | dash-html-components==2.0.0 35 | dash-table==5.0.0 36 | dataclasses-json==0.6.0 37 | debugpy==1.6.6 38 | decorator==5.1.1 39 | deepspeed==0.8.1 40 | defusedxml==0.7.1 41 | deprecated==1.2.14 42 | diffusers==0.18.2 43 | docker-pycreds==0.4.0 44 | einops==0.6.0 45 | executing==1.2.0 46 | fastapi==0.101.0 47 | fastjsonschema==2.16.2 48 | ffmpy==0.3.1 49 | filelock==3.9.0 50 | flask==2.2.3 51 | flatbuffers==23.5.26 52 | fonttools==4.38.0 53 | fqdn==1.5.1 54 | frozenlist==1.3.3 55 | fsspec==2023.1.0 56 | ftfy==6.1.1 57 | fvcore==0.1.5.post20221221 58 | gitdb==4.0.10 59 | gitpython==3.1.31 60 | google-auth==2.16.1 61 | google-auth-oauthlib==0.4.6 62 | gradio==3.39.0 63 | gradio-client==0.3.0 64 | grpcio==1.51.3 65 | h11==0.14.0 66 | hjson==3.1.0 67 | httpcore==0.17.3 68 | httpx==0.24.1 69 | huggingface-hub==0.16.4 70 | humanfriendly==10.0 71 | idna==3.4 72 | imageio==2.25.1 73 | importlib-metadata==6.0.0 74 | iopath==0.1.10 75 | ipydatawidgets==4.3.3 76 | ipykernel==6.21.2 77 | ipython==8.10.0 78 | ipython-genutils==0.2.0 79 | ipywidgets==8.0.4 80 | isoduration==20.11.0 81 | itsdangerous==2.1.2 82 | jedi==0.18.2 83 | jinja2==3.1.2 84 | joblib==1.2.0 85 | jsonpointer==2.3 86 | jsonschema==4.17.3 87 | jupyter==1.0.0 88 | jupyter-client==8.0.3 89 | jupyter-console==6.6.1 90 | jupyter-core==5.2.0 91 | jupyter-events==0.6.3 92 | jupyter-server==2.3.0 93 | jupyter-server-terminals==0.4.4 94 | jupyterlab-pygments==0.2.2 95 | jupyterlab-widgets==3.0.5 96 | kiwisolver==1.4.4 97 | lightning-utilities==0.7.1 98 | linkify-it-py==2.0.2 99 | lmdb==1.4.1 100 | markdown==3.4.1 101 | markdown-it-py==2.2.0 102 | markupsafe==2.1.2 103 | marshmallow==3.20.1 104 | matplotlib==3.6.3 105 | matplotlib-inline==0.1.6 106 | mdit-py-plugins==0.3.3 107 | mdurl==0.1.2 108 | mesh2sdf==1.1.0 109 | mistune==2.0.5 110 | mpmath==1.3.0 111 | multidict==6.0.4 112 | mypy-extensions==1.0.0 113 | nbclassic==0.5.2 114 | nbclient==0.7.2 115 | nbconvert==7.2.9 116 | nbformat==5.5.0 117 | nest-asyncio==1.5.6 118 | networkx==3.0 119 | ninja==1.11.1 120 | notebook==6.5.2 121 | notebook-shim==0.2.2 122 | numpy==1.23.1 123 | oauthlib==3.2.2 124 | objaverse==0.0.7 125 | omegaconf==2.3.0 126 | onnxruntime==1.15.1 127 | opencv-contrib-python==4.8.0.74 128 | opencv-python==4.7.0.72 129 | orjson==3.9.2 130 | packaging==21.3 131 | pandas==1.4.4 132 | pandocfilters==1.5.0 133 | parso==0.8.3 134 | pathtools==0.1.2 135 | pexpect==4.8.0 136 | pickleshare==0.7.5 137 | pillow==9.2.0 138 | platformdirs==3.0.0 139 | plotly==5.13.0 140 | portalocker==2.7.0 141 | prometheus-client==0.16.0 142 | prompt-toolkit==3.0.37 143 | protobuf==3.19.6 144 | psutil==5.9.4 145 | ptyprocess==0.7.0 146 | pure-eval==0.2.2 147 | py-cpuinfo==9.0.0 148 | pyasn1==0.4.8 149 | pyasn1-modules==0.2.8 150 | pycparser==2.21 151 | pydantic==1.10.5 152 | pydub==0.25.1 153 | pygltflib==1.16.0 154 | pygments==2.14.0 155 | pymeshlab==2022.2.post3 156 | pyparsing==3.0.9 157 | pyquaternion==0.9.9 158 | pyrsistent==0.19.3 159 | pysdf==0.1.8 160 | python-dateutil==2.8.2 161 | python-json-logger==2.0.7 162 | python-multipart==0.0.6 163 | pythreejs==2.4.2 164 | pytz==2022.7.1 165 | pywavelets==1.4.1 166 | pyyaml==6.0 167 | pyzmq==25.0.0 168 | qtconsole==5.4.0 169 | qtpy==2.3.0 170 | regex==2022.10.31 171 | requests==2.28.2 172 | requests-oauthlib==1.3.1 173 | rfc3339-validator==0.1.4 174 | rfc3986-validator==0.1.1 175 | rsa==4.9 176 | rtree==1.0.1 177 | safetensors==0.3.1 178 | scikit-image==0.19.3 179 | scikit-learn==1.2.1 180 | scipy==1.10.1 181 | semantic-version==2.10.0 182 | send2trash==1.8.0 183 | sentencepiece==0.1.97 184 | sentry-sdk==1.15.0 185 | setproctitle==1.3.2 186 | setuptools==63.4.3 187 | sh==2.0.2 188 | shapely==2.0.1 189 | six==1.16.0 190 | smmap==5.0.0 191 | sniffio==1.3.0 192 | soupsieve==2.4 193 | stack-data==0.6.2 194 | starlette==0.27.0 195 | sympy==1.12 196 | tabulate==0.9.0 197 | tenacity==8.2.1 198 | tensorboard==2.10.1 199 | tensorboard-data-server==0.6.1 200 | tensorboard-plugin-wit==1.8.1 201 | termcolor==2.3.0 202 | terminado==0.17.1 203 | threadpoolctl==3.1.0 204 | tifffile==2023.2.3 205 | timm==0.9.2 206 | tinycss2==1.2.1 207 | tokenizers==0.13.2 208 | toolz==0.12.0 209 | tornado==6.2 210 | tqdm==4.64.1 211 | traitlets==5.9.0 212 | traittypes==0.2.1 213 | transformers==4.30.2 214 | trimesh==3.18.3 215 | triton==1.1.1 216 | typing-extensions==4.5.0 217 | typing-inspect==0.9.0 218 | uc-micro-py==1.0.2 219 | uri-template==1.2.0 220 | urllib3==1.26.14 221 | uvicorn==0.23.2 222 | wandb==0.13.10 223 | wcwidth==0.2.6 224 | webcolors==1.12 225 | webdataset==0.2.33 226 | webencodings==0.5.1 227 | websocket-client==1.5.1 228 | websockets==11.0.3 229 | werkzeug==2.2.3 230 | widgetsnbextension==4.0.5 231 | wrapt==1.15.0 232 | xatlas==0.0.7 233 | yacs==0.1.8 234 | yarl==1.8.2 235 | zipp==3.14.0 236 | # torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 237 | https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=db457a822d736013b6ffe509053001bc918bdd78fe68967b605f53984a9afac5 238 | https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=a9fc38040e133d1779f131b4497caef830e9e699faf89cd323cd58794ffb305b 239 | https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=5bc0e29cb78f7c452eeb4f27029c40049770d51553bf840b4ca2edd63da289ee 240 | # torch-cluster 241 | https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_cluster-1.6.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl 242 | # torch-scatter 243 | https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl 244 | torchmetrics==0.11.1 245 | pytorch_lightning~=1.9.3 246 | git+https://github.com/pyvista/fast-simplification.git 247 | git+https://github.com/skoch9/meshplot.git 248 | git+https://github.com/NVlabs/nvdiffrast/ -------------------------------------------------------------------------------- /meshanything_train/miche/scripts/infer.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --task reconstruction \ 3 | --config_path ./configs/aligned_shape_latents/shapevae-256.yaml \ 4 | --ckpt_path ./checkpoints/aligned_shape_latents/shapevae-256.ckpt \ 5 | --pointcloud_path ./example_data/surface/surface.npz 6 | 7 | python inference.py \ 8 | --task image2mesh \ 9 | --config_path ./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml \ 10 | --ckpt_path ./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt \ 11 | --image_path ./example_data/image/car.jpg 12 | 13 | python inference.py \ 14 | --task text2mesh \ 15 | --config_path ./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml \ 16 | --ckpt_path ./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt \ 17 | --text "A 3D model of motorcar; Porche Cayenne Turbo." -------------------------------------------------------------------------------- /meshanything_train/miche/scripts/inference/image2mesh.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --task image2mesh \ 3 | --config_path ./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml \ 4 | --ckpt_path ./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt \ 5 | --image_path ./example_data/image/car.jpg -------------------------------------------------------------------------------- /meshanything_train/miche/scripts/inference/reconstruction.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --task reconstruction \ 3 | --config_path ./configs/aligned_shape_latents/shapevae-256.yaml \ 4 | --ckpt_path ./checkpoints/aligned_shape_latents/shapevae-256.ckpt \ 5 | --pointcloud_path ./example_data/surface/surface.npz -------------------------------------------------------------------------------- /meshanything_train/miche/scripts/inference/text2mesh.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --task text2mesh \ 3 | --config_path ./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml \ 4 | --ckpt_path ./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt \ 5 | --text "A 3D model of motorcar; Porche Cayenne Turbo." -------------------------------------------------------------------------------- /meshanything_train/miche/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | from distutils.extension import Extension 4 | from Cython.Build import cythonize 5 | import numpy as np 6 | 7 | 8 | setup( 9 | name="michelangelo", 10 | version="0.4.1", 11 | author="Zibo Zhao, Wen Liu and Xin Chen", 12 | author_email="liuwen@shanghaitech.edu.cn", 13 | description="Michelangelo: a 3D Shape Generation System.", 14 | packages=find_packages(exclude=("configs", "tests", "scripts", "example_data")), 15 | python_requires=">=3.8", 16 | install_requires=[ 17 | "torch", 18 | "numpy", 19 | "cython", 20 | "tqdm", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /meshanything_train/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import numpy as np 4 | from collections import deque 5 | from typing import List 6 | from meshanything_train.dist import is_distributed, barrier, all_reduce_sum 7 | 8 | 9 | def my_worker_init_fn(worker_id): 10 | np.random.seed(np.random.get_state()[1][0] + worker_id) 11 | 12 | 13 | @torch.jit.ignore 14 | def to_list_1d(arr) -> List[float]: 15 | arr = arr.detach().cpu().numpy().tolist() 16 | return arr 17 | 18 | 19 | @torch.jit.ignore 20 | def to_list_3d(arr) -> List[List[List[float]]]: 21 | arr = arr.detach().cpu().numpy().tolist() 22 | return arr 23 | 24 | 25 | def huber_loss(error, delta=1.0): 26 | """ 27 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py 28 | x = error = pred - gt or dist(pred,gt) 29 | 0.5 * |x|^2 if |x|<=d 30 | 0.5 * d^2 + d * (|x|-d) if |x|>d 31 | """ 32 | abs_error = torch.abs(error) 33 | quadratic = torch.clamp(abs_error, max=delta) 34 | linear = abs_error - quadratic 35 | loss = 0.5 * quadratic ** 2 + delta * linear 36 | return loss 37 | 38 | 39 | # From https://github.com/facebookresearch/detr/blob/master/util/misc.py 40 | class SmoothedValue(object): 41 | """Track a series of values and provide access to smoothed values over a 42 | window or the global series average. 43 | """ 44 | 45 | def __init__(self, window_size=20, fmt=None): 46 | if fmt is None: 47 | fmt = "{median:.4f} ({global_avg:.4f})" 48 | self.deque = deque(maxlen=window_size) 49 | self.total = 0.0 50 | self.count = 0 51 | self.fmt = fmt 52 | 53 | def update(self, value, n=1): 54 | self.deque.append(value) 55 | self.count += n 56 | self.total += value * n 57 | 58 | def synchronize_between_processes(self): 59 | """ 60 | Warning: does not synchronize the deque! 61 | """ 62 | if not is_distributed(): 63 | return 64 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 65 | barrier() 66 | all_reduce_sum(t) 67 | t = t.tolist() 68 | self.count = int(t[0]) 69 | self.total = t[1] 70 | 71 | @property 72 | def median(self): 73 | d = torch.tensor(list(self.deque)) 74 | return d.median().item() 75 | 76 | @property 77 | def avg(self): 78 | d = torch.tensor(list(self.deque), dtype=torch.float32) 79 | return d.mean().item() 80 | 81 | @property 82 | def global_avg(self): 83 | return self.total / self.count 84 | 85 | @property 86 | def max(self): 87 | return max(self.deque) 88 | 89 | @property 90 | def value(self): 91 | return self.deque[-1] 92 | 93 | def __str__(self): 94 | return self.fmt.format( 95 | median=self.median, 96 | avg=self.avg, 97 | global_avg=self.global_avg, 98 | max=self.max, 99 | value=self.value, 100 | ) 101 | -------------------------------------------------------------------------------- /pc_examples/grenade.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/pc_examples/grenade.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trimesh==4.2.3 2 | accelerate==0.28.0 3 | mesh2sdf==1.1.0 4 | einops==0.7.0 5 | einx==0.1.3 6 | optimum==1.18.0 7 | omegaconf==2.3.0 8 | opencv-python==4.9.0.80 9 | transformers==4.39.3 10 | numpy==1.26.4 11 | huggingface_hub 12 | matplotlib 13 | gradio 14 | spaces -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | 3 | import datetime 4 | from meshanything_train.engine import do_train 5 | from meshanything_train.models.single_gpt import SingleGPT 6 | 7 | from accelerate import Accelerator 8 | from accelerate.logging import get_logger 9 | from accelerate.utils import set_seed 10 | import logging 11 | import importlib 12 | from accelerate.utils import DistributedDataParallelKwargs 13 | import torch 14 | 15 | def make_args_parser(): 16 | parser = argparse.ArgumentParser("MeshAnything", add_help=False) 17 | 18 | parser.add_argument("--input_pc_num", default=8192, type=int) 19 | parser.add_argument("--max_vertices", default=800, type=int) 20 | 21 | parser.add_argument("--warm_lr_epochs", default=1, type=int) 22 | parser.add_argument("--num_beams", default=1, type=int) 23 | parser.add_argument("--max_seq_ratio", default=0.70, type=float) 24 | 25 | ##### Model Setups ##### 26 | parser.add_argument( 27 | '--pretrained_tokenizer_weight', 28 | default=None, 29 | type=str, 30 | help="The weight for pre-trained vqvae" 31 | ) 32 | 33 | parser.add_argument('--llm', default="facebook/opt-350m", type=str, help="The LLM backend") 34 | parser.add_argument("--gen_n_max_triangles", default=1600, type=int, help="max number of triangles") 35 | 36 | ##### Training ##### 37 | parser.add_argument("--eval_every_iteration", default=2000, type=int) 38 | parser.add_argument("--save_every", default=250, type=int) 39 | parser.add_argument("--generate_every_data", default=1, type=int) 40 | 41 | ##### Testing ##### 42 | parser.add_argument( 43 | "--clip_gradient", default=1., type=float, 44 | help="Max L2 norm of the gradient" 45 | ) 46 | parser.add_argument("--pad_id", default=-1, type=int, help="padding id") 47 | parser.add_argument("--dataset", default='loop_set_256', help="dataset list split by ','") 48 | parser.add_argument("--n_discrete_size", default=128, type=int, help="discretized 3D space") 49 | parser.add_argument("--data_n_max_triangles", default=1600, type=int, help="max number of triangles") 50 | 51 | parser.add_argument("--n_max_triangles", default=1600, type=int, help="max number of triangles") 52 | parser.add_argument("--n_min_triangles", default=40, type=int, help="max number of triangles") 53 | 54 | parser.add_argument("--shift_scale", default=0.1, type=float) 55 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int) 56 | parser.add_argument('--data_dir', default="dataset", type=str, help="data path") 57 | 58 | parser.add_argument("--seed", default=0, type=int) 59 | 60 | parser.add_argument("--base_lr", default=1e-4, type=float) 61 | parser.add_argument("--final_lr", default=6e-5, type=float) 62 | parser.add_argument("--lr_scheduler", default="cosine", type=str) 63 | parser.add_argument("--weight_decay", default=0.1, type=float) 64 | parser.add_argument("--optimizer", default="AdamW", type=str) 65 | parser.add_argument("--warm_lr", default=1e-6, type=float) 66 | 67 | parser.add_argument("--no_aug", default=False, action="store_true") 68 | parser.add_argument("--checkpoint_dir", default="default", type=str) 69 | parser.add_argument("--log_every", default=10, type=int) 70 | parser.add_argument("--test_only", default=False, action="store_true") 71 | parser.add_argument("--generate_every_iteration", default=18000, type=int) 72 | 73 | parser.add_argument("--start_epoch", default=-1, type=int) 74 | parser.add_argument("--max_epoch", default=800, type=int) 75 | parser.add_argument("--start_eval_after", default=-1, type=int) 76 | parser.add_argument("--precision", default="fp16", type=str) 77 | parser.add_argument("--batchsize_per_gpu", default=8, type=int) 78 | parser.add_argument( 79 | "--criterion", default=None, type=str, 80 | help='metrics for saving the best model' 81 | ) 82 | 83 | parser.add_argument('--pretrained_weights', default=None, type=str) 84 | 85 | args = parser.parse_args() 86 | 87 | return args 88 | 89 | if __name__ == "__main__": 90 | 91 | logging.basicConfig(level=logging.INFO) 92 | logger = get_logger(__file__) 93 | 94 | args = make_args_parser() 95 | 96 | cur_time = datetime.datetime.now().strftime("%d_%H-%M-%S") 97 | wandb_name = args.checkpoint_dir + "_" +cur_time 98 | args.checkpoint_dir = os.path.join("gpt_output", wandb_name) 99 | print("checkpoint_dir:", args.checkpoint_dir) 100 | os.makedirs(args.checkpoint_dir, exist_ok=True) 101 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 102 | 103 | accelerator = Accelerator( 104 | gradient_accumulation_steps=args.gradient_accumulation_steps, 105 | mixed_precision=args.precision, 106 | log_with="wandb", 107 | project_dir=args.checkpoint_dir, 108 | kwargs_handlers=[kwargs] 109 | ) 110 | if "default" not in args.checkpoint_dir: 111 | accelerator.init_trackers( 112 | project_name="GPT", 113 | config=vars(args), 114 | init_kwargs={"wandb": {"name": wandb_name}} 115 | ) 116 | 117 | set_seed(args.seed, device_specific=True) 118 | 119 | dataset_module = importlib.import_module(f'meshanything_train.{args.dataset}') 120 | 121 | train_dataset = dataset_module.Dataset(args, split_set="train") 122 | test_dataset = dataset_module.Dataset(args, split_set="test") 123 | # make sure no val sample in train set 124 | train_uids = [cur_data['uid'] for cur_data in train_dataset.data] 125 | val_uids = [cur_data['uid'] for cur_data in test_dataset.data] 126 | intersection_list = list(set(train_uids).intersection(set(val_uids))) 127 | print("intersection_list:", len(intersection_list)) 128 | 129 | new_train_set_data = [] 130 | for cur_data in train_dataset.data: 131 | if cur_data['uid'] not in intersection_list: 132 | new_train_set_data.append(cur_data) 133 | train_dataset.data = new_train_set_data 134 | 135 | dataloaders = {} 136 | 137 | dataloaders['train'] = torch.utils.data.DataLoader( 138 | train_dataset, 139 | batch_size=args.batchsize_per_gpu, 140 | drop_last = True, 141 | shuffle = True, 142 | ) 143 | 144 | dataloaders['test'] = torch.utils.data.DataLoader( 145 | test_dataset, 146 | batch_size=args.batchsize_per_gpu, 147 | drop_last = True, 148 | shuffle = False, 149 | ) 150 | 151 | model = SingleGPT(args) 152 | model.to(torch.float32) 153 | do_train( 154 | args, 155 | model, 156 | dataloaders, 157 | logger, 158 | accelerator, 159 | ) -------------------------------------------------------------------------------- /training_requirement.txt: -------------------------------------------------------------------------------- 1 | trimesh==4.2.3 2 | accelerate==0.28.0 3 | mesh2sdf==1.1.0 4 | einops==0.7.0 5 | einx==0.1.3 6 | optimum==1.18.0 7 | omegaconf==2.3.0 8 | opencv-python==4.9.0.80 9 | transformers==4.39.3 10 | numpy==1.26.4 11 | huggingface_hub 12 | matplotlib 13 | gradio 14 | spaces --------------------------------------------------------------------------------