├── LICENSE ├── README.md ├── configs ├── aligned_shape_latents │ └── shapevae-256.yaml ├── image_cond_diffuser_asl │ └── image-ASLDM-256.yaml └── text_cond_diffuser_asl │ └── text-ASLDM-256.yaml ├── example_data ├── image │ └── car.jpg └── surface │ └── surface.npz ├── inference.py ├── michelangelo ├── __init__.py ├── data │ ├── __init__.py │ ├── templates.json │ ├── transforms.py │ └── utils.py ├── graphics │ ├── __init__.py │ └── primitives │ │ ├── __init__.py │ │ ├── mesh.py │ │ └── volume.py ├── models │ ├── __init__.py │ ├── asl_diffusion │ │ ├── __init__.py │ │ ├── asl_diffuser_pl_module.py │ │ ├── asl_udt.py │ │ ├── base.py │ │ ├── clip_asl_diffuser_pl_module.py │ │ └── inference_utils.py │ ├── conditional_encoders │ │ ├── __init__.py │ │ ├── clip.py │ │ └── encoder_factory.py │ ├── modules │ │ ├── __init__.py │ │ ├── checkpoint.py │ │ ├── diffusion_transformer.py │ │ ├── distributions.py │ │ ├── embedder.py │ │ ├── transformer_blocks.py │ │ └── transformer_vit.py │ └── tsal │ │ ├── __init__.py │ │ ├── asl_pl_module.py │ │ ├── clip_asl_module.py │ │ ├── inference_utils.py │ │ ├── loss.py │ │ ├── sal_perceiver.py │ │ ├── sal_pl_module.py │ │ └── tsal_base.py └── utils │ ├── __init__.py │ ├── eval.py │ ├── io.py │ ├── misc.py │ └── visualizers │ ├── __init__.py │ ├── color_util.py │ ├── html_util.py │ └── pythreejs_viewer.py ├── requirements.txt ├── scripts ├── infer.sh └── inference │ ├── image2mesh.sh │ ├── reconstruction.sh │ └── text2mesh.sh └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # Michelangelo 2 | 3 | ## [Conditional 3D Shape Generation based on Shape-Image-Text Aligned Latent Representation](https://neuralcarver.github.io/michelangelo)
4 | [Zibo Zhao](https://github.com/Maikouuu), 5 | [Wen Liu](https://github.com/StevenLiuWen), 6 | [Xin Chen](https://chenxin.tech/), 7 | [Xianfang Zeng](https://github.com/Zzlongjuanfeng), 8 | [Rui Wang](https://wrong.wang/), 9 | [Pei Cheng](https://neuralcarver.github.io/michelangelo), 10 | [Bin Fu](https://neuralcarver.github.io/michelangelo), 11 | [Tao Chen](https://eetchen.github.io), 12 | [Gang Yu](https://www.skicyyu.org), 13 | [Shenghua Gao](https://sist.shanghaitech.edu.cn/sist_en/2020/0814/c7582a54772/page.htm)
14 | ### [Hugging Face Demo](https://huggingface.co/spaces/Maikou/Michelangelo) | [Project Page](https://neuralcarver.github.io/michelangelo/) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Paper](https://openreview.net/pdf?id=xmxgMij3LY)
15 | 16 | https://github.com/NeuralCarver/Michelangelo/assets/37449470/123bae2c-fbb1-4d63-bd13-0e300a550868 17 | 18 | Visualization of the 3D shape produced by our framework, which splits into triplets with a conditional input on the left, a normal map in the middle, and a triangle mesh on the right. The generated 3D shapes semantically conform to the visual or textural conditional inputs.
19 | 20 | ## 🔆 Features 21 | **Michelangelo** possesses three capabilities: 22 | 23 | 1. Representing a shape into shape-image-text aligned space; 24 | 2. Image-conditioned Shape Generation; 25 | 3. Text-conditioned Shape Generation. 26 | 27 |
28 | Techniques 29 | 30 | We present a novel _alignment-before-generation_ approach to tackle the challenging task of generating general 3D shapes based on 2D images or texts. Directly learning a conditional generative model from images or texts to 3D shapes is prone to producing inconsistent results with the conditions because 3D shapes have an additional dimension whose distribution significantly differs from that of 2D images and texts. To bridge the domain gap among the three modalities and facilitate multi-modal-conditioned 3D shape generation, we explore representing 3D shapes in a shape-image-text-aligned space. Our framework comprises two models: a Shape-Image-Text-Aligned Variational Auto-Encoder (SITA-VAE) and a conditional Aligned Shape Latent Diffusion Model (ASLDM). The former model encodes the 3D shapes into the shape latent space aligned to the image and text and reconstructs the fine-grained 3D neural fields corresponding to given shape embeddings via the transformer-based decoder. The latter model learns a probabilistic mapping function from the image or text space to the latent shape space. Our extensive experiments demonstrate that our proposed approach can generate higher-quality and more diverse 3D shapes that better semantically conform to the visual or textural conditional inputs, validating the effectiveness of the shape-image-text-aligned space for cross-modality 3D shape generation. 31 | 32 | ![newnetwork](https://github.com/NeuralCarver/Michelangelo/assets/16475892/d5231fb7-7768-45ee-92e1-3599a4c43a2c) 33 |
34 | 35 | ## 📰 News 36 | - [2024/1/23] Set up the Hugging Face Demo and release the code 37 | - [2023/09/22] **Michelangelo got accepted by NeurIPS 2023!** 38 | - [2023/6/29] Upload paper and init project 39 | 40 | ## ⚙️ Setup 41 | 42 | ### Installation 43 | Follow the command below to install the environment. We have tested the installation package on Tesla V100 and Tesla T4. 44 | ``` 45 | git clone https://github.com/NeuralCarver/Michelangelo.git 46 | cd Michelangelo 47 | conda create --name Michelangelo python=3.9 48 | conda activate Michelangelo 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | ### Checkpoints 53 | Pleasae download weights from Hugging Face Model Space and put it to root folder. We have also uploaded the weights related to CLIP to facilitate quick usage. 54 | 55 |
56 | 57 | Tips for debugging configureation 58 | 59 | 60 | - If something goes wrong in the environment configuration process unfortunately, the user may consider skipping those packages, such as pysdf, torch-cluster, and torch-scatter. These packages will not affect the execution of the commands we provide. 61 | - If you encounter any issues while downloading CLIP, you can consider downloading it from [CLIP's Hugging Face page](https://huggingface.co/openai/clip-vit-large-patch14). Once the download is complete, remember to modify line [26](https://github.com/NeuralCarver/Michelangelo/blob/b53fa004cd4aeb0f4eb4d159ecec8489a4450dab/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml#L26C1-L26C76) and line [34](https://github.com/NeuralCarver/Michelangelo/blob/b53fa004cd4aeb0f4eb4d159ecec8489a4450dab/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml#L34) in the config file for providing correct path of CLIP. 62 | - From [issue 6](https://github.com/NeuralCarver/Michelangelo/issues/6#issuecomment-1913513382). For Windows users, running wsl2 + ubuntu 22.04, will have issues. As discussed in [issue 786](https://github.com/microsoft/WSL/issues/8587) it is just a matter to add this in the .bashrc: 63 | ``` 64 | export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH. 65 | ``` 66 |
67 | 68 | ## ⚡ Quick Start 69 | 70 | ### Inference 71 | 72 | #### Reconstruction a 3D shape 73 | ``` 74 | ./scripts/inference/reconstruction.sh 75 | ``` 76 | 77 | #### Image-conditioned shape generation 78 | ``` 79 | ./scripts/inference/image2mesh.sh 80 | ``` 81 | 82 | #### Text-conditioned shape generation 83 | ``` 84 | ./scripts/inference/text2mesh.sh 85 | ``` 86 | 87 | #### Simply run all the scripts 88 | ``` 89 | ./scripts/infer.sh 90 | ``` 91 | 92 | 93 | ## ❓ FAQ 94 | 95 | ## Citation 96 | 97 | If you find our code or paper helps, please consider citing: 98 | 99 | ```bibtex 100 | @inproceedings{ 101 | zhao2023michelangelo, 102 | title={Michelangelo: Conditional 3D Shape Generation based on Shape-Image-Text Aligned Latent Representation}, 103 | author={Zibo Zhao and Wen Liu and Xin Chen and Xianfang Zeng and Rui Wang and Pei Cheng and BIN FU and Tao Chen and Gang YU and Shenghua Gao}, 104 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 105 | year={2023}, 106 | url={https://openreview.net/forum?id=xmxgMij3LY} 107 | } 108 | ``` 109 | 110 | ## License 111 | 112 | This code is distributed under an [GPL-3.0 license](LICENSE). 113 | 114 | -------------------------------------------------------------------------------- /configs/aligned_shape_latents/shapevae-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 3 | params: 4 | shape_module_cfg: 5 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 6 | params: 7 | num_latents: 256 8 | embed_dim: 64 9 | point_feats: 3 # normal 10 | num_freqs: 8 11 | include_pi: false 12 | heads: 12 13 | width: 768 14 | num_encoder_layers: 8 15 | num_decoder_layers: 16 16 | use_ln_post: true 17 | init_scale: 0.25 18 | qkv_bias: false 19 | use_checkpoint: true 20 | aligned_module_cfg: 21 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 22 | params: 23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 24 | 25 | loss_cfg: 26 | target: michelangelo.models.tsal.loss.ContrastKLNearFar 27 | params: 28 | contrast_weight: 0.1 29 | near_weight: 0.1 30 | kl_weight: 0.001 31 | 32 | optimizer_cfg: 33 | optimizer: 34 | target: torch.optim.AdamW 35 | params: 36 | betas: [0.9, 0.99] 37 | eps: 1.e-6 38 | weight_decay: 1.e-2 39 | 40 | scheduler: 41 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 42 | params: 43 | warm_up_steps: 5000 44 | f_start: 1.e-6 45 | f_min: 1.e-3 46 | f_max: 1.0 47 | -------------------------------------------------------------------------------- /configs/image_cond_diffuser_asl/image-ASLDM-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser 3 | params: 4 | first_stage_config: 5 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 6 | params: 7 | shape_module_cfg: 8 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 9 | params: 10 | num_latents: &num_latents 256 11 | embed_dim: &embed_dim 64 12 | point_feats: 3 # normal 13 | num_freqs: 8 14 | include_pi: false 15 | heads: 12 16 | width: 768 17 | num_encoder_layers: 8 18 | num_decoder_layers: 16 19 | use_ln_post: true 20 | init_scale: 0.25 21 | qkv_bias: false 22 | use_checkpoint: false 23 | aligned_module_cfg: 24 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 25 | params: 26 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 27 | 28 | loss_cfg: 29 | target: torch.nn.Identity 30 | 31 | cond_stage_config: 32 | target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder 33 | params: 34 | version: "./checkpoints/clip/clip-vit-large-patch14" 35 | zero_embedding_radio: 0.1 36 | 37 | first_stage_key: "surface" 38 | cond_stage_key: "image" 39 | scale_by_std: false 40 | 41 | denoiser_cfg: 42 | target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser 43 | params: 44 | input_channels: *embed_dim 45 | output_channels: *embed_dim 46 | n_ctx: *num_latents 47 | width: 768 48 | layers: 6 # 2 * 6 + 1 = 13 49 | heads: 12 50 | context_dim: 1024 51 | init_scale: 1.0 52 | skip_ln: true 53 | use_checkpoint: true 54 | 55 | scheduler_cfg: 56 | guidance_scale: 7.5 57 | num_inference_steps: 50 58 | eta: 0.0 59 | 60 | noise: 61 | target: diffusers.schedulers.DDPMScheduler 62 | params: 63 | num_train_timesteps: 1000 64 | beta_start: 0.00085 65 | beta_end: 0.012 66 | beta_schedule: "scaled_linear" 67 | variance_type: "fixed_small" 68 | clip_sample: false 69 | denoise: 70 | target: diffusers.schedulers.DDIMScheduler 71 | params: 72 | num_train_timesteps: 1000 73 | beta_start: 0.00085 74 | beta_end: 0.012 75 | beta_schedule: "scaled_linear" 76 | clip_sample: false # clip sample to -1~1 77 | set_alpha_to_one: false 78 | steps_offset: 1 79 | 80 | optimizer_cfg: 81 | optimizer: 82 | target: torch.optim.AdamW 83 | params: 84 | betas: [0.9, 0.99] 85 | eps: 1.e-6 86 | weight_decay: 1.e-2 87 | 88 | scheduler: 89 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 90 | params: 91 | warm_up_steps: 5000 92 | f_start: 1.e-6 93 | f_min: 1.e-3 94 | f_max: 1.0 95 | 96 | loss_cfg: 97 | loss_type: "mse" 98 | -------------------------------------------------------------------------------- /configs/text_cond_diffuser_asl/text-ASLDM-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser 3 | params: 4 | first_stage_config: 5 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 6 | params: 7 | shape_module_cfg: 8 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 9 | params: 10 | num_latents: &num_latents 256 11 | embed_dim: &embed_dim 64 12 | point_feats: 3 # normal 13 | num_freqs: 8 14 | include_pi: false 15 | heads: 12 16 | width: 768 17 | num_encoder_layers: 8 18 | num_decoder_layers: 16 19 | use_ln_post: true 20 | init_scale: 0.25 21 | qkv_bias: false 22 | use_checkpoint: true 23 | aligned_module_cfg: 24 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 25 | params: 26 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 27 | 28 | loss_cfg: 29 | target: torch.nn.Identity 30 | 31 | cond_stage_config: 32 | target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder 33 | params: 34 | version: "./checkpoints/clip/clip-vit-large-patch14" 35 | zero_embedding_radio: 0.1 36 | max_length: 77 37 | 38 | first_stage_key: "surface" 39 | cond_stage_key: "text" 40 | scale_by_std: false 41 | 42 | denoiser_cfg: 43 | target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser 44 | params: 45 | input_channels: *embed_dim 46 | output_channels: *embed_dim 47 | n_ctx: *num_latents 48 | width: 768 49 | layers: 8 # 2 * 6 + 1 = 13 50 | heads: 12 51 | context_dim: 768 52 | init_scale: 1.0 53 | skip_ln: true 54 | use_checkpoint: true 55 | 56 | scheduler_cfg: 57 | guidance_scale: 7.5 58 | num_inference_steps: 50 59 | eta: 0.0 60 | 61 | noise: 62 | target: diffusers.schedulers.DDPMScheduler 63 | params: 64 | num_train_timesteps: 1000 65 | beta_start: 0.00085 66 | beta_end: 0.012 67 | beta_schedule: "scaled_linear" 68 | variance_type: "fixed_small" 69 | clip_sample: false 70 | denoise: 71 | target: diffusers.schedulers.DDIMScheduler 72 | params: 73 | num_train_timesteps: 1000 74 | beta_start: 0.00085 75 | beta_end: 0.012 76 | beta_schedule: "scaled_linear" 77 | clip_sample: false # clip sample to -1~1 78 | set_alpha_to_one: false 79 | steps_offset: 1 80 | 81 | optimizer_cfg: 82 | optimizer: 83 | target: torch.optim.AdamW 84 | params: 85 | betas: [0.9, 0.99] 86 | eps: 1.e-6 87 | weight_decay: 1.e-2 88 | 89 | scheduler: 90 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 91 | params: 92 | warm_up_steps: 5000 93 | f_start: 1.e-6 94 | f_min: 1.e-3 95 | f_max: 1.0 96 | 97 | loss_cfg: 98 | loss_type: "mse" -------------------------------------------------------------------------------- /example_data/image/car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralCarver/Michelangelo/6d83b0bacef92715dd5179d45647ed9a3d39bc95/example_data/image/car.jpg -------------------------------------------------------------------------------- /example_data/surface/surface.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralCarver/Michelangelo/6d83b0bacef92715dd5179d45647ed9a3d39bc95/example_data/surface/surface.npz -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | from collections import OrderedDict 5 | from typing import Optional, List 6 | import argparse 7 | from functools import partial 8 | 9 | from einops import repeat, rearrange 10 | import numpy as np 11 | from PIL import Image 12 | import trimesh 13 | import cv2 14 | 15 | import torch 16 | import pytorch_lightning as pl 17 | 18 | from michelangelo.models.tsal.tsal_base import Latent2MeshOutput 19 | from michelangelo.models.tsal.inference_utils import extract_geometry 20 | from michelangelo.utils.misc import get_config_from_file, instantiate_from_config 21 | from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer 22 | from michelangelo.utils.visualizers import html_util 23 | 24 | def load_model(args): 25 | 26 | model_config = get_config_from_file(args.config_path) 27 | if hasattr(model_config, "model"): 28 | model_config = model_config.model 29 | 30 | model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path) 31 | model = model.cuda() 32 | model = model.eval() 33 | 34 | return model 35 | 36 | def load_surface(fp): 37 | 38 | with np.load(args.pointcloud_path) as input_pc: 39 | surface = input_pc['points'] 40 | normal = input_pc['normals'] 41 | 42 | rng = np.random.default_rng() 43 | ind = rng.choice(surface.shape[0], 4096, replace=False) 44 | surface = torch.FloatTensor(surface[ind]) 45 | normal = torch.FloatTensor(normal[ind]) 46 | 47 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() 48 | 49 | return surface 50 | 51 | def prepare_image(args, number_samples=2): 52 | 53 | image = cv2.imread(f"{args.image_path}") 54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 55 | 56 | image_pt = torch.tensor(image).float() 57 | image_pt = image_pt / 255 * 2 - 1 58 | image_pt = rearrange(image_pt, "h w c -> c h w") 59 | 60 | image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples) 61 | 62 | return image_pt 63 | 64 | def save_output(args, mesh_outputs): 65 | 66 | os.makedirs(args.output_dir, exist_ok=True) 67 | for i, mesh in enumerate(mesh_outputs): 68 | mesh.mesh_f = mesh.mesh_f[:, ::-1] 69 | mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) 70 | 71 | name = str(i) + "_out_mesh.obj" 72 | mesh_output.export(os.path.join(args.output_dir, name), include_normals=True) 73 | 74 | print(f'-----------------------------------------------------------------------------') 75 | print(f'>>> Finished and mesh saved in {args.output_dir}') 76 | print(f'-----------------------------------------------------------------------------') 77 | 78 | return 0 79 | 80 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): 81 | 82 | surface = load_surface(args.pointcloud_path) 83 | 84 | # encoding 85 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) 86 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) 87 | 88 | # decoding 89 | latents = model.model.shape_model.decode(shape_zq) 90 | geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) 91 | 92 | # reconstruction 93 | mesh_v_f, has_surface = extract_geometry( 94 | geometric_func=geometric_func, 95 | device=surface.device, 96 | batch_size=surface.shape[0], 97 | bounds=bounds, 98 | octree_depth=octree_depth, 99 | num_chunks=num_chunks, 100 | ) 101 | recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1]) 102 | 103 | # save 104 | os.makedirs(args.output_dir, exist_ok=True) 105 | recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj')) 106 | 107 | print(f'-----------------------------------------------------------------------------') 108 | print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}') 109 | print(f'-----------------------------------------------------------------------------') 110 | 111 | return 0 112 | 113 | def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7): 114 | 115 | sample_inputs = { 116 | "image": prepare_image(args) 117 | } 118 | 119 | mesh_outputs = model.sample( 120 | sample_inputs, 121 | sample_times=1, 122 | guidance_scale=guidance_scale, 123 | return_intermediates=False, 124 | bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], 125 | octree_depth=octree_depth, 126 | )[0] 127 | 128 | save_output(args, mesh_outputs) 129 | 130 | return 0 131 | 132 | def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7): 133 | 134 | sample_inputs = { 135 | "text": [args.text] * num_samples 136 | } 137 | mesh_outputs = model.sample( 138 | sample_inputs, 139 | sample_times=1, 140 | guidance_scale=guidance_scale, 141 | return_intermediates=False, 142 | bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], 143 | octree_depth=octree_depth, 144 | )[0] 145 | 146 | save_output(args, mesh_outputs) 147 | 148 | return 0 149 | 150 | task_dick = { 151 | 'reconstruction': reconstruction, 152 | 'image2mesh': image2mesh, 153 | 'text2mesh': text2mesh, 154 | } 155 | 156 | if __name__ == "__main__": 157 | ''' 158 | 1. Reconstruct point cloud 159 | 2. Image-conditioned generation 160 | 3. Text-conditioned generation 161 | ''' 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True) 164 | parser.add_argument("--config_path", type=str, required=True) 165 | parser.add_argument("--ckpt_path", type=str, required=True) 166 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud') 167 | parser.add_argument("--image_path", type=str, help='Path to the input image') 168 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.') 169 | parser.add_argument("--output_dir", type=str, default='./output') 170 | parser.add_argument("-s", "--seed", type=int, default=0) 171 | args = parser.parse_args() 172 | 173 | pl.seed_everything(args.seed) 174 | 175 | print(f'-----------------------------------------------------------------------------') 176 | print(f'>>> Running {args.task}') 177 | args.output_dir = os.path.join(args.output_dir, args.task) 178 | print(f'>>> Output directory: {args.output_dir}') 179 | print(f'-----------------------------------------------------------------------------') 180 | 181 | task_dick[args.task](args, load_model(args)) -------------------------------------------------------------------------------- /michelangelo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /michelangelo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /michelangelo/data/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "shape": [ 3 | "a point cloud model of {}.", 4 | "There is a {} in the scene.", 5 | "There is the {} in the scene.", 6 | "a photo of a {} in the scene.", 7 | "a photo of the {} in the scene.", 8 | "a photo of one {} in the scene.", 9 | "itap of a {}.", 10 | "itap of my {}.", 11 | "itap of the {}.", 12 | "a photo of a {}.", 13 | "a photo of my {}.", 14 | "a photo of the {}.", 15 | "a photo of one {}.", 16 | "a photo of many {}.", 17 | "a good photo of a {}.", 18 | "a good photo of the {}.", 19 | "a bad photo of a {}.", 20 | "a bad photo of the {}.", 21 | "a photo of a nice {}.", 22 | "a photo of the nice {}.", 23 | "a photo of a cool {}.", 24 | "a photo of the cool {}.", 25 | "a photo of a weird {}.", 26 | "a photo of the weird {}.", 27 | "a photo of a small {}.", 28 | "a photo of the small {}.", 29 | "a photo of a large {}.", 30 | "a photo of the large {}.", 31 | "a photo of a clean {}.", 32 | "a photo of the clean {}.", 33 | "a photo of a dirty {}.", 34 | "a photo of the dirty {}.", 35 | "a bright photo of a {}.", 36 | "a bright photo of the {}.", 37 | "a dark photo of a {}.", 38 | "a dark photo of the {}.", 39 | "a photo of a hard to see {}.", 40 | "a photo of the hard to see {}.", 41 | "a low resolution photo of a {}.", 42 | "a low resolution photo of the {}.", 43 | "a cropped photo of a {}.", 44 | "a cropped photo of the {}.", 45 | "a close-up photo of a {}.", 46 | "a close-up photo of the {}.", 47 | "a jpeg corrupted photo of a {}.", 48 | "a jpeg corrupted photo of the {}.", 49 | "a blurry photo of a {}.", 50 | "a blurry photo of the {}.", 51 | "a pixelated photo of a {}.", 52 | "a pixelated photo of the {}.", 53 | "a black and white photo of the {}.", 54 | "a black and white photo of a {}", 55 | "a plastic {}.", 56 | "the plastic {}.", 57 | "a toy {}.", 58 | "the toy {}.", 59 | "a plushie {}.", 60 | "the plushie {}.", 61 | "a cartoon {}.", 62 | "the cartoon {}.", 63 | "an embroidered {}.", 64 | "the embroidered {}.", 65 | "a painting of the {}.", 66 | "a painting of a {}." 67 | ] 68 | 69 | } -------------------------------------------------------------------------------- /michelangelo/data/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import numpy as np 5 | import warnings 6 | import random 7 | from omegaconf.listconfig import ListConfig 8 | from webdataset import pipelinefilter 9 | import torch 10 | import torchvision.transforms.functional as TVF 11 | from torchvision.transforms import InterpolationMode 12 | from torchvision.transforms.transforms import _interpolation_modes_from_int 13 | from typing import Sequence 14 | 15 | from michelangelo.utils import instantiate_from_config 16 | 17 | 18 | def _uid_buffer_pick(buf_dict, rng): 19 | uid_keys = list(buf_dict.keys()) 20 | selected_uid = rng.choice(uid_keys) 21 | buf = buf_dict[selected_uid] 22 | 23 | k = rng.randint(0, len(buf) - 1) 24 | sample = buf[k] 25 | buf[k] = buf[-1] 26 | buf.pop() 27 | 28 | if len(buf) == 0: 29 | del buf_dict[selected_uid] 30 | 31 | return sample 32 | 33 | 34 | def _add_to_buf_dict(buf_dict, sample): 35 | key = sample["__key__"] 36 | uid, uid_sample_id = key.split("_") 37 | if uid not in buf_dict: 38 | buf_dict[uid] = [] 39 | buf_dict[uid].append(sample) 40 | 41 | return buf_dict 42 | 43 | 44 | def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): 45 | """Shuffle the data in the stream. 46 | 47 | This uses a buffer of size `bufsize`. Shuffling at 48 | startup is less random; this is traded off against 49 | yielding samples quickly. 50 | 51 | data: iterator 52 | bufsize: buffer size for shuffling 53 | returns: iterator 54 | rng: either random module or random.Random instance 55 | 56 | """ 57 | if rng is None: 58 | rng = random.Random(int((os.getpid() + time.time()) * 1e9)) 59 | initial = min(initial, bufsize) 60 | buf_dict = dict() 61 | current_samples = 0 62 | for sample in data: 63 | _add_to_buf_dict(buf_dict, sample) 64 | current_samples += 1 65 | 66 | if current_samples < bufsize: 67 | try: 68 | _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708 69 | current_samples += 1 70 | except StopIteration: 71 | pass 72 | 73 | if current_samples >= initial: 74 | current_samples -= 1 75 | yield _uid_buffer_pick(buf_dict, rng) 76 | 77 | while current_samples > 0: 78 | current_samples -= 1 79 | yield _uid_buffer_pick(buf_dict, rng) 80 | 81 | 82 | uid_shuffle = pipelinefilter(_uid_shuffle) 83 | 84 | 85 | class RandomSample(object): 86 | def __init__(self, 87 | num_volume_samples: int = 1024, 88 | num_near_samples: int = 1024): 89 | 90 | super().__init__() 91 | 92 | self.num_volume_samples = num_volume_samples 93 | self.num_near_samples = num_near_samples 94 | 95 | def __call__(self, sample): 96 | rng = np.random.default_rng() 97 | 98 | # 1. sample surface input 99 | total_surface = sample["surface"] 100 | ind = rng.choice(total_surface.shape[0], replace=False) 101 | surface = total_surface[ind] 102 | 103 | # 2. sample volume/near geometric points 104 | vol_points = sample["vol_points"] 105 | vol_label = sample["vol_label"] 106 | near_points = sample["near_points"] 107 | near_label = sample["near_label"] 108 | 109 | ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) 110 | vol_points = vol_points[ind] 111 | vol_label = vol_label[ind] 112 | vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) 113 | 114 | ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) 115 | near_points = near_points[ind] 116 | near_label = near_label[ind] 117 | near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) 118 | 119 | # concat sampled volume and near points 120 | geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) 121 | 122 | sample = { 123 | "surface": surface, 124 | "geo_points": geo_points 125 | } 126 | 127 | return sample 128 | 129 | 130 | class SplitRandomSample(object): 131 | def __init__(self, 132 | use_surface_sample: bool = False, 133 | num_surface_samples: int = 4096, 134 | num_volume_samples: int = 1024, 135 | num_near_samples: int = 1024): 136 | 137 | super().__init__() 138 | 139 | self.use_surface_sample = use_surface_sample 140 | self.num_surface_samples = num_surface_samples 141 | self.num_volume_samples = num_volume_samples 142 | self.num_near_samples = num_near_samples 143 | 144 | def __call__(self, sample): 145 | 146 | rng = np.random.default_rng() 147 | 148 | # 1. sample surface input 149 | surface = sample["surface"] 150 | 151 | if self.use_surface_sample: 152 | replace = surface.shape[0] < self.num_surface_samples 153 | ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace) 154 | surface = surface[ind] 155 | 156 | # 2. sample volume/near geometric points 157 | vol_points = sample["vol_points"] 158 | vol_label = sample["vol_label"] 159 | near_points = sample["near_points"] 160 | near_label = sample["near_label"] 161 | 162 | ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) 163 | vol_points = vol_points[ind] 164 | vol_label = vol_label[ind] 165 | vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) 166 | 167 | ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) 168 | near_points = near_points[ind] 169 | near_label = near_label[ind] 170 | near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) 171 | 172 | # concat sampled volume and near points 173 | geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) 174 | 175 | sample = { 176 | "surface": surface, 177 | "geo_points": geo_points 178 | } 179 | 180 | return sample 181 | 182 | 183 | class FeatureSelection(object): 184 | 185 | VALID_SURFACE_FEATURE_DIMS = { 186 | "none": [0, 1, 2], # xyz 187 | "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal 188 | "normal": [0, 1, 2, 6, 7, 8] 189 | } 190 | 191 | def __init__(self, surface_feature_type: str): 192 | 193 | self.surface_feature_type = surface_feature_type 194 | self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type] 195 | 196 | def __call__(self, sample): 197 | sample["surface"] = sample["surface"][:, self.surface_dims] 198 | return sample 199 | 200 | 201 | class AxisScaleTransform(object): 202 | def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): 203 | assert isinstance(interval, (tuple, list, ListConfig)) 204 | self.interval = interval 205 | self.min_val = interval[0] 206 | self.max_val = interval[1] 207 | self.inter_size = interval[1] - interval[0] 208 | self.jitter = jitter 209 | self.jitter_scale = jitter_scale 210 | 211 | def __call__(self, sample): 212 | 213 | surface = sample["surface"][..., 0:3] 214 | geo_points = sample["geo_points"][..., 0:3] 215 | 216 | scaling = torch.rand(1, 3) * self.inter_size + self.min_val 217 | # print(scaling) 218 | surface = surface * scaling 219 | geo_points = geo_points * scaling 220 | 221 | scale = (1 / torch.abs(surface).max().item()) * 0.999999 222 | surface *= scale 223 | geo_points *= scale 224 | 225 | if self.jitter: 226 | surface += self.jitter_scale * torch.randn_like(surface) 227 | surface.clamp_(min=-1.015, max=1.015) 228 | 229 | sample["surface"][..., 0:3] = surface 230 | sample["geo_points"][..., 0:3] = geo_points 231 | 232 | return sample 233 | 234 | 235 | class ToTensor(object): 236 | 237 | def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")): 238 | self.tensor_keys = tensor_keys 239 | 240 | def __call__(self, sample): 241 | for key in self.tensor_keys: 242 | if key not in sample: 243 | continue 244 | 245 | sample[key] = torch.tensor(sample[key], dtype=torch.float32) 246 | 247 | return sample 248 | 249 | 250 | class AxisScale(object): 251 | def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): 252 | assert isinstance(interval, (tuple, list, ListConfig)) 253 | self.interval = interval 254 | self.jitter = jitter 255 | self.jitter_scale = jitter_scale 256 | 257 | def __call__(self, surface, *args): 258 | scaling = torch.rand(1, 3) * 0.5 + 0.75 259 | # print(scaling) 260 | surface = surface * scaling 261 | scale = (1 / torch.abs(surface).max().item()) * 0.999999 262 | surface *= scale 263 | 264 | args_outputs = [] 265 | for _arg in args: 266 | _arg = _arg * scaling * scale 267 | args_outputs.append(_arg) 268 | 269 | if self.jitter: 270 | surface += self.jitter_scale * torch.randn_like(surface) 271 | surface.clamp_(min=-1, max=1) 272 | 273 | if len(args) == 0: 274 | return surface 275 | else: 276 | return surface, *args_outputs 277 | 278 | 279 | class RandomResize(torch.nn.Module): 280 | """Apply randomly Resize with a given probability.""" 281 | 282 | def __init__( 283 | self, 284 | size, 285 | resize_radio=(0.5, 1), 286 | allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR), 287 | interpolation=InterpolationMode.BICUBIC, 288 | max_size=None, 289 | antialias=None, 290 | ): 291 | super().__init__() 292 | if not isinstance(size, (int, Sequence)): 293 | raise TypeError(f"Size should be int or sequence. Got {type(size)}") 294 | if isinstance(size, Sequence) and len(size) not in (1, 2): 295 | raise ValueError("If size is a sequence, it should have 1 or 2 values") 296 | 297 | self.size = size 298 | self.max_size = max_size 299 | # Backward compatibility with integer value 300 | if isinstance(interpolation, int): 301 | warnings.warn( 302 | "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " 303 | "Please use InterpolationMode enum." 304 | ) 305 | interpolation = _interpolation_modes_from_int(interpolation) 306 | 307 | self.interpolation = interpolation 308 | self.antialias = antialias 309 | 310 | self.resize_radio = resize_radio 311 | self.allow_resize_interpolations = allow_resize_interpolations 312 | 313 | def random_resize_params(self): 314 | radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0] 315 | 316 | if isinstance(self.size, int): 317 | size = int(self.size * radio) 318 | elif isinstance(self.size, Sequence): 319 | size = list(self.size) 320 | size = (int(size[0] * radio), int(size[1] * radio)) 321 | else: 322 | raise RuntimeError() 323 | 324 | interpolation = self.allow_resize_interpolations[ 325 | torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,)) 326 | ] 327 | return size, interpolation 328 | 329 | def forward(self, img): 330 | size, interpolation = self.random_resize_params() 331 | img = TVF.resize(img, size, interpolation, self.max_size, self.antialias) 332 | img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) 333 | return img 334 | 335 | def __repr__(self) -> str: 336 | detail = f"(size={self.size}, interpolation={self.interpolation.value}," 337 | detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}" 338 | return f"{self.__class__.__name__}{detail}" 339 | 340 | 341 | class Compose(object): 342 | """Composes several transforms together. This transform does not support torchscript. 343 | Please, see the note below. 344 | 345 | Args: 346 | transforms (list of ``Transform`` objects): list of transforms to compose. 347 | 348 | Example: 349 | >>> transforms.Compose([ 350 | >>> transforms.CenterCrop(10), 351 | >>> transforms.ToTensor(), 352 | >>> ]) 353 | 354 | .. note:: 355 | In order to script the transformations, please use ``torch.nn.Sequential`` as below. 356 | 357 | >>> transforms = torch.nn.Sequential( 358 | >>> transforms.CenterCrop(10), 359 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 360 | >>> ) 361 | >>> scripted_transforms = torch.jit.script(transforms) 362 | 363 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require 364 | `lambda` functions or ``PIL.Image``. 365 | 366 | """ 367 | 368 | def __init__(self, transforms): 369 | self.transforms = transforms 370 | 371 | def __call__(self, *args): 372 | for t in self.transforms: 373 | args = t(*args) 374 | return args 375 | 376 | def __repr__(self): 377 | format_string = self.__class__.__name__ + '(' 378 | for t in self.transforms: 379 | format_string += '\n' 380 | format_string += ' {0}'.format(t) 381 | format_string += '\n)' 382 | return format_string 383 | 384 | 385 | def identity(*args, **kwargs): 386 | if len(args) == 1: 387 | return args[0] 388 | else: 389 | return args 390 | 391 | 392 | def build_transforms(cfg): 393 | 394 | if cfg is None: 395 | return identity 396 | 397 | transforms = [] 398 | 399 | for transform_name, cfg_instance in cfg.items(): 400 | transform_instance = instantiate_from_config(cfg_instance) 401 | transforms.append(transform_instance) 402 | print(f"Build transform: {transform_instance}") 403 | 404 | transforms = Compose(transforms) 405 | 406 | return transforms 407 | 408 | -------------------------------------------------------------------------------- /michelangelo/data/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def worker_init_fn(_): 8 | worker_info = torch.utils.data.get_worker_info() 9 | worker_id = worker_info.id 10 | 11 | # dataset = worker_info.dataset 12 | # split_size = dataset.num_records // worker_info.num_workers 13 | # # reset num_records to the true number to retain reliable length information 14 | # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 15 | # current_id = np.random.choice(len(np.random.get_state()[1]), 1) 16 | # return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 17 | 18 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 19 | 20 | 21 | def collation_fn(samples, combine_tensors=True, combine_scalars=True): 22 | """ 23 | 24 | Args: 25 | samples (list[dict]): 26 | combine_tensors: 27 | combine_scalars: 28 | 29 | Returns: 30 | 31 | """ 32 | 33 | result = {} 34 | 35 | keys = samples[0].keys() 36 | 37 | for key in keys: 38 | result[key] = [] 39 | 40 | for sample in samples: 41 | for key in keys: 42 | val = sample[key] 43 | result[key].append(val) 44 | 45 | for key in keys: 46 | val_list = result[key] 47 | if isinstance(val_list[0], (int, float)): 48 | if combine_scalars: 49 | result[key] = np.array(result[key]) 50 | 51 | elif isinstance(val_list[0], torch.Tensor): 52 | if combine_tensors: 53 | result[key] = torch.stack(val_list) 54 | 55 | elif isinstance(val_list[0], np.ndarray): 56 | if combine_tensors: 57 | result[key] = np.stack(val_list) 58 | 59 | return result 60 | -------------------------------------------------------------------------------- /michelangelo/graphics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /michelangelo/graphics/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .volume import generate_dense_grid_points 4 | 5 | from .mesh import ( 6 | MeshOutput, 7 | save_obj, 8 | savemeshtes2 9 | ) 10 | -------------------------------------------------------------------------------- /michelangelo/graphics/primitives/mesh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import PIL.Image 7 | from typing import Optional 8 | 9 | import trimesh 10 | 11 | 12 | def save_obj(pointnp_px3, facenp_fx3, fname): 13 | fid = open(fname, "w") 14 | write_str = "" 15 | for pidx, p in enumerate(pointnp_px3): 16 | pp = p 17 | write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2]) 18 | 19 | for i, f in enumerate(facenp_fx3): 20 | f1 = f + 1 21 | write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2]) 22 | fid.write(write_str) 23 | fid.close() 24 | return 25 | 26 | 27 | def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname): 28 | fol, na = os.path.split(fname) 29 | na, _ = os.path.splitext(na) 30 | 31 | matname = "%s/%s.mtl" % (fol, na) 32 | fid = open(matname, "w") 33 | fid.write("newmtl material_0\n") 34 | fid.write("Kd 1 1 1\n") 35 | fid.write("Ka 0 0 0\n") 36 | fid.write("Ks 0.4 0.4 0.4\n") 37 | fid.write("Ns 10\n") 38 | fid.write("illum 2\n") 39 | fid.write("map_Kd %s.png\n" % na) 40 | fid.close() 41 | #### 42 | 43 | fid = open(fname, "w") 44 | fid.write("mtllib %s.mtl\n" % na) 45 | 46 | for pidx, p in enumerate(pointnp_px3): 47 | pp = p 48 | fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2])) 49 | 50 | for pidx, p in enumerate(tcoords_px2): 51 | pp = p 52 | fid.write("vt %f %f\n" % (pp[0], pp[1])) 53 | 54 | fid.write("usemtl material_0\n") 55 | for i, f in enumerate(facenp_fx3): 56 | f1 = f + 1 57 | f2 = facetex_fx3[i] + 1 58 | fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) 59 | fid.close() 60 | 61 | PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save( 62 | os.path.join(fol, "%s.png" % na)) 63 | 64 | return 65 | 66 | 67 | class MeshOutput(object): 68 | 69 | def __init__(self, 70 | mesh_v: np.ndarray, 71 | mesh_f: np.ndarray, 72 | vertex_colors: Optional[np.ndarray] = None, 73 | uvs: Optional[np.ndarray] = None, 74 | mesh_tex_idx: Optional[np.ndarray] = None, 75 | tex_map: Optional[np.ndarray] = None): 76 | 77 | self.mesh_v = mesh_v 78 | self.mesh_f = mesh_f 79 | self.vertex_colors = vertex_colors 80 | self.uvs = uvs 81 | self.mesh_tex_idx = mesh_tex_idx 82 | self.tex_map = tex_map 83 | 84 | def contain_uv_texture(self): 85 | return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None) 86 | 87 | def contain_vertex_colors(self): 88 | return self.vertex_colors is not None 89 | 90 | def export(self, fname): 91 | 92 | if self.contain_uv_texture(): 93 | savemeshtes2( 94 | self.mesh_v, 95 | self.uvs, 96 | self.mesh_f, 97 | self.mesh_tex_idx, 98 | self.tex_map, 99 | fname 100 | ) 101 | 102 | elif self.contain_vertex_colors(): 103 | mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors) 104 | mesh_obj.export(fname) 105 | 106 | else: 107 | save_obj( 108 | self.mesh_v, 109 | self.mesh_f, 110 | fname 111 | ) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /michelangelo/graphics/primitives/volume.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | 6 | def generate_dense_grid_points(bbox_min: np.ndarray, 7 | bbox_max: np.ndarray, 8 | octree_depth: int, 9 | indexing: str = "ij"): 10 | length = bbox_max - bbox_min 11 | num_cells = np.exp2(octree_depth) 12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) 13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) 14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) 15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) 16 | xyz = np.stack((xs, ys, zs), axis=-1) 17 | xyz = xyz.reshape(-1, 3) 18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] 19 | 20 | return xyz, grid_size, length 21 | 22 | -------------------------------------------------------------------------------- /michelangelo/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /michelangelo/models/asl_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from omegaconf import DictConfig 4 | from typing import List, Tuple, Dict, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import lr_scheduler 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.utilities import rank_zero_only 12 | 13 | from einops import rearrange 14 | 15 | from diffusers.schedulers import ( 16 | DDPMScheduler, 17 | DDIMScheduler, 18 | KarrasVeScheduler, 19 | DPMSolverMultistepScheduler 20 | ) 21 | 22 | from michelangelo.utils import instantiate_from_config 23 | # from michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule 24 | from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule 25 | from michelangelo.models.asl_diffusion.inference_utils import ddim_sample 26 | 27 | SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] 28 | 29 | 30 | def disabled_train(self, mode=True): 31 | """Overwrite model.train with this function to make sure train/eval mode 32 | does not change anymore.""" 33 | return self 34 | 35 | 36 | class ASLDiffuser(pl.LightningModule): 37 | first_stage_model: Optional[AlignedShapeAsLatentPLModule] 38 | # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] 39 | model: nn.Module 40 | 41 | def __init__(self, *, 42 | first_stage_config, 43 | denoiser_cfg, 44 | scheduler_cfg, 45 | optimizer_cfg, 46 | loss_cfg, 47 | first_stage_key: str = "surface", 48 | cond_stage_key: str = "image", 49 | cond_stage_trainable: bool = True, 50 | scale_by_std: bool = False, 51 | z_scale_factor: float = 1.0, 52 | ckpt_path: Optional[str] = None, 53 | ignore_keys: Union[Tuple[str], List[str]] = ()): 54 | 55 | super().__init__() 56 | 57 | self.first_stage_key = first_stage_key 58 | self.cond_stage_key = cond_stage_key 59 | self.cond_stage_trainable = cond_stage_trainable 60 | 61 | # 1. initialize first stage. 62 | # Note: the condition model contained in the first stage model. 63 | self.first_stage_config = first_stage_config 64 | self.first_stage_model = None 65 | # self.instantiate_first_stage(first_stage_config) 66 | 67 | # 2. initialize conditional stage 68 | # self.instantiate_cond_stage(cond_stage_config) 69 | self.cond_stage_model = { 70 | "image": self.encode_image, 71 | "image_unconditional_embedding": self.empty_img_cond, 72 | "text": self.encode_text, 73 | "text_unconditional_embedding": self.empty_text_cond, 74 | "surface": self.encode_surface, 75 | "surface_unconditional_embedding": self.empty_surface_cond, 76 | } 77 | 78 | # 3. diffusion model 79 | self.model = instantiate_from_config( 80 | denoiser_cfg, device=None, dtype=None 81 | ) 82 | 83 | self.optimizer_cfg = optimizer_cfg 84 | 85 | # 4. scheduling strategy 86 | self.scheduler_cfg = scheduler_cfg 87 | 88 | self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) 89 | self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) 90 | 91 | # 5. loss configures 92 | self.loss_cfg = loss_cfg 93 | 94 | self.scale_by_std = scale_by_std 95 | if scale_by_std: 96 | self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) 97 | else: 98 | self.z_scale_factor = z_scale_factor 99 | 100 | self.ckpt_path = ckpt_path 101 | if ckpt_path is not None: 102 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 103 | 104 | def instantiate_first_stage(self, config): 105 | model = instantiate_from_config(config) 106 | self.first_stage_model = model.eval() 107 | self.first_stage_model.train = disabled_train 108 | for param in self.first_stage_model.parameters(): 109 | param.requires_grad = False 110 | 111 | self.first_stage_model = self.first_stage_model.to(self.device) 112 | 113 | # def instantiate_cond_stage(self, config): 114 | # if not self.cond_stage_trainable: 115 | # if config == "__is_first_stage__": 116 | # print("Using first stage also as cond stage.") 117 | # self.cond_stage_model = self.first_stage_model 118 | # elif config == "__is_unconditional__": 119 | # print(f"Training {self.__class__.__name__} as an unconditional model.") 120 | # self.cond_stage_model = None 121 | # # self.be_unconditional = True 122 | # else: 123 | # model = instantiate_from_config(config) 124 | # self.cond_stage_model = model.eval() 125 | # self.cond_stage_model.train = disabled_train 126 | # for param in self.cond_stage_model.parameters(): 127 | # param.requires_grad = False 128 | # else: 129 | # assert config != "__is_first_stage__" 130 | # assert config != "__is_unconditional__" 131 | # model = instantiate_from_config(config) 132 | # self.cond_stage_model = model 133 | 134 | def init_from_ckpt(self, path, ignore_keys=()): 135 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 136 | 137 | keys = list(state_dict.keys()) 138 | for k in keys: 139 | for ik in ignore_keys: 140 | if k.startswith(ik): 141 | print("Deleting key {} from state_dict.".format(k)) 142 | del state_dict[k] 143 | 144 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 145 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 146 | if len(missing) > 0: 147 | print(f"Missing Keys: {missing}") 148 | print(f"Unexpected Keys: {unexpected}") 149 | 150 | @property 151 | def zero_rank(self): 152 | if self._trainer: 153 | zero_rank = self.trainer.local_rank == 0 154 | else: 155 | zero_rank = True 156 | 157 | return zero_rank 158 | 159 | def configure_optimizers(self) -> Tuple[List, List]: 160 | 161 | lr = self.learning_rate 162 | 163 | trainable_parameters = list(self.model.parameters()) 164 | # if the conditional encoder is trainable 165 | 166 | # if self.cond_stage_trainable: 167 | # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad] 168 | # trainable_parameters += conditioner_params 169 | # print(f"number of trainable conditional parameters: {len(conditioner_params)}.") 170 | 171 | if self.optimizer_cfg is None: 172 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 173 | schedulers = [] 174 | else: 175 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 176 | scheduler_func = instantiate_from_config( 177 | self.optimizer_cfg.scheduler, 178 | max_decay_steps=self.trainer.max_steps, 179 | lr_max=lr 180 | ) 181 | scheduler = { 182 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 183 | "interval": "step", 184 | "frequency": 1 185 | } 186 | optimizers = [optimizer] 187 | schedulers = [scheduler] 188 | 189 | return optimizers, schedulers 190 | 191 | @torch.no_grad() 192 | def encode_text(self, text): 193 | 194 | b = text.shape[0] 195 | text_tokens = rearrange(text, "b t l -> (b t) l") 196 | text_embed = self.first_stage_model.model.encode_text_embed(text_tokens) 197 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 198 | text_embed = text_embed.mean(dim=1) 199 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 200 | 201 | return text_embed 202 | 203 | @torch.no_grad() 204 | def encode_image(self, img): 205 | 206 | return self.first_stage_model.model.encode_image_embed(img) 207 | 208 | @torch.no_grad() 209 | def encode_surface(self, surface): 210 | 211 | return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False) 212 | 213 | @torch.no_grad() 214 | def empty_text_cond(self, cond): 215 | 216 | return torch.zeros_like(cond, device=cond.device) 217 | 218 | @torch.no_grad() 219 | def empty_img_cond(self, cond): 220 | 221 | return torch.zeros_like(cond, device=cond.device) 222 | 223 | @torch.no_grad() 224 | def empty_surface_cond(self, cond): 225 | 226 | return torch.zeros_like(cond, device=cond.device) 227 | 228 | @torch.no_grad() 229 | def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): 230 | 231 | z_q = self.first_stage_model.encode(surface, sample_posterior) 232 | z_q = self.z_scale_factor * z_q 233 | 234 | return z_q 235 | 236 | @torch.no_grad() 237 | def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): 238 | 239 | z_q = 1. / self.z_scale_factor * z_q 240 | latents = self.first_stage_model.decode(z_q, **kwargs) 241 | return latents 242 | 243 | @rank_zero_only 244 | @torch.no_grad() 245 | def on_train_batch_start(self, batch, batch_idx): 246 | # only for very first batch 247 | if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ 248 | and batch_idx == 0 and self.ckpt_path is None: 249 | # set rescale weight to 1./std of encodings 250 | print("### USING STD-RESCALING ###") 251 | 252 | z_q = self.encode_first_stage(batch[self.first_stage_key]) 253 | z = z_q.detach() 254 | 255 | del self.z_scale_factor 256 | self.register_buffer("z_scale_factor", 1. / z.flatten().std()) 257 | print(f"setting self.z_scale_factor to {self.z_scale_factor}") 258 | 259 | print("### USING STD-RESCALING ###") 260 | 261 | def compute_loss(self, model_outputs, split): 262 | """ 263 | 264 | Args: 265 | model_outputs (dict): 266 | - x_0: 267 | - noise: 268 | - noise_prior: 269 | - noise_pred: 270 | - noise_pred_prior: 271 | 272 | split (str): 273 | 274 | Returns: 275 | 276 | """ 277 | 278 | pred = model_outputs["pred"] 279 | 280 | if self.noise_scheduler.prediction_type == "epsilon": 281 | target = model_outputs["noise"] 282 | elif self.noise_scheduler.prediction_type == "sample": 283 | target = model_outputs["x_0"] 284 | else: 285 | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") 286 | 287 | if self.loss_cfg.loss_type == "l1": 288 | simple = F.l1_loss(pred, target, reduction="mean") 289 | elif self.loss_cfg.loss_type in ["mse", "l2"]: 290 | simple = F.mse_loss(pred, target, reduction="mean") 291 | else: 292 | raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") 293 | 294 | total_loss = simple 295 | 296 | loss_dict = { 297 | f"{split}/total_loss": total_loss.clone().detach(), 298 | f"{split}/simple": simple.detach(), 299 | } 300 | 301 | return total_loss, loss_dict 302 | 303 | def forward(self, batch): 304 | """ 305 | 306 | Args: 307 | batch: 308 | 309 | Returns: 310 | 311 | """ 312 | 313 | if self.first_stage_model is None: 314 | self.instantiate_first_stage(self.first_stage_config) 315 | 316 | latents = self.encode_first_stage(batch[self.first_stage_key]) 317 | 318 | # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) 319 | 320 | conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1) 321 | 322 | mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1 323 | conditions = conditions * mask.to(conditions) 324 | 325 | # Sample noise that we"ll add to the latents 326 | # [batch_size, n_token, latent_dim] 327 | noise = torch.randn_like(latents) 328 | bs = latents.shape[0] 329 | # Sample a random timestep for each motion 330 | timesteps = torch.randint( 331 | 0, 332 | self.noise_scheduler.config.num_train_timesteps, 333 | (bs,), 334 | device=latents.device, 335 | ) 336 | timesteps = timesteps.long() 337 | # Add noise to the latents according to the noise magnitude at each timestep 338 | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) 339 | 340 | # diffusion model forward 341 | noise_pred = self.model(noisy_z, timesteps, conditions) 342 | 343 | diffusion_outputs = { 344 | "x_0": noisy_z, 345 | "noise": noise, 346 | "pred": noise_pred 347 | } 348 | 349 | return diffusion_outputs 350 | 351 | def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], 352 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 353 | """ 354 | 355 | Args: 356 | batch (dict): the batch sample, and it contains: 357 | - surface (torch.FloatTensor): 358 | - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] 359 | - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] 360 | - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] 361 | - text (list of str): 362 | 363 | batch_idx (int): 364 | 365 | optimizer_idx (int): 366 | 367 | Returns: 368 | loss (torch.FloatTensor): 369 | 370 | """ 371 | 372 | diffusion_outputs = self(batch) 373 | 374 | loss, loss_dict = self.compute_loss(diffusion_outputs, "train") 375 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 376 | 377 | return loss 378 | 379 | def validation_step(self, batch: Dict[str, torch.FloatTensor], 380 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 381 | """ 382 | 383 | Args: 384 | batch (dict): the batch sample, and it contains: 385 | - surface_pc (torch.FloatTensor): [n_pts, 4] 386 | - surface_feats (torch.FloatTensor): [n_pts, c] 387 | - text (list of str): 388 | 389 | batch_idx (int): 390 | 391 | optimizer_idx (int): 392 | 393 | Returns: 394 | loss (torch.FloatTensor): 395 | 396 | """ 397 | 398 | diffusion_outputs = self(batch) 399 | 400 | loss, loss_dict = self.compute_loss(diffusion_outputs, "val") 401 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 402 | 403 | return loss 404 | 405 | @torch.no_grad() 406 | def sample(self, 407 | batch: Dict[str, Union[torch.FloatTensor, List[str]]], 408 | sample_times: int = 1, 409 | steps: Optional[int] = None, 410 | guidance_scale: Optional[float] = None, 411 | eta: float = 0.0, 412 | return_intermediates: bool = False, **kwargs): 413 | 414 | if self.first_stage_model is None: 415 | self.instantiate_first_stage(self.first_stage_config) 416 | 417 | if steps is None: 418 | steps = self.scheduler_cfg.num_inference_steps 419 | 420 | if guidance_scale is None: 421 | guidance_scale = self.scheduler_cfg.guidance_scale 422 | do_classifier_free_guidance = guidance_scale > 0 423 | 424 | # conditional encode 425 | xc = batch[self.cond_stage_key] 426 | # cond = self.cond_stage_model[self.cond_stage_key](xc) 427 | cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1) 428 | 429 | if do_classifier_free_guidance: 430 | """ 431 | Note: There are two kinds of uncond for text. 432 | 1: using "" as uncond text; (in SAL diffusion) 433 | 2: zeros_like(cond) as uncond text; (in MDM) 434 | """ 435 | # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) 436 | un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond) 437 | # un_cond = torch.zeros_like(cond, device=cond.device) 438 | cond = torch.cat([un_cond, cond], dim=0) 439 | 440 | outputs = [] 441 | latents = None 442 | 443 | if not return_intermediates: 444 | for _ in range(sample_times): 445 | sample_loop = ddim_sample( 446 | self.denoise_scheduler, 447 | self.model, 448 | shape=self.first_stage_model.latent_shape, 449 | cond=cond, 450 | steps=steps, 451 | guidance_scale=guidance_scale, 452 | do_classifier_free_guidance=do_classifier_free_guidance, 453 | device=self.device, 454 | eta=eta, 455 | disable_prog=not self.zero_rank 456 | ) 457 | for sample, t in sample_loop: 458 | latents = sample 459 | outputs.append(self.decode_first_stage(latents, **kwargs)) 460 | else: 461 | 462 | sample_loop = ddim_sample( 463 | self.denoise_scheduler, 464 | self.model, 465 | shape=self.first_stage_model.latent_shape, 466 | cond=cond, 467 | steps=steps, 468 | guidance_scale=guidance_scale, 469 | do_classifier_free_guidance=do_classifier_free_guidance, 470 | device=self.device, 471 | eta=eta, 472 | disable_prog=not self.zero_rank 473 | ) 474 | 475 | iter_size = steps // sample_times 476 | i = 0 477 | for sample, t in sample_loop: 478 | latents = sample 479 | if i % iter_size == 0 or i == steps - 1: 480 | outputs.append(self.decode_first_stage(latents, **kwargs)) 481 | i += 1 482 | 483 | return outputs 484 | -------------------------------------------------------------------------------- /michelangelo/models/asl_diffusion/asl_udt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Optional 6 | from diffusers.models.embeddings import Timesteps 7 | import math 8 | 9 | from michelangelo.models.modules.transformer_blocks import MLP 10 | from michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer 11 | 12 | 13 | class ConditionalASLUDTDenoiser(nn.Module): 14 | 15 | def __init__(self, *, 16 | device: Optional[torch.device], 17 | dtype: Optional[torch.dtype], 18 | input_channels: int, 19 | output_channels: int, 20 | n_ctx: int, 21 | width: int, 22 | layers: int, 23 | heads: int, 24 | context_dim: int, 25 | context_ln: bool = True, 26 | skip_ln: bool = False, 27 | init_scale: float = 0.25, 28 | flip_sin_to_cos: bool = False, 29 | use_checkpoint: bool = False): 30 | super().__init__() 31 | 32 | self.use_checkpoint = use_checkpoint 33 | 34 | init_scale = init_scale * math.sqrt(1.0 / width) 35 | 36 | self.backbone = UNetDiffusionTransformer( 37 | device=device, 38 | dtype=dtype, 39 | n_ctx=n_ctx, 40 | width=width, 41 | layers=layers, 42 | heads=heads, 43 | skip_ln=skip_ln, 44 | init_scale=init_scale, 45 | use_checkpoint=use_checkpoint 46 | ) 47 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 48 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) 49 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) 50 | 51 | # timestep embedding 52 | self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) 53 | self.time_proj = MLP( 54 | device=device, dtype=dtype, width=width, init_scale=init_scale 55 | ) 56 | 57 | self.context_embed = nn.Sequential( 58 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 59 | nn.Linear(context_dim, width, device=device, dtype=dtype), 60 | ) 61 | 62 | if context_ln: 63 | self.context_embed = nn.Sequential( 64 | nn.LayerNorm(context_dim, device=device, dtype=dtype), 65 | nn.Linear(context_dim, width, device=device, dtype=dtype), 66 | ) 67 | else: 68 | self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) 69 | 70 | def forward(self, 71 | model_input: torch.FloatTensor, 72 | timestep: torch.LongTensor, 73 | context: torch.FloatTensor): 74 | 75 | r""" 76 | Args: 77 | model_input (torch.FloatTensor): [bs, n_data, c] 78 | timestep (torch.LongTensor): [bs,] 79 | context (torch.FloatTensor): [bs, context_tokens, c] 80 | 81 | Returns: 82 | sample (torch.FloatTensor): [bs, n_data, c] 83 | 84 | """ 85 | 86 | _, n_data, _ = model_input.shape 87 | 88 | # 1. time 89 | t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) 90 | 91 | # 2. conditions projector 92 | context = self.context_embed(context) 93 | 94 | # 3. denoiser 95 | x = self.input_proj(model_input) 96 | x = torch.cat([t_emb, context, x], dim=1) 97 | x = self.backbone(x) 98 | x = self.ln_post(x) 99 | x = x[:, -n_data:] 100 | sample = self.output_proj(x) 101 | 102 | return sample 103 | 104 | 105 | -------------------------------------------------------------------------------- /michelangelo/models/asl_diffusion/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BaseDenoiser(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, t, context): 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from omegaconf import DictConfig 4 | from typing import List, Tuple, Dict, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import lr_scheduler 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.utilities import rank_zero_only 12 | 13 | from diffusers.schedulers import ( 14 | DDPMScheduler, 15 | DDIMScheduler, 16 | KarrasVeScheduler, 17 | DPMSolverMultistepScheduler 18 | ) 19 | 20 | from michelangelo.utils import instantiate_from_config 21 | from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule 22 | from michelangelo.models.asl_diffusion.inference_utils import ddim_sample 23 | 24 | SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] 25 | 26 | 27 | def disabled_train(self, mode=True): 28 | """Overwrite model.train with this function to make sure train/eval mode 29 | does not change anymore.""" 30 | return self 31 | 32 | 33 | class ClipASLDiffuser(pl.LightningModule): 34 | first_stage_model: Optional[AlignedShapeAsLatentPLModule] 35 | cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] 36 | model: nn.Module 37 | 38 | def __init__(self, *, 39 | first_stage_config, 40 | cond_stage_config, 41 | denoiser_cfg, 42 | scheduler_cfg, 43 | optimizer_cfg, 44 | loss_cfg, 45 | first_stage_key: str = "surface", 46 | cond_stage_key: str = "image", 47 | scale_by_std: bool = False, 48 | z_scale_factor: float = 1.0, 49 | ckpt_path: Optional[str] = None, 50 | ignore_keys: Union[Tuple[str], List[str]] = ()): 51 | 52 | super().__init__() 53 | 54 | self.first_stage_key = first_stage_key 55 | self.cond_stage_key = cond_stage_key 56 | 57 | # 1. lazy initialize first stage 58 | self.instantiate_first_stage(first_stage_config) 59 | 60 | # 2. initialize conditional stage 61 | self.instantiate_cond_stage(cond_stage_config) 62 | 63 | # 3. diffusion model 64 | self.model = instantiate_from_config( 65 | denoiser_cfg, device=None, dtype=None 66 | ) 67 | 68 | self.optimizer_cfg = optimizer_cfg 69 | 70 | # 4. scheduling strategy 71 | self.scheduler_cfg = scheduler_cfg 72 | 73 | self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) 74 | self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) 75 | 76 | # 5. loss configures 77 | self.loss_cfg = loss_cfg 78 | 79 | self.scale_by_std = scale_by_std 80 | if scale_by_std: 81 | self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) 82 | else: 83 | self.z_scale_factor = z_scale_factor 84 | 85 | self.ckpt_path = ckpt_path 86 | if ckpt_path is not None: 87 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 88 | 89 | def instantiate_non_trainable_model(self, config): 90 | model = instantiate_from_config(config) 91 | model = model.eval() 92 | model.train = disabled_train 93 | for param in model.parameters(): 94 | param.requires_grad = False 95 | 96 | return model 97 | 98 | def instantiate_first_stage(self, first_stage_config): 99 | self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config) 100 | self.first_stage_model.set_shape_model_only() 101 | 102 | def instantiate_cond_stage(self, cond_stage_config): 103 | self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config) 104 | 105 | def init_from_ckpt(self, path, ignore_keys=()): 106 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 107 | 108 | keys = list(state_dict.keys()) 109 | for k in keys: 110 | for ik in ignore_keys: 111 | if k.startswith(ik): 112 | print("Deleting key {} from state_dict.".format(k)) 113 | del state_dict[k] 114 | 115 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 116 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 117 | if len(missing) > 0: 118 | print(f"Missing Keys: {missing}") 119 | print(f"Unexpected Keys: {unexpected}") 120 | 121 | @property 122 | def zero_rank(self): 123 | if self._trainer: 124 | zero_rank = self.trainer.local_rank == 0 125 | else: 126 | zero_rank = True 127 | 128 | return zero_rank 129 | 130 | def configure_optimizers(self) -> Tuple[List, List]: 131 | 132 | lr = self.learning_rate 133 | 134 | trainable_parameters = list(self.model.parameters()) 135 | if self.optimizer_cfg is None: 136 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 137 | schedulers = [] 138 | else: 139 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 140 | scheduler_func = instantiate_from_config( 141 | self.optimizer_cfg.scheduler, 142 | max_decay_steps=self.trainer.max_steps, 143 | lr_max=lr 144 | ) 145 | scheduler = { 146 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 147 | "interval": "step", 148 | "frequency": 1 149 | } 150 | optimizers = [optimizer] 151 | schedulers = [scheduler] 152 | 153 | return optimizers, schedulers 154 | 155 | @torch.no_grad() 156 | def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): 157 | 158 | z_q = self.first_stage_model.encode(surface, sample_posterior) 159 | z_q = self.z_scale_factor * z_q 160 | 161 | return z_q 162 | 163 | @torch.no_grad() 164 | def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): 165 | 166 | z_q = 1. / self.z_scale_factor * z_q 167 | latents = self.first_stage_model.decode(z_q, **kwargs) 168 | return latents 169 | 170 | @rank_zero_only 171 | @torch.no_grad() 172 | def on_train_batch_start(self, batch, batch_idx): 173 | # only for very first batch 174 | if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ 175 | and batch_idx == 0 and self.ckpt_path is None: 176 | # set rescale weight to 1./std of encodings 177 | print("### USING STD-RESCALING ###") 178 | 179 | z_q = self.encode_first_stage(batch[self.first_stage_key]) 180 | z = z_q.detach() 181 | 182 | del self.z_scale_factor 183 | self.register_buffer("z_scale_factor", 1. / z.flatten().std()) 184 | print(f"setting self.z_scale_factor to {self.z_scale_factor}") 185 | 186 | print("### USING STD-RESCALING ###") 187 | 188 | def compute_loss(self, model_outputs, split): 189 | """ 190 | 191 | Args: 192 | model_outputs (dict): 193 | - x_0: 194 | - noise: 195 | - noise_prior: 196 | - noise_pred: 197 | - noise_pred_prior: 198 | 199 | split (str): 200 | 201 | Returns: 202 | 203 | """ 204 | 205 | pred = model_outputs["pred"] 206 | 207 | if self.noise_scheduler.prediction_type == "epsilon": 208 | target = model_outputs["noise"] 209 | elif self.noise_scheduler.prediction_type == "sample": 210 | target = model_outputs["x_0"] 211 | else: 212 | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") 213 | 214 | if self.loss_cfg.loss_type == "l1": 215 | simple = F.l1_loss(pred, target, reduction="mean") 216 | elif self.loss_cfg.loss_type in ["mse", "l2"]: 217 | simple = F.mse_loss(pred, target, reduction="mean") 218 | else: 219 | raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") 220 | 221 | total_loss = simple 222 | 223 | loss_dict = { 224 | f"{split}/total_loss": total_loss.clone().detach(), 225 | f"{split}/simple": simple.detach(), 226 | } 227 | 228 | return total_loss, loss_dict 229 | 230 | def forward(self, batch): 231 | """ 232 | 233 | Args: 234 | batch: 235 | 236 | Returns: 237 | 238 | """ 239 | 240 | latents = self.encode_first_stage(batch[self.first_stage_key]) 241 | conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) 242 | 243 | # Sample noise that we"ll add to the latents 244 | # [batch_size, n_token, latent_dim] 245 | noise = torch.randn_like(latents) 246 | bs = latents.shape[0] 247 | # Sample a random timestep for each motion 248 | timesteps = torch.randint( 249 | 0, 250 | self.noise_scheduler.config.num_train_timesteps, 251 | (bs,), 252 | device=latents.device, 253 | ) 254 | timesteps = timesteps.long() 255 | # Add noise to the latents according to the noise magnitude at each timestep 256 | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) 257 | 258 | # diffusion model forward 259 | noise_pred = self.model(noisy_z, timesteps, conditions) 260 | 261 | diffusion_outputs = { 262 | "x_0": noisy_z, 263 | "noise": noise, 264 | "pred": noise_pred 265 | } 266 | 267 | return diffusion_outputs 268 | 269 | def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], 270 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 271 | """ 272 | 273 | Args: 274 | batch (dict): the batch sample, and it contains: 275 | - surface (torch.FloatTensor): 276 | - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] 277 | - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] 278 | - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] 279 | - text (list of str): 280 | 281 | batch_idx (int): 282 | 283 | optimizer_idx (int): 284 | 285 | Returns: 286 | loss (torch.FloatTensor): 287 | 288 | """ 289 | 290 | diffusion_outputs = self(batch) 291 | 292 | loss, loss_dict = self.compute_loss(diffusion_outputs, "train") 293 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 294 | 295 | return loss 296 | 297 | def validation_step(self, batch: Dict[str, torch.FloatTensor], 298 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 299 | """ 300 | 301 | Args: 302 | batch (dict): the batch sample, and it contains: 303 | - surface_pc (torch.FloatTensor): [n_pts, 4] 304 | - surface_feats (torch.FloatTensor): [n_pts, c] 305 | - text (list of str): 306 | 307 | batch_idx (int): 308 | 309 | optimizer_idx (int): 310 | 311 | Returns: 312 | loss (torch.FloatTensor): 313 | 314 | """ 315 | 316 | diffusion_outputs = self(batch) 317 | 318 | loss, loss_dict = self.compute_loss(diffusion_outputs, "val") 319 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) 320 | 321 | return loss 322 | 323 | @torch.no_grad() 324 | def sample(self, 325 | batch: Dict[str, Union[torch.FloatTensor, List[str]]], 326 | sample_times: int = 1, 327 | steps: Optional[int] = None, 328 | guidance_scale: Optional[float] = None, 329 | eta: float = 0.0, 330 | return_intermediates: bool = False, **kwargs): 331 | 332 | if steps is None: 333 | steps = self.scheduler_cfg.num_inference_steps 334 | 335 | if guidance_scale is None: 336 | guidance_scale = self.scheduler_cfg.guidance_scale 337 | do_classifier_free_guidance = guidance_scale > 0 338 | 339 | # conditional encode 340 | xc = batch[self.cond_stage_key] 341 | 342 | # print(self.first_stage_model.device, self.cond_stage_model.device, self.device) 343 | 344 | cond = self.cond_stage_model(xc) 345 | 346 | if do_classifier_free_guidance: 347 | un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) 348 | cond = torch.cat([un_cond, cond], dim=0) 349 | 350 | outputs = [] 351 | latents = None 352 | 353 | if not return_intermediates: 354 | for _ in range(sample_times): 355 | sample_loop = ddim_sample( 356 | self.denoise_scheduler, 357 | self.model, 358 | shape=self.first_stage_model.latent_shape, 359 | cond=cond, 360 | steps=steps, 361 | guidance_scale=guidance_scale, 362 | do_classifier_free_guidance=do_classifier_free_guidance, 363 | device=self.device, 364 | eta=eta, 365 | disable_prog=not self.zero_rank 366 | ) 367 | for sample, t in sample_loop: 368 | latents = sample 369 | outputs.append(self.decode_first_stage(latents, **kwargs)) 370 | else: 371 | 372 | sample_loop = ddim_sample( 373 | self.denoise_scheduler, 374 | self.model, 375 | shape=self.first_stage_model.latent_shape, 376 | cond=cond, 377 | steps=steps, 378 | guidance_scale=guidance_scale, 379 | do_classifier_free_guidance=do_classifier_free_guidance, 380 | device=self.device, 381 | eta=eta, 382 | disable_prog=not self.zero_rank 383 | ) 384 | 385 | iter_size = steps // sample_times 386 | i = 0 387 | for sample, t in sample_loop: 388 | latents = sample 389 | if i % iter_size == 0 or i == steps - 1: 390 | outputs.append(self.decode_first_stage(latents, **kwargs)) 391 | i += 1 392 | 393 | return outputs 394 | -------------------------------------------------------------------------------- /michelangelo/models/asl_diffusion/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from typing import Tuple, List, Union, Optional 6 | from diffusers.schedulers import DDIMScheduler 7 | 8 | 9 | __all__ = ["ddim_sample"] 10 | 11 | 12 | def ddim_sample(ddim_scheduler: DDIMScheduler, 13 | diffusion_model: torch.nn.Module, 14 | shape: Union[List[int], Tuple[int]], 15 | cond: torch.FloatTensor, 16 | steps: int, 17 | eta: float = 0.0, 18 | guidance_scale: float = 3.0, 19 | do_classifier_free_guidance: bool = True, 20 | generator: Optional[torch.Generator] = None, 21 | device: torch.device = "cuda:0", 22 | disable_prog: bool = True): 23 | 24 | assert steps > 0, f"{steps} must > 0." 25 | 26 | # init latents 27 | bsz = cond.shape[0] 28 | if do_classifier_free_guidance: 29 | bsz = bsz // 2 30 | 31 | latents = torch.randn( 32 | (bsz, *shape), 33 | generator=generator, 34 | device=cond.device, 35 | dtype=cond.dtype, 36 | ) 37 | # scale the initial noise by the standard deviation required by the scheduler 38 | latents = latents * ddim_scheduler.init_noise_sigma 39 | # set timesteps 40 | ddim_scheduler.set_timesteps(steps) 41 | timesteps = ddim_scheduler.timesteps.to(device) 42 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 43 | # eta (η) is only used with the DDIMScheduler, and between [0, 1] 44 | extra_step_kwargs = { 45 | "eta": eta, 46 | "generator": generator 47 | } 48 | 49 | # reverse 50 | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): 51 | # expand the latents if we are doing classifier free guidance 52 | latent_model_input = ( 53 | torch.cat([latents] * 2) 54 | if do_classifier_free_guidance 55 | else latents 56 | ) 57 | # latent_model_input = scheduler.scale_model_input(latent_model_input, t) 58 | # predict the noise residual 59 | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) 60 | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) 61 | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) 62 | 63 | # perform guidance 64 | if do_classifier_free_guidance: 65 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 66 | noise_pred = noise_pred_uncond + guidance_scale * ( 67 | noise_pred_text - noise_pred_uncond 68 | ) 69 | # text_embeddings_for_guidance = encoder_hidden_states.chunk( 70 | # 2)[1] if do_classifier_free_guidance else encoder_hidden_states 71 | # compute the previous noisy sample x_t -> x_t-1 72 | latents = ddim_scheduler.step( 73 | noise_pred, t, latents, **extra_step_kwargs 74 | ).prev_sample 75 | 76 | yield latents, t 77 | 78 | 79 | def karra_sample(): 80 | pass 81 | -------------------------------------------------------------------------------- /michelangelo/models/conditional_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .clip import CLIPEncoder 4 | -------------------------------------------------------------------------------- /michelangelo/models/conditional_encoders/clip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from dataclasses import dataclass 7 | from torchvision.transforms import Normalize 8 | from transformers import CLIPModel, CLIPTokenizer 9 | from transformers.utils import ModelOutput 10 | from typing import Iterable, Optional, Union, List 11 | 12 | 13 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 14 | 15 | 16 | @dataclass 17 | class CLIPEmbedOutput(ModelOutput): 18 | last_hidden_state: torch.FloatTensor = None 19 | pooler_output: torch.FloatTensor = None 20 | embeds: torch.FloatTensor = None 21 | 22 | 23 | class CLIPEncoder(torch.nn.Module): 24 | 25 | def __init__(self, model_path="openai/clip-vit-base-patch32"): 26 | 27 | super().__init__() 28 | 29 | # Load the CLIP model and processor 30 | self.model: CLIPModel = CLIPModel.from_pretrained(model_path) 31 | self.tokenizer = CLIPTokenizer.from_pretrained(model_path) 32 | self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 33 | 34 | self.model.training = False 35 | for p in self.model.parameters(): 36 | p.requires_grad = False 37 | 38 | @torch.no_grad() 39 | def encode_image(self, images: Iterable[Optional[ImageType]]): 40 | pixel_values = self.image_preprocess(images) 41 | 42 | vision_outputs = self.model.vision_model(pixel_values=pixel_values) 43 | 44 | pooler_output = vision_outputs[1] # pooled_output 45 | image_features = self.model.visual_projection(pooler_output) 46 | 47 | visual_embeds = CLIPEmbedOutput( 48 | last_hidden_state=vision_outputs.last_hidden_state, 49 | pooler_output=pooler_output, 50 | embeds=image_features 51 | ) 52 | 53 | return visual_embeds 54 | 55 | @torch.no_grad() 56 | def encode_text(self, texts: List[str]): 57 | text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") 58 | 59 | text_outputs = self.model.text_model(input_ids=text_inputs) 60 | 61 | pooler_output = text_outputs[1] # pooled_output 62 | text_features = self.model.text_projection(pooler_output) 63 | 64 | text_embeds = CLIPEmbedOutput( 65 | last_hidden_state=text_outputs.last_hidden_state, 66 | pooler_output=pooler_output, 67 | embeds=text_features 68 | ) 69 | 70 | return text_embeds 71 | 72 | def forward(self, 73 | images: Iterable[Optional[ImageType]], 74 | texts: List[str]): 75 | 76 | visual_embeds = self.encode_image(images) 77 | text_embeds = self.encode_text(texts) 78 | 79 | return visual_embeds, text_embeds 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /michelangelo/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import checkpoint 4 | -------------------------------------------------------------------------------- /michelangelo/models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 4 | """ 5 | 6 | import torch 7 | from typing import Callable, Iterable, Sequence, Union 8 | 9 | 10 | def checkpoint( 11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 12 | inputs: Sequence[torch.Tensor], 13 | params: Iterable[torch.Tensor], 14 | flag: bool, 15 | use_deepspeed: bool = False 16 | ): 17 | """ 18 | Evaluate a function without caching intermediate activations, allowing for 19 | reduced memory at the expense of extra compute in the backward pass. 20 | :param func: the function to evaluate. 21 | :param inputs: the argument sequence to pass to `func`. 22 | :param params: a sequence of parameters `func` depends on but does not 23 | explicitly take as arguments. 24 | :param flag: if False, disable gradient checkpointing. 25 | :param use_deepspeed: if True, use deepspeed 26 | """ 27 | if flag: 28 | if use_deepspeed: 29 | import deepspeed 30 | return deepspeed.checkpointing.checkpoint(func, *inputs) 31 | 32 | args = tuple(inputs) + tuple(params) 33 | return CheckpointFunction.apply(func, len(inputs), *args) 34 | else: 35 | return func(*inputs) 36 | 37 | 38 | class CheckpointFunction(torch.autograd.Function): 39 | @staticmethod 40 | @torch.cuda.amp.custom_fwd 41 | def forward(ctx, run_function, length, *args): 42 | ctx.run_function = run_function 43 | ctx.input_tensors = list(args[:length]) 44 | ctx.input_params = list(args[length:]) 45 | 46 | with torch.no_grad(): 47 | output_tensors = ctx.run_function(*ctx.input_tensors) 48 | return output_tensors 49 | 50 | @staticmethod 51 | @torch.cuda.amp.custom_bwd 52 | def backward(ctx, *output_grads): 53 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 54 | with torch.enable_grad(): 55 | # Fixes a bug where the first op in run_function modifies the 56 | # Tensor storage in place, which is not allowed for detach()'d 57 | # Tensors. 58 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 59 | output_tensors = ctx.run_function(*shallow_copies) 60 | input_grads = torch.autograd.grad( 61 | output_tensors, 62 | ctx.input_tensors + ctx.input_params, 63 | output_grads, 64 | allow_unused=True, 65 | ) 66 | del ctx.input_tensors 67 | del ctx.input_params 68 | del output_tensors 69 | return (None, None) + input_grads 70 | -------------------------------------------------------------------------------- /michelangelo/models/modules/diffusion_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from typing import Optional 7 | 8 | from michelangelo.models.modules.checkpoint import checkpoint 9 | from michelangelo.models.modules.transformer_blocks import ( 10 | init_linear, 11 | MLP, 12 | MultiheadCrossAttention, 13 | MultiheadAttention, 14 | ResidualAttentionBlock 15 | ) 16 | 17 | 18 | class AdaLayerNorm(nn.Module): 19 | def __init__(self, 20 | device: torch.device, 21 | dtype: torch.dtype, 22 | width: int): 23 | 24 | super().__init__() 25 | 26 | self.silu = nn.SiLU(inplace=True) 27 | self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype) 28 | self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype) 29 | 30 | def forward(self, x, timestep): 31 | emb = self.linear(timestep) 32 | scale, shift = torch.chunk(emb, 2, dim=2) 33 | x = self.layernorm(x) * (1 + scale) + shift 34 | return x 35 | 36 | 37 | class DitBlock(nn.Module): 38 | def __init__( 39 | self, 40 | *, 41 | device: torch.device, 42 | dtype: torch.dtype, 43 | n_ctx: int, 44 | width: int, 45 | heads: int, 46 | context_dim: int, 47 | qkv_bias: bool = False, 48 | init_scale: float = 1.0, 49 | use_checkpoint: bool = False 50 | ): 51 | super().__init__() 52 | 53 | self.use_checkpoint = use_checkpoint 54 | 55 | self.attn = MultiheadAttention( 56 | device=device, 57 | dtype=dtype, 58 | n_ctx=n_ctx, 59 | width=width, 60 | heads=heads, 61 | init_scale=init_scale, 62 | qkv_bias=qkv_bias 63 | ) 64 | self.ln_1 = AdaLayerNorm(device, dtype, width) 65 | 66 | if context_dim is not None: 67 | self.ln_2 = AdaLayerNorm(device, dtype, width) 68 | self.cross_attn = MultiheadCrossAttention( 69 | device=device, 70 | dtype=dtype, 71 | width=width, 72 | heads=heads, 73 | data_width=context_dim, 74 | init_scale=init_scale, 75 | qkv_bias=qkv_bias 76 | ) 77 | 78 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 79 | self.ln_3 = AdaLayerNorm(device, dtype, width) 80 | 81 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 82 | return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint) 83 | 84 | def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 85 | x = x + self.attn(self.ln_1(x, t)) 86 | if context is not None: 87 | x = x + self.cross_attn(self.ln_2(x, t), context) 88 | x = x + self.mlp(self.ln_3(x, t)) 89 | return x 90 | 91 | 92 | class DiT(nn.Module): 93 | def __init__( 94 | self, 95 | *, 96 | device: Optional[torch.device], 97 | dtype: Optional[torch.dtype], 98 | n_ctx: int, 99 | width: int, 100 | layers: int, 101 | heads: int, 102 | context_dim: int, 103 | init_scale: float = 0.25, 104 | qkv_bias: bool = False, 105 | use_checkpoint: bool = False 106 | ): 107 | super().__init__() 108 | self.n_ctx = n_ctx 109 | self.width = width 110 | self.layers = layers 111 | 112 | self.resblocks = nn.ModuleList( 113 | [ 114 | DitBlock( 115 | device=device, 116 | dtype=dtype, 117 | n_ctx=n_ctx, 118 | width=width, 119 | heads=heads, 120 | context_dim=context_dim, 121 | qkv_bias=qkv_bias, 122 | init_scale=init_scale, 123 | use_checkpoint=use_checkpoint 124 | ) 125 | for _ in range(layers) 126 | ] 127 | ) 128 | 129 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): 130 | for block in self.resblocks: 131 | x = block(x, t, context) 132 | return x 133 | 134 | 135 | class UNetDiffusionTransformer(nn.Module): 136 | def __init__( 137 | self, 138 | *, 139 | device: Optional[torch.device], 140 | dtype: Optional[torch.dtype], 141 | n_ctx: int, 142 | width: int, 143 | layers: int, 144 | heads: int, 145 | init_scale: float = 0.25, 146 | qkv_bias: bool = False, 147 | skip_ln: bool = False, 148 | use_checkpoint: bool = False 149 | ): 150 | super().__init__() 151 | 152 | self.n_ctx = n_ctx 153 | self.width = width 154 | self.layers = layers 155 | 156 | self.encoder = nn.ModuleList() 157 | for _ in range(layers): 158 | resblock = ResidualAttentionBlock( 159 | device=device, 160 | dtype=dtype, 161 | n_ctx=n_ctx, 162 | width=width, 163 | heads=heads, 164 | init_scale=init_scale, 165 | qkv_bias=qkv_bias, 166 | use_checkpoint=use_checkpoint 167 | ) 168 | self.encoder.append(resblock) 169 | 170 | self.middle_block = ResidualAttentionBlock( 171 | device=device, 172 | dtype=dtype, 173 | n_ctx=n_ctx, 174 | width=width, 175 | heads=heads, 176 | init_scale=init_scale, 177 | qkv_bias=qkv_bias, 178 | use_checkpoint=use_checkpoint 179 | ) 180 | 181 | self.decoder = nn.ModuleList() 182 | for _ in range(layers): 183 | resblock = ResidualAttentionBlock( 184 | device=device, 185 | dtype=dtype, 186 | n_ctx=n_ctx, 187 | width=width, 188 | heads=heads, 189 | init_scale=init_scale, 190 | qkv_bias=qkv_bias, 191 | use_checkpoint=use_checkpoint 192 | ) 193 | linear = nn.Linear(width * 2, width, device=device, dtype=dtype) 194 | init_linear(linear, init_scale) 195 | 196 | layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None 197 | 198 | self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) 199 | 200 | def forward(self, x: torch.Tensor): 201 | 202 | enc_outputs = [] 203 | for block in self.encoder: 204 | x = block(x) 205 | enc_outputs.append(x) 206 | 207 | x = self.middle_block(x) 208 | 209 | for i, (resblock, linear, layer_norm) in enumerate(self.decoder): 210 | x = torch.cat([enc_outputs.pop(), x], dim=-1) 211 | x = linear(x) 212 | 213 | if layer_norm is not None: 214 | x = layer_norm(x) 215 | 216 | x = resblock(x) 217 | 218 | return x 219 | -------------------------------------------------------------------------------- /michelangelo/models/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union, List 4 | 5 | 6 | class AbstractDistribution(object): 7 | def sample(self): 8 | raise NotImplementedError() 9 | 10 | def mode(self): 11 | raise NotImplementedError() 12 | 13 | 14 | class DiracDistribution(AbstractDistribution): 15 | def __init__(self, value): 16 | self.value = value 17 | 18 | def sample(self): 19 | return self.value 20 | 21 | def mode(self): 22 | return self.value 23 | 24 | 25 | class DiagonalGaussianDistribution(object): 26 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): 27 | self.feat_dim = feat_dim 28 | self.parameters = parameters 29 | 30 | if isinstance(parameters, list): 31 | self.mean = parameters[0] 32 | self.logvar = parameters[1] 33 | else: 34 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) 35 | 36 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 37 | self.deterministic = deterministic 38 | self.std = torch.exp(0.5 * self.logvar) 39 | self.var = torch.exp(self.logvar) 40 | if self.deterministic: 41 | self.var = self.std = torch.zeros_like(self.mean) 42 | 43 | def sample(self): 44 | x = self.mean + self.std * torch.randn_like(self.mean) 45 | return x 46 | 47 | def kl(self, other=None, dims=(1, 2, 3)): 48 | if self.deterministic: 49 | return torch.Tensor([0.]) 50 | else: 51 | if other is None: 52 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 53 | + self.var - 1.0 - self.logvar, 54 | dim=dims) 55 | else: 56 | return 0.5 * torch.mean( 57 | torch.pow(self.mean - other.mean, 2) / other.var 58 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 59 | dim=dims) 60 | 61 | def nll(self, sample, dims=(1, 2, 3)): 62 | if self.deterministic: 63 | return torch.Tensor([0.]) 64 | logtwopi = np.log(2.0 * np.pi) 65 | return 0.5 * torch.sum( 66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 67 | dim=dims) 68 | 69 | def mode(self): 70 | return self.mean 71 | 72 | 73 | def normal_kl(mean1, logvar1, mean2, logvar2): 74 | """ 75 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 76 | Compute the KL divergence between two gaussians. 77 | Shapes are automatically broadcasted, so batches can be compared to 78 | scalars, among other use cases. 79 | """ 80 | tensor = None 81 | for obj in (mean1, logvar1, mean2, logvar2): 82 | if isinstance(obj, torch.Tensor): 83 | tensor = obj 84 | break 85 | assert tensor is not None, "at least one argument must be a Tensor" 86 | 87 | # Force variances to be Tensors. Broadcasting helps convert scalars to 88 | # Tensors, but it does not work for torch.exp(). 89 | logvar1, logvar2 = [ 90 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 91 | for x in (logvar1, logvar2) 92 | ] 93 | 94 | return 0.5 * ( 95 | -1.0 96 | + logvar2 97 | - logvar1 98 | + torch.exp(logvar1 - logvar2) 99 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 100 | ) 101 | -------------------------------------------------------------------------------- /michelangelo/models/modules/embedder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] 9 | 10 | 11 | class FourierEmbedder(nn.Module): 12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts 13 | each feature dimension of `x[..., i]` into: 14 | [ 15 | sin(x[..., i]), 16 | sin(f_1*x[..., i]), 17 | sin(f_2*x[..., i]), 18 | ... 19 | sin(f_N * x[..., i]), 20 | cos(x[..., i]), 21 | cos(f_1*x[..., i]), 22 | cos(f_2*x[..., i]), 23 | ... 24 | cos(f_N * x[..., i]), 25 | x[..., i] # only present if include_input is True. 26 | ], here f_i is the frequency. 27 | 28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. 29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; 30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. 31 | 32 | Args: 33 | num_freqs (int): the number of frequencies, default is 6; 34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; 36 | input_dim (int): the input dimension, default is 3; 37 | include_input (bool): include the input tensor or not, default is True. 38 | 39 | Attributes: 40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); 42 | 43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), 44 | otherwise, it is input_dim * num_freqs * 2. 45 | 46 | """ 47 | 48 | def __init__(self, 49 | num_freqs: int = 6, 50 | logspace: bool = True, 51 | input_dim: int = 3, 52 | include_input: bool = True, 53 | include_pi: bool = True) -> None: 54 | 55 | """The initialization""" 56 | 57 | super().__init__() 58 | 59 | if logspace: 60 | frequencies = 2.0 ** torch.arange( 61 | num_freqs, 62 | dtype=torch.float32 63 | ) 64 | else: 65 | frequencies = torch.linspace( 66 | 1.0, 67 | 2.0 ** (num_freqs - 1), 68 | num_freqs, 69 | dtype=torch.float32 70 | ) 71 | 72 | if include_pi: 73 | frequencies *= torch.pi 74 | 75 | self.register_buffer("frequencies", frequencies, persistent=False) 76 | self.include_input = include_input 77 | self.num_freqs = num_freqs 78 | 79 | self.out_dim = self.get_dims(input_dim) 80 | 81 | def get_dims(self, input_dim): 82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0 83 | out_dim = input_dim * (self.num_freqs * 2 + temp) 84 | 85 | return out_dim 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | """ Forward process. 89 | 90 | Args: 91 | x: tensor of shape [..., dim] 92 | 93 | Returns: 94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] 95 | where temp is 1 if include_input is True and 0 otherwise. 96 | """ 97 | 98 | if self.num_freqs > 0: 99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) 100 | if self.include_input: 101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) 102 | else: 103 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 104 | else: 105 | return x 106 | 107 | 108 | class LearnedFourierEmbedder(nn.Module): 109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """ 110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 111 | 112 | def __init__(self, in_channels, dim): 113 | super().__init__() 114 | assert (dim % 2) == 0 115 | half_dim = dim // 2 116 | per_channel_dim = half_dim // in_channels 117 | self.weights = nn.Parameter(torch.randn(per_channel_dim)) 118 | 119 | def forward(self, x): 120 | """ 121 | 122 | Args: 123 | x (torch.FloatTensor): [..., c] 124 | 125 | Returns: 126 | x (torch.FloatTensor): [..., d] 127 | """ 128 | 129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] 130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) 131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) 132 | return fouriered 133 | 134 | 135 | class TriplaneLearnedFourierEmbedder(nn.Module): 136 | def __init__(self, in_channels, dim): 137 | super().__init__() 138 | 139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 142 | 143 | self.out_dim = in_channels + dim 144 | 145 | def forward(self, x): 146 | 147 | yz_embed = self.yz_plane_embedder(x) 148 | xz_embed = self.xz_plane_embedder(x) 149 | xy_embed = self.xy_plane_embedder(x) 150 | 151 | embed = yz_embed + xz_embed + xy_embed 152 | 153 | return embed 154 | 155 | 156 | def sequential_pos_embed(num_len, embed_dim): 157 | assert embed_dim % 2 == 0 158 | 159 | pos = torch.arange(num_len, dtype=torch.float32) 160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32) 161 | omega /= embed_dim / 2. 162 | omega = 1. / 10000 ** omega # (D/2,) 163 | 164 | pos = pos.reshape(-1) # (M,) 165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 166 | 167 | emb_sin = torch.sin(out) # (M, D/2) 168 | emb_cos = torch.cos(out) # (M, D/2) 169 | 170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 171 | 172 | return embeddings 173 | 174 | 175 | def timestep_embedding(timesteps, dim, max_period=10000): 176 | """ 177 | Create sinusoidal timestep embeddings. 178 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 179 | These may be fractional. 180 | :param dim: the dimension of the output. 181 | :param max_period: controls the minimum frequency of the embeddings. 182 | :return: an [N x dim] Tensor of positional embeddings. 183 | """ 184 | half = dim // 2 185 | freqs = torch.exp( 186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 187 | ).to(device=timesteps.device) 188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None] 189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 190 | if dim % 2: 191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 192 | return embedding 193 | 194 | 195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, 196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, 197 | log2_hashmap_size=19, desired_resolution=None): 198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): 199 | return nn.Identity(), input_dim 200 | 201 | elif embed_type == "fourier": 202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, 203 | logspace=True, include_input=True) 204 | return embedder_obj, embedder_obj.out_dim 205 | 206 | elif embed_type == "hashgrid": 207 | raise NotImplementedError 208 | 209 | elif embed_type == "sphere_harmonic": 210 | raise NotImplementedError 211 | 212 | else: 213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") 214 | -------------------------------------------------------------------------------- /michelangelo/models/modules/transformer_blocks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Optional 8 | 9 | from michelangelo.models.modules.checkpoint import checkpoint 10 | 11 | 12 | def init_linear(l, stddev): 13 | nn.init.normal_(l.weight, std=stddev) 14 | if l.bias is not None: 15 | nn.init.constant_(l.bias, 0.0) 16 | 17 | 18 | class MultiheadAttention(nn.Module): 19 | def __init__( 20 | self, 21 | *, 22 | device: torch.device, 23 | dtype: torch.dtype, 24 | n_ctx: int, 25 | width: int, 26 | heads: int, 27 | init_scale: float, 28 | qkv_bias: bool, 29 | flash: bool = False 30 | ): 31 | super().__init__() 32 | self.n_ctx = n_ctx 33 | self.width = width 34 | self.heads = heads 35 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) 36 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 37 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) 38 | init_linear(self.c_qkv, init_scale) 39 | init_linear(self.c_proj, init_scale) 40 | 41 | def forward(self, x): 42 | x = self.c_qkv(x) 43 | x = checkpoint(self.attention, (x,), (), True) 44 | x = self.c_proj(x) 45 | return x 46 | 47 | 48 | class QKVMultiheadAttention(nn.Module): 49 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): 50 | super().__init__() 51 | self.device = device 52 | self.dtype = dtype 53 | self.heads = heads 54 | self.n_ctx = n_ctx 55 | self.flash = flash 56 | 57 | def forward(self, qkv): 58 | bs, n_ctx, width = qkv.shape 59 | attn_ch = width // self.heads // 3 60 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 61 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 62 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 63 | 64 | if self.flash: 65 | out = F.scaled_dot_product_attention(q, k, v) 66 | else: 67 | weight = torch.einsum( 68 | "bthc,bshc->bhts", q * scale, k * scale 69 | ) # More stable with f16 than dividing afterwards 70 | wdtype = weight.dtype 71 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 72 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 73 | 74 | return out 75 | 76 | 77 | class ResidualAttentionBlock(nn.Module): 78 | def __init__( 79 | self, 80 | *, 81 | device: torch.device, 82 | dtype: torch.dtype, 83 | n_ctx: int, 84 | width: int, 85 | heads: int, 86 | init_scale: float = 1.0, 87 | qkv_bias: bool = True, 88 | flash: bool = False, 89 | use_checkpoint: bool = False 90 | ): 91 | super().__init__() 92 | 93 | self.use_checkpoint = use_checkpoint 94 | 95 | self.attn = MultiheadAttention( 96 | device=device, 97 | dtype=dtype, 98 | n_ctx=n_ctx, 99 | width=width, 100 | heads=heads, 101 | init_scale=init_scale, 102 | qkv_bias=qkv_bias, 103 | flash=flash 104 | ) 105 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 106 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 107 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) 108 | 109 | def _forward(self, x: torch.Tensor): 110 | x = x + self.attn(self.ln_1(x)) 111 | x = x + self.mlp(self.ln_2(x)) 112 | return x 113 | 114 | def forward(self, x: torch.Tensor): 115 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 116 | 117 | 118 | class MultiheadCrossAttention(nn.Module): 119 | def __init__( 120 | self, 121 | *, 122 | device: torch.device, 123 | dtype: torch.dtype, 124 | width: int, 125 | heads: int, 126 | init_scale: float, 127 | qkv_bias: bool = True, 128 | flash: bool = False, 129 | n_data: Optional[int] = None, 130 | data_width: Optional[int] = None, 131 | ): 132 | super().__init__() 133 | self.n_data = n_data 134 | self.width = width 135 | self.heads = heads 136 | self.data_width = width if data_width is None else data_width 137 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) 138 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) 139 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 140 | self.attention = QKVMultiheadCrossAttention( 141 | device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash 142 | ) 143 | init_linear(self.c_q, init_scale) 144 | init_linear(self.c_kv, init_scale) 145 | init_linear(self.c_proj, init_scale) 146 | 147 | def forward(self, x, data): 148 | x = self.c_q(x) 149 | data = self.c_kv(data) 150 | x = checkpoint(self.attention, (x, data), (), True) 151 | x = self.c_proj(x) 152 | return x 153 | 154 | 155 | class QKVMultiheadCrossAttention(nn.Module): 156 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, 157 | flash: bool = False, n_data: Optional[int] = None): 158 | 159 | super().__init__() 160 | self.device = device 161 | self.dtype = dtype 162 | self.heads = heads 163 | self.n_data = n_data 164 | self.flash = flash 165 | 166 | def forward(self, q, kv): 167 | _, n_ctx, _ = q.shape 168 | bs, n_data, width = kv.shape 169 | attn_ch = width // self.heads // 2 170 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 171 | q = q.view(bs, n_ctx, self.heads, -1) 172 | kv = kv.view(bs, n_data, self.heads, -1) 173 | k, v = torch.split(kv, attn_ch, dim=-1) 174 | 175 | if self.flash: 176 | out = F.scaled_dot_product_attention(q, k, v) 177 | else: 178 | weight = torch.einsum( 179 | "bthc,bshc->bhts", q * scale, k * scale 180 | ) # More stable with f16 than dividing afterwards 181 | wdtype = weight.dtype 182 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 183 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 184 | 185 | return out 186 | 187 | 188 | class ResidualCrossAttentionBlock(nn.Module): 189 | def __init__( 190 | self, 191 | *, 192 | device: Optional[torch.device], 193 | dtype: Optional[torch.dtype], 194 | n_data: Optional[int] = None, 195 | width: int, 196 | heads: int, 197 | data_width: Optional[int] = None, 198 | init_scale: float = 0.25, 199 | qkv_bias: bool = True, 200 | flash: bool = False 201 | ): 202 | super().__init__() 203 | 204 | if data_width is None: 205 | data_width = width 206 | 207 | self.attn = MultiheadCrossAttention( 208 | device=device, 209 | dtype=dtype, 210 | n_data=n_data, 211 | width=width, 212 | heads=heads, 213 | data_width=data_width, 214 | init_scale=init_scale, 215 | qkv_bias=qkv_bias, 216 | flash=flash, 217 | ) 218 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 219 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 220 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 221 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 222 | 223 | def forward(self, x: torch.Tensor, data: torch.Tensor): 224 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 225 | x = x + self.mlp(self.ln_3(x)) 226 | return x 227 | 228 | 229 | class MLP(nn.Module): 230 | def __init__(self, *, 231 | device: Optional[torch.device], 232 | dtype: Optional[torch.dtype], 233 | width: int, 234 | init_scale: float): 235 | super().__init__() 236 | self.width = width 237 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 238 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 239 | self.gelu = nn.GELU() 240 | init_linear(self.c_fc, init_scale) 241 | init_linear(self.c_proj, init_scale) 242 | 243 | def forward(self, x): 244 | return self.c_proj(self.gelu(self.c_fc(x))) 245 | 246 | 247 | class Transformer(nn.Module): 248 | def __init__( 249 | self, 250 | *, 251 | device: Optional[torch.device], 252 | dtype: Optional[torch.dtype], 253 | n_ctx: int, 254 | width: int, 255 | layers: int, 256 | heads: int, 257 | init_scale: float = 0.25, 258 | qkv_bias: bool = True, 259 | flash: bool = False, 260 | use_checkpoint: bool = False 261 | ): 262 | super().__init__() 263 | self.n_ctx = n_ctx 264 | self.width = width 265 | self.layers = layers 266 | self.resblocks = nn.ModuleList( 267 | [ 268 | ResidualAttentionBlock( 269 | device=device, 270 | dtype=dtype, 271 | n_ctx=n_ctx, 272 | width=width, 273 | heads=heads, 274 | init_scale=init_scale, 275 | qkv_bias=qkv_bias, 276 | flash=flash, 277 | use_checkpoint=use_checkpoint 278 | ) 279 | for _ in range(layers) 280 | ] 281 | ) 282 | 283 | def forward(self, x: torch.Tensor): 284 | for block in self.resblocks: 285 | x = block(x) 286 | return x 287 | -------------------------------------------------------------------------------- /michelangelo/models/modules/transformer_vit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from typing import Optional 7 | import warnings 8 | 9 | from michelangelo.models.modules.checkpoint import checkpoint 10 | 11 | 12 | def _trunc_normal_(tensor, mean, std, a, b): 13 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 14 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 15 | def norm_cdf(x): 16 | # Computes standard normal cumulative distribution function 17 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 18 | 19 | if (mean < a - 2 * std) or (mean > b + 2 * std): 20 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 21 | "The distribution of values may be incorrect.", 22 | stacklevel=2) 23 | 24 | # Values are generated by using a truncated uniform distribution and 25 | # then using the inverse CDF for the normal distribution. 26 | # Get upper and lower cdf values 27 | l = norm_cdf((a - mean) / std) 28 | u = norm_cdf((b - mean) / std) 29 | 30 | # Uniformly fill tensor with values from [l, u], then translate to 31 | # [2l-1, 2u-1]. 32 | tensor.uniform_(2 * l - 1, 2 * u - 1) 33 | 34 | # Use inverse cdf transform for normal distribution to get truncated 35 | # standard normal 36 | tensor.erfinv_() 37 | 38 | # Transform to proper mean, std 39 | tensor.mul_(std * math.sqrt(2.)) 40 | tensor.add_(mean) 41 | 42 | # Clamp to ensure it's in the proper range 43 | tensor.clamp_(min=a, max=b) 44 | return tensor 45 | 46 | 47 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 48 | # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor 49 | r"""Fills the input Tensor with values drawn from a truncated 50 | normal distribution. The values are effectively drawn from the 51 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 52 | with values outside :math:`[a, b]` redrawn until they are within 53 | the bounds. The method used for generating the random values works 54 | best when :math:`a \leq \text{mean} \leq b`. 55 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 56 | applied while sampling the normal with mean/std applied, therefore a, b args 57 | should be adjusted to match the range of mean, std args. 58 | Args: 59 | tensor: an n-dimensional `torch.Tensor` 60 | mean: the mean of the normal distribution 61 | std: the standard deviation of the normal distribution 62 | a: the minimum cutoff value 63 | b: the maximum cutoff value 64 | Examples: 65 | >>> w = torch.empty(3, 5) 66 | >>> nn.init.trunc_normal_(w) 67 | """ 68 | with torch.no_grad(): 69 | return _trunc_normal_(tensor, mean, std, a, b) 70 | 71 | 72 | def init_weights(m): 73 | if isinstance(m, nn.Linear): 74 | trunc_normal_(m.weight, std=.02) 75 | if isinstance(m, nn.Linear) and m.bias is not None: 76 | nn.init.constant_(m.bias, 0) 77 | elif isinstance(m, nn.LayerNorm): 78 | nn.init.constant_(m.bias, 0) 79 | nn.init.constant_(m.weight, 1.0) 80 | 81 | 82 | class MultiheadAttention(nn.Module): 83 | def __init__( 84 | self, 85 | *, 86 | device: torch.device, 87 | dtype: torch.dtype, 88 | n_ctx: int, 89 | width: int, 90 | heads: int, 91 | qkv_bias: bool 92 | ): 93 | super().__init__() 94 | self.n_ctx = n_ctx 95 | self.width = width 96 | self.heads = heads 97 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) 98 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 99 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) 100 | 101 | def forward(self, x): 102 | x = self.c_qkv(x) 103 | x = checkpoint(self.attention, (x,), (), True) 104 | x = self.c_proj(x) 105 | return x 106 | 107 | 108 | class QKVMultiheadAttention(nn.Module): 109 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): 110 | super().__init__() 111 | self.device = device 112 | self.dtype = dtype 113 | self.heads = heads 114 | self.n_ctx = n_ctx 115 | 116 | def forward(self, qkv): 117 | bs, n_ctx, width = qkv.shape 118 | attn_ch = width // self.heads // 3 119 | scale = 1 / math.sqrt(attn_ch) 120 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 121 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 122 | weight = torch.einsum("bthc,bshc->bhts", q, k) * scale 123 | wdtype = weight.dtype 124 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 125 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 126 | 127 | 128 | class ResidualAttentionBlock(nn.Module): 129 | def __init__( 130 | self, 131 | *, 132 | device: torch.device, 133 | dtype: torch.dtype, 134 | n_ctx: int, 135 | width: int, 136 | heads: int, 137 | qkv_bias: bool = True, 138 | use_checkpoint: bool = False 139 | ): 140 | super().__init__() 141 | 142 | self.use_checkpoint = use_checkpoint 143 | 144 | self.attn = MultiheadAttention( 145 | device=device, 146 | dtype=dtype, 147 | n_ctx=n_ctx, 148 | width=width, 149 | heads=heads, 150 | qkv_bias=qkv_bias 151 | ) 152 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 153 | self.mlp = MLP(device=device, dtype=dtype, width=width) 154 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) 155 | 156 | def _forward(self, x: torch.Tensor): 157 | x = x + self.attn(self.ln_1(x)) 158 | x = x + self.mlp(self.ln_2(x)) 159 | return x 160 | 161 | def forward(self, x: torch.Tensor): 162 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 163 | 164 | 165 | class MultiheadCrossAttention(nn.Module): 166 | def __init__( 167 | self, 168 | *, 169 | device: torch.device, 170 | dtype: torch.dtype, 171 | width: int, 172 | heads: int, 173 | qkv_bias: bool = True, 174 | n_data: Optional[int] = None, 175 | data_width: Optional[int] = None, 176 | ): 177 | super().__init__() 178 | self.n_data = n_data 179 | self.width = width 180 | self.heads = heads 181 | self.data_width = width if data_width is None else data_width 182 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) 183 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) 184 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 185 | self.attention = QKVMultiheadCrossAttention( 186 | device=device, dtype=dtype, heads=heads, n_data=n_data 187 | ) 188 | 189 | def forward(self, x, data): 190 | x = self.c_q(x) 191 | data = self.c_kv(data) 192 | x = checkpoint(self.attention, (x, data), (), True) 193 | x = self.c_proj(x) 194 | return x 195 | 196 | 197 | class QKVMultiheadCrossAttention(nn.Module): 198 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None): 199 | super().__init__() 200 | self.device = device 201 | self.dtype = dtype 202 | self.heads = heads 203 | self.n_data = n_data 204 | 205 | def forward(self, q, kv): 206 | _, n_ctx, _ = q.shape 207 | bs, n_data, width = kv.shape 208 | attn_ch = width // self.heads // 2 209 | scale = 1 / math.sqrt(attn_ch) 210 | q = q.view(bs, n_ctx, self.heads, -1) 211 | kv = kv.view(bs, n_data, self.heads, -1) 212 | k, v = torch.split(kv, attn_ch, dim=-1) 213 | weight = torch.einsum("bthc,bshc->bhts", q, k) * scale 214 | wdtype = weight.dtype 215 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 216 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 217 | 218 | 219 | class ResidualCrossAttentionBlock(nn.Module): 220 | def __init__( 221 | self, 222 | *, 223 | device: Optional[torch.device], 224 | dtype: Optional[torch.dtype], 225 | n_data: Optional[int] = None, 226 | width: int, 227 | heads: int, 228 | data_width: Optional[int] = None, 229 | qkv_bias: bool = True 230 | ): 231 | super().__init__() 232 | 233 | if data_width is None: 234 | data_width = width 235 | 236 | self.attn = MultiheadCrossAttention( 237 | device=device, 238 | dtype=dtype, 239 | n_data=n_data, 240 | width=width, 241 | heads=heads, 242 | data_width=data_width, 243 | qkv_bias=qkv_bias 244 | ) 245 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 246 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 247 | self.mlp = MLP(device=device, dtype=dtype, width=width) 248 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 249 | 250 | def forward(self, x: torch.Tensor, data: torch.Tensor): 251 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 252 | x = x + self.mlp(self.ln_3(x)) 253 | return x 254 | 255 | 256 | class MLP(nn.Module): 257 | def __init__(self, *, 258 | device: Optional[torch.device], 259 | dtype: Optional[torch.dtype], 260 | width: int): 261 | super().__init__() 262 | self.width = width 263 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 264 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 265 | self.gelu = nn.GELU() 266 | 267 | def forward(self, x): 268 | return self.c_proj(self.gelu(self.c_fc(x))) 269 | 270 | 271 | class Transformer(nn.Module): 272 | def __init__( 273 | self, 274 | *, 275 | device: Optional[torch.device], 276 | dtype: Optional[torch.dtype], 277 | n_ctx: int, 278 | width: int, 279 | layers: int, 280 | heads: int, 281 | qkv_bias: bool = True, 282 | use_checkpoint: bool = False 283 | ): 284 | super().__init__() 285 | self.n_ctx = n_ctx 286 | self.width = width 287 | self.layers = layers 288 | self.resblocks = nn.ModuleList( 289 | [ 290 | ResidualAttentionBlock( 291 | device=device, 292 | dtype=dtype, 293 | n_ctx=n_ctx, 294 | width=width, 295 | heads=heads, 296 | qkv_bias=qkv_bias, 297 | use_checkpoint=use_checkpoint 298 | ) 299 | for _ in range(layers) 300 | ] 301 | ) 302 | 303 | self.apply(init_weights) 304 | 305 | def forward(self, x: torch.Tensor): 306 | for block in self.resblocks: 307 | x = block(x) 308 | return x 309 | -------------------------------------------------------------------------------- /michelangelo/models/tsal/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /michelangelo/models/tsal/asl_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import List, Tuple, Dict, Optional 4 | from omegaconf import DictConfig 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.optim import lr_scheduler 9 | import pytorch_lightning as pl 10 | from typing import Union 11 | from functools import partial 12 | 13 | from michelangelo.utils import instantiate_from_config 14 | 15 | from .inference_utils import extract_geometry 16 | from .tsal_base import ( 17 | AlignedShapeAsLatentModule, 18 | ShapeAsLatentModule, 19 | Latent2MeshOutput, 20 | AlignedMeshOutput 21 | ) 22 | 23 | 24 | class AlignedShapeAsLatentPLModule(pl.LightningModule): 25 | 26 | def __init__(self, *, 27 | shape_module_cfg, 28 | aligned_module_cfg, 29 | loss_cfg, 30 | optimizer_cfg: Optional[DictConfig] = None, 31 | ckpt_path: Optional[str] = None, 32 | ignore_keys: Union[Tuple[str], List[str]] = ()): 33 | 34 | super().__init__() 35 | 36 | shape_model: ShapeAsLatentModule = instantiate_from_config( 37 | shape_module_cfg, device=None, dtype=None 38 | ) 39 | self.model: AlignedShapeAsLatentModule = instantiate_from_config( 40 | aligned_module_cfg, shape_model=shape_model 41 | ) 42 | 43 | self.loss = instantiate_from_config(loss_cfg) 44 | 45 | self.optimizer_cfg = optimizer_cfg 46 | 47 | if ckpt_path is not None: 48 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 49 | 50 | self.save_hyperparameters() 51 | 52 | def set_shape_model_only(self): 53 | self.model.set_shape_model_only() 54 | 55 | @property 56 | def latent_shape(self): 57 | return self.model.shape_model.latent_shape 58 | 59 | @property 60 | def zero_rank(self): 61 | if self._trainer: 62 | zero_rank = self.trainer.local_rank == 0 63 | else: 64 | zero_rank = True 65 | 66 | return zero_rank 67 | 68 | def init_from_ckpt(self, path, ignore_keys=()): 69 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 70 | 71 | keys = list(state_dict.keys()) 72 | for k in keys: 73 | for ik in ignore_keys: 74 | if k.startswith(ik): 75 | print("Deleting key {} from state_dict.".format(k)) 76 | del state_dict[k] 77 | 78 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 79 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 80 | if len(missing) > 0: 81 | print(f"Missing Keys: {missing}") 82 | print(f"Unexpected Keys: {unexpected}") 83 | 84 | def configure_optimizers(self) -> Tuple[List, List]: 85 | lr = self.learning_rate 86 | 87 | trainable_parameters = list(self.model.parameters()) 88 | 89 | if self.optimizer_cfg is None: 90 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 91 | schedulers = [] 92 | else: 93 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 94 | scheduler_func = instantiate_from_config( 95 | self.optimizer_cfg.scheduler, 96 | max_decay_steps=self.trainer.max_steps, 97 | lr_max=lr 98 | ) 99 | scheduler = { 100 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 101 | "interval": "step", 102 | "frequency": 1 103 | } 104 | optimizers = [optimizer] 105 | schedulers = [scheduler] 106 | 107 | return optimizers, schedulers 108 | 109 | def forward(self, 110 | surface: torch.FloatTensor, 111 | image: torch.FloatTensor, 112 | text: torch.FloatTensor, 113 | volume_queries: torch.FloatTensor): 114 | 115 | """ 116 | 117 | Args: 118 | surface (torch.FloatTensor): 119 | image (torch.FloatTensor): 120 | text (torch.FloatTensor): 121 | volume_queries (torch.FloatTensor): 122 | 123 | Returns: 124 | 125 | """ 126 | 127 | embed_outputs, shape_z = self.model(surface, image, text) 128 | 129 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) 130 | latents = self.model.shape_model.decode(shape_zq) 131 | logits = self.model.shape_model.query_geometry(volume_queries, latents) 132 | 133 | return embed_outputs, logits, posterior 134 | 135 | def encode(self, surface: torch.FloatTensor, sample_posterior=True): 136 | 137 | pc = surface[..., 0:3] 138 | feats = surface[..., 3:6] 139 | 140 | shape_embed, shape_zq, posterior = self.model.shape_model.encode( 141 | pc=pc, feats=feats, sample_posterior=sample_posterior 142 | ) 143 | 144 | return shape_zq 145 | 146 | def decode(self, 147 | z_q, 148 | bounds: Union[Tuple[float], List[float], float] = 1.1, 149 | octree_depth: int = 7, 150 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 151 | 152 | latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim] 153 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) 154 | 155 | return outputs 156 | 157 | def training_step(self, batch: Dict[str, torch.FloatTensor], 158 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 159 | """ 160 | 161 | Args: 162 | batch (dict): the batch sample, and it contains: 163 | - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] 164 | - image (torch.FloatTensor): [bs, 3, 224, 224] 165 | - text (torch.FloatTensor): [bs, num_templates, 77] 166 | - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] 167 | 168 | batch_idx (int): 169 | 170 | optimizer_idx (int): 171 | 172 | Returns: 173 | loss (torch.FloatTensor): 174 | 175 | """ 176 | 177 | surface = batch["surface"] 178 | image = batch["image"] 179 | text = batch["text"] 180 | 181 | volume_queries = batch["geo_points"][..., 0:3] 182 | shape_labels = batch["geo_points"][..., -1] 183 | 184 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) 185 | 186 | aeloss, log_dict_ae = self.loss( 187 | **embed_outputs, 188 | posteriors=posteriors, 189 | shape_logits=shape_logits, 190 | shape_labels=shape_labels, 191 | split="train" 192 | ) 193 | 194 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], 195 | sync_dist=False, rank_zero_only=True) 196 | 197 | return aeloss 198 | 199 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: 200 | 201 | surface = batch["surface"] 202 | image = batch["image"] 203 | text = batch["text"] 204 | 205 | volume_queries = batch["geo_points"][..., 0:3] 206 | shape_labels = batch["geo_points"][..., -1] 207 | 208 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) 209 | 210 | aeloss, log_dict_ae = self.loss( 211 | **embed_outputs, 212 | posteriors=posteriors, 213 | shape_logits=shape_logits, 214 | shape_labels=shape_labels, 215 | split="val" 216 | ) 217 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], 218 | sync_dist=False, rank_zero_only=True) 219 | 220 | return aeloss 221 | 222 | def visual_alignment(self, 223 | surface: torch.FloatTensor, 224 | image: torch.FloatTensor, 225 | text: torch.FloatTensor, 226 | description: Optional[List[str]] = None, 227 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 228 | octree_depth: int = 7, 229 | num_chunks: int = 10000) -> List[AlignedMeshOutput]: 230 | 231 | """ 232 | 233 | Args: 234 | surface: 235 | image: 236 | text: 237 | description: 238 | bounds: 239 | octree_depth: 240 | num_chunks: 241 | 242 | Returns: 243 | mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list. 244 | 245 | """ 246 | 247 | outputs = [] 248 | 249 | device = surface.device 250 | bs = surface.shape[0] 251 | 252 | embed_outputs, shape_z = self.model(surface, image, text) 253 | 254 | # calculate the similarity 255 | image_embed = embed_outputs["image_embed"] 256 | text_embed = embed_outputs["text_embed"] 257 | shape_embed = embed_outputs["shape_embed"] 258 | 259 | # normalized features 260 | shape_embed = F.normalize(shape_embed, dim=-1, p=2) 261 | text_embed = F.normalize(text_embed, dim=-1, p=2) 262 | image_embed = F.normalize(image_embed, dim=-1, p=2) 263 | 264 | # B x B 265 | shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1) 266 | 267 | # B x B 268 | shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1) 269 | 270 | # shape reconstruction 271 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) 272 | latents = self.model.shape_model.decode(shape_zq) 273 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 274 | 275 | # 2. decode geometry 276 | mesh_v_f, has_surface = extract_geometry( 277 | geometric_func=geometric_func, 278 | device=device, 279 | batch_size=bs, 280 | bounds=bounds, 281 | octree_depth=octree_depth, 282 | num_chunks=num_chunks, 283 | disable=not self.zero_rank 284 | ) 285 | 286 | # 3. decode texture 287 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 288 | if not is_surface: 289 | outputs.append(None) 290 | continue 291 | 292 | out = AlignedMeshOutput() 293 | out.mesh_v = mesh_v 294 | out.mesh_f = mesh_f 295 | out.surface = surface[i].cpu().numpy() 296 | out.image = image[i].cpu().numpy() 297 | if description is not None: 298 | out.text = description[i] 299 | out.shape_text_similarity = shape_text_similarity[i, i] 300 | out.shape_image_similarity = shape_image_similarity[i, i] 301 | 302 | outputs.append(out) 303 | 304 | return outputs 305 | 306 | def latent2mesh(self, 307 | latents: torch.FloatTensor, 308 | bounds: Union[Tuple[float], List[float], float] = 1.1, 309 | octree_depth: int = 7, 310 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 311 | 312 | """ 313 | 314 | Args: 315 | latents: [bs, num_latents, dim] 316 | bounds: 317 | octree_depth: 318 | num_chunks: 319 | 320 | Returns: 321 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 322 | 323 | """ 324 | 325 | outputs = [] 326 | 327 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 328 | 329 | # 2. decode geometry 330 | device = latents.device 331 | mesh_v_f, has_surface = extract_geometry( 332 | geometric_func=geometric_func, 333 | device=device, 334 | batch_size=len(latents), 335 | bounds=bounds, 336 | octree_depth=octree_depth, 337 | num_chunks=num_chunks, 338 | disable=not self.zero_rank 339 | ) 340 | 341 | # 3. decode texture 342 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 343 | if not is_surface: 344 | outputs.append(None) 345 | continue 346 | 347 | out = Latent2MeshOutput() 348 | out.mesh_v = mesh_v 349 | out.mesh_f = mesh_f 350 | 351 | outputs.append(out) 352 | 353 | return outputs 354 | 355 | -------------------------------------------------------------------------------- /michelangelo/models/tsal/clip_asl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange 6 | from transformers import CLIPModel 7 | 8 | from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule 9 | 10 | 11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): 12 | 13 | def __init__(self, *, 14 | shape_model, 15 | clip_model_version: str = "openai/clip-vit-large-patch14"): 16 | 17 | super().__init__() 18 | 19 | self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) 20 | for params in self.clip_model.parameters(): 21 | params.requires_grad = False 22 | 23 | self.shape_model = shape_model 24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.clip_model.projection_dim)) 25 | nn.init.normal_(self.shape_projection, std=self.clip_model.projection_dim ** -0.5) 26 | 27 | def set_shape_model_only(self): 28 | self.clip_model = None 29 | 30 | def encode_shape_embed(self, surface, return_latents: bool = False): 31 | """ 32 | 33 | Args: 34 | surface (torch.FloatTensor): [bs, n, 3 + c] 35 | return_latents (bool): 36 | 37 | Returns: 38 | x (torch.FloatTensor): [bs, projection_dim] 39 | shape_latents (torch.FloatTensor): [bs, m, d] 40 | """ 41 | 42 | pc = surface[..., 0:3] 43 | feats = surface[..., 3:] 44 | 45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) 46 | x = shape_embed @ self.shape_projection 47 | 48 | if return_latents: 49 | return x, shape_latents 50 | else: 51 | return x 52 | 53 | def encode_image_embed(self, image): 54 | """ 55 | 56 | Args: 57 | image (torch.FloatTensor): [bs, 3, h, w] 58 | 59 | Returns: 60 | x (torch.FloatTensor): [bs, projection_dim] 61 | """ 62 | 63 | x = self.clip_model.get_image_features(image) 64 | 65 | return x 66 | 67 | def encode_text_embed(self, text): 68 | x = self.clip_model.get_text_features(text) 69 | return x 70 | 71 | def forward(self, surface, image, text): 72 | """ 73 | 74 | Args: 75 | surface (torch.FloatTensor): 76 | image (torch.FloatTensor): [bs, 3, 224, 224] 77 | text (torch.LongTensor): [bs, num_templates, 77] 78 | 79 | Returns: 80 | embed_outputs (dict): the embedding outputs, and it contains: 81 | - image_embed (torch.FloatTensor): 82 | - text_embed (torch.FloatTensor): 83 | - shape_embed (torch.FloatTensor): 84 | - logit_scale (float): 85 | """ 86 | 87 | # # text embedding 88 | # text_embed_all = [] 89 | # for i in range(text.shape[0]): 90 | # text_for_one_sample = text[i] 91 | # text_embed = self.encode_text_embed(text_for_one_sample) 92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 93 | # text_embed = text_embed.mean(dim=0) 94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 95 | # text_embed_all.append(text_embed) 96 | # text_embed_all = torch.stack(text_embed_all) 97 | 98 | b = text.shape[0] 99 | text_tokens = rearrange(text, "b t l -> (b t) l") 100 | text_embed = self.encode_text_embed(text_tokens) 101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 102 | text_embed = text_embed.mean(dim=1) 103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 104 | 105 | # image embedding 106 | image_embed = self.encode_image_embed(image) 107 | 108 | # shape embedding 109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) 110 | 111 | embed_outputs = { 112 | "image_embed": image_embed, 113 | "text_embed": text_embed, 114 | "shape_embed": shape_embed, 115 | "logit_scale": self.clip_model.logit_scale.exp() 116 | } 117 | 118 | return embed_outputs, shape_latents 119 | -------------------------------------------------------------------------------- /michelangelo/models/tsal/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from einops import repeat 6 | import numpy as np 7 | from typing import Callable, Tuple, List, Union, Optional 8 | from skimage import measure 9 | 10 | from michelangelo.graphics.primitives import generate_dense_grid_points 11 | 12 | 13 | @torch.no_grad() 14 | def extract_geometry(geometric_func: Callable, 15 | device: torch.device, 16 | batch_size: int = 1, 17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 18 | octree_depth: int = 7, 19 | num_chunks: int = 10000, 20 | disable: bool = True): 21 | """ 22 | 23 | Args: 24 | geometric_func: 25 | device: 26 | bounds: 27 | octree_depth: 28 | batch_size: 29 | num_chunks: 30 | disable: 31 | 32 | Returns: 33 | 34 | """ 35 | 36 | if isinstance(bounds, float): 37 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] 38 | 39 | bbox_min = np.array(bounds[0:3]) 40 | bbox_max = np.array(bounds[3:6]) 41 | bbox_size = bbox_max - bbox_min 42 | 43 | xyz_samples, grid_size, length = generate_dense_grid_points( 44 | bbox_min=bbox_min, 45 | bbox_max=bbox_max, 46 | octree_depth=octree_depth, 47 | indexing="ij" 48 | ) 49 | xyz_samples = torch.FloatTensor(xyz_samples) 50 | 51 | batch_logits = [] 52 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), 53 | desc="Implicit Function:", disable=disable, leave=False): 54 | queries = xyz_samples[start: start + num_chunks, :].to(device) 55 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size) 56 | 57 | logits = geometric_func(batch_queries) 58 | batch_logits.append(logits.cpu()) 59 | 60 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() 61 | 62 | mesh_v_f = [] 63 | has_surface = np.zeros((batch_size,), dtype=np.bool_) 64 | for i in range(batch_size): 65 | try: 66 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") 67 | vertices = vertices / grid_size * bbox_size + bbox_min 68 | # vertices[:, [0, 1]] = vertices[:, [1, 0]] 69 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) 70 | has_surface[i] = True 71 | 72 | except ValueError: 73 | mesh_v_f.append((None, None)) 74 | has_surface[i] = False 75 | 76 | except RuntimeError: 77 | mesh_v_f.append((None, None)) 78 | has_surface[i] = False 79 | 80 | return mesh_v_f, has_surface 81 | -------------------------------------------------------------------------------- /michelangelo/models/tsal/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from typing import Optional, Tuple, Dict 7 | 8 | from michelangelo.models.modules.distributions import DiagonalGaussianDistribution 9 | from michelangelo.utils.eval import compute_psnr 10 | from michelangelo.utils import misc 11 | 12 | 13 | class KLNearFar(nn.Module): 14 | def __init__(self, 15 | near_weight: float = 0.1, 16 | kl_weight: float = 1.0, 17 | num_near_samples: Optional[int] = None): 18 | 19 | super().__init__() 20 | 21 | self.near_weight = near_weight 22 | self.kl_weight = kl_weight 23 | self.num_near_samples = num_near_samples 24 | self.geo_criterion = nn.BCEWithLogitsLoss() 25 | 26 | def forward(self, 27 | posteriors: Optional[DiagonalGaussianDistribution], 28 | logits: torch.FloatTensor, 29 | labels: torch.FloatTensor, 30 | split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: 31 | 32 | """ 33 | 34 | Args: 35 | posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): 36 | logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; 37 | labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; 38 | split (str): 39 | **kwargs: 40 | 41 | Returns: 42 | loss (torch.Tensor): (,) 43 | log (dict): 44 | 45 | """ 46 | 47 | if self.num_near_samples is None: 48 | num_vol = logits.shape[1] // 2 49 | else: 50 | num_vol = logits.shape[1] - self.num_near_samples 51 | 52 | vol_logits = logits[:, 0:num_vol] 53 | vol_labels = labels[:, 0:num_vol] 54 | 55 | near_logits = logits[:, num_vol:] 56 | near_labels = labels[:, num_vol:] 57 | 58 | # occupancy loss 59 | # vol_bce = self.geo_criterion(vol_logits, vol_labels) 60 | # near_bce = self.geo_criterion(near_logits, near_labels) 61 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 62 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 63 | 64 | if posteriors is None: 65 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) 66 | else: 67 | kl_loss = posteriors.kl(dims=(1, 2)) 68 | kl_loss = torch.mean(kl_loss) 69 | 70 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight 71 | 72 | with torch.no_grad(): 73 | preds = logits >= 0 74 | accuracy = (preds == labels).float() 75 | accuracy = accuracy.mean() 76 | pos_ratio = torch.mean(labels) 77 | 78 | log = { 79 | "{}/total_loss".format(split): loss.clone().detach(), 80 | "{}/near".format(split): near_bce.detach(), 81 | "{}/far".format(split): vol_bce.detach(), 82 | "{}/kl".format(split): kl_loss.detach(), 83 | "{}/accuracy".format(split): accuracy, 84 | "{}/pos_ratio".format(split): pos_ratio 85 | } 86 | 87 | if posteriors is not None: 88 | log[f"{split}/mean"] = posteriors.mean.mean().detach() 89 | log[f"{split}/std_mean"] = posteriors.std.mean().detach() 90 | log[f"{split}/std_max"] = posteriors.std.max().detach() 91 | 92 | return loss, log 93 | 94 | 95 | class KLNearFarColor(nn.Module): 96 | def __init__(self, 97 | near_weight: float = 0.1, 98 | kl_weight: float = 1.0, 99 | color_weight: float = 1.0, 100 | color_criterion: str = "mse", 101 | num_near_samples: Optional[int] = None): 102 | 103 | super().__init__() 104 | 105 | self.color_weight = color_weight 106 | self.near_weight = near_weight 107 | self.kl_weight = kl_weight 108 | self.num_near_samples = num_near_samples 109 | 110 | if color_criterion == "mse": 111 | self.color_criterion = nn.MSELoss() 112 | 113 | elif color_criterion == "l1": 114 | self.color_criterion = nn.L1Loss() 115 | 116 | else: 117 | raise ValueError(f"{color_criterion} must be [`mse`, `l1`].") 118 | 119 | self.geo_criterion = nn.BCEWithLogitsLoss() 120 | 121 | def forward(self, 122 | posteriors: Optional[DiagonalGaussianDistribution], 123 | logits: torch.FloatTensor, 124 | labels: torch.FloatTensor, 125 | pred_colors: torch.FloatTensor, 126 | gt_colors: torch.FloatTensor, 127 | split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: 128 | 129 | """ 130 | 131 | Args: 132 | posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): 133 | logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; 134 | labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; 135 | pred_colors (torch.FloatTensor): [B, M, 3] 136 | gt_colors (torch.FloatTensor): [B, M, 3] 137 | split (str): 138 | **kwargs: 139 | 140 | Returns: 141 | loss (torch.Tensor): (,) 142 | log (dict): 143 | 144 | """ 145 | 146 | if self.num_near_samples is None: 147 | num_vol = logits.shape[1] // 2 148 | else: 149 | num_vol = logits.shape[1] - self.num_near_samples 150 | 151 | vol_logits = logits[:, 0:num_vol] 152 | vol_labels = labels[:, 0:num_vol] 153 | 154 | near_logits = logits[:, num_vol:] 155 | near_labels = labels[:, num_vol:] 156 | 157 | # occupancy loss 158 | # vol_bce = self.geo_criterion(vol_logits, vol_labels) 159 | # near_bce = self.geo_criterion(near_logits, near_labels) 160 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 161 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 162 | 163 | # surface color loss 164 | color = self.color_criterion(pred_colors, gt_colors) 165 | 166 | if posteriors is None: 167 | kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device) 168 | else: 169 | kl_loss = posteriors.kl(dims=(1, 2)) 170 | kl_loss = torch.mean(kl_loss) 171 | 172 | loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight 173 | 174 | with torch.no_grad(): 175 | preds = logits >= 0 176 | accuracy = (preds == labels).float() 177 | accuracy = accuracy.mean() 178 | psnr = compute_psnr(pred_colors, gt_colors) 179 | 180 | log = { 181 | "{}/total_loss".format(split): loss.clone().detach(), 182 | "{}/near".format(split): near_bce.detach(), 183 | "{}/far".format(split): vol_bce.detach(), 184 | "{}/color".format(split): color.detach(), 185 | "{}/kl".format(split): kl_loss.detach(), 186 | "{}/psnr".format(split): psnr.detach(), 187 | "{}/accuracy".format(split): accuracy 188 | } 189 | 190 | return loss, log 191 | 192 | 193 | class ContrastKLNearFar(nn.Module): 194 | def __init__(self, 195 | contrast_weight: float = 1.0, 196 | near_weight: float = 0.1, 197 | kl_weight: float = 1.0, 198 | num_near_samples: Optional[int] = None): 199 | 200 | super().__init__() 201 | 202 | self.labels = None 203 | self.last_local_batch_size = None 204 | 205 | self.contrast_weight = contrast_weight 206 | self.near_weight = near_weight 207 | self.kl_weight = kl_weight 208 | self.num_near_samples = num_near_samples 209 | self.geo_criterion = nn.BCEWithLogitsLoss() 210 | 211 | def forward(self, 212 | shape_embed: torch.FloatTensor, 213 | text_embed: torch.FloatTensor, 214 | image_embed: torch.FloatTensor, 215 | logit_scale: torch.FloatTensor, 216 | posteriors: Optional[DiagonalGaussianDistribution], 217 | shape_logits: torch.FloatTensor, 218 | shape_labels: torch.FloatTensor, 219 | split: Optional[str] = "train", **kwargs): 220 | 221 | local_batch_size = shape_embed.size(0) 222 | 223 | if local_batch_size != self.last_local_batch_size: 224 | self.labels = local_batch_size * misc.get_rank() + torch.arange( 225 | local_batch_size, device=shape_embed.device 226 | ).long() 227 | self.last_local_batch_size = local_batch_size 228 | 229 | # normalized features 230 | shape_embed = F.normalize(shape_embed, dim=-1, p=2) 231 | text_embed = F.normalize(text_embed, dim=-1, p=2) 232 | image_embed = F.normalize(image_embed, dim=-1, p=2) 233 | 234 | # gather features from all GPUs 235 | shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( 236 | [shape_embed, text_embed, image_embed] 237 | ) 238 | 239 | # cosine similarity as logits 240 | logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() 241 | logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() 242 | logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() 243 | logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() 244 | contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + 245 | F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ 246 | (F.cross_entropy(logits_per_shape_image, self.labels) + 247 | F.cross_entropy(logits_per_image_shape, self.labels)) / 2 248 | 249 | # shape reconstruction 250 | if self.num_near_samples is None: 251 | num_vol = shape_logits.shape[1] // 2 252 | else: 253 | num_vol = shape_logits.shape[1] - self.num_near_samples 254 | 255 | vol_logits = shape_logits[:, 0:num_vol] 256 | vol_labels = shape_labels[:, 0:num_vol] 257 | 258 | near_logits = shape_logits[:, num_vol:] 259 | near_labels = shape_labels[:, num_vol:] 260 | 261 | # occupancy loss 262 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 263 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 264 | 265 | if posteriors is None: 266 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) 267 | else: 268 | kl_loss = posteriors.kl(dims=(1, 2)) 269 | kl_loss = torch.mean(kl_loss) 270 | 271 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight 272 | 273 | # compute accuracy 274 | with torch.no_grad(): 275 | pred = torch.argmax(logits_per_shape_text, dim=-1) 276 | correct = pred.eq(self.labels).sum() 277 | shape_text_acc = 100 * correct / local_batch_size 278 | 279 | pred = torch.argmax(logits_per_shape_image, dim=-1) 280 | correct = pred.eq(self.labels).sum() 281 | shape_image_acc = 100 * correct / local_batch_size 282 | 283 | preds = shape_logits >= 0 284 | accuracy = (preds == shape_labels).float() 285 | accuracy = accuracy.mean() 286 | 287 | log = { 288 | "{}/contrast".format(split): contrast_loss.clone().detach(), 289 | "{}/near".format(split): near_bce.detach(), 290 | "{}/far".format(split): vol_bce.detach(), 291 | "{}/kl".format(split): kl_loss.detach(), 292 | "{}/shape_text_acc".format(split): shape_text_acc, 293 | "{}/shape_image_acc".format(split): shape_image_acc, 294 | "{}/total_loss".format(split): loss.clone().detach(), 295 | "{}/accuracy".format(split): accuracy, 296 | } 297 | 298 | if posteriors is not None: 299 | log[f"{split}/mean"] = posteriors.mean.mean().detach() 300 | log[f"{split}/std_mean"] = posteriors.std.mean().detach() 301 | log[f"{split}/std_max"] = posteriors.std.max().detach() 302 | 303 | return loss, log 304 | -------------------------------------------------------------------------------- /michelangelo/models/tsal/sal_perceiver.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Optional 6 | from einops import repeat 7 | import math 8 | 9 | from michelangelo.models.modules import checkpoint 10 | from michelangelo.models.modules.embedder import FourierEmbedder 11 | from michelangelo.models.modules.distributions import DiagonalGaussianDistribution 12 | from michelangelo.models.modules.transformer_blocks import ( 13 | ResidualCrossAttentionBlock, 14 | Transformer 15 | ) 16 | 17 | from .tsal_base import ShapeAsLatentModule 18 | 19 | 20 | class CrossAttentionEncoder(nn.Module): 21 | 22 | def __init__(self, *, 23 | device: Optional[torch.device], 24 | dtype: Optional[torch.dtype], 25 | num_latents: int, 26 | fourier_embedder: FourierEmbedder, 27 | point_feats: int, 28 | width: int, 29 | heads: int, 30 | layers: int, 31 | init_scale: float = 0.25, 32 | qkv_bias: bool = True, 33 | flash: bool = False, 34 | use_ln_post: bool = False, 35 | use_checkpoint: bool = False): 36 | 37 | super().__init__() 38 | 39 | self.use_checkpoint = use_checkpoint 40 | self.num_latents = num_latents 41 | 42 | self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) 43 | 44 | self.fourier_embedder = fourier_embedder 45 | self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) 46 | self.cross_attn = ResidualCrossAttentionBlock( 47 | device=device, 48 | dtype=dtype, 49 | width=width, 50 | heads=heads, 51 | init_scale=init_scale, 52 | qkv_bias=qkv_bias, 53 | flash=flash, 54 | ) 55 | 56 | self.self_attn = Transformer( 57 | device=device, 58 | dtype=dtype, 59 | n_ctx=num_latents, 60 | width=width, 61 | layers=layers, 62 | heads=heads, 63 | init_scale=init_scale, 64 | qkv_bias=qkv_bias, 65 | flash=flash, 66 | use_checkpoint=False 67 | ) 68 | 69 | if use_ln_post: 70 | self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) 71 | else: 72 | self.ln_post = None 73 | 74 | def _forward(self, pc, feats): 75 | """ 76 | 77 | Args: 78 | pc (torch.FloatTensor): [B, N, 3] 79 | feats (torch.FloatTensor or None): [B, N, C] 80 | 81 | Returns: 82 | 83 | """ 84 | 85 | bs = pc.shape[0] 86 | 87 | data = self.fourier_embedder(pc) 88 | if feats is not None: 89 | data = torch.cat([data, feats], dim=-1) 90 | data = self.input_proj(data) 91 | 92 | query = repeat(self.query, "m c -> b m c", b=bs) 93 | latents = self.cross_attn(query, data) 94 | latents = self.self_attn(latents) 95 | 96 | if self.ln_post is not None: 97 | latents = self.ln_post(latents) 98 | 99 | return latents, pc 100 | 101 | def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): 102 | """ 103 | 104 | Args: 105 | pc (torch.FloatTensor): [B, N, 3] 106 | feats (torch.FloatTensor or None): [B, N, C] 107 | 108 | Returns: 109 | dict 110 | """ 111 | 112 | return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) 113 | 114 | 115 | class CrossAttentionDecoder(nn.Module): 116 | 117 | def __init__(self, *, 118 | device: Optional[torch.device], 119 | dtype: Optional[torch.dtype], 120 | num_latents: int, 121 | out_channels: int, 122 | fourier_embedder: FourierEmbedder, 123 | width: int, 124 | heads: int, 125 | init_scale: float = 0.25, 126 | qkv_bias: bool = True, 127 | flash: bool = False, 128 | use_checkpoint: bool = False): 129 | 130 | super().__init__() 131 | 132 | self.use_checkpoint = use_checkpoint 133 | self.fourier_embedder = fourier_embedder 134 | 135 | self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) 136 | 137 | self.cross_attn_decoder = ResidualCrossAttentionBlock( 138 | device=device, 139 | dtype=dtype, 140 | n_data=num_latents, 141 | width=width, 142 | heads=heads, 143 | init_scale=init_scale, 144 | qkv_bias=qkv_bias, 145 | flash=flash 146 | ) 147 | 148 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) 149 | self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) 150 | 151 | def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 152 | queries = self.query_proj(self.fourier_embedder(queries)) 153 | x = self.cross_attn_decoder(queries, latents) 154 | x = self.ln_post(x) 155 | x = self.output_proj(x) 156 | return x 157 | 158 | def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 159 | return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) 160 | 161 | 162 | class ShapeAsLatentPerceiver(ShapeAsLatentModule): 163 | def __init__(self, *, 164 | device: Optional[torch.device], 165 | dtype: Optional[torch.dtype], 166 | num_latents: int, 167 | point_feats: int = 0, 168 | embed_dim: int = 0, 169 | num_freqs: int = 8, 170 | include_pi: bool = True, 171 | width: int, 172 | heads: int, 173 | num_encoder_layers: int, 174 | num_decoder_layers: int, 175 | init_scale: float = 0.25, 176 | qkv_bias: bool = True, 177 | flash: bool = False, 178 | use_ln_post: bool = False, 179 | use_checkpoint: bool = False): 180 | 181 | super().__init__() 182 | 183 | self.use_checkpoint = use_checkpoint 184 | 185 | self.num_latents = num_latents 186 | self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) 187 | 188 | init_scale = init_scale * math.sqrt(1.0 / width) 189 | self.encoder = CrossAttentionEncoder( 190 | device=device, 191 | dtype=dtype, 192 | fourier_embedder=self.fourier_embedder, 193 | num_latents=num_latents, 194 | point_feats=point_feats, 195 | width=width, 196 | heads=heads, 197 | layers=num_encoder_layers, 198 | init_scale=init_scale, 199 | qkv_bias=qkv_bias, 200 | flash=flash, 201 | use_ln_post=use_ln_post, 202 | use_checkpoint=use_checkpoint 203 | ) 204 | 205 | self.embed_dim = embed_dim 206 | if embed_dim > 0: 207 | # VAE embed 208 | self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) 209 | self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype) 210 | self.latent_shape = (num_latents, embed_dim) 211 | else: 212 | self.latent_shape = (num_latents, width) 213 | 214 | self.transformer = Transformer( 215 | device=device, 216 | dtype=dtype, 217 | n_ctx=num_latents, 218 | width=width, 219 | layers=num_decoder_layers, 220 | heads=heads, 221 | init_scale=init_scale, 222 | qkv_bias=qkv_bias, 223 | flash=flash, 224 | use_checkpoint=use_checkpoint 225 | ) 226 | 227 | # geometry decoder 228 | self.geo_decoder = CrossAttentionDecoder( 229 | device=device, 230 | dtype=dtype, 231 | fourier_embedder=self.fourier_embedder, 232 | out_channels=1, 233 | num_latents=num_latents, 234 | width=width, 235 | heads=heads, 236 | init_scale=init_scale, 237 | qkv_bias=qkv_bias, 238 | flash=flash, 239 | use_checkpoint=use_checkpoint 240 | ) 241 | 242 | def encode(self, 243 | pc: torch.FloatTensor, 244 | feats: Optional[torch.FloatTensor] = None, 245 | sample_posterior: bool = True): 246 | """ 247 | 248 | Args: 249 | pc (torch.FloatTensor): [B, N, 3] 250 | feats (torch.FloatTensor or None): [B, N, C] 251 | sample_posterior (bool): 252 | 253 | Returns: 254 | latents (torch.FloatTensor) 255 | center_pos (torch.FloatTensor or None): 256 | posterior (DiagonalGaussianDistribution or None): 257 | """ 258 | 259 | latents, center_pos = self.encoder(pc, feats) 260 | 261 | posterior = None 262 | if self.embed_dim > 0: 263 | moments = self.pre_kl(latents) 264 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) 265 | 266 | if sample_posterior: 267 | latents = posterior.sample() 268 | else: 269 | latents = posterior.mode() 270 | 271 | return latents, center_pos, posterior 272 | 273 | def decode(self, latents: torch.FloatTensor): 274 | latents = self.post_kl(latents) 275 | return self.transformer(latents) 276 | 277 | def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): 278 | logits = self.geo_decoder(queries, latents).squeeze(-1) 279 | return logits 280 | 281 | def forward(self, 282 | pc: torch.FloatTensor, 283 | feats: torch.FloatTensor, 284 | volume_queries: torch.FloatTensor, 285 | sample_posterior: bool = True): 286 | """ 287 | 288 | Args: 289 | pc (torch.FloatTensor): [B, N, 3] 290 | feats (torch.FloatTensor or None): [B, N, C] 291 | volume_queries (torch.FloatTensor): [B, P, 3] 292 | sample_posterior (bool): 293 | 294 | Returns: 295 | logits (torch.FloatTensor): [B, P] 296 | center_pos (torch.FloatTensor): [B, M, 3] 297 | posterior (DiagonalGaussianDistribution or None). 298 | 299 | """ 300 | 301 | latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) 302 | 303 | latents = self.decode(latents) 304 | logits = self.query_geometry(volume_queries, latents) 305 | 306 | return logits, center_pos, posterior 307 | 308 | 309 | class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): 310 | 311 | def __init__(self, *, 312 | device: Optional[torch.device], 313 | dtype: Optional[torch.dtype], 314 | num_latents: int, 315 | point_feats: int = 0, 316 | embed_dim: int = 0, 317 | num_freqs: int = 8, 318 | include_pi: bool = True, 319 | width: int, 320 | heads: int, 321 | num_encoder_layers: int, 322 | num_decoder_layers: int, 323 | init_scale: float = 0.25, 324 | qkv_bias: bool = True, 325 | flash: bool = False, 326 | use_ln_post: bool = False, 327 | use_checkpoint: bool = False): 328 | 329 | super().__init__( 330 | device=device, 331 | dtype=dtype, 332 | num_latents=1 + num_latents, 333 | point_feats=point_feats, 334 | embed_dim=embed_dim, 335 | num_freqs=num_freqs, 336 | include_pi=include_pi, 337 | width=width, 338 | heads=heads, 339 | num_encoder_layers=num_encoder_layers, 340 | num_decoder_layers=num_decoder_layers, 341 | init_scale=init_scale, 342 | qkv_bias=qkv_bias, 343 | flash=flash, 344 | use_ln_post=use_ln_post, 345 | use_checkpoint=use_checkpoint 346 | ) 347 | 348 | self.width = width 349 | 350 | def encode(self, 351 | pc: torch.FloatTensor, 352 | feats: Optional[torch.FloatTensor] = None, 353 | sample_posterior: bool = True): 354 | """ 355 | 356 | Args: 357 | pc (torch.FloatTensor): [B, N, 3] 358 | feats (torch.FloatTensor or None): [B, N, c] 359 | sample_posterior (bool): 360 | 361 | Returns: 362 | shape_embed (torch.FloatTensor) 363 | kl_embed (torch.FloatTensor): 364 | posterior (DiagonalGaussianDistribution or None): 365 | """ 366 | 367 | shape_embed, latents = self.encode_latents(pc, feats) 368 | kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) 369 | 370 | return shape_embed, kl_embed, posterior 371 | 372 | def encode_latents(self, 373 | pc: torch.FloatTensor, 374 | feats: Optional[torch.FloatTensor] = None): 375 | 376 | x, _ = self.encoder(pc, feats) 377 | 378 | shape_embed = x[:, 0] 379 | latents = x[:, 1:] 380 | 381 | return shape_embed, latents 382 | 383 | def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): 384 | posterior = None 385 | if self.embed_dim > 0: 386 | moments = self.pre_kl(latents) 387 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) 388 | 389 | if sample_posterior: 390 | kl_embed = posterior.sample() 391 | else: 392 | kl_embed = posterior.mode() 393 | else: 394 | kl_embed = latents 395 | 396 | return kl_embed, posterior 397 | 398 | def forward(self, 399 | pc: torch.FloatTensor, 400 | feats: torch.FloatTensor, 401 | volume_queries: torch.FloatTensor, 402 | sample_posterior: bool = True): 403 | """ 404 | 405 | Args: 406 | pc (torch.FloatTensor): [B, N, 3] 407 | feats (torch.FloatTensor or None): [B, N, C] 408 | volume_queries (torch.FloatTensor): [B, P, 3] 409 | sample_posterior (bool): 410 | 411 | Returns: 412 | shape_embed (torch.FloatTensor): [B, projection_dim] 413 | logits (torch.FloatTensor): [B, M] 414 | posterior (DiagonalGaussianDistribution or None). 415 | 416 | """ 417 | 418 | shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) 419 | 420 | latents = self.decode(kl_embed) 421 | logits = self.query_geometry(volume_queries, latents) 422 | 423 | return shape_embed, logits, posterior 424 | -------------------------------------------------------------------------------- /michelangelo/models/tsal/sal_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import List, Tuple, Dict, Optional 4 | from omegaconf import DictConfig 5 | 6 | import torch 7 | from torch.optim import lr_scheduler 8 | import pytorch_lightning as pl 9 | from typing import Union 10 | from functools import partial 11 | 12 | from michelangelo.utils import instantiate_from_config 13 | 14 | from .inference_utils import extract_geometry 15 | from .tsal_base import ( 16 | ShapeAsLatentModule, 17 | Latent2MeshOutput, 18 | Point2MeshOutput 19 | ) 20 | 21 | 22 | class ShapeAsLatentPLModule(pl.LightningModule): 23 | 24 | def __init__(self, *, 25 | module_cfg, 26 | loss_cfg, 27 | optimizer_cfg: Optional[DictConfig] = None, 28 | ckpt_path: Optional[str] = None, 29 | ignore_keys: Union[Tuple[str], List[str]] = ()): 30 | 31 | super().__init__() 32 | 33 | self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None) 34 | 35 | self.loss = instantiate_from_config(loss_cfg) 36 | 37 | self.optimizer_cfg = optimizer_cfg 38 | 39 | if ckpt_path is not None: 40 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 41 | 42 | self.save_hyperparameters() 43 | 44 | @property 45 | def latent_shape(self): 46 | return self.sal.latent_shape 47 | 48 | @property 49 | def zero_rank(self): 50 | if self._trainer: 51 | zero_rank = self.trainer.local_rank == 0 52 | else: 53 | zero_rank = True 54 | 55 | return zero_rank 56 | 57 | def init_from_ckpt(self, path, ignore_keys=()): 58 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 59 | 60 | keys = list(state_dict.keys()) 61 | for k in keys: 62 | for ik in ignore_keys: 63 | if k.startswith(ik): 64 | print("Deleting key {} from state_dict.".format(k)) 65 | del state_dict[k] 66 | 67 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 68 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 69 | if len(missing) > 0: 70 | print(f"Missing Keys: {missing}") 71 | print(f"Unexpected Keys: {unexpected}") 72 | 73 | def configure_optimizers(self) -> Tuple[List, List]: 74 | lr = self.learning_rate 75 | 76 | # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)] 77 | # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 78 | 79 | if self.optimizer_cfg is None: 80 | optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 81 | schedulers = [] 82 | else: 83 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters()) 84 | scheduler_func = instantiate_from_config( 85 | self.optimizer_cfg.scheduler, 86 | max_decay_steps=self.trainer.max_steps, 87 | lr_max=lr 88 | ) 89 | scheduler = { 90 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 91 | "interval": "step", 92 | "frequency": 1 93 | } 94 | optimizers = [optimizer] 95 | schedulers = [scheduler] 96 | 97 | return optimizers, schedulers 98 | 99 | def forward(self, 100 | pc: torch.FloatTensor, 101 | feats: torch.FloatTensor, 102 | volume_queries: torch.FloatTensor): 103 | 104 | logits, center_pos, posterior = self.sal(pc, feats, volume_queries) 105 | 106 | return posterior, logits 107 | 108 | def encode(self, surface: torch.FloatTensor, sample_posterior=True): 109 | 110 | pc = surface[..., 0:3] 111 | feats = surface[..., 3:6] 112 | 113 | latents, center_pos, posterior = self.sal.encode( 114 | pc=pc, feats=feats, sample_posterior=sample_posterior 115 | ) 116 | 117 | return latents 118 | 119 | def decode(self, 120 | z_q, 121 | bounds: Union[Tuple[float], List[float], float] = 1.1, 122 | octree_depth: int = 7, 123 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 124 | 125 | latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim] 126 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) 127 | 128 | return outputs 129 | 130 | def training_step(self, batch: Dict[str, torch.FloatTensor], 131 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 132 | """ 133 | 134 | Args: 135 | batch (dict): the batch sample, and it contains: 136 | - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] 137 | - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] 138 | 139 | batch_idx (int): 140 | 141 | optimizer_idx (int): 142 | 143 | Returns: 144 | loss (torch.FloatTensor): 145 | 146 | """ 147 | 148 | pc = batch["surface"][..., 0:3] 149 | feats = batch["surface"][..., 3:] 150 | 151 | volume_queries = batch["geo_points"][..., 0:3] 152 | volume_labels = batch["geo_points"][..., -1] 153 | 154 | posterior, logits = self( 155 | pc=pc, feats=feats, volume_queries=volume_queries 156 | ) 157 | aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train") 158 | 159 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], 160 | sync_dist=False, rank_zero_only=True) 161 | 162 | return aeloss 163 | 164 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: 165 | 166 | pc = batch["surface"][..., 0:3] 167 | feats = batch["surface"][..., 3:] 168 | 169 | volume_queries = batch["geo_points"][..., 0:3] 170 | volume_labels = batch["geo_points"][..., -1] 171 | 172 | posterior, logits = self( 173 | pc=pc, feats=feats, volume_queries=volume_queries, 174 | ) 175 | aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val") 176 | 177 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], 178 | sync_dist=False, rank_zero_only=True) 179 | 180 | return aeloss 181 | 182 | def point2mesh(self, 183 | pc: torch.FloatTensor, 184 | feats: torch.FloatTensor, 185 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 186 | octree_depth: int = 7, 187 | num_chunks: int = 10000) -> List[Point2MeshOutput]: 188 | 189 | """ 190 | 191 | Args: 192 | pc: 193 | feats: 194 | bounds: 195 | octree_depth: 196 | num_chunks: 197 | 198 | Returns: 199 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 200 | 201 | """ 202 | 203 | outputs = [] 204 | 205 | device = pc.device 206 | bs = pc.shape[0] 207 | 208 | # 1. point encoder + latents transformer 209 | latents, center_pos, posterior = self.sal.encode(pc, feats) 210 | latents = self.sal.decode(latents) # latents: [bs, num_latents, dim] 211 | 212 | geometric_func = partial(self.sal.query_geometry, latents=latents) 213 | 214 | # 2. decode geometry 215 | mesh_v_f, has_surface = extract_geometry( 216 | geometric_func=geometric_func, 217 | device=device, 218 | batch_size=bs, 219 | bounds=bounds, 220 | octree_depth=octree_depth, 221 | num_chunks=num_chunks, 222 | disable=not self.zero_rank 223 | ) 224 | 225 | # 3. decode texture 226 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 227 | if not is_surface: 228 | outputs.append(None) 229 | continue 230 | 231 | out = Point2MeshOutput() 232 | out.mesh_v = mesh_v 233 | out.mesh_f = mesh_f 234 | out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy() 235 | 236 | if center_pos is not None: 237 | out.center = center_pos[i].cpu().numpy() 238 | 239 | outputs.append(out) 240 | 241 | return outputs 242 | 243 | def latent2mesh(self, 244 | latents: torch.FloatTensor, 245 | bounds: Union[Tuple[float], List[float], float] = 1.1, 246 | octree_depth: int = 7, 247 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 248 | 249 | """ 250 | 251 | Args: 252 | latents: [bs, num_latents, dim] 253 | bounds: 254 | octree_depth: 255 | num_chunks: 256 | 257 | Returns: 258 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 259 | 260 | """ 261 | 262 | outputs = [] 263 | 264 | geometric_func = partial(self.sal.query_geometry, latents=latents) 265 | 266 | # 2. decode geometry 267 | device = latents.device 268 | mesh_v_f, has_surface = extract_geometry( 269 | geometric_func=geometric_func, 270 | device=device, 271 | batch_size=len(latents), 272 | bounds=bounds, 273 | octree_depth=octree_depth, 274 | num_chunks=num_chunks, 275 | disable=not self.zero_rank 276 | ) 277 | 278 | # 3. decode texture 279 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 280 | if not is_surface: 281 | outputs.append(None) 282 | continue 283 | 284 | out = Latent2MeshOutput() 285 | out.mesh_v = mesh_v 286 | out.mesh_f = mesh_f 287 | 288 | outputs.append(out) 289 | 290 | return outputs 291 | -------------------------------------------------------------------------------- /michelangelo/models/tsal/tsal_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from typing import Tuple, List, Optional 5 | import pytorch_lightning as pl 6 | 7 | 8 | class Point2MeshOutput(object): 9 | def __init__(self): 10 | self.mesh_v = None 11 | self.mesh_f = None 12 | self.center = None 13 | self.pc = None 14 | 15 | 16 | class Latent2MeshOutput(object): 17 | 18 | def __init__(self): 19 | self.mesh_v = None 20 | self.mesh_f = None 21 | 22 | 23 | class AlignedMeshOutput(object): 24 | 25 | def __init__(self): 26 | self.mesh_v = None 27 | self.mesh_f = None 28 | self.surface = None 29 | self.image = None 30 | self.text: Optional[str] = None 31 | self.shape_text_similarity: Optional[float] = None 32 | self.shape_image_similarity: Optional[float] = None 33 | 34 | 35 | class ShapeAsLatentPLModule(pl.LightningModule): 36 | latent_shape: Tuple[int] 37 | 38 | def encode(self, surface, *args, **kwargs): 39 | raise NotImplementedError 40 | 41 | def decode(self, z_q, *args, **kwargs): 42 | raise NotImplementedError 43 | 44 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 45 | raise NotImplementedError 46 | 47 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 48 | raise NotImplementedError 49 | 50 | 51 | class ShapeAsLatentModule(nn.Module): 52 | latent_shape: Tuple[int, int] 53 | 54 | def __init__(self, *args, **kwargs): 55 | super().__init__() 56 | 57 | def encode(self, *args, **kwargs): 58 | raise NotImplementedError 59 | 60 | def decode(self, *args, **kwargs): 61 | raise NotImplementedError 62 | 63 | def query_geometry(self, *args, **kwargs): 64 | raise NotImplementedError 65 | 66 | 67 | class AlignedShapeAsLatentPLModule(pl.LightningModule): 68 | latent_shape: Tuple[int] 69 | 70 | def set_shape_model_only(self): 71 | raise NotImplementedError 72 | 73 | def encode(self, surface, *args, **kwargs): 74 | raise NotImplementedError 75 | 76 | def decode(self, z_q, *args, **kwargs): 77 | raise NotImplementedError 78 | 79 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 80 | raise NotImplementedError 81 | 82 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 83 | raise NotImplementedError 84 | 85 | 86 | class AlignedShapeAsLatentModule(nn.Module): 87 | shape_model: ShapeAsLatentModule 88 | latent_shape: Tuple[int, int] 89 | 90 | def __init__(self, *args, **kwargs): 91 | super().__init__() 92 | 93 | def set_shape_model_only(self): 94 | raise NotImplementedError 95 | 96 | def encode_image_embed(self, *args, **kwargs): 97 | raise NotImplementedError 98 | 99 | def encode_text_embed(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | def encode_shape_embed(self, *args, **kwargs): 103 | raise NotImplementedError 104 | 105 | 106 | class TexturedShapeAsLatentModule(nn.Module): 107 | 108 | def __init__(self, *args, **kwargs): 109 | super().__init__() 110 | 111 | def encode(self, *args, **kwargs): 112 | raise NotImplementedError 113 | 114 | def decode(self, *args, **kwargs): 115 | raise NotImplementedError 116 | 117 | def query_geometry(self, *args, **kwargs): 118 | raise NotImplementedError 119 | 120 | def query_color(self, *args, **kwargs): 121 | raise NotImplementedError 122 | -------------------------------------------------------------------------------- /michelangelo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .misc import get_config_from_file 4 | from .misc import instantiate_from_config 5 | -------------------------------------------------------------------------------- /michelangelo/utils/eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7): 7 | 8 | mse = torch.mean((x - y) ** 2) 9 | psnr = 10 * torch.log10(data_range / (mse + eps)) 10 | 11 | return psnr 12 | 13 | -------------------------------------------------------------------------------- /michelangelo/utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import io 5 | import tarfile 6 | import json 7 | import numpy as np 8 | import numpy.lib.format 9 | 10 | 11 | def mkdir(path): 12 | os.makedirs(path, exist_ok=True) 13 | return path 14 | 15 | 16 | def npy_loads(data): 17 | stream = io.BytesIO(data) 18 | return np.lib.format.read_array(stream) 19 | 20 | 21 | def npz_loads(data): 22 | return np.load(io.BytesIO(data)) 23 | 24 | 25 | def json_loads(data): 26 | return json.loads(data) 27 | 28 | 29 | def load_json(filepath): 30 | with open(filepath, "r") as f: 31 | data = json.load(f) 32 | return data 33 | 34 | 35 | def write_json(filepath, data): 36 | with open(filepath, "w") as f: 37 | json.dump(data, f, indent=2) 38 | 39 | 40 | def extract_tar(tar_path, tar_cache_folder): 41 | 42 | with tarfile.open(tar_path, "r") as tar: 43 | tar.extractall(path=tar_cache_folder) 44 | 45 | tar_uids = sorted(os.listdir(tar_cache_folder)) 46 | print(f"extract tar: {tar_path} to {tar_cache_folder}") 47 | return tar_uids 48 | -------------------------------------------------------------------------------- /michelangelo/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import importlib 4 | from omegaconf import OmegaConf, DictConfig, ListConfig 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from typing import Union 9 | 10 | 11 | def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]: 12 | config_file = OmegaConf.load(config_file) 13 | 14 | if 'base_config' in config_file.keys(): 15 | if config_file['base_config'] == "default_base": 16 | base_config = OmegaConf.create() 17 | # base_config = get_default_config() 18 | elif config_file['base_config'].endswith(".yaml"): 19 | base_config = get_config_from_file(config_file['base_config']) 20 | else: 21 | raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.") 22 | 23 | config_file = {key: value for key, value in config_file if key != "base_config"} 24 | 25 | return OmegaConf.merge(base_config, config_file) 26 | 27 | return config_file 28 | 29 | 30 | def get_obj_from_str(string, reload=False): 31 | module, cls = string.rsplit(".", 1) 32 | if reload: 33 | module_imp = importlib.import_module(module) 34 | importlib.reload(module_imp) 35 | return getattr(importlib.import_module(module, package=None), cls) 36 | 37 | 38 | def get_obj_from_config(config): 39 | if "target" not in config: 40 | raise KeyError("Expected key `target` to instantiate.") 41 | 42 | return get_obj_from_str(config["target"]) 43 | 44 | 45 | def instantiate_from_config(config, **kwargs): 46 | if "target" not in config: 47 | raise KeyError("Expected key `target` to instantiate.") 48 | 49 | cls = get_obj_from_str(config["target"]) 50 | 51 | params = config.get("params", dict()) 52 | # params.update(kwargs) 53 | # instance = cls(**params) 54 | kwargs.update(params) 55 | instance = cls(**kwargs) 56 | 57 | return instance 58 | 59 | 60 | def is_dist_avail_and_initialized(): 61 | if not dist.is_available(): 62 | return False 63 | if not dist.is_initialized(): 64 | return False 65 | return True 66 | 67 | 68 | def get_rank(): 69 | if not is_dist_avail_and_initialized(): 70 | return 0 71 | return dist.get_rank() 72 | 73 | 74 | def get_world_size(): 75 | if not is_dist_avail_and_initialized(): 76 | return 1 77 | return dist.get_world_size() 78 | 79 | 80 | def all_gather_batch(tensors): 81 | """ 82 | Performs all_gather operation on the provided tensors. 83 | """ 84 | # Queue the gathered tensors 85 | world_size = get_world_size() 86 | # There is no need for reduction in the single-proc case 87 | if world_size == 1: 88 | return tensors 89 | tensor_list = [] 90 | output_tensor = [] 91 | for tensor in tensors: 92 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 93 | dist.all_gather( 94 | tensor_all, 95 | tensor, 96 | async_op=False # performance opt 97 | ) 98 | 99 | tensor_list.append(tensor_all) 100 | 101 | for tensor_all in tensor_list: 102 | output_tensor.append(torch.cat(tensor_all, dim=0)) 103 | return output_tensor 104 | -------------------------------------------------------------------------------- /michelangelo/utils/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /michelangelo/utils/visualizers/color_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | # Helper functions 6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): 7 | colormap = plt.cm.get_cmap(colormap) 8 | if normalize: 9 | vmin = np.min(inp) 10 | vmax = np.max(inp) 11 | 12 | norm = plt.Normalize(vmin, vmax) 13 | return colormap(norm(inp))[:, :3] 14 | 15 | 16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): 17 | # tex dims need to be power of two. 18 | array = np.ones((width, height, 3), dtype='float32') 19 | 20 | # width in texels of each checker 21 | checker_w = width / n_checkers_x 22 | checker_h = height / n_checkers_y 23 | 24 | for y in range(height): 25 | for x in range(width): 26 | color_key = int(x / checker_w) + int(y / checker_h) 27 | if color_key % 2 == 0: 28 | array[x, y, :] = [1., 0.874, 0.0] 29 | else: 30 | array[x, y, :] = [0., 0., 0.] 31 | return array 32 | 33 | 34 | def gen_circle(width=256, height=256): 35 | xx, yy = np.mgrid[:width, :height] 36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 37 | array = np.ones((width, height, 4), dtype='float32') 38 | array[:, :, 0] = (circle <= width) 39 | array[:, :, 1] = (circle <= width) 40 | array[:, :, 2] = (circle <= width) 41 | array[:, :, 3] = circle <= width 42 | return array 43 | 44 | -------------------------------------------------------------------------------- /michelangelo/utils/visualizers/html_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | import base64 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def to_html_frame(content): 9 | 10 | html_frame = f""" 11 | 12 | 13 | {content} 14 | 15 | 16 | """ 17 | 18 | return html_frame 19 | 20 | 21 | def to_single_row_table(caption: str, content: str): 22 | 23 | table_html = f""" 24 | 25 | 26 | 27 | 28 | 29 |
{caption}
{content}
30 | """ 31 | 32 | return table_html 33 | 34 | 35 | def to_image_embed_tag(image: np.ndarray): 36 | 37 | # Convert np.ndarray to bytes 38 | img = Image.fromarray(image) 39 | raw_bytes = io.BytesIO() 40 | img.save(raw_bytes, "PNG") 41 | 42 | # Encode bytes to base64 43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8") 44 | 45 | image_tag = f""" 46 | Embedded Image 47 | """ 48 | 49 | return image_tag 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.20.3 3 | addict==2.4.0 4 | aiofiles==23.1.0 5 | aiohttp==3.8.4 6 | aiosignal==1.3.1 7 | altair==5.0.1 8 | antlr4-python3-runtime==4.9.3 9 | anyio==3.6.2 10 | appdirs==1.4.4 11 | argon2-cffi==21.3.0 12 | argon2-cffi-bindings==21.2.0 13 | arrow==1.2.3 14 | asttokens==2.2.1 15 | async-timeout==4.0.2 16 | attrs==22.2.0 17 | backcall==0.2.0 18 | beautifulsoup4==4.11.2 19 | bleach==6.0.0 20 | braceexpand==0.1.7 21 | cachetools==5.3.0 22 | cffi==1.15.1 23 | charset-normalizer==3.0.1 24 | click==8.1.3 25 | coloredlogs==15.0.1 26 | comm==0.1.2 27 | configargparse==1.5.3 28 | contourpy==1.0.7 29 | controlnet-aux==0.0.5 30 | cycler==0.11.0 31 | cython==0.29.33 32 | dash==2.8.1 33 | dash-core-components==2.0.0 34 | dash-html-components==2.0.0 35 | dash-table==5.0.0 36 | dataclasses-json==0.6.0 37 | debugpy==1.6.6 38 | decorator==5.1.1 39 | deepspeed==0.8.1 40 | defusedxml==0.7.1 41 | deprecated==1.2.14 42 | diffusers==0.18.2 43 | docker-pycreds==0.4.0 44 | einops==0.6.0 45 | executing==1.2.0 46 | fastapi==0.101.0 47 | fastjsonschema==2.16.2 48 | ffmpy==0.3.1 49 | filelock==3.9.0 50 | flask==2.2.3 51 | flatbuffers==23.5.26 52 | fonttools==4.38.0 53 | fqdn==1.5.1 54 | frozenlist==1.3.3 55 | fsspec==2023.1.0 56 | ftfy==6.1.1 57 | fvcore==0.1.5.post20221221 58 | gitdb==4.0.10 59 | gitpython==3.1.31 60 | google-auth==2.16.1 61 | google-auth-oauthlib==0.4.6 62 | gradio==3.39.0 63 | gradio-client==0.3.0 64 | grpcio==1.51.3 65 | h11==0.14.0 66 | hjson==3.1.0 67 | httpcore==0.17.3 68 | httpx==0.24.1 69 | huggingface-hub==0.16.4 70 | humanfriendly==10.0 71 | idna==3.4 72 | imageio==2.25.1 73 | importlib-metadata==6.0.0 74 | iopath==0.1.10 75 | ipydatawidgets==4.3.3 76 | ipykernel==6.21.2 77 | ipython==8.10.0 78 | ipython-genutils==0.2.0 79 | ipywidgets==8.0.4 80 | isoduration==20.11.0 81 | itsdangerous==2.1.2 82 | jedi==0.18.2 83 | jinja2==3.1.2 84 | joblib==1.2.0 85 | jsonpointer==2.3 86 | jsonschema==4.17.3 87 | jupyter==1.0.0 88 | jupyter-client==8.0.3 89 | jupyter-console==6.6.1 90 | jupyter-core==5.2.0 91 | jupyter-events==0.6.3 92 | jupyter-server==2.3.0 93 | jupyter-server-terminals==0.4.4 94 | jupyterlab-pygments==0.2.2 95 | jupyterlab-widgets==3.0.5 96 | kiwisolver==1.4.4 97 | lightning-utilities==0.7.1 98 | linkify-it-py==2.0.2 99 | lmdb==1.4.1 100 | markdown==3.4.1 101 | markdown-it-py==2.2.0 102 | markupsafe==2.1.2 103 | marshmallow==3.20.1 104 | matplotlib==3.6.3 105 | matplotlib-inline==0.1.6 106 | mdit-py-plugins==0.3.3 107 | mdurl==0.1.2 108 | mesh2sdf==1.1.0 109 | mistune==2.0.5 110 | mpmath==1.3.0 111 | multidict==6.0.4 112 | mypy-extensions==1.0.0 113 | nbclassic==0.5.2 114 | nbclient==0.7.2 115 | nbconvert==7.2.9 116 | nbformat==5.5.0 117 | nest-asyncio==1.5.6 118 | networkx==3.0 119 | ninja==1.11.1 120 | notebook==6.5.2 121 | notebook-shim==0.2.2 122 | numpy==1.23.1 123 | oauthlib==3.2.2 124 | objaverse==0.0.7 125 | omegaconf==2.3.0 126 | onnxruntime==1.15.1 127 | opencv-contrib-python==4.8.0.74 128 | opencv-python==4.7.0.72 129 | orjson==3.9.2 130 | packaging==21.3 131 | pandas==1.4.4 132 | pandocfilters==1.5.0 133 | parso==0.8.3 134 | pathtools==0.1.2 135 | pexpect==4.8.0 136 | pickleshare==0.7.5 137 | pillow==9.2.0 138 | platformdirs==3.0.0 139 | plotly==5.13.0 140 | portalocker==2.7.0 141 | prometheus-client==0.16.0 142 | prompt-toolkit==3.0.37 143 | protobuf==3.19.6 144 | psutil==5.9.4 145 | ptyprocess==0.7.0 146 | pure-eval==0.2.2 147 | py-cpuinfo==9.0.0 148 | pyasn1==0.4.8 149 | pyasn1-modules==0.2.8 150 | pycparser==2.21 151 | pydantic==1.10.5 152 | pydub==0.25.1 153 | pygltflib==1.16.0 154 | pygments==2.14.0 155 | pymeshlab==2022.2.post3 156 | pyparsing==3.0.9 157 | pyquaternion==0.9.9 158 | pyrsistent==0.19.3 159 | pysdf==0.1.8 160 | python-dateutil==2.8.2 161 | python-json-logger==2.0.7 162 | python-multipart==0.0.6 163 | pythreejs==2.4.2 164 | pytz==2022.7.1 165 | pywavelets==1.4.1 166 | pyyaml==6.0 167 | pyzmq==25.0.0 168 | qtconsole==5.4.0 169 | qtpy==2.3.0 170 | regex==2022.10.31 171 | requests==2.28.2 172 | requests-oauthlib==1.3.1 173 | rfc3339-validator==0.1.4 174 | rfc3986-validator==0.1.1 175 | rsa==4.9 176 | rtree==1.0.1 177 | safetensors==0.3.1 178 | scikit-image==0.19.3 179 | scikit-learn==1.2.1 180 | scipy==1.10.1 181 | semantic-version==2.10.0 182 | send2trash==1.8.0 183 | sentencepiece==0.1.97 184 | sentry-sdk==1.15.0 185 | setproctitle==1.3.2 186 | setuptools==63.4.3 187 | sh==2.0.2 188 | shapely==2.0.1 189 | six==1.16.0 190 | smmap==5.0.0 191 | sniffio==1.3.0 192 | soupsieve==2.4 193 | stack-data==0.6.2 194 | starlette==0.27.0 195 | sympy==1.12 196 | tabulate==0.9.0 197 | tenacity==8.2.1 198 | tensorboard==2.10.1 199 | tensorboard-data-server==0.6.1 200 | tensorboard-plugin-wit==1.8.1 201 | termcolor==2.3.0 202 | terminado==0.17.1 203 | threadpoolctl==3.1.0 204 | tifffile==2023.2.3 205 | timm==0.9.2 206 | tinycss2==1.2.1 207 | tokenizers==0.13.2 208 | toolz==0.12.0 209 | tornado==6.2 210 | tqdm==4.64.1 211 | traitlets==5.9.0 212 | traittypes==0.2.1 213 | transformers==4.30.2 214 | trimesh==3.18.3 215 | triton==1.1.1 216 | typing-extensions==4.5.0 217 | typing-inspect==0.9.0 218 | uc-micro-py==1.0.2 219 | uri-template==1.2.0 220 | urllib3==1.26.14 221 | uvicorn==0.23.2 222 | wandb==0.13.10 223 | wcwidth==0.2.6 224 | webcolors==1.12 225 | webdataset==0.2.33 226 | webencodings==0.5.1 227 | websocket-client==1.5.1 228 | websockets==11.0.3 229 | werkzeug==2.2.3 230 | widgetsnbextension==4.0.5 231 | wrapt==1.15.0 232 | xatlas==0.0.7 233 | yacs==0.1.8 234 | yarl==1.8.2 235 | zipp==3.14.0 236 | # torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 237 | https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=db457a822d736013b6ffe509053001bc918bdd78fe68967b605f53984a9afac5 238 | https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=a9fc38040e133d1779f131b4497caef830e9e699faf89cd323cd58794ffb305b 239 | https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=5bc0e29cb78f7c452eeb4f27029c40049770d51553bf840b4ca2edd63da289ee 240 | # torch-cluster 241 | https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_cluster-1.6.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl 242 | # torch-scatter 243 | https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl 244 | torchmetrics==0.11.1 245 | pytorch_lightning~=1.9.3 246 | git+https://github.com/pyvista/fast-simplification.git 247 | git+https://github.com/skoch9/meshplot.git 248 | git+https://github.com/NVlabs/nvdiffrast/ -------------------------------------------------------------------------------- /scripts/infer.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --task reconstruction \ 3 | --config_path ./configs/aligned_shape_latents/shapevae-256.yaml \ 4 | --ckpt_path ./checkpoints/aligned_shape_latents/shapevae-256.ckpt \ 5 | --pointcloud_path ./example_data/surface/surface.npz 6 | 7 | python inference.py \ 8 | --task image2mesh \ 9 | --config_path ./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml \ 10 | --ckpt_path ./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt \ 11 | --image_path ./example_data/image/car.jpg 12 | 13 | python inference.py \ 14 | --task text2mesh \ 15 | --config_path ./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml \ 16 | --ckpt_path ./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt \ 17 | --text "A 3D model of motorcar; Porche Cayenne Turbo." -------------------------------------------------------------------------------- /scripts/inference/image2mesh.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --task image2mesh \ 3 | --config_path ./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml \ 4 | --ckpt_path ./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt \ 5 | --image_path ./example_data/image/car.jpg -------------------------------------------------------------------------------- /scripts/inference/reconstruction.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --task reconstruction \ 3 | --config_path ./configs/aligned_shape_latents/shapevae-256.yaml \ 4 | --ckpt_path ./checkpoints/aligned_shape_latents/shapevae-256.ckpt \ 5 | --pointcloud_path ./example_data/surface/surface.npz -------------------------------------------------------------------------------- /scripts/inference/text2mesh.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --task text2mesh \ 3 | --config_path ./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml \ 4 | --ckpt_path ./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt \ 5 | --text "A 3D model of motorcar; Porche Cayenne Turbo." -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | from distutils.extension import Extension 4 | from Cython.Build import cythonize 5 | import numpy as np 6 | 7 | 8 | setup( 9 | name="michelangelo", 10 | version="0.4.1", 11 | author="Zibo Zhao, Wen Liu and Xin Chen", 12 | author_email="liuwen@shanghaitech.edu.cn", 13 | description="Michelangelo: a 3D Shape Generation System.", 14 | packages=find_packages(exclude=("configs", "tests", "scripts", "example_data")), 15 | python_requires=">=3.8", 16 | install_requires=[ 17 | "torch", 18 | "numpy", 19 | "cython", 20 | "tqdm", 21 | ], 22 | ) 23 | --------------------------------------------------------------------------------