├── LICENSE.txt ├── MeshAnything ├── miche │ ├── LICENSE │ ├── encode.py │ ├── michelangelo │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── templates.json │ │ │ ├── transforms.py │ │ │ └── utils.py │ │ ├── graphics │ │ │ ├── __init__.py │ │ │ └── primitives │ │ │ │ ├── __init__.py │ │ │ │ ├── mesh.py │ │ │ │ └── volume.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── asl_diffusion │ │ │ │ ├── __init__.py │ │ │ │ ├── asl_diffuser_pl_module.py │ │ │ │ ├── asl_udt.py │ │ │ │ ├── base.py │ │ │ │ ├── clip_asl_diffuser_pl_module.py │ │ │ │ └── inference_utils.py │ │ │ ├── conditional_encoders │ │ │ │ ├── __init__.py │ │ │ │ ├── clip.py │ │ │ │ └── encoder_factory.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── checkpoint.py │ │ │ │ ├── diffusion_transformer.py │ │ │ │ ├── distributions.py │ │ │ │ ├── embedder.py │ │ │ │ ├── transformer_blocks.py │ │ │ │ └── transformer_vit.py │ │ │ └── tsal │ │ │ │ ├── __init__.py │ │ │ │ ├── asl_pl_module.py │ │ │ │ ├── clip_asl_module.py │ │ │ │ ├── inference_utils.py │ │ │ │ ├── loss.py │ │ │ │ ├── sal_perceiver.py │ │ │ │ ├── sal_pl_module.py │ │ │ │ └── tsal_base.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── eval.py │ │ │ ├── io.py │ │ │ ├── misc.py │ │ │ └── visualizers │ │ │ ├── __init__.py │ │ │ ├── color_util.py │ │ │ ├── html_util.py │ │ │ └── pythreejs_viewer.py │ └── shapevae-256.yaml └── models │ ├── meshanything.py │ └── shape_opt.py ├── README.md ├── app.py ├── demo └── demo_video.gif ├── examples ├── screwdriver.obj └── wand.obj ├── main.py ├── mesh_to_pc.py ├── pc_examples └── mouse.npy ├── requirements.txt └── setup.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /MeshAnything/miche/encode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | from omegaconf import OmegaConf 4 | import numpy as np 5 | import torch 6 | from .michelangelo.utils.misc import instantiate_from_config 7 | 8 | def load_surface(fp): 9 | 10 | with np.load(fp) as input_pc: 11 | surface = input_pc['points'] 12 | normal = input_pc['normals'] 13 | 14 | rng = np.random.default_rng() 15 | ind = rng.choice(surface.shape[0], 4096, replace=False) 16 | surface = torch.FloatTensor(surface[ind]) 17 | normal = torch.FloatTensor(normal[ind]) 18 | 19 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() 20 | 21 | return surface 22 | 23 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): 24 | 25 | surface = load_surface(args.pointcloud_path) 26 | # old_surface = surface.clone() 27 | 28 | # surface[0,:,0]*=-1 29 | # surface[0,:,1]*=-1 30 | surface[0,:,2]*=-1 31 | 32 | # encoding 33 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) 34 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) 35 | 36 | # decoding 37 | latents = model.model.shape_model.decode(shape_zq) 38 | # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) 39 | 40 | return 0 41 | 42 | def load_model(ckpt_path="MeshAnything/miche/shapevae-256.ckpt"): 43 | model_config = OmegaConf.load("MeshAnything/miche/shapevae-256.yaml") 44 | # print(model_config) 45 | if hasattr(model_config, "model"): 46 | model_config = model_config.model 47 | 48 | model = instantiate_from_config(model_config, ckpt_path=ckpt_path) 49 | model = model.cuda() 50 | model = model.eval() 51 | 52 | return model 53 | if __name__ == "__main__": 54 | ''' 55 | 1. Reconstruct point cloud 56 | 2. Image-conditioned generation 57 | 3. Text-conditioned generation 58 | ''' 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--config_path", type=str, required=True) 61 | parser.add_argument("--ckpt_path", type=str, required=True) 62 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud') 63 | parser.add_argument("--image_path", type=str, help='Path to the input image') 64 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.') 65 | parser.add_argument("--output_dir", type=str, default='./output') 66 | parser.add_argument("-s", "--seed", type=int, default=0) 67 | args = parser.parse_args() 68 | 69 | print(f'-----------------------------------------------------------------------------') 70 | print(f'>>> Output directory: {args.output_dir}') 71 | print(f'-----------------------------------------------------------------------------') 72 | 73 | reconstruction(args, load_model(args)) -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "shape": [ 3 | "a point cloud model of {}.", 4 | "There is a {} in the scene.", 5 | "There is the {} in the scene.", 6 | "a photo of a {} in the scene.", 7 | "a photo of the {} in the scene.", 8 | "a photo of one {} in the scene.", 9 | "itap of a {}.", 10 | "itap of my {}.", 11 | "itap of the {}.", 12 | "a photo of a {}.", 13 | "a photo of my {}.", 14 | "a photo of the {}.", 15 | "a photo of one {}.", 16 | "a photo of many {}.", 17 | "a good photo of a {}.", 18 | "a good photo of the {}.", 19 | "a bad photo of a {}.", 20 | "a bad photo of the {}.", 21 | "a photo of a nice {}.", 22 | "a photo of the nice {}.", 23 | "a photo of a cool {}.", 24 | "a photo of the cool {}.", 25 | "a photo of a weird {}.", 26 | "a photo of the weird {}.", 27 | "a photo of a small {}.", 28 | "a photo of the small {}.", 29 | "a photo of a large {}.", 30 | "a photo of the large {}.", 31 | "a photo of a clean {}.", 32 | "a photo of the clean {}.", 33 | "a photo of a dirty {}.", 34 | "a photo of the dirty {}.", 35 | "a bright photo of a {}.", 36 | "a bright photo of the {}.", 37 | "a dark photo of a {}.", 38 | "a dark photo of the {}.", 39 | "a photo of a hard to see {}.", 40 | "a photo of the hard to see {}.", 41 | "a low resolution photo of a {}.", 42 | "a low resolution photo of the {}.", 43 | "a cropped photo of a {}.", 44 | "a cropped photo of the {}.", 45 | "a close-up photo of a {}.", 46 | "a close-up photo of the {}.", 47 | "a jpeg corrupted photo of a {}.", 48 | "a jpeg corrupted photo of the {}.", 49 | "a blurry photo of a {}.", 50 | "a blurry photo of the {}.", 51 | "a pixelated photo of a {}.", 52 | "a pixelated photo of the {}.", 53 | "a black and white photo of the {}.", 54 | "a black and white photo of a {}", 55 | "a plastic {}.", 56 | "the plastic {}.", 57 | "a toy {}.", 58 | "the toy {}.", 59 | "a plushie {}.", 60 | "the plushie {}.", 61 | "a cartoon {}.", 62 | "the cartoon {}.", 63 | "an embroidered {}.", 64 | "the embroidered {}.", 65 | "a painting of the {}.", 66 | "a painting of a {}." 67 | ] 68 | 69 | } -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import numpy as np 5 | import warnings 6 | import random 7 | from omegaconf.listconfig import ListConfig 8 | from webdataset import pipelinefilter 9 | import torch 10 | import torchvision.transforms.functional as TVF 11 | from torchvision.transforms import InterpolationMode 12 | from torchvision.transforms.transforms import _interpolation_modes_from_int 13 | from typing import Sequence 14 | 15 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 16 | 17 | 18 | def _uid_buffer_pick(buf_dict, rng): 19 | uid_keys = list(buf_dict.keys()) 20 | selected_uid = rng.choice(uid_keys) 21 | buf = buf_dict[selected_uid] 22 | 23 | k = rng.randint(0, len(buf) - 1) 24 | sample = buf[k] 25 | buf[k] = buf[-1] 26 | buf.pop() 27 | 28 | if len(buf) == 0: 29 | del buf_dict[selected_uid] 30 | 31 | return sample 32 | 33 | 34 | def _add_to_buf_dict(buf_dict, sample): 35 | key = sample["__key__"] 36 | uid, uid_sample_id = key.split("_") 37 | if uid not in buf_dict: 38 | buf_dict[uid] = [] 39 | buf_dict[uid].append(sample) 40 | 41 | return buf_dict 42 | 43 | 44 | def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): 45 | """Shuffle the data in the stream. 46 | 47 | This uses a buffer of size `bufsize`. Shuffling at 48 | startup is less random; this is traded off against 49 | yielding samples quickly. 50 | 51 | data: iterator 52 | bufsize: buffer size for shuffling 53 | returns: iterator 54 | rng: either random module or random.Random instance 55 | 56 | """ 57 | if rng is None: 58 | rng = random.Random(int((os.getpid() + time.time()) * 1e9)) 59 | initial = min(initial, bufsize) 60 | buf_dict = dict() 61 | current_samples = 0 62 | for sample in data: 63 | _add_to_buf_dict(buf_dict, sample) 64 | current_samples += 1 65 | 66 | if current_samples < bufsize: 67 | try: 68 | _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708 69 | current_samples += 1 70 | except StopIteration: 71 | pass 72 | 73 | if current_samples >= initial: 74 | current_samples -= 1 75 | yield _uid_buffer_pick(buf_dict, rng) 76 | 77 | while current_samples > 0: 78 | current_samples -= 1 79 | yield _uid_buffer_pick(buf_dict, rng) 80 | 81 | 82 | uid_shuffle = pipelinefilter(_uid_shuffle) 83 | 84 | 85 | class RandomSample(object): 86 | def __init__(self, 87 | num_volume_samples: int = 1024, 88 | num_near_samples: int = 1024): 89 | 90 | super().__init__() 91 | 92 | self.num_volume_samples = num_volume_samples 93 | self.num_near_samples = num_near_samples 94 | 95 | def __call__(self, sample): 96 | rng = np.random.default_rng() 97 | 98 | # 1. sample surface input 99 | total_surface = sample["surface"] 100 | ind = rng.choice(total_surface.shape[0], replace=False) 101 | surface = total_surface[ind] 102 | 103 | # 2. sample volume/near geometric points 104 | vol_points = sample["vol_points"] 105 | vol_label = sample["vol_label"] 106 | near_points = sample["near_points"] 107 | near_label = sample["near_label"] 108 | 109 | ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) 110 | vol_points = vol_points[ind] 111 | vol_label = vol_label[ind] 112 | vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) 113 | 114 | ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) 115 | near_points = near_points[ind] 116 | near_label = near_label[ind] 117 | near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) 118 | 119 | # concat sampled volume and near points 120 | geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) 121 | 122 | sample = { 123 | "surface": surface, 124 | "geo_points": geo_points 125 | } 126 | 127 | return sample 128 | 129 | 130 | class SplitRandomSample(object): 131 | def __init__(self, 132 | use_surface_sample: bool = False, 133 | num_surface_samples: int = 4096, 134 | num_volume_samples: int = 1024, 135 | num_near_samples: int = 1024): 136 | 137 | super().__init__() 138 | 139 | self.use_surface_sample = use_surface_sample 140 | self.num_surface_samples = num_surface_samples 141 | self.num_volume_samples = num_volume_samples 142 | self.num_near_samples = num_near_samples 143 | 144 | def __call__(self, sample): 145 | 146 | rng = np.random.default_rng() 147 | 148 | # 1. sample surface input 149 | surface = sample["surface"] 150 | 151 | if self.use_surface_sample: 152 | replace = surface.shape[0] < self.num_surface_samples 153 | ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace) 154 | surface = surface[ind] 155 | 156 | # 2. sample volume/near geometric points 157 | vol_points = sample["vol_points"] 158 | vol_label = sample["vol_label"] 159 | near_points = sample["near_points"] 160 | near_label = sample["near_label"] 161 | 162 | ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) 163 | vol_points = vol_points[ind] 164 | vol_label = vol_label[ind] 165 | vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) 166 | 167 | ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) 168 | near_points = near_points[ind] 169 | near_label = near_label[ind] 170 | near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) 171 | 172 | # concat sampled volume and near points 173 | geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) 174 | 175 | sample = { 176 | "surface": surface, 177 | "geo_points": geo_points 178 | } 179 | 180 | return sample 181 | 182 | 183 | class FeatureSelection(object): 184 | 185 | VALID_SURFACE_FEATURE_DIMS = { 186 | "none": [0, 1, 2], # xyz 187 | "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal 188 | "normal": [0, 1, 2, 6, 7, 8] 189 | } 190 | 191 | def __init__(self, surface_feature_type: str): 192 | 193 | self.surface_feature_type = surface_feature_type 194 | self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type] 195 | 196 | def __call__(self, sample): 197 | sample["surface"] = sample["surface"][:, self.surface_dims] 198 | return sample 199 | 200 | 201 | class AxisScaleTransform(object): 202 | def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): 203 | assert isinstance(interval, (tuple, list, ListConfig)) 204 | self.interval = interval 205 | self.min_val = interval[0] 206 | self.max_val = interval[1] 207 | self.inter_size = interval[1] - interval[0] 208 | self.jitter = jitter 209 | self.jitter_scale = jitter_scale 210 | 211 | def __call__(self, sample): 212 | 213 | surface = sample["surface"][..., 0:3] 214 | geo_points = sample["geo_points"][..., 0:3] 215 | 216 | scaling = torch.rand(1, 3) * self.inter_size + self.min_val 217 | # print(scaling) 218 | surface = surface * scaling 219 | geo_points = geo_points * scaling 220 | 221 | scale = (1 / torch.abs(surface).max().item()) * 0.999999 222 | surface *= scale 223 | geo_points *= scale 224 | 225 | if self.jitter: 226 | surface += self.jitter_scale * torch.randn_like(surface) 227 | surface.clamp_(min=-1.015, max=1.015) 228 | 229 | sample["surface"][..., 0:3] = surface 230 | sample["geo_points"][..., 0:3] = geo_points 231 | 232 | return sample 233 | 234 | 235 | class ToTensor(object): 236 | 237 | def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")): 238 | self.tensor_keys = tensor_keys 239 | 240 | def __call__(self, sample): 241 | for key in self.tensor_keys: 242 | if key not in sample: 243 | continue 244 | 245 | sample[key] = torch.tensor(sample[key], dtype=torch.float32) 246 | 247 | return sample 248 | 249 | 250 | class AxisScale(object): 251 | def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): 252 | assert isinstance(interval, (tuple, list, ListConfig)) 253 | self.interval = interval 254 | self.jitter = jitter 255 | self.jitter_scale = jitter_scale 256 | 257 | def __call__(self, surface, *args): 258 | scaling = torch.rand(1, 3) * 0.5 + 0.75 259 | # print(scaling) 260 | surface = surface * scaling 261 | scale = (1 / torch.abs(surface).max().item()) * 0.999999 262 | surface *= scale 263 | 264 | args_outputs = [] 265 | for _arg in args: 266 | _arg = _arg * scaling * scale 267 | args_outputs.append(_arg) 268 | 269 | if self.jitter: 270 | surface += self.jitter_scale * torch.randn_like(surface) 271 | surface.clamp_(min=-1, max=1) 272 | 273 | if len(args) == 0: 274 | return surface 275 | else: 276 | return surface, *args_outputs 277 | 278 | 279 | class RandomResize(torch.nn.Module): 280 | """Apply randomly Resize with a given probability.""" 281 | 282 | def __init__( 283 | self, 284 | size, 285 | resize_radio=(0.5, 1), 286 | allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR), 287 | interpolation=InterpolationMode.BICUBIC, 288 | max_size=None, 289 | antialias=None, 290 | ): 291 | super().__init__() 292 | if not isinstance(size, (int, Sequence)): 293 | raise TypeError(f"Size should be int or sequence. Got {type(size)}") 294 | if isinstance(size, Sequence) and len(size) not in (1, 2): 295 | raise ValueError("If size is a sequence, it should have 1 or 2 values") 296 | 297 | self.size = size 298 | self.max_size = max_size 299 | # Backward compatibility with integer value 300 | if isinstance(interpolation, int): 301 | warnings.warn( 302 | "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " 303 | "Please use InterpolationMode enum." 304 | ) 305 | interpolation = _interpolation_modes_from_int(interpolation) 306 | 307 | self.interpolation = interpolation 308 | self.antialias = antialias 309 | 310 | self.resize_radio = resize_radio 311 | self.allow_resize_interpolations = allow_resize_interpolations 312 | 313 | def random_resize_params(self): 314 | radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0] 315 | 316 | if isinstance(self.size, int): 317 | size = int(self.size * radio) 318 | elif isinstance(self.size, Sequence): 319 | size = list(self.size) 320 | size = (int(size[0] * radio), int(size[1] * radio)) 321 | else: 322 | raise RuntimeError() 323 | 324 | interpolation = self.allow_resize_interpolations[ 325 | torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,)) 326 | ] 327 | return size, interpolation 328 | 329 | def forward(self, img): 330 | size, interpolation = self.random_resize_params() 331 | img = TVF.resize(img, size, interpolation, self.max_size, self.antialias) 332 | img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) 333 | return img 334 | 335 | def __repr__(self) -> str: 336 | detail = f"(size={self.size}, interpolation={self.interpolation.value}," 337 | detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}" 338 | return f"{self.__class__.__name__}{detail}" 339 | 340 | 341 | class Compose(object): 342 | """Composes several transforms together. This transform does not support torchscript. 343 | Please, see the note below. 344 | 345 | Args: 346 | transforms (list of ``Transform`` objects): list of transforms to compose. 347 | 348 | Example: 349 | >>> transforms.Compose([ 350 | >>> transforms.CenterCrop(10), 351 | >>> transforms.ToTensor(), 352 | >>> ]) 353 | 354 | .. note:: 355 | In order to script the transformations, please use ``torch.nn.Sequential`` as below. 356 | 357 | >>> transforms = torch.nn.Sequential( 358 | >>> transforms.CenterCrop(10), 359 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 360 | >>> ) 361 | >>> scripted_transforms = torch.jit.script(transforms) 362 | 363 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require 364 | `lambda` functions or ``PIL.Image``. 365 | 366 | """ 367 | 368 | def __init__(self, transforms): 369 | self.transforms = transforms 370 | 371 | def __call__(self, *args): 372 | for t in self.transforms: 373 | args = t(*args) 374 | return args 375 | 376 | def __repr__(self): 377 | format_string = self.__class__.__name__ + '(' 378 | for t in self.transforms: 379 | format_string += '\n' 380 | format_string += ' {0}'.format(t) 381 | format_string += '\n)' 382 | return format_string 383 | 384 | 385 | def identity(*args, **kwargs): 386 | if len(args) == 1: 387 | return args[0] 388 | else: 389 | return args 390 | 391 | 392 | def build_transforms(cfg): 393 | 394 | if cfg is None: 395 | return identity 396 | 397 | transforms = [] 398 | 399 | for transform_name, cfg_instance in cfg.items(): 400 | transform_instance = instantiate_from_config(cfg_instance) 401 | transforms.append(transform_instance) 402 | print(f"Build transform: {transform_instance}") 403 | 404 | transforms = Compose(transforms) 405 | 406 | return transforms 407 | 408 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def worker_init_fn(_): 8 | worker_info = torch.utils.data.get_worker_info() 9 | worker_id = worker_info.id 10 | 11 | # dataset = worker_info.dataset 12 | # split_size = dataset.num_records // worker_info.num_workers 13 | # # reset num_records to the true number to retain reliable length information 14 | # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 15 | # current_id = np.random.choice(len(np.random.get_state()[1]), 1) 16 | # return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 17 | 18 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 19 | 20 | 21 | def collation_fn(samples, combine_tensors=True, combine_scalars=True): 22 | """ 23 | 24 | Args: 25 | samples (list[dict]): 26 | combine_tensors: 27 | combine_scalars: 28 | 29 | Returns: 30 | 31 | """ 32 | 33 | result = {} 34 | 35 | keys = samples[0].keys() 36 | 37 | for key in keys: 38 | result[key] = [] 39 | 40 | for sample in samples: 41 | for key in keys: 42 | val = sample[key] 43 | result[key].append(val) 44 | 45 | for key in keys: 46 | val_list = result[key] 47 | if isinstance(val_list[0], (int, float)): 48 | if combine_scalars: 49 | result[key] = np.array(result[key]) 50 | 51 | elif isinstance(val_list[0], torch.Tensor): 52 | if combine_tensors: 53 | result[key] = torch.stack(val_list) 54 | 55 | elif isinstance(val_list[0], np.ndarray): 56 | if combine_tensors: 57 | result[key] = np.stack(val_list) 58 | 59 | return result 60 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .volume import generate_dense_grid_points 4 | 5 | from .mesh import ( 6 | MeshOutput, 7 | save_obj, 8 | savemeshtes2 9 | ) 10 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/mesh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import PIL.Image 7 | from typing import Optional 8 | 9 | import trimesh 10 | 11 | 12 | def save_obj(pointnp_px3, facenp_fx3, fname): 13 | fid = open(fname, "w") 14 | write_str = "" 15 | for pidx, p in enumerate(pointnp_px3): 16 | pp = p 17 | write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2]) 18 | 19 | for i, f in enumerate(facenp_fx3): 20 | f1 = f + 1 21 | write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2]) 22 | fid.write(write_str) 23 | fid.close() 24 | return 25 | 26 | 27 | def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname): 28 | fol, na = os.path.split(fname) 29 | na, _ = os.path.splitext(na) 30 | 31 | matname = "%s/%s.mtl" % (fol, na) 32 | fid = open(matname, "w") 33 | fid.write("newmtl material_0\n") 34 | fid.write("Kd 1 1 1\n") 35 | fid.write("Ka 0 0 0\n") 36 | fid.write("Ks 0.4 0.4 0.4\n") 37 | fid.write("Ns 10\n") 38 | fid.write("illum 2\n") 39 | fid.write("map_Kd %s.png\n" % na) 40 | fid.close() 41 | #### 42 | 43 | fid = open(fname, "w") 44 | fid.write("mtllib %s.mtl\n" % na) 45 | 46 | for pidx, p in enumerate(pointnp_px3): 47 | pp = p 48 | fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2])) 49 | 50 | for pidx, p in enumerate(tcoords_px2): 51 | pp = p 52 | fid.write("vt %f %f\n" % (pp[0], pp[1])) 53 | 54 | fid.write("usemtl material_0\n") 55 | for i, f in enumerate(facenp_fx3): 56 | f1 = f + 1 57 | f2 = facetex_fx3[i] + 1 58 | fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) 59 | fid.close() 60 | 61 | PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save( 62 | os.path.join(fol, "%s.png" % na)) 63 | 64 | return 65 | 66 | 67 | class MeshOutput(object): 68 | 69 | def __init__(self, 70 | mesh_v: np.ndarray, 71 | mesh_f: np.ndarray, 72 | vertex_colors: Optional[np.ndarray] = None, 73 | uvs: Optional[np.ndarray] = None, 74 | mesh_tex_idx: Optional[np.ndarray] = None, 75 | tex_map: Optional[np.ndarray] = None): 76 | 77 | self.mesh_v = mesh_v 78 | self.mesh_f = mesh_f 79 | self.vertex_colors = vertex_colors 80 | self.uvs = uvs 81 | self.mesh_tex_idx = mesh_tex_idx 82 | self.tex_map = tex_map 83 | 84 | def contain_uv_texture(self): 85 | return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None) 86 | 87 | def contain_vertex_colors(self): 88 | return self.vertex_colors is not None 89 | 90 | def export(self, fname): 91 | 92 | if self.contain_uv_texture(): 93 | savemeshtes2( 94 | self.mesh_v, 95 | self.uvs, 96 | self.mesh_f, 97 | self.mesh_tex_idx, 98 | self.tex_map, 99 | fname 100 | ) 101 | 102 | elif self.contain_vertex_colors(): 103 | mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors) 104 | mesh_obj.export(fname) 105 | 106 | else: 107 | save_obj( 108 | self.mesh_v, 109 | self.mesh_f, 110 | fname 111 | ) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/volume.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | 6 | def generate_dense_grid_points(bbox_min: np.ndarray, 7 | bbox_max: np.ndarray, 8 | octree_depth: int, 9 | indexing: str = "ij"): 10 | length = bbox_max - bbox_min 11 | num_cells = np.exp2(octree_depth) 12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) 13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) 14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) 15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) 16 | xyz = np.stack((xs, ys, zs), axis=-1) 17 | xyz = xyz.reshape(-1, 3) 18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] 19 | 20 | return xyz, grid_size, length 21 | 22 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from omegaconf import DictConfig 4 | from typing import List, Tuple, Dict, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import lr_scheduler 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.utilities import rank_zero_only 12 | 13 | from einops import rearrange 14 | 15 | from diffusers.schedulers import ( 16 | DDPMScheduler, 17 | DDIMScheduler, 18 | KarrasVeScheduler, 19 | DPMSolverMultistepScheduler 20 | ) 21 | 22 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 23 | # from MeshAnything.miche.michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule 24 | from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule 25 | from MeshAnything.miche.michelangelo.models.asl_diffusion.inference_utils import ddim_sample 26 | 27 | SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] 28 | 29 | 30 | def disabled_train(self, mode=True): 31 | """Overwrite model.train with this function to make sure train/eval mode 32 | does not change anymore.""" 33 | return self 34 | 35 | 36 | class ASLDiffuser(pl.LightningModule): 37 | first_stage_model: Optional[AlignedShapeAsLatentPLModule] 38 | # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] 39 | model: nn.Module 40 | 41 | def __init__(self, *, 42 | first_stage_config, 43 | denoiser_cfg, 44 | scheduler_cfg, 45 | optimizer_cfg, 46 | loss_cfg, 47 | first_stage_key: str = "surface", 48 | cond_stage_key: str = "image", 49 | cond_stage_trainable: bool = True, 50 | scale_by_std: bool = False, 51 | z_scale_factor: float = 1.0, 52 | ckpt_path: Optional[str] = None, 53 | ignore_keys: Union[Tuple[str], List[str]] = ()): 54 | 55 | super().__init__() 56 | 57 | self.first_stage_key = first_stage_key 58 | self.cond_stage_key = cond_stage_key 59 | self.cond_stage_trainable = cond_stage_trainable 60 | 61 | # 1. initialize first stage. 62 | # Note: the condition model contained in the first stage model. 63 | self.first_stage_config = first_stage_config 64 | self.first_stage_model = None 65 | # self.instantiate_first_stage(first_stage_config) 66 | 67 | # 2. initialize conditional stage 68 | # self.instantiate_cond_stage(cond_stage_config) 69 | self.cond_stage_model = { 70 | "image": self.encode_image, 71 | "image_unconditional_embedding": self.empty_img_cond, 72 | "text": self.encode_text, 73 | "text_unconditional_embedding": self.empty_text_cond, 74 | "surface": self.encode_surface, 75 | "surface_unconditional_embedding": self.empty_surface_cond, 76 | } 77 | 78 | # 3. diffusion model 79 | self.model = instantiate_from_config( 80 | denoiser_cfg, device=None, dtype=None 81 | ) 82 | 83 | self.optimizer_cfg = optimizer_cfg 84 | 85 | # 4. scheduling strategy 86 | self.scheduler_cfg = scheduler_cfg 87 | 88 | self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) 89 | self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) 90 | 91 | # 5. loss configures 92 | self.loss_cfg = loss_cfg 93 | 94 | self.scale_by_std = scale_by_std 95 | if scale_by_std: 96 | self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) 97 | else: 98 | self.z_scale_factor = z_scale_factor 99 | 100 | self.ckpt_path = ckpt_path 101 | if ckpt_path is not None: 102 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 103 | 104 | def instantiate_first_stage(self, config): 105 | model = instantiate_from_config(config) 106 | self.first_stage_model = model.eval() 107 | self.first_stage_model.train = disabled_train 108 | for param in self.first_stage_model.parameters(): 109 | param.requires_grad = False 110 | 111 | self.first_stage_model = self.first_stage_model.to(self.device) 112 | 113 | # def instantiate_cond_stage(self, config): 114 | # if not self.cond_stage_trainable: 115 | # if config == "__is_first_stage__": 116 | # print("Using first stage also as cond stage.") 117 | # self.cond_stage_model = self.first_stage_model 118 | # elif config == "__is_unconditional__": 119 | # print(f"Training {self.__class__.__name__} as an unconditional model.") 120 | # self.cond_stage_model = None 121 | # # self.be_unconditional = True 122 | # else: 123 | # model = instantiate_from_config(config) 124 | # self.cond_stage_model = model.eval() 125 | # self.cond_stage_model.train = disabled_train 126 | # for param in self.cond_stage_model.parameters(): 127 | # param.requires_grad = False 128 | # else: 129 | # assert config != "__is_first_stage__" 130 | # assert config != "__is_unconditional__" 131 | # model = instantiate_from_config(config) 132 | # self.cond_stage_model = model 133 | 134 | def init_from_ckpt(self, path, ignore_keys=()): 135 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 136 | 137 | keys = list(state_dict.keys()) 138 | for k in keys: 139 | for ik in ignore_keys: 140 | if k.startswith(ik): 141 | print("Deleting key {} from state_dict.".format(k)) 142 | del state_dict[k] 143 | 144 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 145 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 146 | if len(missing) > 0: 147 | print(f"Missing Keys: {missing}") 148 | print(f"Unexpected Keys: {unexpected}") 149 | 150 | @property 151 | def zero_rank(self): 152 | if self._trainer: 153 | zero_rank = self.trainer.local_rank == 0 154 | else: 155 | zero_rank = True 156 | 157 | return zero_rank 158 | 159 | def configure_optimizers(self) -> Tuple[List, List]: 160 | 161 | lr = self.learning_rate 162 | 163 | trainable_parameters = list(self.model.parameters()) 164 | # if the conditional encoder is trainable 165 | 166 | # if self.cond_stage_trainable: 167 | # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad] 168 | # trainable_parameters += conditioner_params 169 | # print(f"number of trainable conditional parameters: {len(conditioner_params)}.") 170 | 171 | if self.optimizer_cfg is None: 172 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 173 | schedulers = [] 174 | else: 175 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 176 | scheduler_func = instantiate_from_config( 177 | self.optimizer_cfg.scheduler, 178 | max_decay_steps=self.trainer.max_steps, 179 | lr_max=lr 180 | ) 181 | scheduler = { 182 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 183 | "interval": "step", 184 | "frequency": 1 185 | } 186 | optimizers = [optimizer] 187 | schedulers = [scheduler] 188 | 189 | return optimizers, schedulers 190 | 191 | @torch.no_grad() 192 | def encode_text(self, text): 193 | 194 | b = text.shape[0] 195 | text_tokens = rearrange(text, "b t l -> (b t) l") 196 | text_embed = self.first_stage_model.model.encode_text_embed(text_tokens) 197 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 198 | text_embed = text_embed.mean(dim=1) 199 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 200 | 201 | return text_embed 202 | 203 | @torch.no_grad() 204 | def encode_image(self, img): 205 | 206 | return self.first_stage_model.model.encode_image_embed(img) 207 | 208 | @torch.no_grad() 209 | def encode_surface(self, surface): 210 | 211 | return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False) 212 | 213 | @torch.no_grad() 214 | def empty_text_cond(self, cond): 215 | 216 | return torch.zeros_like(cond, device=cond.device) 217 | 218 | @torch.no_grad() 219 | def empty_img_cond(self, cond): 220 | 221 | return torch.zeros_like(cond, device=cond.device) 222 | 223 | @torch.no_grad() 224 | def empty_surface_cond(self, cond): 225 | 226 | return torch.zeros_like(cond, device=cond.device) 227 | 228 | @torch.no_grad() 229 | def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): 230 | 231 | z_q = self.first_stage_model.encode(surface, sample_posterior) 232 | z_q = self.z_scale_factor * z_q 233 | 234 | return z_q 235 | 236 | @torch.no_grad() 237 | def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): 238 | 239 | z_q = 1. / self.z_scale_factor * z_q 240 | latents = self.first_stage_model.decode(z_q, **kwargs) 241 | return latents 242 | 243 | @rank_zero_only 244 | @torch.no_grad() 245 | def on_train_batch_start(self, batch, batch_idx): 246 | # only for very first batch 247 | if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ 248 | and batch_idx == 0 and self.ckpt_path is None: 249 | # set rescale weight to 1./std of encodings 250 | print("### USING STD-RESCALING ###") 251 | 252 | z_q = self.encode_first_stage(batch[self.first_stage_key]) 253 | z = z_q.detach() 254 | 255 | del self.z_scale_factor 256 | self.register_buffer("z_scale_factor", 1. / z.flatten().std()) 257 | print(f"setting self.z_scale_factor to {self.z_scale_factor}") 258 | 259 | print("### USING STD-RESCALING ###") 260 | 261 | def compute_loss(self, model_outputs, split): 262 | """ 263 | 264 | Args: 265 | model_outputs (dict): 266 | - x_0: 267 | - noise: 268 | - noise_prior: 269 | - noise_pred: 270 | - noise_pred_prior: 271 | 272 | split (str): 273 | 274 | Returns: 275 | 276 | """ 277 | 278 | pred = model_outputs["pred"] 279 | 280 | if self.noise_scheduler.prediction_type == "epsilon": 281 | target = model_outputs["noise"] 282 | elif self.noise_scheduler.prediction_type == "sample": 283 | target = model_outputs["x_0"] 284 | else: 285 | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") 286 | 287 | if self.loss_cfg.loss_type == "l1": 288 | simple = F.l1_loss(pred, target, reduction="mean") 289 | elif self.loss_cfg.loss_type in ["mse", "l2"]: 290 | simple = F.mse_loss(pred, target, reduction="mean") 291 | else: 292 | raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") 293 | 294 | total_loss = simple 295 | 296 | loss_dict = { 297 | f"{split}/total_loss": total_loss.clone().detach(), 298 | f"{split}/simple": simple.detach(), 299 | } 300 | 301 | return total_loss, loss_dict 302 | 303 | def forward(self, batch): 304 | """ 305 | 306 | Args: 307 | batch: 308 | 309 | Returns: 310 | 311 | """ 312 | 313 | if self.first_stage_model is None: 314 | self.instantiate_first_stage(self.first_stage_config) 315 | 316 | latents = self.encode_first_stage(batch[self.first_stage_key]) 317 | 318 | # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) 319 | 320 | conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1) 321 | 322 | mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1 323 | conditions = conditions * mask.to(conditions) 324 | 325 | # Sample noise that we"ll add to the latents 326 | # [batch_size, n_token, latent_dim] 327 | noise = torch.randn_like(latents) 328 | bs = latents.shape[0] 329 | # Sample a random timestep for each motion 330 | timesteps = torch.randint( 331 | 0, 332 | self.noise_scheduler.config.num_train_timesteps, 333 | (bs,), 334 | device=latents.device, 335 | ) 336 | timesteps = timesteps.long() 337 | # Add noise to the latents according to the noise magnitude at each timestep 338 | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) 339 | 340 | # diffusion model forward 341 | noise_pred = self.model(noisy_z, timesteps, conditions) 342 | 343 | diffusion_outputs = { 344 | "x_0": noisy_z, 345 | "noise": noise, 346 | "pred": noise_pred 347 | } 348 | 349 | return diffusion_outputs 350 | 351 | def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], 352 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 353 | """ 354 | 355 | Args: 356 | batch (dict): the batch sample, and it contains: 357 | - surface (torch.FloatTensor): 358 | - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] 359 | - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] 360 | - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] 361 | - text (list of str): 362 | 363 | batch_idx (int): 364 | 365 | optimizer_idx (int): 366 | 367 | Returns: 368 | loss (torch.FloatTensor): 369 | 370 | """ 371 | 372 | diffusion_outputs = self(batch) 373 | 374 | loss, loss_dict = self.compute_loss(diffusion_outputs, "train") 375 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 376 | 377 | return loss 378 | 379 | def validation_step(self, batch: Dict[str, torch.FloatTensor], 380 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 381 | """ 382 | 383 | Args: 384 | batch (dict): the batch sample, and it contains: 385 | - surface_pc (torch.FloatTensor): [n_pts, 4] 386 | - surface_feats (torch.FloatTensor): [n_pts, c] 387 | - text (list of str): 388 | 389 | batch_idx (int): 390 | 391 | optimizer_idx (int): 392 | 393 | Returns: 394 | loss (torch.FloatTensor): 395 | 396 | """ 397 | 398 | diffusion_outputs = self(batch) 399 | 400 | loss, loss_dict = self.compute_loss(diffusion_outputs, "val") 401 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 402 | 403 | return loss 404 | 405 | @torch.no_grad() 406 | def sample(self, 407 | batch: Dict[str, Union[torch.FloatTensor, List[str]]], 408 | sample_times: int = 1, 409 | steps: Optional[int] = None, 410 | guidance_scale: Optional[float] = None, 411 | eta: float = 0.0, 412 | return_intermediates: bool = False, **kwargs): 413 | 414 | if self.first_stage_model is None: 415 | self.instantiate_first_stage(self.first_stage_config) 416 | 417 | if steps is None: 418 | steps = self.scheduler_cfg.num_inference_steps 419 | 420 | if guidance_scale is None: 421 | guidance_scale = self.scheduler_cfg.guidance_scale 422 | do_classifier_free_guidance = guidance_scale > 0 423 | 424 | # conditional encode 425 | xc = batch[self.cond_stage_key] 426 | # cond = self.cond_stage_model[self.cond_stage_key](xc) 427 | cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1) 428 | 429 | if do_classifier_free_guidance: 430 | """ 431 | Note: There are two kinds of uncond for text. 432 | 1: using "" as uncond text; (in SAL diffusion) 433 | 2: zeros_like(cond) as uncond text; (in MDM) 434 | """ 435 | # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) 436 | un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond) 437 | # un_cond = torch.zeros_like(cond, device=cond.device) 438 | cond = torch.cat([un_cond, cond], dim=0) 439 | 440 | outputs = [] 441 | latents = None 442 | 443 | if not return_intermediates: 444 | for _ in range(sample_times): 445 | sample_loop = ddim_sample( 446 | self.denoise_scheduler, 447 | self.model, 448 | shape=self.first_stage_model.latent_shape, 449 | cond=cond, 450 | steps=steps, 451 | guidance_scale=guidance_scale, 452 | do_classifier_free_guidance=do_classifier_free_guidance, 453 | device=self.device, 454 | eta=eta, 455 | disable_prog=not self.zero_rank 456 | ) 457 | for sample, t in sample_loop: 458 | latents = sample 459 | outputs.append(self.decode_first_stage(latents, **kwargs)) 460 | else: 461 | 462 | sample_loop = ddim_sample( 463 | self.denoise_scheduler, 464 | self.model, 465 | shape=self.first_stage_model.latent_shape, 466 | cond=cond, 467 | steps=steps, 468 | guidance_scale=guidance_scale, 469 | do_classifier_free_guidance=do_classifier_free_guidance, 470 | device=self.device, 471 | eta=eta, 472 | disable_prog=not self.zero_rank 473 | ) 474 | 475 | iter_size = steps // sample_times 476 | i = 0 477 | for sample, t in sample_loop: 478 | latents = sample 479 | if i % iter_size == 0 or i == steps - 1: 480 | outputs.append(self.decode_first_stage(latents, **kwargs)) 481 | i += 1 482 | 483 | return outputs 484 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/asl_udt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Optional 6 | from diffusers.models.embeddings import Timesteps 7 | import math 8 | 9 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import MLP 10 | from MeshAnything.miche.michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer 11 | 12 | 13 | class ConditionalASLUDTDenoiser(nn.Module): 14 | 15 | def __init__(self, *, 16 | device: Optional[torch.device], 17 | dtype: Optional[torch.dtype], 18 | input_channels: int, 19 | output_channels: int, 20 | n_ctx: int, 21 | width: int, 22 | layers: int, 23 | heads: int, 24 | context_dim: int, 25 | context_ln: bool = True, 26 | skip_ln: bool = False, 27 | init_scale: float = 0.25, 28 | flip_sin_to_cos: bool = False, 29 | use_checkpoint: bool = False): 30 | super().__init__() 31 | 32 | self.use_checkpoint = use_checkpoint 33 | 34 | init_scale = init_scale * math.sqrt(1.0 / width) 35 | 36 | self.backbone = UNetDiffusionTransformer( 37 | device=device, 38 | dtype=dtype, 39 | n_ctx=n_ctx, 40 | width=width, 41 | layers=layers, 42 | heads=heads, 43 | skip_ln=skip_ln, 44 | init_scale=init_scale, 45 | use_checkpoint=use_checkpoint 46 | ) 47 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 48 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) 49 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) 50 | 51 | # timestep embedding 52 | self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) 53 | self.time_proj = MLP( 54 | device=device, dtype=dtype, width=width, init_scale=init_scale 55 | ) 56 | 57 | self.context_embed = nn.Sequential( 58 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 59 | nn.Linear(context_dim, width, device=device, dtype=dtype), 60 | ) 61 | 62 | if context_ln: 63 | self.context_embed = nn.Sequential( 64 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 65 | nn.Linear(context_dim, width, device=device, dtype=dtype), 66 | ) 67 | else: 68 | self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) 69 | 70 | def forward(self, 71 | model_input: torch.FloatTensor, 72 | timestep: torch.LongTensor, 73 | context: torch.FloatTensor): 74 | 75 | r""" 76 | Args: 77 | model_input (torch.FloatTensor): [bs, n_data, c] 78 | timestep (torch.LongTensor): [bs,] 79 | context (torch.FloatTensor): [bs, context_tokens, c] 80 | 81 | Returns: 82 | sample (torch.FloatTensor): [bs, n_data, c] 83 | 84 | """ 85 | 86 | _, n_data, _ = model_input.shape 87 | 88 | # 1. time 89 | t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) 90 | 91 | # 2. conditions projector 92 | context = self.context_embed(context) 93 | 94 | # 3. denoiser 95 | x = self.input_proj(model_input) 96 | x = torch.cat([t_emb, context, x], dim=1) 97 | x = self.backbone(x) 98 | x = self.ln_post(x) 99 | x = x[:, -n_data:] 100 | sample = self.output_proj(x) 101 | 102 | return sample 103 | 104 | 105 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BaseDenoiser(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, t, context): 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from omegaconf import DictConfig 4 | from typing import List, Tuple, Dict, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import lr_scheduler 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.utilities import rank_zero_only 12 | 13 | from diffusers.schedulers import ( 14 | DDPMScheduler, 15 | DDIMScheduler, 16 | KarrasVeScheduler, 17 | DPMSolverMultistepScheduler 18 | ) 19 | 20 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 21 | from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule 22 | from MeshAnything.miche.michelangelo.models.asl_diffusion.inference_utils import ddim_sample 23 | 24 | SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] 25 | 26 | 27 | def disabled_train(self, mode=True): 28 | """Overwrite model.train with this function to make sure train/eval mode 29 | does not change anymore.""" 30 | return self 31 | 32 | 33 | class ClipASLDiffuser(pl.LightningModule): 34 | first_stage_model: Optional[AlignedShapeAsLatentPLModule] 35 | cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] 36 | model: nn.Module 37 | 38 | def __init__(self, *, 39 | first_stage_config, 40 | cond_stage_config, 41 | denoiser_cfg, 42 | scheduler_cfg, 43 | optimizer_cfg, 44 | loss_cfg, 45 | first_stage_key: str = "surface", 46 | cond_stage_key: str = "image", 47 | scale_by_std: bool = False, 48 | z_scale_factor: float = 1.0, 49 | ckpt_path: Optional[str] = None, 50 | ignore_keys: Union[Tuple[str], List[str]] = ()): 51 | 52 | super().__init__() 53 | 54 | self.first_stage_key = first_stage_key 55 | self.cond_stage_key = cond_stage_key 56 | 57 | # 1. lazy initialize first stage 58 | self.instantiate_first_stage(first_stage_config) 59 | 60 | # 2. initialize conditional stage 61 | self.instantiate_cond_stage(cond_stage_config) 62 | 63 | # 3. diffusion model 64 | self.model = instantiate_from_config( 65 | denoiser_cfg, device=None, dtype=None 66 | ) 67 | 68 | self.optimizer_cfg = optimizer_cfg 69 | 70 | # 4. scheduling strategy 71 | self.scheduler_cfg = scheduler_cfg 72 | 73 | self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) 74 | self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) 75 | 76 | # 5. loss configures 77 | self.loss_cfg = loss_cfg 78 | 79 | self.scale_by_std = scale_by_std 80 | if scale_by_std: 81 | self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) 82 | else: 83 | self.z_scale_factor = z_scale_factor 84 | 85 | self.ckpt_path = ckpt_path 86 | if ckpt_path is not None: 87 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 88 | 89 | def instantiate_non_trainable_model(self, config): 90 | model = instantiate_from_config(config) 91 | model = model.eval() 92 | model.train = disabled_train 93 | for param in model.parameters(): 94 | param.requires_grad = False 95 | 96 | return model 97 | 98 | def instantiate_first_stage(self, first_stage_config): 99 | self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config) 100 | self.first_stage_model.set_shape_model_only() 101 | 102 | def instantiate_cond_stage(self, cond_stage_config): 103 | self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config) 104 | 105 | def init_from_ckpt(self, path, ignore_keys=()): 106 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 107 | 108 | keys = list(state_dict.keys()) 109 | for k in keys: 110 | for ik in ignore_keys: 111 | if k.startswith(ik): 112 | print("Deleting key {} from state_dict.".format(k)) 113 | del state_dict[k] 114 | 115 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 116 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 117 | if len(missing) > 0: 118 | print(f"Missing Keys: {missing}") 119 | print(f"Unexpected Keys: {unexpected}") 120 | 121 | @property 122 | def zero_rank(self): 123 | if self._trainer: 124 | zero_rank = self.trainer.local_rank == 0 125 | else: 126 | zero_rank = True 127 | 128 | return zero_rank 129 | 130 | def configure_optimizers(self) -> Tuple[List, List]: 131 | 132 | lr = self.learning_rate 133 | 134 | trainable_parameters = list(self.model.parameters()) 135 | if self.optimizer_cfg is None: 136 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 137 | schedulers = [] 138 | else: 139 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 140 | scheduler_func = instantiate_from_config( 141 | self.optimizer_cfg.scheduler, 142 | max_decay_steps=self.trainer.max_steps, 143 | lr_max=lr 144 | ) 145 | scheduler = { 146 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 147 | "interval": "step", 148 | "frequency": 1 149 | } 150 | optimizers = [optimizer] 151 | schedulers = [scheduler] 152 | 153 | return optimizers, schedulers 154 | 155 | @torch.no_grad() 156 | def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): 157 | 158 | z_q = self.first_stage_model.encode(surface, sample_posterior) 159 | z_q = self.z_scale_factor * z_q 160 | 161 | return z_q 162 | 163 | @torch.no_grad() 164 | def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): 165 | 166 | z_q = 1. / self.z_scale_factor * z_q 167 | latents = self.first_stage_model.decode(z_q, **kwargs) 168 | return latents 169 | 170 | @rank_zero_only 171 | @torch.no_grad() 172 | def on_train_batch_start(self, batch, batch_idx): 173 | # only for very first batch 174 | if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ 175 | and batch_idx == 0 and self.ckpt_path is None: 176 | # set rescale weight to 1./std of encodings 177 | print("### USING STD-RESCALING ###") 178 | 179 | z_q = self.encode_first_stage(batch[self.first_stage_key]) 180 | z = z_q.detach() 181 | 182 | del self.z_scale_factor 183 | self.register_buffer("z_scale_factor", 1. / z.flatten().std()) 184 | print(f"setting self.z_scale_factor to {self.z_scale_factor}") 185 | 186 | print("### USING STD-RESCALING ###") 187 | 188 | def compute_loss(self, model_outputs, split): 189 | """ 190 | 191 | Args: 192 | model_outputs (dict): 193 | - x_0: 194 | - noise: 195 | - noise_prior: 196 | - noise_pred: 197 | - noise_pred_prior: 198 | 199 | split (str): 200 | 201 | Returns: 202 | 203 | """ 204 | 205 | pred = model_outputs["pred"] 206 | 207 | if self.noise_scheduler.prediction_type == "epsilon": 208 | target = model_outputs["noise"] 209 | elif self.noise_scheduler.prediction_type == "sample": 210 | target = model_outputs["x_0"] 211 | else: 212 | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") 213 | 214 | if self.loss_cfg.loss_type == "l1": 215 | simple = F.l1_loss(pred, target, reduction="mean") 216 | elif self.loss_cfg.loss_type in ["mse", "l2"]: 217 | simple = F.mse_loss(pred, target, reduction="mean") 218 | else: 219 | raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") 220 | 221 | total_loss = simple 222 | 223 | loss_dict = { 224 | f"{split}/total_loss": total_loss.clone().detach(), 225 | f"{split}/simple": simple.detach(), 226 | } 227 | 228 | return total_loss, loss_dict 229 | 230 | def forward(self, batch): 231 | """ 232 | 233 | Args: 234 | batch: 235 | 236 | Returns: 237 | 238 | """ 239 | 240 | latents = self.encode_first_stage(batch[self.first_stage_key]) 241 | conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) 242 | 243 | # Sample noise that we"ll add to the latents 244 | # [batch_size, n_token, latent_dim] 245 | noise = torch.randn_like(latents) 246 | bs = latents.shape[0] 247 | # Sample a random timestep for each motion 248 | timesteps = torch.randint( 249 | 0, 250 | self.noise_scheduler.config.num_train_timesteps, 251 | (bs,), 252 | device=latents.device, 253 | ) 254 | timesteps = timesteps.long() 255 | # Add noise to the latents according to the noise magnitude at each timestep 256 | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) 257 | 258 | # diffusion model forward 259 | noise_pred = self.model(noisy_z, timesteps, conditions) 260 | 261 | diffusion_outputs = { 262 | "x_0": noisy_z, 263 | "noise": noise, 264 | "pred": noise_pred 265 | } 266 | 267 | return diffusion_outputs 268 | 269 | def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], 270 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 271 | """ 272 | 273 | Args: 274 | batch (dict): the batch sample, and it contains: 275 | - surface (torch.FloatTensor): 276 | - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] 277 | - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] 278 | - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] 279 | - text (list of str): 280 | 281 | batch_idx (int): 282 | 283 | optimizer_idx (int): 284 | 285 | Returns: 286 | loss (torch.FloatTensor): 287 | 288 | """ 289 | 290 | diffusion_outputs = self(batch) 291 | 292 | loss, loss_dict = self.compute_loss(diffusion_outputs, "train") 293 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 294 | 295 | return loss 296 | 297 | def validation_step(self, batch: Dict[str, torch.FloatTensor], 298 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 299 | """ 300 | 301 | Args: 302 | batch (dict): the batch sample, and it contains: 303 | - surface_pc (torch.FloatTensor): [n_pts, 4] 304 | - surface_feats (torch.FloatTensor): [n_pts, c] 305 | - text (list of str): 306 | 307 | batch_idx (int): 308 | 309 | optimizer_idx (int): 310 | 311 | Returns: 312 | loss (torch.FloatTensor): 313 | 314 | """ 315 | 316 | diffusion_outputs = self(batch) 317 | 318 | loss, loss_dict = self.compute_loss(diffusion_outputs, "val") 319 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 320 | 321 | return loss 322 | 323 | @torch.no_grad() 324 | def sample(self, 325 | batch: Dict[str, Union[torch.FloatTensor, List[str]]], 326 | sample_times: int = 1, 327 | steps: Optional[int] = None, 328 | guidance_scale: Optional[float] = None, 329 | eta: float = 0.0, 330 | return_intermediates: bool = False, **kwargs): 331 | 332 | if steps is None: 333 | steps = self.scheduler_cfg.num_inference_steps 334 | 335 | if guidance_scale is None: 336 | guidance_scale = self.scheduler_cfg.guidance_scale 337 | do_classifier_free_guidance = guidance_scale > 0 338 | 339 | # conditional encode 340 | xc = batch[self.cond_stage_key] 341 | 342 | # print(self.first_stage_model.device, self.cond_stage_model.device, self.device) 343 | 344 | cond = self.cond_stage_model(xc) 345 | 346 | if do_classifier_free_guidance: 347 | un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) 348 | cond = torch.cat([un_cond, cond], dim=0) 349 | 350 | outputs = [] 351 | latents = None 352 | 353 | if not return_intermediates: 354 | for _ in range(sample_times): 355 | sample_loop = ddim_sample( 356 | self.denoise_scheduler, 357 | self.model, 358 | shape=self.first_stage_model.latent_shape, 359 | cond=cond, 360 | steps=steps, 361 | guidance_scale=guidance_scale, 362 | do_classifier_free_guidance=do_classifier_free_guidance, 363 | device=self.device, 364 | eta=eta, 365 | disable_prog=not self.zero_rank 366 | ) 367 | for sample, t in sample_loop: 368 | latents = sample 369 | outputs.append(self.decode_first_stage(latents, **kwargs)) 370 | else: 371 | 372 | sample_loop = ddim_sample( 373 | self.denoise_scheduler, 374 | self.model, 375 | shape=self.first_stage_model.latent_shape, 376 | cond=cond, 377 | steps=steps, 378 | guidance_scale=guidance_scale, 379 | do_classifier_free_guidance=do_classifier_free_guidance, 380 | device=self.device, 381 | eta=eta, 382 | disable_prog=not self.zero_rank 383 | ) 384 | 385 | iter_size = steps // sample_times 386 | i = 0 387 | for sample, t in sample_loop: 388 | latents = sample 389 | if i % iter_size == 0 or i == steps - 1: 390 | outputs.append(self.decode_first_stage(latents, **kwargs)) 391 | i += 1 392 | 393 | return outputs 394 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from typing import Tuple, List, Union, Optional 6 | from diffusers.schedulers import DDIMScheduler 7 | 8 | 9 | __all__ = ["ddim_sample"] 10 | 11 | 12 | def ddim_sample(ddim_scheduler: DDIMScheduler, 13 | diffusion_model: torch.nn.Module, 14 | shape: Union[List[int], Tuple[int]], 15 | cond: torch.FloatTensor, 16 | steps: int, 17 | eta: float = 0.0, 18 | guidance_scale: float = 3.0, 19 | do_classifier_free_guidance: bool = True, 20 | generator: Optional[torch.Generator] = None, 21 | device: torch.device = "cuda:0", 22 | disable_prog: bool = True): 23 | 24 | assert steps > 0, f"{steps} must > 0." 25 | 26 | # init latents 27 | bsz = cond.shape[0] 28 | if do_classifier_free_guidance: 29 | bsz = bsz // 2 30 | 31 | latents = torch.randn( 32 | (bsz, *shape), 33 | generator=generator, 34 | device=cond.device, 35 | dtype=cond.dtype, 36 | ) 37 | # scale the initial noise by the standard deviation required by the scheduler 38 | latents = latents * ddim_scheduler.init_noise_sigma 39 | # set timesteps 40 | ddim_scheduler.set_timesteps(steps) 41 | timesteps = ddim_scheduler.timesteps.to(device) 42 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 43 | # eta (η) is only used with the DDIMScheduler, and between [0, 1] 44 | extra_step_kwargs = { 45 | "eta": eta, 46 | "generator": generator 47 | } 48 | 49 | # reverse 50 | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): 51 | # expand the latents if we are doing classifier free guidance 52 | latent_model_input = ( 53 | torch.cat([latents] * 2) 54 | if do_classifier_free_guidance 55 | else latents 56 | ) 57 | # latent_model_input = scheduler.scale_model_input(latent_model_input, t) 58 | # predict the noise residual 59 | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) 60 | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) 61 | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) 62 | 63 | # perform guidance 64 | if do_classifier_free_guidance: 65 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 66 | noise_pred = noise_pred_uncond + guidance_scale * ( 67 | noise_pred_text - noise_pred_uncond 68 | ) 69 | # text_embeddings_for_guidance = encoder_hidden_states.chunk( 70 | # 2)[1] if do_classifier_free_guidance else encoder_hidden_states 71 | # compute the previous noisy sample x_t -> x_t-1 72 | latents = ddim_scheduler.step( 73 | noise_pred, t, latents, **extra_step_kwargs 74 | ).prev_sample 75 | 76 | yield latents, t 77 | 78 | 79 | def karra_sample(): 80 | pass 81 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/conditional_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .clip import CLIPEncoder 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/conditional_encoders/clip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from dataclasses import dataclass 7 | from torchvision.transforms import Normalize 8 | from transformers import CLIPModel, CLIPTokenizer 9 | from transformers.utils import ModelOutput 10 | from typing import Iterable, Optional, Union, List 11 | 12 | 13 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 14 | 15 | 16 | @dataclass 17 | class CLIPEmbedOutput(ModelOutput): 18 | last_hidden_state: torch.FloatTensor = None 19 | pooler_output: torch.FloatTensor = None 20 | embeds: torch.FloatTensor = None 21 | 22 | 23 | class CLIPEncoder(torch.nn.Module): 24 | 25 | def __init__(self, model_path="openai/clip-vit-base-patch32"): 26 | 27 | super().__init__() 28 | 29 | # Load the CLIP model and processor 30 | self.model: CLIPModel = CLIPModel.from_pretrained(model_path) 31 | self.tokenizer = CLIPTokenizer.from_pretrained(model_path) 32 | self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 33 | 34 | self.model.training = False 35 | for p in self.model.parameters(): 36 | p.requires_grad = False 37 | 38 | @torch.no_grad() 39 | def encode_image(self, images: Iterable[Optional[ImageType]]): 40 | pixel_values = self.image_preprocess(images) 41 | 42 | vision_outputs = self.model.vision_model(pixel_values=pixel_values) 43 | 44 | pooler_output = vision_outputs[1] # pooled_output 45 | image_features = self.model.visual_projection(pooler_output) 46 | 47 | visual_embeds = CLIPEmbedOutput( 48 | last_hidden_state=vision_outputs.last_hidden_state, 49 | pooler_output=pooler_output, 50 | embeds=image_features 51 | ) 52 | 53 | return visual_embeds 54 | 55 | @torch.no_grad() 56 | def encode_text(self, texts: List[str]): 57 | text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") 58 | 59 | text_outputs = self.model.text_model(input_ids=text_inputs) 60 | 61 | pooler_output = text_outputs[1] # pooled_output 62 | text_features = self.model.text_projection(pooler_output) 63 | 64 | text_embeds = CLIPEmbedOutput( 65 | last_hidden_state=text_outputs.last_hidden_state, 66 | pooler_output=pooler_output, 67 | embeds=text_features 68 | ) 69 | 70 | return text_embeds 71 | 72 | def forward(self, 73 | images: Iterable[Optional[ImageType]], 74 | texts: List[str]): 75 | 76 | visual_embeds = self.encode_image(images) 77 | text_embeds = self.encode_text(texts) 78 | 79 | return visual_embeds, text_embeds 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import checkpoint 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 4 | """ 5 | 6 | import torch 7 | from typing import Callable, Iterable, Sequence, Union 8 | 9 | 10 | def checkpoint( 11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 12 | inputs: Sequence[torch.Tensor], 13 | params: Iterable[torch.Tensor], 14 | flag: bool, 15 | use_deepspeed: bool = False 16 | ): 17 | """ 18 | Evaluate a function without caching intermediate activations, allowing for 19 | reduced memory at the expense of extra compute in the backward pass. 20 | :param func: the function to evaluate. 21 | :param inputs: the argument sequence to pass to `func`. 22 | :param params: a sequence of parameters `func` depends on but does not 23 | explicitly take as arguments. 24 | :param flag: if False, disable gradient checkpointing. 25 | :param use_deepspeed: if True, use deepspeed 26 | """ 27 | if flag: 28 | if use_deepspeed: 29 | import deepspeed 30 | return deepspeed.checkpointing.checkpoint(func, *inputs) 31 | 32 | args = tuple(inputs) + tuple(params) 33 | return CheckpointFunction.apply(func, len(inputs), *args) 34 | else: 35 | return func(*inputs) 36 | 37 | 38 | class CheckpointFunction(torch.autograd.Function): 39 | @staticmethod 40 | @torch.cuda.amp.custom_fwd 41 | def forward(ctx, run_function, length, *args): 42 | ctx.run_function = run_function 43 | ctx.input_tensors = list(args[:length]) 44 | ctx.input_params = list(args[length:]) 45 | 46 | with torch.no_grad(): 47 | output_tensors = ctx.run_function(*ctx.input_tensors) 48 | return output_tensors 49 | 50 | @staticmethod 51 | @torch.cuda.amp.custom_bwd 52 | def backward(ctx, *output_grads): 53 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 54 | with torch.enable_grad(): 55 | # Fixes a bug where the first op in run_function modifies the 56 | # Tensor storage in place, which is not allowed for detach()'d 57 | # Tensors. 58 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 59 | output_tensors = ctx.run_function(*shallow_copies) 60 | input_grads = torch.autograd.grad( 61 | output_tensors, 62 | ctx.input_tensors + ctx.input_params, 63 | output_grads, 64 | allow_unused=True, 65 | ) 66 | del ctx.input_tensors 67 | del ctx.input_params 68 | del output_tensors 69 | return (None, None) + input_grads 70 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/diffusion_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from typing import Optional 7 | 8 | from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint 9 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import ( 10 | init_linear, 11 | MLP, 12 | MultiheadCrossAttention, 13 | MultiheadAttention, 14 | ResidualAttentionBlock 15 | ) 16 | 17 | 18 | class AdaLayerNorm(nn.Module): 19 | def __init__(self, 20 | device: torch.device, 21 | dtype: torch.dtype, 22 | width: int): 23 | 24 | super().__init__() 25 | 26 | self.silu = nn.SiLU(inplace=True) 27 | self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype) 28 | self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype) 29 | 30 | def forward(self, x, timestep): 31 | emb = self.linear(timestep) 32 | scale, shift = torch.chunk(emb, 2, dim=2) 33 | x = self.layernorm(x) * (1 + scale) + shift 34 | return x 35 | 36 | 37 | class DitBlock(nn.Module): 38 | def __init__( 39 | self, 40 | *, 41 | device: torch.device, 42 | dtype: torch.dtype, 43 | n_ctx: int, 44 | width: int, 45 | heads: int, 46 | context_dim: int, 47 | qkv_bias: bool = False, 48 | init_scale: float = 1.0, 49 | use_checkpoint: bool = False 50 | ): 51 | super().__init__() 52 | 53 | self.use_checkpoint = use_checkpoint 54 | 55 | self.attn = MultiheadAttention( 56 | device=device, 57 | dtype=dtype, 58 | n_ctx=n_ctx, 59 | width=width, 60 | heads=heads, 61 | init_scale=init_scale, 62 | qkv_bias=qkv_bias 63 | ) 64 | self.ln_1 = AdaLayerNorm(device, dtype, width) 65 | 66 | if context_dim is not None: 67 | self.ln_2 = AdaLayerNorm(device, dtype, width) 68 | self.cross_attn = MultiheadCrossAttention( 69 | device=device, 70 | dtype=dtype, 71 | width=width, 72 | heads=heads, 73 | data_width=context_dim, 74 | init_scale=init_scale, 75 | qkv_bias=qkv_bias 76 | ) 77 | 78 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 79 | self.ln_3 = AdaLayerNorm(device, dtype, width) 80 | 81 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 82 | return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint) 83 | 84 | def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 85 | x = x + self.attn(self.ln_1(x, t)) 86 | if context is not None: 87 | x = x + self.cross_attn(self.ln_2(x, t), context) 88 | x = x + self.mlp(self.ln_3(x, t)) 89 | return x 90 | 91 | 92 | class DiT(nn.Module): 93 | def __init__( 94 | self, 95 | *, 96 | device: Optional[torch.device], 97 | dtype: Optional[torch.dtype], 98 | n_ctx: int, 99 | width: int, 100 | layers: int, 101 | heads: int, 102 | context_dim: int, 103 | init_scale: float = 0.25, 104 | qkv_bias: bool = False, 105 | use_checkpoint: bool = False 106 | ): 107 | super().__init__() 108 | self.n_ctx = n_ctx 109 | self.width = width 110 | self.layers = layers 111 | 112 | self.resblocks = nn.ModuleList( 113 | [ 114 | DitBlock( 115 | device=device, 116 | dtype=dtype, 117 | n_ctx=n_ctx, 118 | width=width, 119 | heads=heads, 120 | context_dim=context_dim, 121 | qkv_bias=qkv_bias, 122 | init_scale=init_scale, 123 | use_checkpoint=use_checkpoint 124 | ) 125 | for _ in range(layers) 126 | ] 127 | ) 128 | 129 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 130 | for block in self.resblocks: 131 | x = block(x, t, context) 132 | return x 133 | 134 | 135 | class UNetDiffusionTransformer(nn.Module): 136 | def __init__( 137 | self, 138 | *, 139 | device: Optional[torch.device], 140 | dtype: Optional[torch.dtype], 141 | n_ctx: int, 142 | width: int, 143 | layers: int, 144 | heads: int, 145 | init_scale: float = 0.25, 146 | qkv_bias: bool = False, 147 | skip_ln: bool = False, 148 | use_checkpoint: bool = False 149 | ): 150 | super().__init__() 151 | 152 | self.n_ctx = n_ctx 153 | self.width = width 154 | self.layers = layers 155 | 156 | self.encoder = nn.ModuleList() 157 | for _ in range(layers): 158 | resblock = ResidualAttentionBlock( 159 | device=device, 160 | dtype=dtype, 161 | n_ctx=n_ctx, 162 | width=width, 163 | heads=heads, 164 | init_scale=init_scale, 165 | qkv_bias=qkv_bias, 166 | use_checkpoint=use_checkpoint 167 | ) 168 | self.encoder.append(resblock) 169 | 170 | self.middle_block = ResidualAttentionBlock( 171 | device=device, 172 | dtype=dtype, 173 | n_ctx=n_ctx, 174 | width=width, 175 | heads=heads, 176 | init_scale=init_scale, 177 | qkv_bias=qkv_bias, 178 | use_checkpoint=use_checkpoint 179 | ) 180 | 181 | self.decoder = nn.ModuleList() 182 | for _ in range(layers): 183 | resblock = ResidualAttentionBlock( 184 | device=device, 185 | dtype=dtype, 186 | n_ctx=n_ctx, 187 | width=width, 188 | heads=heads, 189 | init_scale=init_scale, 190 | qkv_bias=qkv_bias, 191 | use_checkpoint=use_checkpoint 192 | ) 193 | linear = nn.Linear(width * 2, width, device=device, dtype=dtype) 194 | init_linear(linear, init_scale) 195 | 196 | layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None 197 | 198 | self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) 199 | 200 | def forward(self, x: torch.Tensor): 201 | 202 | enc_outputs = [] 203 | for block in self.encoder: 204 | x = block(x) 205 | enc_outputs.append(x) 206 | 207 | x = self.middle_block(x) 208 | 209 | for i, (resblock, linear, layer_norm) in enumerate(self.decoder): 210 | x = torch.cat([enc_outputs.pop(), x], dim=-1) 211 | x = linear(x) 212 | 213 | if layer_norm is not None: 214 | x = layer_norm(x) 215 | 216 | x = resblock(x) 217 | 218 | return x 219 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union, List 4 | 5 | 6 | class AbstractDistribution(object): 7 | def sample(self): 8 | raise NotImplementedError() 9 | 10 | def mode(self): 11 | raise NotImplementedError() 12 | 13 | 14 | class DiracDistribution(AbstractDistribution): 15 | def __init__(self, value): 16 | self.value = value 17 | 18 | def sample(self): 19 | return self.value 20 | 21 | def mode(self): 22 | return self.value 23 | 24 | 25 | class DiagonalGaussianDistribution(object): 26 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): 27 | self.feat_dim = feat_dim 28 | self.parameters = parameters 29 | 30 | if isinstance(parameters, list): 31 | self.mean = parameters[0] 32 | self.logvar = parameters[1] 33 | else: 34 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) 35 | 36 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 37 | self.deterministic = deterministic 38 | self.std = torch.exp(0.5 * self.logvar) 39 | self.var = torch.exp(self.logvar) 40 | if self.deterministic: 41 | self.var = self.std = torch.zeros_like(self.mean) 42 | 43 | def sample(self): 44 | x = self.mean + self.std * torch.randn_like(self.mean) 45 | return x 46 | 47 | def kl(self, other=None, dims=(1, 2, 3)): 48 | if self.deterministic: 49 | return torch.Tensor([0.]) 50 | else: 51 | if other is None: 52 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 53 | + self.var - 1.0 - self.logvar, 54 | dim=dims) 55 | else: 56 | return 0.5 * torch.mean( 57 | torch.pow(self.mean - other.mean, 2) / other.var 58 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 59 | dim=dims) 60 | 61 | def nll(self, sample, dims=(1, 2, 3)): 62 | if self.deterministic: 63 | return torch.Tensor([0.]) 64 | logtwopi = np.log(2.0 * np.pi) 65 | return 0.5 * torch.sum( 66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 67 | dim=dims) 68 | 69 | def mode(self): 70 | return self.mean 71 | 72 | 73 | def normal_kl(mean1, logvar1, mean2, logvar2): 74 | """ 75 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 76 | Compute the KL divergence between two gaussians. 77 | Shapes are automatically broadcasted, so batches can be compared to 78 | scalars, among other use cases. 79 | """ 80 | tensor = None 81 | for obj in (mean1, logvar1, mean2, logvar2): 82 | if isinstance(obj, torch.Tensor): 83 | tensor = obj 84 | break 85 | assert tensor is not None, "at least one argument must be a Tensor" 86 | 87 | # Force variances to be Tensors. Broadcasting helps convert scalars to 88 | # Tensors, but it does not work for torch.exp(). 89 | logvar1, logvar2 = [ 90 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 91 | for x in (logvar1, logvar2) 92 | ] 93 | 94 | return 0.5 * ( 95 | -1.0 96 | + logvar2 97 | - logvar1 98 | + torch.exp(logvar1 - logvar2) 99 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 100 | ) 101 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/embedder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] 9 | 10 | 11 | class FourierEmbedder(nn.Module): 12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts 13 | each feature dimension of `x[..., i]` into: 14 | [ 15 | sin(x[..., i]), 16 | sin(f_1*x[..., i]), 17 | sin(f_2*x[..., i]), 18 | ... 19 | sin(f_N * x[..., i]), 20 | cos(x[..., i]), 21 | cos(f_1*x[..., i]), 22 | cos(f_2*x[..., i]), 23 | ... 24 | cos(f_N * x[..., i]), 25 | x[..., i] # only present if include_input is True. 26 | ], here f_i is the frequency. 27 | 28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. 29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; 30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. 31 | 32 | Args: 33 | num_freqs (int): the number of frequencies, default is 6; 34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; 36 | input_dim (int): the input dimension, default is 3; 37 | include_input (bool): include the input tensor or not, default is True. 38 | 39 | Attributes: 40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); 42 | 43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), 44 | otherwise, it is input_dim * num_freqs * 2. 45 | 46 | """ 47 | 48 | def __init__(self, 49 | num_freqs: int = 6, 50 | logspace: bool = True, 51 | input_dim: int = 3, 52 | include_input: bool = True, 53 | include_pi: bool = True) -> None: 54 | 55 | """The initialization""" 56 | 57 | super().__init__() 58 | 59 | if logspace: 60 | frequencies = 2.0 ** torch.arange( 61 | num_freqs, 62 | dtype=torch.float32 63 | ) 64 | else: 65 | frequencies = torch.linspace( 66 | 1.0, 67 | 2.0 ** (num_freqs - 1), 68 | num_freqs, 69 | dtype=torch.float32 70 | ) 71 | 72 | if include_pi: 73 | frequencies *= torch.pi 74 | 75 | self.register_buffer("frequencies", frequencies, persistent=False) 76 | self.include_input = include_input 77 | self.num_freqs = num_freqs 78 | 79 | self.out_dim = self.get_dims(input_dim) 80 | 81 | def get_dims(self, input_dim): 82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0 83 | out_dim = input_dim * (self.num_freqs * 2 + temp) 84 | 85 | return out_dim 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | """ Forward process. 89 | 90 | Args: 91 | x: tensor of shape [..., dim] 92 | 93 | Returns: 94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] 95 | where temp is 1 if include_input is True and 0 otherwise. 96 | """ 97 | 98 | if self.num_freqs > 0: 99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) 100 | if self.include_input: 101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) 102 | else: 103 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 104 | else: 105 | return x 106 | 107 | 108 | class LearnedFourierEmbedder(nn.Module): 109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """ 110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 111 | 112 | def __init__(self, in_channels, dim): 113 | super().__init__() 114 | assert (dim % 2) == 0 115 | half_dim = dim // 2 116 | per_channel_dim = half_dim // in_channels 117 | self.weights = nn.Parameter(torch.randn(per_channel_dim)) 118 | 119 | def forward(self, x): 120 | """ 121 | 122 | Args: 123 | x (torch.FloatTensor): [..., c] 124 | 125 | Returns: 126 | x (torch.FloatTensor): [..., d] 127 | """ 128 | 129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] 130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) 131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) 132 | return fouriered 133 | 134 | 135 | class TriplaneLearnedFourierEmbedder(nn.Module): 136 | def __init__(self, in_channels, dim): 137 | super().__init__() 138 | 139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 142 | 143 | self.out_dim = in_channels + dim 144 | 145 | def forward(self, x): 146 | 147 | yz_embed = self.yz_plane_embedder(x) 148 | xz_embed = self.xz_plane_embedder(x) 149 | xy_embed = self.xy_plane_embedder(x) 150 | 151 | embed = yz_embed + xz_embed + xy_embed 152 | 153 | return embed 154 | 155 | 156 | def sequential_pos_embed(num_len, embed_dim): 157 | assert embed_dim % 2 == 0 158 | 159 | pos = torch.arange(num_len, dtype=torch.float32) 160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32) 161 | omega /= embed_dim / 2. 162 | omega = 1. / 10000 ** omega # (D/2,) 163 | 164 | pos = pos.reshape(-1) # (M,) 165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 166 | 167 | emb_sin = torch.sin(out) # (M, D/2) 168 | emb_cos = torch.cos(out) # (M, D/2) 169 | 170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 171 | 172 | return embeddings 173 | 174 | 175 | def timestep_embedding(timesteps, dim, max_period=10000): 176 | """ 177 | Create sinusoidal timestep embeddings. 178 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 179 | These may be fractional. 180 | :param dim: the dimension of the output. 181 | :param max_period: controls the minimum frequency of the embeddings. 182 | :return: an [N x dim] Tensor of positional embeddings. 183 | """ 184 | half = dim // 2 185 | freqs = torch.exp( 186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 187 | ).to(device=timesteps.device) 188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None] 189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 190 | if dim % 2: 191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 192 | return embedding 193 | 194 | 195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, 196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, 197 | log2_hashmap_size=19, desired_resolution=None): 198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): 199 | return nn.Identity(), input_dim 200 | 201 | elif embed_type == "fourier": 202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, 203 | logspace=True, include_input=True) 204 | return embedder_obj, embedder_obj.out_dim 205 | 206 | elif embed_type == "hashgrid": 207 | raise NotImplementedError 208 | 209 | elif embed_type == "sphere_harmonic": 210 | raise NotImplementedError 211 | 212 | else: 213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") 214 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/transformer_blocks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Optional 8 | 9 | from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint 10 | 11 | 12 | def init_linear(l, stddev): 13 | nn.init.normal_(l.weight, std=stddev) 14 | if l.bias is not None: 15 | nn.init.constant_(l.bias, 0.0) 16 | 17 | 18 | class MultiheadAttention(nn.Module): 19 | def __init__( 20 | self, 21 | *, 22 | device: torch.device, 23 | dtype: torch.dtype, 24 | n_ctx: int, 25 | width: int, 26 | heads: int, 27 | init_scale: float, 28 | qkv_bias: bool, 29 | flash: bool = False 30 | ): 31 | super().__init__() 32 | self.n_ctx = n_ctx 33 | self.width = width 34 | self.heads = heads 35 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) 36 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 37 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) 38 | init_linear(self.c_qkv, init_scale) 39 | init_linear(self.c_proj, init_scale) 40 | 41 | def forward(self, x): 42 | x = self.c_qkv(x) 43 | x = checkpoint(self.attention, (x,), (), True) 44 | x = self.c_proj(x) 45 | return x 46 | 47 | 48 | class QKVMultiheadAttention(nn.Module): 49 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): 50 | super().__init__() 51 | self.device = device 52 | self.dtype = dtype 53 | self.heads = heads 54 | self.n_ctx = n_ctx 55 | self.flash = flash 56 | 57 | def forward(self, qkv): 58 | bs, n_ctx, width = qkv.shape 59 | attn_ch = width // self.heads // 3 60 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 61 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 62 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 63 | 64 | if self.flash: 65 | out = F.scaled_dot_product_attention(q, k, v) 66 | else: 67 | weight = torch.einsum( 68 | "bthc,bshc->bhts", q * scale, k * scale 69 | ) # More stable with f16 than dividing afterwards 70 | wdtype = weight.dtype 71 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 72 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 73 | 74 | return out 75 | 76 | 77 | class ResidualAttentionBlock(nn.Module): 78 | def __init__( 79 | self, 80 | *, 81 | device: torch.device, 82 | dtype: torch.dtype, 83 | n_ctx: int, 84 | width: int, 85 | heads: int, 86 | init_scale: float = 1.0, 87 | qkv_bias: bool = True, 88 | flash: bool = False, 89 | use_checkpoint: bool = False 90 | ): 91 | super().__init__() 92 | 93 | self.use_checkpoint = use_checkpoint 94 | 95 | self.attn = MultiheadAttention( 96 | device=device, 97 | dtype=dtype, 98 | n_ctx=n_ctx, 99 | width=width, 100 | heads=heads, 101 | init_scale=init_scale, 102 | qkv_bias=qkv_bias, 103 | flash=flash 104 | ) 105 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 106 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 107 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) 108 | 109 | def _forward(self, x: torch.Tensor): 110 | x = x + self.attn(self.ln_1(x)) 111 | x = x + self.mlp(self.ln_2(x)) 112 | return x 113 | 114 | def forward(self, x: torch.Tensor): 115 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 116 | 117 | 118 | class MultiheadCrossAttention(nn.Module): 119 | def __init__( 120 | self, 121 | *, 122 | device: torch.device, 123 | dtype: torch.dtype, 124 | width: int, 125 | heads: int, 126 | init_scale: float, 127 | qkv_bias: bool = True, 128 | flash: bool = False, 129 | n_data: Optional[int] = None, 130 | data_width: Optional[int] = None, 131 | ): 132 | super().__init__() 133 | self.n_data = n_data 134 | self.width = width 135 | self.heads = heads 136 | self.data_width = width if data_width is None else data_width 137 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) 138 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) 139 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 140 | self.attention = QKVMultiheadCrossAttention( 141 | device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash 142 | ) 143 | init_linear(self.c_q, init_scale) 144 | init_linear(self.c_kv, init_scale) 145 | init_linear(self.c_proj, init_scale) 146 | 147 | def forward(self, x, data): 148 | x = self.c_q(x) 149 | data = self.c_kv(data) 150 | x = checkpoint(self.attention, (x, data), (), True) 151 | x = self.c_proj(x) 152 | return x 153 | 154 | 155 | class QKVMultiheadCrossAttention(nn.Module): 156 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, 157 | flash: bool = False, n_data: Optional[int] = None): 158 | 159 | super().__init__() 160 | self.device = device 161 | self.dtype = dtype 162 | self.heads = heads 163 | self.n_data = n_data 164 | self.flash = flash 165 | 166 | def forward(self, q, kv): 167 | _, n_ctx, _ = q.shape 168 | bs, n_data, width = kv.shape 169 | attn_ch = width // self.heads // 2 170 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 171 | q = q.view(bs, n_ctx, self.heads, -1) 172 | kv = kv.view(bs, n_data, self.heads, -1) 173 | k, v = torch.split(kv, attn_ch, dim=-1) 174 | 175 | if self.flash: 176 | out = F.scaled_dot_product_attention(q, k, v) 177 | else: 178 | weight = torch.einsum( 179 | "bthc,bshc->bhts", q * scale, k * scale 180 | ) # More stable with f16 than dividing afterwards 181 | wdtype = weight.dtype 182 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 183 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 184 | 185 | return out 186 | 187 | 188 | class ResidualCrossAttentionBlock(nn.Module): 189 | def __init__( 190 | self, 191 | *, 192 | device: Optional[torch.device], 193 | dtype: Optional[torch.dtype], 194 | n_data: Optional[int] = None, 195 | width: int, 196 | heads: int, 197 | data_width: Optional[int] = None, 198 | init_scale: float = 0.25, 199 | qkv_bias: bool = True, 200 | flash: bool = False 201 | ): 202 | super().__init__() 203 | 204 | if data_width is None: 205 | data_width = width 206 | 207 | self.attn = MultiheadCrossAttention( 208 | device=device, 209 | dtype=dtype, 210 | n_data=n_data, 211 | width=width, 212 | heads=heads, 213 | data_width=data_width, 214 | init_scale=init_scale, 215 | qkv_bias=qkv_bias, 216 | flash=flash, 217 | ) 218 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 219 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 220 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 221 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 222 | 223 | def forward(self, x: torch.Tensor, data: torch.Tensor): 224 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 225 | x = x + self.mlp(self.ln_3(x)) 226 | return x 227 | 228 | 229 | class MLP(nn.Module): 230 | def __init__(self, *, 231 | device: Optional[torch.device], 232 | dtype: Optional[torch.dtype], 233 | width: int, 234 | init_scale: float): 235 | super().__init__() 236 | self.width = width 237 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 238 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 239 | self.gelu = nn.GELU() 240 | init_linear(self.c_fc, init_scale) 241 | init_linear(self.c_proj, init_scale) 242 | 243 | def forward(self, x): 244 | return self.c_proj(self.gelu(self.c_fc(x))) 245 | 246 | 247 | class Transformer(nn.Module): 248 | def __init__( 249 | self, 250 | *, 251 | device: Optional[torch.device], 252 | dtype: Optional[torch.dtype], 253 | n_ctx: int, 254 | width: int, 255 | layers: int, 256 | heads: int, 257 | init_scale: float = 0.25, 258 | qkv_bias: bool = True, 259 | flash: bool = False, 260 | use_checkpoint: bool = False 261 | ): 262 | super().__init__() 263 | self.n_ctx = n_ctx 264 | self.width = width 265 | self.layers = layers 266 | self.resblocks = nn.ModuleList( 267 | [ 268 | ResidualAttentionBlock( 269 | device=device, 270 | dtype=dtype, 271 | n_ctx=n_ctx, 272 | width=width, 273 | heads=heads, 274 | init_scale=init_scale, 275 | qkv_bias=qkv_bias, 276 | flash=flash, 277 | use_checkpoint=use_checkpoint 278 | ) 279 | for _ in range(layers) 280 | ] 281 | ) 282 | 283 | def forward(self, x: torch.Tensor): 284 | for block in self.resblocks: 285 | x = block(x) 286 | return x 287 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/transformer_vit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from typing import Optional 7 | import warnings 8 | 9 | from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint 10 | 11 | 12 | def _trunc_normal_(tensor, mean, std, a, b): 13 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 14 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 15 | def norm_cdf(x): 16 | # Computes standard normal cumulative distribution function 17 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 18 | 19 | if (mean < a - 2 * std) or (mean > b + 2 * std): 20 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 21 | "The distribution of values may be incorrect.", 22 | stacklevel=2) 23 | 24 | # Values are generated by using a truncated uniform distribution and 25 | # then using the inverse CDF for the normal distribution. 26 | # Get upper and lower cdf values 27 | l = norm_cdf((a - mean) / std) 28 | u = norm_cdf((b - mean) / std) 29 | 30 | # Uniformly fill tensor with values from [l, u], then translate to 31 | # [2l-1, 2u-1]. 32 | tensor.uniform_(2 * l - 1, 2 * u - 1) 33 | 34 | # Use inverse cdf transform for normal distribution to get truncated 35 | # standard normal 36 | tensor.erfinv_() 37 | 38 | # Transform to proper mean, std 39 | tensor.mul_(std * math.sqrt(2.)) 40 | tensor.add_(mean) 41 | 42 | # Clamp to ensure it's in the proper range 43 | tensor.clamp_(min=a, max=b) 44 | return tensor 45 | 46 | 47 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 48 | # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor 49 | r"""Fills the input Tensor with values drawn from a truncated 50 | normal distribution. The values are effectively drawn from the 51 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 52 | with values outside :math:`[a, b]` redrawn until they are within 53 | the bounds. The method used for generating the random values works 54 | best when :math:`a \leq \text{mean} \leq b`. 55 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 56 | applied while sampling the normal with mean/std applied, therefore a, b args 57 | should be adjusted to match the range of mean, std args. 58 | Args: 59 | tensor: an n-dimensional `torch.Tensor` 60 | mean: the mean of the normal distribution 61 | std: the standard deviation of the normal distribution 62 | a: the minimum cutoff value 63 | b: the maximum cutoff value 64 | Examples: 65 | >>> w = torch.empty(3, 5) 66 | >>> nn.init.trunc_normal_(w) 67 | """ 68 | with torch.no_grad(): 69 | return _trunc_normal_(tensor, mean, std, a, b) 70 | 71 | 72 | def init_weights(m): 73 | if isinstance(m, nn.Linear): 74 | trunc_normal_(m.weight, std=.02) 75 | if isinstance(m, nn.Linear) and m.bias is not None: 76 | nn.init.constant_(m.bias, 0) 77 | elif isinstance(m, nn.LayerNorm): 78 | nn.init.constant_(m.bias, 0) 79 | nn.init.constant_(m.weight, 1.0) 80 | 81 | 82 | class MultiheadAttention(nn.Module): 83 | def __init__( 84 | self, 85 | *, 86 | device: torch.device, 87 | dtype: torch.dtype, 88 | n_ctx: int, 89 | width: int, 90 | heads: int, 91 | qkv_bias: bool 92 | ): 93 | super().__init__() 94 | self.n_ctx = n_ctx 95 | self.width = width 96 | self.heads = heads 97 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) 98 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 99 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) 100 | 101 | def forward(self, x): 102 | x = self.c_qkv(x) 103 | x = checkpoint(self.attention, (x,), (), True) 104 | x = self.c_proj(x) 105 | return x 106 | 107 | 108 | class QKVMultiheadAttention(nn.Module): 109 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): 110 | super().__init__() 111 | self.device = device 112 | self.dtype = dtype 113 | self.heads = heads 114 | self.n_ctx = n_ctx 115 | 116 | def forward(self, qkv): 117 | bs, n_ctx, width = qkv.shape 118 | attn_ch = width // self.heads // 3 119 | scale = 1 / math.sqrt(attn_ch) 120 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 121 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 122 | weight = torch.einsum("bthc,bshc->bhts", q, k) * scale 123 | wdtype = weight.dtype 124 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 125 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 126 | 127 | 128 | class ResidualAttentionBlock(nn.Module): 129 | def __init__( 130 | self, 131 | *, 132 | device: torch.device, 133 | dtype: torch.dtype, 134 | n_ctx: int, 135 | width: int, 136 | heads: int, 137 | qkv_bias: bool = True, 138 | use_checkpoint: bool = False 139 | ): 140 | super().__init__() 141 | 142 | self.use_checkpoint = use_checkpoint 143 | 144 | self.attn = MultiheadAttention( 145 | device=device, 146 | dtype=dtype, 147 | n_ctx=n_ctx, 148 | width=width, 149 | heads=heads, 150 | qkv_bias=qkv_bias 151 | ) 152 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 153 | self.mlp = MLP(device=device, dtype=dtype, width=width) 154 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) 155 | 156 | def _forward(self, x: torch.Tensor): 157 | x = x + self.attn(self.ln_1(x)) 158 | x = x + self.mlp(self.ln_2(x)) 159 | return x 160 | 161 | def forward(self, x: torch.Tensor): 162 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 163 | 164 | 165 | class MultiheadCrossAttention(nn.Module): 166 | def __init__( 167 | self, 168 | *, 169 | device: torch.device, 170 | dtype: torch.dtype, 171 | width: int, 172 | heads: int, 173 | qkv_bias: bool = True, 174 | n_data: Optional[int] = None, 175 | data_width: Optional[int] = None, 176 | ): 177 | super().__init__() 178 | self.n_data = n_data 179 | self.width = width 180 | self.heads = heads 181 | self.data_width = width if data_width is None else data_width 182 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) 183 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) 184 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 185 | self.attention = QKVMultiheadCrossAttention( 186 | device=device, dtype=dtype, heads=heads, n_data=n_data 187 | ) 188 | 189 | def forward(self, x, data): 190 | x = self.c_q(x) 191 | data = self.c_kv(data) 192 | x = checkpoint(self.attention, (x, data), (), True) 193 | x = self.c_proj(x) 194 | return x 195 | 196 | 197 | class QKVMultiheadCrossAttention(nn.Module): 198 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None): 199 | super().__init__() 200 | self.device = device 201 | self.dtype = dtype 202 | self.heads = heads 203 | self.n_data = n_data 204 | 205 | def forward(self, q, kv): 206 | _, n_ctx, _ = q.shape 207 | bs, n_data, width = kv.shape 208 | attn_ch = width // self.heads // 2 209 | scale = 1 / math.sqrt(attn_ch) 210 | q = q.view(bs, n_ctx, self.heads, -1) 211 | kv = kv.view(bs, n_data, self.heads, -1) 212 | k, v = torch.split(kv, attn_ch, dim=-1) 213 | weight = torch.einsum("bthc,bshc->bhts", q, k) * scale 214 | wdtype = weight.dtype 215 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 216 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 217 | 218 | 219 | class ResidualCrossAttentionBlock(nn.Module): 220 | def __init__( 221 | self, 222 | *, 223 | device: Optional[torch.device], 224 | dtype: Optional[torch.dtype], 225 | n_data: Optional[int] = None, 226 | width: int, 227 | heads: int, 228 | data_width: Optional[int] = None, 229 | qkv_bias: bool = True 230 | ): 231 | super().__init__() 232 | 233 | if data_width is None: 234 | data_width = width 235 | 236 | self.attn = MultiheadCrossAttention( 237 | device=device, 238 | dtype=dtype, 239 | n_data=n_data, 240 | width=width, 241 | heads=heads, 242 | data_width=data_width, 243 | qkv_bias=qkv_bias 244 | ) 245 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 246 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 247 | self.mlp = MLP(device=device, dtype=dtype, width=width) 248 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 249 | 250 | def forward(self, x: torch.Tensor, data: torch.Tensor): 251 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 252 | x = x + self.mlp(self.ln_3(x)) 253 | return x 254 | 255 | 256 | class MLP(nn.Module): 257 | def __init__(self, *, 258 | device: Optional[torch.device], 259 | dtype: Optional[torch.dtype], 260 | width: int): 261 | super().__init__() 262 | self.width = width 263 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 264 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 265 | self.gelu = nn.GELU() 266 | 267 | def forward(self, x): 268 | return self.c_proj(self.gelu(self.c_fc(x))) 269 | 270 | 271 | class Transformer(nn.Module): 272 | def __init__( 273 | self, 274 | *, 275 | device: Optional[torch.device], 276 | dtype: Optional[torch.dtype], 277 | n_ctx: int, 278 | width: int, 279 | layers: int, 280 | heads: int, 281 | qkv_bias: bool = True, 282 | use_checkpoint: bool = False 283 | ): 284 | super().__init__() 285 | self.n_ctx = n_ctx 286 | self.width = width 287 | self.layers = layers 288 | self.resblocks = nn.ModuleList( 289 | [ 290 | ResidualAttentionBlock( 291 | device=device, 292 | dtype=dtype, 293 | n_ctx=n_ctx, 294 | width=width, 295 | heads=heads, 296 | qkv_bias=qkv_bias, 297 | use_checkpoint=use_checkpoint 298 | ) 299 | for _ in range(layers) 300 | ] 301 | ) 302 | 303 | self.apply(init_weights) 304 | 305 | def forward(self, x: torch.Tensor): 306 | for block in self.resblocks: 307 | x = block(x) 308 | return x 309 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/asl_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import List, Tuple, Dict, Optional 4 | from omegaconf import DictConfig 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch.optim import lr_scheduler 10 | from typing import Union 11 | from functools import partial 12 | 13 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 14 | 15 | from .tsal_base import ( 16 | AlignedShapeAsLatentModule, 17 | ShapeAsLatentModule, 18 | Latent2MeshOutput, 19 | AlignedMeshOutput 20 | ) 21 | from MeshAnything.miche.michelangelo.models.tsal.inference_utils import extract_geometry 22 | import trimesh 23 | 24 | class AlignedShapeAsLatentPLModule(nn.Module): 25 | def __init__(self, *, 26 | shape_module_cfg, 27 | aligned_module_cfg, 28 | loss_cfg, 29 | optimizer_cfg: Optional[DictConfig] = None, 30 | ckpt_path: Optional[str] = None, 31 | ignore_keys: Union[Tuple[str], List[str]] = ()): 32 | 33 | super().__init__() 34 | 35 | shape_model: ShapeAsLatentModule = instantiate_from_config( 36 | shape_module_cfg, device=None, dtype=None 37 | ) 38 | self.model: AlignedShapeAsLatentModule = instantiate_from_config( 39 | aligned_module_cfg, shape_model=shape_model 40 | ) 41 | 42 | self.loss = instantiate_from_config(loss_cfg) 43 | 44 | self.optimizer_cfg = optimizer_cfg 45 | 46 | if ckpt_path is not None: 47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 48 | 49 | def set_shape_model_only(self): 50 | self.model.set_shape_model_only() 51 | 52 | 53 | 54 | @property 55 | def latent_shape(self): 56 | return self.model.shape_model.latent_shape 57 | 58 | @property 59 | def zero_rank(self): 60 | if self._trainer: 61 | zero_rank = self.trainer.local_rank == 0 62 | else: 63 | zero_rank = True 64 | 65 | return zero_rank 66 | 67 | def init_from_ckpt(self, path, ignore_keys=()): 68 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 69 | 70 | keys = list(state_dict.keys()) 71 | for k in keys: 72 | for ik in ignore_keys: 73 | if k.startswith(ik): 74 | print("Deleting key {} from state_dict.".format(k)) 75 | del state_dict[k] 76 | 77 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 78 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 79 | if len(missing) > 0: 80 | print(f"Missing Keys: {missing}") 81 | print(f"Unexpected Keys: {unexpected}") 82 | 83 | def configure_optimizers(self) -> Tuple[List, List]: 84 | lr = self.learning_rate 85 | 86 | trainable_parameters = list(self.model.parameters()) 87 | 88 | if self.optimizer_cfg is None: 89 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 90 | schedulers = [] 91 | else: 92 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 93 | scheduler_func = instantiate_from_config( 94 | self.optimizer_cfg.scheduler, 95 | max_decay_steps=self.trainer.max_steps, 96 | lr_max=lr 97 | ) 98 | scheduler = { 99 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 100 | "interval": "step", 101 | "frequency": 1 102 | } 103 | optimizers = [optimizer] 104 | schedulers = [scheduler] 105 | 106 | return optimizers, schedulers 107 | 108 | def forward(self, 109 | surface: torch.FloatTensor, 110 | image: torch.FloatTensor, 111 | text: torch.FloatTensor, 112 | volume_queries: torch.FloatTensor): 113 | 114 | """ 115 | 116 | Args: 117 | surface (torch.FloatTensor): 118 | image (torch.FloatTensor): 119 | text (torch.FloatTensor): 120 | volume_queries (torch.FloatTensor): 121 | 122 | Returns: 123 | 124 | """ 125 | 126 | embed_outputs, shape_z = self.model(surface, image, text) 127 | 128 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) 129 | latents = self.model.shape_model.decode(shape_zq) 130 | logits = self.model.shape_model.query_geometry(volume_queries, latents) 131 | 132 | return embed_outputs, logits, posterior 133 | 134 | def encode(self, surface: torch.FloatTensor, sample_posterior=True): 135 | 136 | pc = surface[..., 0:3] 137 | feats = surface[..., 3:6] 138 | 139 | shape_embed, shape_zq, posterior = self.model.shape_model.encode( 140 | pc=pc, feats=feats, sample_posterior=sample_posterior 141 | ) 142 | 143 | return shape_zq 144 | 145 | def encode_latents(self, surface: torch.FloatTensor): 146 | 147 | pc = surface[..., 0:3] 148 | feats = surface[..., 3:6] 149 | 150 | shape_embed, shape_latents = self.model.shape_model.encode_latents( 151 | pc=pc, feats=feats 152 | ) 153 | shape_embed = shape_embed.unsqueeze(1) 154 | assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256 155 | cat_latents = torch.cat([shape_embed, shape_latents], dim=1) 156 | 157 | return cat_latents 158 | 159 | def recon(self, surface): 160 | cat_latents = self.encode_latents(surface) 161 | shape_latents = cat_latents[:, 1:] 162 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_latents) 163 | 164 | # decoding 165 | latents = self.model.shape_model.decode(shape_zq) 166 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 167 | 168 | # reconstruction 169 | mesh_v_f, has_surface = extract_geometry( 170 | geometric_func=geometric_func, 171 | device=surface.device, 172 | batch_size=surface.shape[0], 173 | bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 174 | octree_depth=7, 175 | num_chunks=10000, 176 | ) 177 | recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1]) 178 | 179 | return recon_mesh 180 | 181 | 182 | def to_shape_latents(self, latents): 183 | 184 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False) 185 | return self.model.shape_model.decode(shape_zq) 186 | 187 | def decode(self, 188 | z_q, 189 | bounds: Union[Tuple[float], List[float], float] = 1.1, 190 | octree_depth: int = 7, 191 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 192 | 193 | latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim] 194 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) 195 | 196 | return outputs 197 | 198 | def training_step(self, batch: Dict[str, torch.FloatTensor], 199 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 200 | """ 201 | 202 | Args: 203 | batch (dict): the batch sample, and it contains: 204 | - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] 205 | - image (torch.FloatTensor): [bs, 3, 224, 224] 206 | - text (torch.FloatTensor): [bs, num_templates, 77] 207 | - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] 208 | 209 | batch_idx (int): 210 | 211 | optimizer_idx (int): 212 | 213 | Returns: 214 | loss (torch.FloatTensor): 215 | 216 | """ 217 | 218 | surface = batch["surface"] 219 | image = batch["image"] 220 | text = batch["text"] 221 | 222 | volume_queries = batch["geo_points"][..., 0:3] 223 | shape_labels = batch["geo_points"][..., -1] 224 | 225 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) 226 | 227 | aeloss, log_dict_ae = self.loss( 228 | **embed_outputs, 229 | posteriors=posteriors, 230 | shape_logits=shape_logits, 231 | shape_labels=shape_labels, 232 | split="train" 233 | ) 234 | 235 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], 236 | sync_dist=False, rank_zero_only=True) 237 | 238 | return aeloss 239 | 240 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: 241 | 242 | surface = batch["surface"] 243 | image = batch["image"] 244 | text = batch["text"] 245 | 246 | volume_queries = batch["geo_points"][..., 0:3] 247 | shape_labels = batch["geo_points"][..., -1] 248 | 249 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) 250 | 251 | aeloss, log_dict_ae = self.loss( 252 | **embed_outputs, 253 | posteriors=posteriors, 254 | shape_logits=shape_logits, 255 | shape_labels=shape_labels, 256 | split="val" 257 | ) 258 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], 259 | sync_dist=False, rank_zero_only=True) 260 | 261 | return aeloss 262 | 263 | def visual_alignment(self, 264 | surface: torch.FloatTensor, 265 | image: torch.FloatTensor, 266 | text: torch.FloatTensor, 267 | description: Optional[List[str]] = None, 268 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 269 | octree_depth: int = 7, 270 | num_chunks: int = 10000) -> List[AlignedMeshOutput]: 271 | 272 | """ 273 | 274 | Args: 275 | surface: 276 | image: 277 | text: 278 | description: 279 | bounds: 280 | octree_depth: 281 | num_chunks: 282 | 283 | Returns: 284 | mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list. 285 | 286 | """ 287 | 288 | outputs = [] 289 | 290 | device = surface.device 291 | bs = surface.shape[0] 292 | 293 | embed_outputs, shape_z = self.model(surface, image, text) 294 | 295 | # calculate the similarity 296 | image_embed = embed_outputs["image_embed"] 297 | text_embed = embed_outputs["text_embed"] 298 | shape_embed = embed_outputs["shape_embed"] 299 | 300 | # normalized features 301 | shape_embed = F.normalize(shape_embed, dim=-1, p=2) 302 | text_embed = F.normalize(text_embed, dim=-1, p=2) 303 | image_embed = F.normalize(image_embed, dim=-1, p=2) 304 | 305 | # B x B 306 | shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1) 307 | 308 | # B x B 309 | shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1) 310 | 311 | # shape reconstruction 312 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) 313 | latents = self.model.shape_model.decode(shape_zq) 314 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 315 | 316 | # 2. decode geometry 317 | mesh_v_f, has_surface = extract_geometry( 318 | geometric_func=geometric_func, 319 | device=device, 320 | batch_size=bs, 321 | bounds=bounds, 322 | octree_depth=octree_depth, 323 | num_chunks=num_chunks, 324 | disable=not self.zero_rank 325 | ) 326 | 327 | # 3. decode texture 328 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 329 | if not is_surface: 330 | outputs.append(None) 331 | continue 332 | 333 | out = AlignedMeshOutput() 334 | out.mesh_v = mesh_v 335 | out.mesh_f = mesh_f 336 | out.surface = surface[i].cpu().numpy() 337 | out.image = image[i].cpu().numpy() 338 | if description is not None: 339 | out.text = description[i] 340 | out.shape_text_similarity = shape_text_similarity[i, i] 341 | out.shape_image_similarity = shape_image_similarity[i, i] 342 | 343 | outputs.append(out) 344 | 345 | return outputs 346 | 347 | def latent2mesh(self, 348 | latents: torch.FloatTensor, 349 | bounds: Union[Tuple[float], List[float], float] = 1.1, 350 | octree_depth: int = 7, 351 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 352 | 353 | """ 354 | 355 | Args: 356 | latents: [bs, num_latents, dim] 357 | bounds: 358 | octree_depth: 359 | num_chunks: 360 | 361 | Returns: 362 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 363 | 364 | """ 365 | 366 | outputs = [] 367 | 368 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 369 | 370 | # 2. decode geometry 371 | device = latents.device 372 | mesh_v_f, has_surface = extract_geometry( 373 | geometric_func=geometric_func, 374 | device=device, 375 | batch_size=len(latents), 376 | bounds=bounds, 377 | octree_depth=octree_depth, 378 | num_chunks=num_chunks, 379 | disable=not self.zero_rank 380 | ) 381 | 382 | # 3. decode texture 383 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 384 | if not is_surface: 385 | outputs.append(None) 386 | continue 387 | 388 | out = Latent2MeshOutput() 389 | out.mesh_v = mesh_v 390 | out.mesh_f = mesh_f 391 | 392 | outputs.append(out) 393 | 394 | return outputs 395 | 396 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/clip_asl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange 6 | from transformers import CLIPModel 7 | 8 | from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule 9 | 10 | 11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): 12 | 13 | def __init__(self, *, 14 | shape_model, 15 | clip_model_version: str = "openai/clip-vit-large-patch14"): 16 | 17 | super().__init__() 18 | 19 | # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) 20 | # for params in self.clip_model.parameters(): 21 | # params.requires_grad = False 22 | self.clip_model = None 23 | self.shape_model = shape_model 24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width)) 25 | # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5) 26 | 27 | def set_shape_model_only(self): 28 | self.clip_model = None 29 | 30 | def encode_shape_embed(self, surface, return_latents: bool = False): 31 | """ 32 | 33 | Args: 34 | surface (torch.FloatTensor): [bs, n, 3 + c] 35 | return_latents (bool): 36 | 37 | Returns: 38 | x (torch.FloatTensor): [bs, projection_dim] 39 | shape_latents (torch.FloatTensor): [bs, m, d] 40 | """ 41 | 42 | pc = surface[..., 0:3] 43 | feats = surface[..., 3:] 44 | 45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) 46 | x = shape_embed @ self.shape_projection 47 | 48 | if return_latents: 49 | return x, shape_latents 50 | else: 51 | return x 52 | 53 | def encode_image_embed(self, image): 54 | """ 55 | 56 | Args: 57 | image (torch.FloatTensor): [bs, 3, h, w] 58 | 59 | Returns: 60 | x (torch.FloatTensor): [bs, projection_dim] 61 | """ 62 | 63 | x = self.clip_model.get_image_features(image) 64 | 65 | return x 66 | 67 | def encode_text_embed(self, text): 68 | x = self.clip_model.get_text_features(text) 69 | return x 70 | 71 | def forward(self, surface, image, text): 72 | """ 73 | 74 | Args: 75 | surface (torch.FloatTensor): 76 | image (torch.FloatTensor): [bs, 3, 224, 224] 77 | text (torch.LongTensor): [bs, num_templates, 77] 78 | 79 | Returns: 80 | embed_outputs (dict): the embedding outputs, and it contains: 81 | - image_embed (torch.FloatTensor): 82 | - text_embed (torch.FloatTensor): 83 | - shape_embed (torch.FloatTensor): 84 | - logit_scale (float): 85 | """ 86 | 87 | # # text embedding 88 | # text_embed_all = [] 89 | # for i in range(text.shape[0]): 90 | # text_for_one_sample = text[i] 91 | # text_embed = self.encode_text_embed(text_for_one_sample) 92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 93 | # text_embed = text_embed.mean(dim=0) 94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 95 | # text_embed_all.append(text_embed) 96 | # text_embed_all = torch.stack(text_embed_all) 97 | 98 | b = text.shape[0] 99 | text_tokens = rearrange(text, "b t l -> (b t) l") 100 | text_embed = self.encode_text_embed(text_tokens) 101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 102 | text_embed = text_embed.mean(dim=1) 103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 104 | 105 | # image embedding 106 | image_embed = self.encode_image_embed(image) 107 | 108 | # shape embedding 109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) 110 | 111 | embed_outputs = { 112 | "image_embed": image_embed, 113 | "text_embed": text_embed, 114 | "shape_embed": shape_embed, 115 | # "logit_scale": self.clip_model.logit_scale.exp() 116 | } 117 | 118 | return embed_outputs, shape_latents 119 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from einops import repeat 6 | import numpy as np 7 | from typing import Callable, Tuple, List, Union, Optional 8 | from skimage import measure 9 | 10 | from MeshAnything.miche.michelangelo.graphics.primitives import generate_dense_grid_points 11 | 12 | 13 | @torch.no_grad() 14 | def extract_geometry(geometric_func: Callable, 15 | device: torch.device, 16 | batch_size: int = 1, 17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 18 | octree_depth: int = 7, 19 | num_chunks: int = 10000, 20 | disable: bool = True): 21 | """ 22 | 23 | Args: 24 | geometric_func: 25 | device: 26 | bounds: 27 | octree_depth: 28 | batch_size: 29 | num_chunks: 30 | disable: 31 | 32 | Returns: 33 | 34 | """ 35 | 36 | if isinstance(bounds, float): 37 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] 38 | 39 | bbox_min = np.array(bounds[0:3]) 40 | bbox_max = np.array(bounds[3:6]) 41 | bbox_size = bbox_max - bbox_min 42 | 43 | xyz_samples, grid_size, length = generate_dense_grid_points( 44 | bbox_min=bbox_min, 45 | bbox_max=bbox_max, 46 | octree_depth=octree_depth, 47 | indexing="ij" 48 | ) 49 | xyz_samples = torch.FloatTensor(xyz_samples) 50 | 51 | batch_logits = [] 52 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), 53 | desc="Implicit Function:", disable=disable, leave=False): 54 | queries = xyz_samples[start: start + num_chunks, :].to(device) 55 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size) 56 | 57 | logits = geometric_func(batch_queries) 58 | batch_logits.append(logits.cpu()) 59 | 60 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() 61 | 62 | mesh_v_f = [] 63 | has_surface = np.zeros((batch_size,), dtype=np.bool_) 64 | for i in range(batch_size): 65 | try: 66 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") 67 | vertices = vertices / grid_size * bbox_size + bbox_min 68 | # vertices[:, [0, 1]] = vertices[:, [1, 0]] 69 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) 70 | has_surface[i] = True 71 | 72 | except ValueError: 73 | mesh_v_f.append((None, None)) 74 | has_surface[i] = False 75 | 76 | except RuntimeError: 77 | mesh_v_f.append((None, None)) 78 | has_surface[i] = False 79 | 80 | return mesh_v_f, has_surface 81 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from typing import Optional, Tuple, Dict 7 | 8 | from MeshAnything.miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution 9 | from MeshAnything.miche.michelangelo.utils.eval import compute_psnr 10 | from MeshAnything.miche.michelangelo.utils import misc 11 | 12 | 13 | class KLNearFar(nn.Module): 14 | def __init__(self, 15 | near_weight: float = 0.1, 16 | kl_weight: float = 1.0, 17 | num_near_samples: Optional[int] = None): 18 | 19 | super().__init__() 20 | 21 | self.near_weight = near_weight 22 | self.kl_weight = kl_weight 23 | self.num_near_samples = num_near_samples 24 | self.geo_criterion = nn.BCEWithLogitsLoss() 25 | 26 | def forward(self, 27 | posteriors: Optional[DiagonalGaussianDistribution], 28 | logits: torch.FloatTensor, 29 | labels: torch.FloatTensor, 30 | split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: 31 | 32 | """ 33 | 34 | Args: 35 | posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): 36 | logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; 37 | labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; 38 | split (str): 39 | **kwargs: 40 | 41 | Returns: 42 | loss (torch.Tensor): (,) 43 | log (dict): 44 | 45 | """ 46 | 47 | if self.num_near_samples is None: 48 | num_vol = logits.shape[1] // 2 49 | else: 50 | num_vol = logits.shape[1] - self.num_near_samples 51 | 52 | vol_logits = logits[:, 0:num_vol] 53 | vol_labels = labels[:, 0:num_vol] 54 | 55 | near_logits = logits[:, num_vol:] 56 | near_labels = labels[:, num_vol:] 57 | 58 | # occupancy loss 59 | # vol_bce = self.geo_criterion(vol_logits, vol_labels) 60 | # near_bce = self.geo_criterion(near_logits, near_labels) 61 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 62 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 63 | 64 | if posteriors is None: 65 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) 66 | else: 67 | kl_loss = posteriors.kl(dims=(1, 2)) 68 | kl_loss = torch.mean(kl_loss) 69 | 70 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight 71 | 72 | with torch.no_grad(): 73 | preds = logits >= 0 74 | accuracy = (preds == labels).float() 75 | accuracy = accuracy.mean() 76 | pos_ratio = torch.mean(labels) 77 | 78 | log = { 79 | "{}/total_loss".format(split): loss.clone().detach(), 80 | "{}/near".format(split): near_bce.detach(), 81 | "{}/far".format(split): vol_bce.detach(), 82 | "{}/kl".format(split): kl_loss.detach(), 83 | "{}/accuracy".format(split): accuracy, 84 | "{}/pos_ratio".format(split): pos_ratio 85 | } 86 | 87 | if posteriors is not None: 88 | log[f"{split}/mean"] = posteriors.mean.mean().detach() 89 | log[f"{split}/std_mean"] = posteriors.std.mean().detach() 90 | log[f"{split}/std_max"] = posteriors.std.max().detach() 91 | 92 | return loss, log 93 | 94 | 95 | class KLNearFarColor(nn.Module): 96 | def __init__(self, 97 | near_weight: float = 0.1, 98 | kl_weight: float = 1.0, 99 | color_weight: float = 1.0, 100 | color_criterion: str = "mse", 101 | num_near_samples: Optional[int] = None): 102 | 103 | super().__init__() 104 | 105 | self.color_weight = color_weight 106 | self.near_weight = near_weight 107 | self.kl_weight = kl_weight 108 | self.num_near_samples = num_near_samples 109 | 110 | if color_criterion == "mse": 111 | self.color_criterion = nn.MSELoss() 112 | 113 | elif color_criterion == "l1": 114 | self.color_criterion = nn.L1Loss() 115 | 116 | else: 117 | raise ValueError(f"{color_criterion} must be [`mse`, `l1`].") 118 | 119 | self.geo_criterion = nn.BCEWithLogitsLoss() 120 | 121 | def forward(self, 122 | posteriors: Optional[DiagonalGaussianDistribution], 123 | logits: torch.FloatTensor, 124 | labels: torch.FloatTensor, 125 | pred_colors: torch.FloatTensor, 126 | gt_colors: torch.FloatTensor, 127 | split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: 128 | 129 | """ 130 | 131 | Args: 132 | posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): 133 | logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; 134 | labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; 135 | pred_colors (torch.FloatTensor): [B, M, 3] 136 | gt_colors (torch.FloatTensor): [B, M, 3] 137 | split (str): 138 | **kwargs: 139 | 140 | Returns: 141 | loss (torch.Tensor): (,) 142 | log (dict): 143 | 144 | """ 145 | 146 | if self.num_near_samples is None: 147 | num_vol = logits.shape[1] // 2 148 | else: 149 | num_vol = logits.shape[1] - self.num_near_samples 150 | 151 | vol_logits = logits[:, 0:num_vol] 152 | vol_labels = labels[:, 0:num_vol] 153 | 154 | near_logits = logits[:, num_vol:] 155 | near_labels = labels[:, num_vol:] 156 | 157 | # occupancy loss 158 | # vol_bce = self.geo_criterion(vol_logits, vol_labels) 159 | # near_bce = self.geo_criterion(near_logits, near_labels) 160 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 161 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 162 | 163 | # surface color loss 164 | color = self.color_criterion(pred_colors, gt_colors) 165 | 166 | if posteriors is None: 167 | kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device) 168 | else: 169 | kl_loss = posteriors.kl(dims=(1, 2)) 170 | kl_loss = torch.mean(kl_loss) 171 | 172 | loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight 173 | 174 | with torch.no_grad(): 175 | preds = logits >= 0 176 | accuracy = (preds == labels).float() 177 | accuracy = accuracy.mean() 178 | psnr = compute_psnr(pred_colors, gt_colors) 179 | 180 | log = { 181 | "{}/total_loss".format(split): loss.clone().detach(), 182 | "{}/near".format(split): near_bce.detach(), 183 | "{}/far".format(split): vol_bce.detach(), 184 | "{}/color".format(split): color.detach(), 185 | "{}/kl".format(split): kl_loss.detach(), 186 | "{}/psnr".format(split): psnr.detach(), 187 | "{}/accuracy".format(split): accuracy 188 | } 189 | 190 | return loss, log 191 | 192 | 193 | class ContrastKLNearFar(nn.Module): 194 | def __init__(self, 195 | contrast_weight: float = 1.0, 196 | near_weight: float = 0.1, 197 | kl_weight: float = 1.0, 198 | num_near_samples: Optional[int] = None): 199 | 200 | super().__init__() 201 | 202 | self.labels = None 203 | self.last_local_batch_size = None 204 | 205 | self.contrast_weight = contrast_weight 206 | self.near_weight = near_weight 207 | self.kl_weight = kl_weight 208 | self.num_near_samples = num_near_samples 209 | self.geo_criterion = nn.BCEWithLogitsLoss() 210 | 211 | def forward(self, 212 | shape_embed: torch.FloatTensor, 213 | text_embed: torch.FloatTensor, 214 | image_embed: torch.FloatTensor, 215 | logit_scale: torch.FloatTensor, 216 | posteriors: Optional[DiagonalGaussianDistribution], 217 | shape_logits: torch.FloatTensor, 218 | shape_labels: torch.FloatTensor, 219 | split: Optional[str] = "train", **kwargs): 220 | 221 | local_batch_size = shape_embed.size(0) 222 | 223 | if local_batch_size != self.last_local_batch_size: 224 | self.labels = local_batch_size * misc.get_rank() + torch.arange( 225 | local_batch_size, device=shape_embed.device 226 | ).long() 227 | self.last_local_batch_size = local_batch_size 228 | 229 | # normalized features 230 | shape_embed = F.normalize(shape_embed, dim=-1, p=2) 231 | text_embed = F.normalize(text_embed, dim=-1, p=2) 232 | image_embed = F.normalize(image_embed, dim=-1, p=2) 233 | 234 | # gather features from all GPUs 235 | shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( 236 | [shape_embed, text_embed, image_embed] 237 | ) 238 | 239 | # cosine similarity as logits 240 | logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() 241 | logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() 242 | logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() 243 | logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() 244 | contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + 245 | F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ 246 | (F.cross_entropy(logits_per_shape_image, self.labels) + 247 | F.cross_entropy(logits_per_image_shape, self.labels)) / 2 248 | 249 | # shape reconstruction 250 | if self.num_near_samples is None: 251 | num_vol = shape_logits.shape[1] // 2 252 | else: 253 | num_vol = shape_logits.shape[1] - self.num_near_samples 254 | 255 | vol_logits = shape_logits[:, 0:num_vol] 256 | vol_labels = shape_labels[:, 0:num_vol] 257 | 258 | near_logits = shape_logits[:, num_vol:] 259 | near_labels = shape_labels[:, num_vol:] 260 | 261 | # occupancy loss 262 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 263 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 264 | 265 | if posteriors is None: 266 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) 267 | else: 268 | kl_loss = posteriors.kl(dims=(1, 2)) 269 | kl_loss = torch.mean(kl_loss) 270 | 271 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight 272 | 273 | # compute accuracy 274 | with torch.no_grad(): 275 | pred = torch.argmax(logits_per_shape_text, dim=-1) 276 | correct = pred.eq(self.labels).sum() 277 | shape_text_acc = 100 * correct / local_batch_size 278 | 279 | pred = torch.argmax(logits_per_shape_image, dim=-1) 280 | correct = pred.eq(self.labels).sum() 281 | shape_image_acc = 100 * correct / local_batch_size 282 | 283 | preds = shape_logits >= 0 284 | accuracy = (preds == shape_labels).float() 285 | accuracy = accuracy.mean() 286 | 287 | log = { 288 | "{}/contrast".format(split): contrast_loss.clone().detach(), 289 | "{}/near".format(split): near_bce.detach(), 290 | "{}/far".format(split): vol_bce.detach(), 291 | "{}/kl".format(split): kl_loss.detach(), 292 | "{}/shape_text_acc".format(split): shape_text_acc, 293 | "{}/shape_image_acc".format(split): shape_image_acc, 294 | "{}/total_loss".format(split): loss.clone().detach(), 295 | "{}/accuracy".format(split): accuracy, 296 | } 297 | 298 | if posteriors is not None: 299 | log[f"{split}/mean"] = posteriors.mean.mean().detach() 300 | log[f"{split}/std_mean"] = posteriors.std.mean().detach() 301 | log[f"{split}/std_max"] = posteriors.std.max().detach() 302 | 303 | return loss, log 304 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/sal_perceiver.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Optional 6 | from einops import repeat 7 | import math 8 | 9 | from MeshAnything.miche.michelangelo.models.modules import checkpoint 10 | from MeshAnything.miche.michelangelo.models.modules.embedder import FourierEmbedder 11 | from MeshAnything.miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution 12 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import ( 13 | ResidualCrossAttentionBlock, 14 | Transformer 15 | ) 16 | 17 | from .tsal_base import ShapeAsLatentModule 18 | 19 | 20 | class CrossAttentionEncoder(nn.Module): 21 | 22 | def __init__(self, *, 23 | device: Optional[torch.device], 24 | dtype: Optional[torch.dtype], 25 | num_latents: int, 26 | fourier_embedder: FourierEmbedder, 27 | point_feats: int, 28 | width: int, 29 | heads: int, 30 | layers: int, 31 | init_scale: float = 0.25, 32 | qkv_bias: bool = True, 33 | flash: bool = False, 34 | use_ln_post: bool = False, 35 | use_checkpoint: bool = False): 36 | 37 | super().__init__() 38 | 39 | self.use_checkpoint = use_checkpoint 40 | self.num_latents = num_latents 41 | 42 | self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) 43 | 44 | self.fourier_embedder = fourier_embedder 45 | self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) 46 | self.cross_attn = ResidualCrossAttentionBlock( 47 | device=device, 48 | dtype=dtype, 49 | width=width, 50 | heads=heads, 51 | init_scale=init_scale, 52 | qkv_bias=qkv_bias, 53 | flash=flash, 54 | ) 55 | 56 | self.self_attn = Transformer( 57 | device=device, 58 | dtype=dtype, 59 | n_ctx=num_latents, 60 | width=width, 61 | layers=layers, 62 | heads=heads, 63 | init_scale=init_scale, 64 | qkv_bias=qkv_bias, 65 | flash=flash, 66 | use_checkpoint=False 67 | ) 68 | 69 | if use_ln_post: 70 | self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) 71 | else: 72 | self.ln_post = None 73 | 74 | def _forward(self, pc, feats): 75 | """ 76 | 77 | Args: 78 | pc (torch.FloatTensor): [B, N, 3] 79 | feats (torch.FloatTensor or None): [B, N, C] 80 | 81 | Returns: 82 | 83 | """ 84 | 85 | bs = pc.shape[0] 86 | 87 | data = self.fourier_embedder(pc) 88 | if feats is not None: 89 | data = torch.cat([data, feats], dim=-1) 90 | data = self.input_proj(data) 91 | 92 | query = repeat(self.query, "m c -> b m c", b=bs) 93 | latents = self.cross_attn(query, data) 94 | latents = self.self_attn(latents) 95 | 96 | if self.ln_post is not None: 97 | latents = self.ln_post(latents) 98 | 99 | return latents, pc 100 | 101 | def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): 102 | """ 103 | 104 | Args: 105 | pc (torch.FloatTensor): [B, N, 3] 106 | feats (torch.FloatTensor or None): [B, N, C] 107 | 108 | Returns: 109 | dict 110 | """ 111 | 112 | return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) 113 | 114 | 115 | class CrossAttentionDecoder(nn.Module): 116 | 117 | def __init__(self, *, 118 | device: Optional[torch.device], 119 | dtype: Optional[torch.dtype], 120 | num_latents: int, 121 | out_channels: int, 122 | fourier_embedder: FourierEmbedder, 123 | width: int, 124 | heads: int, 125 | init_scale: float = 0.25, 126 | qkv_bias: bool = True, 127 | flash: bool = False, 128 | use_checkpoint: bool = False): 129 | 130 | super().__init__() 131 | 132 | self.use_checkpoint = use_checkpoint 133 | self.fourier_embedder = fourier_embedder 134 | 135 | self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) 136 | 137 | self.cross_attn_decoder = ResidualCrossAttentionBlock( 138 | device=device, 139 | dtype=dtype, 140 | n_data=num_latents, 141 | width=width, 142 | heads=heads, 143 | init_scale=init_scale, 144 | qkv_bias=qkv_bias, 145 | flash=flash 146 | ) 147 | 148 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 149 | self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) 150 | 151 | def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 152 | queries = self.query_proj(self.fourier_embedder(queries)) 153 | x = self.cross_attn_decoder(queries, latents) 154 | x = self.ln_post(x) 155 | x = self.output_proj(x) 156 | return x 157 | 158 | def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 159 | return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) 160 | 161 | 162 | class ShapeAsLatentPerceiver(ShapeAsLatentModule): 163 | def __init__(self, *, 164 | device: Optional[torch.device], 165 | dtype: Optional[torch.dtype], 166 | num_latents: int, 167 | point_feats: int = 0, 168 | embed_dim: int = 0, 169 | num_freqs: int = 8, 170 | include_pi: bool = True, 171 | width: int, 172 | heads: int, 173 | num_encoder_layers: int, 174 | num_decoder_layers: int, 175 | init_scale: float = 0.25, 176 | qkv_bias: bool = True, 177 | flash: bool = False, 178 | use_ln_post: bool = False, 179 | use_checkpoint: bool = False): 180 | 181 | super().__init__() 182 | 183 | self.use_checkpoint = use_checkpoint 184 | 185 | self.num_latents = num_latents 186 | self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) 187 | 188 | init_scale = init_scale * math.sqrt(1.0 / width) 189 | self.encoder = CrossAttentionEncoder( 190 | device=device, 191 | dtype=dtype, 192 | fourier_embedder=self.fourier_embedder, 193 | num_latents=num_latents, 194 | point_feats=point_feats, 195 | width=width, 196 | heads=heads, 197 | layers=num_encoder_layers, 198 | init_scale=init_scale, 199 | qkv_bias=qkv_bias, 200 | flash=flash, 201 | use_ln_post=use_ln_post, 202 | use_checkpoint=use_checkpoint 203 | ) 204 | 205 | self.embed_dim = embed_dim 206 | if embed_dim > 0: 207 | # VAE embed 208 | self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) 209 | self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype) 210 | self.latent_shape = (num_latents, embed_dim) 211 | else: 212 | self.latent_shape = (num_latents, width) 213 | 214 | self.transformer = Transformer( 215 | device=device, 216 | dtype=dtype, 217 | n_ctx=num_latents, 218 | width=width, 219 | layers=num_decoder_layers, 220 | heads=heads, 221 | init_scale=init_scale, 222 | qkv_bias=qkv_bias, 223 | flash=flash, 224 | use_checkpoint=use_checkpoint 225 | ) 226 | 227 | # geometry decoder 228 | self.geo_decoder = CrossAttentionDecoder( 229 | device=device, 230 | dtype=dtype, 231 | fourier_embedder=self.fourier_embedder, 232 | out_channels=1, 233 | num_latents=num_latents, 234 | width=width, 235 | heads=heads, 236 | init_scale=init_scale, 237 | qkv_bias=qkv_bias, 238 | flash=flash, 239 | use_checkpoint=use_checkpoint 240 | ) 241 | 242 | def encode(self, 243 | pc: torch.FloatTensor, 244 | feats: Optional[torch.FloatTensor] = None, 245 | sample_posterior: bool = True): 246 | """ 247 | 248 | Args: 249 | pc (torch.FloatTensor): [B, N, 3] 250 | feats (torch.FloatTensor or None): [B, N, C] 251 | sample_posterior (bool): 252 | 253 | Returns: 254 | latents (torch.FloatTensor) 255 | center_pos (torch.FloatTensor or None): 256 | posterior (DiagonalGaussianDistribution or None): 257 | """ 258 | 259 | latents, center_pos = self.encoder(pc, feats) 260 | 261 | posterior = None 262 | if self.embed_dim > 0: 263 | moments = self.pre_kl(latents) 264 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) 265 | 266 | if sample_posterior: 267 | latents = posterior.sample() 268 | else: 269 | latents = posterior.mode() 270 | 271 | return latents, center_pos, posterior 272 | 273 | def decode(self, latents: torch.FloatTensor): 274 | latents = self.post_kl(latents) 275 | return self.transformer(latents) 276 | 277 | def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 278 | logits = self.geo_decoder(queries, latents).squeeze(-1) 279 | return logits 280 | 281 | def forward(self, 282 | pc: torch.FloatTensor, 283 | feats: torch.FloatTensor, 284 | volume_queries: torch.FloatTensor, 285 | sample_posterior: bool = True): 286 | """ 287 | 288 | Args: 289 | pc (torch.FloatTensor): [B, N, 3] 290 | feats (torch.FloatTensor or None): [B, N, C] 291 | volume_queries (torch.FloatTensor): [B, P, 3] 292 | sample_posterior (bool): 293 | 294 | Returns: 295 | logits (torch.FloatTensor): [B, P] 296 | center_pos (torch.FloatTensor): [B, M, 3] 297 | posterior (DiagonalGaussianDistribution or None). 298 | 299 | """ 300 | 301 | latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) 302 | 303 | latents = self.decode(latents) 304 | logits = self.query_geometry(volume_queries, latents) 305 | 306 | return logits, center_pos, posterior 307 | 308 | 309 | class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): 310 | 311 | def __init__(self, *, 312 | device: Optional[torch.device], 313 | dtype: Optional[torch.dtype], 314 | num_latents: int, 315 | point_feats: int = 0, 316 | embed_dim: int = 0, 317 | num_freqs: int = 8, 318 | include_pi: bool = True, 319 | width: int, 320 | heads: int, 321 | num_encoder_layers: int, 322 | num_decoder_layers: int, 323 | init_scale: float = 0.25, 324 | qkv_bias: bool = True, 325 | flash: bool = False, 326 | use_ln_post: bool = False, 327 | use_checkpoint: bool = False): 328 | 329 | super().__init__( 330 | device=device, 331 | dtype=dtype, 332 | num_latents=1 + num_latents, 333 | point_feats=point_feats, 334 | embed_dim=embed_dim, 335 | num_freqs=num_freqs, 336 | include_pi=include_pi, 337 | width=width, 338 | heads=heads, 339 | num_encoder_layers=num_encoder_layers, 340 | num_decoder_layers=num_decoder_layers, 341 | init_scale=init_scale, 342 | qkv_bias=qkv_bias, 343 | flash=flash, 344 | use_ln_post=use_ln_post, 345 | use_checkpoint=use_checkpoint 346 | ) 347 | 348 | self.width = width 349 | 350 | def encode(self, 351 | pc: torch.FloatTensor, 352 | feats: Optional[torch.FloatTensor] = None, 353 | sample_posterior: bool = True): 354 | """ 355 | 356 | Args: 357 | pc (torch.FloatTensor): [B, N, 3] 358 | feats (torch.FloatTensor or None): [B, N, c] 359 | sample_posterior (bool): 360 | 361 | Returns: 362 | shape_embed (torch.FloatTensor) 363 | kl_embed (torch.FloatTensor): 364 | posterior (DiagonalGaussianDistribution or None): 365 | """ 366 | 367 | shape_embed, latents = self.encode_latents(pc, feats) 368 | kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) 369 | 370 | return shape_embed, kl_embed, posterior 371 | 372 | def encode_latents(self, 373 | pc: torch.FloatTensor, 374 | feats: Optional[torch.FloatTensor] = None): 375 | 376 | x, _ = self.encoder(pc, feats) 377 | 378 | shape_embed = x[:, 0] 379 | latents = x[:, 1:] 380 | 381 | return shape_embed, latents 382 | 383 | def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): 384 | posterior = None 385 | if self.embed_dim > 0: 386 | moments = self.pre_kl(latents) 387 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) 388 | 389 | if sample_posterior: 390 | kl_embed = posterior.sample() 391 | else: 392 | kl_embed = posterior.mode() 393 | else: 394 | kl_embed = latents 395 | 396 | return kl_embed, posterior 397 | 398 | def forward(self, 399 | pc: torch.FloatTensor, 400 | feats: torch.FloatTensor, 401 | volume_queries: torch.FloatTensor, 402 | sample_posterior: bool = True): 403 | """ 404 | 405 | Args: 406 | pc (torch.FloatTensor): [B, N, 3] 407 | feats (torch.FloatTensor or None): [B, N, C] 408 | volume_queries (torch.FloatTensor): [B, P, 3] 409 | sample_posterior (bool): 410 | 411 | Returns: 412 | shape_embed (torch.FloatTensor): [B, projection_dim] 413 | logits (torch.FloatTensor): [B, M] 414 | posterior (DiagonalGaussianDistribution or None). 415 | 416 | """ 417 | 418 | shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) 419 | 420 | latents = self.decode(kl_embed) 421 | logits = self.query_geometry(volume_queries, latents) 422 | 423 | return shape_embed, logits, posterior 424 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/sal_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import List, Tuple, Dict, Optional 4 | from omegaconf import DictConfig 5 | 6 | import torch 7 | from torch.optim import lr_scheduler 8 | import pytorch_lightning as pl 9 | from typing import Union 10 | from functools import partial 11 | 12 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 13 | 14 | from .inference_utils import extract_geometry 15 | from .tsal_base import ( 16 | ShapeAsLatentModule, 17 | Latent2MeshOutput, 18 | Point2MeshOutput 19 | ) 20 | 21 | 22 | class ShapeAsLatentPLModule(pl.LightningModule): 23 | 24 | def __init__(self, *, 25 | module_cfg, 26 | loss_cfg, 27 | optimizer_cfg: Optional[DictConfig] = None, 28 | ckpt_path: Optional[str] = None, 29 | ignore_keys: Union[Tuple[str], List[str]] = ()): 30 | 31 | super().__init__() 32 | 33 | self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None) 34 | 35 | self.loss = instantiate_from_config(loss_cfg) 36 | 37 | self.optimizer_cfg = optimizer_cfg 38 | 39 | if ckpt_path is not None: 40 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 41 | 42 | self.save_hyperparameters() 43 | 44 | @property 45 | def latent_shape(self): 46 | return self.sal.latent_shape 47 | 48 | @property 49 | def zero_rank(self): 50 | if self._trainer: 51 | zero_rank = self.trainer.local_rank == 0 52 | else: 53 | zero_rank = True 54 | 55 | return zero_rank 56 | 57 | def init_from_ckpt(self, path, ignore_keys=()): 58 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 59 | 60 | keys = list(state_dict.keys()) 61 | for k in keys: 62 | for ik in ignore_keys: 63 | if k.startswith(ik): 64 | print("Deleting key {} from state_dict.".format(k)) 65 | del state_dict[k] 66 | 67 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 68 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 69 | if len(missing) > 0: 70 | print(f"Missing Keys: {missing}") 71 | print(f"Unexpected Keys: {unexpected}") 72 | 73 | def configure_optimizers(self) -> Tuple[List, List]: 74 | lr = self.learning_rate 75 | 76 | # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)] 77 | # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 78 | 79 | if self.optimizer_cfg is None: 80 | optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 81 | schedulers = [] 82 | else: 83 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters()) 84 | scheduler_func = instantiate_from_config( 85 | self.optimizer_cfg.scheduler, 86 | max_decay_steps=self.trainer.max_steps, 87 | lr_max=lr 88 | ) 89 | scheduler = { 90 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 91 | "interval": "step", 92 | "frequency": 1 93 | } 94 | optimizers = [optimizer] 95 | schedulers = [scheduler] 96 | 97 | return optimizers, schedulers 98 | 99 | def forward(self, 100 | pc: torch.FloatTensor, 101 | feats: torch.FloatTensor, 102 | volume_queries: torch.FloatTensor): 103 | 104 | logits, center_pos, posterior = self.sal(pc, feats, volume_queries) 105 | 106 | return posterior, logits 107 | 108 | def encode(self, surface: torch.FloatTensor, sample_posterior=True): 109 | 110 | pc = surface[..., 0:3] 111 | feats = surface[..., 3:6] 112 | 113 | latents, center_pos, posterior = self.sal.encode( 114 | pc=pc, feats=feats, sample_posterior=sample_posterior 115 | ) 116 | 117 | return latents 118 | 119 | def decode(self, 120 | z_q, 121 | bounds: Union[Tuple[float], List[float], float] = 1.1, 122 | octree_depth: int = 7, 123 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 124 | 125 | latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim] 126 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) 127 | 128 | return outputs 129 | 130 | def training_step(self, batch: Dict[str, torch.FloatTensor], 131 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 132 | """ 133 | 134 | Args: 135 | batch (dict): the batch sample, and it contains: 136 | - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] 137 | - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] 138 | 139 | batch_idx (int): 140 | 141 | optimizer_idx (int): 142 | 143 | Returns: 144 | loss (torch.FloatTensor): 145 | 146 | """ 147 | 148 | pc = batch["surface"][..., 0:3] 149 | feats = batch["surface"][..., 3:] 150 | 151 | volume_queries = batch["geo_points"][..., 0:3] 152 | volume_labels = batch["geo_points"][..., -1] 153 | 154 | posterior, logits = self( 155 | pc=pc, feats=feats, volume_queries=volume_queries 156 | ) 157 | aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train") 158 | 159 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], 160 | sync_dist=False, rank_zero_only=True) 161 | 162 | return aeloss 163 | 164 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: 165 | 166 | pc = batch["surface"][..., 0:3] 167 | feats = batch["surface"][..., 3:] 168 | 169 | volume_queries = batch["geo_points"][..., 0:3] 170 | volume_labels = batch["geo_points"][..., -1] 171 | 172 | posterior, logits = self( 173 | pc=pc, feats=feats, volume_queries=volume_queries, 174 | ) 175 | aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val") 176 | 177 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], 178 | sync_dist=False, rank_zero_only=True) 179 | 180 | return aeloss 181 | 182 | def point2mesh(self, 183 | pc: torch.FloatTensor, 184 | feats: torch.FloatTensor, 185 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 186 | octree_depth: int = 7, 187 | num_chunks: int = 10000) -> List[Point2MeshOutput]: 188 | 189 | """ 190 | 191 | Args: 192 | pc: 193 | feats: 194 | bounds: 195 | octree_depth: 196 | num_chunks: 197 | 198 | Returns: 199 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 200 | 201 | """ 202 | 203 | outputs = [] 204 | 205 | device = pc.device 206 | bs = pc.shape[0] 207 | 208 | # 1. point encoder + latents transformer 209 | latents, center_pos, posterior = self.sal.encode(pc, feats) 210 | latents = self.sal.decode(latents) # latents: [bs, num_latents, dim] 211 | 212 | geometric_func = partial(self.sal.query_geometry, latents=latents) 213 | 214 | # 2. decode geometry 215 | mesh_v_f, has_surface = extract_geometry( 216 | geometric_func=geometric_func, 217 | device=device, 218 | batch_size=bs, 219 | bounds=bounds, 220 | octree_depth=octree_depth, 221 | num_chunks=num_chunks, 222 | disable=not self.zero_rank 223 | ) 224 | 225 | # 3. decode texture 226 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 227 | if not is_surface: 228 | outputs.append(None) 229 | continue 230 | 231 | out = Point2MeshOutput() 232 | out.mesh_v = mesh_v 233 | out.mesh_f = mesh_f 234 | out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy() 235 | 236 | if center_pos is not None: 237 | out.center = center_pos[i].cpu().numpy() 238 | 239 | outputs.append(out) 240 | 241 | return outputs 242 | 243 | def latent2mesh(self, 244 | latents: torch.FloatTensor, 245 | bounds: Union[Tuple[float], List[float], float] = 1.1, 246 | octree_depth: int = 7, 247 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 248 | 249 | """ 250 | 251 | Args: 252 | latents: [bs, num_latents, dim] 253 | bounds: 254 | octree_depth: 255 | num_chunks: 256 | 257 | Returns: 258 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 259 | 260 | """ 261 | 262 | outputs = [] 263 | 264 | geometric_func = partial(self.sal.query_geometry, latents=latents) 265 | 266 | # 2. decode geometry 267 | device = latents.device 268 | mesh_v_f, has_surface = extract_geometry( 269 | geometric_func=geometric_func, 270 | device=device, 271 | batch_size=len(latents), 272 | bounds=bounds, 273 | octree_depth=octree_depth, 274 | num_chunks=num_chunks, 275 | disable=not self.zero_rank 276 | ) 277 | 278 | # 3. decode texture 279 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 280 | if not is_surface: 281 | outputs.append(None) 282 | continue 283 | 284 | out = Latent2MeshOutput() 285 | out.mesh_v = mesh_v 286 | out.mesh_f = mesh_f 287 | 288 | outputs.append(out) 289 | 290 | return outputs 291 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/tsal_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from typing import Tuple, List, Optional 5 | 6 | 7 | class Point2MeshOutput(object): 8 | def __init__(self): 9 | self.mesh_v = None 10 | self.mesh_f = None 11 | self.center = None 12 | self.pc = None 13 | 14 | 15 | class Latent2MeshOutput(object): 16 | 17 | def __init__(self): 18 | self.mesh_v = None 19 | self.mesh_f = None 20 | 21 | 22 | class AlignedMeshOutput(object): 23 | 24 | def __init__(self): 25 | self.mesh_v = None 26 | self.mesh_f = None 27 | self.surface = None 28 | self.image = None 29 | self.text: Optional[str] = None 30 | self.shape_text_similarity: Optional[float] = None 31 | self.shape_image_similarity: Optional[float] = None 32 | 33 | 34 | class ShapeAsLatentPLModule(nn.Module): 35 | latent_shape: Tuple[int] 36 | 37 | def encode(self, surface, *args, **kwargs): 38 | raise NotImplementedError 39 | 40 | def decode(self, z_q, *args, **kwargs): 41 | raise NotImplementedError 42 | 43 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 44 | raise NotImplementedError 45 | 46 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 47 | raise NotImplementedError 48 | 49 | 50 | class ShapeAsLatentModule(nn.Module): 51 | latent_shape: Tuple[int, int] 52 | 53 | def __init__(self, *args, **kwargs): 54 | super().__init__() 55 | 56 | def encode(self, *args, **kwargs): 57 | raise NotImplementedError 58 | 59 | def decode(self, *args, **kwargs): 60 | raise NotImplementedError 61 | 62 | def query_geometry(self, *args, **kwargs): 63 | raise NotImplementedError 64 | 65 | 66 | class AlignedShapeAsLatentPLModule(nn.Module): 67 | latent_shape: Tuple[int] 68 | 69 | def set_shape_model_only(self): 70 | raise NotImplementedError 71 | 72 | def encode(self, surface, *args, **kwargs): 73 | raise NotImplementedError 74 | 75 | def decode(self, z_q, *args, **kwargs): 76 | raise NotImplementedError 77 | 78 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 79 | raise NotImplementedError 80 | 81 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 82 | raise NotImplementedError 83 | 84 | 85 | class AlignedShapeAsLatentModule(nn.Module): 86 | shape_model: ShapeAsLatentModule 87 | latent_shape: Tuple[int, int] 88 | 89 | def __init__(self, *args, **kwargs): 90 | super().__init__() 91 | 92 | def set_shape_model_only(self): 93 | raise NotImplementedError 94 | 95 | def encode_image_embed(self, *args, **kwargs): 96 | raise NotImplementedError 97 | 98 | def encode_text_embed(self, *args, **kwargs): 99 | raise NotImplementedError 100 | 101 | def encode_shape_embed(self, *args, **kwargs): 102 | raise NotImplementedError 103 | 104 | 105 | class TexturedShapeAsLatentModule(nn.Module): 106 | 107 | def __init__(self, *args, **kwargs): 108 | super().__init__() 109 | 110 | def encode(self, *args, **kwargs): 111 | raise NotImplementedError 112 | 113 | def decode(self, *args, **kwargs): 114 | raise NotImplementedError 115 | 116 | def query_geometry(self, *args, **kwargs): 117 | raise NotImplementedError 118 | 119 | def query_color(self, *args, **kwargs): 120 | raise NotImplementedError 121 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .misc import instantiate_from_config 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7): 7 | 8 | mse = torch.mean((x - y) ** 2) 9 | psnr = 10 * torch.log10(data_range / (mse + eps)) 10 | 11 | return psnr 12 | 13 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import io 5 | import tarfile 6 | import json 7 | import numpy as np 8 | import numpy.lib.format 9 | 10 | 11 | def mkdir(path): 12 | os.makedirs(path, exist_ok=True) 13 | return path 14 | 15 | 16 | def npy_loads(data): 17 | stream = io.BytesIO(data) 18 | return np.lib.format.read_array(stream) 19 | 20 | 21 | def npz_loads(data): 22 | return np.load(io.BytesIO(data)) 23 | 24 | 25 | def json_loads(data): 26 | return json.loads(data) 27 | 28 | 29 | def load_json(filepath): 30 | with open(filepath, "r") as f: 31 | data = json.load(f) 32 | return data 33 | 34 | 35 | def write_json(filepath, data): 36 | with open(filepath, "w") as f: 37 | json.dump(data, f, indent=2) 38 | 39 | 40 | def extract_tar(tar_path, tar_cache_folder): 41 | 42 | with tarfile.open(tar_path, "r") as tar: 43 | tar.extractall(path=tar_cache_folder) 44 | 45 | tar_uids = sorted(os.listdir(tar_cache_folder)) 46 | print(f"extract tar: {tar_path} to {tar_cache_folder}") 47 | return tar_uids 48 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import importlib 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | 10 | def get_obj_from_str(string, reload=False): 11 | module, cls = string.rsplit(".", 1) 12 | if reload: 13 | module_imp = importlib.import_module(module) 14 | importlib.reload(module_imp) 15 | return getattr(importlib.import_module(module, package=None), cls) 16 | 17 | 18 | def get_obj_from_config(config): 19 | if "target" not in config: 20 | raise KeyError("Expected key `target` to instantiate.") 21 | 22 | return get_obj_from_str(config["target"]) 23 | 24 | 25 | def instantiate_from_config(config, **kwargs): 26 | if "target" not in config: 27 | raise KeyError("Expected key `target` to instantiate.") 28 | 29 | cls = get_obj_from_str(config["target"]) 30 | 31 | params = config.get("params", dict()) 32 | # params.update(kwargs) 33 | # instance = cls(**params) 34 | kwargs.update(params) 35 | instance = cls(**kwargs) 36 | 37 | return instance 38 | 39 | 40 | def is_dist_avail_and_initialized(): 41 | if not dist.is_available(): 42 | return False 43 | if not dist.is_initialized(): 44 | return False 45 | return True 46 | 47 | 48 | def get_rank(): 49 | if not is_dist_avail_and_initialized(): 50 | return 0 51 | return dist.get_rank() 52 | 53 | 54 | def get_world_size(): 55 | if not is_dist_avail_and_initialized(): 56 | return 1 57 | return dist.get_world_size() 58 | 59 | 60 | def all_gather_batch(tensors): 61 | """ 62 | Performs all_gather operation on the provided tensors. 63 | """ 64 | # Queue the gathered tensors 65 | world_size = get_world_size() 66 | # There is no need for reduction in the single-proc case 67 | if world_size == 1: 68 | return tensors 69 | tensor_list = [] 70 | output_tensor = [] 71 | for tensor in tensors: 72 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 73 | dist.all_gather( 74 | tensor_all, 75 | tensor, 76 | async_op=False # performance opt 77 | ) 78 | 79 | tensor_list.append(tensor_all) 80 | 81 | for tensor_all in tensor_list: 82 | output_tensor.append(torch.cat(tensor_all, dim=0)) 83 | return output_tensor 84 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/color_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | # Helper functions 6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): 7 | colormap = plt.cm.get_cmap(colormap) 8 | if normalize: 9 | vmin = np.min(inp) 10 | vmax = np.max(inp) 11 | 12 | norm = plt.Normalize(vmin, vmax) 13 | return colormap(norm(inp))[:, :3] 14 | 15 | 16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): 17 | # tex dims need to be power of two. 18 | array = np.ones((width, height, 3), dtype='float32') 19 | 20 | # width in texels of each checker 21 | checker_w = width / n_checkers_x 22 | checker_h = height / n_checkers_y 23 | 24 | for y in range(height): 25 | for x in range(width): 26 | color_key = int(x / checker_w) + int(y / checker_h) 27 | if color_key % 2 == 0: 28 | array[x, y, :] = [1., 0.874, 0.0] 29 | else: 30 | array[x, y, :] = [0., 0., 0.] 31 | return array 32 | 33 | 34 | def gen_circle(width=256, height=256): 35 | xx, yy = np.mgrid[:width, :height] 36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 37 | array = np.ones((width, height, 4), dtype='float32') 38 | array[:, :, 0] = (circle <= width) 39 | array[:, :, 1] = (circle <= width) 40 | array[:, :, 2] = (circle <= width) 41 | array[:, :, 3] = circle <= width 42 | return array 43 | 44 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/html_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | import base64 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def to_html_frame(content): 9 | 10 | html_frame = f""" 11 | 12 | 13 | {content} 14 | 15 | 16 | """ 17 | 18 | return html_frame 19 | 20 | 21 | def to_single_row_table(caption: str, content: str): 22 | 23 | table_html = f""" 24 | 25 | 26 | 27 | 28 | 29 |
{caption}
{content}
30 | """ 31 | 32 | return table_html 33 | 34 | 35 | def to_image_embed_tag(image: np.ndarray): 36 | 37 | # Convert np.ndarray to bytes 38 | img = Image.fromarray(image) 39 | raw_bytes = io.BytesIO() 40 | img.save(raw_bytes, "PNG") 41 | 42 | # Encode bytes to base64 43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8") 44 | 45 | image_tag = f""" 46 | Embedded Image 47 | """ 48 | 49 | return image_tag 50 | -------------------------------------------------------------------------------- /MeshAnything/miche/shapevae-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: MeshAnything.miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 3 | params: 4 | shape_module_cfg: 5 | target: MeshAnything.miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 6 | params: 7 | num_latents: 256 8 | embed_dim: 64 9 | point_feats: 3 # normal 10 | num_freqs: 8 11 | include_pi: false 12 | heads: 12 13 | width: 768 14 | num_encoder_layers: 8 15 | num_decoder_layers: 16 16 | use_ln_post: true 17 | init_scale: 0.25 18 | qkv_bias: false 19 | use_checkpoint: true 20 | aligned_module_cfg: 21 | target: MeshAnything.miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 22 | params: 23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 24 | 25 | loss_cfg: 26 | target: MeshAnything.miche.michelangelo.models.tsal.loss.ContrastKLNearFar 27 | params: 28 | contrast_weight: 0.1 29 | near_weight: 0.1 30 | kl_weight: 0.001 31 | 32 | optimizer_cfg: 33 | optimizer: 34 | target: torch.optim.AdamW 35 | params: 36 | betas: [0.9, 0.99] 37 | eps: 1.e-6 38 | weight_decay: 1.e-2 39 | 40 | scheduler: 41 | target: MeshAnything.miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 42 | params: 43 | warm_up_steps: 5000 44 | f_start: 1.e-6 45 | f_min: 1.e-3 46 | f_max: 1.0 47 | -------------------------------------------------------------------------------- /MeshAnything/models/meshanything.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from transformers import AutoModelForCausalLM, AutoConfig, AutoModel 4 | from MeshAnything.miche.encode import load_model 5 | from MeshAnything.models.shape_opt import ShapeOPTConfig 6 | from einops.layers.torch import Rearrange 7 | from einops import rearrange, repeat, reduce, pack, unpack 8 | import torch.nn.functional as F 9 | 10 | class NoiseResistantDecoder(nn.Module): 11 | 12 | def __init__(self, args): 13 | super().__init__() 14 | self.args = args 15 | self.pad_id = -1 16 | self.num_quantizers = 3 17 | 18 | self.discrete_num = 128 19 | self.codebook_size = args.codebook_size 20 | self.codebook_dim = args.codebook_dim 21 | 22 | config = AutoConfig.from_pretrained("bert-base-uncased") 23 | config.num_hidden_layers = 6 24 | self.decoder= AutoModel.from_config(config=config).to_bettertransformer().encoder 25 | self.n_embd = self.decoder.config.hidden_size 26 | 27 | self.pos_embedding = nn.Embedding(18000, self.n_embd) 28 | self.layernorm = nn.LayerNorm(self.n_embd) 29 | self.point_layernorm = nn.LayerNorm(self.n_embd) 30 | 31 | self.cond_length = 257 32 | self.cond_dim = 768 33 | self.point_pe = nn.Embedding(self.cond_length, self.n_embd) 34 | self.cond_proj = nn.Linear(self.cond_dim, self.n_embd) 35 | self.cond_head_proj = nn.Linear(self.cond_dim, self.n_embd) 36 | 37 | self.project_down_codebook = nn.Linear(self.codebook_dim * 3, self.n_embd) 38 | self.to_coor_logits = nn.Sequential( 39 | nn.Linear(self.n_embd, self.discrete_num * 9), 40 | Rearrange('... (v c) -> ... v c', v = 9) 41 | ) 42 | def process_point_feature(self, encode_feature): 43 | point_feature = torch.zeros(encode_feature.shape[0], self.cond_length, self.n_embd, device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype) 44 | point_feature[:, 0] = self.cond_head_proj(encode_feature[:, 0]) 45 | point_feature[:, 1:] = self.cond_proj(encode_feature[:, 1:]) 46 | 47 | point_feature = self.point_layernorm(point_feature + self.point_pe.weight[None, :point_feature.shape[1]]) 48 | return point_feature 49 | 50 | def forward(self, input_ids, input_embeds, point_feature = None): 51 | input_ids = input_ids.reshape(input_ids.shape[0], -1) 52 | point_feature = self.process_point_feature(point_feature) 53 | 54 | face_embeds = rearrange(input_embeds, 'b (nf nv) d -> b nf (nv d)', nv = 3) 55 | face_embeds = self.project_down_codebook(face_embeds) 56 | 57 | face_mask = reduce(input_ids != self.pad_id, 'b (nf nv q) -> b nf', 'all', nv = 3, q = self.num_quantizers) 58 | face_embeds[~face_mask] = 0 59 | 60 | face_embeds = self.layernorm(face_embeds + self.pos_embedding.weight[None, :face_embeds.shape[1]]) 61 | 62 | outputs = self.decoder( 63 | hidden_states=torch.concatenate([point_feature, face_embeds], dim=1), 64 | ) 65 | decoded = outputs.last_hidden_state[:, self.cond_length:] # batch x nfaces x dim 66 | decoded = decoded.masked_fill(~face_mask.unsqueeze(-1), 0.) 67 | 68 | # batch x nfaces x 9 -> batch x nfaces x 3 x 3 69 | pred_face_logits = self.to_coor_logits(decoded) # batch x nfaces x 9 x ndiscrete 70 | pred_face_coords = rearrange(pred_face_logits.argmax(dim = -1), '... (v c) -> ... v c', v = 3) 71 | 72 | continuous_coors = undiscretize( 73 | pred_face_coords, 74 | num_discrete = self.discrete_num, 75 | low = -0.5, 76 | high = 0.5 77 | ) 78 | continuous_coors = continuous_coors.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1 1'), float('nan')) 79 | 80 | return continuous_coors 81 | 82 | class MeshAnything(nn.Module): 83 | def __init__(self, args): 84 | super().__init__() 85 | self.args = args 86 | self.point_encoder = load_model(ckpt_path=None) 87 | self.tokenizer = NoiseResistantDecoder(args) 88 | 89 | self.num_quantizers = 3 90 | self.face_per_token = self.num_quantizers * 3 91 | self.cond_length = 257 92 | self.cond_dim = 768 93 | self.max_length = args.n_max_triangles * self.face_per_token + 2 + self.cond_length 94 | 95 | self.config = ShapeOPTConfig.from_pretrained( 96 | args.llm, 97 | n_positions=18259, 98 | max_position_embeddings=18259, 99 | vocab_size=self.tokenizer.codebook_size + 3, 100 | _attn_implementation="flash_attention_2" 101 | ) 102 | self.bos_token_id = 0 103 | self.eos_token_id = 1 104 | self.pad_token_id = 2 105 | self.config.bos_token_id = self.bos_token_id 106 | self.config.eos_token_id = self.eos_token_id 107 | self.config.pad_token_id = self.pad_token_id 108 | self.config.quantize_codebook_dim = self.tokenizer.codebook_dim 109 | self.config.face_per_token = self.face_per_token 110 | self.config._attn_implementation="flash_attention_2" 111 | self.config.cond_length = self.cond_length 112 | if self.config.word_embed_proj_dim != self.config.hidden_size: 113 | self.config.word_embed_proj_dim = self.config.hidden_size 114 | self.transformer = AutoModelForCausalLM.from_config( 115 | config=self.config, use_flash_attention_2 = True 116 | ) 117 | self.transformer.to_bettertransformer() 118 | self.transformer.model.decoder.quantize_codebooks = nn.Parameter(torch.zeros(1, self.tokenizer.codebook_size, self.tokenizer.codebook_dim)) 119 | 120 | self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim) 121 | self.cond_proj = nn.Linear(self.cond_dim * 2, self.config.word_embed_proj_dim) 122 | 123 | self.eval() 124 | 125 | def process_point_feature(self, point_feature): 126 | encode_feature = torch.zeros(point_feature.shape[0], self.cond_length, self.config.word_embed_proj_dim, 127 | device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype) 128 | encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0]) 129 | shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:]) 130 | encode_feature[:, 1:] = self.cond_proj(torch.cat([point_feature[:, 1:], shape_latents], dim=-1)) 131 | 132 | return encode_feature 133 | 134 | @torch.no_grad() 135 | def forward(self, pc_normal, sampling=False) -> dict: 136 | batch_size = pc_normal.shape[0] 137 | point_feature = self.point_encoder.encode_latents(pc_normal) 138 | processed_point_feature = self.process_point_feature(point_feature) 139 | 140 | generate_length = self.max_length - self.cond_length 141 | net_device = next(self.parameters()).device 142 | outputs = torch.ones(batch_size, generate_length).long().to(net_device) * self.eos_token_id 143 | if not sampling: 144 | results = self.transformer.generate( 145 | inputs_embeds=processed_point_feature, 146 | max_new_tokens=generate_length, # all faces plus two 147 | num_beams=1, 148 | bos_token_id=self.bos_token_id, 149 | eos_token_id=self.eos_token_id, 150 | pad_token_id=self.pad_token_id, 151 | ) 152 | else: 153 | results = self.transformer.generate( 154 | inputs_embeds = processed_point_feature, 155 | max_new_tokens=generate_length, # all faces plus two 156 | do_sample=True, 157 | top_k=50, 158 | top_p=0.95, 159 | bos_token_id = self.bos_token_id, 160 | eos_token_id = self.eos_token_id, 161 | pad_token_id = self.pad_token_id, 162 | ) 163 | assert results.shape[1] <= generate_length # B x ID bos is not included since it's predicted 164 | outputs[:, :results.shape[1]] = results 165 | # batch x ntokens ====> batch x ntokens x D 166 | outputs = outputs[:, 1: -1] 167 | 168 | outputs[outputs == self.bos_token_id] = self.tokenizer.pad_id 169 | outputs[outputs == self.eos_token_id] = self.tokenizer.pad_id 170 | outputs[outputs == self.pad_token_id] = self.tokenizer.pad_id 171 | 172 | outputs[outputs != self.tokenizer.pad_id] -= 3 173 | code_embed = self.get_codes(outputs) 174 | decoder_output = self.tokenizer(outputs, code_embed, point_feature=point_feature) 175 | 176 | return decoder_output 177 | 178 | def get_codes(self, indices): 179 | indices = indices.reshape(indices.shape[0], -1) 180 | 181 | indices = rearrange(indices, 'b (n q) -> b n q', q=self.num_quantizers) 182 | 183 | batch, quantize_dim = indices.shape[0], indices.shape[-1] 184 | # may also receive indices in the shape of 'b h w q' (accept_image_fmap) 185 | 186 | indices, ps = pack([indices], 'b * q') 187 | 188 | # because of quantize dropout, one can pass in indices that are coarse 189 | # and the network should be able to reconstruct 190 | 191 | if quantize_dim < self.num_quantizers: 192 | indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1) 193 | 194 | # take care of quantizer dropout 195 | 196 | mask = indices == -1. 197 | indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later 198 | 199 | # dummy implementation for shared codebook 200 | all_codes = self.transformer.model.decoder.quantize_codebooks[0][indices] 201 | all_codes = all_codes.permute(2, 0, 1, 3) 202 | 203 | # mask out any codes that were dropout-ed 204 | 205 | all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.) 206 | 207 | # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) 208 | 209 | codes, = unpack(all_codes, ps, 'q b * d') 210 | 211 | codes_summed = reduce(codes, 'q ... -> ...', 'sum') 212 | return codes_summed 213 | 214 | def undiscretize( 215 | t, 216 | low, 217 | high, 218 | num_discrete 219 | ) -> Tensor: 220 | t = t.float() 221 | 222 | t /= num_discrete 223 | return t * (high - low) + low 224 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

MeshAnything:
Artist-Created Mesh Generation
with Autoregressive Transformers

3 | 4 |

5 | Yiwen Chen1,2*, 6 | Tong He2†, 7 | Di Huang2, 8 | Weicai Ye2, 9 | Sijin Chen3, 10 | Jiaxiang Tang4
11 | Xin Chen5, 12 | Zhongang Cai6, 13 | Lei Yang6, 14 | Gang Yu7, 15 | Guosheng Lin1†, 16 | Chi Zhang8† 17 |
18 | *Work done during a research internship at Shanghai AI Lab. 19 |
20 | Corresponding authors. 21 |
22 | 1S-Lab, Nanyang Technological University, 23 | 2Shanghai AI Lab, 24 |
25 | 3Fudan University, 26 | 4Peking University, 27 | 5University of Chinese Academy of Sciences, 28 |
29 | 6SenseTime Research, 30 | 7Stepfun, 31 | 8Westlake University 32 |

33 | 34 | 35 |
36 | 37 |      38 |      39 |      40 |      41 | 42 | 43 |
44 | 45 | 46 |

47 | Demo GIF 48 |

49 | 50 | 51 | ## Release 52 | - [6/17] 🔥🔥 Try our newly released **[MeshAnything V2](https://github.com/buaacyw/MeshAnythingV2)**. Maximum face number is increased to **1600** in V2 with better performance. 53 | - [6/17] We released the 350m version of **MeshAnything**. 54 | 55 | ## Contents 56 | - [Release](#release) 57 | - [Contents](#contents) 58 | - [Installation](#installation) 59 | - [Usage](#usage) 60 | - [Important Notes](#important-notes) 61 | - [Training](#training) 62 | - [Acknowledgement](#acknowledgement) 63 | - [Star History](#star-history) 64 | - [BibTeX](#bibtex) 65 | 66 | ## Installation 67 | Our environment has been tested on Ubuntu 22, CUDA 11.8 with A100, A800 and A6000. 68 | 1. Clone our repo and create conda environment 69 | ``` 70 | git clone https://github.com/buaacyw/MeshAnything.git && cd MeshAnything 71 | conda create -n MeshAnything python==3.10.13 -y 72 | conda activate MeshAnything 73 | pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118 74 | pip install -r requirements.txt 75 | pip install flash-attn --no-build-isolation 76 | ``` 77 | or 78 | ```shell 79 | pip install git+https://github.com/buaacyw/MeshAnything.git 80 | ``` 81 | And directly use in your code as 82 | ``` 83 | import MeshAnything 84 | ``` 85 | 86 | ## Usage 87 | ### Local Gradio Demo 88 | ``` 89 | python app.py 90 | ``` 91 | 92 | ### Mesh Command line inference 93 | ``` 94 | # folder input 95 | python main.py --input_dir examples --out_dir mesh_output --input_type mesh 96 | 97 | # single file input 98 | python main.py --input_path examples/wand.obj --out_dir mesh_output --input_type mesh 99 | 100 | # Preprocess with Marching Cubes first 101 | python main.py --input_dir examples --out_dir mesh_output --input_type mesh --mc 102 | ``` 103 | ### Point Cloud Command line inference 104 | ``` 105 | # Note: if you want to use your own point cloud, please make sure the normal is included. 106 | # The file format should be a .npy file with shape (N, 6), where N is the number of points. The first 3 columns are the coordinates, and the last 3 columns are the normal. 107 | 108 | # inference for folder 109 | python main.py --input_dir pc_examples --out_dir pc_output --input_type pc_normal 110 | 111 | # inference for single file 112 | python main.py --input_path pc_examples/mouse.npy --out_dir pc_output --input_type pc_normal 113 | ``` 114 | 115 | ## Important Notes 116 | - It takes about 7GB and 30s to generate a mesh on an A6000 GPU. 117 | - The input mesh will be normalized to a unit bounding box. The up vector of the input mesh should be +Y for better results. 118 | - Limited by computational resources, MeshAnything is trained on meshes with fewer than 800 faces and cannot generate meshes with more than 800 faces. The shape of the input mesh should be sharp enough; otherwise, it will be challenging to represent it with only 800 faces. Thus, feed-forward 3D generation methods may often produce bad results due to insufficient shape quality. We suggest using results from 3D reconstruction, scanning and SDS-based method (like [DreamCraft3D](https://github.com/deepseek-ai/DreamCraft3D)) as the input of MeshAnything. 119 | - Please refer to https://huggingface.co/spaces/Yiwen-ntu/MeshAnything/tree/main/examples for more examples. 120 | 121 | ## Training 122 | Please refer to the training code of MeshAnythingV2 at https://github.com/buaacyw/MeshAnythingV2. 123 | 124 | ## Acknowledgement 125 | 126 | Our code is based on these wonderful repos: 127 | 128 | * [MeshGPT](https://nihalsid.github.io/mesh-gpt/) 129 | * [meshgpt-pytorch](https://github.com/lucidrains/meshgpt-pytorch) 130 | * [Michelangelo](https://github.com/NeuralCarver/Michelangelo) 131 | * [transformers](https://github.com/huggingface/transformers) 132 | * [vector-quantize-pytorch](https://github.com/lucidrains/vector-quantize-pytorch) 133 | 134 | ## Star History 135 | 136 | [![Star History Chart](https://api.star-history.com/svg?repos=buaacyw/MeshAnything&type=Date)](https://star-history.com/#buaacyw/MeshAnything&Date) 137 | 138 | ## BibTeX 139 | ``` 140 | @misc{chen2024meshanything, 141 | title={MeshAnything: Artist-Created Mesh Generation with Autoregressive Transformers}, 142 | author={Yiwen Chen and Tong He and Di Huang and Weicai Ye and Sijin Chen and Jiaxiang Tang and Xin Chen and Zhongang Cai and Lei Yang and Gang Yu and Guosheng Lin and Chi Zhang}, 143 | year={2024}, 144 | eprint={2406.10163}, 145 | archivePrefix={arXiv}, 146 | primaryClass={cs.CV} 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import trimesh 4 | from accelerate.utils import set_seed 5 | from accelerate import Accelerator 6 | import numpy as np 7 | import gradio as gr 8 | from main import get_args, load_model 9 | from mesh_to_pc import process_mesh_to_pc 10 | import time 11 | import matplotlib.pyplot as plt 12 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 13 | from PIL import Image 14 | import io 15 | 16 | args = get_args() 17 | model = load_model(args) 18 | 19 | device = torch.device('cuda') 20 | accelerator = Accelerator( 21 | mixed_precision="fp16", 22 | ) 23 | model = accelerator.prepare(model) 24 | model.eval() 25 | print("Model loaded to device") 26 | 27 | def wireframe_render(mesh): 28 | views = [ 29 | (90, 20), (270, 20) 30 | ] 31 | mesh.vertices = mesh.vertices[:, [0, 2, 1]] 32 | 33 | bounding_box = mesh.bounds 34 | center = mesh.centroid 35 | scale = np.ptp(bounding_box, axis=0).max() 36 | 37 | fig = plt.figure(figsize=(10, 10)) 38 | 39 | # Function to render and return each view as an image 40 | def render_view(mesh, azimuth, elevation): 41 | ax = fig.add_subplot(111, projection='3d') 42 | ax.set_axis_off() 43 | 44 | # Extract vertices and faces for plotting 45 | vertices = mesh.vertices 46 | faces = mesh.faces 47 | 48 | # Plot faces 49 | ax.add_collection3d(Poly3DCollection( 50 | vertices[faces], 51 | facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow 52 | edgecolors='k', 53 | linewidths=0.5, 54 | )) 55 | 56 | # Set limits and center the view on the object 57 | ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2) 58 | ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2) 59 | ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2) 60 | 61 | # Set view angle 62 | ax.view_init(elev=elevation, azim=azimuth) 63 | 64 | # Save the figure to a buffer 65 | buf = io.BytesIO() 66 | plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300) 67 | plt.clf() 68 | buf.seek(0) 69 | 70 | return Image.open(buf) 71 | 72 | # Render each view and store in a list 73 | images = [render_view(mesh, az, el) for az, el in views] 74 | 75 | # Combine images horizontally 76 | widths, heights = zip(*(i.size for i in images)) 77 | total_width = sum(widths) 78 | max_height = max(heights) 79 | 80 | combined_image = Image.new('RGBA', (total_width, max_height)) 81 | 82 | x_offset = 0 83 | for img in images: 84 | combined_image.paste(img, (x_offset, 0)) 85 | x_offset += img.width 86 | 87 | # Save the combined image 88 | save_path = f"combined_mesh_view_{int(time.time())}.png" 89 | combined_image.save(save_path) 90 | 91 | plt.close(fig) 92 | return save_path 93 | 94 | @torch.no_grad() 95 | def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False): 96 | set_seed(sample_seed) 97 | print("Seed value:", sample_seed) 98 | 99 | input_mesh = trimesh.load(input_3d) 100 | pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes) 101 | pc_normal = pc_list[0] # 4096, 6 102 | mesh = mesh_list[0] 103 | vertices = mesh.vertices 104 | 105 | pc_coor = pc_normal[:, :3] 106 | normals = pc_normal[:, 3:] 107 | 108 | bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) 109 | # scale mesh and pc 110 | vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 111 | vertices = vertices / (bounds[1] - bounds[0]).max() 112 | mesh.vertices = vertices 113 | pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 114 | pc_coor = pc_coor / (bounds[1] - bounds[0]).max() 115 | 116 | mesh.merge_vertices() 117 | mesh.update_faces(mesh.unique_faces()) 118 | mesh.fix_normals() 119 | if mesh.visual.vertex_colors is not None: 120 | orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) 121 | 122 | mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1)) 123 | else: 124 | orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) 125 | mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1)) 126 | input_save_name = f"processed_input_{int(time.time())}.obj" 127 | mesh.export(input_save_name) 128 | input_render_res = wireframe_render(mesh) 129 | 130 | pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995 # input should be from -1 to 1 131 | 132 | assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong" 133 | normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) 134 | 135 | input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] 136 | print("Data loaded") 137 | 138 | # with accelerator.autocast(): 139 | with accelerator.autocast(): 140 | outputs = model(input, do_sampling) 141 | print("Model inference done") 142 | recon_mesh = outputs[0] 143 | 144 | recon_mesh = recon_mesh[~torch.isnan(recon_mesh[:, 0, 0])] # nvalid_face x 3 x 3 145 | vertices = recon_mesh.reshape(-1, 3).cpu() 146 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face 147 | triangles = vertices_index.reshape(-1, 3) 148 | 149 | artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh", 150 | merge_primitives=True) 151 | artist_mesh.merge_vertices() 152 | artist_mesh.update_faces(artist_mesh.unique_faces()) 153 | artist_mesh.fix_normals() 154 | 155 | if artist_mesh.visual.vertex_colors is not None: 156 | orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) 157 | 158 | artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1)) 159 | else: 160 | orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) 161 | artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1)) 162 | 163 | num_faces = len(artist_mesh.faces) 164 | 165 | brown_color = np.array([165, 42, 42, 255], dtype=np.uint8) 166 | face_colors = np.tile(brown_color, (num_faces, 1)) 167 | 168 | artist_mesh.visual.face_colors = face_colors 169 | # add time stamp to avoid cache 170 | save_name = f"output_{int(time.time())}.obj" 171 | artist_mesh.export(save_name) 172 | output_render = wireframe_render(artist_mesh) 173 | return input_save_name, input_render_res, save_name, output_render 174 | 175 | 176 | _HEADER_ = ''' 177 |

Official ? Gradio Demo

MeshAnything: Artist-Created Mesh Generation with Autoregressive Transformers

178 | 179 | **MeshAnything** converts any 3D representation into meshes created by human artists, i.e., Artist-Created Meshes (AMs). 180 | 181 | Code: GitHub. Arxiv Paper: ArXiv. 182 | 183 | ??????**Important Notes:** 184 | - Gradio doesn't support interactive wireframe rendering currently. For interactive mesh visualization, please use download the obj file and open it with MeshLab or https://3dviewer.net/. 185 | - The input mesh will be normalized to a unit bounding box. The up vector of the input mesh should be +Y for better results. Click **Preprocess with Marching Cubes** if the input mesh is a manually created mesh. 186 | - Limited by computational resources, MeshAnything is trained on meshes with fewer than 800 faces and cannot generate meshes with more than 800 faces. The shape of the input mesh should be sharp enough; otherwise, it will be challenging to represent it with only 800 faces. Thus, feed-forward image-to-3D methods may often produce bad results due to insufficient shape quality. 187 | - For point cloud input, please refer to our github repo GitHub. 188 | ''' 189 | 190 | 191 | _CITE_ = r""" 192 | If MeshAnything is helpful, please help to ? the Github Repo. Thanks! 193 | --- 194 | ? **License** 195 | 196 | S-Lab-1.0 LICENSE. Please refer to the [LICENSE file](https://github.com/buaacyw/GaussianEditor/blob/master/LICENSE.txt) for details. 197 | 198 | ? **Contact** 199 | 200 | If you have any questions, feel free to open a discussion or contact us at yiwen002@e.ntu.edu.sg. 201 | 202 | """ 203 | output_model_obj = gr.Model3D( 204 | label="Generated Mesh (OBJ Format)", 205 | clear_color=[1, 1, 1, 1], 206 | ) 207 | preprocess_model_obj = gr.Model3D( 208 | label="Processed Input Mesh (OBJ Format)", 209 | clear_color=[1, 1, 1, 1], 210 | ) 211 | input_image_render = gr.Image( 212 | label="Wireframe Render of Processed Input Mesh", 213 | ) 214 | output_image_render = gr.Image( 215 | label="Wireframe Render of Generated Mesh", 216 | ) 217 | with (gr.Blocks() as demo): 218 | gr.Markdown(_HEADER_) 219 | with gr.Row(variant="panel"): 220 | with gr.Column(): 221 | with gr.Row(): 222 | input_3d = gr.Model3D( 223 | label="Input Mesh", 224 | clear_color=[1,1,1,1], 225 | ) 226 | 227 | with gr.Row(): 228 | with gr.Group(): 229 | do_marching_cubes = gr.Checkbox(label="Preprocess with Marching Cubes", value=False) 230 | do_sampling = gr.Checkbox(label="Random Sampling", value=False) 231 | sample_seed = gr.Number(value=0, label="Seed Value", precision=0) 232 | 233 | with gr.Row(): 234 | submit = gr.Button("Generate", elem_id="generate", variant="primary") 235 | 236 | with gr.Row(variant="panel"): 237 | mesh_examples = gr.Examples( 238 | examples=[ 239 | os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples")) 240 | ], 241 | inputs=input_3d, 242 | outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render], 243 | fn=do_inference, 244 | cache_examples = False, 245 | examples_per_page=10 246 | ) 247 | with gr.Column(): 248 | with gr.Row(): 249 | input_image_render.render() 250 | with gr.Row(): 251 | with gr.Tab("OBJ"): 252 | preprocess_model_obj.render() 253 | with gr.Row(): 254 | output_image_render.render() 255 | with gr.Row(): 256 | with gr.Tab("OBJ"): 257 | output_model_obj.render() 258 | with gr.Row(): 259 | gr.Markdown('''Try click random sampling and different Seed Value if the result is unsatisfying''') 260 | 261 | gr.Markdown(_CITE_) 262 | 263 | mv_images = gr.State() 264 | 265 | submit.click( 266 | fn=do_inference, 267 | inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes], 268 | outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render], 269 | ) 270 | 271 | demo.launch(share=True) -------------------------------------------------------------------------------- /demo/demo_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buaacyw/MeshAnything/7a3cd736a5caa48950af40fdb11a5f0185229e85/demo/demo_video.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, argparse, importlib 2 | import torch 3 | import time 4 | import trimesh 5 | import numpy as np 6 | from MeshAnything.models.meshanything import MeshAnything 7 | import datetime 8 | from accelerate import Accelerator 9 | from accelerate.utils import set_seed 10 | from accelerate.utils import DistributedDataParallelKwargs 11 | from safetensors import safe_open 12 | from mesh_to_pc import process_mesh_to_pc 13 | from huggingface_hub import hf_hub_download 14 | 15 | class Dataset: 16 | def __init__(self, input_type, input_list, mc=False): 17 | super().__init__() 18 | self.data = [] 19 | if input_type == 'pc_normal': 20 | for input_path in input_list: 21 | # load npy 22 | cur_data = np.load(input_path) 23 | # sample 4096 24 | assert cur_data.shape[0] >= 4096, "input pc_normal should have at least 4096 points" 25 | idx = np.random.choice(cur_data.shape[0], 4096, replace=False) 26 | cur_data = cur_data[idx] 27 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) 28 | 29 | elif input_type == 'mesh': 30 | mesh_list = [] 31 | for input_path in input_list: 32 | # load ply 33 | cur_data = trimesh.load(input_path) 34 | mesh_list.append(cur_data) 35 | if mc: 36 | print("First Marching Cubes and then sample point cloud, need several minutes...") 37 | pc_list, _ = process_mesh_to_pc(mesh_list, marching_cubes=mc) 38 | for input_path, cur_data in zip(input_list, pc_list): 39 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) 40 | print(f"dataset total data samples: {len(self.data)}") 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def __getitem__(self, idx): 46 | data_dict = {} 47 | data_dict['pc_normal'] = self.data[idx]['pc_normal'] 48 | # normalize pc coor 49 | pc_coor = data_dict['pc_normal'][:, :3] 50 | normals = data_dict['pc_normal'][:, 3:] 51 | bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)]) 52 | pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 53 | pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995 54 | assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong" 55 | data_dict['pc_normal'] = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) 56 | data_dict['uid'] = self.data[idx]['uid'] 57 | 58 | return data_dict 59 | 60 | def get_args(): 61 | parser = argparse.ArgumentParser("MeshAnything", add_help=False) 62 | 63 | parser.add_argument('--llm', default="facebook/opt-350m", type=str) 64 | parser.add_argument('--input_dir', default=None, type=str) 65 | parser.add_argument('--input_path', default=None, type=str) 66 | 67 | parser.add_argument('--out_dir', default="inference_out", type=str) 68 | parser.add_argument('--pretrained_weights', default="MeshAnything_350m.pth", type=str) 69 | 70 | parser.add_argument( 71 | '--input_type', 72 | choices=['mesh','pc_normal'], 73 | default='pc', 74 | help="Type of the asset to process (default: pc)" 75 | ) 76 | 77 | parser.add_argument("--codebook_size", default=8192, type=int) 78 | parser.add_argument("--codebook_dim", default=1024, type=int) 79 | 80 | parser.add_argument("--n_max_triangles", default=800, type=int) 81 | 82 | parser.add_argument("--batchsize_per_gpu", default=1, type=int) 83 | parser.add_argument("--seed", default=0, type=int) 84 | 85 | parser.add_argument("--mc", default=False, action="store_true") 86 | parser.add_argument("--sampling", default=False, action="store_true") 87 | 88 | args = parser.parse_args() 89 | return args 90 | 91 | def load_model(args): 92 | model = MeshAnything(args) 93 | print("load model over!!!") 94 | 95 | ckpt_path = hf_hub_download( 96 | repo_id="Yiwen-ntu/MeshAnything", 97 | filename="MeshAnything_350m.pth", 98 | ) 99 | tensors = {} 100 | with safe_open(ckpt_path, framework="pt", device=0) as f: 101 | for k in f.keys(): 102 | tensors[k] = f.get_tensor(k) 103 | 104 | model.load_state_dict(tensors, strict=True) 105 | print("load weights over!!!") 106 | return model 107 | if __name__ == "__main__": 108 | args = get_args() 109 | 110 | cur_time = datetime.datetime.now().strftime("%d_%H-%M-%S") 111 | checkpoint_dir = os.path.join(args.out_dir, cur_time) 112 | os.makedirs(checkpoint_dir, exist_ok=True) 113 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 114 | accelerator = Accelerator( 115 | mixed_precision="fp16", 116 | project_dir=checkpoint_dir, 117 | kwargs_handlers=[kwargs] 118 | ) 119 | 120 | model = load_model(args) 121 | # create dataset 122 | if args.input_dir is not None: 123 | input_list = sorted(os.listdir(args.input_dir)) 124 | # only ply, obj or npy 125 | if args.input_type == 'pc_normal': 126 | input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.npy')] 127 | else: 128 | input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.ply') or x.endswith('.obj') or x.endswith('.npy')] 129 | set_seed(args.seed) 130 | dataset = Dataset(args.input_type, input_list, args.mc) 131 | elif args.input_path is not None: 132 | set_seed(args.seed) 133 | dataset = Dataset(args.input_type, [args.input_path], args.mc) 134 | else: 135 | raise ValueError("input_dir or input_path must be provided.") 136 | 137 | dataloader = torch.utils.data.DataLoader( 138 | dataset, 139 | batch_size=args.batchsize_per_gpu, 140 | drop_last = False, 141 | shuffle = False, 142 | ) 143 | 144 | if accelerator.state.num_processes > 1: 145 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 146 | dataloader, model = accelerator.prepare(dataloader, model) 147 | begin_time = time.time() 148 | print("Generation Start!!!") 149 | with accelerator.autocast(): 150 | for curr_iter, batch_data_label in enumerate(dataloader): 151 | curr_time = time.time() 152 | outputs = model(batch_data_label['pc_normal'], sampling=args.sampling) 153 | batch_size = outputs.shape[0] 154 | device = outputs.device 155 | 156 | for batch_id in range(batch_size): 157 | recon_mesh = outputs[batch_id] 158 | recon_mesh = recon_mesh[~torch.isnan(recon_mesh[:, 0, 0])] # nvalid_face x 3 x 3 159 | vertices = recon_mesh.reshape(-1, 3).cpu() 160 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face 161 | triangles = vertices_index.reshape(-1, 3) 162 | 163 | scene_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh", 164 | merge_primitives=True) 165 | scene_mesh.merge_vertices() 166 | scene_mesh.update_faces(scene_mesh.unique_faces()) 167 | scene_mesh.fix_normals() 168 | save_path = os.path.join(checkpoint_dir, f'{batch_data_label["uid"][batch_id]}_gen.obj') 169 | num_faces = len(scene_mesh.faces) 170 | brown_color = np.array([255, 165, 0, 255], dtype=np.uint8) 171 | face_colors = np.tile(brown_color, (num_faces, 1)) 172 | 173 | scene_mesh.visual.face_colors = face_colors 174 | scene_mesh.export(save_path) 175 | print(f"{save_path} Over!!") 176 | end_time = time.time() 177 | print(f"Total time: {end_time - begin_time}") -------------------------------------------------------------------------------- /mesh_to_pc.py: -------------------------------------------------------------------------------- 1 | import mesh2sdf.core 2 | import numpy as np 3 | import skimage.measure 4 | import trimesh 5 | 6 | def normalize_vertices(vertices, scale=0.9): 7 | bbmin, bbmax = vertices.min(0), vertices.max(0) 8 | center = (bbmin + bbmax) * 0.5 9 | scale = 2.0 * scale / (bbmax - bbmin).max() 10 | vertices = (vertices - center) * scale 11 | return vertices, center, scale 12 | 13 | def export_to_watertight(normalized_mesh, octree_depth: int = 7): 14 | """ 15 | Convert the non-watertight mesh to watertight. 16 | 17 | Args: 18 | input_path (str): normalized path 19 | octree_depth (int): 20 | 21 | Returns: 22 | mesh(trimesh.Trimesh): watertight mesh 23 | 24 | """ 25 | size = 2 ** octree_depth 26 | level = 2 / size 27 | 28 | scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices) 29 | 30 | sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size) 31 | 32 | vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level) 33 | 34 | # watertight mesh 35 | vertices = vertices / size * 2 - 1 # -1 to 1 36 | vertices = vertices / to_orig_scale + to_orig_center 37 | # vertices = vertices / to_orig_scale + to_orig_center 38 | mesh = trimesh.Trimesh(vertices, faces, normals=normals) 39 | 40 | return mesh 41 | 42 | def process_mesh_to_pc(mesh_list, marching_cubes = False, sample_num = 4096): 43 | # mesh_list : list of trimesh 44 | pc_normal_list = [] 45 | return_mesh_list = [] 46 | for mesh in mesh_list: 47 | if marching_cubes: 48 | mesh = export_to_watertight(mesh) 49 | print("MC over!") 50 | return_mesh_list.append(mesh) 51 | points, face_idx = mesh.sample(sample_num, return_index=True) 52 | normals = mesh.face_normals[face_idx] 53 | 54 | pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16) 55 | pc_normal_list.append(pc_normal) 56 | print("process mesh success") 57 | return pc_normal_list, return_mesh_list 58 | 59 | -------------------------------------------------------------------------------- /pc_examples/mouse.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buaacyw/MeshAnything/7a3cd736a5caa48950af40fdb11a5f0185229e85/pc_examples/mouse.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trimesh==4.2.3 2 | accelerate==0.28.0 3 | mesh2sdf==1.1.0 4 | einops==0.7.0 5 | einx==0.1.3 6 | optimum==1.18.0 7 | omegaconf==2.3.0 8 | opencv-python==4.9.0.80 9 | transformers==4.39.3 10 | huggingface_hub 11 | matplotlib 12 | gradio 13 | spaces -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup, find_packages 3 | 4 | setup_path = Path(__file__).parent 5 | README = (setup_path / "README.md").read_text(encoding="utf-8") 6 | 7 | with open("README.md", "r") as fh: 8 | long_description = fh.read() 9 | 10 | def split_requirements(requirements): 11 | install_requires = [] 12 | dependency_links = [] 13 | for requirement in requirements: 14 | if requirement.startswith("git+"): 15 | dependency_links.append(requirement) 16 | else: 17 | install_requires.append(requirement) 18 | 19 | return install_requires, dependency_links 20 | 21 | with open("./requirements.txt", "r") as f: 22 | requirements = f.read().splitlines() 23 | 24 | install_requires, dependency_links = split_requirements(requirements) 25 | 26 | setup( 27 | name = "MeshAnything", 28 | packages=find_packages(), 29 | description=long_description, 30 | long_description=README, 31 | install_requires=install_requires 32 | ) 33 | --------------------------------------------------------------------------------