├── LICENSE.txt
├── MeshAnything
├── miche
│ ├── LICENSE
│ ├── encode.py
│ ├── michelangelo
│ │ ├── __init__.py
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── templates.json
│ │ │ ├── transforms.py
│ │ │ └── utils.py
│ │ ├── graphics
│ │ │ ├── __init__.py
│ │ │ └── primitives
│ │ │ │ ├── __init__.py
│ │ │ │ ├── mesh.py
│ │ │ │ └── volume.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── asl_diffusion
│ │ │ │ ├── __init__.py
│ │ │ │ ├── asl_diffuser_pl_module.py
│ │ │ │ ├── asl_udt.py
│ │ │ │ ├── base.py
│ │ │ │ ├── clip_asl_diffuser_pl_module.py
│ │ │ │ └── inference_utils.py
│ │ │ ├── conditional_encoders
│ │ │ │ ├── __init__.py
│ │ │ │ ├── clip.py
│ │ │ │ └── encoder_factory.py
│ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── checkpoint.py
│ │ │ │ ├── diffusion_transformer.py
│ │ │ │ ├── distributions.py
│ │ │ │ ├── embedder.py
│ │ │ │ ├── transformer_blocks.py
│ │ │ │ └── transformer_vit.py
│ │ │ └── tsal
│ │ │ │ ├── __init__.py
│ │ │ │ ├── asl_pl_module.py
│ │ │ │ ├── clip_asl_module.py
│ │ │ │ ├── inference_utils.py
│ │ │ │ ├── loss.py
│ │ │ │ ├── sal_perceiver.py
│ │ │ │ ├── sal_pl_module.py
│ │ │ │ └── tsal_base.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── eval.py
│ │ │ ├── io.py
│ │ │ ├── misc.py
│ │ │ └── visualizers
│ │ │ ├── __init__.py
│ │ │ ├── color_util.py
│ │ │ ├── html_util.py
│ │ │ └── pythreejs_viewer.py
│ └── shapevae-256.yaml
└── models
│ ├── meshanything_v2.py
│ └── shape_opt.py
├── README.md
├── adjacent_mesh_tokenization.py
├── app.py
├── data_process.py
├── demo
└── demo_video.gif
├── examples
├── screwdriver.obj
└── wand.obj
├── gt_examples
└── seals.ply
├── main.py
├── mesh_to_pc.py
├── meshanything_train
├── dist.py
├── engine.py
├── eval_cond_gpt.py
├── loop_set_256.py
├── miche
│ ├── LICENSE
│ ├── README.md
│ ├── configs
│ │ ├── aligned_shape_latents
│ │ │ ├── shapevae-256.yaml
│ │ │ └── shapevae-512.yaml
│ │ ├── image_cond_diffuser_asl
│ │ │ └── image-ASLDM-256.yaml
│ │ └── text_cond_diffuser_asl
│ │ │ └── text-ASLDM-256.yaml
│ ├── encode.py
│ ├── example_data
│ │ ├── image
│ │ │ └── car.jpg
│ │ └── surface
│ │ │ └── surface.npz
│ ├── inference.py
│ ├── michelangelo
│ │ ├── __init__.py
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── templates.json
│ │ │ ├── transforms.py
│ │ │ └── utils.py
│ │ ├── graphics
│ │ │ ├── __init__.py
│ │ │ └── primitives
│ │ │ │ ├── __init__.py
│ │ │ │ ├── mesh.py
│ │ │ │ └── volume.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── asl_diffusion
│ │ │ │ ├── __init__.py
│ │ │ │ ├── asl_diffuser_pl_module.py
│ │ │ │ ├── asl_udt.py
│ │ │ │ ├── base.py
│ │ │ │ ├── clip_asl_diffuser_pl_module.py
│ │ │ │ └── inference_utils.py
│ │ │ ├── conditional_encoders
│ │ │ │ ├── __init__.py
│ │ │ │ ├── clip.py
│ │ │ │ └── encoder_factory.py
│ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── checkpoint.py
│ │ │ │ ├── diffusion_transformer.py
│ │ │ │ ├── distributions.py
│ │ │ │ ├── embedder.py
│ │ │ │ ├── transformer_blocks.py
│ │ │ │ └── transformer_vit.py
│ │ │ └── tsal
│ │ │ │ ├── __init__.py
│ │ │ │ ├── asl_pl_module.py
│ │ │ │ ├── clip_asl_module.py
│ │ │ │ ├── inference_utils.py
│ │ │ │ ├── loss.py
│ │ │ │ ├── sal_perceiver.py
│ │ │ │ ├── sal_pl_module.py
│ │ │ │ └── tsal_base.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── eval.py
│ │ │ ├── io.py
│ │ │ ├── misc.py
│ │ │ └── visualizers
│ │ │ ├── __init__.py
│ │ │ ├── color_util.py
│ │ │ ├── html_util.py
│ │ │ └── pythreejs_viewer.py
│ ├── requirements.txt
│ ├── scripts
│ │ ├── infer.sh
│ │ └── inference
│ │ │ ├── image2mesh.sh
│ │ │ ├── reconstruction.sh
│ │ │ └── text2mesh.sh
│ └── setup.py
├── misc.py
└── models
│ ├── shape_opt.py
│ └── single_gpt.py
├── pc_examples
└── grenade.npy
├── requirements.txt
├── train.py
└── training_requirement.txt
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2023 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and
6 | binary forms, with or without modification, are permitted provided
7 | that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright
10 | notice, this list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright
13 | notice, this list of conditions and the following disclaimer in
14 | the documentation and/or other materials provided with the
15 | distribution.
16 |
17 | 3. Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived
19 | from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 |
33 | In the event that redistribution and/or use for commercial purpose in
34 | source or binary forms, with or without modification is required,
35 | please contact the contributor(s) of the work.
--------------------------------------------------------------------------------
/MeshAnything/miche/encode.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 | from omegaconf import OmegaConf
4 | import numpy as np
5 | import torch
6 | from .michelangelo.utils.misc import instantiate_from_config
7 |
8 | def load_surface(fp):
9 |
10 | with np.load(fp) as input_pc:
11 | surface = input_pc['points']
12 | normal = input_pc['normals']
13 |
14 | rng = np.random.default_rng()
15 | ind = rng.choice(surface.shape[0], 4096, replace=False)
16 | surface = torch.FloatTensor(surface[ind])
17 | normal = torch.FloatTensor(normal[ind])
18 |
19 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
20 |
21 | return surface
22 |
23 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
24 |
25 | surface = load_surface(args.pointcloud_path)
26 | # old_surface = surface.clone()
27 |
28 | # surface[0,:,0]*=-1
29 | # surface[0,:,1]*=-1
30 | surface[0,:,2]*=-1
31 |
32 | # encoding
33 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
34 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
35 |
36 | # decoding
37 | latents = model.model.shape_model.decode(shape_zq)
38 | # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
39 |
40 | return 0
41 |
42 | def load_model(ckpt_path="MeshAnything/miche/shapevae-256.ckpt"):
43 | model_config = OmegaConf.load("MeshAnything/miche/shapevae-256.yaml")
44 | # print(model_config)
45 | if hasattr(model_config, "model"):
46 | model_config = model_config.model
47 |
48 | model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
49 | device = "cuda" if torch.cuda.is_available() else "cpu"
50 | model.to(device)
51 | model = model.eval()
52 |
53 | return model
54 | if __name__ == "__main__":
55 | '''
56 | 1. Reconstruct point cloud
57 | 2. Image-conditioned generation
58 | 3. Text-conditioned generation
59 | '''
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("--config_path", type=str, required=True)
62 | parser.add_argument("--ckpt_path", type=str, required=True)
63 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
64 | parser.add_argument("--image_path", type=str, help='Path to the input image')
65 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
66 | parser.add_argument("--output_dir", type=str, default='./output')
67 | parser.add_argument("-s", "--seed", type=int, default=0)
68 | args = parser.parse_args()
69 |
70 | print(f'-----------------------------------------------------------------------------')
71 | print(f'>>> Output directory: {args.output_dir}')
72 | print(f'-----------------------------------------------------------------------------')
73 |
74 | reconstruction(args, load_model(args))
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/data/templates.json:
--------------------------------------------------------------------------------
1 | {
2 | "shape": [
3 | "a point cloud model of {}.",
4 | "There is a {} in the scene.",
5 | "There is the {} in the scene.",
6 | "a photo of a {} in the scene.",
7 | "a photo of the {} in the scene.",
8 | "a photo of one {} in the scene.",
9 | "itap of a {}.",
10 | "itap of my {}.",
11 | "itap of the {}.",
12 | "a photo of a {}.",
13 | "a photo of my {}.",
14 | "a photo of the {}.",
15 | "a photo of one {}.",
16 | "a photo of many {}.",
17 | "a good photo of a {}.",
18 | "a good photo of the {}.",
19 | "a bad photo of a {}.",
20 | "a bad photo of the {}.",
21 | "a photo of a nice {}.",
22 | "a photo of the nice {}.",
23 | "a photo of a cool {}.",
24 | "a photo of the cool {}.",
25 | "a photo of a weird {}.",
26 | "a photo of the weird {}.",
27 | "a photo of a small {}.",
28 | "a photo of the small {}.",
29 | "a photo of a large {}.",
30 | "a photo of the large {}.",
31 | "a photo of a clean {}.",
32 | "a photo of the clean {}.",
33 | "a photo of a dirty {}.",
34 | "a photo of the dirty {}.",
35 | "a bright photo of a {}.",
36 | "a bright photo of the {}.",
37 | "a dark photo of a {}.",
38 | "a dark photo of the {}.",
39 | "a photo of a hard to see {}.",
40 | "a photo of the hard to see {}.",
41 | "a low resolution photo of a {}.",
42 | "a low resolution photo of the {}.",
43 | "a cropped photo of a {}.",
44 | "a cropped photo of the {}.",
45 | "a close-up photo of a {}.",
46 | "a close-up photo of the {}.",
47 | "a jpeg corrupted photo of a {}.",
48 | "a jpeg corrupted photo of the {}.",
49 | "a blurry photo of a {}.",
50 | "a blurry photo of the {}.",
51 | "a pixelated photo of a {}.",
52 | "a pixelated photo of the {}.",
53 | "a black and white photo of the {}.",
54 | "a black and white photo of a {}",
55 | "a plastic {}.",
56 | "the plastic {}.",
57 | "a toy {}.",
58 | "the toy {}.",
59 | "a plushie {}.",
60 | "the plushie {}.",
61 | "a cartoon {}.",
62 | "the cartoon {}.",
63 | "an embroidered {}.",
64 | "the embroidered {}.",
65 | "a painting of the {}.",
66 | "a painting of a {}."
67 | ]
68 |
69 | }
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/data/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def worker_init_fn(_):
8 | worker_info = torch.utils.data.get_worker_info()
9 | worker_id = worker_info.id
10 |
11 | # dataset = worker_info.dataset
12 | # split_size = dataset.num_records // worker_info.num_workers
13 | # # reset num_records to the true number to retain reliable length information
14 | # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
15 | # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
16 | # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
17 |
18 | return np.random.seed(np.random.get_state()[1][0] + worker_id)
19 |
20 |
21 | def collation_fn(samples, combine_tensors=True, combine_scalars=True):
22 | """
23 |
24 | Args:
25 | samples (list[dict]):
26 | combine_tensors:
27 | combine_scalars:
28 |
29 | Returns:
30 |
31 | """
32 |
33 | result = {}
34 |
35 | keys = samples[0].keys()
36 |
37 | for key in keys:
38 | result[key] = []
39 |
40 | for sample in samples:
41 | for key in keys:
42 | val = sample[key]
43 | result[key].append(val)
44 |
45 | for key in keys:
46 | val_list = result[key]
47 | if isinstance(val_list[0], (int, float)):
48 | if combine_scalars:
49 | result[key] = np.array(result[key])
50 |
51 | elif isinstance(val_list[0], torch.Tensor):
52 | if combine_tensors:
53 | result[key] = torch.stack(val_list)
54 |
55 | elif isinstance(val_list[0], np.ndarray):
56 | if combine_tensors:
57 | result[key] = np.stack(val_list)
58 |
59 | return result
60 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/graphics/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/graphics/primitives/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .volume import generate_dense_grid_points
4 |
5 | from .mesh import (
6 | MeshOutput,
7 | save_obj,
8 | savemeshtes2
9 | )
10 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/graphics/primitives/mesh.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import cv2
5 | import numpy as np
6 | import PIL.Image
7 | from typing import Optional
8 |
9 | import trimesh
10 |
11 |
12 | def save_obj(pointnp_px3, facenp_fx3, fname):
13 | fid = open(fname, "w")
14 | write_str = ""
15 | for pidx, p in enumerate(pointnp_px3):
16 | pp = p
17 | write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
18 |
19 | for i, f in enumerate(facenp_fx3):
20 | f1 = f + 1
21 | write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
22 | fid.write(write_str)
23 | fid.close()
24 | return
25 |
26 |
27 | def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
28 | fol, na = os.path.split(fname)
29 | na, _ = os.path.splitext(na)
30 |
31 | matname = "%s/%s.mtl" % (fol, na)
32 | fid = open(matname, "w")
33 | fid.write("newmtl material_0\n")
34 | fid.write("Kd 1 1 1\n")
35 | fid.write("Ka 0 0 0\n")
36 | fid.write("Ks 0.4 0.4 0.4\n")
37 | fid.write("Ns 10\n")
38 | fid.write("illum 2\n")
39 | fid.write("map_Kd %s.png\n" % na)
40 | fid.close()
41 | ####
42 |
43 | fid = open(fname, "w")
44 | fid.write("mtllib %s.mtl\n" % na)
45 |
46 | for pidx, p in enumerate(pointnp_px3):
47 | pp = p
48 | fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
49 |
50 | for pidx, p in enumerate(tcoords_px2):
51 | pp = p
52 | fid.write("vt %f %f\n" % (pp[0], pp[1]))
53 |
54 | fid.write("usemtl material_0\n")
55 | for i, f in enumerate(facenp_fx3):
56 | f1 = f + 1
57 | f2 = facetex_fx3[i] + 1
58 | fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
59 | fid.close()
60 |
61 | PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
62 | os.path.join(fol, "%s.png" % na))
63 |
64 | return
65 |
66 |
67 | class MeshOutput(object):
68 |
69 | def __init__(self,
70 | mesh_v: np.ndarray,
71 | mesh_f: np.ndarray,
72 | vertex_colors: Optional[np.ndarray] = None,
73 | uvs: Optional[np.ndarray] = None,
74 | mesh_tex_idx: Optional[np.ndarray] = None,
75 | tex_map: Optional[np.ndarray] = None):
76 |
77 | self.mesh_v = mesh_v
78 | self.mesh_f = mesh_f
79 | self.vertex_colors = vertex_colors
80 | self.uvs = uvs
81 | self.mesh_tex_idx = mesh_tex_idx
82 | self.tex_map = tex_map
83 |
84 | def contain_uv_texture(self):
85 | return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
86 |
87 | def contain_vertex_colors(self):
88 | return self.vertex_colors is not None
89 |
90 | def export(self, fname):
91 |
92 | if self.contain_uv_texture():
93 | savemeshtes2(
94 | self.mesh_v,
95 | self.uvs,
96 | self.mesh_f,
97 | self.mesh_tex_idx,
98 | self.tex_map,
99 | fname
100 | )
101 |
102 | elif self.contain_vertex_colors():
103 | mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
104 | mesh_obj.export(fname)
105 |
106 | else:
107 | save_obj(
108 | self.mesh_v,
109 | self.mesh_f,
110 | fname
111 | )
112 |
113 |
114 |
115 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/graphics/primitives/volume.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 |
5 |
6 | def generate_dense_grid_points(bbox_min: np.ndarray,
7 | bbox_max: np.ndarray,
8 | octree_depth: int,
9 | indexing: str = "ij"):
10 | length = bbox_max - bbox_min
11 | num_cells = np.exp2(octree_depth)
12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
16 | xyz = np.stack((xs, ys, zs), axis=-1)
17 | xyz = xyz.reshape(-1, 3)
18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
19 |
20 | return xyz, grid_size, length
21 |
22 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/asl_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/asl_diffusion/asl_udt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | from typing import Optional
6 | from diffusers.models.embeddings import Timesteps
7 | import math
8 |
9 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import MLP
10 | from MeshAnything.miche.michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer
11 |
12 |
13 | class ConditionalASLUDTDenoiser(nn.Module):
14 |
15 | def __init__(self, *,
16 | device: Optional[torch.device],
17 | dtype: Optional[torch.dtype],
18 | input_channels: int,
19 | output_channels: int,
20 | n_ctx: int,
21 | width: int,
22 | layers: int,
23 | heads: int,
24 | context_dim: int,
25 | context_ln: bool = True,
26 | skip_ln: bool = False,
27 | init_scale: float = 0.25,
28 | flip_sin_to_cos: bool = False,
29 | use_checkpoint: bool = False):
30 | super().__init__()
31 |
32 | self.use_checkpoint = use_checkpoint
33 |
34 | init_scale = init_scale * math.sqrt(1.0 / width)
35 |
36 | self.backbone = UNetDiffusionTransformer(
37 | device=device,
38 | dtype=dtype,
39 | n_ctx=n_ctx,
40 | width=width,
41 | layers=layers,
42 | heads=heads,
43 | skip_ln=skip_ln,
44 | init_scale=init_scale,
45 | use_checkpoint=use_checkpoint
46 | )
47 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
48 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
49 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
50 |
51 | # timestep embedding
52 | self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
53 | self.time_proj = MLP(
54 | device=device, dtype=dtype, width=width, init_scale=init_scale
55 | )
56 |
57 | self.context_embed = nn.Sequential(
58 | nn.LayerNorm(context_dim, device=device, dtype=dtype),
59 | nn.Linear(context_dim, width, device=device, dtype=dtype),
60 | )
61 |
62 | if context_ln:
63 | self.context_embed = nn.Sequential(
64 | nn.LayerNorm(context_dim, device=device, dtype=dtype),
65 | nn.Linear(context_dim, width, device=device, dtype=dtype),
66 | )
67 | else:
68 | self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
69 |
70 | def forward(self,
71 | model_input: torch.FloatTensor,
72 | timestep: torch.LongTensor,
73 | context: torch.FloatTensor):
74 |
75 | r"""
76 | Args:
77 | model_input (torch.FloatTensor): [bs, n_data, c]
78 | timestep (torch.LongTensor): [bs,]
79 | context (torch.FloatTensor): [bs, context_tokens, c]
80 |
81 | Returns:
82 | sample (torch.FloatTensor): [bs, n_data, c]
83 |
84 | """
85 |
86 | _, n_data, _ = model_input.shape
87 |
88 | # 1. time
89 | t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
90 |
91 | # 2. conditions projector
92 | context = self.context_embed(context)
93 |
94 | # 3. denoiser
95 | x = self.input_proj(model_input)
96 | x = torch.cat([t_emb, context, x], dim=1)
97 | x = self.backbone(x)
98 | x = self.ln_post(x)
99 | x = x[:, -n_data:]
100 | sample = self.output_proj(x)
101 |
102 | return sample
103 |
104 |
105 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/asl_diffusion/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class BaseDenoiser(nn.Module):
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x, t, context):
13 | raise NotImplementedError
14 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/asl_diffusion/inference_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from tqdm import tqdm
5 | from typing import Tuple, List, Union, Optional
6 | from diffusers.schedulers import DDIMScheduler
7 |
8 |
9 | __all__ = ["ddim_sample"]
10 |
11 |
12 | def ddim_sample(ddim_scheduler: DDIMScheduler,
13 | diffusion_model: torch.nn.Module,
14 | shape: Union[List[int], Tuple[int]],
15 | cond: torch.FloatTensor,
16 | steps: int,
17 | eta: float = 0.0,
18 | guidance_scale: float = 3.0,
19 | do_classifier_free_guidance: bool = True,
20 | generator: Optional[torch.Generator] = None,
21 | device: torch.device = "cuda:0",
22 | disable_prog: bool = True):
23 |
24 | assert steps > 0, f"{steps} must > 0."
25 |
26 | # init latents
27 | bsz = cond.shape[0]
28 | if do_classifier_free_guidance:
29 | bsz = bsz // 2
30 |
31 | latents = torch.randn(
32 | (bsz, *shape),
33 | generator=generator,
34 | device=cond.device,
35 | dtype=cond.dtype,
36 | )
37 | # scale the initial noise by the standard deviation required by the scheduler
38 | latents = latents * ddim_scheduler.init_noise_sigma
39 | # set timesteps
40 | ddim_scheduler.set_timesteps(steps)
41 | timesteps = ddim_scheduler.timesteps.to(device)
42 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
43 | # eta (η) is only used with the DDIMScheduler, and between [0, 1]
44 | extra_step_kwargs = {
45 | "eta": eta,
46 | "generator": generator
47 | }
48 |
49 | # reverse
50 | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
51 | # expand the latents if we are doing classifier free guidance
52 | latent_model_input = (
53 | torch.cat([latents] * 2)
54 | if do_classifier_free_guidance
55 | else latents
56 | )
57 | # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
58 | # predict the noise residual
59 | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
60 | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
61 | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
62 |
63 | # perform guidance
64 | if do_classifier_free_guidance:
65 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
66 | noise_pred = noise_pred_uncond + guidance_scale * (
67 | noise_pred_text - noise_pred_uncond
68 | )
69 | # text_embeddings_for_guidance = encoder_hidden_states.chunk(
70 | # 2)[1] if do_classifier_free_guidance else encoder_hidden_states
71 | # compute the previous noisy sample x_t -> x_t-1
72 | latents = ddim_scheduler.step(
73 | noise_pred, t, latents, **extra_step_kwargs
74 | ).prev_sample
75 |
76 | yield latents, t
77 |
78 |
79 | def karra_sample():
80 | pass
81 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/conditional_encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .clip import CLIPEncoder
4 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/conditional_encoders/clip.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | from dataclasses import dataclass
7 | from torchvision.transforms import Normalize
8 | from transformers import CLIPModel, CLIPTokenizer
9 | from transformers.utils import ModelOutput
10 | from typing import Iterable, Optional, Union, List
11 |
12 |
13 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
14 |
15 |
16 | @dataclass
17 | class CLIPEmbedOutput(ModelOutput):
18 | last_hidden_state: torch.FloatTensor = None
19 | pooler_output: torch.FloatTensor = None
20 | embeds: torch.FloatTensor = None
21 |
22 |
23 | class CLIPEncoder(torch.nn.Module):
24 |
25 | def __init__(self, model_path="openai/clip-vit-base-patch32"):
26 |
27 | super().__init__()
28 |
29 | # Load the CLIP model and processor
30 | self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
31 | self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
32 | self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33 |
34 | self.model.training = False
35 | for p in self.model.parameters():
36 | p.requires_grad = False
37 |
38 | @torch.no_grad()
39 | def encode_image(self, images: Iterable[Optional[ImageType]]):
40 | pixel_values = self.image_preprocess(images)
41 |
42 | vision_outputs = self.model.vision_model(pixel_values=pixel_values)
43 |
44 | pooler_output = vision_outputs[1] # pooled_output
45 | image_features = self.model.visual_projection(pooler_output)
46 |
47 | visual_embeds = CLIPEmbedOutput(
48 | last_hidden_state=vision_outputs.last_hidden_state,
49 | pooler_output=pooler_output,
50 | embeds=image_features
51 | )
52 |
53 | return visual_embeds
54 |
55 | @torch.no_grad()
56 | def encode_text(self, texts: List[str]):
57 | text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
58 |
59 | text_outputs = self.model.text_model(input_ids=text_inputs)
60 |
61 | pooler_output = text_outputs[1] # pooled_output
62 | text_features = self.model.text_projection(pooler_output)
63 |
64 | text_embeds = CLIPEmbedOutput(
65 | last_hidden_state=text_outputs.last_hidden_state,
66 | pooler_output=pooler_output,
67 | embeds=text_features
68 | )
69 |
70 | return text_embeds
71 |
72 | def forward(self,
73 | images: Iterable[Optional[ImageType]],
74 | texts: List[str]):
75 |
76 | visual_embeds = self.encode_image(images)
77 | text_embeds = self.encode_text(texts)
78 |
79 | return visual_embeds, text_embeds
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .checkpoint import checkpoint
4 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/modules/checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
4 | """
5 |
6 | import torch
7 | from typing import Callable, Iterable, Sequence, Union
8 |
9 |
10 | def checkpoint(
11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
12 | inputs: Sequence[torch.Tensor],
13 | params: Iterable[torch.Tensor],
14 | flag: bool,
15 | use_deepspeed: bool = False
16 | ):
17 | """
18 | Evaluate a function without caching intermediate activations, allowing for
19 | reduced memory at the expense of extra compute in the backward pass.
20 | :param func: the function to evaluate.
21 | :param inputs: the argument sequence to pass to `func`.
22 | :param params: a sequence of parameters `func` depends on but does not
23 | explicitly take as arguments.
24 | :param flag: if False, disable gradient checkpointing.
25 | :param use_deepspeed: if True, use deepspeed
26 | """
27 | if flag:
28 | if use_deepspeed:
29 | import deepspeed
30 | return deepspeed.checkpointing.checkpoint(func, *inputs)
31 |
32 | args = tuple(inputs) + tuple(params)
33 | return CheckpointFunction.apply(func, len(inputs), *args)
34 | else:
35 | return func(*inputs)
36 |
37 |
38 | class CheckpointFunction(torch.autograd.Function):
39 | @staticmethod
40 | @torch.cuda.amp.custom_fwd
41 | def forward(ctx, run_function, length, *args):
42 | ctx.run_function = run_function
43 | ctx.input_tensors = list(args[:length])
44 | ctx.input_params = list(args[length:])
45 |
46 | with torch.no_grad():
47 | output_tensors = ctx.run_function(*ctx.input_tensors)
48 | return output_tensors
49 |
50 | @staticmethod
51 | @torch.cuda.amp.custom_bwd
52 | def backward(ctx, *output_grads):
53 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
54 | with torch.enable_grad():
55 | # Fixes a bug where the first op in run_function modifies the
56 | # Tensor storage in place, which is not allowed for detach()'d
57 | # Tensors.
58 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
59 | output_tensors = ctx.run_function(*shallow_copies)
60 | input_grads = torch.autograd.grad(
61 | output_tensors,
62 | ctx.input_tensors + ctx.input_params,
63 | output_grads,
64 | allow_unused=True,
65 | )
66 | del ctx.input_tensors
67 | del ctx.input_params
68 | del output_tensors
69 | return (None, None) + input_grads
70 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/modules/diffusion_transformer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from typing import Optional
7 |
8 | from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint
9 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import (
10 | init_linear,
11 | MLP,
12 | MultiheadCrossAttention,
13 | MultiheadAttention,
14 | ResidualAttentionBlock
15 | )
16 |
17 |
18 | class AdaLayerNorm(nn.Module):
19 | def __init__(self,
20 | device: torch.device,
21 | dtype: torch.dtype,
22 | width: int):
23 |
24 | super().__init__()
25 |
26 | self.silu = nn.SiLU(inplace=True)
27 | self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
28 | self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
29 |
30 | def forward(self, x, timestep):
31 | emb = self.linear(timestep)
32 | scale, shift = torch.chunk(emb, 2, dim=2)
33 | x = self.layernorm(x) * (1 + scale) + shift
34 | return x
35 |
36 |
37 | class DitBlock(nn.Module):
38 | def __init__(
39 | self,
40 | *,
41 | device: torch.device,
42 | dtype: torch.dtype,
43 | n_ctx: int,
44 | width: int,
45 | heads: int,
46 | context_dim: int,
47 | qkv_bias: bool = False,
48 | init_scale: float = 1.0,
49 | use_checkpoint: bool = False
50 | ):
51 | super().__init__()
52 |
53 | self.use_checkpoint = use_checkpoint
54 |
55 | self.attn = MultiheadAttention(
56 | device=device,
57 | dtype=dtype,
58 | n_ctx=n_ctx,
59 | width=width,
60 | heads=heads,
61 | init_scale=init_scale,
62 | qkv_bias=qkv_bias
63 | )
64 | self.ln_1 = AdaLayerNorm(device, dtype, width)
65 |
66 | if context_dim is not None:
67 | self.ln_2 = AdaLayerNorm(device, dtype, width)
68 | self.cross_attn = MultiheadCrossAttention(
69 | device=device,
70 | dtype=dtype,
71 | width=width,
72 | heads=heads,
73 | data_width=context_dim,
74 | init_scale=init_scale,
75 | qkv_bias=qkv_bias
76 | )
77 |
78 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
79 | self.ln_3 = AdaLayerNorm(device, dtype, width)
80 |
81 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
82 | return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
83 |
84 | def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
85 | x = x + self.attn(self.ln_1(x, t))
86 | if context is not None:
87 | x = x + self.cross_attn(self.ln_2(x, t), context)
88 | x = x + self.mlp(self.ln_3(x, t))
89 | return x
90 |
91 |
92 | class DiT(nn.Module):
93 | def __init__(
94 | self,
95 | *,
96 | device: Optional[torch.device],
97 | dtype: Optional[torch.dtype],
98 | n_ctx: int,
99 | width: int,
100 | layers: int,
101 | heads: int,
102 | context_dim: int,
103 | init_scale: float = 0.25,
104 | qkv_bias: bool = False,
105 | use_checkpoint: bool = False
106 | ):
107 | super().__init__()
108 | self.n_ctx = n_ctx
109 | self.width = width
110 | self.layers = layers
111 |
112 | self.resblocks = nn.ModuleList(
113 | [
114 | DitBlock(
115 | device=device,
116 | dtype=dtype,
117 | n_ctx=n_ctx,
118 | width=width,
119 | heads=heads,
120 | context_dim=context_dim,
121 | qkv_bias=qkv_bias,
122 | init_scale=init_scale,
123 | use_checkpoint=use_checkpoint
124 | )
125 | for _ in range(layers)
126 | ]
127 | )
128 |
129 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
130 | for block in self.resblocks:
131 | x = block(x, t, context)
132 | return x
133 |
134 |
135 | class UNetDiffusionTransformer(nn.Module):
136 | def __init__(
137 | self,
138 | *,
139 | device: Optional[torch.device],
140 | dtype: Optional[torch.dtype],
141 | n_ctx: int,
142 | width: int,
143 | layers: int,
144 | heads: int,
145 | init_scale: float = 0.25,
146 | qkv_bias: bool = False,
147 | skip_ln: bool = False,
148 | use_checkpoint: bool = False
149 | ):
150 | super().__init__()
151 |
152 | self.n_ctx = n_ctx
153 | self.width = width
154 | self.layers = layers
155 |
156 | self.encoder = nn.ModuleList()
157 | for _ in range(layers):
158 | resblock = ResidualAttentionBlock(
159 | device=device,
160 | dtype=dtype,
161 | n_ctx=n_ctx,
162 | width=width,
163 | heads=heads,
164 | init_scale=init_scale,
165 | qkv_bias=qkv_bias,
166 | use_checkpoint=use_checkpoint
167 | )
168 | self.encoder.append(resblock)
169 |
170 | self.middle_block = ResidualAttentionBlock(
171 | device=device,
172 | dtype=dtype,
173 | n_ctx=n_ctx,
174 | width=width,
175 | heads=heads,
176 | init_scale=init_scale,
177 | qkv_bias=qkv_bias,
178 | use_checkpoint=use_checkpoint
179 | )
180 |
181 | self.decoder = nn.ModuleList()
182 | for _ in range(layers):
183 | resblock = ResidualAttentionBlock(
184 | device=device,
185 | dtype=dtype,
186 | n_ctx=n_ctx,
187 | width=width,
188 | heads=heads,
189 | init_scale=init_scale,
190 | qkv_bias=qkv_bias,
191 | use_checkpoint=use_checkpoint
192 | )
193 | linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
194 | init_linear(linear, init_scale)
195 |
196 | layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
197 |
198 | self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
199 |
200 | def forward(self, x: torch.Tensor):
201 |
202 | enc_outputs = []
203 | for block in self.encoder:
204 | x = block(x)
205 | enc_outputs.append(x)
206 |
207 | x = self.middle_block(x)
208 |
209 | for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
210 | x = torch.cat([enc_outputs.pop(), x], dim=-1)
211 | x = linear(x)
212 |
213 | if layer_norm is not None:
214 | x = layer_norm(x)
215 |
216 | x = resblock(x)
217 |
218 | return x
219 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/modules/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Union, List
4 |
5 |
6 | class AbstractDistribution(object):
7 | def sample(self):
8 | raise NotImplementedError()
9 |
10 | def mode(self):
11 | raise NotImplementedError()
12 |
13 |
14 | class DiracDistribution(AbstractDistribution):
15 | def __init__(self, value):
16 | self.value = value
17 |
18 | def sample(self):
19 | return self.value
20 |
21 | def mode(self):
22 | return self.value
23 |
24 |
25 | class DiagonalGaussianDistribution(object):
26 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
27 | self.feat_dim = feat_dim
28 | self.parameters = parameters
29 |
30 | if isinstance(parameters, list):
31 | self.mean = parameters[0]
32 | self.logvar = parameters[1]
33 | else:
34 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
35 |
36 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
37 | self.deterministic = deterministic
38 | self.std = torch.exp(0.5 * self.logvar)
39 | self.var = torch.exp(self.logvar)
40 | if self.deterministic:
41 | self.var = self.std = torch.zeros_like(self.mean)
42 |
43 | def sample(self):
44 | x = self.mean + self.std * torch.randn_like(self.mean)
45 | return x
46 |
47 | def kl(self, other=None, dims=(1, 2, 3)):
48 | if self.deterministic:
49 | return torch.Tensor([0.])
50 | else:
51 | if other is None:
52 | return 0.5 * torch.mean(torch.pow(self.mean, 2)
53 | + self.var - 1.0 - self.logvar,
54 | dim=dims)
55 | else:
56 | return 0.5 * torch.mean(
57 | torch.pow(self.mean - other.mean, 2) / other.var
58 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
59 | dim=dims)
60 |
61 | def nll(self, sample, dims=(1, 2, 3)):
62 | if self.deterministic:
63 | return torch.Tensor([0.])
64 | logtwopi = np.log(2.0 * np.pi)
65 | return 0.5 * torch.sum(
66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
67 | dim=dims)
68 |
69 | def mode(self):
70 | return self.mean
71 |
72 |
73 | def normal_kl(mean1, logvar1, mean2, logvar2):
74 | """
75 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
76 | Compute the KL divergence between two gaussians.
77 | Shapes are automatically broadcasted, so batches can be compared to
78 | scalars, among other use cases.
79 | """
80 | tensor = None
81 | for obj in (mean1, logvar1, mean2, logvar2):
82 | if isinstance(obj, torch.Tensor):
83 | tensor = obj
84 | break
85 | assert tensor is not None, "at least one argument must be a Tensor"
86 |
87 | # Force variances to be Tensors. Broadcasting helps convert scalars to
88 | # Tensors, but it does not work for torch.exp().
89 | logvar1, logvar2 = [
90 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
91 | for x in (logvar1, logvar2)
92 | ]
93 |
94 | return 0.5 * (
95 | -1.0
96 | + logvar2
97 | - logvar1
98 | + torch.exp(logvar1 - logvar2)
99 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
100 | )
101 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/modules/embedder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import math
7 |
8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
9 |
10 |
11 | class FourierEmbedder(nn.Module):
12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
13 | each feature dimension of `x[..., i]` into:
14 | [
15 | sin(x[..., i]),
16 | sin(f_1*x[..., i]),
17 | sin(f_2*x[..., i]),
18 | ...
19 | sin(f_N * x[..., i]),
20 | cos(x[..., i]),
21 | cos(f_1*x[..., i]),
22 | cos(f_2*x[..., i]),
23 | ...
24 | cos(f_N * x[..., i]),
25 | x[..., i] # only present if include_input is True.
26 | ], here f_i is the frequency.
27 |
28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
31 |
32 | Args:
33 | num_freqs (int): the number of frequencies, default is 6;
34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
36 | input_dim (int): the input dimension, default is 3;
37 | include_input (bool): include the input tensor or not, default is True.
38 |
39 | Attributes:
40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
42 |
43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
44 | otherwise, it is input_dim * num_freqs * 2.
45 |
46 | """
47 |
48 | def __init__(self,
49 | num_freqs: int = 6,
50 | logspace: bool = True,
51 | input_dim: int = 3,
52 | include_input: bool = True,
53 | include_pi: bool = True) -> None:
54 |
55 | """The initialization"""
56 |
57 | super().__init__()
58 |
59 | if logspace:
60 | frequencies = 2.0 ** torch.arange(
61 | num_freqs,
62 | dtype=torch.float32
63 | )
64 | else:
65 | frequencies = torch.linspace(
66 | 1.0,
67 | 2.0 ** (num_freqs - 1),
68 | num_freqs,
69 | dtype=torch.float32
70 | )
71 |
72 | if include_pi:
73 | frequencies *= torch.pi
74 |
75 | self.register_buffer("frequencies", frequencies, persistent=False)
76 | self.include_input = include_input
77 | self.num_freqs = num_freqs
78 |
79 | self.out_dim = self.get_dims(input_dim)
80 |
81 | def get_dims(self, input_dim):
82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0
83 | out_dim = input_dim * (self.num_freqs * 2 + temp)
84 |
85 | return out_dim
86 |
87 | def forward(self, x: torch.Tensor) -> torch.Tensor:
88 | """ Forward process.
89 |
90 | Args:
91 | x: tensor of shape [..., dim]
92 |
93 | Returns:
94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
95 | where temp is 1 if include_input is True and 0 otherwise.
96 | """
97 |
98 | if self.num_freqs > 0:
99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
100 | if self.include_input:
101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
102 | else:
103 | return torch.cat((embed.sin(), embed.cos()), dim=-1)
104 | else:
105 | return x
106 |
107 |
108 | class LearnedFourierEmbedder(nn.Module):
109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """
110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
111 |
112 | def __init__(self, in_channels, dim):
113 | super().__init__()
114 | assert (dim % 2) == 0
115 | half_dim = dim // 2
116 | per_channel_dim = half_dim // in_channels
117 | self.weights = nn.Parameter(torch.randn(per_channel_dim))
118 |
119 | def forward(self, x):
120 | """
121 |
122 | Args:
123 | x (torch.FloatTensor): [..., c]
124 |
125 | Returns:
126 | x (torch.FloatTensor): [..., d]
127 | """
128 |
129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
132 | return fouriered
133 |
134 |
135 | class TriplaneLearnedFourierEmbedder(nn.Module):
136 | def __init__(self, in_channels, dim):
137 | super().__init__()
138 |
139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
142 |
143 | self.out_dim = in_channels + dim
144 |
145 | def forward(self, x):
146 |
147 | yz_embed = self.yz_plane_embedder(x)
148 | xz_embed = self.xz_plane_embedder(x)
149 | xy_embed = self.xy_plane_embedder(x)
150 |
151 | embed = yz_embed + xz_embed + xy_embed
152 |
153 | return embed
154 |
155 |
156 | def sequential_pos_embed(num_len, embed_dim):
157 | assert embed_dim % 2 == 0
158 |
159 | pos = torch.arange(num_len, dtype=torch.float32)
160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32)
161 | omega /= embed_dim / 2.
162 | omega = 1. / 10000 ** omega # (D/2,)
163 |
164 | pos = pos.reshape(-1) # (M,)
165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
166 |
167 | emb_sin = torch.sin(out) # (M, D/2)
168 | emb_cos = torch.cos(out) # (M, D/2)
169 |
170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
171 |
172 | return embeddings
173 |
174 |
175 | def timestep_embedding(timesteps, dim, max_period=10000):
176 | """
177 | Create sinusoidal timestep embeddings.
178 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
179 | These may be fractional.
180 | :param dim: the dimension of the output.
181 | :param max_period: controls the minimum frequency of the embeddings.
182 | :return: an [N x dim] Tensor of positional embeddings.
183 | """
184 | half = dim // 2
185 | freqs = torch.exp(
186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
187 | ).to(device=timesteps.device)
188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190 | if dim % 2:
191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
192 | return embedding
193 |
194 |
195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
197 | log2_hashmap_size=19, desired_resolution=None):
198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
199 | return nn.Identity(), input_dim
200 |
201 | elif embed_type == "fourier":
202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
203 | logspace=True, include_input=True)
204 | return embedder_obj, embedder_obj.out_dim
205 |
206 | elif embed_type == "hashgrid":
207 | raise NotImplementedError
208 |
209 | elif embed_type == "sphere_harmonic":
210 | raise NotImplementedError
211 |
212 | else:
213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
214 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/tsal/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/tsal/clip_asl_module.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from torch import nn
5 | from einops import rearrange
6 | from transformers import CLIPModel
7 |
8 | from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule
9 |
10 |
11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
12 |
13 | def __init__(self, *,
14 | shape_model,
15 | clip_model_version: str = "openai/clip-vit-large-patch14"):
16 |
17 | super().__init__()
18 |
19 | # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version)
20 | # for params in self.clip_model.parameters():
21 | # params.requires_grad = False
22 | self.clip_model = None
23 | self.shape_model = shape_model
24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width))
25 | # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5)
26 |
27 | def set_shape_model_only(self):
28 | self.clip_model = None
29 |
30 | def encode_shape_embed(self, surface, return_latents: bool = False):
31 | """
32 |
33 | Args:
34 | surface (torch.FloatTensor): [bs, n, 3 + c]
35 | return_latents (bool):
36 |
37 | Returns:
38 | x (torch.FloatTensor): [bs, projection_dim]
39 | shape_latents (torch.FloatTensor): [bs, m, d]
40 | """
41 |
42 | pc = surface[..., 0:3]
43 | feats = surface[..., 3:]
44 |
45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
46 | x = shape_embed @ self.shape_projection
47 |
48 | if return_latents:
49 | return x, shape_latents
50 | else:
51 | return x
52 |
53 | def encode_image_embed(self, image):
54 | """
55 |
56 | Args:
57 | image (torch.FloatTensor): [bs, 3, h, w]
58 |
59 | Returns:
60 | x (torch.FloatTensor): [bs, projection_dim]
61 | """
62 |
63 | x = self.clip_model.get_image_features(image)
64 |
65 | return x
66 |
67 | def encode_text_embed(self, text):
68 | x = self.clip_model.get_text_features(text)
69 | return x
70 |
71 | def forward(self, surface, image, text):
72 | """
73 |
74 | Args:
75 | surface (torch.FloatTensor):
76 | image (torch.FloatTensor): [bs, 3, 224, 224]
77 | text (torch.LongTensor): [bs, num_templates, 77]
78 |
79 | Returns:
80 | embed_outputs (dict): the embedding outputs, and it contains:
81 | - image_embed (torch.FloatTensor):
82 | - text_embed (torch.FloatTensor):
83 | - shape_embed (torch.FloatTensor):
84 | - logit_scale (float):
85 | """
86 |
87 | # # text embedding
88 | # text_embed_all = []
89 | # for i in range(text.shape[0]):
90 | # text_for_one_sample = text[i]
91 | # text_embed = self.encode_text_embed(text_for_one_sample)
92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
93 | # text_embed = text_embed.mean(dim=0)
94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
95 | # text_embed_all.append(text_embed)
96 | # text_embed_all = torch.stack(text_embed_all)
97 |
98 | b = text.shape[0]
99 | text_tokens = rearrange(text, "b t l -> (b t) l")
100 | text_embed = self.encode_text_embed(text_tokens)
101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
102 | text_embed = text_embed.mean(dim=1)
103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
104 |
105 | # image embedding
106 | image_embed = self.encode_image_embed(image)
107 |
108 | # shape embedding
109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
110 |
111 | embed_outputs = {
112 | "image_embed": image_embed,
113 | "text_embed": text_embed,
114 | "shape_embed": shape_embed,
115 | # "logit_scale": self.clip_model.logit_scale.exp()
116 | }
117 |
118 | return embed_outputs, shape_latents
119 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/tsal/inference_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from tqdm import tqdm
5 | from einops import repeat
6 | import numpy as np
7 | from typing import Callable, Tuple, List, Union, Optional
8 | from skimage import measure
9 |
10 | from MeshAnything.miche.michelangelo.graphics.primitives import generate_dense_grid_points
11 |
12 |
13 | @torch.no_grad()
14 | def extract_geometry(geometric_func: Callable,
15 | device: torch.device,
16 | batch_size: int = 1,
17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
18 | octree_depth: int = 7,
19 | num_chunks: int = 10000,
20 | disable: bool = True):
21 | """
22 |
23 | Args:
24 | geometric_func:
25 | device:
26 | bounds:
27 | octree_depth:
28 | batch_size:
29 | num_chunks:
30 | disable:
31 |
32 | Returns:
33 |
34 | """
35 |
36 | if isinstance(bounds, float):
37 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
38 |
39 | bbox_min = np.array(bounds[0:3])
40 | bbox_max = np.array(bounds[3:6])
41 | bbox_size = bbox_max - bbox_min
42 |
43 | xyz_samples, grid_size, length = generate_dense_grid_points(
44 | bbox_min=bbox_min,
45 | bbox_max=bbox_max,
46 | octree_depth=octree_depth,
47 | indexing="ij"
48 | )
49 | xyz_samples = torch.FloatTensor(xyz_samples)
50 |
51 | batch_logits = []
52 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
53 | desc="Implicit Function:", disable=disable, leave=False):
54 | queries = xyz_samples[start: start + num_chunks, :].to(device)
55 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
56 |
57 | logits = geometric_func(batch_queries)
58 | batch_logits.append(logits.cpu())
59 |
60 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()
61 |
62 | mesh_v_f = []
63 | has_surface = np.zeros((batch_size,), dtype=np.bool_)
64 | for i in range(batch_size):
65 | try:
66 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
67 | vertices = vertices / grid_size * bbox_size + bbox_min
68 | # vertices[:, [0, 1]] = vertices[:, [1, 0]]
69 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
70 | has_surface[i] = True
71 |
72 | except ValueError:
73 | mesh_v_f.append((None, None))
74 | has_surface[i] = False
75 |
76 | except RuntimeError:
77 | mesh_v_f.append((None, None))
78 | has_surface[i] = False
79 |
80 | return mesh_v_f, has_surface
81 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/models/tsal/tsal_base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch.nn as nn
4 | from typing import Tuple, List, Optional
5 |
6 |
7 | class Point2MeshOutput(object):
8 | def __init__(self):
9 | self.mesh_v = None
10 | self.mesh_f = None
11 | self.center = None
12 | self.pc = None
13 |
14 |
15 | class Latent2MeshOutput(object):
16 |
17 | def __init__(self):
18 | self.mesh_v = None
19 | self.mesh_f = None
20 |
21 |
22 | class AlignedMeshOutput(object):
23 |
24 | def __init__(self):
25 | self.mesh_v = None
26 | self.mesh_f = None
27 | self.surface = None
28 | self.image = None
29 | self.text: Optional[str] = None
30 | self.shape_text_similarity: Optional[float] = None
31 | self.shape_image_similarity: Optional[float] = None
32 |
33 |
34 | class ShapeAsLatentPLModule(nn.Module):
35 | latent_shape: Tuple[int]
36 |
37 | def encode(self, surface, *args, **kwargs):
38 | raise NotImplementedError
39 |
40 | def decode(self, z_q, *args, **kwargs):
41 | raise NotImplementedError
42 |
43 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
44 | raise NotImplementedError
45 |
46 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
47 | raise NotImplementedError
48 |
49 |
50 | class ShapeAsLatentModule(nn.Module):
51 | latent_shape: Tuple[int, int]
52 |
53 | def __init__(self, *args, **kwargs):
54 | super().__init__()
55 |
56 | def encode(self, *args, **kwargs):
57 | raise NotImplementedError
58 |
59 | def decode(self, *args, **kwargs):
60 | raise NotImplementedError
61 |
62 | def query_geometry(self, *args, **kwargs):
63 | raise NotImplementedError
64 |
65 |
66 | class AlignedShapeAsLatentPLModule(nn.Module):
67 | latent_shape: Tuple[int]
68 |
69 | def set_shape_model_only(self):
70 | raise NotImplementedError
71 |
72 | def encode(self, surface, *args, **kwargs):
73 | raise NotImplementedError
74 |
75 | def decode(self, z_q, *args, **kwargs):
76 | raise NotImplementedError
77 |
78 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
79 | raise NotImplementedError
80 |
81 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
82 | raise NotImplementedError
83 |
84 |
85 | class AlignedShapeAsLatentModule(nn.Module):
86 | shape_model: ShapeAsLatentModule
87 | latent_shape: Tuple[int, int]
88 |
89 | def __init__(self, *args, **kwargs):
90 | super().__init__()
91 |
92 | def set_shape_model_only(self):
93 | raise NotImplementedError
94 |
95 | def encode_image_embed(self, *args, **kwargs):
96 | raise NotImplementedError
97 |
98 | def encode_text_embed(self, *args, **kwargs):
99 | raise NotImplementedError
100 |
101 | def encode_shape_embed(self, *args, **kwargs):
102 | raise NotImplementedError
103 |
104 |
105 | class TexturedShapeAsLatentModule(nn.Module):
106 |
107 | def __init__(self, *args, **kwargs):
108 | super().__init__()
109 |
110 | def encode(self, *args, **kwargs):
111 | raise NotImplementedError
112 |
113 | def decode(self, *args, **kwargs):
114 | raise NotImplementedError
115 |
116 | def query_geometry(self, *args, **kwargs):
117 | raise NotImplementedError
118 |
119 | def query_color(self, *args, **kwargs):
120 | raise NotImplementedError
121 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .misc import instantiate_from_config
4 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/utils/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 |
5 |
6 | def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7):
7 |
8 | mse = torch.mean((x - y) ** 2)
9 | psnr = 10 * torch.log10(data_range / (mse + eps))
10 |
11 | return psnr
12 |
13 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/utils/io.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import io
5 | import tarfile
6 | import json
7 | import numpy as np
8 | import numpy.lib.format
9 |
10 |
11 | def mkdir(path):
12 | os.makedirs(path, exist_ok=True)
13 | return path
14 |
15 |
16 | def npy_loads(data):
17 | stream = io.BytesIO(data)
18 | return np.lib.format.read_array(stream)
19 |
20 |
21 | def npz_loads(data):
22 | return np.load(io.BytesIO(data))
23 |
24 |
25 | def json_loads(data):
26 | return json.loads(data)
27 |
28 |
29 | def load_json(filepath):
30 | with open(filepath, "r") as f:
31 | data = json.load(f)
32 | return data
33 |
34 |
35 | def write_json(filepath, data):
36 | with open(filepath, "w") as f:
37 | json.dump(data, f, indent=2)
38 |
39 |
40 | def extract_tar(tar_path, tar_cache_folder):
41 |
42 | with tarfile.open(tar_path, "r") as tar:
43 | tar.extractall(path=tar_cache_folder)
44 |
45 | tar_uids = sorted(os.listdir(tar_cache_folder))
46 | print(f"extract tar: {tar_path} to {tar_cache_folder}")
47 | return tar_uids
48 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/utils/misc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import importlib
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 |
9 |
10 | def get_obj_from_str(string, reload=False):
11 | module, cls = string.rsplit(".", 1)
12 | if reload:
13 | module_imp = importlib.import_module(module)
14 | importlib.reload(module_imp)
15 | return getattr(importlib.import_module(module, package=None), cls)
16 |
17 |
18 | def get_obj_from_config(config):
19 | if "target" not in config:
20 | raise KeyError("Expected key `target` to instantiate.")
21 |
22 | return get_obj_from_str(config["target"])
23 |
24 |
25 | def instantiate_from_config(config, **kwargs):
26 | if "target" not in config:
27 | raise KeyError("Expected key `target` to instantiate.")
28 |
29 | cls = get_obj_from_str(config["target"])
30 |
31 | params = config.get("params", dict())
32 | # params.update(kwargs)
33 | # instance = cls(**params)
34 | kwargs.update(params)
35 | instance = cls(**kwargs)
36 |
37 | return instance
38 |
39 |
40 | def is_dist_avail_and_initialized():
41 | if not dist.is_available():
42 | return False
43 | if not dist.is_initialized():
44 | return False
45 | return True
46 |
47 |
48 | def get_rank():
49 | if not is_dist_avail_and_initialized():
50 | return 0
51 | return dist.get_rank()
52 |
53 |
54 | def get_world_size():
55 | if not is_dist_avail_and_initialized():
56 | return 1
57 | return dist.get_world_size()
58 |
59 |
60 | def all_gather_batch(tensors):
61 | """
62 | Performs all_gather operation on the provided tensors.
63 | """
64 | # Queue the gathered tensors
65 | world_size = get_world_size()
66 | # There is no need for reduction in the single-proc case
67 | if world_size == 1:
68 | return tensors
69 | tensor_list = []
70 | output_tensor = []
71 | for tensor in tensors:
72 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
73 | dist.all_gather(
74 | tensor_all,
75 | tensor,
76 | async_op=False # performance opt
77 | )
78 |
79 | tensor_list.append(tensor_all)
80 |
81 | for tensor_all in tensor_list:
82 | output_tensor.append(torch.cat(tensor_all, dim=0))
83 | return output_tensor
84 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/utils/visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/utils/visualizers/color_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | # Helper functions
6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
7 | colormap = plt.cm.get_cmap(colormap)
8 | if normalize:
9 | vmin = np.min(inp)
10 | vmax = np.max(inp)
11 |
12 | norm = plt.Normalize(vmin, vmax)
13 | return colormap(norm(inp))[:, :3]
14 |
15 |
16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
17 | # tex dims need to be power of two.
18 | array = np.ones((width, height, 3), dtype='float32')
19 |
20 | # width in texels of each checker
21 | checker_w = width / n_checkers_x
22 | checker_h = height / n_checkers_y
23 |
24 | for y in range(height):
25 | for x in range(width):
26 | color_key = int(x / checker_w) + int(y / checker_h)
27 | if color_key % 2 == 0:
28 | array[x, y, :] = [1., 0.874, 0.0]
29 | else:
30 | array[x, y, :] = [0., 0., 0.]
31 | return array
32 |
33 |
34 | def gen_circle(width=256, height=256):
35 | xx, yy = np.mgrid[:width, :height]
36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
37 | array = np.ones((width, height, 4), dtype='float32')
38 | array[:, :, 0] = (circle <= width)
39 | array[:, :, 1] = (circle <= width)
40 | array[:, :, 2] = (circle <= width)
41 | array[:, :, 3] = circle <= width
42 | return array
43 |
44 |
--------------------------------------------------------------------------------
/MeshAnything/miche/michelangelo/utils/visualizers/html_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import io
3 | import base64
4 | import numpy as np
5 | from PIL import Image
6 |
7 |
8 | def to_html_frame(content):
9 |
10 | html_frame = f"""
11 |
12 |
13 | {content}
14 |
15 |
16 | """
17 |
18 | return html_frame
19 |
20 |
21 | def to_single_row_table(caption: str, content: str):
22 |
23 | table_html = f"""
24 |
25 | {caption}
26 |
27 | {content} |
28 |
29 |
30 | """
31 |
32 | return table_html
33 |
34 |
35 | def to_image_embed_tag(image: np.ndarray):
36 |
37 | # Convert np.ndarray to bytes
38 | img = Image.fromarray(image)
39 | raw_bytes = io.BytesIO()
40 | img.save(raw_bytes, "PNG")
41 |
42 | # Encode bytes to base64
43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
44 |
45 | image_tag = f"""
46 |
47 | """
48 |
49 | return image_tag
50 |
--------------------------------------------------------------------------------
/MeshAnything/miche/shapevae-256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: MeshAnything.miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
3 | params:
4 | shape_module_cfg:
5 | target: MeshAnything.miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
6 | params:
7 | num_latents: 256
8 | embed_dim: 64
9 | point_feats: 3 # normal
10 | num_freqs: 8
11 | include_pi: false
12 | heads: 12
13 | width: 768
14 | num_encoder_layers: 8
15 | num_decoder_layers: 16
16 | use_ln_post: true
17 | init_scale: 0.25
18 | qkv_bias: false
19 | use_checkpoint: true
20 | aligned_module_cfg:
21 | target: MeshAnything.miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
22 | params:
23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
24 |
25 | loss_cfg:
26 | target: MeshAnything.miche.michelangelo.models.tsal.loss.ContrastKLNearFar
27 | params:
28 | contrast_weight: 0.1
29 | near_weight: 0.1
30 | kl_weight: 0.001
31 |
32 | optimizer_cfg:
33 | optimizer:
34 | target: torch.optim.AdamW
35 | params:
36 | betas: [0.9, 0.99]
37 | eps: 1.e-6
38 | weight_decay: 1.e-2
39 |
40 | scheduler:
41 | target: MeshAnything.miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
42 | params:
43 | warm_up_steps: 5000
44 | f_start: 1.e-6
45 | f_min: 1.e-3
46 | f_max: 1.0
47 |
--------------------------------------------------------------------------------
/MeshAnything/models/meshanything_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from transformers import AutoModelForCausalLM
4 | from MeshAnything.miche.encode import load_model
5 | from MeshAnything.models.shape_opt import ShapeOPTConfig
6 | from einops import rearrange
7 |
8 | from huggingface_hub import PyTorchModelHubMixin
9 |
10 | class MeshAnythingV2(nn.Module, PyTorchModelHubMixin,
11 | repo_url="https://github.com/buaacyw/MeshAnythingV2", pipeline_tag="image-to-3d", license="mit"):
12 | def __init__(self, config={}):
13 | super().__init__()
14 | self.config = config
15 | self.point_encoder = load_model(ckpt_path=None)
16 | self.n_discrete_size = 128
17 | self.max_seq_ratio = 0.70
18 | self.face_per_token = 9
19 | self.cond_length = 257
20 | self.cond_dim = 768
21 | self.pad_id = -1
22 | self.n_max_triangles = 1600
23 | self.max_length = int(self.n_max_triangles * self.face_per_token * self.max_seq_ratio + 3 + self.cond_length) # add 1
24 |
25 | self.coor_continuous_range = (-0.5, 0.5)
26 |
27 | self.config = ShapeOPTConfig.from_pretrained(
28 | "facebook/opt-350m",
29 | n_positions=self.max_length,
30 | max_position_embeddings=self.max_length,
31 | vocab_size=self.n_discrete_size + 4,
32 | _attn_implementation="flash_attention_2"
33 | )
34 |
35 | self.bos_token_id = 0
36 | self.eos_token_id = 1
37 | self.pad_token_id = 2
38 |
39 | self.config.bos_token_id = self.bos_token_id
40 | self.config.eos_token_id = self.eos_token_id
41 | self.config.pad_token_id = self.pad_token_id
42 | self.config._attn_implementation="flash_attention_2"
43 | self.config.n_discrete_size = self.n_discrete_size
44 | self.config.face_per_token = self.face_per_token
45 | self.config.cond_length = self.cond_length
46 |
47 | if self.config.word_embed_proj_dim != self.config.hidden_size:
48 | self.config.word_embed_proj_dim = self.config.hidden_size
49 | self.transformer = AutoModelForCausalLM.from_config(
50 | config=self.config, use_flash_attention_2 = True
51 | )
52 | self.transformer.to_bettertransformer()
53 |
54 | self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim)
55 | self.cond_proj = nn.Linear(self.cond_dim * 2, self.config.word_embed_proj_dim)
56 |
57 | self.eval()
58 |
59 | def adjacent_detokenize(self, input_ids):
60 | input_ids = input_ids.reshape(input_ids.shape[0], -1) # B x L
61 | batch_size = input_ids.shape[0]
62 | continuous_coors = torch.zeros((batch_size, self.n_max_triangles * 3 * 10, 3), device=input_ids.device)
63 | continuous_coors[...] = float('nan')
64 |
65 | for i in range(batch_size):
66 | cur_ids = input_ids[i]
67 | coor_loop_check = 0
68 | vertice_count = 0
69 | continuous_coors[i, :3, :] = torch.tensor([[-0.1, 0.0, 0.1], [-0.1, 0.1, 0.2], [-0.3, 0.3, 0.2]],
70 | device=input_ids.device)
71 | for id in cur_ids:
72 | if id == self.pad_id:
73 | break
74 | elif id == self.n_discrete_size:
75 | if coor_loop_check < 9:
76 | break
77 | if coor_loop_check % 3 !=0:
78 | break
79 | coor_loop_check = 0
80 | else:
81 |
82 | if coor_loop_check % 3 == 0 and coor_loop_check >= 9:
83 | continuous_coors[i, vertice_count] = continuous_coors[i, vertice_count-2]
84 | continuous_coors[i, vertice_count+1] = continuous_coors[i, vertice_count-1]
85 | vertice_count += 2
86 | continuous_coors[i, vertice_count, coor_loop_check % 3] = undiscretize(id, self.coor_continuous_range[0], self.coor_continuous_range[1], self.n_discrete_size)
87 | if coor_loop_check % 3 == 2:
88 | vertice_count += 1
89 | coor_loop_check += 1
90 |
91 | continuous_coors = rearrange(continuous_coors, 'b (nf nv) c -> b nf nv c', nv=3, c=3)
92 |
93 | return continuous_coors # b, nf, 3, 3
94 |
95 |
96 | def forward(self, data_dict: dict, is_eval: bool = False) -> dict:
97 | if not is_eval:
98 | return self.train_one_step(data_dict)
99 | else:
100 | return self.generate(data_dict)
101 |
102 | def process_point_feature(self, point_feature):
103 | encode_feature = torch.zeros(point_feature.shape[0], self.cond_length, self.config.word_embed_proj_dim,
104 | device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype)
105 | encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0])
106 | shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:])
107 | encode_feature[:, 1:] = self.cond_proj(torch.cat([point_feature[:, 1:], shape_latents], dim=-1))
108 |
109 | return encode_feature
110 |
111 | @torch.no_grad()
112 | def forward(self, pc_normal, sampling=False) -> dict:
113 | batch_size = pc_normal.shape[0]
114 | point_feature = self.point_encoder.encode_latents(pc_normal)
115 | processed_point_feature = self.process_point_feature(point_feature)
116 | generate_length = self.max_length - self.cond_length
117 | net_device = next(self.parameters()).device
118 | outputs = torch.ones(batch_size, generate_length).long().to(net_device) * self.eos_token_id
119 | # batch x ntokens
120 | if not sampling:
121 | results = self.transformer.generate(
122 | inputs_embeds=processed_point_feature,
123 | max_new_tokens=generate_length, # all faces plus two
124 | num_beams=1,
125 | bos_token_id=self.bos_token_id,
126 | eos_token_id=self.eos_token_id,
127 | pad_token_id=self.pad_token_id,
128 | )
129 | else:
130 | results = self.transformer.generate(
131 | inputs_embeds = processed_point_feature,
132 | max_new_tokens = generate_length, # all faces plus two
133 | do_sample=True,
134 | top_k=50,
135 | top_p=0.95,
136 | bos_token_id = self.bos_token_id,
137 | eos_token_id = self.eos_token_id,
138 | pad_token_id = self.pad_token_id,
139 | )
140 | assert results.shape[1] <= generate_length # B x ID bos is not included since it's predicted
141 | outputs[:, :results.shape[1]] = results
142 | # batch x ntokens ====> batch x ntokens x D
143 | outputs = outputs[:, 1: -1]
144 |
145 | outputs[outputs == self.bos_token_id] = self.pad_id
146 | outputs[outputs == self.eos_token_id] = self.pad_id
147 | outputs[outputs == self.pad_token_id] = self.pad_id
148 |
149 | outputs[outputs != self.pad_id] -= 3
150 | gen_mesh = self.adjacent_detokenize(outputs)
151 |
152 | return gen_mesh
153 |
154 | def undiscretize(
155 | t,
156 | low,#-0.5
157 | high,# 0.5
158 | num_discrete
159 | ):
160 | t = t.float() #[0, num_discrete-1]
161 |
162 | t /= num_discrete # 0<=t<1
163 | t = t * (high - low) + low # -0.5 <= t < 0.5
164 | return t
165 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
MeshAnything V2:
Artist-Created Mesh Generation
With Adjacent Mesh Tokenization
3 |
4 |
5 | Yiwen Chen1,2,
6 | Yikai Wang3*,
7 | Yihao Luo4,
8 | Zhengyi Wang2,3,
9 |
10 | Zilong Chen2,3,
11 | Jun Zhu2,3,
12 | Chi Zhang5*,
13 | Guosheng Lin1*
14 |
15 | *Corresponding authors.
16 |
17 | 1S-Lab, Nanyang Technological University,
18 | 2Shengshu,
19 |
20 | 3Tsinghua University,
21 | 4Imperial College London,
22 | 5Westlake University
23 |
24 |
25 |
26 |
27 |
28 |
29 |

30 |

31 |

32 |

33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | ## Contents
43 | - [Contents](#contents)
44 | - [Installation](#installation)
45 | - [Usage](#usage)
46 | - [Training](#training)
47 | - [Important Notes](#important-notes)
48 | - [Acknowledgement](#acknowledgement)
49 | - [BibTeX](#bibtex)
50 |
51 | ## Installation
52 | Our environment has been tested on Ubuntu 22, CUDA 11.8 with A800.
53 | 1. Clone our repo and create conda environment
54 | ```
55 | git clone https://github.com/buaacyw/MeshAnythingV2.git && cd MeshAnythingV2
56 | conda create -n MeshAnythingV2 python==3.10.13 -y
57 | conda activate MeshAnythingV2
58 | pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
59 | pip install -r requirements.txt
60 | pip install -r training_requirements.txt # in case you want to train
61 | pip install flash-attn --no-build-isolation
62 | pip install -U gradio
63 | ```
64 |
65 | ## Usage
66 |
67 | ### Implementation of Adjacent Mesh Tokenization and Detokenization
68 | ```
69 | # We release our adjacent mesh tokenization implementation in adjacent_mesh_tokenization.py.
70 | # For detokenization please check the function adjacent_detokenize in MeshAnything/models/meshanything_v2.py
71 | python adjacent_mesh_tokenization.py
72 | ```
73 |
74 |
75 | ### For text/image to Artist-Create Mesh. We suggest using [Rodin](https://hyperhuman.deemos.com/rodin) to first achieve text or image to dense mesh. And then input the dense mesh to us.
76 | ```
77 | # Put the output obj file of Rodin to rodin_result and using the following command to generate the Artist-Created Mesh.
78 | # We suggest using the --mc flag to preprocess the input mesh with Marching Cubes first. This helps us to align the inference point cloud to our training domain.
79 | python main.py --input_dir rodin_result --out_dir mesh_output --input_type mesh --mc
80 | ```
81 |
82 | ### Mesh Command line inference
83 | #### Important Notes: If your mesh input is not produced by Marching Cubes, We suggest you to preprocess the mesh with Marching Cubes first (simply by adding --mc).
84 | ```
85 | # folder input
86 | python main.py --input_dir examples --out_dir mesh_output --input_type mesh
87 |
88 | # single file input
89 | python main.py --input_path examples/wand.obj --out_dir mesh_output --input_type mesh
90 |
91 | # Preprocess with Marching Cubes first
92 | python main.py --input_dir examples --out_dir mesh_output --input_type mesh --mc
93 |
94 | # The mc resolution is default to be 128. For some delicate mesh, this resolution is not sufficient. Raise this resolution takes more time to preprocess but should achieve a better result.
95 | # Change it by : --mc_level 7 -> 128 (2^7), --mc_level 8 -> 256 (2^8).
96 | # 256 resolution Marching Cube example.
97 | python main.py --input_dir examples --out_dir mesh_output --input_type mesh --mc --mc_level 8
98 | ```
99 |
100 | ### Point Cloud Command line inference
101 | ```
102 | # Note: if you want to use your own point cloud, please make sure the normal is included.
103 | # The file format should be a .npy file with shape (N, 6), where N is the number of points. The first 3 columns are the coordinates, and the last 3 columns are the normal.
104 |
105 | # inference for folder
106 | python main.py --input_dir pc_examples --out_dir pc_output --input_type pc_normal
107 |
108 | # inference for single file
109 | python main.py --input_path pc_examples/grenade.npy --out_dir pc_output --input_type pc_normal
110 | ```
111 |
112 | ### Local Gradio Demo
113 | ```
114 | python app.py
115 | ```
116 |
117 | ## Training
118 |
119 | ### Step 1 Download Dataset
120 | We provide part of our processed dataset from Objaverse. You can download it from https://huggingface.co/datasets/Yiwen-ntu/MeshAnythingV2/tree/main
121 |
122 | After downloading, place `train.npz` and `test.npz` into the `dataset` directory.
123 |
124 | If you prefer to process your own data, please refer to `data_process.py`.
125 |
126 | ### Step 2 Download Point Cloud Encoder Checkpoints
127 |
128 | Download Michelangelo's point encoder from https://huggingface.co/Maikou/Michelangelo/tree/main/checkpoints/aligned_shape_latents and put it into `meshanything_train/miche/checkpoints/aligned_shape_latents/shapevae-256.ckpt`.
129 |
130 | ### Step 3 Training and Evaluation
131 | ```
132 | # Training with MultiGPU
133 | accelerate launch --multi_gpu --num_processes 8 train.py --batchsize_per_gpu 2 --checkpoint_dir training_trial
134 |
135 | # Evaluation
136 | python train.py --batchsize_per_gpu 2 --checkpoint_dir evaluation_trial --pretrained_weights gpt_output/training_trial/xxx_xxx.pth --test_only
137 | ```
138 |
139 | ## Important Notes
140 | - It takes about 8GB and 45s to generate a mesh on an A6000 GPU (depending on the face number of the generated mesh).
141 | - The input mesh will be normalized to a unit bounding box. The up vector of the input mesh should be +Y for better results.
142 | - Limited by computational resources, MeshAnything is trained on meshes with fewer than 1600 faces and cannot generate meshes with more than 1600 faces. The shape of the input mesh should be sharp enough; otherwise, it will be challenging to represent it with only 1600 faces. Thus, feed-forward 3D generation methods may often produce bad results due to insufficient shape quality. We suggest using results from 3D reconstruction, scanning, SDS-based method (like [DreamCraft3D](https://github.com/deepseek-ai/DreamCraft3D)) or [Rodin](https://hyperhuman.deemos.com/rodin) as the input of MeshAnything.
143 | - Please refer to https://huggingface.co/spaces/Yiwen-ntu/MeshAnything/tree/main/examples for more examples.
144 |
145 | ## Acknowledgement
146 |
147 | Our code is based on these wonderful repos:
148 |
149 | * [MeshAnything](https://github.com/buaacyw/MeshAnything)
150 | * [MeshGPT](https://nihalsid.github.io/mesh-gpt/)
151 | * [meshgpt-pytorch](https://github.com/lucidrains/meshgpt-pytorch)
152 | * [Michelangelo](https://github.com/NeuralCarver/Michelangelo)
153 | * [transformers](https://github.com/huggingface/transformers)
154 | * [vector-quantize-pytorch](https://github.com/lucidrains/vector-quantize-pytorch)
155 |
156 | ## BibTeX
157 | ```
158 | @misc{chen2024meshanythingv2artistcreatedmesh,
159 | title={MeshAnything V2: Artist-Created Mesh Generation With Adjacent Mesh Tokenization},
160 | author={Yiwen Chen and Yikai Wang and Yihao Luo and Zhengyi Wang and Zilong Chen and Jun Zhu and Chi Zhang and Guosheng Lin},
161 | year={2024},
162 | eprint={2408.02555},
163 | archivePrefix={arXiv},
164 | primaryClass={cs.CV},
165 | url={https://arxiv.org/abs/2408.02555},
166 | }
167 | ```
168 |
--------------------------------------------------------------------------------
/adjacent_mesh_tokenization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import trimesh
3 | import networkx as nx
4 | import os
5 |
6 | def mesh_sort(vertices_, faces_):
7 | assert (vertices_ <= 0.5).all() and (vertices_ >= -0.5).all() # [-0.5, 0.5]
8 | vertices = (vertices_ + 0.5) * 128 # [0, num_tokens]
9 | vertices -= 0.5 # for evenly distributed, [-0.5, num_tokens -0.5] will be round to 0 or num_tokens (-1)
10 | vertices_quantized_ = np.clip(vertices.round(), 0, 128 - 1).astype(int) # [0, num_tokens -1]
11 |
12 | cur_mesh = trimesh.Trimesh(vertices=vertices_quantized_, faces=faces_)
13 |
14 | cur_mesh.merge_vertices()
15 | cur_mesh.update_faces(cur_mesh.nondegenerate_faces())
16 | cur_mesh.update_faces(cur_mesh.unique_faces())
17 | cur_mesh.remove_unreferenced_vertices()
18 |
19 | sort_inds = np.lexsort(cur_mesh.vertices.T)
20 | vertices = cur_mesh.vertices[sort_inds]
21 | faces = [np.argsort(sort_inds)[f] for f in cur_mesh.faces]
22 |
23 | faces = [sorted(sub_arr) for sub_arr in faces]
24 |
25 | def sort_faces(face):
26 | return face[0], face[1], face[2]
27 |
28 | faces = sorted(faces, key=sort_faces)
29 |
30 | vertices = vertices / 128 - 0.5 # [0, num_tokens -1] to [-0.5, 0.5) for computing
31 |
32 | return vertices, faces
33 |
34 | def adjacent_mesh_tokenization(mesh):
35 | naive_v_length = mesh.faces.shape[0] * 9
36 |
37 | graph = mesh.vertex_adjacency_graph
38 |
39 | unvisited_faces = mesh.faces.copy()
40 | dis_vertices = np.asarray((mesh.vertices.copy() + 0.5) * 128)
41 |
42 | sequence = []
43 | while unvisited_faces.shape[0] > 0:
44 | # find the face with the smallest index
45 | if len(sequence) == 0 or sequence[-1] == -1:
46 | cur_face = unvisited_faces[0]
47 | unvisited_faces = unvisited_faces[1:]
48 | sequence.extend(cur_face.tolist())
49 | else:
50 | last_vertices = sequence[-2:]
51 | # find common neighbors
52 | commons = sorted(list(nx.common_neighbors(graph, last_vertices[0], last_vertices[1])))
53 | next_token = None
54 | for common in commons:
55 | common_face = sorted(np.array(last_vertices + [common]))
56 | # find index of common face
57 | equals = np.where((unvisited_faces == common_face).all(axis=1))[0]
58 | assert len(equals) == 1 or len(equals) == 0
59 | if len(equals) == 1:
60 | next_token = common
61 | next_face_index = equals[0]
62 | break
63 | if next_token is not None:
64 | unvisited_faces = np.delete(unvisited_faces, next_face_index, axis=0)
65 | sequence.append(int(next_token))
66 | else:
67 | sequence.append(-1)
68 |
69 | final_sequence = []
70 | for token_id in sequence:
71 | if token_id == -1:
72 | final_sequence.append(128)
73 | else:
74 | final_sequence.extend(dis_vertices[token_id].tolist())
75 |
76 | cur_ratio = len(final_sequence) / naive_v_length
77 |
78 | return cur_ratio
79 |
80 |
81 | if __name__ == "__main__":
82 | # read_ply
83 | data_dir = 'gt_examples'
84 | data_list = sorted(os.listdir(data_dir))
85 | data_list = [os.path.join(data_dir, x) for x in data_list if x.endswith('.ply') or x.endswith('.obj')]
86 | ratio_list = []
87 | for idx, cur_data in enumerate(data_list):
88 | cur_mesh = trimesh.load(cur_data)
89 |
90 | vertices = cur_mesh.vertices
91 | faces = cur_mesh.faces
92 | bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
93 | vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
94 | vertices = vertices / (bounds[1] - bounds[0]).max()
95 | vertices = vertices.clip(-0.5, 0.5)
96 |
97 | vertices, faces = mesh_sort(vertices, faces)
98 | dis_mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
99 | try:
100 | cur_ratio = adjacent_mesh_tokenization(dis_mesh)
101 |
102 | ratio_list.append(cur_ratio)
103 | except Exception as e:
104 | print(e)
105 |
106 | # print mean and variance of ratio:
107 | print(f"mean ratio: {np.mean(ratio_list)}, variance ratio: {np.var(ratio_list)}")
--------------------------------------------------------------------------------
/demo/demo_video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/demo/demo_video.gif
--------------------------------------------------------------------------------
/gt_examples/seals.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/gt_examples/seals.ply
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import torch
3 | import time
4 | import trimesh
5 | import numpy as np
6 | import datetime
7 | from accelerate import Accelerator
8 | from accelerate.utils import set_seed
9 | from accelerate.utils import DistributedDataParallelKwargs
10 | from safetensors.torch import load_model
11 |
12 | from mesh_to_pc import process_mesh_to_pc
13 | from huggingface_hub import hf_hub_download
14 | from MeshAnything.models.meshanything_v2 import MeshAnythingV2
15 |
16 | class Dataset:
17 | def __init__(self, input_type, input_list, mc=False, mc_level = 7):
18 | super().__init__()
19 | self.data = []
20 | if input_type == 'pc_normal':
21 | for input_path in input_list:
22 | # load npy
23 | cur_data = np.load(input_path)
24 | # sample 4096
25 | assert cur_data.shape[0] >= 8192, "input pc_normal should have at least 4096 points"
26 | idx = np.random.choice(cur_data.shape[0], 8192, replace=False)
27 | cur_data = cur_data[idx]
28 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]})
29 |
30 | elif input_type == 'mesh':
31 | mesh_list = []
32 | for input_path in input_list:
33 | # load ply
34 | cur_data = trimesh.load(input_path)
35 | mesh_list.append(cur_data)
36 | if mc:
37 | print("First Marching Cubes and then sample point cloud, need several minutes...")
38 | pc_list, _ = process_mesh_to_pc(mesh_list, marching_cubes=mc, mc_level=mc_level)
39 | for input_path, cur_data in zip(input_list, pc_list):
40 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]})
41 | print(f"dataset total data samples: {len(self.data)}")
42 |
43 | def __len__(self):
44 | return len(self.data)
45 |
46 | def __getitem__(self, idx):
47 | data_dict = {}
48 | data_dict['pc_normal'] = self.data[idx]['pc_normal']
49 | # normalize pc coor
50 | pc_coor = data_dict['pc_normal'][:, :3]
51 | normals = data_dict['pc_normal'][:, 3:]
52 | bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
53 | pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
54 | pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995
55 | assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
56 | data_dict['pc_normal'] = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
57 | data_dict['uid'] = self.data[idx]['uid']
58 |
59 | return data_dict
60 |
61 | def get_args():
62 | parser = argparse.ArgumentParser("MeshAnything", add_help=False)
63 |
64 | parser.add_argument('--input_dir', default=None, type=str)
65 | parser.add_argument('--input_path', default=None, type=str)
66 |
67 | parser.add_argument('--out_dir', default="inference_out", type=str)
68 |
69 | parser.add_argument(
70 | '--input_type',
71 | choices=['mesh','pc_normal'],
72 | default='pc',
73 | help="Type of the asset to process (default: pc)"
74 | )
75 |
76 | parser.add_argument("--batchsize_per_gpu", default=1, type=int)
77 | parser.add_argument("--seed", default=0, type=int)
78 |
79 | parser.add_argument("--mc", default=False, action="store_true")
80 | parser.add_argument("--mc_level", default=7, type=int)
81 |
82 | parser.add_argument("--sampling", default=False, action="store_true")
83 |
84 | args = parser.parse_args()
85 | return args
86 |
87 | if __name__ == "__main__":
88 | args = get_args()
89 |
90 | cur_time = datetime.datetime.now().strftime("%d_%H-%M-%S")
91 | checkpoint_dir = os.path.join(args.out_dir, cur_time)
92 | os.makedirs(checkpoint_dir, exist_ok=True)
93 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
94 | accelerator = Accelerator(
95 | mixed_precision="fp16",
96 | project_dir=checkpoint_dir,
97 | kwargs_handlers=[kwargs]
98 | )
99 |
100 | model = MeshAnythingV2.from_pretrained("Yiwen-ntu/meshanythingv2")
101 |
102 | # create dataset
103 | if args.input_dir is not None:
104 | input_list = sorted(os.listdir(args.input_dir))
105 | # only ply, obj or npy
106 | if args.input_type == 'pc_normal':
107 | input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.npy')]
108 | else:
109 | input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.ply') or x.endswith('.obj') or x.endswith('.npy')]
110 | set_seed(args.seed)
111 | dataset = Dataset(args.input_type, input_list, args.mc, args.mc_level)
112 | elif args.input_path is not None:
113 | set_seed(args.seed)
114 | dataset = Dataset(args.input_type, [args.input_path], args.mc, args.mc_level)
115 | else:
116 | raise ValueError("input_dir or input_path must be provided.")
117 |
118 | dataloader = torch.utils.data.DataLoader(
119 | dataset,
120 | batch_size=args.batchsize_per_gpu,
121 | drop_last = False,
122 | shuffle = False,
123 | )
124 |
125 | if accelerator.state.num_processes > 1:
126 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
127 | dataloader, model = accelerator.prepare(dataloader, model)
128 | begin_time = time.time()
129 | print("Generation Start!!!")
130 | with accelerator.autocast():
131 | for curr_iter, batch_data_label in enumerate(dataloader):
132 | curr_time = time.time()
133 | outputs = model(batch_data_label['pc_normal'], sampling=args.sampling)
134 | batch_size = outputs.shape[0]
135 | device = outputs.device
136 |
137 | for batch_id in range(batch_size):
138 | recon_mesh = outputs[batch_id]
139 | valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1)
140 | recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3
141 |
142 | vertices = recon_mesh.reshape(-1, 3).cpu()
143 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
144 | triangles = vertices_index.reshape(-1, 3)
145 |
146 | scene_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh",
147 | merge_primitives=True)
148 | scene_mesh.merge_vertices()
149 | scene_mesh.update_faces(scene_mesh.nondegenerate_faces())
150 | scene_mesh.update_faces(scene_mesh.unique_faces())
151 | scene_mesh.remove_unreferenced_vertices()
152 | scene_mesh.fix_normals()
153 | save_path = os.path.join(checkpoint_dir, f'{batch_data_label["uid"][batch_id]}_gen.obj')
154 | num_faces = len(scene_mesh.faces)
155 | brown_color = np.array([255, 165, 0, 255], dtype=np.uint8)
156 | face_colors = np.tile(brown_color, (num_faces, 1))
157 |
158 | scene_mesh.visual.face_colors = face_colors
159 | scene_mesh.export(save_path)
160 | print(f"{save_path} Over!!")
161 | end_time = time.time()
162 | print(f"Total time: {end_time - begin_time}")
--------------------------------------------------------------------------------
/mesh_to_pc.py:
--------------------------------------------------------------------------------
1 | import mesh2sdf.core
2 | import numpy as np
3 | import skimage.measure
4 | import trimesh
5 | import time
6 | def normalize_vertices(vertices, scale=0.95):
7 | bbmin, bbmax = vertices.min(0), vertices.max(0)
8 | center = (bbmin + bbmax) * 0.5
9 | scale = 2.0 * scale / (bbmax - bbmin).max()
10 | vertices = (vertices - center) * scale
11 | return vertices, center, scale
12 |
13 | def export_to_watertight(normalized_mesh, octree_depth: int = 7):
14 | """
15 | Convert the non-watertight mesh to watertight.
16 |
17 | Args:
18 | input_path (str): normlized path
19 | octree_depth (int):
20 |
21 | Returns:
22 | mesh(trimesh.Trimesh): watertight mesh
23 |
24 | """
25 | size = 2 ** octree_depth
26 | level = 2 / size
27 | scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
28 |
29 | sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
30 |
31 | vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
32 |
33 | vertices = vertices / size * 2 - 1 # -1 to 1
34 | vertices = vertices / to_orig_scale + to_orig_center
35 | # vertices = vertices / to_orig_scale + to_orig_center
36 | mesh = trimesh.Trimesh(vertices, faces, normals=normals)
37 |
38 | return mesh
39 |
40 | def process_mesh_to_pc(mesh_list, marching_cubes = False, sample_num = 8192, mc_level= 7):
41 | # mesh_list : list of trimesh
42 | pc_normal_list = []
43 | return_mesh_list = []
44 | for mesh in mesh_list:
45 | if marching_cubes:
46 | cur_time = time.time()
47 | mesh = export_to_watertight(mesh, octree_depth=mc_level)
48 | print("MC over! ", "mc_level: ", mc_level, "process_time:" , time.time() - cur_time)
49 | return_mesh_list.append(mesh)
50 | points, face_idx = mesh.sample(sample_num, return_index=True)
51 | normals = mesh.face_normals[face_idx]
52 |
53 | pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
54 | pc_normal_list.append(pc_normal)
55 | print("process mesh success")
56 | return pc_normal_list, return_mesh_list
57 |
58 |
--------------------------------------------------------------------------------
/meshanything_train/dist.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import pickle
3 |
4 | import torch
5 | import torch.distributed as dist
6 | import os
7 |
8 | def is_distributed():
9 | if not dist.is_available() or not dist.is_initialized():
10 | return False
11 | return True
12 |
13 |
14 | def get_rank():
15 | if not is_distributed():
16 | return 0
17 | return dist.get_rank()
18 |
19 |
20 | def is_primary():
21 | return get_rank() == 0
22 |
23 |
24 | def get_world_size():
25 | if not is_distributed():
26 | return 1
27 | return dist.get_world_size()
28 |
29 |
30 | def barrier():
31 | if not is_distributed():
32 | return
33 | torch.distributed.barrier()
34 |
35 |
36 | def setup_print_for_distributed(is_primary):
37 | """
38 | This function disables printing when not in primary process
39 | """
40 | import builtins as __builtin__
41 | builtin_print = __builtin__.print
42 |
43 | def print(*args, **kwargs):
44 | force = kwargs.pop('force', False)
45 | if is_primary or force:
46 | builtin_print(*args, **kwargs)
47 |
48 | __builtin__.print = print
49 |
50 |
51 | def init_distributed(gpu_id, global_rank, world_size, dist_url, dist_backend):
52 | torch.cuda.set_device(gpu_id)
53 | os.environ["MASTER_ADDR"] = "localhost"
54 | os.environ["MASTER_PORT"] = "12355"
55 | print(
56 | f"| distributed init (rank {global_rank}) (world {world_size}): {dist_url}",
57 | flush=True,
58 | )
59 | torch.distributed.init_process_group(
60 | backend=dist_backend,
61 | # init_method="env://",
62 | world_size=world_size,
63 | rank=global_rank,
64 | )
65 | torch.distributed.barrier()
66 | setup_print_for_distributed(is_primary())
67 |
68 |
69 | def all_reduce_sum(tensor):
70 | if not is_distributed():
71 | return tensor
72 | dim_squeeze = False
73 | if tensor.ndim == 0:
74 | tensor = tensor[None, ...]
75 | dim_squeeze = True
76 | torch.distributed.all_reduce(tensor)
77 | print("loss_tensor: ", tensor)
78 | if dim_squeeze:
79 | tensor = tensor.squeeze(0)
80 | return tensor
81 |
82 |
83 | def all_reduce_average(tensor):
84 | val = all_reduce_sum(tensor)
85 | return val / get_world_size()
86 |
87 |
88 | # Function from DETR - https://github.com/facebookresearch/detr/blob/master/util/misc.py
89 | def reduce_dict(input_dict, average=True):
90 | """
91 | Args:
92 | input_dict (dict): all the values will be reduced
93 | average (bool): whether to do average or sum
94 | Reduce the values in the dictionary from all processes so that all processes
95 | have the averaged results. Returns a dict with the same fields as
96 | input_dict, after reduction.
97 | """
98 | world_size = get_world_size()
99 | if world_size < 2:
100 | return input_dict
101 | with torch.no_grad():
102 | names = []
103 | values = []
104 | # sort the keys so that they are consistent across processes
105 | for k in sorted(input_dict.keys()):
106 | names.append(k)
107 | values.append(input_dict[k])
108 | values = torch.stack(values, dim=0)
109 | torch.distributed.all_reduce(values)
110 | if average:
111 | values /= world_size
112 | reduced_dict = {k: v for k, v in zip(names, values)}
113 | return reduced_dict
114 |
115 |
116 | # Function from https://github.com/facebookresearch/detr/blob/master/util/misc.py
117 | def all_gather_pickle(data, device):
118 | """
119 | Run all_gather on arbitrary picklable data (not necessarily tensors)
120 | Args:
121 | data: any picklable object
122 | Returns:
123 | list[data]: list of data gathered from each rank
124 | """
125 | world_size = get_world_size()
126 | if world_size == 1:
127 | return [data]
128 |
129 | # serialized to a Tensor
130 | buffer = pickle.dumps(data)
131 | storage = torch.ByteStorage.from_buffer(buffer)
132 | tensor = torch.ByteTensor(storage).to(device)
133 |
134 | # obtain Tensor size of each rank
135 | local_size = torch.tensor([tensor.numel()], device=device)
136 | size_list = [torch.tensor([0], device=device) for _ in range(world_size)]
137 | dist.all_gather(size_list, local_size)
138 | size_list = [int(size.item()) for size in size_list]
139 | max_size = max(size_list)
140 |
141 | # receiving Tensor from all ranks
142 | # we pad the tensor because torch all_gather does not support
143 | # gathering tensors of different shapes
144 | tensor_list = []
145 | for _ in size_list:
146 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
147 | if local_size != max_size:
148 | padding = torch.empty(
149 | size=(max_size - local_size,), dtype=torch.uint8, device=device
150 | )
151 | tensor = torch.cat((tensor, padding), dim=0)
152 | dist.all_gather(tensor_list, tensor)
153 |
154 | data_list = []
155 | for size, tensor in zip(size_list, tensor_list):
156 | buffer = tensor.cpu().numpy().tobytes()[:size]
157 | data_list.append(pickle.loads(buffer))
158 |
159 | return data_list
160 |
161 |
162 | def all_gather_dict(data):
163 | """
164 | Run all_gather on data which is a dictionary of Tensors
165 | """
166 | assert isinstance(data, dict)
167 |
168 | gathered_dict = {}
169 | for item_key in data:
170 | if isinstance(data[item_key], torch.Tensor):
171 | if is_distributed():
172 | data[item_key] = data[item_key].contiguous()
173 | tensor_list = [torch.empty_like(data[item_key]) for _ in range(get_world_size())]
174 | dist.all_gather(tensor_list, data[item_key])
175 | gathered_tensor = torch.cat(tensor_list, dim=0)
176 | else:
177 | gathered_tensor = data[item_key]
178 | gathered_dict[item_key] = gathered_tensor
179 | return gathered_dict
180 |
--------------------------------------------------------------------------------
/meshanything_train/engine.py:
--------------------------------------------------------------------------------
1 | import os, sys, time, math, json, importlib
2 | import torch
3 | import datetime
4 |
5 | from meshanything_train.misc import SmoothedValue
6 | from collections import defaultdict
7 |
8 | def save_checkpoint(
9 | checkpoint_dir,
10 | model,
11 | optimizer,
12 | epoch,
13 | args,
14 | best_val_metrics,
15 | filename=None,
16 | ):
17 |
18 | checkpoint_name = os.path.join(checkpoint_dir, filename)
19 | try:
20 | weight_ckpt = model.module.state_dict()
21 | except Exception as e:
22 | print("single GPU")
23 | weight_ckpt = model.state_dict()
24 |
25 | sd = {
26 | "model": weight_ckpt,
27 | "optimizer": optimizer.state_dict(),
28 | "epoch": epoch,
29 | "args": args,
30 | "best_val_metrics": best_val_metrics,
31 | }
32 | torch.save(sd, checkpoint_name)
33 |
34 |
35 | def compute_learning_rate(args, curr_epoch_normalized):
36 | assert curr_epoch_normalized <= 1.0 and curr_epoch_normalized >= 0.0
37 | if (
38 | curr_epoch_normalized <= (args.warm_lr_epochs / args.max_epoch)
39 | and args.warm_lr_epochs > 0
40 | ):
41 | # Linear Warmup
42 | curr_lr = args.warm_lr + curr_epoch_normalized * args.max_epoch * (
43 | (args.base_lr - args.warm_lr) / args.warm_lr_epochs
44 | )
45 | else:
46 | # Cosine Learning Rate Schedule
47 | curr_lr = args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (
48 | 1 + math.cos(math.pi * curr_epoch_normalized)
49 | )
50 | return curr_lr
51 |
52 | def adjust_learning_rate(args, optimizer, curr_epoch):
53 | curr_lr = compute_learning_rate(args, curr_epoch)
54 | for param_group in optimizer.param_groups:
55 | param_group["lr"] = curr_lr
56 | return curr_lr
57 |
58 | def do_train(
59 | args,
60 | model,
61 | dataloaders,
62 | logger,
63 | accelerator,
64 | best_val_metrics=dict()
65 | ):
66 |
67 | optimizer = torch.optim.AdamW(
68 | filter(lambda params: params.requires_grad, model.parameters()), # list(model.named_parameters())
69 | lr=args.base_lr,
70 | weight_decay=args.weight_decay
71 | )
72 | start_epoch = 0
73 | if args.pretrained_weights is not None:
74 | sd = torch.load(args.pretrained_weights, map_location=torch.device("cpu"))
75 | epoch = sd["epoch"]
76 | print(f"Found checkpoint at {epoch}. Resuming.")
77 | model.load_state_dict(sd["model"], strict=not args.no_strict)
78 | optimizer.load_state_dict(sd["optimizer"])
79 | start_epoch = epoch
80 | print(
81 | f"Loaded model and optimizer state at {epoch}. Loaded best val metrics so far."
82 | )
83 |
84 | if accelerator.state.num_processes > 1:
85 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
86 |
87 | dataloaders['train'], dataloaders['test'], model, optimizer = accelerator.prepare(
88 | dataloaders['train'],
89 | dataloaders['test'],
90 | model,
91 | optimizer,
92 | )
93 |
94 | max_iters = args.max_epoch * len(dataloaders['train']) // args.gradient_accumulation_steps
95 | print("train dataloader len ",len(dataloaders['train']))
96 |
97 | time_delta = SmoothedValue(window_size=10)
98 |
99 | model.train()
100 | if args.start_epoch == -1:
101 | args.start_epoch = start_epoch
102 |
103 | if args.test_only:
104 | with accelerator.autocast():
105 | task_metrics, eval_loss_dict = dataloaders['test'].dataset.eval_func(
106 | args,
107 | 0,
108 | model,
109 | dataloaders['test'],
110 | accelerator,
111 | logger,
112 | 0,
113 | test_only = True
114 | )
115 | accelerator.log(eval_loss_dict, step=0)
116 | return
117 |
118 | curr_iter = args.start_epoch * len(dataloaders['train']) // args.gradient_accumulation_steps
119 | curr_time = time.time()
120 | loss_dict = defaultdict(list)
121 |
122 | for curr_epoch in range(args.start_epoch, args.max_epoch):
123 | for batch_idx, batch_data_label in enumerate(dataloaders['train']):
124 | curr_lr = adjust_learning_rate(args, optimizer, curr_iter / max_iters)
125 | with accelerator.accumulate(model):
126 | optimizer.zero_grad()
127 | with accelerator.autocast():
128 | outputs = model(batch_data_label)
129 |
130 | loss = outputs['loss']
131 |
132 | if not math.isfinite(loss.item()):
133 | logger.info("Loss in not finite. Terminate training.")
134 | exit(-1)
135 |
136 | accelerator.backward(loss)
137 | if args.clip_gradient > 0 and accelerator.sync_gradients:
138 | accelerator.clip_grad_norm_(model.parameters(), args.clip_gradient)
139 |
140 | optimizer.step()
141 |
142 | for key, value in outputs.items():
143 | if 'loss' in key.lower():
144 | loss_dict[key].append(value.item())
145 | # logging
146 | if accelerator.sync_gradients:
147 | time_delta.update(time.time() - curr_time)
148 | curr_time = time.time()
149 | curr_iter += 1
150 |
151 | if curr_iter % args.log_every == 0:
152 | mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
153 | eta_seconds = (max_iters - curr_iter) * time_delta.avg
154 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
155 |
156 | logger.info(
157 | f"Epoch [{curr_epoch}/{args.max_epoch}]; "
158 | f"Iter [{curr_iter}/{max_iters}]; " + \
159 | f"LR {curr_lr:0.2e}; Iter time {time_delta.avg:0.2f}; "
160 | f"ETA {eta_str}; Mem {mem_mb:0.2f}MB"
161 | )
162 | for key, value in loss_dict.items():
163 | loss_dict[key] = torch.tensor(value, dtype=torch.float32).mean().item()
164 | loss_dict["learning_rate"] = curr_lr
165 | accelerator.log(loss_dict, step=curr_iter)
166 | loss_dict = defaultdict(list)
167 |
168 | if accelerator.is_main_process and (curr_iter + 1) % args.eval_every_iteration == 0:
169 | save_checkpoint(
170 | args.checkpoint_dir,
171 | model,
172 | optimizer,
173 | curr_epoch,
174 | args,
175 | best_val_metrics,
176 | filename=f"checkpoint_{curr_iter+1}.pth",
177 | )
178 |
179 | # do eval
180 | do_eval_flag = (curr_iter + 1) % args.eval_every_iteration == 0
181 | do_eval_flag &= (curr_iter + 1) > args.start_eval_after
182 | do_eval_flag |= (curr_iter + 1) == max_iters
183 | do_eval_flag |= curr_iter == 1000
184 |
185 | if do_eval_flag is True:
186 | with accelerator.autocast():
187 | task_metrics, eval_loss_dict = dataloaders['test'].dataset.eval_func(
188 | args,
189 | curr_epoch,
190 | model,
191 | dataloaders['test'],
192 | accelerator,
193 | logger,
194 | curr_iter+1,
195 | )
196 | if accelerator.is_main_process:
197 | print("Evaluation End, Begin Log!")
198 | accelerator.log(eval_loss_dict, step=curr_iter+1)
199 | print("Resume Training!")
200 | model.train()
201 |
202 | accelerator.end_training()
203 | return
--------------------------------------------------------------------------------
/meshanything_train/miche/README.md:
--------------------------------------------------------------------------------
1 | # Michelangelo
2 |
3 | ## [Conditional 3D Shape Generation based on Shape-Image-Text Aligned Latent Representation](https://neuralcarver.github.io/michelangelo)
4 | [Zibo Zhao](https://github.com/Maikouuu),
5 | [Wen Liu](https://github.com/StevenLiuWen),
6 | [Xin Chen](https://chenxin.tech/),
7 | [Xianfang Zeng](https://github.com/Zzlongjuanfeng),
8 | [Rui Wang](https://wrong.wang/),
9 | [Pei Cheng](https://neuralcarver.github.io/michelangelo),
10 | [Bin Fu](https://neuralcarver.github.io/michelangelo),
11 | [Tao Chen](https://eetchen.github.io),
12 | [Gang Yu](https://www.skicyyu.org),
13 | [Shenghua Gao](https://sist.shanghaitech.edu.cn/sist_en/2020/0814/c7582a54772/page.htm)
14 | ### [Hugging Face Demo](https://huggingface.co/spaces/Maikou/Michelangelo) | [Project Page](https://neuralcarver.github.io/michelangelo/) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Paper](https://openreview.net/pdf?id=xmxgMij3LY)
15 |
16 | https://github.com/NeuralCarver/Michelangelo/assets/37449470/123bae2c-fbb1-4d63-bd13-0e300a550868
17 |
18 | Visualization of the 3D shape produced by our framework, which splits into triplets with a conditional input on the left, a normal map in the middle, and a triangle mesh on the right. The generated 3D shapes semantically conform to the visual or textural conditional inputs.
19 |
20 | ## 🔆 Features
21 | **Michelangelo** possesses three capabilities:
22 |
23 | 1. Representing a shape into shape-image-text aligned space;
24 | 2. Image-conditioned Shape Generation;
25 | 3. Text-conditioned Shape Generation.
26 |
27 |
28 | Techniques
29 |
30 | We present a novel _alignment-before-generation_ approach to tackle the challenging task of generating general 3D shapes based on 2D images or texts. Directly learning a conditional generative model from images or texts to 3D shapes is prone to producing inconsistent results with the conditions because 3D shapes have an additional dimension whose distribution significantly differs from that of 2D images and texts. To bridge the domain gap among the three modalities and facilitate multi-modal-conditioned 3D shape generation, we explore representing 3D shapes in a shape-image-text-aligned space. Our framework comprises two models: a Shape-Image-Text-Aligned Variational Auto-Encoder (SITA-VAE) and a conditional Aligned Shape Latent Diffusion Model (ASLDM). The former model encodes the 3D shapes into the shape latent space aligned to the image and text and reconstructs the fine-grained 3D neural fields corresponding to given shape embeddings via the transformer-based decoder. The latter model learns a probabilistic mapping function from the image or text space to the latent shape space. Our extensive experiments demonstrate that our proposed approach can generate higher-quality and more diverse 3D shapes that better semantically conform to the visual or textural conditional inputs, validating the effectiveness of the shape-image-text-aligned space for cross-modality 3D shape generation.
31 |
32 | 
33 |
34 |
35 | ## 📰 News
36 | - [2024/1/23] Set up the Hugging Face Demo and release the code
37 | - [2023/09/22] **Michelangelo got accepted by NeurIPS 2023!**
38 | - [2023/6/29] Upload paper and init project
39 |
40 | ## ⚙️ Setup
41 |
42 | ### Installation
43 | Follow the command below to install the environment. We have tested the installation package on Tesla V100 and Tesla T4.
44 | ```
45 | git clone https://github.com/NeuralCarver/Michelangelo.git
46 | cd Michelangelo
47 | conda create --name Michelangelo python=3.9
48 | conda activate Michelangelo
49 | pip install -r requirements.txt
50 | ```
51 |
52 | ### Checkpoints
53 | Pleasae download weights from Hugging Face Model Space and put it to root folder. We have also uploaded the weights related to CLIP to facilitate quick usage.
54 |
55 |
56 |
57 | Tips for debugging configureation
58 |
59 |
60 | - If something goes wrong in the environment configuration process unfortunately, the user may consider skipping those packages, such as pysdf, torch-cluster, and torch-scatter. These packages will not affect the execution of the commands we provide.
61 | - If you encounter any issues while downloading CLIP, you can consider downloading it from [CLIP's Hugging Face page](https://huggingface.co/openai/clip-vit-large-patch14). Once the download is complete, remember to modify line [26](https://github.com/NeuralCarver/Michelangelo/blob/b53fa004cd4aeb0f4eb4d159ecec8489a4450dab/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml#L26C1-L26C76) and line [34](https://github.com/NeuralCarver/Michelangelo/blob/b53fa004cd4aeb0f4eb4d159ecec8489a4450dab/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml#L34) in the config file for providing correct path of CLIP.
62 | - From [issue 6](https://github.com/NeuralCarver/Michelangelo/issues/6#issuecomment-1913513382). For Windows users, running wsl2 + ubuntu 22.04, will have issues. As discussed in [issue 786](https://github.com/microsoft/WSL/issues/8587) it is just a matter to add this in the .bashrc:
63 | ```
64 | export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH.
65 | ```
66 |
67 |
68 | ## ⚡ Quick Start
69 |
70 | ### Inference
71 |
72 | #### Reconstruction a 3D shape
73 | ```
74 | ./scripts/inference/reconstruction.sh
75 | ```
76 |
77 | #### Image-conditioned shape generation
78 | ```
79 | ./scripts/inference/image2mesh.sh
80 | ```
81 |
82 | #### Text-conditioned shape generation
83 | ```
84 | ./scripts/inference/text2mesh.sh
85 | ```
86 |
87 | #### Simply run all the scripts
88 | ```
89 | ./scripts/infer.sh
90 | ```
91 |
92 |
93 | ## ❓ FAQ
94 |
95 | ## Citation
96 |
97 | If you find our code or paper helps, please consider citing:
98 |
99 | ```bibtex
100 | @inproceedings{
101 | zhao2023michelangelo,
102 | title={Michelangelo: Conditional 3D Shape Generation based on Shape-Image-Text Aligned Latent Representation},
103 | author={Zibo Zhao and Wen Liu and Xin Chen and Xianfang Zeng and Rui Wang and Pei Cheng and BIN FU and Tao Chen and Gang YU and Shenghua Gao},
104 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
105 | year={2023},
106 | url={https://openreview.net/forum?id=xmxgMij3LY}
107 | }
108 | ```
109 |
110 | ## License
111 |
112 | This code is distributed under an [GPL-3.0 license](LICENSE).
113 |
114 |
--------------------------------------------------------------------------------
/meshanything_train/miche/configs/aligned_shape_latents/shapevae-256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: meshanything_train.miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
3 | params:
4 | shape_module_cfg:
5 | target: meshanything_train.miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
6 | params:
7 | num_latents: 256
8 | embed_dim: 64
9 | point_feats: 3 # normal
10 | num_freqs: 8
11 | include_pi: false
12 | heads: 12
13 | width: 768
14 | num_encoder_layers: 8
15 | num_decoder_layers: 16
16 | use_ln_post: true
17 | init_scale: 0.25
18 | qkv_bias: false
19 | use_checkpoint: true
20 | aligned_module_cfg:
21 | target: meshanything_train.miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
22 | params:
23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
24 |
25 | loss_cfg:
26 | target: meshanything_train.miche.michelangelo.models.tsal.loss.ContrastKLNearFar
27 | params:
28 | contrast_weight: 0.1
29 | near_weight: 0.1
30 | kl_weight: 0.001
31 |
32 | optimizer_cfg:
33 | optimizer:
34 | target: torch.optim.AdamW
35 | params:
36 | betas: [0.9, 0.99]
37 | eps: 1.e-6
38 | weight_decay: 1.e-2
39 |
40 | scheduler:
41 | target: meshanything_train.miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
42 | params:
43 | warm_up_steps: 5000
44 | f_start: 1.e-6
45 | f_min: 1.e-3
46 | f_max: 1.0
47 |
--------------------------------------------------------------------------------
/meshanything_train/miche/configs/aligned_shape_latents/shapevae-512.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: meshanything_train.miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
3 | params:
4 | shape_module_cfg:
5 | target: meshanything_train.miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
6 | params:
7 | num_latents: 512
8 | embed_dim: 64
9 | point_feats: 3 # normal
10 | num_freqs: 8
11 | include_pi: false
12 | heads: 16
13 | width: 1024
14 | num_encoder_layers: 12
15 | num_decoder_layers: 16
16 | use_ln_post: true
17 | init_scale: 0.25
18 | qkv_bias: false
19 | use_checkpoint: true
20 | aligned_module_cfg:
21 | target: meshanything_train.miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
22 | params:
23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
24 |
25 | loss_cfg:
26 | target: meshanything_train.miche.michelangelo.models.tsal.loss.ContrastKLNearFar
27 | params:
28 | contrast_weight: 0.1
29 | near_weight: 0.1
30 | kl_weight: 0.001
31 |
32 | optimizer_cfg:
33 | optimizer:
34 | target: torch.optim.AdamW
35 | params:
36 | betas: [0.9, 0.99]
37 | eps: 1.e-6
38 | weight_decay: 1.e-2
39 |
40 | scheduler:
41 | target: meshanything_train.miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
42 | params:
43 | warm_up_steps: 5000
44 | f_start: 1.e-6
45 | f_min: 1.e-3
46 | f_max: 1.0
47 |
--------------------------------------------------------------------------------
/meshanything_train/miche/configs/image_cond_diffuser_asl/image-ASLDM-256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3 | params:
4 | first_stage_config:
5 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6 | params:
7 | shape_module_cfg:
8 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9 | params:
10 | num_latents: &num_latents 256
11 | embed_dim: &embed_dim 64
12 | point_feats: 3 # normal
13 | num_freqs: 8
14 | include_pi: false
15 | heads: 12
16 | width: 768
17 | num_encoder_layers: 8
18 | num_decoder_layers: 16
19 | use_ln_post: true
20 | init_scale: 0.25
21 | qkv_bias: false
22 | use_checkpoint: false
23 | aligned_module_cfg:
24 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25 | params:
26 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27 |
28 | loss_cfg:
29 | target: torch.nn.Identity
30 |
31 | cond_stage_config:
32 | target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder
33 | params:
34 | version: "./checkpoints/clip/clip-vit-large-patch14"
35 | zero_embedding_radio: 0.1
36 |
37 | first_stage_key: "surface"
38 | cond_stage_key: "image"
39 | scale_by_std: false
40 |
41 | denoiser_cfg:
42 | target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
43 | params:
44 | input_channels: *embed_dim
45 | output_channels: *embed_dim
46 | n_ctx: *num_latents
47 | width: 768
48 | layers: 6 # 2 * 6 + 1 = 13
49 | heads: 12
50 | context_dim: 1024
51 | init_scale: 1.0
52 | skip_ln: true
53 | use_checkpoint: true
54 |
55 | scheduler_cfg:
56 | guidance_scale: 7.5
57 | num_inference_steps: 50
58 | eta: 0.0
59 |
60 | noise:
61 | target: diffusers.schedulers.DDPMScheduler
62 | params:
63 | num_train_timesteps: 1000
64 | beta_start: 0.00085
65 | beta_end: 0.012
66 | beta_schedule: "scaled_linear"
67 | variance_type: "fixed_small"
68 | clip_sample: false
69 | denoise:
70 | target: diffusers.schedulers.DDIMScheduler
71 | params:
72 | num_train_timesteps: 1000
73 | beta_start: 0.00085
74 | beta_end: 0.012
75 | beta_schedule: "scaled_linear"
76 | clip_sample: false # clip sample to -1~1
77 | set_alpha_to_one: false
78 | steps_offset: 1
79 |
80 | optimizer_cfg:
81 | optimizer:
82 | target: torch.optim.AdamW
83 | params:
84 | betas: [0.9, 0.99]
85 | eps: 1.e-6
86 | weight_decay: 1.e-2
87 |
88 | scheduler:
89 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
90 | params:
91 | warm_up_steps: 5000
92 | f_start: 1.e-6
93 | f_min: 1.e-3
94 | f_max: 1.0
95 |
96 | loss_cfg:
97 | loss_type: "mse"
98 |
--------------------------------------------------------------------------------
/meshanything_train/miche/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3 | params:
4 | first_stage_config:
5 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6 | params:
7 | shape_module_cfg:
8 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9 | params:
10 | num_latents: &num_latents 256
11 | embed_dim: &embed_dim 64
12 | point_feats: 3 # normal
13 | num_freqs: 8
14 | include_pi: false
15 | heads: 12
16 | width: 768
17 | num_encoder_layers: 8
18 | num_decoder_layers: 16
19 | use_ln_post: true
20 | init_scale: 0.25
21 | qkv_bias: false
22 | use_checkpoint: true
23 | aligned_module_cfg:
24 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25 | params:
26 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27 |
28 | loss_cfg:
29 | target: torch.nn.Identity
30 |
31 | cond_stage_config:
32 | target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder
33 | params:
34 | version: "./checkpoints/clip/clip-vit-large-patch14"
35 | zero_embedding_radio: 0.1
36 | max_length: 77
37 |
38 | first_stage_key: "surface"
39 | cond_stage_key: "text"
40 | scale_by_std: false
41 |
42 | denoiser_cfg:
43 | target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
44 | params:
45 | input_channels: *embed_dim
46 | output_channels: *embed_dim
47 | n_ctx: *num_latents
48 | width: 768
49 | layers: 8 # 2 * 6 + 1 = 13
50 | heads: 12
51 | context_dim: 768
52 | init_scale: 1.0
53 | skip_ln: true
54 | use_checkpoint: true
55 |
56 | scheduler_cfg:
57 | guidance_scale: 7.5
58 | num_inference_steps: 50
59 | eta: 0.0
60 |
61 | noise:
62 | target: diffusers.schedulers.DDPMScheduler
63 | params:
64 | num_train_timesteps: 1000
65 | beta_start: 0.00085
66 | beta_end: 0.012
67 | beta_schedule: "scaled_linear"
68 | variance_type: "fixed_small"
69 | clip_sample: false
70 | denoise:
71 | target: diffusers.schedulers.DDIMScheduler
72 | params:
73 | num_train_timesteps: 1000
74 | beta_start: 0.00085
75 | beta_end: 0.012
76 | beta_schedule: "scaled_linear"
77 | clip_sample: false # clip sample to -1~1
78 | set_alpha_to_one: false
79 | steps_offset: 1
80 |
81 | optimizer_cfg:
82 | optimizer:
83 | target: torch.optim.AdamW
84 | params:
85 | betas: [0.9, 0.99]
86 | eps: 1.e-6
87 | weight_decay: 1.e-2
88 |
89 | scheduler:
90 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
91 | params:
92 | warm_up_steps: 5000
93 | f_start: 1.e-6
94 | f_min: 1.e-3
95 | f_max: 1.0
96 |
97 | loss_cfg:
98 | loss_type: "mse"
--------------------------------------------------------------------------------
/meshanything_train/miche/encode.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import argparse
4 | # from functools import partial
5 | from omegaconf import OmegaConf, DictConfig, ListConfig
6 | import numpy as np
7 | import torch
8 | # from meshanything_train.models.helper import load_config
9 | from .michelangelo.utils.misc import instantiate_from_config
10 |
11 | def load_surface(fp):
12 |
13 | with np.load(fp) as input_pc:
14 | surface = input_pc['points']
15 | normal = input_pc['normals']
16 |
17 | rng = np.random.default_rng()
18 | ind = rng.choice(surface.shape[0], 4096, replace=False)
19 | surface = torch.FloatTensor(surface[ind])
20 | normal = torch.FloatTensor(normal[ind])
21 |
22 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
23 |
24 | return surface
25 |
26 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
27 |
28 | surface = load_surface(args.pointcloud_path)
29 | # old_surface = surface.clone()
30 |
31 | # surface[0,:,0]*=-1
32 | # surface[0,:,1]*=-1
33 | surface[0,:,2]*=-1
34 |
35 | # encoding
36 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
37 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
38 |
39 | # decoding
40 | latents = model.model.shape_model.decode(shape_zq)
41 | # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
42 |
43 | return 0
44 |
45 | def load_model(ckpt_path="meshanything_train/miche/checkpoints/aligned_shape_latents/shapevae-256.ckpt"):
46 | model_config = OmegaConf.load("meshanything_train/miche/configs/aligned_shape_latents/shapevae-256.yaml")
47 | if hasattr(model_config, "model"):
48 | model_config = model_config.model
49 |
50 | model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
51 |
52 | return model
53 | if __name__ == "__main__":
54 | '''
55 | 1. Reconstruct point cloud
56 | 2. Image-conditioned generation
57 | 3. Text-conditioned generation
58 | '''
59 | parser = argparse.ArgumentParser()
60 | parser.add_argument("--config_path", type=str, required=True)
61 | parser.add_argument("--ckpt_path", type=str, required=True)
62 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
63 | parser.add_argument("--image_path", type=str, help='Path to the input image')
64 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
65 | parser.add_argument("--output_dir", type=str, default='./output')
66 | parser.add_argument("-s", "--seed", type=int, default=0)
67 | args = parser.parse_args()
68 |
69 | print(f'-----------------------------------------------------------------------------')
70 | print(f'>>> Output directory: {args.output_dir}')
71 | print(f'-----------------------------------------------------------------------------')
72 |
73 | reconstruction(args, load_model(args))
--------------------------------------------------------------------------------
/meshanything_train/miche/example_data/image/car.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/meshanything_train/miche/example_data/image/car.jpg
--------------------------------------------------------------------------------
/meshanything_train/miche/example_data/surface/surface.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/meshanything_train/miche/example_data/surface/surface.npz
--------------------------------------------------------------------------------
/meshanything_train/miche/inference.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import time
4 | from collections import OrderedDict
5 | from typing import Optional, List
6 | import argparse
7 | from functools import partial
8 |
9 | from einops import repeat, rearrange
10 | import numpy as np
11 | from PIL import Image
12 | import trimesh
13 | import cv2
14 |
15 | import torch
16 | import pytorch_lightning as pl
17 |
18 | from meshanything_train.miche.michelangelo.models.tsal.tsal_base import Latent2MeshOutput
19 | from meshanything_train.miche.michelangelo.models.tsal.inference_utils import extract_geometry
20 | from meshanything_train.miche.michelangelo.utils.misc import get_config_from_file, instantiate_from_config
21 | from meshanything_train.miche.michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
22 | from meshanything_train.miche.michelangelo.utils.visualizers import html_util
23 |
24 | def load_model(args):
25 |
26 | model_config = get_config_from_file(args.config_path)
27 | if hasattr(model_config, "model"):
28 | model_config = model_config.model
29 |
30 | model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path)
31 | model = model.cuda()
32 | model = model.eval()
33 |
34 | return model
35 |
36 | def load_surface(fp):
37 |
38 | with np.load(args.pointcloud_path) as input_pc:
39 | surface = input_pc['points']
40 | normal = input_pc['normals']
41 |
42 | rng = np.random.default_rng()
43 | ind = rng.choice(surface.shape[0], 4096, replace=False)
44 | surface = torch.FloatTensor(surface[ind])
45 | normal = torch.FloatTensor(normal[ind])
46 |
47 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
48 |
49 | return surface
50 |
51 | def prepare_image(args, number_samples=2):
52 |
53 | image = cv2.imread(f"{args.image_path}")
54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55 |
56 | image_pt = torch.tensor(image).float()
57 | image_pt = image_pt / 255 * 2 - 1
58 | image_pt = rearrange(image_pt, "h w c -> c h w")
59 |
60 | image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples)
61 |
62 | return image_pt
63 |
64 | def save_output(args, mesh_outputs):
65 |
66 | os.makedirs(args.output_dir, exist_ok=True)
67 | for i, mesh in enumerate(mesh_outputs):
68 | mesh.mesh_f = mesh.mesh_f[:, ::-1]
69 | mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
70 |
71 | name = str(i) + "_out_mesh.obj"
72 | mesh_output.export(os.path.join(args.output_dir, name), include_normals=True)
73 |
74 | print(f'-----------------------------------------------------------------------------')
75 | print(f'>>> Finished and mesh saved in {args.output_dir}')
76 | print(f'-----------------------------------------------------------------------------')
77 |
78 | return 0
79 |
80 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
81 |
82 | surface = load_surface(args.pointcloud_path)
83 |
84 | # encoding
85 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
86 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
87 |
88 | # decoding
89 | latents = model.model.shape_model.decode(shape_zq)
90 | geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
91 |
92 | # reconstruction
93 | mesh_v_f, has_surface = extract_geometry(
94 | geometric_func=geometric_func,
95 | device=surface.device,
96 | batch_size=surface.shape[0],
97 | bounds=bounds,
98 | octree_depth=octree_depth,
99 | num_chunks=num_chunks,
100 | )
101 | recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1])
102 |
103 | # save
104 | os.makedirs(args.output_dir, exist_ok=True)
105 | recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj'))
106 |
107 | print(f'-----------------------------------------------------------------------------')
108 | print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}')
109 | print(f'-----------------------------------------------------------------------------')
110 |
111 | return 0
112 |
113 | def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7):
114 |
115 | sample_inputs = {
116 | "image": prepare_image(args)
117 | }
118 |
119 | mesh_outputs = model.sample(
120 | sample_inputs,
121 | sample_times=1,
122 | guidance_scale=guidance_scale,
123 | return_intermediates=False,
124 | bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
125 | octree_depth=octree_depth,
126 | )[0]
127 |
128 | save_output(args, mesh_outputs)
129 |
130 | return 0
131 |
132 | def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7):
133 |
134 | sample_inputs = {
135 | "text": [args.text] * num_samples
136 | }
137 | mesh_outputs = model.sample(
138 | sample_inputs,
139 | sample_times=1,
140 | guidance_scale=guidance_scale,
141 | return_intermediates=False,
142 | bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
143 | octree_depth=octree_depth,
144 | )[0]
145 |
146 | save_output(args, mesh_outputs)
147 |
148 | return 0
149 |
150 | task_dick = {
151 | 'reconstruction': reconstruction,
152 | 'image2mesh': image2mesh,
153 | 'text2mesh': text2mesh,
154 | }
155 |
156 | if __name__ == "__main__":
157 | '''
158 | 1. Reconstruct point cloud
159 | 2. Image-conditioned generation
160 | 3. Text-conditioned generation
161 | '''
162 | parser = argparse.ArgumentParser()
163 | parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True)
164 | parser.add_argument("--config_path", type=str, required=True)
165 | parser.add_argument("--ckpt_path", type=str, required=True)
166 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
167 | parser.add_argument("--image_path", type=str, help='Path to the input image')
168 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
169 | parser.add_argument("--output_dir", type=str, default='./output')
170 | parser.add_argument("-s", "--seed", type=int, default=0)
171 | args = parser.parse_args()
172 |
173 | pl.seed_everything(args.seed)
174 |
175 | print(f'-----------------------------------------------------------------------------')
176 | print(f'>>> Running {args.task}')
177 | args.output_dir = os.path.join(args.output_dir, args.task)
178 | print(f'>>> Output directory: {args.output_dir}')
179 | print(f'-----------------------------------------------------------------------------')
180 |
181 | task_dick[args.task](args, load_model(args))
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/data/templates.json:
--------------------------------------------------------------------------------
1 | {
2 | "shape": [
3 | "a point cloud model of {}.",
4 | "There is a {} in the scene.",
5 | "There is the {} in the scene.",
6 | "a photo of a {} in the scene.",
7 | "a photo of the {} in the scene.",
8 | "a photo of one {} in the scene.",
9 | "itap of a {}.",
10 | "itap of my {}.",
11 | "itap of the {}.",
12 | "a photo of a {}.",
13 | "a photo of my {}.",
14 | "a photo of the {}.",
15 | "a photo of one {}.",
16 | "a photo of many {}.",
17 | "a good photo of a {}.",
18 | "a good photo of the {}.",
19 | "a bad photo of a {}.",
20 | "a bad photo of the {}.",
21 | "a photo of a nice {}.",
22 | "a photo of the nice {}.",
23 | "a photo of a cool {}.",
24 | "a photo of the cool {}.",
25 | "a photo of a weird {}.",
26 | "a photo of the weird {}.",
27 | "a photo of a small {}.",
28 | "a photo of the small {}.",
29 | "a photo of a large {}.",
30 | "a photo of the large {}.",
31 | "a photo of a clean {}.",
32 | "a photo of the clean {}.",
33 | "a photo of a dirty {}.",
34 | "a photo of the dirty {}.",
35 | "a bright photo of a {}.",
36 | "a bright photo of the {}.",
37 | "a dark photo of a {}.",
38 | "a dark photo of the {}.",
39 | "a photo of a hard to see {}.",
40 | "a photo of the hard to see {}.",
41 | "a low resolution photo of a {}.",
42 | "a low resolution photo of the {}.",
43 | "a cropped photo of a {}.",
44 | "a cropped photo of the {}.",
45 | "a close-up photo of a {}.",
46 | "a close-up photo of the {}.",
47 | "a jpeg corrupted photo of a {}.",
48 | "a jpeg corrupted photo of the {}.",
49 | "a blurry photo of a {}.",
50 | "a blurry photo of the {}.",
51 | "a pixelated photo of a {}.",
52 | "a pixelated photo of the {}.",
53 | "a black and white photo of the {}.",
54 | "a black and white photo of a {}",
55 | "a plastic {}.",
56 | "the plastic {}.",
57 | "a toy {}.",
58 | "the toy {}.",
59 | "a plushie {}.",
60 | "the plushie {}.",
61 | "a cartoon {}.",
62 | "the cartoon {}.",
63 | "an embroidered {}.",
64 | "the embroidered {}.",
65 | "a painting of the {}.",
66 | "a painting of a {}."
67 | ]
68 |
69 | }
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/data/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def worker_init_fn(_):
8 | worker_info = torch.utils.data.get_worker_info()
9 | worker_id = worker_info.id
10 |
11 | # dataset = worker_info.dataset
12 | # split_size = dataset.num_records // worker_info.num_workers
13 | # # reset num_records to the true number to retain reliable length information
14 | # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
15 | # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
16 | # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
17 |
18 | return np.random.seed(np.random.get_state()[1][0] + worker_id)
19 |
20 |
21 | def collation_fn(samples, combine_tensors=True, combine_scalars=True):
22 | """
23 |
24 | Args:
25 | samples (list[dict]):
26 | combine_tensors:
27 | combine_scalars:
28 |
29 | Returns:
30 |
31 | """
32 |
33 | result = {}
34 |
35 | keys = samples[0].keys()
36 |
37 | for key in keys:
38 | result[key] = []
39 |
40 | for sample in samples:
41 | for key in keys:
42 | val = sample[key]
43 | result[key].append(val)
44 |
45 | for key in keys:
46 | val_list = result[key]
47 | if isinstance(val_list[0], (int, float)):
48 | if combine_scalars:
49 | result[key] = np.array(result[key])
50 |
51 | elif isinstance(val_list[0], torch.Tensor):
52 | if combine_tensors:
53 | result[key] = torch.stack(val_list)
54 |
55 | elif isinstance(val_list[0], np.ndarray):
56 | if combine_tensors:
57 | result[key] = np.stack(val_list)
58 |
59 | return result
60 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/graphics/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/graphics/primitives/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .volume import generate_dense_grid_points
4 |
5 | from .mesh import (
6 | MeshOutput,
7 | save_obj,
8 | savemeshtes2
9 | )
10 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/graphics/primitives/mesh.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import cv2
5 | import numpy as np
6 | import PIL.Image
7 | from typing import Optional
8 |
9 | import trimesh
10 |
11 |
12 | def save_obj(pointnp_px3, facenp_fx3, fname):
13 | fid = open(fname, "w")
14 | write_str = ""
15 | for pidx, p in enumerate(pointnp_px3):
16 | pp = p
17 | write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
18 |
19 | for i, f in enumerate(facenp_fx3):
20 | f1 = f + 1
21 | write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
22 | fid.write(write_str)
23 | fid.close()
24 | return
25 |
26 |
27 | def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
28 | fol, na = os.path.split(fname)
29 | na, _ = os.path.splitext(na)
30 |
31 | matname = "%s/%s.mtl" % (fol, na)
32 | fid = open(matname, "w")
33 | fid.write("newmtl material_0\n")
34 | fid.write("Kd 1 1 1\n")
35 | fid.write("Ka 0 0 0\n")
36 | fid.write("Ks 0.4 0.4 0.4\n")
37 | fid.write("Ns 10\n")
38 | fid.write("illum 2\n")
39 | fid.write("map_Kd %s.png\n" % na)
40 | fid.close()
41 | ####
42 |
43 | fid = open(fname, "w")
44 | fid.write("mtllib %s.mtl\n" % na)
45 |
46 | for pidx, p in enumerate(pointnp_px3):
47 | pp = p
48 | fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
49 |
50 | for pidx, p in enumerate(tcoords_px2):
51 | pp = p
52 | fid.write("vt %f %f\n" % (pp[0], pp[1]))
53 |
54 | fid.write("usemtl material_0\n")
55 | for i, f in enumerate(facenp_fx3):
56 | f1 = f + 1
57 | f2 = facetex_fx3[i] + 1
58 | fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
59 | fid.close()
60 |
61 | PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
62 | os.path.join(fol, "%s.png" % na))
63 |
64 | return
65 |
66 |
67 | class MeshOutput(object):
68 |
69 | def __init__(self,
70 | mesh_v: np.ndarray,
71 | mesh_f: np.ndarray,
72 | vertex_colors: Optional[np.ndarray] = None,
73 | uvs: Optional[np.ndarray] = None,
74 | mesh_tex_idx: Optional[np.ndarray] = None,
75 | tex_map: Optional[np.ndarray] = None):
76 |
77 | self.mesh_v = mesh_v
78 | self.mesh_f = mesh_f
79 | self.vertex_colors = vertex_colors
80 | self.uvs = uvs
81 | self.mesh_tex_idx = mesh_tex_idx
82 | self.tex_map = tex_map
83 |
84 | def contain_uv_texture(self):
85 | return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
86 |
87 | def contain_vertex_colors(self):
88 | return self.vertex_colors is not None
89 |
90 | def export(self, fname):
91 |
92 | if self.contain_uv_texture():
93 | savemeshtes2(
94 | self.mesh_v,
95 | self.uvs,
96 | self.mesh_f,
97 | self.mesh_tex_idx,
98 | self.tex_map,
99 | fname
100 | )
101 |
102 | elif self.contain_vertex_colors():
103 | mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
104 | mesh_obj.export(fname)
105 |
106 | else:
107 | save_obj(
108 | self.mesh_v,
109 | self.mesh_f,
110 | fname
111 | )
112 |
113 |
114 |
115 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/graphics/primitives/volume.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 |
5 |
6 | def generate_dense_grid_points(bbox_min: np.ndarray,
7 | bbox_max: np.ndarray,
8 | octree_depth: int,
9 | indexing: str = "ij"):
10 | length = bbox_max - bbox_min
11 | num_cells = np.exp2(octree_depth)
12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
16 | xyz = np.stack((xs, ys, zs), axis=-1)
17 | xyz = xyz.reshape(-1, 3)
18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
19 |
20 | return xyz, grid_size, length
21 |
22 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/asl_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/asl_diffusion/asl_udt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | from typing import Optional
6 | from diffusers.models.embeddings import Timesteps
7 | import math
8 |
9 | from meshanything_train.miche.michelangelo.models.modules.transformer_blocks import MLP
10 | from meshanything_train.miche.michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer
11 |
12 |
13 | class ConditionalASLUDTDenoiser(nn.Module):
14 |
15 | def __init__(self, *,
16 | device: Optional[torch.device],
17 | dtype: Optional[torch.dtype],
18 | input_channels: int,
19 | output_channels: int,
20 | n_ctx: int,
21 | width: int,
22 | layers: int,
23 | heads: int,
24 | context_dim: int,
25 | context_ln: bool = True,
26 | skip_ln: bool = False,
27 | init_scale: float = 0.25,
28 | flip_sin_to_cos: bool = False,
29 | use_checkpoint: bool = False):
30 | super().__init__()
31 |
32 | self.use_checkpoint = use_checkpoint
33 |
34 | init_scale = init_scale * math.sqrt(1.0 / width)
35 |
36 | self.backbone = UNetDiffusionTransformer(
37 | device=device,
38 | dtype=dtype,
39 | n_ctx=n_ctx,
40 | width=width,
41 | layers=layers,
42 | heads=heads,
43 | skip_ln=skip_ln,
44 | init_scale=init_scale,
45 | use_checkpoint=use_checkpoint
46 | )
47 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
48 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
49 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
50 |
51 | # timestep embedding
52 | self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
53 | self.time_proj = MLP(
54 | device=device, dtype=dtype, width=width, init_scale=init_scale
55 | )
56 |
57 | self.context_embed = nn.Sequential(
58 | nn.LayerNorm(context_dim, device=device, dtype=dtype),
59 | nn.Linear(context_dim, width, device=device, dtype=dtype),
60 | )
61 |
62 | if context_ln:
63 | self.context_embed = nn.Sequential(
64 | nn.LayerNorm(context_dim, device=device, dtype=dtype),
65 | nn.Linear(context_dim, width, device=device, dtype=dtype),
66 | )
67 | else:
68 | self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
69 |
70 | def forward(self,
71 | model_input: torch.FloatTensor,
72 | timestep: torch.LongTensor,
73 | context: torch.FloatTensor):
74 |
75 | r"""
76 | Args:
77 | model_input (torch.FloatTensor): [bs, n_data, c]
78 | timestep (torch.LongTensor): [bs,]
79 | context (torch.FloatTensor): [bs, context_tokens, c]
80 |
81 | Returns:
82 | sample (torch.FloatTensor): [bs, n_data, c]
83 |
84 | """
85 |
86 | _, n_data, _ = model_input.shape
87 |
88 | # 1. time
89 | t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
90 |
91 | # 2. conditions projector
92 | context = self.context_embed(context)
93 |
94 | # 3. denoiser
95 | x = self.input_proj(model_input)
96 | x = torch.cat([t_emb, context, x], dim=1)
97 | x = self.backbone(x)
98 | x = self.ln_post(x)
99 | x = x[:, -n_data:]
100 | sample = self.output_proj(x)
101 |
102 | return sample
103 |
104 |
105 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/asl_diffusion/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class BaseDenoiser(nn.Module):
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x, t, context):
13 | raise NotImplementedError
14 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/asl_diffusion/inference_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from tqdm import tqdm
5 | from typing import Tuple, List, Union, Optional
6 | from diffusers.schedulers import DDIMScheduler
7 |
8 |
9 | __all__ = ["ddim_sample"]
10 |
11 |
12 | def ddim_sample(ddim_scheduler: DDIMScheduler,
13 | diffusion_model: torch.nn.Module,
14 | shape: Union[List[int], Tuple[int]],
15 | cond: torch.FloatTensor,
16 | steps: int,
17 | eta: float = 0.0,
18 | guidance_scale: float = 3.0,
19 | do_classifier_free_guidance: bool = True,
20 | generator: Optional[torch.Generator] = None,
21 | device: torch.device = "cuda:0",
22 | disable_prog: bool = True):
23 |
24 | assert steps > 0, f"{steps} must > 0."
25 |
26 | # init latents
27 | bsz = cond.shape[0]
28 | if do_classifier_free_guidance:
29 | bsz = bsz // 2
30 |
31 | latents = torch.randn(
32 | (bsz, *shape),
33 | generator=generator,
34 | device=cond.device,
35 | dtype=cond.dtype,
36 | )
37 | # scale the initial noise by the standard deviation required by the scheduler
38 | latents = latents * ddim_scheduler.init_noise_sigma
39 | # set timesteps
40 | ddim_scheduler.set_timesteps(steps)
41 | timesteps = ddim_scheduler.timesteps.to(device)
42 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
43 | # eta (η) is only used with the DDIMScheduler, and between [0, 1]
44 | extra_step_kwargs = {
45 | "eta": eta,
46 | "generator": generator
47 | }
48 |
49 | # reverse
50 | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
51 | # expand the latents if we are doing classifier free guidance
52 | latent_model_input = (
53 | torch.cat([latents] * 2)
54 | if do_classifier_free_guidance
55 | else latents
56 | )
57 | # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
58 | # predict the noise residual
59 | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
60 | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
61 | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
62 |
63 | # perform guidance
64 | if do_classifier_free_guidance:
65 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
66 | noise_pred = noise_pred_uncond + guidance_scale * (
67 | noise_pred_text - noise_pred_uncond
68 | )
69 | # text_embeddings_for_guidance = encoder_hidden_states.chunk(
70 | # 2)[1] if do_classifier_free_guidance else encoder_hidden_states
71 | # compute the previous noisy sample x_t -> x_t-1
72 | latents = ddim_scheduler.step(
73 | noise_pred, t, latents, **extra_step_kwargs
74 | ).prev_sample
75 |
76 | yield latents, t
77 |
78 |
79 | def karra_sample():
80 | pass
81 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/conditional_encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .clip import CLIPEncoder
4 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/conditional_encoders/clip.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | from dataclasses import dataclass
7 | from torchvision.transforms import Normalize
8 | from transformers import CLIPModel, CLIPTokenizer
9 | from transformers.utils import ModelOutput
10 | from typing import Iterable, Optional, Union, List
11 |
12 |
13 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
14 |
15 |
16 | @dataclass
17 | class CLIPEmbedOutput(ModelOutput):
18 | last_hidden_state: torch.FloatTensor = None
19 | pooler_output: torch.FloatTensor = None
20 | embeds: torch.FloatTensor = None
21 |
22 |
23 | class CLIPEncoder(torch.nn.Module):
24 |
25 | def __init__(self, model_path="openai/clip-vit-base-patch32"):
26 |
27 | super().__init__()
28 |
29 | # Load the CLIP model and processor
30 | self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
31 | self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
32 | self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33 |
34 | self.model.training = False
35 | for p in self.model.parameters():
36 | p.requires_grad = False
37 |
38 | @torch.no_grad()
39 | def encode_image(self, images: Iterable[Optional[ImageType]]):
40 | pixel_values = self.image_preprocess(images)
41 |
42 | vision_outputs = self.model.vision_model(pixel_values=pixel_values)
43 |
44 | pooler_output = vision_outputs[1] # pooled_output
45 | image_features = self.model.visual_projection(pooler_output)
46 |
47 | visual_embeds = CLIPEmbedOutput(
48 | last_hidden_state=vision_outputs.last_hidden_state,
49 | pooler_output=pooler_output,
50 | embeds=image_features
51 | )
52 |
53 | return visual_embeds
54 |
55 | @torch.no_grad()
56 | def encode_text(self, texts: List[str]):
57 | text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
58 |
59 | text_outputs = self.model.text_model(input_ids=text_inputs)
60 |
61 | pooler_output = text_outputs[1] # pooled_output
62 | text_features = self.model.text_projection(pooler_output)
63 |
64 | text_embeds = CLIPEmbedOutput(
65 | last_hidden_state=text_outputs.last_hidden_state,
66 | pooler_output=pooler_output,
67 | embeds=text_features
68 | )
69 |
70 | return text_embeds
71 |
72 | def forward(self,
73 | images: Iterable[Optional[ImageType]],
74 | texts: List[str]):
75 |
76 | visual_embeds = self.encode_image(images)
77 | text_embeds = self.encode_text(texts)
78 |
79 | return visual_embeds, text_embeds
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .checkpoint import checkpoint
4 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/modules/checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
4 | """
5 |
6 | import torch
7 | from typing import Callable, Iterable, Sequence, Union
8 |
9 |
10 | def checkpoint(
11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
12 | inputs: Sequence[torch.Tensor],
13 | params: Iterable[torch.Tensor],
14 | flag: bool,
15 | use_deepspeed: bool = False
16 | ):
17 | """
18 | Evaluate a function without caching intermediate activations, allowing for
19 | reduced memory at the expense of extra compute in the backward pass.
20 | :param func: the function to evaluate.
21 | :param inputs: the argument sequence to pass to `func`.
22 | :param params: a sequence of parameters `func` depends on but does not
23 | explicitly take as arguments.
24 | :param flag: if False, disable gradient checkpointing.
25 | :param use_deepspeed: if True, use deepspeed
26 | """
27 | if flag:
28 | if use_deepspeed:
29 | import deepspeed
30 | return deepspeed.checkpointing.checkpoint(func, *inputs)
31 |
32 | args = tuple(inputs) + tuple(params)
33 | return CheckpointFunction.apply(func, len(inputs), *args)
34 | else:
35 | return func(*inputs)
36 |
37 |
38 | class CheckpointFunction(torch.autograd.Function):
39 | @staticmethod
40 | @torch.cuda.amp.custom_fwd
41 | def forward(ctx, run_function, length, *args):
42 | ctx.run_function = run_function
43 | ctx.input_tensors = list(args[:length])
44 | ctx.input_params = list(args[length:])
45 |
46 | with torch.no_grad():
47 | output_tensors = ctx.run_function(*ctx.input_tensors)
48 | return output_tensors
49 |
50 | @staticmethod
51 | @torch.cuda.amp.custom_bwd
52 | def backward(ctx, *output_grads):
53 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
54 | with torch.enable_grad():
55 | # Fixes a bug where the first op in run_function modifies the
56 | # Tensor storage in place, which is not allowed for detach()'d
57 | # Tensors.
58 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
59 | output_tensors = ctx.run_function(*shallow_copies)
60 | input_grads = torch.autograd.grad(
61 | output_tensors,
62 | ctx.input_tensors + ctx.input_params,
63 | output_grads,
64 | allow_unused=True,
65 | )
66 | del ctx.input_tensors
67 | del ctx.input_params
68 | del output_tensors
69 | return (None, None) + input_grads
70 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/modules/diffusion_transformer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from typing import Optional
7 |
8 | from meshanything_train.miche.michelangelo.models.modules.checkpoint import checkpoint
9 | from meshanything_train.miche.michelangelo.models.modules.transformer_blocks import (
10 | init_linear,
11 | MLP,
12 | MultiheadCrossAttention,
13 | MultiheadAttention,
14 | ResidualAttentionBlock
15 | )
16 |
17 |
18 | class AdaLayerNorm(nn.Module):
19 | def __init__(self,
20 | device: torch.device,
21 | dtype: torch.dtype,
22 | width: int):
23 |
24 | super().__init__()
25 |
26 | self.silu = nn.SiLU(inplace=True)
27 | self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
28 | self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
29 |
30 | def forward(self, x, timestep):
31 | emb = self.linear(timestep)
32 | scale, shift = torch.chunk(emb, 2, dim=2)
33 | x = self.layernorm(x) * (1 + scale) + shift
34 | return x
35 |
36 |
37 | class DitBlock(nn.Module):
38 | def __init__(
39 | self,
40 | *,
41 | device: torch.device,
42 | dtype: torch.dtype,
43 | n_ctx: int,
44 | width: int,
45 | heads: int,
46 | context_dim: int,
47 | qkv_bias: bool = False,
48 | init_scale: float = 1.0,
49 | use_checkpoint: bool = False
50 | ):
51 | super().__init__()
52 |
53 | self.use_checkpoint = use_checkpoint
54 |
55 | self.attn = MultiheadAttention(
56 | device=device,
57 | dtype=dtype,
58 | n_ctx=n_ctx,
59 | width=width,
60 | heads=heads,
61 | init_scale=init_scale,
62 | qkv_bias=qkv_bias
63 | )
64 | self.ln_1 = AdaLayerNorm(device, dtype, width)
65 |
66 | if context_dim is not None:
67 | self.ln_2 = AdaLayerNorm(device, dtype, width)
68 | self.cross_attn = MultiheadCrossAttention(
69 | device=device,
70 | dtype=dtype,
71 | width=width,
72 | heads=heads,
73 | data_width=context_dim,
74 | init_scale=init_scale,
75 | qkv_bias=qkv_bias
76 | )
77 |
78 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
79 | self.ln_3 = AdaLayerNorm(device, dtype, width)
80 |
81 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
82 | return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
83 |
84 | def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
85 | x = x + self.attn(self.ln_1(x, t))
86 | if context is not None:
87 | x = x + self.cross_attn(self.ln_2(x, t), context)
88 | x = x + self.mlp(self.ln_3(x, t))
89 | return x
90 |
91 |
92 | class DiT(nn.Module):
93 | def __init__(
94 | self,
95 | *,
96 | device: Optional[torch.device],
97 | dtype: Optional[torch.dtype],
98 | n_ctx: int,
99 | width: int,
100 | layers: int,
101 | heads: int,
102 | context_dim: int,
103 | init_scale: float = 0.25,
104 | qkv_bias: bool = False,
105 | use_checkpoint: bool = False
106 | ):
107 | super().__init__()
108 | self.n_ctx = n_ctx
109 | self.width = width
110 | self.layers = layers
111 |
112 | self.resblocks = nn.ModuleList(
113 | [
114 | DitBlock(
115 | device=device,
116 | dtype=dtype,
117 | n_ctx=n_ctx,
118 | width=width,
119 | heads=heads,
120 | context_dim=context_dim,
121 | qkv_bias=qkv_bias,
122 | init_scale=init_scale,
123 | use_checkpoint=use_checkpoint
124 | )
125 | for _ in range(layers)
126 | ]
127 | )
128 |
129 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
130 | for block in self.resblocks:
131 | x = block(x, t, context)
132 | return x
133 |
134 |
135 | class UNetDiffusionTransformer(nn.Module):
136 | def __init__(
137 | self,
138 | *,
139 | device: Optional[torch.device],
140 | dtype: Optional[torch.dtype],
141 | n_ctx: int,
142 | width: int,
143 | layers: int,
144 | heads: int,
145 | init_scale: float = 0.25,
146 | qkv_bias: bool = False,
147 | skip_ln: bool = False,
148 | use_checkpoint: bool = False
149 | ):
150 | super().__init__()
151 |
152 | self.n_ctx = n_ctx
153 | self.width = width
154 | self.layers = layers
155 |
156 | self.encoder = nn.ModuleList()
157 | for _ in range(layers):
158 | resblock = ResidualAttentionBlock(
159 | device=device,
160 | dtype=dtype,
161 | n_ctx=n_ctx,
162 | width=width,
163 | heads=heads,
164 | init_scale=init_scale,
165 | qkv_bias=qkv_bias,
166 | use_checkpoint=use_checkpoint
167 | )
168 | self.encoder.append(resblock)
169 |
170 | self.middle_block = ResidualAttentionBlock(
171 | device=device,
172 | dtype=dtype,
173 | n_ctx=n_ctx,
174 | width=width,
175 | heads=heads,
176 | init_scale=init_scale,
177 | qkv_bias=qkv_bias,
178 | use_checkpoint=use_checkpoint
179 | )
180 |
181 | self.decoder = nn.ModuleList()
182 | for _ in range(layers):
183 | resblock = ResidualAttentionBlock(
184 | device=device,
185 | dtype=dtype,
186 | n_ctx=n_ctx,
187 | width=width,
188 | heads=heads,
189 | init_scale=init_scale,
190 | qkv_bias=qkv_bias,
191 | use_checkpoint=use_checkpoint
192 | )
193 | linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
194 | init_linear(linear, init_scale)
195 |
196 | layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
197 |
198 | self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
199 |
200 | def forward(self, x: torch.Tensor):
201 |
202 | enc_outputs = []
203 | for block in self.encoder:
204 | x = block(x)
205 | enc_outputs.append(x)
206 |
207 | x = self.middle_block(x)
208 |
209 | for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
210 | x = torch.cat([enc_outputs.pop(), x], dim=-1)
211 | x = linear(x)
212 |
213 | if layer_norm is not None:
214 | x = layer_norm(x)
215 |
216 | x = resblock(x)
217 |
218 | return x
219 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/modules/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Union, List
4 |
5 |
6 | class AbstractDistribution(object):
7 | def sample(self):
8 | raise NotImplementedError()
9 |
10 | def mode(self):
11 | raise NotImplementedError()
12 |
13 |
14 | class DiracDistribution(AbstractDistribution):
15 | def __init__(self, value):
16 | self.value = value
17 |
18 | def sample(self):
19 | return self.value
20 |
21 | def mode(self):
22 | return self.value
23 |
24 |
25 | class DiagonalGaussianDistribution(object):
26 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
27 | self.feat_dim = feat_dim
28 | self.parameters = parameters
29 |
30 | if isinstance(parameters, list):
31 | self.mean = parameters[0]
32 | self.logvar = parameters[1]
33 | else:
34 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
35 |
36 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
37 | self.deterministic = deterministic
38 | self.std = torch.exp(0.5 * self.logvar)
39 | self.var = torch.exp(self.logvar)
40 | if self.deterministic:
41 | self.var = self.std = torch.zeros_like(self.mean)
42 |
43 | def sample(self):
44 | x = self.mean + self.std * torch.randn_like(self.mean)
45 | return x
46 |
47 | def kl(self, other=None, dims=(1, 2, 3)):
48 | if self.deterministic:
49 | return torch.Tensor([0.])
50 | else:
51 | if other is None:
52 | return 0.5 * torch.mean(torch.pow(self.mean, 2)
53 | + self.var - 1.0 - self.logvar,
54 | dim=dims)
55 | else:
56 | return 0.5 * torch.mean(
57 | torch.pow(self.mean - other.mean, 2) / other.var
58 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
59 | dim=dims)
60 |
61 | def nll(self, sample, dims=(1, 2, 3)):
62 | if self.deterministic:
63 | return torch.Tensor([0.])
64 | logtwopi = np.log(2.0 * np.pi)
65 | return 0.5 * torch.sum(
66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
67 | dim=dims)
68 |
69 | def mode(self):
70 | return self.mean
71 |
72 |
73 | def normal_kl(mean1, logvar1, mean2, logvar2):
74 | """
75 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
76 | Compute the KL divergence between two gaussians.
77 | Shapes are automatically broadcasted, so batches can be compared to
78 | scalars, among other use cases.
79 | """
80 | tensor = None
81 | for obj in (mean1, logvar1, mean2, logvar2):
82 | if isinstance(obj, torch.Tensor):
83 | tensor = obj
84 | break
85 | assert tensor is not None, "at least one argument must be a Tensor"
86 |
87 | # Force variances to be Tensors. Broadcasting helps convert scalars to
88 | # Tensors, but it does not work for torch.exp().
89 | logvar1, logvar2 = [
90 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
91 | for x in (logvar1, logvar2)
92 | ]
93 |
94 | return 0.5 * (
95 | -1.0
96 | + logvar2
97 | - logvar1
98 | + torch.exp(logvar1 - logvar2)
99 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
100 | )
101 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/modules/embedder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import math
7 |
8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
9 |
10 |
11 | class FourierEmbedder(nn.Module):
12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
13 | each feature dimension of `x[..., i]` into:
14 | [
15 | sin(x[..., i]),
16 | sin(f_1*x[..., i]),
17 | sin(f_2*x[..., i]),
18 | ...
19 | sin(f_N * x[..., i]),
20 | cos(x[..., i]),
21 | cos(f_1*x[..., i]),
22 | cos(f_2*x[..., i]),
23 | ...
24 | cos(f_N * x[..., i]),
25 | x[..., i] # only present if include_input is True.
26 | ], here f_i is the frequency.
27 |
28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
31 |
32 | Args:
33 | num_freqs (int): the number of frequencies, default is 6;
34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
36 | input_dim (int): the input dimension, default is 3;
37 | include_input (bool): include the input tensor or not, default is True.
38 |
39 | Attributes:
40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
42 |
43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
44 | otherwise, it is input_dim * num_freqs * 2.
45 |
46 | """
47 |
48 | def __init__(self,
49 | num_freqs: int = 6,
50 | logspace: bool = True,
51 | input_dim: int = 3,
52 | include_input: bool = True,
53 | include_pi: bool = True) -> None:
54 |
55 | """The initialization"""
56 |
57 | super().__init__()
58 |
59 | if logspace:
60 | frequencies = 2.0 ** torch.arange(
61 | num_freqs,
62 | dtype=torch.float32
63 | )
64 | else:
65 | frequencies = torch.linspace(
66 | 1.0,
67 | 2.0 ** (num_freqs - 1),
68 | num_freqs,
69 | dtype=torch.float32
70 | )
71 |
72 | if include_pi:
73 | frequencies *= torch.pi
74 |
75 | self.register_buffer("frequencies", frequencies, persistent=False)
76 | self.include_input = include_input
77 | self.num_freqs = num_freqs
78 |
79 | self.out_dim = self.get_dims(input_dim)
80 |
81 | def get_dims(self, input_dim):
82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0
83 | out_dim = input_dim * (self.num_freqs * 2 + temp)
84 |
85 | return out_dim
86 |
87 | def forward(self, x: torch.Tensor) -> torch.Tensor:
88 | """ Forward process.
89 |
90 | Args:
91 | x: tensor of shape [..., dim]
92 |
93 | Returns:
94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
95 | where temp is 1 if include_input is True and 0 otherwise.
96 | """
97 |
98 | if self.num_freqs > 0:
99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
100 | if self.include_input:
101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
102 | else:
103 | return torch.cat((embed.sin(), embed.cos()), dim=-1)
104 | else:
105 | return x
106 |
107 |
108 | class LearnedFourierEmbedder(nn.Module):
109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """
110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
111 |
112 | def __init__(self, in_channels, dim):
113 | super().__init__()
114 | assert (dim % 2) == 0
115 | half_dim = dim // 2
116 | per_channel_dim = half_dim // in_channels
117 | self.weights = nn.Parameter(torch.randn(per_channel_dim))
118 |
119 | def forward(self, x):
120 | """
121 |
122 | Args:
123 | x (torch.FloatTensor): [..., c]
124 |
125 | Returns:
126 | x (torch.FloatTensor): [..., d]
127 | """
128 |
129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
132 | return fouriered
133 |
134 |
135 | class TriplaneLearnedFourierEmbedder(nn.Module):
136 | def __init__(self, in_channels, dim):
137 | super().__init__()
138 |
139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
142 |
143 | self.out_dim = in_channels + dim
144 |
145 | def forward(self, x):
146 |
147 | yz_embed = self.yz_plane_embedder(x)
148 | xz_embed = self.xz_plane_embedder(x)
149 | xy_embed = self.xy_plane_embedder(x)
150 |
151 | embed = yz_embed + xz_embed + xy_embed
152 |
153 | return embed
154 |
155 |
156 | def sequential_pos_embed(num_len, embed_dim):
157 | assert embed_dim % 2 == 0
158 |
159 | pos = torch.arange(num_len, dtype=torch.float32)
160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32)
161 | omega /= embed_dim / 2.
162 | omega = 1. / 10000 ** omega # (D/2,)
163 |
164 | pos = pos.reshape(-1) # (M,)
165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
166 |
167 | emb_sin = torch.sin(out) # (M, D/2)
168 | emb_cos = torch.cos(out) # (M, D/2)
169 |
170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
171 |
172 | return embeddings
173 |
174 |
175 | def timestep_embedding(timesteps, dim, max_period=10000):
176 | """
177 | Create sinusoidal timestep embeddings.
178 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
179 | These may be fractional.
180 | :param dim: the dimension of the output.
181 | :param max_period: controls the minimum frequency of the embeddings.
182 | :return: an [N x dim] Tensor of positional embeddings.
183 | """
184 | half = dim // 2
185 | freqs = torch.exp(
186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
187 | ).to(device=timesteps.device)
188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190 | if dim % 2:
191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
192 | return embedding
193 |
194 |
195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
197 | log2_hashmap_size=19, desired_resolution=None):
198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
199 | return nn.Identity(), input_dim
200 |
201 | elif embed_type == "fourier":
202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
203 | logspace=True, include_input=True)
204 | return embedder_obj, embedder_obj.out_dim
205 |
206 | elif embed_type == "hashgrid":
207 | raise NotImplementedError
208 |
209 | elif embed_type == "sphere_harmonic":
210 | raise NotImplementedError
211 |
212 | else:
213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
214 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/tsal/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/tsal/clip_asl_module.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from torch import nn
5 | from einops import rearrange
6 | from transformers import CLIPModel
7 |
8 | from meshanything_train.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule
9 |
10 |
11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
12 |
13 | def __init__(self, *,
14 | shape_model,
15 | clip_model_version: str = "openai/clip-vit-large-patch14"):
16 |
17 | super().__init__()
18 |
19 | # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version)
20 | # for params in self.clip_model.parameters():
21 | # params.requires_grad = False
22 | self.clip_model = None
23 | self.shape_model = shape_model
24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width))
25 | # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5)
26 |
27 | def set_shape_model_only(self):
28 | self.clip_model = None
29 |
30 | def encode_shape_embed(self, surface, return_latents: bool = False):
31 | """
32 |
33 | Args:
34 | surface (torch.FloatTensor): [bs, n, 3 + c]
35 | return_latents (bool):
36 |
37 | Returns:
38 | x (torch.FloatTensor): [bs, projection_dim]
39 | shape_latents (torch.FloatTensor): [bs, m, d]
40 | """
41 |
42 | pc = surface[..., 0:3]
43 | feats = surface[..., 3:]
44 |
45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
46 | x = shape_embed @ self.shape_projection
47 |
48 | if return_latents:
49 | return x, shape_latents
50 | else:
51 | return x
52 |
53 | def encode_image_embed(self, image):
54 | """
55 |
56 | Args:
57 | image (torch.FloatTensor): [bs, 3, h, w]
58 |
59 | Returns:
60 | x (torch.FloatTensor): [bs, projection_dim]
61 | """
62 |
63 | x = self.clip_model.get_image_features(image)
64 |
65 | return x
66 |
67 | def encode_text_embed(self, text):
68 | x = self.clip_model.get_text_features(text)
69 | return x
70 |
71 | def forward(self, surface, image, text):
72 | """
73 |
74 | Args:
75 | surface (torch.FloatTensor):
76 | image (torch.FloatTensor): [bs, 3, 224, 224]
77 | text (torch.LongTensor): [bs, num_templates, 77]
78 |
79 | Returns:
80 | embed_outputs (dict): the embedding outputs, and it contains:
81 | - image_embed (torch.FloatTensor):
82 | - text_embed (torch.FloatTensor):
83 | - shape_embed (torch.FloatTensor):
84 | - logit_scale (float):
85 | """
86 |
87 | # # text embedding
88 | # text_embed_all = []
89 | # for i in range(text.shape[0]):
90 | # text_for_one_sample = text[i]
91 | # text_embed = self.encode_text_embed(text_for_one_sample)
92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
93 | # text_embed = text_embed.mean(dim=0)
94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
95 | # text_embed_all.append(text_embed)
96 | # text_embed_all = torch.stack(text_embed_all)
97 |
98 | b = text.shape[0]
99 | text_tokens = rearrange(text, "b t l -> (b t) l")
100 | text_embed = self.encode_text_embed(text_tokens)
101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
102 | text_embed = text_embed.mean(dim=1)
103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
104 |
105 | # image embedding
106 | image_embed = self.encode_image_embed(image)
107 |
108 | # shape embedding
109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
110 |
111 | embed_outputs = {
112 | "image_embed": image_embed,
113 | "text_embed": text_embed,
114 | "shape_embed": shape_embed,
115 | # "logit_scale": self.clip_model.logit_scale.exp()
116 | }
117 |
118 | return embed_outputs, shape_latents
119 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/tsal/inference_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from tqdm import tqdm
5 | from einops import repeat
6 | import numpy as np
7 | from typing import Callable, Tuple, List, Union, Optional
8 | from skimage import measure
9 |
10 | from meshanything_train.miche.michelangelo.graphics.primitives import generate_dense_grid_points
11 |
12 |
13 | @torch.no_grad()
14 | def extract_geometry(geometric_func: Callable,
15 | device: torch.device,
16 | batch_size: int = 1,
17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
18 | octree_depth: int = 7,
19 | num_chunks: int = 10000,
20 | disable: bool = True):
21 | """
22 |
23 | Args:
24 | geometric_func:
25 | device:
26 | bounds:
27 | octree_depth:
28 | batch_size:
29 | num_chunks:
30 | disable:
31 |
32 | Returns:
33 |
34 | """
35 |
36 | if isinstance(bounds, float):
37 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
38 |
39 | bbox_min = np.array(bounds[0:3])
40 | bbox_max = np.array(bounds[3:6])
41 | bbox_size = bbox_max - bbox_min
42 |
43 | xyz_samples, grid_size, length = generate_dense_grid_points(
44 | bbox_min=bbox_min,
45 | bbox_max=bbox_max,
46 | octree_depth=octree_depth,
47 | indexing="ij"
48 | )
49 | xyz_samples = torch.FloatTensor(xyz_samples)
50 |
51 | batch_logits = []
52 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
53 | desc="Implicit Function:", disable=disable, leave=False):
54 | queries = xyz_samples[start: start + num_chunks, :].to(device)
55 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
56 |
57 | logits = geometric_func(batch_queries)
58 | batch_logits.append(logits.cpu())
59 |
60 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()
61 |
62 | mesh_v_f = []
63 | has_surface = np.zeros((batch_size,), dtype=np.bool_)
64 | for i in range(batch_size):
65 | try:
66 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
67 | vertices = vertices / grid_size * bbox_size + bbox_min
68 | # vertices[:, [0, 1]] = vertices[:, [1, 0]]
69 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
70 | has_surface[i] = True
71 |
72 | except ValueError:
73 | mesh_v_f.append((None, None))
74 | has_surface[i] = False
75 |
76 | except RuntimeError:
77 | mesh_v_f.append((None, None))
78 | has_surface[i] = False
79 |
80 | return mesh_v_f, has_surface
81 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/models/tsal/tsal_base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch.nn as nn
4 | from typing import Tuple, List, Optional
5 |
6 |
7 | class Point2MeshOutput(object):
8 | def __init__(self):
9 | self.mesh_v = None
10 | self.mesh_f = None
11 | self.center = None
12 | self.pc = None
13 |
14 |
15 | class Latent2MeshOutput(object):
16 |
17 | def __init__(self):
18 | self.mesh_v = None
19 | self.mesh_f = None
20 |
21 |
22 | class AlignedMeshOutput(object):
23 |
24 | def __init__(self):
25 | self.mesh_v = None
26 | self.mesh_f = None
27 | self.surface = None
28 | self.image = None
29 | self.text: Optional[str] = None
30 | self.shape_text_similarity: Optional[float] = None
31 | self.shape_image_similarity: Optional[float] = None
32 |
33 |
34 | class ShapeAsLatentPLModule(nn.Module):
35 | latent_shape: Tuple[int]
36 |
37 | def encode(self, surface, *args, **kwargs):
38 | raise NotImplementedError
39 |
40 | def decode(self, z_q, *args, **kwargs):
41 | raise NotImplementedError
42 |
43 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
44 | raise NotImplementedError
45 |
46 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
47 | raise NotImplementedError
48 |
49 |
50 | class ShapeAsLatentModule(nn.Module):
51 | latent_shape: Tuple[int, int]
52 |
53 | def __init__(self, *args, **kwargs):
54 | super().__init__()
55 |
56 | def encode(self, *args, **kwargs):
57 | raise NotImplementedError
58 |
59 | def decode(self, *args, **kwargs):
60 | raise NotImplementedError
61 |
62 | def query_geometry(self, *args, **kwargs):
63 | raise NotImplementedError
64 |
65 |
66 | class AlignedShapeAsLatentPLModule(nn.Module):
67 | latent_shape: Tuple[int]
68 |
69 | def set_shape_model_only(self):
70 | raise NotImplementedError
71 |
72 | def encode(self, surface, *args, **kwargs):
73 | raise NotImplementedError
74 |
75 | def decode(self, z_q, *args, **kwargs):
76 | raise NotImplementedError
77 |
78 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
79 | raise NotImplementedError
80 |
81 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
82 | raise NotImplementedError
83 |
84 |
85 | class AlignedShapeAsLatentModule(nn.Module):
86 | shape_model: ShapeAsLatentModule
87 | latent_shape: Tuple[int, int]
88 |
89 | def __init__(self, *args, **kwargs):
90 | super().__init__()
91 |
92 | def set_shape_model_only(self):
93 | raise NotImplementedError
94 |
95 | def encode_image_embed(self, *args, **kwargs):
96 | raise NotImplementedError
97 |
98 | def encode_text_embed(self, *args, **kwargs):
99 | raise NotImplementedError
100 |
101 | def encode_shape_embed(self, *args, **kwargs):
102 | raise NotImplementedError
103 |
104 |
105 | class TexturedShapeAsLatentModule(nn.Module):
106 |
107 | def __init__(self, *args, **kwargs):
108 | super().__init__()
109 |
110 | def encode(self, *args, **kwargs):
111 | raise NotImplementedError
112 |
113 | def decode(self, *args, **kwargs):
114 | raise NotImplementedError
115 |
116 | def query_geometry(self, *args, **kwargs):
117 | raise NotImplementedError
118 |
119 | def query_color(self, *args, **kwargs):
120 | raise NotImplementedError
121 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .misc import instantiate_from_config
4 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/utils/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 |
5 |
6 | def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7):
7 |
8 | mse = torch.mean((x - y) ** 2)
9 | psnr = 10 * torch.log10(data_range / (mse + eps))
10 |
11 | return psnr
12 |
13 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/utils/io.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import io
5 | import tarfile
6 | import json
7 | import numpy as np
8 | import numpy.lib.format
9 |
10 |
11 | def mkdir(path):
12 | os.makedirs(path, exist_ok=True)
13 | return path
14 |
15 |
16 | def npy_loads(data):
17 | stream = io.BytesIO(data)
18 | return np.lib.format.read_array(stream)
19 |
20 |
21 | def npz_loads(data):
22 | return np.load(io.BytesIO(data))
23 |
24 |
25 | def json_loads(data):
26 | return json.loads(data)
27 |
28 |
29 | def load_json(filepath):
30 | with open(filepath, "r") as f:
31 | data = json.load(f)
32 | return data
33 |
34 |
35 | def write_json(filepath, data):
36 | with open(filepath, "w") as f:
37 | json.dump(data, f, indent=2)
38 |
39 |
40 | def extract_tar(tar_path, tar_cache_folder):
41 |
42 | with tarfile.open(tar_path, "r") as tar:
43 | tar.extractall(path=tar_cache_folder)
44 |
45 | tar_uids = sorted(os.listdir(tar_cache_folder))
46 | print(f"extract tar: {tar_path} to {tar_cache_folder}")
47 | return tar_uids
48 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/utils/misc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import importlib
4 |
5 | import torch
6 | import torch.distributed as dist
7 | from typing import Union
8 |
9 |
10 |
11 |
12 | def get_obj_from_str(string, reload=False):
13 | module, cls = string.rsplit(".", 1)
14 | if reload:
15 | module_imp = importlib.import_module(module)
16 | importlib.reload(module_imp)
17 | return getattr(importlib.import_module(module, package=None), cls)
18 |
19 |
20 | def get_obj_from_config(config):
21 | if "target" not in config:
22 | raise KeyError("Expected key `target` to instantiate.")
23 |
24 | return get_obj_from_str(config["target"])
25 |
26 |
27 | def instantiate_from_config(config, **kwargs):
28 | if "target" not in config:
29 | raise KeyError("Expected key `target` to instantiate.")
30 |
31 | cls = get_obj_from_str(config["target"])
32 |
33 | params = config.get("params", dict())
34 | # params.update(kwargs)
35 | # instance = cls(**params)
36 | kwargs.update(params)
37 | instance = cls(**kwargs)
38 |
39 | return instance
40 |
41 |
42 | def is_dist_avail_and_initialized():
43 | if not dist.is_available():
44 | return False
45 | if not dist.is_initialized():
46 | return False
47 | return True
48 |
49 |
50 | def get_rank():
51 | if not is_dist_avail_and_initialized():
52 | return 0
53 | return dist.get_rank()
54 |
55 |
56 | def get_world_size():
57 | if not is_dist_avail_and_initialized():
58 | return 1
59 | return dist.get_world_size()
60 |
61 |
62 | def all_gather_batch(tensors):
63 | """
64 | Performs all_gather operation on the provided tensors.
65 | """
66 | # Queue the gathered tensors
67 | world_size = get_world_size()
68 | # There is no need for reduction in the single-proc case
69 | if world_size == 1:
70 | return tensors
71 | tensor_list = []
72 | output_tensor = []
73 | for tensor in tensors:
74 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
75 | dist.all_gather(
76 | tensor_all,
77 | tensor,
78 | async_op=False # performance opt
79 | )
80 |
81 | tensor_list.append(tensor_all)
82 |
83 | for tensor_all in tensor_list:
84 | output_tensor.append(torch.cat(tensor_all, dim=0))
85 | return output_tensor
86 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/utils/visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/utils/visualizers/color_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | # Helper functions
6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
7 | colormap = plt.cm.get_cmap(colormap)
8 | if normalize:
9 | vmin = np.min(inp)
10 | vmax = np.max(inp)
11 |
12 | norm = plt.Normalize(vmin, vmax)
13 | return colormap(norm(inp))[:, :3]
14 |
15 |
16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
17 | # tex dims need to be power of two.
18 | array = np.ones((width, height, 3), dtype='float32')
19 |
20 | # width in texels of each checker
21 | checker_w = width / n_checkers_x
22 | checker_h = height / n_checkers_y
23 |
24 | for y in range(height):
25 | for x in range(width):
26 | color_key = int(x / checker_w) + int(y / checker_h)
27 | if color_key % 2 == 0:
28 | array[x, y, :] = [1., 0.874, 0.0]
29 | else:
30 | array[x, y, :] = [0., 0., 0.]
31 | return array
32 |
33 |
34 | def gen_circle(width=256, height=256):
35 | xx, yy = np.mgrid[:width, :height]
36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
37 | array = np.ones((width, height, 4), dtype='float32')
38 | array[:, :, 0] = (circle <= width)
39 | array[:, :, 1] = (circle <= width)
40 | array[:, :, 2] = (circle <= width)
41 | array[:, :, 3] = circle <= width
42 | return array
43 |
44 |
--------------------------------------------------------------------------------
/meshanything_train/miche/michelangelo/utils/visualizers/html_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import io
3 | import base64
4 | import numpy as np
5 | from PIL import Image
6 |
7 |
8 | def to_html_frame(content):
9 |
10 | html_frame = f"""
11 |
12 |
13 | {content}
14 |
15 |
16 | """
17 |
18 | return html_frame
19 |
20 |
21 | def to_single_row_table(caption: str, content: str):
22 |
23 | table_html = f"""
24 |
25 | {caption}
26 |
27 | {content} |
28 |
29 |
30 | """
31 |
32 | return table_html
33 |
34 |
35 | def to_image_embed_tag(image: np.ndarray):
36 |
37 | # Convert np.ndarray to bytes
38 | img = Image.fromarray(image)
39 | raw_bytes = io.BytesIO()
40 | img.save(raw_bytes, "PNG")
41 |
42 | # Encode bytes to base64
43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
44 |
45 | image_tag = f"""
46 |
47 | """
48 |
49 | return image_tag
50 |
--------------------------------------------------------------------------------
/meshanything_train/miche/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | accelerate==0.20.3
3 | addict==2.4.0
4 | aiofiles==23.1.0
5 | aiohttp==3.8.4
6 | aiosignal==1.3.1
7 | altair==5.0.1
8 | antlr4-python3-runtime==4.9.3
9 | anyio==3.6.2
10 | appdirs==1.4.4
11 | argon2-cffi==21.3.0
12 | argon2-cffi-bindings==21.2.0
13 | arrow==1.2.3
14 | asttokens==2.2.1
15 | async-timeout==4.0.2
16 | attrs==22.2.0
17 | backcall==0.2.0
18 | beautifulsoup4==4.11.2
19 | bleach==6.0.0
20 | braceexpand==0.1.7
21 | cachetools==5.3.0
22 | cffi==1.15.1
23 | charset-normalizer==3.0.1
24 | click==8.1.3
25 | coloredlogs==15.0.1
26 | comm==0.1.2
27 | configargparse==1.5.3
28 | contourpy==1.0.7
29 | controlnet-aux==0.0.5
30 | cycler==0.11.0
31 | cython==0.29.33
32 | dash==2.8.1
33 | dash-core-components==2.0.0
34 | dash-html-components==2.0.0
35 | dash-table==5.0.0
36 | dataclasses-json==0.6.0
37 | debugpy==1.6.6
38 | decorator==5.1.1
39 | deepspeed==0.8.1
40 | defusedxml==0.7.1
41 | deprecated==1.2.14
42 | diffusers==0.18.2
43 | docker-pycreds==0.4.0
44 | einops==0.6.0
45 | executing==1.2.0
46 | fastapi==0.101.0
47 | fastjsonschema==2.16.2
48 | ffmpy==0.3.1
49 | filelock==3.9.0
50 | flask==2.2.3
51 | flatbuffers==23.5.26
52 | fonttools==4.38.0
53 | fqdn==1.5.1
54 | frozenlist==1.3.3
55 | fsspec==2023.1.0
56 | ftfy==6.1.1
57 | fvcore==0.1.5.post20221221
58 | gitdb==4.0.10
59 | gitpython==3.1.31
60 | google-auth==2.16.1
61 | google-auth-oauthlib==0.4.6
62 | gradio==3.39.0
63 | gradio-client==0.3.0
64 | grpcio==1.51.3
65 | h11==0.14.0
66 | hjson==3.1.0
67 | httpcore==0.17.3
68 | httpx==0.24.1
69 | huggingface-hub==0.16.4
70 | humanfriendly==10.0
71 | idna==3.4
72 | imageio==2.25.1
73 | importlib-metadata==6.0.0
74 | iopath==0.1.10
75 | ipydatawidgets==4.3.3
76 | ipykernel==6.21.2
77 | ipython==8.10.0
78 | ipython-genutils==0.2.0
79 | ipywidgets==8.0.4
80 | isoduration==20.11.0
81 | itsdangerous==2.1.2
82 | jedi==0.18.2
83 | jinja2==3.1.2
84 | joblib==1.2.0
85 | jsonpointer==2.3
86 | jsonschema==4.17.3
87 | jupyter==1.0.0
88 | jupyter-client==8.0.3
89 | jupyter-console==6.6.1
90 | jupyter-core==5.2.0
91 | jupyter-events==0.6.3
92 | jupyter-server==2.3.0
93 | jupyter-server-terminals==0.4.4
94 | jupyterlab-pygments==0.2.2
95 | jupyterlab-widgets==3.0.5
96 | kiwisolver==1.4.4
97 | lightning-utilities==0.7.1
98 | linkify-it-py==2.0.2
99 | lmdb==1.4.1
100 | markdown==3.4.1
101 | markdown-it-py==2.2.0
102 | markupsafe==2.1.2
103 | marshmallow==3.20.1
104 | matplotlib==3.6.3
105 | matplotlib-inline==0.1.6
106 | mdit-py-plugins==0.3.3
107 | mdurl==0.1.2
108 | mesh2sdf==1.1.0
109 | mistune==2.0.5
110 | mpmath==1.3.0
111 | multidict==6.0.4
112 | mypy-extensions==1.0.0
113 | nbclassic==0.5.2
114 | nbclient==0.7.2
115 | nbconvert==7.2.9
116 | nbformat==5.5.0
117 | nest-asyncio==1.5.6
118 | networkx==3.0
119 | ninja==1.11.1
120 | notebook==6.5.2
121 | notebook-shim==0.2.2
122 | numpy==1.23.1
123 | oauthlib==3.2.2
124 | objaverse==0.0.7
125 | omegaconf==2.3.0
126 | onnxruntime==1.15.1
127 | opencv-contrib-python==4.8.0.74
128 | opencv-python==4.7.0.72
129 | orjson==3.9.2
130 | packaging==21.3
131 | pandas==1.4.4
132 | pandocfilters==1.5.0
133 | parso==0.8.3
134 | pathtools==0.1.2
135 | pexpect==4.8.0
136 | pickleshare==0.7.5
137 | pillow==9.2.0
138 | platformdirs==3.0.0
139 | plotly==5.13.0
140 | portalocker==2.7.0
141 | prometheus-client==0.16.0
142 | prompt-toolkit==3.0.37
143 | protobuf==3.19.6
144 | psutil==5.9.4
145 | ptyprocess==0.7.0
146 | pure-eval==0.2.2
147 | py-cpuinfo==9.0.0
148 | pyasn1==0.4.8
149 | pyasn1-modules==0.2.8
150 | pycparser==2.21
151 | pydantic==1.10.5
152 | pydub==0.25.1
153 | pygltflib==1.16.0
154 | pygments==2.14.0
155 | pymeshlab==2022.2.post3
156 | pyparsing==3.0.9
157 | pyquaternion==0.9.9
158 | pyrsistent==0.19.3
159 | pysdf==0.1.8
160 | python-dateutil==2.8.2
161 | python-json-logger==2.0.7
162 | python-multipart==0.0.6
163 | pythreejs==2.4.2
164 | pytz==2022.7.1
165 | pywavelets==1.4.1
166 | pyyaml==6.0
167 | pyzmq==25.0.0
168 | qtconsole==5.4.0
169 | qtpy==2.3.0
170 | regex==2022.10.31
171 | requests==2.28.2
172 | requests-oauthlib==1.3.1
173 | rfc3339-validator==0.1.4
174 | rfc3986-validator==0.1.1
175 | rsa==4.9
176 | rtree==1.0.1
177 | safetensors==0.3.1
178 | scikit-image==0.19.3
179 | scikit-learn==1.2.1
180 | scipy==1.10.1
181 | semantic-version==2.10.0
182 | send2trash==1.8.0
183 | sentencepiece==0.1.97
184 | sentry-sdk==1.15.0
185 | setproctitle==1.3.2
186 | setuptools==63.4.3
187 | sh==2.0.2
188 | shapely==2.0.1
189 | six==1.16.0
190 | smmap==5.0.0
191 | sniffio==1.3.0
192 | soupsieve==2.4
193 | stack-data==0.6.2
194 | starlette==0.27.0
195 | sympy==1.12
196 | tabulate==0.9.0
197 | tenacity==8.2.1
198 | tensorboard==2.10.1
199 | tensorboard-data-server==0.6.1
200 | tensorboard-plugin-wit==1.8.1
201 | termcolor==2.3.0
202 | terminado==0.17.1
203 | threadpoolctl==3.1.0
204 | tifffile==2023.2.3
205 | timm==0.9.2
206 | tinycss2==1.2.1
207 | tokenizers==0.13.2
208 | toolz==0.12.0
209 | tornado==6.2
210 | tqdm==4.64.1
211 | traitlets==5.9.0
212 | traittypes==0.2.1
213 | transformers==4.30.2
214 | trimesh==3.18.3
215 | triton==1.1.1
216 | typing-extensions==4.5.0
217 | typing-inspect==0.9.0
218 | uc-micro-py==1.0.2
219 | uri-template==1.2.0
220 | urllib3==1.26.14
221 | uvicorn==0.23.2
222 | wandb==0.13.10
223 | wcwidth==0.2.6
224 | webcolors==1.12
225 | webdataset==0.2.33
226 | webencodings==0.5.1
227 | websocket-client==1.5.1
228 | websockets==11.0.3
229 | werkzeug==2.2.3
230 | widgetsnbextension==4.0.5
231 | wrapt==1.15.0
232 | xatlas==0.0.7
233 | yacs==0.1.8
234 | yarl==1.8.2
235 | zipp==3.14.0
236 | # torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
237 | https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=db457a822d736013b6ffe509053001bc918bdd78fe68967b605f53984a9afac5
238 | https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=a9fc38040e133d1779f131b4497caef830e9e699faf89cd323cd58794ffb305b
239 | https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=5bc0e29cb78f7c452eeb4f27029c40049770d51553bf840b4ca2edd63da289ee
240 | # torch-cluster
241 | https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_cluster-1.6.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl
242 | # torch-scatter
243 | https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl
244 | torchmetrics==0.11.1
245 | pytorch_lightning~=1.9.3
246 | git+https://github.com/pyvista/fast-simplification.git
247 | git+https://github.com/skoch9/meshplot.git
248 | git+https://github.com/NVlabs/nvdiffrast/
--------------------------------------------------------------------------------
/meshanything_train/miche/scripts/infer.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --task reconstruction \
3 | --config_path ./configs/aligned_shape_latents/shapevae-256.yaml \
4 | --ckpt_path ./checkpoints/aligned_shape_latents/shapevae-256.ckpt \
5 | --pointcloud_path ./example_data/surface/surface.npz
6 |
7 | python inference.py \
8 | --task image2mesh \
9 | --config_path ./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml \
10 | --ckpt_path ./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt \
11 | --image_path ./example_data/image/car.jpg
12 |
13 | python inference.py \
14 | --task text2mesh \
15 | --config_path ./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml \
16 | --ckpt_path ./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt \
17 | --text "A 3D model of motorcar; Porche Cayenne Turbo."
--------------------------------------------------------------------------------
/meshanything_train/miche/scripts/inference/image2mesh.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --task image2mesh \
3 | --config_path ./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml \
4 | --ckpt_path ./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt \
5 | --image_path ./example_data/image/car.jpg
--------------------------------------------------------------------------------
/meshanything_train/miche/scripts/inference/reconstruction.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --task reconstruction \
3 | --config_path ./configs/aligned_shape_latents/shapevae-256.yaml \
4 | --ckpt_path ./checkpoints/aligned_shape_latents/shapevae-256.ckpt \
5 | --pointcloud_path ./example_data/surface/surface.npz
--------------------------------------------------------------------------------
/meshanything_train/miche/scripts/inference/text2mesh.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --task text2mesh \
3 | --config_path ./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml \
4 | --ckpt_path ./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt \
5 | --text "A 3D model of motorcar; Porche Cayenne Turbo."
--------------------------------------------------------------------------------
/meshanything_train/miche/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 | from distutils.extension import Extension
4 | from Cython.Build import cythonize
5 | import numpy as np
6 |
7 |
8 | setup(
9 | name="michelangelo",
10 | version="0.4.1",
11 | author="Zibo Zhao, Wen Liu and Xin Chen",
12 | author_email="liuwen@shanghaitech.edu.cn",
13 | description="Michelangelo: a 3D Shape Generation System.",
14 | packages=find_packages(exclude=("configs", "tests", "scripts", "example_data")),
15 | python_requires=">=3.8",
16 | install_requires=[
17 | "torch",
18 | "numpy",
19 | "cython",
20 | "tqdm",
21 | ],
22 | )
23 |
--------------------------------------------------------------------------------
/meshanything_train/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 | import numpy as np
4 | from collections import deque
5 | from typing import List
6 | from meshanything_train.dist import is_distributed, barrier, all_reduce_sum
7 |
8 |
9 | def my_worker_init_fn(worker_id):
10 | np.random.seed(np.random.get_state()[1][0] + worker_id)
11 |
12 |
13 | @torch.jit.ignore
14 | def to_list_1d(arr) -> List[float]:
15 | arr = arr.detach().cpu().numpy().tolist()
16 | return arr
17 |
18 |
19 | @torch.jit.ignore
20 | def to_list_3d(arr) -> List[List[List[float]]]:
21 | arr = arr.detach().cpu().numpy().tolist()
22 | return arr
23 |
24 |
25 | def huber_loss(error, delta=1.0):
26 | """
27 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py
28 | x = error = pred - gt or dist(pred,gt)
29 | 0.5 * |x|^2 if |x|<=d
30 | 0.5 * d^2 + d * (|x|-d) if |x|>d
31 | """
32 | abs_error = torch.abs(error)
33 | quadratic = torch.clamp(abs_error, max=delta)
34 | linear = abs_error - quadratic
35 | loss = 0.5 * quadratic ** 2 + delta * linear
36 | return loss
37 |
38 |
39 | # From https://github.com/facebookresearch/detr/blob/master/util/misc.py
40 | class SmoothedValue(object):
41 | """Track a series of values and provide access to smoothed values over a
42 | window or the global series average.
43 | """
44 |
45 | def __init__(self, window_size=20, fmt=None):
46 | if fmt is None:
47 | fmt = "{median:.4f} ({global_avg:.4f})"
48 | self.deque = deque(maxlen=window_size)
49 | self.total = 0.0
50 | self.count = 0
51 | self.fmt = fmt
52 |
53 | def update(self, value, n=1):
54 | self.deque.append(value)
55 | self.count += n
56 | self.total += value * n
57 |
58 | def synchronize_between_processes(self):
59 | """
60 | Warning: does not synchronize the deque!
61 | """
62 | if not is_distributed():
63 | return
64 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
65 | barrier()
66 | all_reduce_sum(t)
67 | t = t.tolist()
68 | self.count = int(t[0])
69 | self.total = t[1]
70 |
71 | @property
72 | def median(self):
73 | d = torch.tensor(list(self.deque))
74 | return d.median().item()
75 |
76 | @property
77 | def avg(self):
78 | d = torch.tensor(list(self.deque), dtype=torch.float32)
79 | return d.mean().item()
80 |
81 | @property
82 | def global_avg(self):
83 | return self.total / self.count
84 |
85 | @property
86 | def max(self):
87 | return max(self.deque)
88 |
89 | @property
90 | def value(self):
91 | return self.deque[-1]
92 |
93 | def __str__(self):
94 | return self.fmt.format(
95 | median=self.median,
96 | avg=self.avg,
97 | global_avg=self.global_avg,
98 | max=self.max,
99 | value=self.value,
100 | )
101 |
--------------------------------------------------------------------------------
/pc_examples/grenade.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buaacyw/MeshAnythingV2/461d3b6ed750ab3443281b2e4a0e30e8ee98097e/pc_examples/grenade.npy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | trimesh==4.2.3
2 | accelerate==0.28.0
3 | mesh2sdf==1.1.0
4 | einops==0.7.0
5 | einx==0.1.3
6 | optimum==1.18.0
7 | omegaconf==2.3.0
8 | opencv-python==4.9.0.80
9 | transformers==4.39.3
10 | numpy==1.26.4
11 | huggingface_hub
12 | matplotlib
13 | gradio
14 | spaces
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 |
3 | import datetime
4 | from meshanything_train.engine import do_train
5 | from meshanything_train.models.single_gpt import SingleGPT
6 |
7 | from accelerate import Accelerator
8 | from accelerate.logging import get_logger
9 | from accelerate.utils import set_seed
10 | import logging
11 | import importlib
12 | from accelerate.utils import DistributedDataParallelKwargs
13 | import torch
14 |
15 | def make_args_parser():
16 | parser = argparse.ArgumentParser("MeshAnything", add_help=False)
17 |
18 | parser.add_argument("--input_pc_num", default=8192, type=int)
19 | parser.add_argument("--max_vertices", default=800, type=int)
20 |
21 | parser.add_argument("--warm_lr_epochs", default=1, type=int)
22 | parser.add_argument("--num_beams", default=1, type=int)
23 | parser.add_argument("--max_seq_ratio", default=0.70, type=float)
24 |
25 | ##### Model Setups #####
26 | parser.add_argument(
27 | '--pretrained_tokenizer_weight',
28 | default=None,
29 | type=str,
30 | help="The weight for pre-trained vqvae"
31 | )
32 |
33 | parser.add_argument('--llm', default="facebook/opt-350m", type=str, help="The LLM backend")
34 | parser.add_argument("--gen_n_max_triangles", default=1600, type=int, help="max number of triangles")
35 |
36 | ##### Training #####
37 | parser.add_argument("--eval_every_iteration", default=2000, type=int)
38 | parser.add_argument("--save_every", default=250, type=int)
39 | parser.add_argument("--generate_every_data", default=1, type=int)
40 |
41 | ##### Testing #####
42 | parser.add_argument(
43 | "--clip_gradient", default=1., type=float,
44 | help="Max L2 norm of the gradient"
45 | )
46 | parser.add_argument("--pad_id", default=-1, type=int, help="padding id")
47 | parser.add_argument("--dataset", default='loop_set_256', help="dataset list split by ','")
48 | parser.add_argument("--n_discrete_size", default=128, type=int, help="discretized 3D space")
49 | parser.add_argument("--data_n_max_triangles", default=1600, type=int, help="max number of triangles")
50 |
51 | parser.add_argument("--n_max_triangles", default=1600, type=int, help="max number of triangles")
52 | parser.add_argument("--n_min_triangles", default=40, type=int, help="max number of triangles")
53 |
54 | parser.add_argument("--shift_scale", default=0.1, type=float)
55 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
56 | parser.add_argument('--data_dir', default="dataset", type=str, help="data path")
57 |
58 | parser.add_argument("--seed", default=0, type=int)
59 |
60 | parser.add_argument("--base_lr", default=1e-4, type=float)
61 | parser.add_argument("--final_lr", default=6e-5, type=float)
62 | parser.add_argument("--lr_scheduler", default="cosine", type=str)
63 | parser.add_argument("--weight_decay", default=0.1, type=float)
64 | parser.add_argument("--optimizer", default="AdamW", type=str)
65 | parser.add_argument("--warm_lr", default=1e-6, type=float)
66 |
67 | parser.add_argument("--no_aug", default=False, action="store_true")
68 | parser.add_argument("--checkpoint_dir", default="default", type=str)
69 | parser.add_argument("--log_every", default=10, type=int)
70 | parser.add_argument("--test_only", default=False, action="store_true")
71 | parser.add_argument("--generate_every_iteration", default=18000, type=int)
72 |
73 | parser.add_argument("--start_epoch", default=-1, type=int)
74 | parser.add_argument("--max_epoch", default=800, type=int)
75 | parser.add_argument("--start_eval_after", default=-1, type=int)
76 | parser.add_argument("--precision", default="fp16", type=str)
77 | parser.add_argument("--batchsize_per_gpu", default=8, type=int)
78 | parser.add_argument(
79 | "--criterion", default=None, type=str,
80 | help='metrics for saving the best model'
81 | )
82 |
83 | parser.add_argument('--pretrained_weights', default=None, type=str)
84 |
85 | args = parser.parse_args()
86 |
87 | return args
88 |
89 | if __name__ == "__main__":
90 |
91 | logging.basicConfig(level=logging.INFO)
92 | logger = get_logger(__file__)
93 |
94 | args = make_args_parser()
95 |
96 | cur_time = datetime.datetime.now().strftime("%d_%H-%M-%S")
97 | wandb_name = args.checkpoint_dir + "_" +cur_time
98 | args.checkpoint_dir = os.path.join("gpt_output", wandb_name)
99 | print("checkpoint_dir:", args.checkpoint_dir)
100 | os.makedirs(args.checkpoint_dir, exist_ok=True)
101 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
102 |
103 | accelerator = Accelerator(
104 | gradient_accumulation_steps=args.gradient_accumulation_steps,
105 | mixed_precision=args.precision,
106 | log_with="wandb",
107 | project_dir=args.checkpoint_dir,
108 | kwargs_handlers=[kwargs]
109 | )
110 | if "default" not in args.checkpoint_dir:
111 | accelerator.init_trackers(
112 | project_name="GPT",
113 | config=vars(args),
114 | init_kwargs={"wandb": {"name": wandb_name}}
115 | )
116 |
117 | set_seed(args.seed, device_specific=True)
118 |
119 | dataset_module = importlib.import_module(f'meshanything_train.{args.dataset}')
120 |
121 | train_dataset = dataset_module.Dataset(args, split_set="train")
122 | test_dataset = dataset_module.Dataset(args, split_set="test")
123 | # make sure no val sample in train set
124 | train_uids = [cur_data['uid'] for cur_data in train_dataset.data]
125 | val_uids = [cur_data['uid'] for cur_data in test_dataset.data]
126 | intersection_list = list(set(train_uids).intersection(set(val_uids)))
127 | print("intersection_list:", len(intersection_list))
128 |
129 | new_train_set_data = []
130 | for cur_data in train_dataset.data:
131 | if cur_data['uid'] not in intersection_list:
132 | new_train_set_data.append(cur_data)
133 | train_dataset.data = new_train_set_data
134 |
135 | dataloaders = {}
136 |
137 | dataloaders['train'] = torch.utils.data.DataLoader(
138 | train_dataset,
139 | batch_size=args.batchsize_per_gpu,
140 | drop_last = True,
141 | shuffle = True,
142 | )
143 |
144 | dataloaders['test'] = torch.utils.data.DataLoader(
145 | test_dataset,
146 | batch_size=args.batchsize_per_gpu,
147 | drop_last = True,
148 | shuffle = False,
149 | )
150 |
151 | model = SingleGPT(args)
152 | model.to(torch.float32)
153 | do_train(
154 | args,
155 | model,
156 | dataloaders,
157 | logger,
158 | accelerator,
159 | )
--------------------------------------------------------------------------------
/training_requirement.txt:
--------------------------------------------------------------------------------
1 | trimesh==4.2.3
2 | accelerate==0.28.0
3 | mesh2sdf==1.1.0
4 | einops==0.7.0
5 | einx==0.1.3
6 | optimum==1.18.0
7 | omegaconf==2.3.0
8 | opencv-python==4.9.0.80
9 | transformers==4.39.3
10 | numpy==1.26.4
11 | huggingface_hub
12 | matplotlib
13 | gradio
14 | spaces
--------------------------------------------------------------------------------