├── LICENSE.txt ├── MeshAnything ├── miche │ ├── LICENSE │ ├── encode.py │ ├── michelangelo │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── templates.json │ │ │ ├── transforms.py │ │ │ └── utils.py │ │ ├── graphics │ │ │ ├── __init__.py │ │ │ └── primitives │ │ │ │ ├── __init__.py │ │ │ │ ├── mesh.py │ │ │ │ └── volume.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── asl_diffusion │ │ │ │ ├── __init__.py │ │ │ │ ├── asl_diffuser_pl_module.py │ │ │ │ ├── asl_udt.py │ │ │ │ ├── base.py │ │ │ │ ├── clip_asl_diffuser_pl_module.py │ │ │ │ └── inference_utils.py │ │ │ ├── conditional_encoders │ │ │ │ ├── __init__.py │ │ │ │ ├── clip.py │ │ │ │ └── encoder_factory.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── checkpoint.py │ │ │ │ ├── diffusion_transformer.py │ │ │ │ ├── distributions.py │ │ │ │ ├── embedder.py │ │ │ │ ├── transformer_blocks.py │ │ │ │ └── transformer_vit.py │ │ │ └── tsal │ │ │ │ ├── __init__.py │ │ │ │ ├── asl_pl_module.py │ │ │ │ ├── clip_asl_module.py │ │ │ │ ├── inference_utils.py │ │ │ │ ├── loss.py │ │ │ │ ├── sal_perceiver.py │ │ │ │ ├── sal_pl_module.py │ │ │ │ └── tsal_base.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── eval.py │ │ │ ├── io.py │ │ │ ├── misc.py │ │ │ └── visualizers │ │ │ ├── __init__.py │ │ │ ├── color_util.py │ │ │ ├── html_util.py │ │ │ └── pythreejs_viewer.py │ └── shapevae-256.yaml └── models │ ├── meshanything.py │ └── shape_opt.py ├── README.md ├── app.py ├── demo └── demo_video.gif ├── examples ├── screwdriver.obj └── wand.obj ├── main.py ├── mesh_to_pc.py ├── pc_examples └── mouse.npy ├── requirements.txt └── setup.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /MeshAnything/miche/encode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | from omegaconf import OmegaConf 4 | import numpy as np 5 | import torch 6 | from .michelangelo.utils.misc import instantiate_from_config 7 | 8 | def load_surface(fp): 9 | 10 | with np.load(fp) as input_pc: 11 | surface = input_pc['points'] 12 | normal = input_pc['normals'] 13 | 14 | rng = np.random.default_rng() 15 | ind = rng.choice(surface.shape[0], 4096, replace=False) 16 | surface = torch.FloatTensor(surface[ind]) 17 | normal = torch.FloatTensor(normal[ind]) 18 | 19 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() 20 | 21 | return surface 22 | 23 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): 24 | 25 | surface = load_surface(args.pointcloud_path) 26 | # old_surface = surface.clone() 27 | 28 | # surface[0,:,0]*=-1 29 | # surface[0,:,1]*=-1 30 | surface[0,:,2]*=-1 31 | 32 | # encoding 33 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) 34 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) 35 | 36 | # decoding 37 | latents = model.model.shape_model.decode(shape_zq) 38 | # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) 39 | 40 | return 0 41 | 42 | def load_model(ckpt_path="MeshAnything/miche/shapevae-256.ckpt"): 43 | model_config = OmegaConf.load("MeshAnything/miche/shapevae-256.yaml") 44 | # print(model_config) 45 | if hasattr(model_config, "model"): 46 | model_config = model_config.model 47 | 48 | model = instantiate_from_config(model_config, ckpt_path=ckpt_path) 49 | model = model.cuda() 50 | model = model.eval() 51 | 52 | return model 53 | if __name__ == "__main__": 54 | ''' 55 | 1. Reconstruct point cloud 56 | 2. Image-conditioned generation 57 | 3. Text-conditioned generation 58 | ''' 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--config_path", type=str, required=True) 61 | parser.add_argument("--ckpt_path", type=str, required=True) 62 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud') 63 | parser.add_argument("--image_path", type=str, help='Path to the input image') 64 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.') 65 | parser.add_argument("--output_dir", type=str, default='./output') 66 | parser.add_argument("-s", "--seed", type=int, default=0) 67 | args = parser.parse_args() 68 | 69 | print(f'-----------------------------------------------------------------------------') 70 | print(f'>>> Output directory: {args.output_dir}') 71 | print(f'-----------------------------------------------------------------------------') 72 | 73 | reconstruction(args, load_model(args)) -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "shape": [ 3 | "a point cloud model of {}.", 4 | "There is a {} in the scene.", 5 | "There is the {} in the scene.", 6 | "a photo of a {} in the scene.", 7 | "a photo of the {} in the scene.", 8 | "a photo of one {} in the scene.", 9 | "itap of a {}.", 10 | "itap of my {}.", 11 | "itap of the {}.", 12 | "a photo of a {}.", 13 | "a photo of my {}.", 14 | "a photo of the {}.", 15 | "a photo of one {}.", 16 | "a photo of many {}.", 17 | "a good photo of a {}.", 18 | "a good photo of the {}.", 19 | "a bad photo of a {}.", 20 | "a bad photo of the {}.", 21 | "a photo of a nice {}.", 22 | "a photo of the nice {}.", 23 | "a photo of a cool {}.", 24 | "a photo of the cool {}.", 25 | "a photo of a weird {}.", 26 | "a photo of the weird {}.", 27 | "a photo of a small {}.", 28 | "a photo of the small {}.", 29 | "a photo of a large {}.", 30 | "a photo of the large {}.", 31 | "a photo of a clean {}.", 32 | "a photo of the clean {}.", 33 | "a photo of a dirty {}.", 34 | "a photo of the dirty {}.", 35 | "a bright photo of a {}.", 36 | "a bright photo of the {}.", 37 | "a dark photo of a {}.", 38 | "a dark photo of the {}.", 39 | "a photo of a hard to see {}.", 40 | "a photo of the hard to see {}.", 41 | "a low resolution photo of a {}.", 42 | "a low resolution photo of the {}.", 43 | "a cropped photo of a {}.", 44 | "a cropped photo of the {}.", 45 | "a close-up photo of a {}.", 46 | "a close-up photo of the {}.", 47 | "a jpeg corrupted photo of a {}.", 48 | "a jpeg corrupted photo of the {}.", 49 | "a blurry photo of a {}.", 50 | "a blurry photo of the {}.", 51 | "a pixelated photo of a {}.", 52 | "a pixelated photo of the {}.", 53 | "a black and white photo of the {}.", 54 | "a black and white photo of a {}", 55 | "a plastic {}.", 56 | "the plastic {}.", 57 | "a toy {}.", 58 | "the toy {}.", 59 | "a plushie {}.", 60 | "the plushie {}.", 61 | "a cartoon {}.", 62 | "the cartoon {}.", 63 | "an embroidered {}.", 64 | "the embroidered {}.", 65 | "a painting of the {}.", 66 | "a painting of a {}." 67 | ] 68 | 69 | } -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import numpy as np 5 | import warnings 6 | import random 7 | from omegaconf.listconfig import ListConfig 8 | from webdataset import pipelinefilter 9 | import torch 10 | import torchvision.transforms.functional as TVF 11 | from torchvision.transforms import InterpolationMode 12 | from torchvision.transforms.transforms import _interpolation_modes_from_int 13 | from typing import Sequence 14 | 15 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 16 | 17 | 18 | def _uid_buffer_pick(buf_dict, rng): 19 | uid_keys = list(buf_dict.keys()) 20 | selected_uid = rng.choice(uid_keys) 21 | buf = buf_dict[selected_uid] 22 | 23 | k = rng.randint(0, len(buf) - 1) 24 | sample = buf[k] 25 | buf[k] = buf[-1] 26 | buf.pop() 27 | 28 | if len(buf) == 0: 29 | del buf_dict[selected_uid] 30 | 31 | return sample 32 | 33 | 34 | def _add_to_buf_dict(buf_dict, sample): 35 | key = sample["__key__"] 36 | uid, uid_sample_id = key.split("_") 37 | if uid not in buf_dict: 38 | buf_dict[uid] = [] 39 | buf_dict[uid].append(sample) 40 | 41 | return buf_dict 42 | 43 | 44 | def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): 45 | """Shuffle the data in the stream. 46 | 47 | This uses a buffer of size `bufsize`. Shuffling at 48 | startup is less random; this is traded off against 49 | yielding samples quickly. 50 | 51 | data: iterator 52 | bufsize: buffer size for shuffling 53 | returns: iterator 54 | rng: either random module or random.Random instance 55 | 56 | """ 57 | if rng is None: 58 | rng = random.Random(int((os.getpid() + time.time()) * 1e9)) 59 | initial = min(initial, bufsize) 60 | buf_dict = dict() 61 | current_samples = 0 62 | for sample in data: 63 | _add_to_buf_dict(buf_dict, sample) 64 | current_samples += 1 65 | 66 | if current_samples < bufsize: 67 | try: 68 | _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708 69 | current_samples += 1 70 | except StopIteration: 71 | pass 72 | 73 | if current_samples >= initial: 74 | current_samples -= 1 75 | yield _uid_buffer_pick(buf_dict, rng) 76 | 77 | while current_samples > 0: 78 | current_samples -= 1 79 | yield _uid_buffer_pick(buf_dict, rng) 80 | 81 | 82 | uid_shuffle = pipelinefilter(_uid_shuffle) 83 | 84 | 85 | class RandomSample(object): 86 | def __init__(self, 87 | num_volume_samples: int = 1024, 88 | num_near_samples: int = 1024): 89 | 90 | super().__init__() 91 | 92 | self.num_volume_samples = num_volume_samples 93 | self.num_near_samples = num_near_samples 94 | 95 | def __call__(self, sample): 96 | rng = np.random.default_rng() 97 | 98 | # 1. sample surface input 99 | total_surface = sample["surface"] 100 | ind = rng.choice(total_surface.shape[0], replace=False) 101 | surface = total_surface[ind] 102 | 103 | # 2. sample volume/near geometric points 104 | vol_points = sample["vol_points"] 105 | vol_label = sample["vol_label"] 106 | near_points = sample["near_points"] 107 | near_label = sample["near_label"] 108 | 109 | ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) 110 | vol_points = vol_points[ind] 111 | vol_label = vol_label[ind] 112 | vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) 113 | 114 | ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) 115 | near_points = near_points[ind] 116 | near_label = near_label[ind] 117 | near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) 118 | 119 | # concat sampled volume and near points 120 | geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) 121 | 122 | sample = { 123 | "surface": surface, 124 | "geo_points": geo_points 125 | } 126 | 127 | return sample 128 | 129 | 130 | class SplitRandomSample(object): 131 | def __init__(self, 132 | use_surface_sample: bool = False, 133 | num_surface_samples: int = 4096, 134 | num_volume_samples: int = 1024, 135 | num_near_samples: int = 1024): 136 | 137 | super().__init__() 138 | 139 | self.use_surface_sample = use_surface_sample 140 | self.num_surface_samples = num_surface_samples 141 | self.num_volume_samples = num_volume_samples 142 | self.num_near_samples = num_near_samples 143 | 144 | def __call__(self, sample): 145 | 146 | rng = np.random.default_rng() 147 | 148 | # 1. sample surface input 149 | surface = sample["surface"] 150 | 151 | if self.use_surface_sample: 152 | replace = surface.shape[0] < self.num_surface_samples 153 | ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace) 154 | surface = surface[ind] 155 | 156 | # 2. sample volume/near geometric points 157 | vol_points = sample["vol_points"] 158 | vol_label = sample["vol_label"] 159 | near_points = sample["near_points"] 160 | near_label = sample["near_label"] 161 | 162 | ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) 163 | vol_points = vol_points[ind] 164 | vol_label = vol_label[ind] 165 | vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) 166 | 167 | ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) 168 | near_points = near_points[ind] 169 | near_label = near_label[ind] 170 | near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) 171 | 172 | # concat sampled volume and near points 173 | geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) 174 | 175 | sample = { 176 | "surface": surface, 177 | "geo_points": geo_points 178 | } 179 | 180 | return sample 181 | 182 | 183 | class FeatureSelection(object): 184 | 185 | VALID_SURFACE_FEATURE_DIMS = { 186 | "none": [0, 1, 2], # xyz 187 | "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal 188 | "normal": [0, 1, 2, 6, 7, 8] 189 | } 190 | 191 | def __init__(self, surface_feature_type: str): 192 | 193 | self.surface_feature_type = surface_feature_type 194 | self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type] 195 | 196 | def __call__(self, sample): 197 | sample["surface"] = sample["surface"][:, self.surface_dims] 198 | return sample 199 | 200 | 201 | class AxisScaleTransform(object): 202 | def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): 203 | assert isinstance(interval, (tuple, list, ListConfig)) 204 | self.interval = interval 205 | self.min_val = interval[0] 206 | self.max_val = interval[1] 207 | self.inter_size = interval[1] - interval[0] 208 | self.jitter = jitter 209 | self.jitter_scale = jitter_scale 210 | 211 | def __call__(self, sample): 212 | 213 | surface = sample["surface"][..., 0:3] 214 | geo_points = sample["geo_points"][..., 0:3] 215 | 216 | scaling = torch.rand(1, 3) * self.inter_size + self.min_val 217 | # print(scaling) 218 | surface = surface * scaling 219 | geo_points = geo_points * scaling 220 | 221 | scale = (1 / torch.abs(surface).max().item()) * 0.999999 222 | surface *= scale 223 | geo_points *= scale 224 | 225 | if self.jitter: 226 | surface += self.jitter_scale * torch.randn_like(surface) 227 | surface.clamp_(min=-1.015, max=1.015) 228 | 229 | sample["surface"][..., 0:3] = surface 230 | sample["geo_points"][..., 0:3] = geo_points 231 | 232 | return sample 233 | 234 | 235 | class ToTensor(object): 236 | 237 | def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")): 238 | self.tensor_keys = tensor_keys 239 | 240 | def __call__(self, sample): 241 | for key in self.tensor_keys: 242 | if key not in sample: 243 | continue 244 | 245 | sample[key] = torch.tensor(sample[key], dtype=torch.float32) 246 | 247 | return sample 248 | 249 | 250 | class AxisScale(object): 251 | def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): 252 | assert isinstance(interval, (tuple, list, ListConfig)) 253 | self.interval = interval 254 | self.jitter = jitter 255 | self.jitter_scale = jitter_scale 256 | 257 | def __call__(self, surface, *args): 258 | scaling = torch.rand(1, 3) * 0.5 + 0.75 259 | # print(scaling) 260 | surface = surface * scaling 261 | scale = (1 / torch.abs(surface).max().item()) * 0.999999 262 | surface *= scale 263 | 264 | args_outputs = [] 265 | for _arg in args: 266 | _arg = _arg * scaling * scale 267 | args_outputs.append(_arg) 268 | 269 | if self.jitter: 270 | surface += self.jitter_scale * torch.randn_like(surface) 271 | surface.clamp_(min=-1, max=1) 272 | 273 | if len(args) == 0: 274 | return surface 275 | else: 276 | return surface, *args_outputs 277 | 278 | 279 | class RandomResize(torch.nn.Module): 280 | """Apply randomly Resize with a given probability.""" 281 | 282 | def __init__( 283 | self, 284 | size, 285 | resize_radio=(0.5, 1), 286 | allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR), 287 | interpolation=InterpolationMode.BICUBIC, 288 | max_size=None, 289 | antialias=None, 290 | ): 291 | super().__init__() 292 | if not isinstance(size, (int, Sequence)): 293 | raise TypeError(f"Size should be int or sequence. Got {type(size)}") 294 | if isinstance(size, Sequence) and len(size) not in (1, 2): 295 | raise ValueError("If size is a sequence, it should have 1 or 2 values") 296 | 297 | self.size = size 298 | self.max_size = max_size 299 | # Backward compatibility with integer value 300 | if isinstance(interpolation, int): 301 | warnings.warn( 302 | "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " 303 | "Please use InterpolationMode enum." 304 | ) 305 | interpolation = _interpolation_modes_from_int(interpolation) 306 | 307 | self.interpolation = interpolation 308 | self.antialias = antialias 309 | 310 | self.resize_radio = resize_radio 311 | self.allow_resize_interpolations = allow_resize_interpolations 312 | 313 | def random_resize_params(self): 314 | radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0] 315 | 316 | if isinstance(self.size, int): 317 | size = int(self.size * radio) 318 | elif isinstance(self.size, Sequence): 319 | size = list(self.size) 320 | size = (int(size[0] * radio), int(size[1] * radio)) 321 | else: 322 | raise RuntimeError() 323 | 324 | interpolation = self.allow_resize_interpolations[ 325 | torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,)) 326 | ] 327 | return size, interpolation 328 | 329 | def forward(self, img): 330 | size, interpolation = self.random_resize_params() 331 | img = TVF.resize(img, size, interpolation, self.max_size, self.antialias) 332 | img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) 333 | return img 334 | 335 | def __repr__(self) -> str: 336 | detail = f"(size={self.size}, interpolation={self.interpolation.value}," 337 | detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}" 338 | return f"{self.__class__.__name__}{detail}" 339 | 340 | 341 | class Compose(object): 342 | """Composes several transforms together. This transform does not support torchscript. 343 | Please, see the note below. 344 | 345 | Args: 346 | transforms (list of ``Transform`` objects): list of transforms to compose. 347 | 348 | Example: 349 | >>> transforms.Compose([ 350 | >>> transforms.CenterCrop(10), 351 | >>> transforms.ToTensor(), 352 | >>> ]) 353 | 354 | .. note:: 355 | In order to script the transformations, please use ``torch.nn.Sequential`` as below. 356 | 357 | >>> transforms = torch.nn.Sequential( 358 | >>> transforms.CenterCrop(10), 359 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 360 | >>> ) 361 | >>> scripted_transforms = torch.jit.script(transforms) 362 | 363 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require 364 | `lambda` functions or ``PIL.Image``. 365 | 366 | """ 367 | 368 | def __init__(self, transforms): 369 | self.transforms = transforms 370 | 371 | def __call__(self, *args): 372 | for t in self.transforms: 373 | args = t(*args) 374 | return args 375 | 376 | def __repr__(self): 377 | format_string = self.__class__.__name__ + '(' 378 | for t in self.transforms: 379 | format_string += '\n' 380 | format_string += ' {0}'.format(t) 381 | format_string += '\n)' 382 | return format_string 383 | 384 | 385 | def identity(*args, **kwargs): 386 | if len(args) == 1: 387 | return args[0] 388 | else: 389 | return args 390 | 391 | 392 | def build_transforms(cfg): 393 | 394 | if cfg is None: 395 | return identity 396 | 397 | transforms = [] 398 | 399 | for transform_name, cfg_instance in cfg.items(): 400 | transform_instance = instantiate_from_config(cfg_instance) 401 | transforms.append(transform_instance) 402 | print(f"Build transform: {transform_instance}") 403 | 404 | transforms = Compose(transforms) 405 | 406 | return transforms 407 | 408 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/data/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def worker_init_fn(_): 8 | worker_info = torch.utils.data.get_worker_info() 9 | worker_id = worker_info.id 10 | 11 | # dataset = worker_info.dataset 12 | # split_size = dataset.num_records // worker_info.num_workers 13 | # # reset num_records to the true number to retain reliable length information 14 | # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 15 | # current_id = np.random.choice(len(np.random.get_state()[1]), 1) 16 | # return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 17 | 18 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 19 | 20 | 21 | def collation_fn(samples, combine_tensors=True, combine_scalars=True): 22 | """ 23 | 24 | Args: 25 | samples (list[dict]): 26 | combine_tensors: 27 | combine_scalars: 28 | 29 | Returns: 30 | 31 | """ 32 | 33 | result = {} 34 | 35 | keys = samples[0].keys() 36 | 37 | for key in keys: 38 | result[key] = [] 39 | 40 | for sample in samples: 41 | for key in keys: 42 | val = sample[key] 43 | result[key].append(val) 44 | 45 | for key in keys: 46 | val_list = result[key] 47 | if isinstance(val_list[0], (int, float)): 48 | if combine_scalars: 49 | result[key] = np.array(result[key]) 50 | 51 | elif isinstance(val_list[0], torch.Tensor): 52 | if combine_tensors: 53 | result[key] = torch.stack(val_list) 54 | 55 | elif isinstance(val_list[0], np.ndarray): 56 | if combine_tensors: 57 | result[key] = np.stack(val_list) 58 | 59 | return result 60 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .volume import generate_dense_grid_points 4 | 5 | from .mesh import ( 6 | MeshOutput, 7 | save_obj, 8 | savemeshtes2 9 | ) 10 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/mesh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import PIL.Image 7 | from typing import Optional 8 | 9 | import trimesh 10 | 11 | 12 | def save_obj(pointnp_px3, facenp_fx3, fname): 13 | fid = open(fname, "w") 14 | write_str = "" 15 | for pidx, p in enumerate(pointnp_px3): 16 | pp = p 17 | write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2]) 18 | 19 | for i, f in enumerate(facenp_fx3): 20 | f1 = f + 1 21 | write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2]) 22 | fid.write(write_str) 23 | fid.close() 24 | return 25 | 26 | 27 | def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname): 28 | fol, na = os.path.split(fname) 29 | na, _ = os.path.splitext(na) 30 | 31 | matname = "%s/%s.mtl" % (fol, na) 32 | fid = open(matname, "w") 33 | fid.write("newmtl material_0\n") 34 | fid.write("Kd 1 1 1\n") 35 | fid.write("Ka 0 0 0\n") 36 | fid.write("Ks 0.4 0.4 0.4\n") 37 | fid.write("Ns 10\n") 38 | fid.write("illum 2\n") 39 | fid.write("map_Kd %s.png\n" % na) 40 | fid.close() 41 | #### 42 | 43 | fid = open(fname, "w") 44 | fid.write("mtllib %s.mtl\n" % na) 45 | 46 | for pidx, p in enumerate(pointnp_px3): 47 | pp = p 48 | fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2])) 49 | 50 | for pidx, p in enumerate(tcoords_px2): 51 | pp = p 52 | fid.write("vt %f %f\n" % (pp[0], pp[1])) 53 | 54 | fid.write("usemtl material_0\n") 55 | for i, f in enumerate(facenp_fx3): 56 | f1 = f + 1 57 | f2 = facetex_fx3[i] + 1 58 | fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) 59 | fid.close() 60 | 61 | PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save( 62 | os.path.join(fol, "%s.png" % na)) 63 | 64 | return 65 | 66 | 67 | class MeshOutput(object): 68 | 69 | def __init__(self, 70 | mesh_v: np.ndarray, 71 | mesh_f: np.ndarray, 72 | vertex_colors: Optional[np.ndarray] = None, 73 | uvs: Optional[np.ndarray] = None, 74 | mesh_tex_idx: Optional[np.ndarray] = None, 75 | tex_map: Optional[np.ndarray] = None): 76 | 77 | self.mesh_v = mesh_v 78 | self.mesh_f = mesh_f 79 | self.vertex_colors = vertex_colors 80 | self.uvs = uvs 81 | self.mesh_tex_idx = mesh_tex_idx 82 | self.tex_map = tex_map 83 | 84 | def contain_uv_texture(self): 85 | return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None) 86 | 87 | def contain_vertex_colors(self): 88 | return self.vertex_colors is not None 89 | 90 | def export(self, fname): 91 | 92 | if self.contain_uv_texture(): 93 | savemeshtes2( 94 | self.mesh_v, 95 | self.uvs, 96 | self.mesh_f, 97 | self.mesh_tex_idx, 98 | self.tex_map, 99 | fname 100 | ) 101 | 102 | elif self.contain_vertex_colors(): 103 | mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors) 104 | mesh_obj.export(fname) 105 | 106 | else: 107 | save_obj( 108 | self.mesh_v, 109 | self.mesh_f, 110 | fname 111 | ) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/graphics/primitives/volume.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | 6 | def generate_dense_grid_points(bbox_min: np.ndarray, 7 | bbox_max: np.ndarray, 8 | octree_depth: int, 9 | indexing: str = "ij"): 10 | length = bbox_max - bbox_min 11 | num_cells = np.exp2(octree_depth) 12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) 13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) 14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) 15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) 16 | xyz = np.stack((xs, ys, zs), axis=-1) 17 | xyz = xyz.reshape(-1, 3) 18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] 19 | 20 | return xyz, grid_size, length 21 | 22 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from omegaconf import DictConfig 4 | from typing import List, Tuple, Dict, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import lr_scheduler 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.utilities import rank_zero_only 12 | 13 | from einops import rearrange 14 | 15 | from diffusers.schedulers import ( 16 | DDPMScheduler, 17 | DDIMScheduler, 18 | KarrasVeScheduler, 19 | DPMSolverMultistepScheduler 20 | ) 21 | 22 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 23 | # from MeshAnything.miche.michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule 24 | from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule 25 | from MeshAnything.miche.michelangelo.models.asl_diffusion.inference_utils import ddim_sample 26 | 27 | SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] 28 | 29 | 30 | def disabled_train(self, mode=True): 31 | """Overwrite model.train with this function to make sure train/eval mode 32 | does not change anymore.""" 33 | return self 34 | 35 | 36 | class ASLDiffuser(pl.LightningModule): 37 | first_stage_model: Optional[AlignedShapeAsLatentPLModule] 38 | # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] 39 | model: nn.Module 40 | 41 | def __init__(self, *, 42 | first_stage_config, 43 | denoiser_cfg, 44 | scheduler_cfg, 45 | optimizer_cfg, 46 | loss_cfg, 47 | first_stage_key: str = "surface", 48 | cond_stage_key: str = "image", 49 | cond_stage_trainable: bool = True, 50 | scale_by_std: bool = False, 51 | z_scale_factor: float = 1.0, 52 | ckpt_path: Optional[str] = None, 53 | ignore_keys: Union[Tuple[str], List[str]] = ()): 54 | 55 | super().__init__() 56 | 57 | self.first_stage_key = first_stage_key 58 | self.cond_stage_key = cond_stage_key 59 | self.cond_stage_trainable = cond_stage_trainable 60 | 61 | # 1. initialize first stage. 62 | # Note: the condition model contained in the first stage model. 63 | self.first_stage_config = first_stage_config 64 | self.first_stage_model = None 65 | # self.instantiate_first_stage(first_stage_config) 66 | 67 | # 2. initialize conditional stage 68 | # self.instantiate_cond_stage(cond_stage_config) 69 | self.cond_stage_model = { 70 | "image": self.encode_image, 71 | "image_unconditional_embedding": self.empty_img_cond, 72 | "text": self.encode_text, 73 | "text_unconditional_embedding": self.empty_text_cond, 74 | "surface": self.encode_surface, 75 | "surface_unconditional_embedding": self.empty_surface_cond, 76 | } 77 | 78 | # 3. diffusion model 79 | self.model = instantiate_from_config( 80 | denoiser_cfg, device=None, dtype=None 81 | ) 82 | 83 | self.optimizer_cfg = optimizer_cfg 84 | 85 | # 4. scheduling strategy 86 | self.scheduler_cfg = scheduler_cfg 87 | 88 | self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) 89 | self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) 90 | 91 | # 5. loss configures 92 | self.loss_cfg = loss_cfg 93 | 94 | self.scale_by_std = scale_by_std 95 | if scale_by_std: 96 | self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) 97 | else: 98 | self.z_scale_factor = z_scale_factor 99 | 100 | self.ckpt_path = ckpt_path 101 | if ckpt_path is not None: 102 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 103 | 104 | def instantiate_first_stage(self, config): 105 | model = instantiate_from_config(config) 106 | self.first_stage_model = model.eval() 107 | self.first_stage_model.train = disabled_train 108 | for param in self.first_stage_model.parameters(): 109 | param.requires_grad = False 110 | 111 | self.first_stage_model = self.first_stage_model.to(self.device) 112 | 113 | # def instantiate_cond_stage(self, config): 114 | # if not self.cond_stage_trainable: 115 | # if config == "__is_first_stage__": 116 | # print("Using first stage also as cond stage.") 117 | # self.cond_stage_model = self.first_stage_model 118 | # elif config == "__is_unconditional__": 119 | # print(f"Training {self.__class__.__name__} as an unconditional model.") 120 | # self.cond_stage_model = None 121 | # # self.be_unconditional = True 122 | # else: 123 | # model = instantiate_from_config(config) 124 | # self.cond_stage_model = model.eval() 125 | # self.cond_stage_model.train = disabled_train 126 | # for param in self.cond_stage_model.parameters(): 127 | # param.requires_grad = False 128 | # else: 129 | # assert config != "__is_first_stage__" 130 | # assert config != "__is_unconditional__" 131 | # model = instantiate_from_config(config) 132 | # self.cond_stage_model = model 133 | 134 | def init_from_ckpt(self, path, ignore_keys=()): 135 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 136 | 137 | keys = list(state_dict.keys()) 138 | for k in keys: 139 | for ik in ignore_keys: 140 | if k.startswith(ik): 141 | print("Deleting key {} from state_dict.".format(k)) 142 | del state_dict[k] 143 | 144 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 145 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 146 | if len(missing) > 0: 147 | print(f"Missing Keys: {missing}") 148 | print(f"Unexpected Keys: {unexpected}") 149 | 150 | @property 151 | def zero_rank(self): 152 | if self._trainer: 153 | zero_rank = self.trainer.local_rank == 0 154 | else: 155 | zero_rank = True 156 | 157 | return zero_rank 158 | 159 | def configure_optimizers(self) -> Tuple[List, List]: 160 | 161 | lr = self.learning_rate 162 | 163 | trainable_parameters = list(self.model.parameters()) 164 | # if the conditional encoder is trainable 165 | 166 | # if self.cond_stage_trainable: 167 | # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad] 168 | # trainable_parameters += conditioner_params 169 | # print(f"number of trainable conditional parameters: {len(conditioner_params)}.") 170 | 171 | if self.optimizer_cfg is None: 172 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 173 | schedulers = [] 174 | else: 175 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 176 | scheduler_func = instantiate_from_config( 177 | self.optimizer_cfg.scheduler, 178 | max_decay_steps=self.trainer.max_steps, 179 | lr_max=lr 180 | ) 181 | scheduler = { 182 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 183 | "interval": "step", 184 | "frequency": 1 185 | } 186 | optimizers = [optimizer] 187 | schedulers = [scheduler] 188 | 189 | return optimizers, schedulers 190 | 191 | @torch.no_grad() 192 | def encode_text(self, text): 193 | 194 | b = text.shape[0] 195 | text_tokens = rearrange(text, "b t l -> (b t) l") 196 | text_embed = self.first_stage_model.model.encode_text_embed(text_tokens) 197 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 198 | text_embed = text_embed.mean(dim=1) 199 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 200 | 201 | return text_embed 202 | 203 | @torch.no_grad() 204 | def encode_image(self, img): 205 | 206 | return self.first_stage_model.model.encode_image_embed(img) 207 | 208 | @torch.no_grad() 209 | def encode_surface(self, surface): 210 | 211 | return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False) 212 | 213 | @torch.no_grad() 214 | def empty_text_cond(self, cond): 215 | 216 | return torch.zeros_like(cond, device=cond.device) 217 | 218 | @torch.no_grad() 219 | def empty_img_cond(self, cond): 220 | 221 | return torch.zeros_like(cond, device=cond.device) 222 | 223 | @torch.no_grad() 224 | def empty_surface_cond(self, cond): 225 | 226 | return torch.zeros_like(cond, device=cond.device) 227 | 228 | @torch.no_grad() 229 | def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): 230 | 231 | z_q = self.first_stage_model.encode(surface, sample_posterior) 232 | z_q = self.z_scale_factor * z_q 233 | 234 | return z_q 235 | 236 | @torch.no_grad() 237 | def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): 238 | 239 | z_q = 1. / self.z_scale_factor * z_q 240 | latents = self.first_stage_model.decode(z_q, **kwargs) 241 | return latents 242 | 243 | @rank_zero_only 244 | @torch.no_grad() 245 | def on_train_batch_start(self, batch, batch_idx): 246 | # only for very first batch 247 | if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ 248 | and batch_idx == 0 and self.ckpt_path is None: 249 | # set rescale weight to 1./std of encodings 250 | print("### USING STD-RESCALING ###") 251 | 252 | z_q = self.encode_first_stage(batch[self.first_stage_key]) 253 | z = z_q.detach() 254 | 255 | del self.z_scale_factor 256 | self.register_buffer("z_scale_factor", 1. / z.flatten().std()) 257 | print(f"setting self.z_scale_factor to {self.z_scale_factor}") 258 | 259 | print("### USING STD-RESCALING ###") 260 | 261 | def compute_loss(self, model_outputs, split): 262 | """ 263 | 264 | Args: 265 | model_outputs (dict): 266 | - x_0: 267 | - noise: 268 | - noise_prior: 269 | - noise_pred: 270 | - noise_pred_prior: 271 | 272 | split (str): 273 | 274 | Returns: 275 | 276 | """ 277 | 278 | pred = model_outputs["pred"] 279 | 280 | if self.noise_scheduler.prediction_type == "epsilon": 281 | target = model_outputs["noise"] 282 | elif self.noise_scheduler.prediction_type == "sample": 283 | target = model_outputs["x_0"] 284 | else: 285 | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") 286 | 287 | if self.loss_cfg.loss_type == "l1": 288 | simple = F.l1_loss(pred, target, reduction="mean") 289 | elif self.loss_cfg.loss_type in ["mse", "l2"]: 290 | simple = F.mse_loss(pred, target, reduction="mean") 291 | else: 292 | raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") 293 | 294 | total_loss = simple 295 | 296 | loss_dict = { 297 | f"{split}/total_loss": total_loss.clone().detach(), 298 | f"{split}/simple": simple.detach(), 299 | } 300 | 301 | return total_loss, loss_dict 302 | 303 | def forward(self, batch): 304 | """ 305 | 306 | Args: 307 | batch: 308 | 309 | Returns: 310 | 311 | """ 312 | 313 | if self.first_stage_model is None: 314 | self.instantiate_first_stage(self.first_stage_config) 315 | 316 | latents = self.encode_first_stage(batch[self.first_stage_key]) 317 | 318 | # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) 319 | 320 | conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1) 321 | 322 | mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1 323 | conditions = conditions * mask.to(conditions) 324 | 325 | # Sample noise that we"ll add to the latents 326 | # [batch_size, n_token, latent_dim] 327 | noise = torch.randn_like(latents) 328 | bs = latents.shape[0] 329 | # Sample a random timestep for each motion 330 | timesteps = torch.randint( 331 | 0, 332 | self.noise_scheduler.config.num_train_timesteps, 333 | (bs,), 334 | device=latents.device, 335 | ) 336 | timesteps = timesteps.long() 337 | # Add noise to the latents according to the noise magnitude at each timestep 338 | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) 339 | 340 | # diffusion model forward 341 | noise_pred = self.model(noisy_z, timesteps, conditions) 342 | 343 | diffusion_outputs = { 344 | "x_0": noisy_z, 345 | "noise": noise, 346 | "pred": noise_pred 347 | } 348 | 349 | return diffusion_outputs 350 | 351 | def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], 352 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 353 | """ 354 | 355 | Args: 356 | batch (dict): the batch sample, and it contains: 357 | - surface (torch.FloatTensor): 358 | - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] 359 | - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] 360 | - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] 361 | - text (list of str): 362 | 363 | batch_idx (int): 364 | 365 | optimizer_idx (int): 366 | 367 | Returns: 368 | loss (torch.FloatTensor): 369 | 370 | """ 371 | 372 | diffusion_outputs = self(batch) 373 | 374 | loss, loss_dict = self.compute_loss(diffusion_outputs, "train") 375 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 376 | 377 | return loss 378 | 379 | def validation_step(self, batch: Dict[str, torch.FloatTensor], 380 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 381 | """ 382 | 383 | Args: 384 | batch (dict): the batch sample, and it contains: 385 | - surface_pc (torch.FloatTensor): [n_pts, 4] 386 | - surface_feats (torch.FloatTensor): [n_pts, c] 387 | - text (list of str): 388 | 389 | batch_idx (int): 390 | 391 | optimizer_idx (int): 392 | 393 | Returns: 394 | loss (torch.FloatTensor): 395 | 396 | """ 397 | 398 | diffusion_outputs = self(batch) 399 | 400 | loss, loss_dict = self.compute_loss(diffusion_outputs, "val") 401 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 402 | 403 | return loss 404 | 405 | @torch.no_grad() 406 | def sample(self, 407 | batch: Dict[str, Union[torch.FloatTensor, List[str]]], 408 | sample_times: int = 1, 409 | steps: Optional[int] = None, 410 | guidance_scale: Optional[float] = None, 411 | eta: float = 0.0, 412 | return_intermediates: bool = False, **kwargs): 413 | 414 | if self.first_stage_model is None: 415 | self.instantiate_first_stage(self.first_stage_config) 416 | 417 | if steps is None: 418 | steps = self.scheduler_cfg.num_inference_steps 419 | 420 | if guidance_scale is None: 421 | guidance_scale = self.scheduler_cfg.guidance_scale 422 | do_classifier_free_guidance = guidance_scale > 0 423 | 424 | # conditional encode 425 | xc = batch[self.cond_stage_key] 426 | # cond = self.cond_stage_model[self.cond_stage_key](xc) 427 | cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1) 428 | 429 | if do_classifier_free_guidance: 430 | """ 431 | Note: There are two kinds of uncond for text. 432 | 1: using "" as uncond text; (in SAL diffusion) 433 | 2: zeros_like(cond) as uncond text; (in MDM) 434 | """ 435 | # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) 436 | un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond) 437 | # un_cond = torch.zeros_like(cond, device=cond.device) 438 | cond = torch.cat([un_cond, cond], dim=0) 439 | 440 | outputs = [] 441 | latents = None 442 | 443 | if not return_intermediates: 444 | for _ in range(sample_times): 445 | sample_loop = ddim_sample( 446 | self.denoise_scheduler, 447 | self.model, 448 | shape=self.first_stage_model.latent_shape, 449 | cond=cond, 450 | steps=steps, 451 | guidance_scale=guidance_scale, 452 | do_classifier_free_guidance=do_classifier_free_guidance, 453 | device=self.device, 454 | eta=eta, 455 | disable_prog=not self.zero_rank 456 | ) 457 | for sample, t in sample_loop: 458 | latents = sample 459 | outputs.append(self.decode_first_stage(latents, **kwargs)) 460 | else: 461 | 462 | sample_loop = ddim_sample( 463 | self.denoise_scheduler, 464 | self.model, 465 | shape=self.first_stage_model.latent_shape, 466 | cond=cond, 467 | steps=steps, 468 | guidance_scale=guidance_scale, 469 | do_classifier_free_guidance=do_classifier_free_guidance, 470 | device=self.device, 471 | eta=eta, 472 | disable_prog=not self.zero_rank 473 | ) 474 | 475 | iter_size = steps // sample_times 476 | i = 0 477 | for sample, t in sample_loop: 478 | latents = sample 479 | if i % iter_size == 0 or i == steps - 1: 480 | outputs.append(self.decode_first_stage(latents, **kwargs)) 481 | i += 1 482 | 483 | return outputs 484 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/asl_udt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Optional 6 | from diffusers.models.embeddings import Timesteps 7 | import math 8 | 9 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import MLP 10 | from MeshAnything.miche.michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer 11 | 12 | 13 | class ConditionalASLUDTDenoiser(nn.Module): 14 | 15 | def __init__(self, *, 16 | device: Optional[torch.device], 17 | dtype: Optional[torch.dtype], 18 | input_channels: int, 19 | output_channels: int, 20 | n_ctx: int, 21 | width: int, 22 | layers: int, 23 | heads: int, 24 | context_dim: int, 25 | context_ln: bool = True, 26 | skip_ln: bool = False, 27 | init_scale: float = 0.25, 28 | flip_sin_to_cos: bool = False, 29 | use_checkpoint: bool = False): 30 | super().__init__() 31 | 32 | self.use_checkpoint = use_checkpoint 33 | 34 | init_scale = init_scale * math.sqrt(1.0 / width) 35 | 36 | self.backbone = UNetDiffusionTransformer( 37 | device=device, 38 | dtype=dtype, 39 | n_ctx=n_ctx, 40 | width=width, 41 | layers=layers, 42 | heads=heads, 43 | skip_ln=skip_ln, 44 | init_scale=init_scale, 45 | use_checkpoint=use_checkpoint 46 | ) 47 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 48 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) 49 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) 50 | 51 | # timestep embedding 52 | self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) 53 | self.time_proj = MLP( 54 | device=device, dtype=dtype, width=width, init_scale=init_scale 55 | ) 56 | 57 | self.context_embed = nn.Sequential( 58 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 59 | nn.Linear(context_dim, width, device=device, dtype=dtype), 60 | ) 61 | 62 | if context_ln: 63 | self.context_embed = nn.Sequential( 64 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 65 | nn.Linear(context_dim, width, device=device, dtype=dtype), 66 | ) 67 | else: 68 | self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) 69 | 70 | def forward(self, 71 | model_input: torch.FloatTensor, 72 | timestep: torch.LongTensor, 73 | context: torch.FloatTensor): 74 | 75 | r""" 76 | Args: 77 | model_input (torch.FloatTensor): [bs, n_data, c] 78 | timestep (torch.LongTensor): [bs,] 79 | context (torch.FloatTensor): [bs, context_tokens, c] 80 | 81 | Returns: 82 | sample (torch.FloatTensor): [bs, n_data, c] 83 | 84 | """ 85 | 86 | _, n_data, _ = model_input.shape 87 | 88 | # 1. time 89 | t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) 90 | 91 | # 2. conditions projector 92 | context = self.context_embed(context) 93 | 94 | # 3. denoiser 95 | x = self.input_proj(model_input) 96 | x = torch.cat([t_emb, context, x], dim=1) 97 | x = self.backbone(x) 98 | x = self.ln_post(x) 99 | x = x[:, -n_data:] 100 | sample = self.output_proj(x) 101 | 102 | return sample 103 | 104 | 105 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BaseDenoiser(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, t, context): 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from omegaconf import DictConfig 4 | from typing import List, Tuple, Dict, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import lr_scheduler 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.utilities import rank_zero_only 12 | 13 | from diffusers.schedulers import ( 14 | DDPMScheduler, 15 | DDIMScheduler, 16 | KarrasVeScheduler, 17 | DPMSolverMultistepScheduler 18 | ) 19 | 20 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 21 | from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule 22 | from MeshAnything.miche.michelangelo.models.asl_diffusion.inference_utils import ddim_sample 23 | 24 | SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] 25 | 26 | 27 | def disabled_train(self, mode=True): 28 | """Overwrite model.train with this function to make sure train/eval mode 29 | does not change anymore.""" 30 | return self 31 | 32 | 33 | class ClipASLDiffuser(pl.LightningModule): 34 | first_stage_model: Optional[AlignedShapeAsLatentPLModule] 35 | cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] 36 | model: nn.Module 37 | 38 | def __init__(self, *, 39 | first_stage_config, 40 | cond_stage_config, 41 | denoiser_cfg, 42 | scheduler_cfg, 43 | optimizer_cfg, 44 | loss_cfg, 45 | first_stage_key: str = "surface", 46 | cond_stage_key: str = "image", 47 | scale_by_std: bool = False, 48 | z_scale_factor: float = 1.0, 49 | ckpt_path: Optional[str] = None, 50 | ignore_keys: Union[Tuple[str], List[str]] = ()): 51 | 52 | super().__init__() 53 | 54 | self.first_stage_key = first_stage_key 55 | self.cond_stage_key = cond_stage_key 56 | 57 | # 1. lazy initialize first stage 58 | self.instantiate_first_stage(first_stage_config) 59 | 60 | # 2. initialize conditional stage 61 | self.instantiate_cond_stage(cond_stage_config) 62 | 63 | # 3. diffusion model 64 | self.model = instantiate_from_config( 65 | denoiser_cfg, device=None, dtype=None 66 | ) 67 | 68 | self.optimizer_cfg = optimizer_cfg 69 | 70 | # 4. scheduling strategy 71 | self.scheduler_cfg = scheduler_cfg 72 | 73 | self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) 74 | self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) 75 | 76 | # 5. loss configures 77 | self.loss_cfg = loss_cfg 78 | 79 | self.scale_by_std = scale_by_std 80 | if scale_by_std: 81 | self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) 82 | else: 83 | self.z_scale_factor = z_scale_factor 84 | 85 | self.ckpt_path = ckpt_path 86 | if ckpt_path is not None: 87 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 88 | 89 | def instantiate_non_trainable_model(self, config): 90 | model = instantiate_from_config(config) 91 | model = model.eval() 92 | model.train = disabled_train 93 | for param in model.parameters(): 94 | param.requires_grad = False 95 | 96 | return model 97 | 98 | def instantiate_first_stage(self, first_stage_config): 99 | self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config) 100 | self.first_stage_model.set_shape_model_only() 101 | 102 | def instantiate_cond_stage(self, cond_stage_config): 103 | self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config) 104 | 105 | def init_from_ckpt(self, path, ignore_keys=()): 106 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 107 | 108 | keys = list(state_dict.keys()) 109 | for k in keys: 110 | for ik in ignore_keys: 111 | if k.startswith(ik): 112 | print("Deleting key {} from state_dict.".format(k)) 113 | del state_dict[k] 114 | 115 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 116 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 117 | if len(missing) > 0: 118 | print(f"Missing Keys: {missing}") 119 | print(f"Unexpected Keys: {unexpected}") 120 | 121 | @property 122 | def zero_rank(self): 123 | if self._trainer: 124 | zero_rank = self.trainer.local_rank == 0 125 | else: 126 | zero_rank = True 127 | 128 | return zero_rank 129 | 130 | def configure_optimizers(self) -> Tuple[List, List]: 131 | 132 | lr = self.learning_rate 133 | 134 | trainable_parameters = list(self.model.parameters()) 135 | if self.optimizer_cfg is None: 136 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 137 | schedulers = [] 138 | else: 139 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 140 | scheduler_func = instantiate_from_config( 141 | self.optimizer_cfg.scheduler, 142 | max_decay_steps=self.trainer.max_steps, 143 | lr_max=lr 144 | ) 145 | scheduler = { 146 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 147 | "interval": "step", 148 | "frequency": 1 149 | } 150 | optimizers = [optimizer] 151 | schedulers = [scheduler] 152 | 153 | return optimizers, schedulers 154 | 155 | @torch.no_grad() 156 | def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): 157 | 158 | z_q = self.first_stage_model.encode(surface, sample_posterior) 159 | z_q = self.z_scale_factor * z_q 160 | 161 | return z_q 162 | 163 | @torch.no_grad() 164 | def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): 165 | 166 | z_q = 1. / self.z_scale_factor * z_q 167 | latents = self.first_stage_model.decode(z_q, **kwargs) 168 | return latents 169 | 170 | @rank_zero_only 171 | @torch.no_grad() 172 | def on_train_batch_start(self, batch, batch_idx): 173 | # only for very first batch 174 | if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ 175 | and batch_idx == 0 and self.ckpt_path is None: 176 | # set rescale weight to 1./std of encodings 177 | print("### USING STD-RESCALING ###") 178 | 179 | z_q = self.encode_first_stage(batch[self.first_stage_key]) 180 | z = z_q.detach() 181 | 182 | del self.z_scale_factor 183 | self.register_buffer("z_scale_factor", 1. / z.flatten().std()) 184 | print(f"setting self.z_scale_factor to {self.z_scale_factor}") 185 | 186 | print("### USING STD-RESCALING ###") 187 | 188 | def compute_loss(self, model_outputs, split): 189 | """ 190 | 191 | Args: 192 | model_outputs (dict): 193 | - x_0: 194 | - noise: 195 | - noise_prior: 196 | - noise_pred: 197 | - noise_pred_prior: 198 | 199 | split (str): 200 | 201 | Returns: 202 | 203 | """ 204 | 205 | pred = model_outputs["pred"] 206 | 207 | if self.noise_scheduler.prediction_type == "epsilon": 208 | target = model_outputs["noise"] 209 | elif self.noise_scheduler.prediction_type == "sample": 210 | target = model_outputs["x_0"] 211 | else: 212 | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") 213 | 214 | if self.loss_cfg.loss_type == "l1": 215 | simple = F.l1_loss(pred, target, reduction="mean") 216 | elif self.loss_cfg.loss_type in ["mse", "l2"]: 217 | simple = F.mse_loss(pred, target, reduction="mean") 218 | else: 219 | raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") 220 | 221 | total_loss = simple 222 | 223 | loss_dict = { 224 | f"{split}/total_loss": total_loss.clone().detach(), 225 | f"{split}/simple": simple.detach(), 226 | } 227 | 228 | return total_loss, loss_dict 229 | 230 | def forward(self, batch): 231 | """ 232 | 233 | Args: 234 | batch: 235 | 236 | Returns: 237 | 238 | """ 239 | 240 | latents = self.encode_first_stage(batch[self.first_stage_key]) 241 | conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) 242 | 243 | # Sample noise that we"ll add to the latents 244 | # [batch_size, n_token, latent_dim] 245 | noise = torch.randn_like(latents) 246 | bs = latents.shape[0] 247 | # Sample a random timestep for each motion 248 | timesteps = torch.randint( 249 | 0, 250 | self.noise_scheduler.config.num_train_timesteps, 251 | (bs,), 252 | device=latents.device, 253 | ) 254 | timesteps = timesteps.long() 255 | # Add noise to the latents according to the noise magnitude at each timestep 256 | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) 257 | 258 | # diffusion model forward 259 | noise_pred = self.model(noisy_z, timesteps, conditions) 260 | 261 | diffusion_outputs = { 262 | "x_0": noisy_z, 263 | "noise": noise, 264 | "pred": noise_pred 265 | } 266 | 267 | return diffusion_outputs 268 | 269 | def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], 270 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 271 | """ 272 | 273 | Args: 274 | batch (dict): the batch sample, and it contains: 275 | - surface (torch.FloatTensor): 276 | - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] 277 | - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] 278 | - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] 279 | - text (list of str): 280 | 281 | batch_idx (int): 282 | 283 | optimizer_idx (int): 284 | 285 | Returns: 286 | loss (torch.FloatTensor): 287 | 288 | """ 289 | 290 | diffusion_outputs = self(batch) 291 | 292 | loss, loss_dict = self.compute_loss(diffusion_outputs, "train") 293 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 294 | 295 | return loss 296 | 297 | def validation_step(self, batch: Dict[str, torch.FloatTensor], 298 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 299 | """ 300 | 301 | Args: 302 | batch (dict): the batch sample, and it contains: 303 | - surface_pc (torch.FloatTensor): [n_pts, 4] 304 | - surface_feats (torch.FloatTensor): [n_pts, c] 305 | - text (list of str): 306 | 307 | batch_idx (int): 308 | 309 | optimizer_idx (int): 310 | 311 | Returns: 312 | loss (torch.FloatTensor): 313 | 314 | """ 315 | 316 | diffusion_outputs = self(batch) 317 | 318 | loss, loss_dict = self.compute_loss(diffusion_outputs, "val") 319 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 320 | 321 | return loss 322 | 323 | @torch.no_grad() 324 | def sample(self, 325 | batch: Dict[str, Union[torch.FloatTensor, List[str]]], 326 | sample_times: int = 1, 327 | steps: Optional[int] = None, 328 | guidance_scale: Optional[float] = None, 329 | eta: float = 0.0, 330 | return_intermediates: bool = False, **kwargs): 331 | 332 | if steps is None: 333 | steps = self.scheduler_cfg.num_inference_steps 334 | 335 | if guidance_scale is None: 336 | guidance_scale = self.scheduler_cfg.guidance_scale 337 | do_classifier_free_guidance = guidance_scale > 0 338 | 339 | # conditional encode 340 | xc = batch[self.cond_stage_key] 341 | 342 | # print(self.first_stage_model.device, self.cond_stage_model.device, self.device) 343 | 344 | cond = self.cond_stage_model(xc) 345 | 346 | if do_classifier_free_guidance: 347 | un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) 348 | cond = torch.cat([un_cond, cond], dim=0) 349 | 350 | outputs = [] 351 | latents = None 352 | 353 | if not return_intermediates: 354 | for _ in range(sample_times): 355 | sample_loop = ddim_sample( 356 | self.denoise_scheduler, 357 | self.model, 358 | shape=self.first_stage_model.latent_shape, 359 | cond=cond, 360 | steps=steps, 361 | guidance_scale=guidance_scale, 362 | do_classifier_free_guidance=do_classifier_free_guidance, 363 | device=self.device, 364 | eta=eta, 365 | disable_prog=not self.zero_rank 366 | ) 367 | for sample, t in sample_loop: 368 | latents = sample 369 | outputs.append(self.decode_first_stage(latents, **kwargs)) 370 | else: 371 | 372 | sample_loop = ddim_sample( 373 | self.denoise_scheduler, 374 | self.model, 375 | shape=self.first_stage_model.latent_shape, 376 | cond=cond, 377 | steps=steps, 378 | guidance_scale=guidance_scale, 379 | do_classifier_free_guidance=do_classifier_free_guidance, 380 | device=self.device, 381 | eta=eta, 382 | disable_prog=not self.zero_rank 383 | ) 384 | 385 | iter_size = steps // sample_times 386 | i = 0 387 | for sample, t in sample_loop: 388 | latents = sample 389 | if i % iter_size == 0 or i == steps - 1: 390 | outputs.append(self.decode_first_stage(latents, **kwargs)) 391 | i += 1 392 | 393 | return outputs 394 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/asl_diffusion/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from typing import Tuple, List, Union, Optional 6 | from diffusers.schedulers import DDIMScheduler 7 | 8 | 9 | __all__ = ["ddim_sample"] 10 | 11 | 12 | def ddim_sample(ddim_scheduler: DDIMScheduler, 13 | diffusion_model: torch.nn.Module, 14 | shape: Union[List[int], Tuple[int]], 15 | cond: torch.FloatTensor, 16 | steps: int, 17 | eta: float = 0.0, 18 | guidance_scale: float = 3.0, 19 | do_classifier_free_guidance: bool = True, 20 | generator: Optional[torch.Generator] = None, 21 | device: torch.device = "cuda:0", 22 | disable_prog: bool = True): 23 | 24 | assert steps > 0, f"{steps} must > 0." 25 | 26 | # init latents 27 | bsz = cond.shape[0] 28 | if do_classifier_free_guidance: 29 | bsz = bsz // 2 30 | 31 | latents = torch.randn( 32 | (bsz, *shape), 33 | generator=generator, 34 | device=cond.device, 35 | dtype=cond.dtype, 36 | ) 37 | # scale the initial noise by the standard deviation required by the scheduler 38 | latents = latents * ddim_scheduler.init_noise_sigma 39 | # set timesteps 40 | ddim_scheduler.set_timesteps(steps) 41 | timesteps = ddim_scheduler.timesteps.to(device) 42 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 43 | # eta (η) is only used with the DDIMScheduler, and between [0, 1] 44 | extra_step_kwargs = { 45 | "eta": eta, 46 | "generator": generator 47 | } 48 | 49 | # reverse 50 | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): 51 | # expand the latents if we are doing classifier free guidance 52 | latent_model_input = ( 53 | torch.cat([latents] * 2) 54 | if do_classifier_free_guidance 55 | else latents 56 | ) 57 | # latent_model_input = scheduler.scale_model_input(latent_model_input, t) 58 | # predict the noise residual 59 | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) 60 | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) 61 | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) 62 | 63 | # perform guidance 64 | if do_classifier_free_guidance: 65 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 66 | noise_pred = noise_pred_uncond + guidance_scale * ( 67 | noise_pred_text - noise_pred_uncond 68 | ) 69 | # text_embeddings_for_guidance = encoder_hidden_states.chunk( 70 | # 2)[1] if do_classifier_free_guidance else encoder_hidden_states 71 | # compute the previous noisy sample x_t -> x_t-1 72 | latents = ddim_scheduler.step( 73 | noise_pred, t, latents, **extra_step_kwargs 74 | ).prev_sample 75 | 76 | yield latents, t 77 | 78 | 79 | def karra_sample(): 80 | pass 81 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/conditional_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .clip import CLIPEncoder 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/conditional_encoders/clip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from dataclasses import dataclass 7 | from torchvision.transforms import Normalize 8 | from transformers import CLIPModel, CLIPTokenizer 9 | from transformers.utils import ModelOutput 10 | from typing import Iterable, Optional, Union, List 11 | 12 | 13 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 14 | 15 | 16 | @dataclass 17 | class CLIPEmbedOutput(ModelOutput): 18 | last_hidden_state: torch.FloatTensor = None 19 | pooler_output: torch.FloatTensor = None 20 | embeds: torch.FloatTensor = None 21 | 22 | 23 | class CLIPEncoder(torch.nn.Module): 24 | 25 | def __init__(self, model_path="openai/clip-vit-base-patch32"): 26 | 27 | super().__init__() 28 | 29 | # Load the CLIP model and processor 30 | self.model: CLIPModel = CLIPModel.from_pretrained(model_path) 31 | self.tokenizer = CLIPTokenizer.from_pretrained(model_path) 32 | self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 33 | 34 | self.model.training = False 35 | for p in self.model.parameters(): 36 | p.requires_grad = False 37 | 38 | @torch.no_grad() 39 | def encode_image(self, images: Iterable[Optional[ImageType]]): 40 | pixel_values = self.image_preprocess(images) 41 | 42 | vision_outputs = self.model.vision_model(pixel_values=pixel_values) 43 | 44 | pooler_output = vision_outputs[1] # pooled_output 45 | image_features = self.model.visual_projection(pooler_output) 46 | 47 | visual_embeds = CLIPEmbedOutput( 48 | last_hidden_state=vision_outputs.last_hidden_state, 49 | pooler_output=pooler_output, 50 | embeds=image_features 51 | ) 52 | 53 | return visual_embeds 54 | 55 | @torch.no_grad() 56 | def encode_text(self, texts: List[str]): 57 | text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") 58 | 59 | text_outputs = self.model.text_model(input_ids=text_inputs) 60 | 61 | pooler_output = text_outputs[1] # pooled_output 62 | text_features = self.model.text_projection(pooler_output) 63 | 64 | text_embeds = CLIPEmbedOutput( 65 | last_hidden_state=text_outputs.last_hidden_state, 66 | pooler_output=pooler_output, 67 | embeds=text_features 68 | ) 69 | 70 | return text_embeds 71 | 72 | def forward(self, 73 | images: Iterable[Optional[ImageType]], 74 | texts: List[str]): 75 | 76 | visual_embeds = self.encode_image(images) 77 | text_embeds = self.encode_text(texts) 78 | 79 | return visual_embeds, text_embeds 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import checkpoint 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 4 | """ 5 | 6 | import torch 7 | from typing import Callable, Iterable, Sequence, Union 8 | 9 | 10 | def checkpoint( 11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 12 | inputs: Sequence[torch.Tensor], 13 | params: Iterable[torch.Tensor], 14 | flag: bool, 15 | use_deepspeed: bool = False 16 | ): 17 | """ 18 | Evaluate a function without caching intermediate activations, allowing for 19 | reduced memory at the expense of extra compute in the backward pass. 20 | :param func: the function to evaluate. 21 | :param inputs: the argument sequence to pass to `func`. 22 | :param params: a sequence of parameters `func` depends on but does not 23 | explicitly take as arguments. 24 | :param flag: if False, disable gradient checkpointing. 25 | :param use_deepspeed: if True, use deepspeed 26 | """ 27 | if flag: 28 | if use_deepspeed: 29 | import deepspeed 30 | return deepspeed.checkpointing.checkpoint(func, *inputs) 31 | 32 | args = tuple(inputs) + tuple(params) 33 | return CheckpointFunction.apply(func, len(inputs), *args) 34 | else: 35 | return func(*inputs) 36 | 37 | 38 | class CheckpointFunction(torch.autograd.Function): 39 | @staticmethod 40 | @torch.cuda.amp.custom_fwd 41 | def forward(ctx, run_function, length, *args): 42 | ctx.run_function = run_function 43 | ctx.input_tensors = list(args[:length]) 44 | ctx.input_params = list(args[length:]) 45 | 46 | with torch.no_grad(): 47 | output_tensors = ctx.run_function(*ctx.input_tensors) 48 | return output_tensors 49 | 50 | @staticmethod 51 | @torch.cuda.amp.custom_bwd 52 | def backward(ctx, *output_grads): 53 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 54 | with torch.enable_grad(): 55 | # Fixes a bug where the first op in run_function modifies the 56 | # Tensor storage in place, which is not allowed for detach()'d 57 | # Tensors. 58 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 59 | output_tensors = ctx.run_function(*shallow_copies) 60 | input_grads = torch.autograd.grad( 61 | output_tensors, 62 | ctx.input_tensors + ctx.input_params, 63 | output_grads, 64 | allow_unused=True, 65 | ) 66 | del ctx.input_tensors 67 | del ctx.input_params 68 | del output_tensors 69 | return (None, None) + input_grads 70 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/diffusion_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from typing import Optional 7 | 8 | from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint 9 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import ( 10 | init_linear, 11 | MLP, 12 | MultiheadCrossAttention, 13 | MultiheadAttention, 14 | ResidualAttentionBlock 15 | ) 16 | 17 | 18 | class AdaLayerNorm(nn.Module): 19 | def __init__(self, 20 | device: torch.device, 21 | dtype: torch.dtype, 22 | width: int): 23 | 24 | super().__init__() 25 | 26 | self.silu = nn.SiLU(inplace=True) 27 | self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype) 28 | self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype) 29 | 30 | def forward(self, x, timestep): 31 | emb = self.linear(timestep) 32 | scale, shift = torch.chunk(emb, 2, dim=2) 33 | x = self.layernorm(x) * (1 + scale) + shift 34 | return x 35 | 36 | 37 | class DitBlock(nn.Module): 38 | def __init__( 39 | self, 40 | *, 41 | device: torch.device, 42 | dtype: torch.dtype, 43 | n_ctx: int, 44 | width: int, 45 | heads: int, 46 | context_dim: int, 47 | qkv_bias: bool = False, 48 | init_scale: float = 1.0, 49 | use_checkpoint: bool = False 50 | ): 51 | super().__init__() 52 | 53 | self.use_checkpoint = use_checkpoint 54 | 55 | self.attn = MultiheadAttention( 56 | device=device, 57 | dtype=dtype, 58 | n_ctx=n_ctx, 59 | width=width, 60 | heads=heads, 61 | init_scale=init_scale, 62 | qkv_bias=qkv_bias 63 | ) 64 | self.ln_1 = AdaLayerNorm(device, dtype, width) 65 | 66 | if context_dim is not None: 67 | self.ln_2 = AdaLayerNorm(device, dtype, width) 68 | self.cross_attn = MultiheadCrossAttention( 69 | device=device, 70 | dtype=dtype, 71 | width=width, 72 | heads=heads, 73 | data_width=context_dim, 74 | init_scale=init_scale, 75 | qkv_bias=qkv_bias 76 | ) 77 | 78 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 79 | self.ln_3 = AdaLayerNorm(device, dtype, width) 80 | 81 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 82 | return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint) 83 | 84 | def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 85 | x = x + self.attn(self.ln_1(x, t)) 86 | if context is not None: 87 | x = x + self.cross_attn(self.ln_2(x, t), context) 88 | x = x + self.mlp(self.ln_3(x, t)) 89 | return x 90 | 91 | 92 | class DiT(nn.Module): 93 | def __init__( 94 | self, 95 | *, 96 | device: Optional[torch.device], 97 | dtype: Optional[torch.dtype], 98 | n_ctx: int, 99 | width: int, 100 | layers: int, 101 | heads: int, 102 | context_dim: int, 103 | init_scale: float = 0.25, 104 | qkv_bias: bool = False, 105 | use_checkpoint: bool = False 106 | ): 107 | super().__init__() 108 | self.n_ctx = n_ctx 109 | self.width = width 110 | self.layers = layers 111 | 112 | self.resblocks = nn.ModuleList( 113 | [ 114 | DitBlock( 115 | device=device, 116 | dtype=dtype, 117 | n_ctx=n_ctx, 118 | width=width, 119 | heads=heads, 120 | context_dim=context_dim, 121 | qkv_bias=qkv_bias, 122 | init_scale=init_scale, 123 | use_checkpoint=use_checkpoint 124 | ) 125 | for _ in range(layers) 126 | ] 127 | ) 128 | 129 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 130 | for block in self.resblocks: 131 | x = block(x, t, context) 132 | return x 133 | 134 | 135 | class UNetDiffusionTransformer(nn.Module): 136 | def __init__( 137 | self, 138 | *, 139 | device: Optional[torch.device], 140 | dtype: Optional[torch.dtype], 141 | n_ctx: int, 142 | width: int, 143 | layers: int, 144 | heads: int, 145 | init_scale: float = 0.25, 146 | qkv_bias: bool = False, 147 | skip_ln: bool = False, 148 | use_checkpoint: bool = False 149 | ): 150 | super().__init__() 151 | 152 | self.n_ctx = n_ctx 153 | self.width = width 154 | self.layers = layers 155 | 156 | self.encoder = nn.ModuleList() 157 | for _ in range(layers): 158 | resblock = ResidualAttentionBlock( 159 | device=device, 160 | dtype=dtype, 161 | n_ctx=n_ctx, 162 | width=width, 163 | heads=heads, 164 | init_scale=init_scale, 165 | qkv_bias=qkv_bias, 166 | use_checkpoint=use_checkpoint 167 | ) 168 | self.encoder.append(resblock) 169 | 170 | self.middle_block = ResidualAttentionBlock( 171 | device=device, 172 | dtype=dtype, 173 | n_ctx=n_ctx, 174 | width=width, 175 | heads=heads, 176 | init_scale=init_scale, 177 | qkv_bias=qkv_bias, 178 | use_checkpoint=use_checkpoint 179 | ) 180 | 181 | self.decoder = nn.ModuleList() 182 | for _ in range(layers): 183 | resblock = ResidualAttentionBlock( 184 | device=device, 185 | dtype=dtype, 186 | n_ctx=n_ctx, 187 | width=width, 188 | heads=heads, 189 | init_scale=init_scale, 190 | qkv_bias=qkv_bias, 191 | use_checkpoint=use_checkpoint 192 | ) 193 | linear = nn.Linear(width * 2, width, device=device, dtype=dtype) 194 | init_linear(linear, init_scale) 195 | 196 | layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None 197 | 198 | self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) 199 | 200 | def forward(self, x: torch.Tensor): 201 | 202 | enc_outputs = [] 203 | for block in self.encoder: 204 | x = block(x) 205 | enc_outputs.append(x) 206 | 207 | x = self.middle_block(x) 208 | 209 | for i, (resblock, linear, layer_norm) in enumerate(self.decoder): 210 | x = torch.cat([enc_outputs.pop(), x], dim=-1) 211 | x = linear(x) 212 | 213 | if layer_norm is not None: 214 | x = layer_norm(x) 215 | 216 | x = resblock(x) 217 | 218 | return x 219 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union, List 4 | 5 | 6 | class AbstractDistribution(object): 7 | def sample(self): 8 | raise NotImplementedError() 9 | 10 | def mode(self): 11 | raise NotImplementedError() 12 | 13 | 14 | class DiracDistribution(AbstractDistribution): 15 | def __init__(self, value): 16 | self.value = value 17 | 18 | def sample(self): 19 | return self.value 20 | 21 | def mode(self): 22 | return self.value 23 | 24 | 25 | class DiagonalGaussianDistribution(object): 26 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): 27 | self.feat_dim = feat_dim 28 | self.parameters = parameters 29 | 30 | if isinstance(parameters, list): 31 | self.mean = parameters[0] 32 | self.logvar = parameters[1] 33 | else: 34 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) 35 | 36 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 37 | self.deterministic = deterministic 38 | self.std = torch.exp(0.5 * self.logvar) 39 | self.var = torch.exp(self.logvar) 40 | if self.deterministic: 41 | self.var = self.std = torch.zeros_like(self.mean) 42 | 43 | def sample(self): 44 | x = self.mean + self.std * torch.randn_like(self.mean) 45 | return x 46 | 47 | def kl(self, other=None, dims=(1, 2, 3)): 48 | if self.deterministic: 49 | return torch.Tensor([0.]) 50 | else: 51 | if other is None: 52 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 53 | + self.var - 1.0 - self.logvar, 54 | dim=dims) 55 | else: 56 | return 0.5 * torch.mean( 57 | torch.pow(self.mean - other.mean, 2) / other.var 58 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 59 | dim=dims) 60 | 61 | def nll(self, sample, dims=(1, 2, 3)): 62 | if self.deterministic: 63 | return torch.Tensor([0.]) 64 | logtwopi = np.log(2.0 * np.pi) 65 | return 0.5 * torch.sum( 66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 67 | dim=dims) 68 | 69 | def mode(self): 70 | return self.mean 71 | 72 | 73 | def normal_kl(mean1, logvar1, mean2, logvar2): 74 | """ 75 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 76 | Compute the KL divergence between two gaussians. 77 | Shapes are automatically broadcasted, so batches can be compared to 78 | scalars, among other use cases. 79 | """ 80 | tensor = None 81 | for obj in (mean1, logvar1, mean2, logvar2): 82 | if isinstance(obj, torch.Tensor): 83 | tensor = obj 84 | break 85 | assert tensor is not None, "at least one argument must be a Tensor" 86 | 87 | # Force variances to be Tensors. Broadcasting helps convert scalars to 88 | # Tensors, but it does not work for torch.exp(). 89 | logvar1, logvar2 = [ 90 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 91 | for x in (logvar1, logvar2) 92 | ] 93 | 94 | return 0.5 * ( 95 | -1.0 96 | + logvar2 97 | - logvar1 98 | + torch.exp(logvar1 - logvar2) 99 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 100 | ) 101 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/embedder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] 9 | 10 | 11 | class FourierEmbedder(nn.Module): 12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts 13 | each feature dimension of `x[..., i]` into: 14 | [ 15 | sin(x[..., i]), 16 | sin(f_1*x[..., i]), 17 | sin(f_2*x[..., i]), 18 | ... 19 | sin(f_N * x[..., i]), 20 | cos(x[..., i]), 21 | cos(f_1*x[..., i]), 22 | cos(f_2*x[..., i]), 23 | ... 24 | cos(f_N * x[..., i]), 25 | x[..., i] # only present if include_input is True. 26 | ], here f_i is the frequency. 27 | 28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. 29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; 30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. 31 | 32 | Args: 33 | num_freqs (int): the number of frequencies, default is 6; 34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; 36 | input_dim (int): the input dimension, default is 3; 37 | include_input (bool): include the input tensor or not, default is True. 38 | 39 | Attributes: 40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); 42 | 43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), 44 | otherwise, it is input_dim * num_freqs * 2. 45 | 46 | """ 47 | 48 | def __init__(self, 49 | num_freqs: int = 6, 50 | logspace: bool = True, 51 | input_dim: int = 3, 52 | include_input: bool = True, 53 | include_pi: bool = True) -> None: 54 | 55 | """The initialization""" 56 | 57 | super().__init__() 58 | 59 | if logspace: 60 | frequencies = 2.0 ** torch.arange( 61 | num_freqs, 62 | dtype=torch.float32 63 | ) 64 | else: 65 | frequencies = torch.linspace( 66 | 1.0, 67 | 2.0 ** (num_freqs - 1), 68 | num_freqs, 69 | dtype=torch.float32 70 | ) 71 | 72 | if include_pi: 73 | frequencies *= torch.pi 74 | 75 | self.register_buffer("frequencies", frequencies, persistent=False) 76 | self.include_input = include_input 77 | self.num_freqs = num_freqs 78 | 79 | self.out_dim = self.get_dims(input_dim) 80 | 81 | def get_dims(self, input_dim): 82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0 83 | out_dim = input_dim * (self.num_freqs * 2 + temp) 84 | 85 | return out_dim 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | """ Forward process. 89 | 90 | Args: 91 | x: tensor of shape [..., dim] 92 | 93 | Returns: 94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] 95 | where temp is 1 if include_input is True and 0 otherwise. 96 | """ 97 | 98 | if self.num_freqs > 0: 99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) 100 | if self.include_input: 101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) 102 | else: 103 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 104 | else: 105 | return x 106 | 107 | 108 | class LearnedFourierEmbedder(nn.Module): 109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """ 110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 111 | 112 | def __init__(self, in_channels, dim): 113 | super().__init__() 114 | assert (dim % 2) == 0 115 | half_dim = dim // 2 116 | per_channel_dim = half_dim // in_channels 117 | self.weights = nn.Parameter(torch.randn(per_channel_dim)) 118 | 119 | def forward(self, x): 120 | """ 121 | 122 | Args: 123 | x (torch.FloatTensor): [..., c] 124 | 125 | Returns: 126 | x (torch.FloatTensor): [..., d] 127 | """ 128 | 129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] 130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) 131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) 132 | return fouriered 133 | 134 | 135 | class TriplaneLearnedFourierEmbedder(nn.Module): 136 | def __init__(self, in_channels, dim): 137 | super().__init__() 138 | 139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 142 | 143 | self.out_dim = in_channels + dim 144 | 145 | def forward(self, x): 146 | 147 | yz_embed = self.yz_plane_embedder(x) 148 | xz_embed = self.xz_plane_embedder(x) 149 | xy_embed = self.xy_plane_embedder(x) 150 | 151 | embed = yz_embed + xz_embed + xy_embed 152 | 153 | return embed 154 | 155 | 156 | def sequential_pos_embed(num_len, embed_dim): 157 | assert embed_dim % 2 == 0 158 | 159 | pos = torch.arange(num_len, dtype=torch.float32) 160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32) 161 | omega /= embed_dim / 2. 162 | omega = 1. / 10000 ** omega # (D/2,) 163 | 164 | pos = pos.reshape(-1) # (M,) 165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 166 | 167 | emb_sin = torch.sin(out) # (M, D/2) 168 | emb_cos = torch.cos(out) # (M, D/2) 169 | 170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 171 | 172 | return embeddings 173 | 174 | 175 | def timestep_embedding(timesteps, dim, max_period=10000): 176 | """ 177 | Create sinusoidal timestep embeddings. 178 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 179 | These may be fractional. 180 | :param dim: the dimension of the output. 181 | :param max_period: controls the minimum frequency of the embeddings. 182 | :return: an [N x dim] Tensor of positional embeddings. 183 | """ 184 | half = dim // 2 185 | freqs = torch.exp( 186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 187 | ).to(device=timesteps.device) 188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None] 189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 190 | if dim % 2: 191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 192 | return embedding 193 | 194 | 195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, 196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, 197 | log2_hashmap_size=19, desired_resolution=None): 198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): 199 | return nn.Identity(), input_dim 200 | 201 | elif embed_type == "fourier": 202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, 203 | logspace=True, include_input=True) 204 | return embedder_obj, embedder_obj.out_dim 205 | 206 | elif embed_type == "hashgrid": 207 | raise NotImplementedError 208 | 209 | elif embed_type == "sphere_harmonic": 210 | raise NotImplementedError 211 | 212 | else: 213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") 214 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/transformer_blocks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Optional 8 | 9 | from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint 10 | 11 | 12 | def init_linear(l, stddev): 13 | nn.init.normal_(l.weight, std=stddev) 14 | if l.bias is not None: 15 | nn.init.constant_(l.bias, 0.0) 16 | 17 | 18 | class MultiheadAttention(nn.Module): 19 | def __init__( 20 | self, 21 | *, 22 | device: torch.device, 23 | dtype: torch.dtype, 24 | n_ctx: int, 25 | width: int, 26 | heads: int, 27 | init_scale: float, 28 | qkv_bias: bool, 29 | flash: bool = False 30 | ): 31 | super().__init__() 32 | self.n_ctx = n_ctx 33 | self.width = width 34 | self.heads = heads 35 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) 36 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 37 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) 38 | init_linear(self.c_qkv, init_scale) 39 | init_linear(self.c_proj, init_scale) 40 | 41 | def forward(self, x): 42 | x = self.c_qkv(x) 43 | x = checkpoint(self.attention, (x,), (), True) 44 | x = self.c_proj(x) 45 | return x 46 | 47 | 48 | class QKVMultiheadAttention(nn.Module): 49 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): 50 | super().__init__() 51 | self.device = device 52 | self.dtype = dtype 53 | self.heads = heads 54 | self.n_ctx = n_ctx 55 | self.flash = flash 56 | 57 | def forward(self, qkv): 58 | bs, n_ctx, width = qkv.shape 59 | attn_ch = width // self.heads // 3 60 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 61 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 62 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 63 | 64 | if self.flash: 65 | out = F.scaled_dot_product_attention(q, k, v) 66 | else: 67 | weight = torch.einsum( 68 | "bthc,bshc->bhts", q * scale, k * scale 69 | ) # More stable with f16 than dividing afterwards 70 | wdtype = weight.dtype 71 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 72 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 73 | 74 | return out 75 | 76 | 77 | class ResidualAttentionBlock(nn.Module): 78 | def __init__( 79 | self, 80 | *, 81 | device: torch.device, 82 | dtype: torch.dtype, 83 | n_ctx: int, 84 | width: int, 85 | heads: int, 86 | init_scale: float = 1.0, 87 | qkv_bias: bool = True, 88 | flash: bool = False, 89 | use_checkpoint: bool = False 90 | ): 91 | super().__init__() 92 | 93 | self.use_checkpoint = use_checkpoint 94 | 95 | self.attn = MultiheadAttention( 96 | device=device, 97 | dtype=dtype, 98 | n_ctx=n_ctx, 99 | width=width, 100 | heads=heads, 101 | init_scale=init_scale, 102 | qkv_bias=qkv_bias, 103 | flash=flash 104 | ) 105 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 106 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 107 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) 108 | 109 | def _forward(self, x: torch.Tensor): 110 | x = x + self.attn(self.ln_1(x)) 111 | x = x + self.mlp(self.ln_2(x)) 112 | return x 113 | 114 | def forward(self, x: torch.Tensor): 115 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 116 | 117 | 118 | class MultiheadCrossAttention(nn.Module): 119 | def __init__( 120 | self, 121 | *, 122 | device: torch.device, 123 | dtype: torch.dtype, 124 | width: int, 125 | heads: int, 126 | init_scale: float, 127 | qkv_bias: bool = True, 128 | flash: bool = False, 129 | n_data: Optional[int] = None, 130 | data_width: Optional[int] = None, 131 | ): 132 | super().__init__() 133 | self.n_data = n_data 134 | self.width = width 135 | self.heads = heads 136 | self.data_width = width if data_width is None else data_width 137 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) 138 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) 139 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 140 | self.attention = QKVMultiheadCrossAttention( 141 | device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash 142 | ) 143 | init_linear(self.c_q, init_scale) 144 | init_linear(self.c_kv, init_scale) 145 | init_linear(self.c_proj, init_scale) 146 | 147 | def forward(self, x, data): 148 | x = self.c_q(x) 149 | data = self.c_kv(data) 150 | x = checkpoint(self.attention, (x, data), (), True) 151 | x = self.c_proj(x) 152 | return x 153 | 154 | 155 | class QKVMultiheadCrossAttention(nn.Module): 156 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, 157 | flash: bool = False, n_data: Optional[int] = None): 158 | 159 | super().__init__() 160 | self.device = device 161 | self.dtype = dtype 162 | self.heads = heads 163 | self.n_data = n_data 164 | self.flash = flash 165 | 166 | def forward(self, q, kv): 167 | _, n_ctx, _ = q.shape 168 | bs, n_data, width = kv.shape 169 | attn_ch = width // self.heads // 2 170 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 171 | q = q.view(bs, n_ctx, self.heads, -1) 172 | kv = kv.view(bs, n_data, self.heads, -1) 173 | k, v = torch.split(kv, attn_ch, dim=-1) 174 | 175 | if self.flash: 176 | out = F.scaled_dot_product_attention(q, k, v) 177 | else: 178 | weight = torch.einsum( 179 | "bthc,bshc->bhts", q * scale, k * scale 180 | ) # More stable with f16 than dividing afterwards 181 | wdtype = weight.dtype 182 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 183 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 184 | 185 | return out 186 | 187 | 188 | class ResidualCrossAttentionBlock(nn.Module): 189 | def __init__( 190 | self, 191 | *, 192 | device: Optional[torch.device], 193 | dtype: Optional[torch.dtype], 194 | n_data: Optional[int] = None, 195 | width: int, 196 | heads: int, 197 | data_width: Optional[int] = None, 198 | init_scale: float = 0.25, 199 | qkv_bias: bool = True, 200 | flash: bool = False 201 | ): 202 | super().__init__() 203 | 204 | if data_width is None: 205 | data_width = width 206 | 207 | self.attn = MultiheadCrossAttention( 208 | device=device, 209 | dtype=dtype, 210 | n_data=n_data, 211 | width=width, 212 | heads=heads, 213 | data_width=data_width, 214 | init_scale=init_scale, 215 | qkv_bias=qkv_bias, 216 | flash=flash, 217 | ) 218 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 219 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 220 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 221 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 222 | 223 | def forward(self, x: torch.Tensor, data: torch.Tensor): 224 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 225 | x = x + self.mlp(self.ln_3(x)) 226 | return x 227 | 228 | 229 | class MLP(nn.Module): 230 | def __init__(self, *, 231 | device: Optional[torch.device], 232 | dtype: Optional[torch.dtype], 233 | width: int, 234 | init_scale: float): 235 | super().__init__() 236 | self.width = width 237 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 238 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 239 | self.gelu = nn.GELU() 240 | init_linear(self.c_fc, init_scale) 241 | init_linear(self.c_proj, init_scale) 242 | 243 | def forward(self, x): 244 | return self.c_proj(self.gelu(self.c_fc(x))) 245 | 246 | 247 | class Transformer(nn.Module): 248 | def __init__( 249 | self, 250 | *, 251 | device: Optional[torch.device], 252 | dtype: Optional[torch.dtype], 253 | n_ctx: int, 254 | width: int, 255 | layers: int, 256 | heads: int, 257 | init_scale: float = 0.25, 258 | qkv_bias: bool = True, 259 | flash: bool = False, 260 | use_checkpoint: bool = False 261 | ): 262 | super().__init__() 263 | self.n_ctx = n_ctx 264 | self.width = width 265 | self.layers = layers 266 | self.resblocks = nn.ModuleList( 267 | [ 268 | ResidualAttentionBlock( 269 | device=device, 270 | dtype=dtype, 271 | n_ctx=n_ctx, 272 | width=width, 273 | heads=heads, 274 | init_scale=init_scale, 275 | qkv_bias=qkv_bias, 276 | flash=flash, 277 | use_checkpoint=use_checkpoint 278 | ) 279 | for _ in range(layers) 280 | ] 281 | ) 282 | 283 | def forward(self, x: torch.Tensor): 284 | for block in self.resblocks: 285 | x = block(x) 286 | return x 287 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/modules/transformer_vit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from typing import Optional 7 | import warnings 8 | 9 | from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint 10 | 11 | 12 | def _trunc_normal_(tensor, mean, std, a, b): 13 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 14 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 15 | def norm_cdf(x): 16 | # Computes standard normal cumulative distribution function 17 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 18 | 19 | if (mean < a - 2 * std) or (mean > b + 2 * std): 20 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 21 | "The distribution of values may be incorrect.", 22 | stacklevel=2) 23 | 24 | # Values are generated by using a truncated uniform distribution and 25 | # then using the inverse CDF for the normal distribution. 26 | # Get upper and lower cdf values 27 | l = norm_cdf((a - mean) / std) 28 | u = norm_cdf((b - mean) / std) 29 | 30 | # Uniformly fill tensor with values from [l, u], then translate to 31 | # [2l-1, 2u-1]. 32 | tensor.uniform_(2 * l - 1, 2 * u - 1) 33 | 34 | # Use inverse cdf transform for normal distribution to get truncated 35 | # standard normal 36 | tensor.erfinv_() 37 | 38 | # Transform to proper mean, std 39 | tensor.mul_(std * math.sqrt(2.)) 40 | tensor.add_(mean) 41 | 42 | # Clamp to ensure it's in the proper range 43 | tensor.clamp_(min=a, max=b) 44 | return tensor 45 | 46 | 47 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 48 | # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor 49 | r"""Fills the input Tensor with values drawn from a truncated 50 | normal distribution. The values are effectively drawn from the 51 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 52 | with values outside :math:`[a, b]` redrawn until they are within 53 | the bounds. The method used for generating the random values works 54 | best when :math:`a \leq \text{mean} \leq b`. 55 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 56 | applied while sampling the normal with mean/std applied, therefore a, b args 57 | should be adjusted to match the range of mean, std args. 58 | Args: 59 | tensor: an n-dimensional `torch.Tensor` 60 | mean: the mean of the normal distribution 61 | std: the standard deviation of the normal distribution 62 | a: the minimum cutoff value 63 | b: the maximum cutoff value 64 | Examples: 65 | >>> w = torch.empty(3, 5) 66 | >>> nn.init.trunc_normal_(w) 67 | """ 68 | with torch.no_grad(): 69 | return _trunc_normal_(tensor, mean, std, a, b) 70 | 71 | 72 | def init_weights(m): 73 | if isinstance(m, nn.Linear): 74 | trunc_normal_(m.weight, std=.02) 75 | if isinstance(m, nn.Linear) and m.bias is not None: 76 | nn.init.constant_(m.bias, 0) 77 | elif isinstance(m, nn.LayerNorm): 78 | nn.init.constant_(m.bias, 0) 79 | nn.init.constant_(m.weight, 1.0) 80 | 81 | 82 | class MultiheadAttention(nn.Module): 83 | def __init__( 84 | self, 85 | *, 86 | device: torch.device, 87 | dtype: torch.dtype, 88 | n_ctx: int, 89 | width: int, 90 | heads: int, 91 | qkv_bias: bool 92 | ): 93 | super().__init__() 94 | self.n_ctx = n_ctx 95 | self.width = width 96 | self.heads = heads 97 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) 98 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 99 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) 100 | 101 | def forward(self, x): 102 | x = self.c_qkv(x) 103 | x = checkpoint(self.attention, (x,), (), True) 104 | x = self.c_proj(x) 105 | return x 106 | 107 | 108 | class QKVMultiheadAttention(nn.Module): 109 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): 110 | super().__init__() 111 | self.device = device 112 | self.dtype = dtype 113 | self.heads = heads 114 | self.n_ctx = n_ctx 115 | 116 | def forward(self, qkv): 117 | bs, n_ctx, width = qkv.shape 118 | attn_ch = width // self.heads // 3 119 | scale = 1 / math.sqrt(attn_ch) 120 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 121 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 122 | weight = torch.einsum("bthc,bshc->bhts", q, k) * scale 123 | wdtype = weight.dtype 124 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 125 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 126 | 127 | 128 | class ResidualAttentionBlock(nn.Module): 129 | def __init__( 130 | self, 131 | *, 132 | device: torch.device, 133 | dtype: torch.dtype, 134 | n_ctx: int, 135 | width: int, 136 | heads: int, 137 | qkv_bias: bool = True, 138 | use_checkpoint: bool = False 139 | ): 140 | super().__init__() 141 | 142 | self.use_checkpoint = use_checkpoint 143 | 144 | self.attn = MultiheadAttention( 145 | device=device, 146 | dtype=dtype, 147 | n_ctx=n_ctx, 148 | width=width, 149 | heads=heads, 150 | qkv_bias=qkv_bias 151 | ) 152 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 153 | self.mlp = MLP(device=device, dtype=dtype, width=width) 154 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) 155 | 156 | def _forward(self, x: torch.Tensor): 157 | x = x + self.attn(self.ln_1(x)) 158 | x = x + self.mlp(self.ln_2(x)) 159 | return x 160 | 161 | def forward(self, x: torch.Tensor): 162 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 163 | 164 | 165 | class MultiheadCrossAttention(nn.Module): 166 | def __init__( 167 | self, 168 | *, 169 | device: torch.device, 170 | dtype: torch.dtype, 171 | width: int, 172 | heads: int, 173 | qkv_bias: bool = True, 174 | n_data: Optional[int] = None, 175 | data_width: Optional[int] = None, 176 | ): 177 | super().__init__() 178 | self.n_data = n_data 179 | self.width = width 180 | self.heads = heads 181 | self.data_width = width if data_width is None else data_width 182 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) 183 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) 184 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 185 | self.attention = QKVMultiheadCrossAttention( 186 | device=device, dtype=dtype, heads=heads, n_data=n_data 187 | ) 188 | 189 | def forward(self, x, data): 190 | x = self.c_q(x) 191 | data = self.c_kv(data) 192 | x = checkpoint(self.attention, (x, data), (), True) 193 | x = self.c_proj(x) 194 | return x 195 | 196 | 197 | class QKVMultiheadCrossAttention(nn.Module): 198 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None): 199 | super().__init__() 200 | self.device = device 201 | self.dtype = dtype 202 | self.heads = heads 203 | self.n_data = n_data 204 | 205 | def forward(self, q, kv): 206 | _, n_ctx, _ = q.shape 207 | bs, n_data, width = kv.shape 208 | attn_ch = width // self.heads // 2 209 | scale = 1 / math.sqrt(attn_ch) 210 | q = q.view(bs, n_ctx, self.heads, -1) 211 | kv = kv.view(bs, n_data, self.heads, -1) 212 | k, v = torch.split(kv, attn_ch, dim=-1) 213 | weight = torch.einsum("bthc,bshc->bhts", q, k) * scale 214 | wdtype = weight.dtype 215 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 216 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 217 | 218 | 219 | class ResidualCrossAttentionBlock(nn.Module): 220 | def __init__( 221 | self, 222 | *, 223 | device: Optional[torch.device], 224 | dtype: Optional[torch.dtype], 225 | n_data: Optional[int] = None, 226 | width: int, 227 | heads: int, 228 | data_width: Optional[int] = None, 229 | qkv_bias: bool = True 230 | ): 231 | super().__init__() 232 | 233 | if data_width is None: 234 | data_width = width 235 | 236 | self.attn = MultiheadCrossAttention( 237 | device=device, 238 | dtype=dtype, 239 | n_data=n_data, 240 | width=width, 241 | heads=heads, 242 | data_width=data_width, 243 | qkv_bias=qkv_bias 244 | ) 245 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 246 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 247 | self.mlp = MLP(device=device, dtype=dtype, width=width) 248 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 249 | 250 | def forward(self, x: torch.Tensor, data: torch.Tensor): 251 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 252 | x = x + self.mlp(self.ln_3(x)) 253 | return x 254 | 255 | 256 | class MLP(nn.Module): 257 | def __init__(self, *, 258 | device: Optional[torch.device], 259 | dtype: Optional[torch.dtype], 260 | width: int): 261 | super().__init__() 262 | self.width = width 263 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 264 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 265 | self.gelu = nn.GELU() 266 | 267 | def forward(self, x): 268 | return self.c_proj(self.gelu(self.c_fc(x))) 269 | 270 | 271 | class Transformer(nn.Module): 272 | def __init__( 273 | self, 274 | *, 275 | device: Optional[torch.device], 276 | dtype: Optional[torch.dtype], 277 | n_ctx: int, 278 | width: int, 279 | layers: int, 280 | heads: int, 281 | qkv_bias: bool = True, 282 | use_checkpoint: bool = False 283 | ): 284 | super().__init__() 285 | self.n_ctx = n_ctx 286 | self.width = width 287 | self.layers = layers 288 | self.resblocks = nn.ModuleList( 289 | [ 290 | ResidualAttentionBlock( 291 | device=device, 292 | dtype=dtype, 293 | n_ctx=n_ctx, 294 | width=width, 295 | heads=heads, 296 | qkv_bias=qkv_bias, 297 | use_checkpoint=use_checkpoint 298 | ) 299 | for _ in range(layers) 300 | ] 301 | ) 302 | 303 | self.apply(init_weights) 304 | 305 | def forward(self, x: torch.Tensor): 306 | for block in self.resblocks: 307 | x = block(x) 308 | return x 309 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/asl_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import List, Tuple, Dict, Optional 4 | from omegaconf import DictConfig 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch.optim import lr_scheduler 10 | from typing import Union 11 | from functools import partial 12 | 13 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 14 | 15 | from .tsal_base import ( 16 | AlignedShapeAsLatentModule, 17 | ShapeAsLatentModule, 18 | Latent2MeshOutput, 19 | AlignedMeshOutput 20 | ) 21 | from MeshAnything.miche.michelangelo.models.tsal.inference_utils import extract_geometry 22 | import trimesh 23 | 24 | class AlignedShapeAsLatentPLModule(nn.Module): 25 | def __init__(self, *, 26 | shape_module_cfg, 27 | aligned_module_cfg, 28 | loss_cfg, 29 | optimizer_cfg: Optional[DictConfig] = None, 30 | ckpt_path: Optional[str] = None, 31 | ignore_keys: Union[Tuple[str], List[str]] = ()): 32 | 33 | super().__init__() 34 | 35 | shape_model: ShapeAsLatentModule = instantiate_from_config( 36 | shape_module_cfg, device=None, dtype=None 37 | ) 38 | self.model: AlignedShapeAsLatentModule = instantiate_from_config( 39 | aligned_module_cfg, shape_model=shape_model 40 | ) 41 | 42 | self.loss = instantiate_from_config(loss_cfg) 43 | 44 | self.optimizer_cfg = optimizer_cfg 45 | 46 | if ckpt_path is not None: 47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 48 | 49 | def set_shape_model_only(self): 50 | self.model.set_shape_model_only() 51 | 52 | 53 | 54 | @property 55 | def latent_shape(self): 56 | return self.model.shape_model.latent_shape 57 | 58 | @property 59 | def zero_rank(self): 60 | if self._trainer: 61 | zero_rank = self.trainer.local_rank == 0 62 | else: 63 | zero_rank = True 64 | 65 | return zero_rank 66 | 67 | def init_from_ckpt(self, path, ignore_keys=()): 68 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 69 | 70 | keys = list(state_dict.keys()) 71 | for k in keys: 72 | for ik in ignore_keys: 73 | if k.startswith(ik): 74 | print("Deleting key {} from state_dict.".format(k)) 75 | del state_dict[k] 76 | 77 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 78 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 79 | if len(missing) > 0: 80 | print(f"Missing Keys: {missing}") 81 | print(f"Unexpected Keys: {unexpected}") 82 | 83 | def configure_optimizers(self) -> Tuple[List, List]: 84 | lr = self.learning_rate 85 | 86 | trainable_parameters = list(self.model.parameters()) 87 | 88 | if self.optimizer_cfg is None: 89 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 90 | schedulers = [] 91 | else: 92 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 93 | scheduler_func = instantiate_from_config( 94 | self.optimizer_cfg.scheduler, 95 | max_decay_steps=self.trainer.max_steps, 96 | lr_max=lr 97 | ) 98 | scheduler = { 99 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 100 | "interval": "step", 101 | "frequency": 1 102 | } 103 | optimizers = [optimizer] 104 | schedulers = [scheduler] 105 | 106 | return optimizers, schedulers 107 | 108 | def forward(self, 109 | surface: torch.FloatTensor, 110 | image: torch.FloatTensor, 111 | text: torch.FloatTensor, 112 | volume_queries: torch.FloatTensor): 113 | 114 | """ 115 | 116 | Args: 117 | surface (torch.FloatTensor): 118 | image (torch.FloatTensor): 119 | text (torch.FloatTensor): 120 | volume_queries (torch.FloatTensor): 121 | 122 | Returns: 123 | 124 | """ 125 | 126 | embed_outputs, shape_z = self.model(surface, image, text) 127 | 128 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) 129 | latents = self.model.shape_model.decode(shape_zq) 130 | logits = self.model.shape_model.query_geometry(volume_queries, latents) 131 | 132 | return embed_outputs, logits, posterior 133 | 134 | def encode(self, surface: torch.FloatTensor, sample_posterior=True): 135 | 136 | pc = surface[..., 0:3] 137 | feats = surface[..., 3:6] 138 | 139 | shape_embed, shape_zq, posterior = self.model.shape_model.encode( 140 | pc=pc, feats=feats, sample_posterior=sample_posterior 141 | ) 142 | 143 | return shape_zq 144 | 145 | def encode_latents(self, surface: torch.FloatTensor): 146 | 147 | pc = surface[..., 0:3] 148 | feats = surface[..., 3:6] 149 | 150 | shape_embed, shape_latents = self.model.shape_model.encode_latents( 151 | pc=pc, feats=feats 152 | ) 153 | shape_embed = shape_embed.unsqueeze(1) 154 | assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256 155 | cat_latents = torch.cat([shape_embed, shape_latents], dim=1) 156 | 157 | return cat_latents 158 | 159 | def recon(self, surface): 160 | cat_latents = self.encode_latents(surface) 161 | shape_latents = cat_latents[:, 1:] 162 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_latents) 163 | 164 | # decoding 165 | latents = self.model.shape_model.decode(shape_zq) 166 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 167 | 168 | # reconstruction 169 | mesh_v_f, has_surface = extract_geometry( 170 | geometric_func=geometric_func, 171 | device=surface.device, 172 | batch_size=surface.shape[0], 173 | bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 174 | octree_depth=7, 175 | num_chunks=10000, 176 | ) 177 | recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1]) 178 | 179 | return recon_mesh 180 | 181 | 182 | def to_shape_latents(self, latents): 183 | 184 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False) 185 | return self.model.shape_model.decode(shape_zq) 186 | 187 | def decode(self, 188 | z_q, 189 | bounds: Union[Tuple[float], List[float], float] = 1.1, 190 | octree_depth: int = 7, 191 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 192 | 193 | latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim] 194 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) 195 | 196 | return outputs 197 | 198 | def training_step(self, batch: Dict[str, torch.FloatTensor], 199 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 200 | """ 201 | 202 | Args: 203 | batch (dict): the batch sample, and it contains: 204 | - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] 205 | - image (torch.FloatTensor): [bs, 3, 224, 224] 206 | - text (torch.FloatTensor): [bs, num_templates, 77] 207 | - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] 208 | 209 | batch_idx (int): 210 | 211 | optimizer_idx (int): 212 | 213 | Returns: 214 | loss (torch.FloatTensor): 215 | 216 | """ 217 | 218 | surface = batch["surface"] 219 | image = batch["image"] 220 | text = batch["text"] 221 | 222 | volume_queries = batch["geo_points"][..., 0:3] 223 | shape_labels = batch["geo_points"][..., -1] 224 | 225 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) 226 | 227 | aeloss, log_dict_ae = self.loss( 228 | **embed_outputs, 229 | posteriors=posteriors, 230 | shape_logits=shape_logits, 231 | shape_labels=shape_labels, 232 | split="train" 233 | ) 234 | 235 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], 236 | sync_dist=False, rank_zero_only=True) 237 | 238 | return aeloss 239 | 240 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: 241 | 242 | surface = batch["surface"] 243 | image = batch["image"] 244 | text = batch["text"] 245 | 246 | volume_queries = batch["geo_points"][..., 0:3] 247 | shape_labels = batch["geo_points"][..., -1] 248 | 249 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) 250 | 251 | aeloss, log_dict_ae = self.loss( 252 | **embed_outputs, 253 | posteriors=posteriors, 254 | shape_logits=shape_logits, 255 | shape_labels=shape_labels, 256 | split="val" 257 | ) 258 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], 259 | sync_dist=False, rank_zero_only=True) 260 | 261 | return aeloss 262 | 263 | def visual_alignment(self, 264 | surface: torch.FloatTensor, 265 | image: torch.FloatTensor, 266 | text: torch.FloatTensor, 267 | description: Optional[List[str]] = None, 268 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 269 | octree_depth: int = 7, 270 | num_chunks: int = 10000) -> List[AlignedMeshOutput]: 271 | 272 | """ 273 | 274 | Args: 275 | surface: 276 | image: 277 | text: 278 | description: 279 | bounds: 280 | octree_depth: 281 | num_chunks: 282 | 283 | Returns: 284 | mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list. 285 | 286 | """ 287 | 288 | outputs = [] 289 | 290 | device = surface.device 291 | bs = surface.shape[0] 292 | 293 | embed_outputs, shape_z = self.model(surface, image, text) 294 | 295 | # calculate the similarity 296 | image_embed = embed_outputs["image_embed"] 297 | text_embed = embed_outputs["text_embed"] 298 | shape_embed = embed_outputs["shape_embed"] 299 | 300 | # normalized features 301 | shape_embed = F.normalize(shape_embed, dim=-1, p=2) 302 | text_embed = F.normalize(text_embed, dim=-1, p=2) 303 | image_embed = F.normalize(image_embed, dim=-1, p=2) 304 | 305 | # B x B 306 | shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1) 307 | 308 | # B x B 309 | shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1) 310 | 311 | # shape reconstruction 312 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) 313 | latents = self.model.shape_model.decode(shape_zq) 314 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 315 | 316 | # 2. decode geometry 317 | mesh_v_f, has_surface = extract_geometry( 318 | geometric_func=geometric_func, 319 | device=device, 320 | batch_size=bs, 321 | bounds=bounds, 322 | octree_depth=octree_depth, 323 | num_chunks=num_chunks, 324 | disable=not self.zero_rank 325 | ) 326 | 327 | # 3. decode texture 328 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 329 | if not is_surface: 330 | outputs.append(None) 331 | continue 332 | 333 | out = AlignedMeshOutput() 334 | out.mesh_v = mesh_v 335 | out.mesh_f = mesh_f 336 | out.surface = surface[i].cpu().numpy() 337 | out.image = image[i].cpu().numpy() 338 | if description is not None: 339 | out.text = description[i] 340 | out.shape_text_similarity = shape_text_similarity[i, i] 341 | out.shape_image_similarity = shape_image_similarity[i, i] 342 | 343 | outputs.append(out) 344 | 345 | return outputs 346 | 347 | def latent2mesh(self, 348 | latents: torch.FloatTensor, 349 | bounds: Union[Tuple[float], List[float], float] = 1.1, 350 | octree_depth: int = 7, 351 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 352 | 353 | """ 354 | 355 | Args: 356 | latents: [bs, num_latents, dim] 357 | bounds: 358 | octree_depth: 359 | num_chunks: 360 | 361 | Returns: 362 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 363 | 364 | """ 365 | 366 | outputs = [] 367 | 368 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 369 | 370 | # 2. decode geometry 371 | device = latents.device 372 | mesh_v_f, has_surface = extract_geometry( 373 | geometric_func=geometric_func, 374 | device=device, 375 | batch_size=len(latents), 376 | bounds=bounds, 377 | octree_depth=octree_depth, 378 | num_chunks=num_chunks, 379 | disable=not self.zero_rank 380 | ) 381 | 382 | # 3. decode texture 383 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 384 | if not is_surface: 385 | outputs.append(None) 386 | continue 387 | 388 | out = Latent2MeshOutput() 389 | out.mesh_v = mesh_v 390 | out.mesh_f = mesh_f 391 | 392 | outputs.append(out) 393 | 394 | return outputs 395 | 396 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/clip_asl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange 6 | from transformers import CLIPModel 7 | 8 | from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule 9 | 10 | 11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): 12 | 13 | def __init__(self, *, 14 | shape_model, 15 | clip_model_version: str = "openai/clip-vit-large-patch14"): 16 | 17 | super().__init__() 18 | 19 | # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) 20 | # for params in self.clip_model.parameters(): 21 | # params.requires_grad = False 22 | self.clip_model = None 23 | self.shape_model = shape_model 24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width)) 25 | # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5) 26 | 27 | def set_shape_model_only(self): 28 | self.clip_model = None 29 | 30 | def encode_shape_embed(self, surface, return_latents: bool = False): 31 | """ 32 | 33 | Args: 34 | surface (torch.FloatTensor): [bs, n, 3 + c] 35 | return_latents (bool): 36 | 37 | Returns: 38 | x (torch.FloatTensor): [bs, projection_dim] 39 | shape_latents (torch.FloatTensor): [bs, m, d] 40 | """ 41 | 42 | pc = surface[..., 0:3] 43 | feats = surface[..., 3:] 44 | 45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) 46 | x = shape_embed @ self.shape_projection 47 | 48 | if return_latents: 49 | return x, shape_latents 50 | else: 51 | return x 52 | 53 | def encode_image_embed(self, image): 54 | """ 55 | 56 | Args: 57 | image (torch.FloatTensor): [bs, 3, h, w] 58 | 59 | Returns: 60 | x (torch.FloatTensor): [bs, projection_dim] 61 | """ 62 | 63 | x = self.clip_model.get_image_features(image) 64 | 65 | return x 66 | 67 | def encode_text_embed(self, text): 68 | x = self.clip_model.get_text_features(text) 69 | return x 70 | 71 | def forward(self, surface, image, text): 72 | """ 73 | 74 | Args: 75 | surface (torch.FloatTensor): 76 | image (torch.FloatTensor): [bs, 3, 224, 224] 77 | text (torch.LongTensor): [bs, num_templates, 77] 78 | 79 | Returns: 80 | embed_outputs (dict): the embedding outputs, and it contains: 81 | - image_embed (torch.FloatTensor): 82 | - text_embed (torch.FloatTensor): 83 | - shape_embed (torch.FloatTensor): 84 | - logit_scale (float): 85 | """ 86 | 87 | # # text embedding 88 | # text_embed_all = [] 89 | # for i in range(text.shape[0]): 90 | # text_for_one_sample = text[i] 91 | # text_embed = self.encode_text_embed(text_for_one_sample) 92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 93 | # text_embed = text_embed.mean(dim=0) 94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 95 | # text_embed_all.append(text_embed) 96 | # text_embed_all = torch.stack(text_embed_all) 97 | 98 | b = text.shape[0] 99 | text_tokens = rearrange(text, "b t l -> (b t) l") 100 | text_embed = self.encode_text_embed(text_tokens) 101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 102 | text_embed = text_embed.mean(dim=1) 103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 104 | 105 | # image embedding 106 | image_embed = self.encode_image_embed(image) 107 | 108 | # shape embedding 109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) 110 | 111 | embed_outputs = { 112 | "image_embed": image_embed, 113 | "text_embed": text_embed, 114 | "shape_embed": shape_embed, 115 | # "logit_scale": self.clip_model.logit_scale.exp() 116 | } 117 | 118 | return embed_outputs, shape_latents 119 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from einops import repeat 6 | import numpy as np 7 | from typing import Callable, Tuple, List, Union, Optional 8 | from skimage import measure 9 | 10 | from MeshAnything.miche.michelangelo.graphics.primitives import generate_dense_grid_points 11 | 12 | 13 | @torch.no_grad() 14 | def extract_geometry(geometric_func: Callable, 15 | device: torch.device, 16 | batch_size: int = 1, 17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 18 | octree_depth: int = 7, 19 | num_chunks: int = 10000, 20 | disable: bool = True): 21 | """ 22 | 23 | Args: 24 | geometric_func: 25 | device: 26 | bounds: 27 | octree_depth: 28 | batch_size: 29 | num_chunks: 30 | disable: 31 | 32 | Returns: 33 | 34 | """ 35 | 36 | if isinstance(bounds, float): 37 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] 38 | 39 | bbox_min = np.array(bounds[0:3]) 40 | bbox_max = np.array(bounds[3:6]) 41 | bbox_size = bbox_max - bbox_min 42 | 43 | xyz_samples, grid_size, length = generate_dense_grid_points( 44 | bbox_min=bbox_min, 45 | bbox_max=bbox_max, 46 | octree_depth=octree_depth, 47 | indexing="ij" 48 | ) 49 | xyz_samples = torch.FloatTensor(xyz_samples) 50 | 51 | batch_logits = [] 52 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), 53 | desc="Implicit Function:", disable=disable, leave=False): 54 | queries = xyz_samples[start: start + num_chunks, :].to(device) 55 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size) 56 | 57 | logits = geometric_func(batch_queries) 58 | batch_logits.append(logits.cpu()) 59 | 60 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() 61 | 62 | mesh_v_f = [] 63 | has_surface = np.zeros((batch_size,), dtype=np.bool_) 64 | for i in range(batch_size): 65 | try: 66 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") 67 | vertices = vertices / grid_size * bbox_size + bbox_min 68 | # vertices[:, [0, 1]] = vertices[:, [1, 0]] 69 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) 70 | has_surface[i] = True 71 | 72 | except ValueError: 73 | mesh_v_f.append((None, None)) 74 | has_surface[i] = False 75 | 76 | except RuntimeError: 77 | mesh_v_f.append((None, None)) 78 | has_surface[i] = False 79 | 80 | return mesh_v_f, has_surface 81 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from typing import Optional, Tuple, Dict 7 | 8 | from MeshAnything.miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution 9 | from MeshAnything.miche.michelangelo.utils.eval import compute_psnr 10 | from MeshAnything.miche.michelangelo.utils import misc 11 | 12 | 13 | class KLNearFar(nn.Module): 14 | def __init__(self, 15 | near_weight: float = 0.1, 16 | kl_weight: float = 1.0, 17 | num_near_samples: Optional[int] = None): 18 | 19 | super().__init__() 20 | 21 | self.near_weight = near_weight 22 | self.kl_weight = kl_weight 23 | self.num_near_samples = num_near_samples 24 | self.geo_criterion = nn.BCEWithLogitsLoss() 25 | 26 | def forward(self, 27 | posteriors: Optional[DiagonalGaussianDistribution], 28 | logits: torch.FloatTensor, 29 | labels: torch.FloatTensor, 30 | split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: 31 | 32 | """ 33 | 34 | Args: 35 | posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): 36 | logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; 37 | labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; 38 | split (str): 39 | **kwargs: 40 | 41 | Returns: 42 | loss (torch.Tensor): (,) 43 | log (dict): 44 | 45 | """ 46 | 47 | if self.num_near_samples is None: 48 | num_vol = logits.shape[1] // 2 49 | else: 50 | num_vol = logits.shape[1] - self.num_near_samples 51 | 52 | vol_logits = logits[:, 0:num_vol] 53 | vol_labels = labels[:, 0:num_vol] 54 | 55 | near_logits = logits[:, num_vol:] 56 | near_labels = labels[:, num_vol:] 57 | 58 | # occupancy loss 59 | # vol_bce = self.geo_criterion(vol_logits, vol_labels) 60 | # near_bce = self.geo_criterion(near_logits, near_labels) 61 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 62 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 63 | 64 | if posteriors is None: 65 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) 66 | else: 67 | kl_loss = posteriors.kl(dims=(1, 2)) 68 | kl_loss = torch.mean(kl_loss) 69 | 70 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight 71 | 72 | with torch.no_grad(): 73 | preds = logits >= 0 74 | accuracy = (preds == labels).float() 75 | accuracy = accuracy.mean() 76 | pos_ratio = torch.mean(labels) 77 | 78 | log = { 79 | "{}/total_loss".format(split): loss.clone().detach(), 80 | "{}/near".format(split): near_bce.detach(), 81 | "{}/far".format(split): vol_bce.detach(), 82 | "{}/kl".format(split): kl_loss.detach(), 83 | "{}/accuracy".format(split): accuracy, 84 | "{}/pos_ratio".format(split): pos_ratio 85 | } 86 | 87 | if posteriors is not None: 88 | log[f"{split}/mean"] = posteriors.mean.mean().detach() 89 | log[f"{split}/std_mean"] = posteriors.std.mean().detach() 90 | log[f"{split}/std_max"] = posteriors.std.max().detach() 91 | 92 | return loss, log 93 | 94 | 95 | class KLNearFarColor(nn.Module): 96 | def __init__(self, 97 | near_weight: float = 0.1, 98 | kl_weight: float = 1.0, 99 | color_weight: float = 1.0, 100 | color_criterion: str = "mse", 101 | num_near_samples: Optional[int] = None): 102 | 103 | super().__init__() 104 | 105 | self.color_weight = color_weight 106 | self.near_weight = near_weight 107 | self.kl_weight = kl_weight 108 | self.num_near_samples = num_near_samples 109 | 110 | if color_criterion == "mse": 111 | self.color_criterion = nn.MSELoss() 112 | 113 | elif color_criterion == "l1": 114 | self.color_criterion = nn.L1Loss() 115 | 116 | else: 117 | raise ValueError(f"{color_criterion} must be [`mse`, `l1`].") 118 | 119 | self.geo_criterion = nn.BCEWithLogitsLoss() 120 | 121 | def forward(self, 122 | posteriors: Optional[DiagonalGaussianDistribution], 123 | logits: torch.FloatTensor, 124 | labels: torch.FloatTensor, 125 | pred_colors: torch.FloatTensor, 126 | gt_colors: torch.FloatTensor, 127 | split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: 128 | 129 | """ 130 | 131 | Args: 132 | posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): 133 | logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; 134 | labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; 135 | pred_colors (torch.FloatTensor): [B, M, 3] 136 | gt_colors (torch.FloatTensor): [B, M, 3] 137 | split (str): 138 | **kwargs: 139 | 140 | Returns: 141 | loss (torch.Tensor): (,) 142 | log (dict): 143 | 144 | """ 145 | 146 | if self.num_near_samples is None: 147 | num_vol = logits.shape[1] // 2 148 | else: 149 | num_vol = logits.shape[1] - self.num_near_samples 150 | 151 | vol_logits = logits[:, 0:num_vol] 152 | vol_labels = labels[:, 0:num_vol] 153 | 154 | near_logits = logits[:, num_vol:] 155 | near_labels = labels[:, num_vol:] 156 | 157 | # occupancy loss 158 | # vol_bce = self.geo_criterion(vol_logits, vol_labels) 159 | # near_bce = self.geo_criterion(near_logits, near_labels) 160 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 161 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 162 | 163 | # surface color loss 164 | color = self.color_criterion(pred_colors, gt_colors) 165 | 166 | if posteriors is None: 167 | kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device) 168 | else: 169 | kl_loss = posteriors.kl(dims=(1, 2)) 170 | kl_loss = torch.mean(kl_loss) 171 | 172 | loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight 173 | 174 | with torch.no_grad(): 175 | preds = logits >= 0 176 | accuracy = (preds == labels).float() 177 | accuracy = accuracy.mean() 178 | psnr = compute_psnr(pred_colors, gt_colors) 179 | 180 | log = { 181 | "{}/total_loss".format(split): loss.clone().detach(), 182 | "{}/near".format(split): near_bce.detach(), 183 | "{}/far".format(split): vol_bce.detach(), 184 | "{}/color".format(split): color.detach(), 185 | "{}/kl".format(split): kl_loss.detach(), 186 | "{}/psnr".format(split): psnr.detach(), 187 | "{}/accuracy".format(split): accuracy 188 | } 189 | 190 | return loss, log 191 | 192 | 193 | class ContrastKLNearFar(nn.Module): 194 | def __init__(self, 195 | contrast_weight: float = 1.0, 196 | near_weight: float = 0.1, 197 | kl_weight: float = 1.0, 198 | num_near_samples: Optional[int] = None): 199 | 200 | super().__init__() 201 | 202 | self.labels = None 203 | self.last_local_batch_size = None 204 | 205 | self.contrast_weight = contrast_weight 206 | self.near_weight = near_weight 207 | self.kl_weight = kl_weight 208 | self.num_near_samples = num_near_samples 209 | self.geo_criterion = nn.BCEWithLogitsLoss() 210 | 211 | def forward(self, 212 | shape_embed: torch.FloatTensor, 213 | text_embed: torch.FloatTensor, 214 | image_embed: torch.FloatTensor, 215 | logit_scale: torch.FloatTensor, 216 | posteriors: Optional[DiagonalGaussianDistribution], 217 | shape_logits: torch.FloatTensor, 218 | shape_labels: torch.FloatTensor, 219 | split: Optional[str] = "train", **kwargs): 220 | 221 | local_batch_size = shape_embed.size(0) 222 | 223 | if local_batch_size != self.last_local_batch_size: 224 | self.labels = local_batch_size * misc.get_rank() + torch.arange( 225 | local_batch_size, device=shape_embed.device 226 | ).long() 227 | self.last_local_batch_size = local_batch_size 228 | 229 | # normalized features 230 | shape_embed = F.normalize(shape_embed, dim=-1, p=2) 231 | text_embed = F.normalize(text_embed, dim=-1, p=2) 232 | image_embed = F.normalize(image_embed, dim=-1, p=2) 233 | 234 | # gather features from all GPUs 235 | shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( 236 | [shape_embed, text_embed, image_embed] 237 | ) 238 | 239 | # cosine similarity as logits 240 | logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() 241 | logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() 242 | logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() 243 | logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() 244 | contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + 245 | F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ 246 | (F.cross_entropy(logits_per_shape_image, self.labels) + 247 | F.cross_entropy(logits_per_image_shape, self.labels)) / 2 248 | 249 | # shape reconstruction 250 | if self.num_near_samples is None: 251 | num_vol = shape_logits.shape[1] // 2 252 | else: 253 | num_vol = shape_logits.shape[1] - self.num_near_samples 254 | 255 | vol_logits = shape_logits[:, 0:num_vol] 256 | vol_labels = shape_labels[:, 0:num_vol] 257 | 258 | near_logits = shape_logits[:, num_vol:] 259 | near_labels = shape_labels[:, num_vol:] 260 | 261 | # occupancy loss 262 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 263 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 264 | 265 | if posteriors is None: 266 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) 267 | else: 268 | kl_loss = posteriors.kl(dims=(1, 2)) 269 | kl_loss = torch.mean(kl_loss) 270 | 271 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight 272 | 273 | # compute accuracy 274 | with torch.no_grad(): 275 | pred = torch.argmax(logits_per_shape_text, dim=-1) 276 | correct = pred.eq(self.labels).sum() 277 | shape_text_acc = 100 * correct / local_batch_size 278 | 279 | pred = torch.argmax(logits_per_shape_image, dim=-1) 280 | correct = pred.eq(self.labels).sum() 281 | shape_image_acc = 100 * correct / local_batch_size 282 | 283 | preds = shape_logits >= 0 284 | accuracy = (preds == shape_labels).float() 285 | accuracy = accuracy.mean() 286 | 287 | log = { 288 | "{}/contrast".format(split): contrast_loss.clone().detach(), 289 | "{}/near".format(split): near_bce.detach(), 290 | "{}/far".format(split): vol_bce.detach(), 291 | "{}/kl".format(split): kl_loss.detach(), 292 | "{}/shape_text_acc".format(split): shape_text_acc, 293 | "{}/shape_image_acc".format(split): shape_image_acc, 294 | "{}/total_loss".format(split): loss.clone().detach(), 295 | "{}/accuracy".format(split): accuracy, 296 | } 297 | 298 | if posteriors is not None: 299 | log[f"{split}/mean"] = posteriors.mean.mean().detach() 300 | log[f"{split}/std_mean"] = posteriors.std.mean().detach() 301 | log[f"{split}/std_max"] = posteriors.std.max().detach() 302 | 303 | return loss, log 304 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/sal_perceiver.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Optional 6 | from einops import repeat 7 | import math 8 | 9 | from MeshAnything.miche.michelangelo.models.modules import checkpoint 10 | from MeshAnything.miche.michelangelo.models.modules.embedder import FourierEmbedder 11 | from MeshAnything.miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution 12 | from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import ( 13 | ResidualCrossAttentionBlock, 14 | Transformer 15 | ) 16 | 17 | from .tsal_base import ShapeAsLatentModule 18 | 19 | 20 | class CrossAttentionEncoder(nn.Module): 21 | 22 | def __init__(self, *, 23 | device: Optional[torch.device], 24 | dtype: Optional[torch.dtype], 25 | num_latents: int, 26 | fourier_embedder: FourierEmbedder, 27 | point_feats: int, 28 | width: int, 29 | heads: int, 30 | layers: int, 31 | init_scale: float = 0.25, 32 | qkv_bias: bool = True, 33 | flash: bool = False, 34 | use_ln_post: bool = False, 35 | use_checkpoint: bool = False): 36 | 37 | super().__init__() 38 | 39 | self.use_checkpoint = use_checkpoint 40 | self.num_latents = num_latents 41 | 42 | self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) 43 | 44 | self.fourier_embedder = fourier_embedder 45 | self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) 46 | self.cross_attn = ResidualCrossAttentionBlock( 47 | device=device, 48 | dtype=dtype, 49 | width=width, 50 | heads=heads, 51 | init_scale=init_scale, 52 | qkv_bias=qkv_bias, 53 | flash=flash, 54 | ) 55 | 56 | self.self_attn = Transformer( 57 | device=device, 58 | dtype=dtype, 59 | n_ctx=num_latents, 60 | width=width, 61 | layers=layers, 62 | heads=heads, 63 | init_scale=init_scale, 64 | qkv_bias=qkv_bias, 65 | flash=flash, 66 | use_checkpoint=False 67 | ) 68 | 69 | if use_ln_post: 70 | self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) 71 | else: 72 | self.ln_post = None 73 | 74 | def _forward(self, pc, feats): 75 | """ 76 | 77 | Args: 78 | pc (torch.FloatTensor): [B, N, 3] 79 | feats (torch.FloatTensor or None): [B, N, C] 80 | 81 | Returns: 82 | 83 | """ 84 | 85 | bs = pc.shape[0] 86 | 87 | data = self.fourier_embedder(pc) 88 | if feats is not None: 89 | data = torch.cat([data, feats], dim=-1) 90 | data = self.input_proj(data) 91 | 92 | query = repeat(self.query, "m c -> b m c", b=bs) 93 | latents = self.cross_attn(query, data) 94 | latents = self.self_attn(latents) 95 | 96 | if self.ln_post is not None: 97 | latents = self.ln_post(latents) 98 | 99 | return latents, pc 100 | 101 | def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): 102 | """ 103 | 104 | Args: 105 | pc (torch.FloatTensor): [B, N, 3] 106 | feats (torch.FloatTensor or None): [B, N, C] 107 | 108 | Returns: 109 | dict 110 | """ 111 | 112 | return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) 113 | 114 | 115 | class CrossAttentionDecoder(nn.Module): 116 | 117 | def __init__(self, *, 118 | device: Optional[torch.device], 119 | dtype: Optional[torch.dtype], 120 | num_latents: int, 121 | out_channels: int, 122 | fourier_embedder: FourierEmbedder, 123 | width: int, 124 | heads: int, 125 | init_scale: float = 0.25, 126 | qkv_bias: bool = True, 127 | flash: bool = False, 128 | use_checkpoint: bool = False): 129 | 130 | super().__init__() 131 | 132 | self.use_checkpoint = use_checkpoint 133 | self.fourier_embedder = fourier_embedder 134 | 135 | self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) 136 | 137 | self.cross_attn_decoder = ResidualCrossAttentionBlock( 138 | device=device, 139 | dtype=dtype, 140 | n_data=num_latents, 141 | width=width, 142 | heads=heads, 143 | init_scale=init_scale, 144 | qkv_bias=qkv_bias, 145 | flash=flash 146 | ) 147 | 148 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 149 | self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) 150 | 151 | def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 152 | queries = self.query_proj(self.fourier_embedder(queries)) 153 | x = self.cross_attn_decoder(queries, latents) 154 | x = self.ln_post(x) 155 | x = self.output_proj(x) 156 | return x 157 | 158 | def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 159 | return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) 160 | 161 | 162 | class ShapeAsLatentPerceiver(ShapeAsLatentModule): 163 | def __init__(self, *, 164 | device: Optional[torch.device], 165 | dtype: Optional[torch.dtype], 166 | num_latents: int, 167 | point_feats: int = 0, 168 | embed_dim: int = 0, 169 | num_freqs: int = 8, 170 | include_pi: bool = True, 171 | width: int, 172 | heads: int, 173 | num_encoder_layers: int, 174 | num_decoder_layers: int, 175 | init_scale: float = 0.25, 176 | qkv_bias: bool = True, 177 | flash: bool = False, 178 | use_ln_post: bool = False, 179 | use_checkpoint: bool = False): 180 | 181 | super().__init__() 182 | 183 | self.use_checkpoint = use_checkpoint 184 | 185 | self.num_latents = num_latents 186 | self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) 187 | 188 | init_scale = init_scale * math.sqrt(1.0 / width) 189 | self.encoder = CrossAttentionEncoder( 190 | device=device, 191 | dtype=dtype, 192 | fourier_embedder=self.fourier_embedder, 193 | num_latents=num_latents, 194 | point_feats=point_feats, 195 | width=width, 196 | heads=heads, 197 | layers=num_encoder_layers, 198 | init_scale=init_scale, 199 | qkv_bias=qkv_bias, 200 | flash=flash, 201 | use_ln_post=use_ln_post, 202 | use_checkpoint=use_checkpoint 203 | ) 204 | 205 | self.embed_dim = embed_dim 206 | if embed_dim > 0: 207 | # VAE embed 208 | self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) 209 | self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype) 210 | self.latent_shape = (num_latents, embed_dim) 211 | else: 212 | self.latent_shape = (num_latents, width) 213 | 214 | self.transformer = Transformer( 215 | device=device, 216 | dtype=dtype, 217 | n_ctx=num_latents, 218 | width=width, 219 | layers=num_decoder_layers, 220 | heads=heads, 221 | init_scale=init_scale, 222 | qkv_bias=qkv_bias, 223 | flash=flash, 224 | use_checkpoint=use_checkpoint 225 | ) 226 | 227 | # geometry decoder 228 | self.geo_decoder = CrossAttentionDecoder( 229 | device=device, 230 | dtype=dtype, 231 | fourier_embedder=self.fourier_embedder, 232 | out_channels=1, 233 | num_latents=num_latents, 234 | width=width, 235 | heads=heads, 236 | init_scale=init_scale, 237 | qkv_bias=qkv_bias, 238 | flash=flash, 239 | use_checkpoint=use_checkpoint 240 | ) 241 | 242 | def encode(self, 243 | pc: torch.FloatTensor, 244 | feats: Optional[torch.FloatTensor] = None, 245 | sample_posterior: bool = True): 246 | """ 247 | 248 | Args: 249 | pc (torch.FloatTensor): [B, N, 3] 250 | feats (torch.FloatTensor or None): [B, N, C] 251 | sample_posterior (bool): 252 | 253 | Returns: 254 | latents (torch.FloatTensor) 255 | center_pos (torch.FloatTensor or None): 256 | posterior (DiagonalGaussianDistribution or None): 257 | """ 258 | 259 | latents, center_pos = self.encoder(pc, feats) 260 | 261 | posterior = None 262 | if self.embed_dim > 0: 263 | moments = self.pre_kl(latents) 264 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) 265 | 266 | if sample_posterior: 267 | latents = posterior.sample() 268 | else: 269 | latents = posterior.mode() 270 | 271 | return latents, center_pos, posterior 272 | 273 | def decode(self, latents: torch.FloatTensor): 274 | latents = self.post_kl(latents) 275 | return self.transformer(latents) 276 | 277 | def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 278 | logits = self.geo_decoder(queries, latents).squeeze(-1) 279 | return logits 280 | 281 | def forward(self, 282 | pc: torch.FloatTensor, 283 | feats: torch.FloatTensor, 284 | volume_queries: torch.FloatTensor, 285 | sample_posterior: bool = True): 286 | """ 287 | 288 | Args: 289 | pc (torch.FloatTensor): [B, N, 3] 290 | feats (torch.FloatTensor or None): [B, N, C] 291 | volume_queries (torch.FloatTensor): [B, P, 3] 292 | sample_posterior (bool): 293 | 294 | Returns: 295 | logits (torch.FloatTensor): [B, P] 296 | center_pos (torch.FloatTensor): [B, M, 3] 297 | posterior (DiagonalGaussianDistribution or None). 298 | 299 | """ 300 | 301 | latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) 302 | 303 | latents = self.decode(latents) 304 | logits = self.query_geometry(volume_queries, latents) 305 | 306 | return logits, center_pos, posterior 307 | 308 | 309 | class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): 310 | 311 | def __init__(self, *, 312 | device: Optional[torch.device], 313 | dtype: Optional[torch.dtype], 314 | num_latents: int, 315 | point_feats: int = 0, 316 | embed_dim: int = 0, 317 | num_freqs: int = 8, 318 | include_pi: bool = True, 319 | width: int, 320 | heads: int, 321 | num_encoder_layers: int, 322 | num_decoder_layers: int, 323 | init_scale: float = 0.25, 324 | qkv_bias: bool = True, 325 | flash: bool = False, 326 | use_ln_post: bool = False, 327 | use_checkpoint: bool = False): 328 | 329 | super().__init__( 330 | device=device, 331 | dtype=dtype, 332 | num_latents=1 + num_latents, 333 | point_feats=point_feats, 334 | embed_dim=embed_dim, 335 | num_freqs=num_freqs, 336 | include_pi=include_pi, 337 | width=width, 338 | heads=heads, 339 | num_encoder_layers=num_encoder_layers, 340 | num_decoder_layers=num_decoder_layers, 341 | init_scale=init_scale, 342 | qkv_bias=qkv_bias, 343 | flash=flash, 344 | use_ln_post=use_ln_post, 345 | use_checkpoint=use_checkpoint 346 | ) 347 | 348 | self.width = width 349 | 350 | def encode(self, 351 | pc: torch.FloatTensor, 352 | feats: Optional[torch.FloatTensor] = None, 353 | sample_posterior: bool = True): 354 | """ 355 | 356 | Args: 357 | pc (torch.FloatTensor): [B, N, 3] 358 | feats (torch.FloatTensor or None): [B, N, c] 359 | sample_posterior (bool): 360 | 361 | Returns: 362 | shape_embed (torch.FloatTensor) 363 | kl_embed (torch.FloatTensor): 364 | posterior (DiagonalGaussianDistribution or None): 365 | """ 366 | 367 | shape_embed, latents = self.encode_latents(pc, feats) 368 | kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) 369 | 370 | return shape_embed, kl_embed, posterior 371 | 372 | def encode_latents(self, 373 | pc: torch.FloatTensor, 374 | feats: Optional[torch.FloatTensor] = None): 375 | 376 | x, _ = self.encoder(pc, feats) 377 | 378 | shape_embed = x[:, 0] 379 | latents = x[:, 1:] 380 | 381 | return shape_embed, latents 382 | 383 | def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): 384 | posterior = None 385 | if self.embed_dim > 0: 386 | moments = self.pre_kl(latents) 387 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) 388 | 389 | if sample_posterior: 390 | kl_embed = posterior.sample() 391 | else: 392 | kl_embed = posterior.mode() 393 | else: 394 | kl_embed = latents 395 | 396 | return kl_embed, posterior 397 | 398 | def forward(self, 399 | pc: torch.FloatTensor, 400 | feats: torch.FloatTensor, 401 | volume_queries: torch.FloatTensor, 402 | sample_posterior: bool = True): 403 | """ 404 | 405 | Args: 406 | pc (torch.FloatTensor): [B, N, 3] 407 | feats (torch.FloatTensor or None): [B, N, C] 408 | volume_queries (torch.FloatTensor): [B, P, 3] 409 | sample_posterior (bool): 410 | 411 | Returns: 412 | shape_embed (torch.FloatTensor): [B, projection_dim] 413 | logits (torch.FloatTensor): [B, M] 414 | posterior (DiagonalGaussianDistribution or None). 415 | 416 | """ 417 | 418 | shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) 419 | 420 | latents = self.decode(kl_embed) 421 | logits = self.query_geometry(volume_queries, latents) 422 | 423 | return shape_embed, logits, posterior 424 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/sal_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import List, Tuple, Dict, Optional 4 | from omegaconf import DictConfig 5 | 6 | import torch 7 | from torch.optim import lr_scheduler 8 | import pytorch_lightning as pl 9 | from typing import Union 10 | from functools import partial 11 | 12 | from MeshAnything.miche.michelangelo.utils import instantiate_from_config 13 | 14 | from .inference_utils import extract_geometry 15 | from .tsal_base import ( 16 | ShapeAsLatentModule, 17 | Latent2MeshOutput, 18 | Point2MeshOutput 19 | ) 20 | 21 | 22 | class ShapeAsLatentPLModule(pl.LightningModule): 23 | 24 | def __init__(self, *, 25 | module_cfg, 26 | loss_cfg, 27 | optimizer_cfg: Optional[DictConfig] = None, 28 | ckpt_path: Optional[str] = None, 29 | ignore_keys: Union[Tuple[str], List[str]] = ()): 30 | 31 | super().__init__() 32 | 33 | self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None) 34 | 35 | self.loss = instantiate_from_config(loss_cfg) 36 | 37 | self.optimizer_cfg = optimizer_cfg 38 | 39 | if ckpt_path is not None: 40 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 41 | 42 | self.save_hyperparameters() 43 | 44 | @property 45 | def latent_shape(self): 46 | return self.sal.latent_shape 47 | 48 | @property 49 | def zero_rank(self): 50 | if self._trainer: 51 | zero_rank = self.trainer.local_rank == 0 52 | else: 53 | zero_rank = True 54 | 55 | return zero_rank 56 | 57 | def init_from_ckpt(self, path, ignore_keys=()): 58 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 59 | 60 | keys = list(state_dict.keys()) 61 | for k in keys: 62 | for ik in ignore_keys: 63 | if k.startswith(ik): 64 | print("Deleting key {} from state_dict.".format(k)) 65 | del state_dict[k] 66 | 67 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 68 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 69 | if len(missing) > 0: 70 | print(f"Missing Keys: {missing}") 71 | print(f"Unexpected Keys: {unexpected}") 72 | 73 | def configure_optimizers(self) -> Tuple[List, List]: 74 | lr = self.learning_rate 75 | 76 | # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)] 77 | # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 78 | 79 | if self.optimizer_cfg is None: 80 | optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 81 | schedulers = [] 82 | else: 83 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters()) 84 | scheduler_func = instantiate_from_config( 85 | self.optimizer_cfg.scheduler, 86 | max_decay_steps=self.trainer.max_steps, 87 | lr_max=lr 88 | ) 89 | scheduler = { 90 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 91 | "interval": "step", 92 | "frequency": 1 93 | } 94 | optimizers = [optimizer] 95 | schedulers = [scheduler] 96 | 97 | return optimizers, schedulers 98 | 99 | def forward(self, 100 | pc: torch.FloatTensor, 101 | feats: torch.FloatTensor, 102 | volume_queries: torch.FloatTensor): 103 | 104 | logits, center_pos, posterior = self.sal(pc, feats, volume_queries) 105 | 106 | return posterior, logits 107 | 108 | def encode(self, surface: torch.FloatTensor, sample_posterior=True): 109 | 110 | pc = surface[..., 0:3] 111 | feats = surface[..., 3:6] 112 | 113 | latents, center_pos, posterior = self.sal.encode( 114 | pc=pc, feats=feats, sample_posterior=sample_posterior 115 | ) 116 | 117 | return latents 118 | 119 | def decode(self, 120 | z_q, 121 | bounds: Union[Tuple[float], List[float], float] = 1.1, 122 | octree_depth: int = 7, 123 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 124 | 125 | latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim] 126 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) 127 | 128 | return outputs 129 | 130 | def training_step(self, batch: Dict[str, torch.FloatTensor], 131 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 132 | """ 133 | 134 | Args: 135 | batch (dict): the batch sample, and it contains: 136 | - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] 137 | - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] 138 | 139 | batch_idx (int): 140 | 141 | optimizer_idx (int): 142 | 143 | Returns: 144 | loss (torch.FloatTensor): 145 | 146 | """ 147 | 148 | pc = batch["surface"][..., 0:3] 149 | feats = batch["surface"][..., 3:] 150 | 151 | volume_queries = batch["geo_points"][..., 0:3] 152 | volume_labels = batch["geo_points"][..., -1] 153 | 154 | posterior, logits = self( 155 | pc=pc, feats=feats, volume_queries=volume_queries 156 | ) 157 | aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train") 158 | 159 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], 160 | sync_dist=False, rank_zero_only=True) 161 | 162 | return aeloss 163 | 164 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: 165 | 166 | pc = batch["surface"][..., 0:3] 167 | feats = batch["surface"][..., 3:] 168 | 169 | volume_queries = batch["geo_points"][..., 0:3] 170 | volume_labels = batch["geo_points"][..., -1] 171 | 172 | posterior, logits = self( 173 | pc=pc, feats=feats, volume_queries=volume_queries, 174 | ) 175 | aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val") 176 | 177 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], 178 | sync_dist=False, rank_zero_only=True) 179 | 180 | return aeloss 181 | 182 | def point2mesh(self, 183 | pc: torch.FloatTensor, 184 | feats: torch.FloatTensor, 185 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 186 | octree_depth: int = 7, 187 | num_chunks: int = 10000) -> List[Point2MeshOutput]: 188 | 189 | """ 190 | 191 | Args: 192 | pc: 193 | feats: 194 | bounds: 195 | octree_depth: 196 | num_chunks: 197 | 198 | Returns: 199 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 200 | 201 | """ 202 | 203 | outputs = [] 204 | 205 | device = pc.device 206 | bs = pc.shape[0] 207 | 208 | # 1. point encoder + latents transformer 209 | latents, center_pos, posterior = self.sal.encode(pc, feats) 210 | latents = self.sal.decode(latents) # latents: [bs, num_latents, dim] 211 | 212 | geometric_func = partial(self.sal.query_geometry, latents=latents) 213 | 214 | # 2. decode geometry 215 | mesh_v_f, has_surface = extract_geometry( 216 | geometric_func=geometric_func, 217 | device=device, 218 | batch_size=bs, 219 | bounds=bounds, 220 | octree_depth=octree_depth, 221 | num_chunks=num_chunks, 222 | disable=not self.zero_rank 223 | ) 224 | 225 | # 3. decode texture 226 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 227 | if not is_surface: 228 | outputs.append(None) 229 | continue 230 | 231 | out = Point2MeshOutput() 232 | out.mesh_v = mesh_v 233 | out.mesh_f = mesh_f 234 | out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy() 235 | 236 | if center_pos is not None: 237 | out.center = center_pos[i].cpu().numpy() 238 | 239 | outputs.append(out) 240 | 241 | return outputs 242 | 243 | def latent2mesh(self, 244 | latents: torch.FloatTensor, 245 | bounds: Union[Tuple[float], List[float], float] = 1.1, 246 | octree_depth: int = 7, 247 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 248 | 249 | """ 250 | 251 | Args: 252 | latents: [bs, num_latents, dim] 253 | bounds: 254 | octree_depth: 255 | num_chunks: 256 | 257 | Returns: 258 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 259 | 260 | """ 261 | 262 | outputs = [] 263 | 264 | geometric_func = partial(self.sal.query_geometry, latents=latents) 265 | 266 | # 2. decode geometry 267 | device = latents.device 268 | mesh_v_f, has_surface = extract_geometry( 269 | geometric_func=geometric_func, 270 | device=device, 271 | batch_size=len(latents), 272 | bounds=bounds, 273 | octree_depth=octree_depth, 274 | num_chunks=num_chunks, 275 | disable=not self.zero_rank 276 | ) 277 | 278 | # 3. decode texture 279 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 280 | if not is_surface: 281 | outputs.append(None) 282 | continue 283 | 284 | out = Latent2MeshOutput() 285 | out.mesh_v = mesh_v 286 | out.mesh_f = mesh_f 287 | 288 | outputs.append(out) 289 | 290 | return outputs 291 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/models/tsal/tsal_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from typing import Tuple, List, Optional 5 | 6 | 7 | class Point2MeshOutput(object): 8 | def __init__(self): 9 | self.mesh_v = None 10 | self.mesh_f = None 11 | self.center = None 12 | self.pc = None 13 | 14 | 15 | class Latent2MeshOutput(object): 16 | 17 | def __init__(self): 18 | self.mesh_v = None 19 | self.mesh_f = None 20 | 21 | 22 | class AlignedMeshOutput(object): 23 | 24 | def __init__(self): 25 | self.mesh_v = None 26 | self.mesh_f = None 27 | self.surface = None 28 | self.image = None 29 | self.text: Optional[str] = None 30 | self.shape_text_similarity: Optional[float] = None 31 | self.shape_image_similarity: Optional[float] = None 32 | 33 | 34 | class ShapeAsLatentPLModule(nn.Module): 35 | latent_shape: Tuple[int] 36 | 37 | def encode(self, surface, *args, **kwargs): 38 | raise NotImplementedError 39 | 40 | def decode(self, z_q, *args, **kwargs): 41 | raise NotImplementedError 42 | 43 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 44 | raise NotImplementedError 45 | 46 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 47 | raise NotImplementedError 48 | 49 | 50 | class ShapeAsLatentModule(nn.Module): 51 | latent_shape: Tuple[int, int] 52 | 53 | def __init__(self, *args, **kwargs): 54 | super().__init__() 55 | 56 | def encode(self, *args, **kwargs): 57 | raise NotImplementedError 58 | 59 | def decode(self, *args, **kwargs): 60 | raise NotImplementedError 61 | 62 | def query_geometry(self, *args, **kwargs): 63 | raise NotImplementedError 64 | 65 | 66 | class AlignedShapeAsLatentPLModule(nn.Module): 67 | latent_shape: Tuple[int] 68 | 69 | def set_shape_model_only(self): 70 | raise NotImplementedError 71 | 72 | def encode(self, surface, *args, **kwargs): 73 | raise NotImplementedError 74 | 75 | def decode(self, z_q, *args, **kwargs): 76 | raise NotImplementedError 77 | 78 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 79 | raise NotImplementedError 80 | 81 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 82 | raise NotImplementedError 83 | 84 | 85 | class AlignedShapeAsLatentModule(nn.Module): 86 | shape_model: ShapeAsLatentModule 87 | latent_shape: Tuple[int, int] 88 | 89 | def __init__(self, *args, **kwargs): 90 | super().__init__() 91 | 92 | def set_shape_model_only(self): 93 | raise NotImplementedError 94 | 95 | def encode_image_embed(self, *args, **kwargs): 96 | raise NotImplementedError 97 | 98 | def encode_text_embed(self, *args, **kwargs): 99 | raise NotImplementedError 100 | 101 | def encode_shape_embed(self, *args, **kwargs): 102 | raise NotImplementedError 103 | 104 | 105 | class TexturedShapeAsLatentModule(nn.Module): 106 | 107 | def __init__(self, *args, **kwargs): 108 | super().__init__() 109 | 110 | def encode(self, *args, **kwargs): 111 | raise NotImplementedError 112 | 113 | def decode(self, *args, **kwargs): 114 | raise NotImplementedError 115 | 116 | def query_geometry(self, *args, **kwargs): 117 | raise NotImplementedError 118 | 119 | def query_color(self, *args, **kwargs): 120 | raise NotImplementedError 121 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .misc import instantiate_from_config 4 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7): 7 | 8 | mse = torch.mean((x - y) ** 2) 9 | psnr = 10 * torch.log10(data_range / (mse + eps)) 10 | 11 | return psnr 12 | 13 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import io 5 | import tarfile 6 | import json 7 | import numpy as np 8 | import numpy.lib.format 9 | 10 | 11 | def mkdir(path): 12 | os.makedirs(path, exist_ok=True) 13 | return path 14 | 15 | 16 | def npy_loads(data): 17 | stream = io.BytesIO(data) 18 | return np.lib.format.read_array(stream) 19 | 20 | 21 | def npz_loads(data): 22 | return np.load(io.BytesIO(data)) 23 | 24 | 25 | def json_loads(data): 26 | return json.loads(data) 27 | 28 | 29 | def load_json(filepath): 30 | with open(filepath, "r") as f: 31 | data = json.load(f) 32 | return data 33 | 34 | 35 | def write_json(filepath, data): 36 | with open(filepath, "w") as f: 37 | json.dump(data, f, indent=2) 38 | 39 | 40 | def extract_tar(tar_path, tar_cache_folder): 41 | 42 | with tarfile.open(tar_path, "r") as tar: 43 | tar.extractall(path=tar_cache_folder) 44 | 45 | tar_uids = sorted(os.listdir(tar_cache_folder)) 46 | print(f"extract tar: {tar_path} to {tar_cache_folder}") 47 | return tar_uids 48 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import importlib 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | 10 | def get_obj_from_str(string, reload=False): 11 | module, cls = string.rsplit(".", 1) 12 | if reload: 13 | module_imp = importlib.import_module(module) 14 | importlib.reload(module_imp) 15 | return getattr(importlib.import_module(module, package=None), cls) 16 | 17 | 18 | def get_obj_from_config(config): 19 | if "target" not in config: 20 | raise KeyError("Expected key `target` to instantiate.") 21 | 22 | return get_obj_from_str(config["target"]) 23 | 24 | 25 | def instantiate_from_config(config, **kwargs): 26 | if "target" not in config: 27 | raise KeyError("Expected key `target` to instantiate.") 28 | 29 | cls = get_obj_from_str(config["target"]) 30 | 31 | params = config.get("params", dict()) 32 | # params.update(kwargs) 33 | # instance = cls(**params) 34 | kwargs.update(params) 35 | instance = cls(**kwargs) 36 | 37 | return instance 38 | 39 | 40 | def is_dist_avail_and_initialized(): 41 | if not dist.is_available(): 42 | return False 43 | if not dist.is_initialized(): 44 | return False 45 | return True 46 | 47 | 48 | def get_rank(): 49 | if not is_dist_avail_and_initialized(): 50 | return 0 51 | return dist.get_rank() 52 | 53 | 54 | def get_world_size(): 55 | if not is_dist_avail_and_initialized(): 56 | return 1 57 | return dist.get_world_size() 58 | 59 | 60 | def all_gather_batch(tensors): 61 | """ 62 | Performs all_gather operation on the provided tensors. 63 | """ 64 | # Queue the gathered tensors 65 | world_size = get_world_size() 66 | # There is no need for reduction in the single-proc case 67 | if world_size == 1: 68 | return tensors 69 | tensor_list = [] 70 | output_tensor = [] 71 | for tensor in tensors: 72 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 73 | dist.all_gather( 74 | tensor_all, 75 | tensor, 76 | async_op=False # performance opt 77 | ) 78 | 79 | tensor_list.append(tensor_all) 80 | 81 | for tensor_all in tensor_list: 82 | output_tensor.append(torch.cat(tensor_all, dim=0)) 83 | return output_tensor 84 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/color_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | # Helper functions 6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): 7 | colormap = plt.cm.get_cmap(colormap) 8 | if normalize: 9 | vmin = np.min(inp) 10 | vmax = np.max(inp) 11 | 12 | norm = plt.Normalize(vmin, vmax) 13 | return colormap(norm(inp))[:, :3] 14 | 15 | 16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): 17 | # tex dims need to be power of two. 18 | array = np.ones((width, height, 3), dtype='float32') 19 | 20 | # width in texels of each checker 21 | checker_w = width / n_checkers_x 22 | checker_h = height / n_checkers_y 23 | 24 | for y in range(height): 25 | for x in range(width): 26 | color_key = int(x / checker_w) + int(y / checker_h) 27 | if color_key % 2 == 0: 28 | array[x, y, :] = [1., 0.874, 0.0] 29 | else: 30 | array[x, y, :] = [0., 0., 0.] 31 | return array 32 | 33 | 34 | def gen_circle(width=256, height=256): 35 | xx, yy = np.mgrid[:width, :height] 36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 37 | array = np.ones((width, height, 4), dtype='float32') 38 | array[:, :, 0] = (circle <= width) 39 | array[:, :, 1] = (circle <= width) 40 | array[:, :, 2] = (circle <= width) 41 | array[:, :, 3] = circle <= width 42 | return array 43 | 44 | -------------------------------------------------------------------------------- /MeshAnything/miche/michelangelo/utils/visualizers/html_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | import base64 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def to_html_frame(content): 9 | 10 | html_frame = f""" 11 | 12 |
13 | {content} 14 | 15 | 16 | """ 17 | 18 | return html_frame 19 | 20 | 21 | def to_single_row_table(caption: str, content: str): 22 | 23 | table_html = f""" 24 |{content} | 28 |
2 |
5 | Yiwen Chen1,2*,
6 | Tong He2†,
7 | Di Huang2,
8 | Weicai Ye2,
9 | Sijin Chen3,
10 | Jiaxiang Tang4
11 | Xin Chen5,
12 | Zhongang Cai6,
13 | Lei Yang6,
14 | Gang Yu7,
15 | Guosheng Lin1†,
16 | Chi Zhang8†
17 |
18 | *Work done during a research internship at Shanghai AI Lab.
19 |
20 | †Corresponding authors.
21 |
22 | 1S-Lab, Nanyang Technological University,
23 | 2Shanghai AI Lab,
24 |
25 | 3Fudan University,
26 | 4Peking University,
27 | 5University of Chinese Academy of Sciences,
28 |
29 | 6SenseTime Research,
30 | 7Stepfun,
31 | 8Westlake University
32 |
47 |
48 |