├── LICENSE
├── README.md
├── configs
├── aligned_shape_latents
│ └── shapevae-256.yaml
├── image_cond_diffuser_asl
│ └── image-ASLDM-256.yaml
└── text_cond_diffuser_asl
│ └── text-ASLDM-256.yaml
├── example_data
├── image
│ └── car.jpg
└── surface
│ └── surface.npz
├── inference.py
├── michelangelo
├── __init__.py
├── data
│ ├── __init__.py
│ ├── templates.json
│ ├── transforms.py
│ └── utils.py
├── graphics
│ ├── __init__.py
│ └── primitives
│ │ ├── __init__.py
│ │ ├── mesh.py
│ │ └── volume.py
├── models
│ ├── __init__.py
│ ├── asl_diffusion
│ │ ├── __init__.py
│ │ ├── asl_diffuser_pl_module.py
│ │ ├── asl_udt.py
│ │ ├── base.py
│ │ ├── clip_asl_diffuser_pl_module.py
│ │ └── inference_utils.py
│ ├── conditional_encoders
│ │ ├── __init__.py
│ │ ├── clip.py
│ │ └── encoder_factory.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── checkpoint.py
│ │ ├── diffusion_transformer.py
│ │ ├── distributions.py
│ │ ├── embedder.py
│ │ ├── transformer_blocks.py
│ │ └── transformer_vit.py
│ └── tsal
│ │ ├── __init__.py
│ │ ├── asl_pl_module.py
│ │ ├── clip_asl_module.py
│ │ ├── inference_utils.py
│ │ ├── loss.py
│ │ ├── sal_perceiver.py
│ │ ├── sal_pl_module.py
│ │ └── tsal_base.py
└── utils
│ ├── __init__.py
│ ├── eval.py
│ ├── io.py
│ ├── misc.py
│ └── visualizers
│ ├── __init__.py
│ ├── color_util.py
│ ├── html_util.py
│ └── pythreejs_viewer.py
├── requirements.txt
├── scripts
├── infer.sh
└── inference
│ ├── image2mesh.sh
│ ├── reconstruction.sh
│ └── text2mesh.sh
└── setup.py
/README.md:
--------------------------------------------------------------------------------
1 | # Michelangelo
2 |
3 | ## [Conditional 3D Shape Generation based on Shape-Image-Text Aligned Latent Representation](https://neuralcarver.github.io/michelangelo)
4 | [Zibo Zhao](https://github.com/Maikouuu),
5 | [Wen Liu](https://github.com/StevenLiuWen),
6 | [Xin Chen](https://chenxin.tech/),
7 | [Xianfang Zeng](https://github.com/Zzlongjuanfeng),
8 | [Rui Wang](https://wrong.wang/),
9 | [Pei Cheng](https://neuralcarver.github.io/michelangelo),
10 | [Bin Fu](https://neuralcarver.github.io/michelangelo),
11 | [Tao Chen](https://eetchen.github.io),
12 | [Gang Yu](https://www.skicyyu.org),
13 | [Shenghua Gao](https://sist.shanghaitech.edu.cn/sist_en/2020/0814/c7582a54772/page.htm)
14 | ### [Hugging Face Demo](https://huggingface.co/spaces/Maikou/Michelangelo) | [Project Page](https://neuralcarver.github.io/michelangelo/) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Paper](https://openreview.net/pdf?id=xmxgMij3LY)
15 |
16 | https://github.com/NeuralCarver/Michelangelo/assets/37449470/123bae2c-fbb1-4d63-bd13-0e300a550868
17 |
18 | Visualization of the 3D shape produced by our framework, which splits into triplets with a conditional input on the left, a normal map in the middle, and a triangle mesh on the right. The generated 3D shapes semantically conform to the visual or textural conditional inputs.
19 |
20 | ## 🔆 Features
21 | **Michelangelo** possesses three capabilities:
22 |
23 | 1. Representing a shape into shape-image-text aligned space;
24 | 2. Image-conditioned Shape Generation;
25 | 3. Text-conditioned Shape Generation.
26 |
27 |
28 | Techniques
29 |
30 | We present a novel _alignment-before-generation_ approach to tackle the challenging task of generating general 3D shapes based on 2D images or texts. Directly learning a conditional generative model from images or texts to 3D shapes is prone to producing inconsistent results with the conditions because 3D shapes have an additional dimension whose distribution significantly differs from that of 2D images and texts. To bridge the domain gap among the three modalities and facilitate multi-modal-conditioned 3D shape generation, we explore representing 3D shapes in a shape-image-text-aligned space. Our framework comprises two models: a Shape-Image-Text-Aligned Variational Auto-Encoder (SITA-VAE) and a conditional Aligned Shape Latent Diffusion Model (ASLDM). The former model encodes the 3D shapes into the shape latent space aligned to the image and text and reconstructs the fine-grained 3D neural fields corresponding to given shape embeddings via the transformer-based decoder. The latter model learns a probabilistic mapping function from the image or text space to the latent shape space. Our extensive experiments demonstrate that our proposed approach can generate higher-quality and more diverse 3D shapes that better semantically conform to the visual or textural conditional inputs, validating the effectiveness of the shape-image-text-aligned space for cross-modality 3D shape generation.
31 |
32 | 
33 |
34 |
35 | ## 📰 News
36 | - [2024/1/23] Set up the Hugging Face Demo and release the code
37 | - [2023/09/22] **Michelangelo got accepted by NeurIPS 2023!**
38 | - [2023/6/29] Upload paper and init project
39 |
40 | ## ⚙️ Setup
41 |
42 | ### Installation
43 | Follow the command below to install the environment. We have tested the installation package on Tesla V100 and Tesla T4.
44 | ```
45 | git clone https://github.com/NeuralCarver/Michelangelo.git
46 | cd Michelangelo
47 | conda create --name Michelangelo python=3.9
48 | conda activate Michelangelo
49 | pip install -r requirements.txt
50 | ```
51 |
52 | ### Checkpoints
53 | Pleasae download weights from Hugging Face Model Space and put it to root folder. We have also uploaded the weights related to CLIP to facilitate quick usage.
54 |
55 |
56 |
57 | Tips for debugging configureation
58 |
59 |
60 | - If something goes wrong in the environment configuration process unfortunately, the user may consider skipping those packages, such as pysdf, torch-cluster, and torch-scatter. These packages will not affect the execution of the commands we provide.
61 | - If you encounter any issues while downloading CLIP, you can consider downloading it from [CLIP's Hugging Face page](https://huggingface.co/openai/clip-vit-large-patch14). Once the download is complete, remember to modify line [26](https://github.com/NeuralCarver/Michelangelo/blob/b53fa004cd4aeb0f4eb4d159ecec8489a4450dab/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml#L26C1-L26C76) and line [34](https://github.com/NeuralCarver/Michelangelo/blob/b53fa004cd4aeb0f4eb4d159ecec8489a4450dab/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml#L34) in the config file for providing correct path of CLIP.
62 | - From [issue 6](https://github.com/NeuralCarver/Michelangelo/issues/6#issuecomment-1913513382). For Windows users, running wsl2 + ubuntu 22.04, will have issues. As discussed in [issue 786](https://github.com/microsoft/WSL/issues/8587) it is just a matter to add this in the .bashrc:
63 | ```
64 | export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH.
65 | ```
66 |
67 |
68 | ## ⚡ Quick Start
69 |
70 | ### Inference
71 |
72 | #### Reconstruction a 3D shape
73 | ```
74 | ./scripts/inference/reconstruction.sh
75 | ```
76 |
77 | #### Image-conditioned shape generation
78 | ```
79 | ./scripts/inference/image2mesh.sh
80 | ```
81 |
82 | #### Text-conditioned shape generation
83 | ```
84 | ./scripts/inference/text2mesh.sh
85 | ```
86 |
87 | #### Simply run all the scripts
88 | ```
89 | ./scripts/infer.sh
90 | ```
91 |
92 |
93 | ## ❓ FAQ
94 |
95 | ## Citation
96 |
97 | If you find our code or paper helps, please consider citing:
98 |
99 | ```bibtex
100 | @inproceedings{
101 | zhao2023michelangelo,
102 | title={Michelangelo: Conditional 3D Shape Generation based on Shape-Image-Text Aligned Latent Representation},
103 | author={Zibo Zhao and Wen Liu and Xin Chen and Xianfang Zeng and Rui Wang and Pei Cheng and BIN FU and Tao Chen and Gang YU and Shenghua Gao},
104 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
105 | year={2023},
106 | url={https://openreview.net/forum?id=xmxgMij3LY}
107 | }
108 | ```
109 |
110 | ## License
111 |
112 | This code is distributed under an [GPL-3.0 license](LICENSE).
113 |
114 |
--------------------------------------------------------------------------------
/configs/aligned_shape_latents/shapevae-256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
3 | params:
4 | shape_module_cfg:
5 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
6 | params:
7 | num_latents: 256
8 | embed_dim: 64
9 | point_feats: 3 # normal
10 | num_freqs: 8
11 | include_pi: false
12 | heads: 12
13 | width: 768
14 | num_encoder_layers: 8
15 | num_decoder_layers: 16
16 | use_ln_post: true
17 | init_scale: 0.25
18 | qkv_bias: false
19 | use_checkpoint: true
20 | aligned_module_cfg:
21 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
22 | params:
23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
24 |
25 | loss_cfg:
26 | target: michelangelo.models.tsal.loss.ContrastKLNearFar
27 | params:
28 | contrast_weight: 0.1
29 | near_weight: 0.1
30 | kl_weight: 0.001
31 |
32 | optimizer_cfg:
33 | optimizer:
34 | target: torch.optim.AdamW
35 | params:
36 | betas: [0.9, 0.99]
37 | eps: 1.e-6
38 | weight_decay: 1.e-2
39 |
40 | scheduler:
41 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
42 | params:
43 | warm_up_steps: 5000
44 | f_start: 1.e-6
45 | f_min: 1.e-3
46 | f_max: 1.0
47 |
--------------------------------------------------------------------------------
/configs/image_cond_diffuser_asl/image-ASLDM-256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3 | params:
4 | first_stage_config:
5 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6 | params:
7 | shape_module_cfg:
8 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9 | params:
10 | num_latents: &num_latents 256
11 | embed_dim: &embed_dim 64
12 | point_feats: 3 # normal
13 | num_freqs: 8
14 | include_pi: false
15 | heads: 12
16 | width: 768
17 | num_encoder_layers: 8
18 | num_decoder_layers: 16
19 | use_ln_post: true
20 | init_scale: 0.25
21 | qkv_bias: false
22 | use_checkpoint: false
23 | aligned_module_cfg:
24 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25 | params:
26 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27 |
28 | loss_cfg:
29 | target: torch.nn.Identity
30 |
31 | cond_stage_config:
32 | target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder
33 | params:
34 | version: "./checkpoints/clip/clip-vit-large-patch14"
35 | zero_embedding_radio: 0.1
36 |
37 | first_stage_key: "surface"
38 | cond_stage_key: "image"
39 | scale_by_std: false
40 |
41 | denoiser_cfg:
42 | target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
43 | params:
44 | input_channels: *embed_dim
45 | output_channels: *embed_dim
46 | n_ctx: *num_latents
47 | width: 768
48 | layers: 6 # 2 * 6 + 1 = 13
49 | heads: 12
50 | context_dim: 1024
51 | init_scale: 1.0
52 | skip_ln: true
53 | use_checkpoint: true
54 |
55 | scheduler_cfg:
56 | guidance_scale: 7.5
57 | num_inference_steps: 50
58 | eta: 0.0
59 |
60 | noise:
61 | target: diffusers.schedulers.DDPMScheduler
62 | params:
63 | num_train_timesteps: 1000
64 | beta_start: 0.00085
65 | beta_end: 0.012
66 | beta_schedule: "scaled_linear"
67 | variance_type: "fixed_small"
68 | clip_sample: false
69 | denoise:
70 | target: diffusers.schedulers.DDIMScheduler
71 | params:
72 | num_train_timesteps: 1000
73 | beta_start: 0.00085
74 | beta_end: 0.012
75 | beta_schedule: "scaled_linear"
76 | clip_sample: false # clip sample to -1~1
77 | set_alpha_to_one: false
78 | steps_offset: 1
79 |
80 | optimizer_cfg:
81 | optimizer:
82 | target: torch.optim.AdamW
83 | params:
84 | betas: [0.9, 0.99]
85 | eps: 1.e-6
86 | weight_decay: 1.e-2
87 |
88 | scheduler:
89 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
90 | params:
91 | warm_up_steps: 5000
92 | f_start: 1.e-6
93 | f_min: 1.e-3
94 | f_max: 1.0
95 |
96 | loss_cfg:
97 | loss_type: "mse"
98 |
--------------------------------------------------------------------------------
/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3 | params:
4 | first_stage_config:
5 | target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6 | params:
7 | shape_module_cfg:
8 | target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9 | params:
10 | num_latents: &num_latents 256
11 | embed_dim: &embed_dim 64
12 | point_feats: 3 # normal
13 | num_freqs: 8
14 | include_pi: false
15 | heads: 12
16 | width: 768
17 | num_encoder_layers: 8
18 | num_decoder_layers: 16
19 | use_ln_post: true
20 | init_scale: 0.25
21 | qkv_bias: false
22 | use_checkpoint: true
23 | aligned_module_cfg:
24 | target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25 | params:
26 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27 |
28 | loss_cfg:
29 | target: torch.nn.Identity
30 |
31 | cond_stage_config:
32 | target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder
33 | params:
34 | version: "./checkpoints/clip/clip-vit-large-patch14"
35 | zero_embedding_radio: 0.1
36 | max_length: 77
37 |
38 | first_stage_key: "surface"
39 | cond_stage_key: "text"
40 | scale_by_std: false
41 |
42 | denoiser_cfg:
43 | target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
44 | params:
45 | input_channels: *embed_dim
46 | output_channels: *embed_dim
47 | n_ctx: *num_latents
48 | width: 768
49 | layers: 8 # 2 * 6 + 1 = 13
50 | heads: 12
51 | context_dim: 768
52 | init_scale: 1.0
53 | skip_ln: true
54 | use_checkpoint: true
55 |
56 | scheduler_cfg:
57 | guidance_scale: 7.5
58 | num_inference_steps: 50
59 | eta: 0.0
60 |
61 | noise:
62 | target: diffusers.schedulers.DDPMScheduler
63 | params:
64 | num_train_timesteps: 1000
65 | beta_start: 0.00085
66 | beta_end: 0.012
67 | beta_schedule: "scaled_linear"
68 | variance_type: "fixed_small"
69 | clip_sample: false
70 | denoise:
71 | target: diffusers.schedulers.DDIMScheduler
72 | params:
73 | num_train_timesteps: 1000
74 | beta_start: 0.00085
75 | beta_end: 0.012
76 | beta_schedule: "scaled_linear"
77 | clip_sample: false # clip sample to -1~1
78 | set_alpha_to_one: false
79 | steps_offset: 1
80 |
81 | optimizer_cfg:
82 | optimizer:
83 | target: torch.optim.AdamW
84 | params:
85 | betas: [0.9, 0.99]
86 | eps: 1.e-6
87 | weight_decay: 1.e-2
88 |
89 | scheduler:
90 | target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
91 | params:
92 | warm_up_steps: 5000
93 | f_start: 1.e-6
94 | f_min: 1.e-3
95 | f_max: 1.0
96 |
97 | loss_cfg:
98 | loss_type: "mse"
--------------------------------------------------------------------------------
/example_data/image/car.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralCarver/Michelangelo/6d83b0bacef92715dd5179d45647ed9a3d39bc95/example_data/image/car.jpg
--------------------------------------------------------------------------------
/example_data/surface/surface.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuralCarver/Michelangelo/6d83b0bacef92715dd5179d45647ed9a3d39bc95/example_data/surface/surface.npz
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import time
4 | from collections import OrderedDict
5 | from typing import Optional, List
6 | import argparse
7 | from functools import partial
8 |
9 | from einops import repeat, rearrange
10 | import numpy as np
11 | from PIL import Image
12 | import trimesh
13 | import cv2
14 |
15 | import torch
16 | import pytorch_lightning as pl
17 |
18 | from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
19 | from michelangelo.models.tsal.inference_utils import extract_geometry
20 | from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
21 | from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
22 | from michelangelo.utils.visualizers import html_util
23 |
24 | def load_model(args):
25 |
26 | model_config = get_config_from_file(args.config_path)
27 | if hasattr(model_config, "model"):
28 | model_config = model_config.model
29 |
30 | model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path)
31 | model = model.cuda()
32 | model = model.eval()
33 |
34 | return model
35 |
36 | def load_surface(fp):
37 |
38 | with np.load(args.pointcloud_path) as input_pc:
39 | surface = input_pc['points']
40 | normal = input_pc['normals']
41 |
42 | rng = np.random.default_rng()
43 | ind = rng.choice(surface.shape[0], 4096, replace=False)
44 | surface = torch.FloatTensor(surface[ind])
45 | normal = torch.FloatTensor(normal[ind])
46 |
47 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
48 |
49 | return surface
50 |
51 | def prepare_image(args, number_samples=2):
52 |
53 | image = cv2.imread(f"{args.image_path}")
54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55 |
56 | image_pt = torch.tensor(image).float()
57 | image_pt = image_pt / 255 * 2 - 1
58 | image_pt = rearrange(image_pt, "h w c -> c h w")
59 |
60 | image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples)
61 |
62 | return image_pt
63 |
64 | def save_output(args, mesh_outputs):
65 |
66 | os.makedirs(args.output_dir, exist_ok=True)
67 | for i, mesh in enumerate(mesh_outputs):
68 | mesh.mesh_f = mesh.mesh_f[:, ::-1]
69 | mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
70 |
71 | name = str(i) + "_out_mesh.obj"
72 | mesh_output.export(os.path.join(args.output_dir, name), include_normals=True)
73 |
74 | print(f'-----------------------------------------------------------------------------')
75 | print(f'>>> Finished and mesh saved in {args.output_dir}')
76 | print(f'-----------------------------------------------------------------------------')
77 |
78 | return 0
79 |
80 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
81 |
82 | surface = load_surface(args.pointcloud_path)
83 |
84 | # encoding
85 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
86 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
87 |
88 | # decoding
89 | latents = model.model.shape_model.decode(shape_zq)
90 | geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
91 |
92 | # reconstruction
93 | mesh_v_f, has_surface = extract_geometry(
94 | geometric_func=geometric_func,
95 | device=surface.device,
96 | batch_size=surface.shape[0],
97 | bounds=bounds,
98 | octree_depth=octree_depth,
99 | num_chunks=num_chunks,
100 | )
101 | recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1])
102 |
103 | # save
104 | os.makedirs(args.output_dir, exist_ok=True)
105 | recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj'))
106 |
107 | print(f'-----------------------------------------------------------------------------')
108 | print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}')
109 | print(f'-----------------------------------------------------------------------------')
110 |
111 | return 0
112 |
113 | def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7):
114 |
115 | sample_inputs = {
116 | "image": prepare_image(args)
117 | }
118 |
119 | mesh_outputs = model.sample(
120 | sample_inputs,
121 | sample_times=1,
122 | guidance_scale=guidance_scale,
123 | return_intermediates=False,
124 | bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
125 | octree_depth=octree_depth,
126 | )[0]
127 |
128 | save_output(args, mesh_outputs)
129 |
130 | return 0
131 |
132 | def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7):
133 |
134 | sample_inputs = {
135 | "text": [args.text] * num_samples
136 | }
137 | mesh_outputs = model.sample(
138 | sample_inputs,
139 | sample_times=1,
140 | guidance_scale=guidance_scale,
141 | return_intermediates=False,
142 | bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
143 | octree_depth=octree_depth,
144 | )[0]
145 |
146 | save_output(args, mesh_outputs)
147 |
148 | return 0
149 |
150 | task_dick = {
151 | 'reconstruction': reconstruction,
152 | 'image2mesh': image2mesh,
153 | 'text2mesh': text2mesh,
154 | }
155 |
156 | if __name__ == "__main__":
157 | '''
158 | 1. Reconstruct point cloud
159 | 2. Image-conditioned generation
160 | 3. Text-conditioned generation
161 | '''
162 | parser = argparse.ArgumentParser()
163 | parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True)
164 | parser.add_argument("--config_path", type=str, required=True)
165 | parser.add_argument("--ckpt_path", type=str, required=True)
166 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
167 | parser.add_argument("--image_path", type=str, help='Path to the input image')
168 | parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
169 | parser.add_argument("--output_dir", type=str, default='./output')
170 | parser.add_argument("-s", "--seed", type=int, default=0)
171 | args = parser.parse_args()
172 |
173 | pl.seed_everything(args.seed)
174 |
175 | print(f'-----------------------------------------------------------------------------')
176 | print(f'>>> Running {args.task}')
177 | args.output_dir = os.path.join(args.output_dir, args.task)
178 | print(f'>>> Output directory: {args.output_dir}')
179 | print(f'-----------------------------------------------------------------------------')
180 |
181 | task_dick[args.task](args, load_model(args))
--------------------------------------------------------------------------------
/michelangelo/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/michelangelo/data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/michelangelo/data/templates.json:
--------------------------------------------------------------------------------
1 | {
2 | "shape": [
3 | "a point cloud model of {}.",
4 | "There is a {} in the scene.",
5 | "There is the {} in the scene.",
6 | "a photo of a {} in the scene.",
7 | "a photo of the {} in the scene.",
8 | "a photo of one {} in the scene.",
9 | "itap of a {}.",
10 | "itap of my {}.",
11 | "itap of the {}.",
12 | "a photo of a {}.",
13 | "a photo of my {}.",
14 | "a photo of the {}.",
15 | "a photo of one {}.",
16 | "a photo of many {}.",
17 | "a good photo of a {}.",
18 | "a good photo of the {}.",
19 | "a bad photo of a {}.",
20 | "a bad photo of the {}.",
21 | "a photo of a nice {}.",
22 | "a photo of the nice {}.",
23 | "a photo of a cool {}.",
24 | "a photo of the cool {}.",
25 | "a photo of a weird {}.",
26 | "a photo of the weird {}.",
27 | "a photo of a small {}.",
28 | "a photo of the small {}.",
29 | "a photo of a large {}.",
30 | "a photo of the large {}.",
31 | "a photo of a clean {}.",
32 | "a photo of the clean {}.",
33 | "a photo of a dirty {}.",
34 | "a photo of the dirty {}.",
35 | "a bright photo of a {}.",
36 | "a bright photo of the {}.",
37 | "a dark photo of a {}.",
38 | "a dark photo of the {}.",
39 | "a photo of a hard to see {}.",
40 | "a photo of the hard to see {}.",
41 | "a low resolution photo of a {}.",
42 | "a low resolution photo of the {}.",
43 | "a cropped photo of a {}.",
44 | "a cropped photo of the {}.",
45 | "a close-up photo of a {}.",
46 | "a close-up photo of the {}.",
47 | "a jpeg corrupted photo of a {}.",
48 | "a jpeg corrupted photo of the {}.",
49 | "a blurry photo of a {}.",
50 | "a blurry photo of the {}.",
51 | "a pixelated photo of a {}.",
52 | "a pixelated photo of the {}.",
53 | "a black and white photo of the {}.",
54 | "a black and white photo of a {}",
55 | "a plastic {}.",
56 | "the plastic {}.",
57 | "a toy {}.",
58 | "the toy {}.",
59 | "a plushie {}.",
60 | "the plushie {}.",
61 | "a cartoon {}.",
62 | "the cartoon {}.",
63 | "an embroidered {}.",
64 | "the embroidered {}.",
65 | "a painting of the {}.",
66 | "a painting of a {}."
67 | ]
68 |
69 | }
--------------------------------------------------------------------------------
/michelangelo/data/transforms.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import time
4 | import numpy as np
5 | import warnings
6 | import random
7 | from omegaconf.listconfig import ListConfig
8 | from webdataset import pipelinefilter
9 | import torch
10 | import torchvision.transforms.functional as TVF
11 | from torchvision.transforms import InterpolationMode
12 | from torchvision.transforms.transforms import _interpolation_modes_from_int
13 | from typing import Sequence
14 |
15 | from michelangelo.utils import instantiate_from_config
16 |
17 |
18 | def _uid_buffer_pick(buf_dict, rng):
19 | uid_keys = list(buf_dict.keys())
20 | selected_uid = rng.choice(uid_keys)
21 | buf = buf_dict[selected_uid]
22 |
23 | k = rng.randint(0, len(buf) - 1)
24 | sample = buf[k]
25 | buf[k] = buf[-1]
26 | buf.pop()
27 |
28 | if len(buf) == 0:
29 | del buf_dict[selected_uid]
30 |
31 | return sample
32 |
33 |
34 | def _add_to_buf_dict(buf_dict, sample):
35 | key = sample["__key__"]
36 | uid, uid_sample_id = key.split("_")
37 | if uid not in buf_dict:
38 | buf_dict[uid] = []
39 | buf_dict[uid].append(sample)
40 |
41 | return buf_dict
42 |
43 |
44 | def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
45 | """Shuffle the data in the stream.
46 |
47 | This uses a buffer of size `bufsize`. Shuffling at
48 | startup is less random; this is traded off against
49 | yielding samples quickly.
50 |
51 | data: iterator
52 | bufsize: buffer size for shuffling
53 | returns: iterator
54 | rng: either random module or random.Random instance
55 |
56 | """
57 | if rng is None:
58 | rng = random.Random(int((os.getpid() + time.time()) * 1e9))
59 | initial = min(initial, bufsize)
60 | buf_dict = dict()
61 | current_samples = 0
62 | for sample in data:
63 | _add_to_buf_dict(buf_dict, sample)
64 | current_samples += 1
65 |
66 | if current_samples < bufsize:
67 | try:
68 | _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708
69 | current_samples += 1
70 | except StopIteration:
71 | pass
72 |
73 | if current_samples >= initial:
74 | current_samples -= 1
75 | yield _uid_buffer_pick(buf_dict, rng)
76 |
77 | while current_samples > 0:
78 | current_samples -= 1
79 | yield _uid_buffer_pick(buf_dict, rng)
80 |
81 |
82 | uid_shuffle = pipelinefilter(_uid_shuffle)
83 |
84 |
85 | class RandomSample(object):
86 | def __init__(self,
87 | num_volume_samples: int = 1024,
88 | num_near_samples: int = 1024):
89 |
90 | super().__init__()
91 |
92 | self.num_volume_samples = num_volume_samples
93 | self.num_near_samples = num_near_samples
94 |
95 | def __call__(self, sample):
96 | rng = np.random.default_rng()
97 |
98 | # 1. sample surface input
99 | total_surface = sample["surface"]
100 | ind = rng.choice(total_surface.shape[0], replace=False)
101 | surface = total_surface[ind]
102 |
103 | # 2. sample volume/near geometric points
104 | vol_points = sample["vol_points"]
105 | vol_label = sample["vol_label"]
106 | near_points = sample["near_points"]
107 | near_label = sample["near_label"]
108 |
109 | ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
110 | vol_points = vol_points[ind]
111 | vol_label = vol_label[ind]
112 | vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
113 |
114 | ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
115 | near_points = near_points[ind]
116 | near_label = near_label[ind]
117 | near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
118 |
119 | # concat sampled volume and near points
120 | geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
121 |
122 | sample = {
123 | "surface": surface,
124 | "geo_points": geo_points
125 | }
126 |
127 | return sample
128 |
129 |
130 | class SplitRandomSample(object):
131 | def __init__(self,
132 | use_surface_sample: bool = False,
133 | num_surface_samples: int = 4096,
134 | num_volume_samples: int = 1024,
135 | num_near_samples: int = 1024):
136 |
137 | super().__init__()
138 |
139 | self.use_surface_sample = use_surface_sample
140 | self.num_surface_samples = num_surface_samples
141 | self.num_volume_samples = num_volume_samples
142 | self.num_near_samples = num_near_samples
143 |
144 | def __call__(self, sample):
145 |
146 | rng = np.random.default_rng()
147 |
148 | # 1. sample surface input
149 | surface = sample["surface"]
150 |
151 | if self.use_surface_sample:
152 | replace = surface.shape[0] < self.num_surface_samples
153 | ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace)
154 | surface = surface[ind]
155 |
156 | # 2. sample volume/near geometric points
157 | vol_points = sample["vol_points"]
158 | vol_label = sample["vol_label"]
159 | near_points = sample["near_points"]
160 | near_label = sample["near_label"]
161 |
162 | ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
163 | vol_points = vol_points[ind]
164 | vol_label = vol_label[ind]
165 | vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
166 |
167 | ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
168 | near_points = near_points[ind]
169 | near_label = near_label[ind]
170 | near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
171 |
172 | # concat sampled volume and near points
173 | geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
174 |
175 | sample = {
176 | "surface": surface,
177 | "geo_points": geo_points
178 | }
179 |
180 | return sample
181 |
182 |
183 | class FeatureSelection(object):
184 |
185 | VALID_SURFACE_FEATURE_DIMS = {
186 | "none": [0, 1, 2], # xyz
187 | "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal
188 | "normal": [0, 1, 2, 6, 7, 8]
189 | }
190 |
191 | def __init__(self, surface_feature_type: str):
192 |
193 | self.surface_feature_type = surface_feature_type
194 | self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type]
195 |
196 | def __call__(self, sample):
197 | sample["surface"] = sample["surface"][:, self.surface_dims]
198 | return sample
199 |
200 |
201 | class AxisScaleTransform(object):
202 | def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
203 | assert isinstance(interval, (tuple, list, ListConfig))
204 | self.interval = interval
205 | self.min_val = interval[0]
206 | self.max_val = interval[1]
207 | self.inter_size = interval[1] - interval[0]
208 | self.jitter = jitter
209 | self.jitter_scale = jitter_scale
210 |
211 | def __call__(self, sample):
212 |
213 | surface = sample["surface"][..., 0:3]
214 | geo_points = sample["geo_points"][..., 0:3]
215 |
216 | scaling = torch.rand(1, 3) * self.inter_size + self.min_val
217 | # print(scaling)
218 | surface = surface * scaling
219 | geo_points = geo_points * scaling
220 |
221 | scale = (1 / torch.abs(surface).max().item()) * 0.999999
222 | surface *= scale
223 | geo_points *= scale
224 |
225 | if self.jitter:
226 | surface += self.jitter_scale * torch.randn_like(surface)
227 | surface.clamp_(min=-1.015, max=1.015)
228 |
229 | sample["surface"][..., 0:3] = surface
230 | sample["geo_points"][..., 0:3] = geo_points
231 |
232 | return sample
233 |
234 |
235 | class ToTensor(object):
236 |
237 | def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")):
238 | self.tensor_keys = tensor_keys
239 |
240 | def __call__(self, sample):
241 | for key in self.tensor_keys:
242 | if key not in sample:
243 | continue
244 |
245 | sample[key] = torch.tensor(sample[key], dtype=torch.float32)
246 |
247 | return sample
248 |
249 |
250 | class AxisScale(object):
251 | def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
252 | assert isinstance(interval, (tuple, list, ListConfig))
253 | self.interval = interval
254 | self.jitter = jitter
255 | self.jitter_scale = jitter_scale
256 |
257 | def __call__(self, surface, *args):
258 | scaling = torch.rand(1, 3) * 0.5 + 0.75
259 | # print(scaling)
260 | surface = surface * scaling
261 | scale = (1 / torch.abs(surface).max().item()) * 0.999999
262 | surface *= scale
263 |
264 | args_outputs = []
265 | for _arg in args:
266 | _arg = _arg * scaling * scale
267 | args_outputs.append(_arg)
268 |
269 | if self.jitter:
270 | surface += self.jitter_scale * torch.randn_like(surface)
271 | surface.clamp_(min=-1, max=1)
272 |
273 | if len(args) == 0:
274 | return surface
275 | else:
276 | return surface, *args_outputs
277 |
278 |
279 | class RandomResize(torch.nn.Module):
280 | """Apply randomly Resize with a given probability."""
281 |
282 | def __init__(
283 | self,
284 | size,
285 | resize_radio=(0.5, 1),
286 | allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR),
287 | interpolation=InterpolationMode.BICUBIC,
288 | max_size=None,
289 | antialias=None,
290 | ):
291 | super().__init__()
292 | if not isinstance(size, (int, Sequence)):
293 | raise TypeError(f"Size should be int or sequence. Got {type(size)}")
294 | if isinstance(size, Sequence) and len(size) not in (1, 2):
295 | raise ValueError("If size is a sequence, it should have 1 or 2 values")
296 |
297 | self.size = size
298 | self.max_size = max_size
299 | # Backward compatibility with integer value
300 | if isinstance(interpolation, int):
301 | warnings.warn(
302 | "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
303 | "Please use InterpolationMode enum."
304 | )
305 | interpolation = _interpolation_modes_from_int(interpolation)
306 |
307 | self.interpolation = interpolation
308 | self.antialias = antialias
309 |
310 | self.resize_radio = resize_radio
311 | self.allow_resize_interpolations = allow_resize_interpolations
312 |
313 | def random_resize_params(self):
314 | radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0]
315 |
316 | if isinstance(self.size, int):
317 | size = int(self.size * radio)
318 | elif isinstance(self.size, Sequence):
319 | size = list(self.size)
320 | size = (int(size[0] * radio), int(size[1] * radio))
321 | else:
322 | raise RuntimeError()
323 |
324 | interpolation = self.allow_resize_interpolations[
325 | torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,))
326 | ]
327 | return size, interpolation
328 |
329 | def forward(self, img):
330 | size, interpolation = self.random_resize_params()
331 | img = TVF.resize(img, size, interpolation, self.max_size, self.antialias)
332 | img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
333 | return img
334 |
335 | def __repr__(self) -> str:
336 | detail = f"(size={self.size}, interpolation={self.interpolation.value},"
337 | detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}"
338 | return f"{self.__class__.__name__}{detail}"
339 |
340 |
341 | class Compose(object):
342 | """Composes several transforms together. This transform does not support torchscript.
343 | Please, see the note below.
344 |
345 | Args:
346 | transforms (list of ``Transform`` objects): list of transforms to compose.
347 |
348 | Example:
349 | >>> transforms.Compose([
350 | >>> transforms.CenterCrop(10),
351 | >>> transforms.ToTensor(),
352 | >>> ])
353 |
354 | .. note::
355 | In order to script the transformations, please use ``torch.nn.Sequential`` as below.
356 |
357 | >>> transforms = torch.nn.Sequential(
358 | >>> transforms.CenterCrop(10),
359 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
360 | >>> )
361 | >>> scripted_transforms = torch.jit.script(transforms)
362 |
363 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
364 | `lambda` functions or ``PIL.Image``.
365 |
366 | """
367 |
368 | def __init__(self, transforms):
369 | self.transforms = transforms
370 |
371 | def __call__(self, *args):
372 | for t in self.transforms:
373 | args = t(*args)
374 | return args
375 |
376 | def __repr__(self):
377 | format_string = self.__class__.__name__ + '('
378 | for t in self.transforms:
379 | format_string += '\n'
380 | format_string += ' {0}'.format(t)
381 | format_string += '\n)'
382 | return format_string
383 |
384 |
385 | def identity(*args, **kwargs):
386 | if len(args) == 1:
387 | return args[0]
388 | else:
389 | return args
390 |
391 |
392 | def build_transforms(cfg):
393 |
394 | if cfg is None:
395 | return identity
396 |
397 | transforms = []
398 |
399 | for transform_name, cfg_instance in cfg.items():
400 | transform_instance = instantiate_from_config(cfg_instance)
401 | transforms.append(transform_instance)
402 | print(f"Build transform: {transform_instance}")
403 |
404 | transforms = Compose(transforms)
405 |
406 | return transforms
407 |
408 |
--------------------------------------------------------------------------------
/michelangelo/data/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def worker_init_fn(_):
8 | worker_info = torch.utils.data.get_worker_info()
9 | worker_id = worker_info.id
10 |
11 | # dataset = worker_info.dataset
12 | # split_size = dataset.num_records // worker_info.num_workers
13 | # # reset num_records to the true number to retain reliable length information
14 | # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
15 | # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
16 | # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
17 |
18 | return np.random.seed(np.random.get_state()[1][0] + worker_id)
19 |
20 |
21 | def collation_fn(samples, combine_tensors=True, combine_scalars=True):
22 | """
23 |
24 | Args:
25 | samples (list[dict]):
26 | combine_tensors:
27 | combine_scalars:
28 |
29 | Returns:
30 |
31 | """
32 |
33 | result = {}
34 |
35 | keys = samples[0].keys()
36 |
37 | for key in keys:
38 | result[key] = []
39 |
40 | for sample in samples:
41 | for key in keys:
42 | val = sample[key]
43 | result[key].append(val)
44 |
45 | for key in keys:
46 | val_list = result[key]
47 | if isinstance(val_list[0], (int, float)):
48 | if combine_scalars:
49 | result[key] = np.array(result[key])
50 |
51 | elif isinstance(val_list[0], torch.Tensor):
52 | if combine_tensors:
53 | result[key] = torch.stack(val_list)
54 |
55 | elif isinstance(val_list[0], np.ndarray):
56 | if combine_tensors:
57 | result[key] = np.stack(val_list)
58 |
59 | return result
60 |
--------------------------------------------------------------------------------
/michelangelo/graphics/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/michelangelo/graphics/primitives/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .volume import generate_dense_grid_points
4 |
5 | from .mesh import (
6 | MeshOutput,
7 | save_obj,
8 | savemeshtes2
9 | )
10 |
--------------------------------------------------------------------------------
/michelangelo/graphics/primitives/mesh.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import cv2
5 | import numpy as np
6 | import PIL.Image
7 | from typing import Optional
8 |
9 | import trimesh
10 |
11 |
12 | def save_obj(pointnp_px3, facenp_fx3, fname):
13 | fid = open(fname, "w")
14 | write_str = ""
15 | for pidx, p in enumerate(pointnp_px3):
16 | pp = p
17 | write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
18 |
19 | for i, f in enumerate(facenp_fx3):
20 | f1 = f + 1
21 | write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
22 | fid.write(write_str)
23 | fid.close()
24 | return
25 |
26 |
27 | def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
28 | fol, na = os.path.split(fname)
29 | na, _ = os.path.splitext(na)
30 |
31 | matname = "%s/%s.mtl" % (fol, na)
32 | fid = open(matname, "w")
33 | fid.write("newmtl material_0\n")
34 | fid.write("Kd 1 1 1\n")
35 | fid.write("Ka 0 0 0\n")
36 | fid.write("Ks 0.4 0.4 0.4\n")
37 | fid.write("Ns 10\n")
38 | fid.write("illum 2\n")
39 | fid.write("map_Kd %s.png\n" % na)
40 | fid.close()
41 | ####
42 |
43 | fid = open(fname, "w")
44 | fid.write("mtllib %s.mtl\n" % na)
45 |
46 | for pidx, p in enumerate(pointnp_px3):
47 | pp = p
48 | fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
49 |
50 | for pidx, p in enumerate(tcoords_px2):
51 | pp = p
52 | fid.write("vt %f %f\n" % (pp[0], pp[1]))
53 |
54 | fid.write("usemtl material_0\n")
55 | for i, f in enumerate(facenp_fx3):
56 | f1 = f + 1
57 | f2 = facetex_fx3[i] + 1
58 | fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
59 | fid.close()
60 |
61 | PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
62 | os.path.join(fol, "%s.png" % na))
63 |
64 | return
65 |
66 |
67 | class MeshOutput(object):
68 |
69 | def __init__(self,
70 | mesh_v: np.ndarray,
71 | mesh_f: np.ndarray,
72 | vertex_colors: Optional[np.ndarray] = None,
73 | uvs: Optional[np.ndarray] = None,
74 | mesh_tex_idx: Optional[np.ndarray] = None,
75 | tex_map: Optional[np.ndarray] = None):
76 |
77 | self.mesh_v = mesh_v
78 | self.mesh_f = mesh_f
79 | self.vertex_colors = vertex_colors
80 | self.uvs = uvs
81 | self.mesh_tex_idx = mesh_tex_idx
82 | self.tex_map = tex_map
83 |
84 | def contain_uv_texture(self):
85 | return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
86 |
87 | def contain_vertex_colors(self):
88 | return self.vertex_colors is not None
89 |
90 | def export(self, fname):
91 |
92 | if self.contain_uv_texture():
93 | savemeshtes2(
94 | self.mesh_v,
95 | self.uvs,
96 | self.mesh_f,
97 | self.mesh_tex_idx,
98 | self.tex_map,
99 | fname
100 | )
101 |
102 | elif self.contain_vertex_colors():
103 | mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
104 | mesh_obj.export(fname)
105 |
106 | else:
107 | save_obj(
108 | self.mesh_v,
109 | self.mesh_f,
110 | fname
111 | )
112 |
113 |
114 |
115 |
--------------------------------------------------------------------------------
/michelangelo/graphics/primitives/volume.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 |
5 |
6 | def generate_dense_grid_points(bbox_min: np.ndarray,
7 | bbox_max: np.ndarray,
8 | octree_depth: int,
9 | indexing: str = "ij"):
10 | length = bbox_max - bbox_min
11 | num_cells = np.exp2(octree_depth)
12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
16 | xyz = np.stack((xs, ys, zs), axis=-1)
17 | xyz = xyz.reshape(-1, 3)
18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
19 |
20 | return xyz, grid_size, length
21 |
22 |
--------------------------------------------------------------------------------
/michelangelo/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/michelangelo/models/asl_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from omegaconf import DictConfig
4 | from typing import List, Tuple, Dict, Optional, Union
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.optim import lr_scheduler
10 | import pytorch_lightning as pl
11 | from pytorch_lightning.utilities import rank_zero_only
12 |
13 | from einops import rearrange
14 |
15 | from diffusers.schedulers import (
16 | DDPMScheduler,
17 | DDIMScheduler,
18 | KarrasVeScheduler,
19 | DPMSolverMultistepScheduler
20 | )
21 |
22 | from michelangelo.utils import instantiate_from_config
23 | # from michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule
24 | from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule
25 | from michelangelo.models.asl_diffusion.inference_utils import ddim_sample
26 |
27 | SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
28 |
29 |
30 | def disabled_train(self, mode=True):
31 | """Overwrite model.train with this function to make sure train/eval mode
32 | does not change anymore."""
33 | return self
34 |
35 |
36 | class ASLDiffuser(pl.LightningModule):
37 | first_stage_model: Optional[AlignedShapeAsLatentPLModule]
38 | # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
39 | model: nn.Module
40 |
41 | def __init__(self, *,
42 | first_stage_config,
43 | denoiser_cfg,
44 | scheduler_cfg,
45 | optimizer_cfg,
46 | loss_cfg,
47 | first_stage_key: str = "surface",
48 | cond_stage_key: str = "image",
49 | cond_stage_trainable: bool = True,
50 | scale_by_std: bool = False,
51 | z_scale_factor: float = 1.0,
52 | ckpt_path: Optional[str] = None,
53 | ignore_keys: Union[Tuple[str], List[str]] = ()):
54 |
55 | super().__init__()
56 |
57 | self.first_stage_key = first_stage_key
58 | self.cond_stage_key = cond_stage_key
59 | self.cond_stage_trainable = cond_stage_trainable
60 |
61 | # 1. initialize first stage.
62 | # Note: the condition model contained in the first stage model.
63 | self.first_stage_config = first_stage_config
64 | self.first_stage_model = None
65 | # self.instantiate_first_stage(first_stage_config)
66 |
67 | # 2. initialize conditional stage
68 | # self.instantiate_cond_stage(cond_stage_config)
69 | self.cond_stage_model = {
70 | "image": self.encode_image,
71 | "image_unconditional_embedding": self.empty_img_cond,
72 | "text": self.encode_text,
73 | "text_unconditional_embedding": self.empty_text_cond,
74 | "surface": self.encode_surface,
75 | "surface_unconditional_embedding": self.empty_surface_cond,
76 | }
77 |
78 | # 3. diffusion model
79 | self.model = instantiate_from_config(
80 | denoiser_cfg, device=None, dtype=None
81 | )
82 |
83 | self.optimizer_cfg = optimizer_cfg
84 |
85 | # 4. scheduling strategy
86 | self.scheduler_cfg = scheduler_cfg
87 |
88 | self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
89 | self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
90 |
91 | # 5. loss configures
92 | self.loss_cfg = loss_cfg
93 |
94 | self.scale_by_std = scale_by_std
95 | if scale_by_std:
96 | self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
97 | else:
98 | self.z_scale_factor = z_scale_factor
99 |
100 | self.ckpt_path = ckpt_path
101 | if ckpt_path is not None:
102 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
103 |
104 | def instantiate_first_stage(self, config):
105 | model = instantiate_from_config(config)
106 | self.first_stage_model = model.eval()
107 | self.first_stage_model.train = disabled_train
108 | for param in self.first_stage_model.parameters():
109 | param.requires_grad = False
110 |
111 | self.first_stage_model = self.first_stage_model.to(self.device)
112 |
113 | # def instantiate_cond_stage(self, config):
114 | # if not self.cond_stage_trainable:
115 | # if config == "__is_first_stage__":
116 | # print("Using first stage also as cond stage.")
117 | # self.cond_stage_model = self.first_stage_model
118 | # elif config == "__is_unconditional__":
119 | # print(f"Training {self.__class__.__name__} as an unconditional model.")
120 | # self.cond_stage_model = None
121 | # # self.be_unconditional = True
122 | # else:
123 | # model = instantiate_from_config(config)
124 | # self.cond_stage_model = model.eval()
125 | # self.cond_stage_model.train = disabled_train
126 | # for param in self.cond_stage_model.parameters():
127 | # param.requires_grad = False
128 | # else:
129 | # assert config != "__is_first_stage__"
130 | # assert config != "__is_unconditional__"
131 | # model = instantiate_from_config(config)
132 | # self.cond_stage_model = model
133 |
134 | def init_from_ckpt(self, path, ignore_keys=()):
135 | state_dict = torch.load(path, map_location="cpu")["state_dict"]
136 |
137 | keys = list(state_dict.keys())
138 | for k in keys:
139 | for ik in ignore_keys:
140 | if k.startswith(ik):
141 | print("Deleting key {} from state_dict.".format(k))
142 | del state_dict[k]
143 |
144 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
145 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
146 | if len(missing) > 0:
147 | print(f"Missing Keys: {missing}")
148 | print(f"Unexpected Keys: {unexpected}")
149 |
150 | @property
151 | def zero_rank(self):
152 | if self._trainer:
153 | zero_rank = self.trainer.local_rank == 0
154 | else:
155 | zero_rank = True
156 |
157 | return zero_rank
158 |
159 | def configure_optimizers(self) -> Tuple[List, List]:
160 |
161 | lr = self.learning_rate
162 |
163 | trainable_parameters = list(self.model.parameters())
164 | # if the conditional encoder is trainable
165 |
166 | # if self.cond_stage_trainable:
167 | # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
168 | # trainable_parameters += conditioner_params
169 | # print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
170 |
171 | if self.optimizer_cfg is None:
172 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
173 | schedulers = []
174 | else:
175 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
176 | scheduler_func = instantiate_from_config(
177 | self.optimizer_cfg.scheduler,
178 | max_decay_steps=self.trainer.max_steps,
179 | lr_max=lr
180 | )
181 | scheduler = {
182 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
183 | "interval": "step",
184 | "frequency": 1
185 | }
186 | optimizers = [optimizer]
187 | schedulers = [scheduler]
188 |
189 | return optimizers, schedulers
190 |
191 | @torch.no_grad()
192 | def encode_text(self, text):
193 |
194 | b = text.shape[0]
195 | text_tokens = rearrange(text, "b t l -> (b t) l")
196 | text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
197 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
198 | text_embed = text_embed.mean(dim=1)
199 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
200 |
201 | return text_embed
202 |
203 | @torch.no_grad()
204 | def encode_image(self, img):
205 |
206 | return self.first_stage_model.model.encode_image_embed(img)
207 |
208 | @torch.no_grad()
209 | def encode_surface(self, surface):
210 |
211 | return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
212 |
213 | @torch.no_grad()
214 | def empty_text_cond(self, cond):
215 |
216 | return torch.zeros_like(cond, device=cond.device)
217 |
218 | @torch.no_grad()
219 | def empty_img_cond(self, cond):
220 |
221 | return torch.zeros_like(cond, device=cond.device)
222 |
223 | @torch.no_grad()
224 | def empty_surface_cond(self, cond):
225 |
226 | return torch.zeros_like(cond, device=cond.device)
227 |
228 | @torch.no_grad()
229 | def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
230 |
231 | z_q = self.first_stage_model.encode(surface, sample_posterior)
232 | z_q = self.z_scale_factor * z_q
233 |
234 | return z_q
235 |
236 | @torch.no_grad()
237 | def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
238 |
239 | z_q = 1. / self.z_scale_factor * z_q
240 | latents = self.first_stage_model.decode(z_q, **kwargs)
241 | return latents
242 |
243 | @rank_zero_only
244 | @torch.no_grad()
245 | def on_train_batch_start(self, batch, batch_idx):
246 | # only for very first batch
247 | if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
248 | and batch_idx == 0 and self.ckpt_path is None:
249 | # set rescale weight to 1./std of encodings
250 | print("### USING STD-RESCALING ###")
251 |
252 | z_q = self.encode_first_stage(batch[self.first_stage_key])
253 | z = z_q.detach()
254 |
255 | del self.z_scale_factor
256 | self.register_buffer("z_scale_factor", 1. / z.flatten().std())
257 | print(f"setting self.z_scale_factor to {self.z_scale_factor}")
258 |
259 | print("### USING STD-RESCALING ###")
260 |
261 | def compute_loss(self, model_outputs, split):
262 | """
263 |
264 | Args:
265 | model_outputs (dict):
266 | - x_0:
267 | - noise:
268 | - noise_prior:
269 | - noise_pred:
270 | - noise_pred_prior:
271 |
272 | split (str):
273 |
274 | Returns:
275 |
276 | """
277 |
278 | pred = model_outputs["pred"]
279 |
280 | if self.noise_scheduler.prediction_type == "epsilon":
281 | target = model_outputs["noise"]
282 | elif self.noise_scheduler.prediction_type == "sample":
283 | target = model_outputs["x_0"]
284 | else:
285 | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
286 |
287 | if self.loss_cfg.loss_type == "l1":
288 | simple = F.l1_loss(pred, target, reduction="mean")
289 | elif self.loss_cfg.loss_type in ["mse", "l2"]:
290 | simple = F.mse_loss(pred, target, reduction="mean")
291 | else:
292 | raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
293 |
294 | total_loss = simple
295 |
296 | loss_dict = {
297 | f"{split}/total_loss": total_loss.clone().detach(),
298 | f"{split}/simple": simple.detach(),
299 | }
300 |
301 | return total_loss, loss_dict
302 |
303 | def forward(self, batch):
304 | """
305 |
306 | Args:
307 | batch:
308 |
309 | Returns:
310 |
311 | """
312 |
313 | if self.first_stage_model is None:
314 | self.instantiate_first_stage(self.first_stage_config)
315 |
316 | latents = self.encode_first_stage(batch[self.first_stage_key])
317 |
318 | # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
319 |
320 | conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
321 |
322 | mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
323 | conditions = conditions * mask.to(conditions)
324 |
325 | # Sample noise that we"ll add to the latents
326 | # [batch_size, n_token, latent_dim]
327 | noise = torch.randn_like(latents)
328 | bs = latents.shape[0]
329 | # Sample a random timestep for each motion
330 | timesteps = torch.randint(
331 | 0,
332 | self.noise_scheduler.config.num_train_timesteps,
333 | (bs,),
334 | device=latents.device,
335 | )
336 | timesteps = timesteps.long()
337 | # Add noise to the latents according to the noise magnitude at each timestep
338 | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
339 |
340 | # diffusion model forward
341 | noise_pred = self.model(noisy_z, timesteps, conditions)
342 |
343 | diffusion_outputs = {
344 | "x_0": noisy_z,
345 | "noise": noise,
346 | "pred": noise_pred
347 | }
348 |
349 | return diffusion_outputs
350 |
351 | def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
352 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
353 | """
354 |
355 | Args:
356 | batch (dict): the batch sample, and it contains:
357 | - surface (torch.FloatTensor):
358 | - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
359 | - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
360 | - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
361 | - text (list of str):
362 |
363 | batch_idx (int):
364 |
365 | optimizer_idx (int):
366 |
367 | Returns:
368 | loss (torch.FloatTensor):
369 |
370 | """
371 |
372 | diffusion_outputs = self(batch)
373 |
374 | loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
375 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
376 |
377 | return loss
378 |
379 | def validation_step(self, batch: Dict[str, torch.FloatTensor],
380 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
381 | """
382 |
383 | Args:
384 | batch (dict): the batch sample, and it contains:
385 | - surface_pc (torch.FloatTensor): [n_pts, 4]
386 | - surface_feats (torch.FloatTensor): [n_pts, c]
387 | - text (list of str):
388 |
389 | batch_idx (int):
390 |
391 | optimizer_idx (int):
392 |
393 | Returns:
394 | loss (torch.FloatTensor):
395 |
396 | """
397 |
398 | diffusion_outputs = self(batch)
399 |
400 | loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
401 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
402 |
403 | return loss
404 |
405 | @torch.no_grad()
406 | def sample(self,
407 | batch: Dict[str, Union[torch.FloatTensor, List[str]]],
408 | sample_times: int = 1,
409 | steps: Optional[int] = None,
410 | guidance_scale: Optional[float] = None,
411 | eta: float = 0.0,
412 | return_intermediates: bool = False, **kwargs):
413 |
414 | if self.first_stage_model is None:
415 | self.instantiate_first_stage(self.first_stage_config)
416 |
417 | if steps is None:
418 | steps = self.scheduler_cfg.num_inference_steps
419 |
420 | if guidance_scale is None:
421 | guidance_scale = self.scheduler_cfg.guidance_scale
422 | do_classifier_free_guidance = guidance_scale > 0
423 |
424 | # conditional encode
425 | xc = batch[self.cond_stage_key]
426 | # cond = self.cond_stage_model[self.cond_stage_key](xc)
427 | cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
428 |
429 | if do_classifier_free_guidance:
430 | """
431 | Note: There are two kinds of uncond for text.
432 | 1: using "" as uncond text; (in SAL diffusion)
433 | 2: zeros_like(cond) as uncond text; (in MDM)
434 | """
435 | # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
436 | un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
437 | # un_cond = torch.zeros_like(cond, device=cond.device)
438 | cond = torch.cat([un_cond, cond], dim=0)
439 |
440 | outputs = []
441 | latents = None
442 |
443 | if not return_intermediates:
444 | for _ in range(sample_times):
445 | sample_loop = ddim_sample(
446 | self.denoise_scheduler,
447 | self.model,
448 | shape=self.first_stage_model.latent_shape,
449 | cond=cond,
450 | steps=steps,
451 | guidance_scale=guidance_scale,
452 | do_classifier_free_guidance=do_classifier_free_guidance,
453 | device=self.device,
454 | eta=eta,
455 | disable_prog=not self.zero_rank
456 | )
457 | for sample, t in sample_loop:
458 | latents = sample
459 | outputs.append(self.decode_first_stage(latents, **kwargs))
460 | else:
461 |
462 | sample_loop = ddim_sample(
463 | self.denoise_scheduler,
464 | self.model,
465 | shape=self.first_stage_model.latent_shape,
466 | cond=cond,
467 | steps=steps,
468 | guidance_scale=guidance_scale,
469 | do_classifier_free_guidance=do_classifier_free_guidance,
470 | device=self.device,
471 | eta=eta,
472 | disable_prog=not self.zero_rank
473 | )
474 |
475 | iter_size = steps // sample_times
476 | i = 0
477 | for sample, t in sample_loop:
478 | latents = sample
479 | if i % iter_size == 0 or i == steps - 1:
480 | outputs.append(self.decode_first_stage(latents, **kwargs))
481 | i += 1
482 |
483 | return outputs
484 |
--------------------------------------------------------------------------------
/michelangelo/models/asl_diffusion/asl_udt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | from typing import Optional
6 | from diffusers.models.embeddings import Timesteps
7 | import math
8 |
9 | from michelangelo.models.modules.transformer_blocks import MLP
10 | from michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer
11 |
12 |
13 | class ConditionalASLUDTDenoiser(nn.Module):
14 |
15 | def __init__(self, *,
16 | device: Optional[torch.device],
17 | dtype: Optional[torch.dtype],
18 | input_channels: int,
19 | output_channels: int,
20 | n_ctx: int,
21 | width: int,
22 | layers: int,
23 | heads: int,
24 | context_dim: int,
25 | context_ln: bool = True,
26 | skip_ln: bool = False,
27 | init_scale: float = 0.25,
28 | flip_sin_to_cos: bool = False,
29 | use_checkpoint: bool = False):
30 | super().__init__()
31 |
32 | self.use_checkpoint = use_checkpoint
33 |
34 | init_scale = init_scale * math.sqrt(1.0 / width)
35 |
36 | self.backbone = UNetDiffusionTransformer(
37 | device=device,
38 | dtype=dtype,
39 | n_ctx=n_ctx,
40 | width=width,
41 | layers=layers,
42 | heads=heads,
43 | skip_ln=skip_ln,
44 | init_scale=init_scale,
45 | use_checkpoint=use_checkpoint
46 | )
47 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
48 | self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
49 | self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
50 |
51 | # timestep embedding
52 | self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
53 | self.time_proj = MLP(
54 | device=device, dtype=dtype, width=width, init_scale=init_scale
55 | )
56 |
57 | self.context_embed = nn.Sequential(
58 | nn.LayerNorm(context_dim, device=device, dtype=dtype),
59 | nn.Linear(context_dim, width, device=device, dtype=dtype),
60 | )
61 |
62 | if context_ln:
63 | self.context_embed = nn.Sequential(
64 | nn.LayerNorm(context_dim, device=device, dtype=dtype),
65 | nn.Linear(context_dim, width, device=device, dtype=dtype),
66 | )
67 | else:
68 | self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
69 |
70 | def forward(self,
71 | model_input: torch.FloatTensor,
72 | timestep: torch.LongTensor,
73 | context: torch.FloatTensor):
74 |
75 | r"""
76 | Args:
77 | model_input (torch.FloatTensor): [bs, n_data, c]
78 | timestep (torch.LongTensor): [bs,]
79 | context (torch.FloatTensor): [bs, context_tokens, c]
80 |
81 | Returns:
82 | sample (torch.FloatTensor): [bs, n_data, c]
83 |
84 | """
85 |
86 | _, n_data, _ = model_input.shape
87 |
88 | # 1. time
89 | t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
90 |
91 | # 2. conditions projector
92 | context = self.context_embed(context)
93 |
94 | # 3. denoiser
95 | x = self.input_proj(model_input)
96 | x = torch.cat([t_emb, context, x], dim=1)
97 | x = self.backbone(x)
98 | x = self.ln_post(x)
99 | x = x[:, -n_data:]
100 | sample = self.output_proj(x)
101 |
102 | return sample
103 |
104 |
105 |
--------------------------------------------------------------------------------
/michelangelo/models/asl_diffusion/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class BaseDenoiser(nn.Module):
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x, t, context):
13 | raise NotImplementedError
14 |
--------------------------------------------------------------------------------
/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from omegaconf import DictConfig
4 | from typing import List, Tuple, Dict, Optional, Union
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.optim import lr_scheduler
10 | import pytorch_lightning as pl
11 | from pytorch_lightning.utilities import rank_zero_only
12 |
13 | from diffusers.schedulers import (
14 | DDPMScheduler,
15 | DDIMScheduler,
16 | KarrasVeScheduler,
17 | DPMSolverMultistepScheduler
18 | )
19 |
20 | from michelangelo.utils import instantiate_from_config
21 | from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule
22 | from michelangelo.models.asl_diffusion.inference_utils import ddim_sample
23 |
24 | SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
25 |
26 |
27 | def disabled_train(self, mode=True):
28 | """Overwrite model.train with this function to make sure train/eval mode
29 | does not change anymore."""
30 | return self
31 |
32 |
33 | class ClipASLDiffuser(pl.LightningModule):
34 | first_stage_model: Optional[AlignedShapeAsLatentPLModule]
35 | cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
36 | model: nn.Module
37 |
38 | def __init__(self, *,
39 | first_stage_config,
40 | cond_stage_config,
41 | denoiser_cfg,
42 | scheduler_cfg,
43 | optimizer_cfg,
44 | loss_cfg,
45 | first_stage_key: str = "surface",
46 | cond_stage_key: str = "image",
47 | scale_by_std: bool = False,
48 | z_scale_factor: float = 1.0,
49 | ckpt_path: Optional[str] = None,
50 | ignore_keys: Union[Tuple[str], List[str]] = ()):
51 |
52 | super().__init__()
53 |
54 | self.first_stage_key = first_stage_key
55 | self.cond_stage_key = cond_stage_key
56 |
57 | # 1. lazy initialize first stage
58 | self.instantiate_first_stage(first_stage_config)
59 |
60 | # 2. initialize conditional stage
61 | self.instantiate_cond_stage(cond_stage_config)
62 |
63 | # 3. diffusion model
64 | self.model = instantiate_from_config(
65 | denoiser_cfg, device=None, dtype=None
66 | )
67 |
68 | self.optimizer_cfg = optimizer_cfg
69 |
70 | # 4. scheduling strategy
71 | self.scheduler_cfg = scheduler_cfg
72 |
73 | self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
74 | self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
75 |
76 | # 5. loss configures
77 | self.loss_cfg = loss_cfg
78 |
79 | self.scale_by_std = scale_by_std
80 | if scale_by_std:
81 | self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
82 | else:
83 | self.z_scale_factor = z_scale_factor
84 |
85 | self.ckpt_path = ckpt_path
86 | if ckpt_path is not None:
87 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
88 |
89 | def instantiate_non_trainable_model(self, config):
90 | model = instantiate_from_config(config)
91 | model = model.eval()
92 | model.train = disabled_train
93 | for param in model.parameters():
94 | param.requires_grad = False
95 |
96 | return model
97 |
98 | def instantiate_first_stage(self, first_stage_config):
99 | self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
100 | self.first_stage_model.set_shape_model_only()
101 |
102 | def instantiate_cond_stage(self, cond_stage_config):
103 | self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
104 |
105 | def init_from_ckpt(self, path, ignore_keys=()):
106 | state_dict = torch.load(path, map_location="cpu")["state_dict"]
107 |
108 | keys = list(state_dict.keys())
109 | for k in keys:
110 | for ik in ignore_keys:
111 | if k.startswith(ik):
112 | print("Deleting key {} from state_dict.".format(k))
113 | del state_dict[k]
114 |
115 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
116 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
117 | if len(missing) > 0:
118 | print(f"Missing Keys: {missing}")
119 | print(f"Unexpected Keys: {unexpected}")
120 |
121 | @property
122 | def zero_rank(self):
123 | if self._trainer:
124 | zero_rank = self.trainer.local_rank == 0
125 | else:
126 | zero_rank = True
127 |
128 | return zero_rank
129 |
130 | def configure_optimizers(self) -> Tuple[List, List]:
131 |
132 | lr = self.learning_rate
133 |
134 | trainable_parameters = list(self.model.parameters())
135 | if self.optimizer_cfg is None:
136 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
137 | schedulers = []
138 | else:
139 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
140 | scheduler_func = instantiate_from_config(
141 | self.optimizer_cfg.scheduler,
142 | max_decay_steps=self.trainer.max_steps,
143 | lr_max=lr
144 | )
145 | scheduler = {
146 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
147 | "interval": "step",
148 | "frequency": 1
149 | }
150 | optimizers = [optimizer]
151 | schedulers = [scheduler]
152 |
153 | return optimizers, schedulers
154 |
155 | @torch.no_grad()
156 | def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
157 |
158 | z_q = self.first_stage_model.encode(surface, sample_posterior)
159 | z_q = self.z_scale_factor * z_q
160 |
161 | return z_q
162 |
163 | @torch.no_grad()
164 | def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
165 |
166 | z_q = 1. / self.z_scale_factor * z_q
167 | latents = self.first_stage_model.decode(z_q, **kwargs)
168 | return latents
169 |
170 | @rank_zero_only
171 | @torch.no_grad()
172 | def on_train_batch_start(self, batch, batch_idx):
173 | # only for very first batch
174 | if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
175 | and batch_idx == 0 and self.ckpt_path is None:
176 | # set rescale weight to 1./std of encodings
177 | print("### USING STD-RESCALING ###")
178 |
179 | z_q = self.encode_first_stage(batch[self.first_stage_key])
180 | z = z_q.detach()
181 |
182 | del self.z_scale_factor
183 | self.register_buffer("z_scale_factor", 1. / z.flatten().std())
184 | print(f"setting self.z_scale_factor to {self.z_scale_factor}")
185 |
186 | print("### USING STD-RESCALING ###")
187 |
188 | def compute_loss(self, model_outputs, split):
189 | """
190 |
191 | Args:
192 | model_outputs (dict):
193 | - x_0:
194 | - noise:
195 | - noise_prior:
196 | - noise_pred:
197 | - noise_pred_prior:
198 |
199 | split (str):
200 |
201 | Returns:
202 |
203 | """
204 |
205 | pred = model_outputs["pred"]
206 |
207 | if self.noise_scheduler.prediction_type == "epsilon":
208 | target = model_outputs["noise"]
209 | elif self.noise_scheduler.prediction_type == "sample":
210 | target = model_outputs["x_0"]
211 | else:
212 | raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
213 |
214 | if self.loss_cfg.loss_type == "l1":
215 | simple = F.l1_loss(pred, target, reduction="mean")
216 | elif self.loss_cfg.loss_type in ["mse", "l2"]:
217 | simple = F.mse_loss(pred, target, reduction="mean")
218 | else:
219 | raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
220 |
221 | total_loss = simple
222 |
223 | loss_dict = {
224 | f"{split}/total_loss": total_loss.clone().detach(),
225 | f"{split}/simple": simple.detach(),
226 | }
227 |
228 | return total_loss, loss_dict
229 |
230 | def forward(self, batch):
231 | """
232 |
233 | Args:
234 | batch:
235 |
236 | Returns:
237 |
238 | """
239 |
240 | latents = self.encode_first_stage(batch[self.first_stage_key])
241 | conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
242 |
243 | # Sample noise that we"ll add to the latents
244 | # [batch_size, n_token, latent_dim]
245 | noise = torch.randn_like(latents)
246 | bs = latents.shape[0]
247 | # Sample a random timestep for each motion
248 | timesteps = torch.randint(
249 | 0,
250 | self.noise_scheduler.config.num_train_timesteps,
251 | (bs,),
252 | device=latents.device,
253 | )
254 | timesteps = timesteps.long()
255 | # Add noise to the latents according to the noise magnitude at each timestep
256 | noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
257 |
258 | # diffusion model forward
259 | noise_pred = self.model(noisy_z, timesteps, conditions)
260 |
261 | diffusion_outputs = {
262 | "x_0": noisy_z,
263 | "noise": noise,
264 | "pred": noise_pred
265 | }
266 |
267 | return diffusion_outputs
268 |
269 | def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
270 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
271 | """
272 |
273 | Args:
274 | batch (dict): the batch sample, and it contains:
275 | - surface (torch.FloatTensor):
276 | - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
277 | - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
278 | - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
279 | - text (list of str):
280 |
281 | batch_idx (int):
282 |
283 | optimizer_idx (int):
284 |
285 | Returns:
286 | loss (torch.FloatTensor):
287 |
288 | """
289 |
290 | diffusion_outputs = self(batch)
291 |
292 | loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
293 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
294 |
295 | return loss
296 |
297 | def validation_step(self, batch: Dict[str, torch.FloatTensor],
298 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
299 | """
300 |
301 | Args:
302 | batch (dict): the batch sample, and it contains:
303 | - surface_pc (torch.FloatTensor): [n_pts, 4]
304 | - surface_feats (torch.FloatTensor): [n_pts, c]
305 | - text (list of str):
306 |
307 | batch_idx (int):
308 |
309 | optimizer_idx (int):
310 |
311 | Returns:
312 | loss (torch.FloatTensor):
313 |
314 | """
315 |
316 | diffusion_outputs = self(batch)
317 |
318 | loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
319 | self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
320 |
321 | return loss
322 |
323 | @torch.no_grad()
324 | def sample(self,
325 | batch: Dict[str, Union[torch.FloatTensor, List[str]]],
326 | sample_times: int = 1,
327 | steps: Optional[int] = None,
328 | guidance_scale: Optional[float] = None,
329 | eta: float = 0.0,
330 | return_intermediates: bool = False, **kwargs):
331 |
332 | if steps is None:
333 | steps = self.scheduler_cfg.num_inference_steps
334 |
335 | if guidance_scale is None:
336 | guidance_scale = self.scheduler_cfg.guidance_scale
337 | do_classifier_free_guidance = guidance_scale > 0
338 |
339 | # conditional encode
340 | xc = batch[self.cond_stage_key]
341 |
342 | # print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
343 |
344 | cond = self.cond_stage_model(xc)
345 |
346 | if do_classifier_free_guidance:
347 | un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
348 | cond = torch.cat([un_cond, cond], dim=0)
349 |
350 | outputs = []
351 | latents = None
352 |
353 | if not return_intermediates:
354 | for _ in range(sample_times):
355 | sample_loop = ddim_sample(
356 | self.denoise_scheduler,
357 | self.model,
358 | shape=self.first_stage_model.latent_shape,
359 | cond=cond,
360 | steps=steps,
361 | guidance_scale=guidance_scale,
362 | do_classifier_free_guidance=do_classifier_free_guidance,
363 | device=self.device,
364 | eta=eta,
365 | disable_prog=not self.zero_rank
366 | )
367 | for sample, t in sample_loop:
368 | latents = sample
369 | outputs.append(self.decode_first_stage(latents, **kwargs))
370 | else:
371 |
372 | sample_loop = ddim_sample(
373 | self.denoise_scheduler,
374 | self.model,
375 | shape=self.first_stage_model.latent_shape,
376 | cond=cond,
377 | steps=steps,
378 | guidance_scale=guidance_scale,
379 | do_classifier_free_guidance=do_classifier_free_guidance,
380 | device=self.device,
381 | eta=eta,
382 | disable_prog=not self.zero_rank
383 | )
384 |
385 | iter_size = steps // sample_times
386 | i = 0
387 | for sample, t in sample_loop:
388 | latents = sample
389 | if i % iter_size == 0 or i == steps - 1:
390 | outputs.append(self.decode_first_stage(latents, **kwargs))
391 | i += 1
392 |
393 | return outputs
394 |
--------------------------------------------------------------------------------
/michelangelo/models/asl_diffusion/inference_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from tqdm import tqdm
5 | from typing import Tuple, List, Union, Optional
6 | from diffusers.schedulers import DDIMScheduler
7 |
8 |
9 | __all__ = ["ddim_sample"]
10 |
11 |
12 | def ddim_sample(ddim_scheduler: DDIMScheduler,
13 | diffusion_model: torch.nn.Module,
14 | shape: Union[List[int], Tuple[int]],
15 | cond: torch.FloatTensor,
16 | steps: int,
17 | eta: float = 0.0,
18 | guidance_scale: float = 3.0,
19 | do_classifier_free_guidance: bool = True,
20 | generator: Optional[torch.Generator] = None,
21 | device: torch.device = "cuda:0",
22 | disable_prog: bool = True):
23 |
24 | assert steps > 0, f"{steps} must > 0."
25 |
26 | # init latents
27 | bsz = cond.shape[0]
28 | if do_classifier_free_guidance:
29 | bsz = bsz // 2
30 |
31 | latents = torch.randn(
32 | (bsz, *shape),
33 | generator=generator,
34 | device=cond.device,
35 | dtype=cond.dtype,
36 | )
37 | # scale the initial noise by the standard deviation required by the scheduler
38 | latents = latents * ddim_scheduler.init_noise_sigma
39 | # set timesteps
40 | ddim_scheduler.set_timesteps(steps)
41 | timesteps = ddim_scheduler.timesteps.to(device)
42 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
43 | # eta (η) is only used with the DDIMScheduler, and between [0, 1]
44 | extra_step_kwargs = {
45 | "eta": eta,
46 | "generator": generator
47 | }
48 |
49 | # reverse
50 | for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
51 | # expand the latents if we are doing classifier free guidance
52 | latent_model_input = (
53 | torch.cat([latents] * 2)
54 | if do_classifier_free_guidance
55 | else latents
56 | )
57 | # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
58 | # predict the noise residual
59 | timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
60 | timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
61 | noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
62 |
63 | # perform guidance
64 | if do_classifier_free_guidance:
65 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
66 | noise_pred = noise_pred_uncond + guidance_scale * (
67 | noise_pred_text - noise_pred_uncond
68 | )
69 | # text_embeddings_for_guidance = encoder_hidden_states.chunk(
70 | # 2)[1] if do_classifier_free_guidance else encoder_hidden_states
71 | # compute the previous noisy sample x_t -> x_t-1
72 | latents = ddim_scheduler.step(
73 | noise_pred, t, latents, **extra_step_kwargs
74 | ).prev_sample
75 |
76 | yield latents, t
77 |
78 |
79 | def karra_sample():
80 | pass
81 |
--------------------------------------------------------------------------------
/michelangelo/models/conditional_encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .clip import CLIPEncoder
4 |
--------------------------------------------------------------------------------
/michelangelo/models/conditional_encoders/clip.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | from dataclasses import dataclass
7 | from torchvision.transforms import Normalize
8 | from transformers import CLIPModel, CLIPTokenizer
9 | from transformers.utils import ModelOutput
10 | from typing import Iterable, Optional, Union, List
11 |
12 |
13 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
14 |
15 |
16 | @dataclass
17 | class CLIPEmbedOutput(ModelOutput):
18 | last_hidden_state: torch.FloatTensor = None
19 | pooler_output: torch.FloatTensor = None
20 | embeds: torch.FloatTensor = None
21 |
22 |
23 | class CLIPEncoder(torch.nn.Module):
24 |
25 | def __init__(self, model_path="openai/clip-vit-base-patch32"):
26 |
27 | super().__init__()
28 |
29 | # Load the CLIP model and processor
30 | self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
31 | self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
32 | self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33 |
34 | self.model.training = False
35 | for p in self.model.parameters():
36 | p.requires_grad = False
37 |
38 | @torch.no_grad()
39 | def encode_image(self, images: Iterable[Optional[ImageType]]):
40 | pixel_values = self.image_preprocess(images)
41 |
42 | vision_outputs = self.model.vision_model(pixel_values=pixel_values)
43 |
44 | pooler_output = vision_outputs[1] # pooled_output
45 | image_features = self.model.visual_projection(pooler_output)
46 |
47 | visual_embeds = CLIPEmbedOutput(
48 | last_hidden_state=vision_outputs.last_hidden_state,
49 | pooler_output=pooler_output,
50 | embeds=image_features
51 | )
52 |
53 | return visual_embeds
54 |
55 | @torch.no_grad()
56 | def encode_text(self, texts: List[str]):
57 | text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
58 |
59 | text_outputs = self.model.text_model(input_ids=text_inputs)
60 |
61 | pooler_output = text_outputs[1] # pooled_output
62 | text_features = self.model.text_projection(pooler_output)
63 |
64 | text_embeds = CLIPEmbedOutput(
65 | last_hidden_state=text_outputs.last_hidden_state,
66 | pooler_output=pooler_output,
67 | embeds=text_features
68 | )
69 |
70 | return text_embeds
71 |
72 | def forward(self,
73 | images: Iterable[Optional[ImageType]],
74 | texts: List[str]):
75 |
76 | visual_embeds = self.encode_image(images)
77 | text_embeds = self.encode_text(texts)
78 |
79 | return visual_embeds, text_embeds
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/michelangelo/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .checkpoint import checkpoint
4 |
--------------------------------------------------------------------------------
/michelangelo/models/modules/checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
4 | """
5 |
6 | import torch
7 | from typing import Callable, Iterable, Sequence, Union
8 |
9 |
10 | def checkpoint(
11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
12 | inputs: Sequence[torch.Tensor],
13 | params: Iterable[torch.Tensor],
14 | flag: bool,
15 | use_deepspeed: bool = False
16 | ):
17 | """
18 | Evaluate a function without caching intermediate activations, allowing for
19 | reduced memory at the expense of extra compute in the backward pass.
20 | :param func: the function to evaluate.
21 | :param inputs: the argument sequence to pass to `func`.
22 | :param params: a sequence of parameters `func` depends on but does not
23 | explicitly take as arguments.
24 | :param flag: if False, disable gradient checkpointing.
25 | :param use_deepspeed: if True, use deepspeed
26 | """
27 | if flag:
28 | if use_deepspeed:
29 | import deepspeed
30 | return deepspeed.checkpointing.checkpoint(func, *inputs)
31 |
32 | args = tuple(inputs) + tuple(params)
33 | return CheckpointFunction.apply(func, len(inputs), *args)
34 | else:
35 | return func(*inputs)
36 |
37 |
38 | class CheckpointFunction(torch.autograd.Function):
39 | @staticmethod
40 | @torch.cuda.amp.custom_fwd
41 | def forward(ctx, run_function, length, *args):
42 | ctx.run_function = run_function
43 | ctx.input_tensors = list(args[:length])
44 | ctx.input_params = list(args[length:])
45 |
46 | with torch.no_grad():
47 | output_tensors = ctx.run_function(*ctx.input_tensors)
48 | return output_tensors
49 |
50 | @staticmethod
51 | @torch.cuda.amp.custom_bwd
52 | def backward(ctx, *output_grads):
53 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
54 | with torch.enable_grad():
55 | # Fixes a bug where the first op in run_function modifies the
56 | # Tensor storage in place, which is not allowed for detach()'d
57 | # Tensors.
58 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
59 | output_tensors = ctx.run_function(*shallow_copies)
60 | input_grads = torch.autograd.grad(
61 | output_tensors,
62 | ctx.input_tensors + ctx.input_params,
63 | output_grads,
64 | allow_unused=True,
65 | )
66 | del ctx.input_tensors
67 | del ctx.input_params
68 | del output_tensors
69 | return (None, None) + input_grads
70 |
--------------------------------------------------------------------------------
/michelangelo/models/modules/diffusion_transformer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from typing import Optional
7 |
8 | from michelangelo.models.modules.checkpoint import checkpoint
9 | from michelangelo.models.modules.transformer_blocks import (
10 | init_linear,
11 | MLP,
12 | MultiheadCrossAttention,
13 | MultiheadAttention,
14 | ResidualAttentionBlock
15 | )
16 |
17 |
18 | class AdaLayerNorm(nn.Module):
19 | def __init__(self,
20 | device: torch.device,
21 | dtype: torch.dtype,
22 | width: int):
23 |
24 | super().__init__()
25 |
26 | self.silu = nn.SiLU(inplace=True)
27 | self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
28 | self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
29 |
30 | def forward(self, x, timestep):
31 | emb = self.linear(timestep)
32 | scale, shift = torch.chunk(emb, 2, dim=2)
33 | x = self.layernorm(x) * (1 + scale) + shift
34 | return x
35 |
36 |
37 | class DitBlock(nn.Module):
38 | def __init__(
39 | self,
40 | *,
41 | device: torch.device,
42 | dtype: torch.dtype,
43 | n_ctx: int,
44 | width: int,
45 | heads: int,
46 | context_dim: int,
47 | qkv_bias: bool = False,
48 | init_scale: float = 1.0,
49 | use_checkpoint: bool = False
50 | ):
51 | super().__init__()
52 |
53 | self.use_checkpoint = use_checkpoint
54 |
55 | self.attn = MultiheadAttention(
56 | device=device,
57 | dtype=dtype,
58 | n_ctx=n_ctx,
59 | width=width,
60 | heads=heads,
61 | init_scale=init_scale,
62 | qkv_bias=qkv_bias
63 | )
64 | self.ln_1 = AdaLayerNorm(device, dtype, width)
65 |
66 | if context_dim is not None:
67 | self.ln_2 = AdaLayerNorm(device, dtype, width)
68 | self.cross_attn = MultiheadCrossAttention(
69 | device=device,
70 | dtype=dtype,
71 | width=width,
72 | heads=heads,
73 | data_width=context_dim,
74 | init_scale=init_scale,
75 | qkv_bias=qkv_bias
76 | )
77 |
78 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
79 | self.ln_3 = AdaLayerNorm(device, dtype, width)
80 |
81 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
82 | return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
83 |
84 | def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
85 | x = x + self.attn(self.ln_1(x, t))
86 | if context is not None:
87 | x = x + self.cross_attn(self.ln_2(x, t), context)
88 | x = x + self.mlp(self.ln_3(x, t))
89 | return x
90 |
91 |
92 | class DiT(nn.Module):
93 | def __init__(
94 | self,
95 | *,
96 | device: Optional[torch.device],
97 | dtype: Optional[torch.dtype],
98 | n_ctx: int,
99 | width: int,
100 | layers: int,
101 | heads: int,
102 | context_dim: int,
103 | init_scale: float = 0.25,
104 | qkv_bias: bool = False,
105 | use_checkpoint: bool = False
106 | ):
107 | super().__init__()
108 | self.n_ctx = n_ctx
109 | self.width = width
110 | self.layers = layers
111 |
112 | self.resblocks = nn.ModuleList(
113 | [
114 | DitBlock(
115 | device=device,
116 | dtype=dtype,
117 | n_ctx=n_ctx,
118 | width=width,
119 | heads=heads,
120 | context_dim=context_dim,
121 | qkv_bias=qkv_bias,
122 | init_scale=init_scale,
123 | use_checkpoint=use_checkpoint
124 | )
125 | for _ in range(layers)
126 | ]
127 | )
128 |
129 | def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
130 | for block in self.resblocks:
131 | x = block(x, t, context)
132 | return x
133 |
134 |
135 | class UNetDiffusionTransformer(nn.Module):
136 | def __init__(
137 | self,
138 | *,
139 | device: Optional[torch.device],
140 | dtype: Optional[torch.dtype],
141 | n_ctx: int,
142 | width: int,
143 | layers: int,
144 | heads: int,
145 | init_scale: float = 0.25,
146 | qkv_bias: bool = False,
147 | skip_ln: bool = False,
148 | use_checkpoint: bool = False
149 | ):
150 | super().__init__()
151 |
152 | self.n_ctx = n_ctx
153 | self.width = width
154 | self.layers = layers
155 |
156 | self.encoder = nn.ModuleList()
157 | for _ in range(layers):
158 | resblock = ResidualAttentionBlock(
159 | device=device,
160 | dtype=dtype,
161 | n_ctx=n_ctx,
162 | width=width,
163 | heads=heads,
164 | init_scale=init_scale,
165 | qkv_bias=qkv_bias,
166 | use_checkpoint=use_checkpoint
167 | )
168 | self.encoder.append(resblock)
169 |
170 | self.middle_block = ResidualAttentionBlock(
171 | device=device,
172 | dtype=dtype,
173 | n_ctx=n_ctx,
174 | width=width,
175 | heads=heads,
176 | init_scale=init_scale,
177 | qkv_bias=qkv_bias,
178 | use_checkpoint=use_checkpoint
179 | )
180 |
181 | self.decoder = nn.ModuleList()
182 | for _ in range(layers):
183 | resblock = ResidualAttentionBlock(
184 | device=device,
185 | dtype=dtype,
186 | n_ctx=n_ctx,
187 | width=width,
188 | heads=heads,
189 | init_scale=init_scale,
190 | qkv_bias=qkv_bias,
191 | use_checkpoint=use_checkpoint
192 | )
193 | linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
194 | init_linear(linear, init_scale)
195 |
196 | layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
197 |
198 | self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
199 |
200 | def forward(self, x: torch.Tensor):
201 |
202 | enc_outputs = []
203 | for block in self.encoder:
204 | x = block(x)
205 | enc_outputs.append(x)
206 |
207 | x = self.middle_block(x)
208 |
209 | for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
210 | x = torch.cat([enc_outputs.pop(), x], dim=-1)
211 | x = linear(x)
212 |
213 | if layer_norm is not None:
214 | x = layer_norm(x)
215 |
216 | x = resblock(x)
217 |
218 | return x
219 |
--------------------------------------------------------------------------------
/michelangelo/models/modules/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Union, List
4 |
5 |
6 | class AbstractDistribution(object):
7 | def sample(self):
8 | raise NotImplementedError()
9 |
10 | def mode(self):
11 | raise NotImplementedError()
12 |
13 |
14 | class DiracDistribution(AbstractDistribution):
15 | def __init__(self, value):
16 | self.value = value
17 |
18 | def sample(self):
19 | return self.value
20 |
21 | def mode(self):
22 | return self.value
23 |
24 |
25 | class DiagonalGaussianDistribution(object):
26 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
27 | self.feat_dim = feat_dim
28 | self.parameters = parameters
29 |
30 | if isinstance(parameters, list):
31 | self.mean = parameters[0]
32 | self.logvar = parameters[1]
33 | else:
34 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
35 |
36 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
37 | self.deterministic = deterministic
38 | self.std = torch.exp(0.5 * self.logvar)
39 | self.var = torch.exp(self.logvar)
40 | if self.deterministic:
41 | self.var = self.std = torch.zeros_like(self.mean)
42 |
43 | def sample(self):
44 | x = self.mean + self.std * torch.randn_like(self.mean)
45 | return x
46 |
47 | def kl(self, other=None, dims=(1, 2, 3)):
48 | if self.deterministic:
49 | return torch.Tensor([0.])
50 | else:
51 | if other is None:
52 | return 0.5 * torch.mean(torch.pow(self.mean, 2)
53 | + self.var - 1.0 - self.logvar,
54 | dim=dims)
55 | else:
56 | return 0.5 * torch.mean(
57 | torch.pow(self.mean - other.mean, 2) / other.var
58 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
59 | dim=dims)
60 |
61 | def nll(self, sample, dims=(1, 2, 3)):
62 | if self.deterministic:
63 | return torch.Tensor([0.])
64 | logtwopi = np.log(2.0 * np.pi)
65 | return 0.5 * torch.sum(
66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
67 | dim=dims)
68 |
69 | def mode(self):
70 | return self.mean
71 |
72 |
73 | def normal_kl(mean1, logvar1, mean2, logvar2):
74 | """
75 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
76 | Compute the KL divergence between two gaussians.
77 | Shapes are automatically broadcasted, so batches can be compared to
78 | scalars, among other use cases.
79 | """
80 | tensor = None
81 | for obj in (mean1, logvar1, mean2, logvar2):
82 | if isinstance(obj, torch.Tensor):
83 | tensor = obj
84 | break
85 | assert tensor is not None, "at least one argument must be a Tensor"
86 |
87 | # Force variances to be Tensors. Broadcasting helps convert scalars to
88 | # Tensors, but it does not work for torch.exp().
89 | logvar1, logvar2 = [
90 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
91 | for x in (logvar1, logvar2)
92 | ]
93 |
94 | return 0.5 * (
95 | -1.0
96 | + logvar2
97 | - logvar1
98 | + torch.exp(logvar1 - logvar2)
99 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
100 | )
101 |
--------------------------------------------------------------------------------
/michelangelo/models/modules/embedder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import math
7 |
8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
9 |
10 |
11 | class FourierEmbedder(nn.Module):
12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
13 | each feature dimension of `x[..., i]` into:
14 | [
15 | sin(x[..., i]),
16 | sin(f_1*x[..., i]),
17 | sin(f_2*x[..., i]),
18 | ...
19 | sin(f_N * x[..., i]),
20 | cos(x[..., i]),
21 | cos(f_1*x[..., i]),
22 | cos(f_2*x[..., i]),
23 | ...
24 | cos(f_N * x[..., i]),
25 | x[..., i] # only present if include_input is True.
26 | ], here f_i is the frequency.
27 |
28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
31 |
32 | Args:
33 | num_freqs (int): the number of frequencies, default is 6;
34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
36 | input_dim (int): the input dimension, default is 3;
37 | include_input (bool): include the input tensor or not, default is True.
38 |
39 | Attributes:
40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
42 |
43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
44 | otherwise, it is input_dim * num_freqs * 2.
45 |
46 | """
47 |
48 | def __init__(self,
49 | num_freqs: int = 6,
50 | logspace: bool = True,
51 | input_dim: int = 3,
52 | include_input: bool = True,
53 | include_pi: bool = True) -> None:
54 |
55 | """The initialization"""
56 |
57 | super().__init__()
58 |
59 | if logspace:
60 | frequencies = 2.0 ** torch.arange(
61 | num_freqs,
62 | dtype=torch.float32
63 | )
64 | else:
65 | frequencies = torch.linspace(
66 | 1.0,
67 | 2.0 ** (num_freqs - 1),
68 | num_freqs,
69 | dtype=torch.float32
70 | )
71 |
72 | if include_pi:
73 | frequencies *= torch.pi
74 |
75 | self.register_buffer("frequencies", frequencies, persistent=False)
76 | self.include_input = include_input
77 | self.num_freqs = num_freqs
78 |
79 | self.out_dim = self.get_dims(input_dim)
80 |
81 | def get_dims(self, input_dim):
82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0
83 | out_dim = input_dim * (self.num_freqs * 2 + temp)
84 |
85 | return out_dim
86 |
87 | def forward(self, x: torch.Tensor) -> torch.Tensor:
88 | """ Forward process.
89 |
90 | Args:
91 | x: tensor of shape [..., dim]
92 |
93 | Returns:
94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
95 | where temp is 1 if include_input is True and 0 otherwise.
96 | """
97 |
98 | if self.num_freqs > 0:
99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
100 | if self.include_input:
101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
102 | else:
103 | return torch.cat((embed.sin(), embed.cos()), dim=-1)
104 | else:
105 | return x
106 |
107 |
108 | class LearnedFourierEmbedder(nn.Module):
109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """
110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
111 |
112 | def __init__(self, in_channels, dim):
113 | super().__init__()
114 | assert (dim % 2) == 0
115 | half_dim = dim // 2
116 | per_channel_dim = half_dim // in_channels
117 | self.weights = nn.Parameter(torch.randn(per_channel_dim))
118 |
119 | def forward(self, x):
120 | """
121 |
122 | Args:
123 | x (torch.FloatTensor): [..., c]
124 |
125 | Returns:
126 | x (torch.FloatTensor): [..., d]
127 | """
128 |
129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
132 | return fouriered
133 |
134 |
135 | class TriplaneLearnedFourierEmbedder(nn.Module):
136 | def __init__(self, in_channels, dim):
137 | super().__init__()
138 |
139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
142 |
143 | self.out_dim = in_channels + dim
144 |
145 | def forward(self, x):
146 |
147 | yz_embed = self.yz_plane_embedder(x)
148 | xz_embed = self.xz_plane_embedder(x)
149 | xy_embed = self.xy_plane_embedder(x)
150 |
151 | embed = yz_embed + xz_embed + xy_embed
152 |
153 | return embed
154 |
155 |
156 | def sequential_pos_embed(num_len, embed_dim):
157 | assert embed_dim % 2 == 0
158 |
159 | pos = torch.arange(num_len, dtype=torch.float32)
160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32)
161 | omega /= embed_dim / 2.
162 | omega = 1. / 10000 ** omega # (D/2,)
163 |
164 | pos = pos.reshape(-1) # (M,)
165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
166 |
167 | emb_sin = torch.sin(out) # (M, D/2)
168 | emb_cos = torch.cos(out) # (M, D/2)
169 |
170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
171 |
172 | return embeddings
173 |
174 |
175 | def timestep_embedding(timesteps, dim, max_period=10000):
176 | """
177 | Create sinusoidal timestep embeddings.
178 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
179 | These may be fractional.
180 | :param dim: the dimension of the output.
181 | :param max_period: controls the minimum frequency of the embeddings.
182 | :return: an [N x dim] Tensor of positional embeddings.
183 | """
184 | half = dim // 2
185 | freqs = torch.exp(
186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
187 | ).to(device=timesteps.device)
188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190 | if dim % 2:
191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
192 | return embedding
193 |
194 |
195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
197 | log2_hashmap_size=19, desired_resolution=None):
198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
199 | return nn.Identity(), input_dim
200 |
201 | elif embed_type == "fourier":
202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
203 | logspace=True, include_input=True)
204 | return embedder_obj, embedder_obj.out_dim
205 |
206 | elif embed_type == "hashgrid":
207 | raise NotImplementedError
208 |
209 | elif embed_type == "sphere_harmonic":
210 | raise NotImplementedError
211 |
212 | else:
213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
214 |
--------------------------------------------------------------------------------
/michelangelo/models/modules/transformer_blocks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from typing import Optional
8 |
9 | from michelangelo.models.modules.checkpoint import checkpoint
10 |
11 |
12 | def init_linear(l, stddev):
13 | nn.init.normal_(l.weight, std=stddev)
14 | if l.bias is not None:
15 | nn.init.constant_(l.bias, 0.0)
16 |
17 |
18 | class MultiheadAttention(nn.Module):
19 | def __init__(
20 | self,
21 | *,
22 | device: torch.device,
23 | dtype: torch.dtype,
24 | n_ctx: int,
25 | width: int,
26 | heads: int,
27 | init_scale: float,
28 | qkv_bias: bool,
29 | flash: bool = False
30 | ):
31 | super().__init__()
32 | self.n_ctx = n_ctx
33 | self.width = width
34 | self.heads = heads
35 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
36 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
37 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
38 | init_linear(self.c_qkv, init_scale)
39 | init_linear(self.c_proj, init_scale)
40 |
41 | def forward(self, x):
42 | x = self.c_qkv(x)
43 | x = checkpoint(self.attention, (x,), (), True)
44 | x = self.c_proj(x)
45 | return x
46 |
47 |
48 | class QKVMultiheadAttention(nn.Module):
49 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
50 | super().__init__()
51 | self.device = device
52 | self.dtype = dtype
53 | self.heads = heads
54 | self.n_ctx = n_ctx
55 | self.flash = flash
56 |
57 | def forward(self, qkv):
58 | bs, n_ctx, width = qkv.shape
59 | attn_ch = width // self.heads // 3
60 | scale = 1 / math.sqrt(math.sqrt(attn_ch))
61 | qkv = qkv.view(bs, n_ctx, self.heads, -1)
62 | q, k, v = torch.split(qkv, attn_ch, dim=-1)
63 |
64 | if self.flash:
65 | out = F.scaled_dot_product_attention(q, k, v)
66 | else:
67 | weight = torch.einsum(
68 | "bthc,bshc->bhts", q * scale, k * scale
69 | ) # More stable with f16 than dividing afterwards
70 | wdtype = weight.dtype
71 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
72 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
73 |
74 | return out
75 |
76 |
77 | class ResidualAttentionBlock(nn.Module):
78 | def __init__(
79 | self,
80 | *,
81 | device: torch.device,
82 | dtype: torch.dtype,
83 | n_ctx: int,
84 | width: int,
85 | heads: int,
86 | init_scale: float = 1.0,
87 | qkv_bias: bool = True,
88 | flash: bool = False,
89 | use_checkpoint: bool = False
90 | ):
91 | super().__init__()
92 |
93 | self.use_checkpoint = use_checkpoint
94 |
95 | self.attn = MultiheadAttention(
96 | device=device,
97 | dtype=dtype,
98 | n_ctx=n_ctx,
99 | width=width,
100 | heads=heads,
101 | init_scale=init_scale,
102 | qkv_bias=qkv_bias,
103 | flash=flash
104 | )
105 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
106 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
107 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
108 |
109 | def _forward(self, x: torch.Tensor):
110 | x = x + self.attn(self.ln_1(x))
111 | x = x + self.mlp(self.ln_2(x))
112 | return x
113 |
114 | def forward(self, x: torch.Tensor):
115 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
116 |
117 |
118 | class MultiheadCrossAttention(nn.Module):
119 | def __init__(
120 | self,
121 | *,
122 | device: torch.device,
123 | dtype: torch.dtype,
124 | width: int,
125 | heads: int,
126 | init_scale: float,
127 | qkv_bias: bool = True,
128 | flash: bool = False,
129 | n_data: Optional[int] = None,
130 | data_width: Optional[int] = None,
131 | ):
132 | super().__init__()
133 | self.n_data = n_data
134 | self.width = width
135 | self.heads = heads
136 | self.data_width = width if data_width is None else data_width
137 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
138 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
139 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
140 | self.attention = QKVMultiheadCrossAttention(
141 | device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
142 | )
143 | init_linear(self.c_q, init_scale)
144 | init_linear(self.c_kv, init_scale)
145 | init_linear(self.c_proj, init_scale)
146 |
147 | def forward(self, x, data):
148 | x = self.c_q(x)
149 | data = self.c_kv(data)
150 | x = checkpoint(self.attention, (x, data), (), True)
151 | x = self.c_proj(x)
152 | return x
153 |
154 |
155 | class QKVMultiheadCrossAttention(nn.Module):
156 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
157 | flash: bool = False, n_data: Optional[int] = None):
158 |
159 | super().__init__()
160 | self.device = device
161 | self.dtype = dtype
162 | self.heads = heads
163 | self.n_data = n_data
164 | self.flash = flash
165 |
166 | def forward(self, q, kv):
167 | _, n_ctx, _ = q.shape
168 | bs, n_data, width = kv.shape
169 | attn_ch = width // self.heads // 2
170 | scale = 1 / math.sqrt(math.sqrt(attn_ch))
171 | q = q.view(bs, n_ctx, self.heads, -1)
172 | kv = kv.view(bs, n_data, self.heads, -1)
173 | k, v = torch.split(kv, attn_ch, dim=-1)
174 |
175 | if self.flash:
176 | out = F.scaled_dot_product_attention(q, k, v)
177 | else:
178 | weight = torch.einsum(
179 | "bthc,bshc->bhts", q * scale, k * scale
180 | ) # More stable with f16 than dividing afterwards
181 | wdtype = weight.dtype
182 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
183 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
184 |
185 | return out
186 |
187 |
188 | class ResidualCrossAttentionBlock(nn.Module):
189 | def __init__(
190 | self,
191 | *,
192 | device: Optional[torch.device],
193 | dtype: Optional[torch.dtype],
194 | n_data: Optional[int] = None,
195 | width: int,
196 | heads: int,
197 | data_width: Optional[int] = None,
198 | init_scale: float = 0.25,
199 | qkv_bias: bool = True,
200 | flash: bool = False
201 | ):
202 | super().__init__()
203 |
204 | if data_width is None:
205 | data_width = width
206 |
207 | self.attn = MultiheadCrossAttention(
208 | device=device,
209 | dtype=dtype,
210 | n_data=n_data,
211 | width=width,
212 | heads=heads,
213 | data_width=data_width,
214 | init_scale=init_scale,
215 | qkv_bias=qkv_bias,
216 | flash=flash,
217 | )
218 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
219 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
220 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
221 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
222 |
223 | def forward(self, x: torch.Tensor, data: torch.Tensor):
224 | x = x + self.attn(self.ln_1(x), self.ln_2(data))
225 | x = x + self.mlp(self.ln_3(x))
226 | return x
227 |
228 |
229 | class MLP(nn.Module):
230 | def __init__(self, *,
231 | device: Optional[torch.device],
232 | dtype: Optional[torch.dtype],
233 | width: int,
234 | init_scale: float):
235 | super().__init__()
236 | self.width = width
237 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
238 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
239 | self.gelu = nn.GELU()
240 | init_linear(self.c_fc, init_scale)
241 | init_linear(self.c_proj, init_scale)
242 |
243 | def forward(self, x):
244 | return self.c_proj(self.gelu(self.c_fc(x)))
245 |
246 |
247 | class Transformer(nn.Module):
248 | def __init__(
249 | self,
250 | *,
251 | device: Optional[torch.device],
252 | dtype: Optional[torch.dtype],
253 | n_ctx: int,
254 | width: int,
255 | layers: int,
256 | heads: int,
257 | init_scale: float = 0.25,
258 | qkv_bias: bool = True,
259 | flash: bool = False,
260 | use_checkpoint: bool = False
261 | ):
262 | super().__init__()
263 | self.n_ctx = n_ctx
264 | self.width = width
265 | self.layers = layers
266 | self.resblocks = nn.ModuleList(
267 | [
268 | ResidualAttentionBlock(
269 | device=device,
270 | dtype=dtype,
271 | n_ctx=n_ctx,
272 | width=width,
273 | heads=heads,
274 | init_scale=init_scale,
275 | qkv_bias=qkv_bias,
276 | flash=flash,
277 | use_checkpoint=use_checkpoint
278 | )
279 | for _ in range(layers)
280 | ]
281 | )
282 |
283 | def forward(self, x: torch.Tensor):
284 | for block in self.resblocks:
285 | x = block(x)
286 | return x
287 |
--------------------------------------------------------------------------------
/michelangelo/models/modules/transformer_vit.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from typing import Optional
7 | import warnings
8 |
9 | from michelangelo.models.modules.checkpoint import checkpoint
10 |
11 |
12 | def _trunc_normal_(tensor, mean, std, a, b):
13 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
14 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
15 | def norm_cdf(x):
16 | # Computes standard normal cumulative distribution function
17 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
18 |
19 | if (mean < a - 2 * std) or (mean > b + 2 * std):
20 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
21 | "The distribution of values may be incorrect.",
22 | stacklevel=2)
23 |
24 | # Values are generated by using a truncated uniform distribution and
25 | # then using the inverse CDF for the normal distribution.
26 | # Get upper and lower cdf values
27 | l = norm_cdf((a - mean) / std)
28 | u = norm_cdf((b - mean) / std)
29 |
30 | # Uniformly fill tensor with values from [l, u], then translate to
31 | # [2l-1, 2u-1].
32 | tensor.uniform_(2 * l - 1, 2 * u - 1)
33 |
34 | # Use inverse cdf transform for normal distribution to get truncated
35 | # standard normal
36 | tensor.erfinv_()
37 |
38 | # Transform to proper mean, std
39 | tensor.mul_(std * math.sqrt(2.))
40 | tensor.add_(mean)
41 |
42 | # Clamp to ensure it's in the proper range
43 | tensor.clamp_(min=a, max=b)
44 | return tensor
45 |
46 |
47 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
48 | # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor
49 | r"""Fills the input Tensor with values drawn from a truncated
50 | normal distribution. The values are effectively drawn from the
51 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
52 | with values outside :math:`[a, b]` redrawn until they are within
53 | the bounds. The method used for generating the random values works
54 | best when :math:`a \leq \text{mean} \leq b`.
55 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
56 | applied while sampling the normal with mean/std applied, therefore a, b args
57 | should be adjusted to match the range of mean, std args.
58 | Args:
59 | tensor: an n-dimensional `torch.Tensor`
60 | mean: the mean of the normal distribution
61 | std: the standard deviation of the normal distribution
62 | a: the minimum cutoff value
63 | b: the maximum cutoff value
64 | Examples:
65 | >>> w = torch.empty(3, 5)
66 | >>> nn.init.trunc_normal_(w)
67 | """
68 | with torch.no_grad():
69 | return _trunc_normal_(tensor, mean, std, a, b)
70 |
71 |
72 | def init_weights(m):
73 | if isinstance(m, nn.Linear):
74 | trunc_normal_(m.weight, std=.02)
75 | if isinstance(m, nn.Linear) and m.bias is not None:
76 | nn.init.constant_(m.bias, 0)
77 | elif isinstance(m, nn.LayerNorm):
78 | nn.init.constant_(m.bias, 0)
79 | nn.init.constant_(m.weight, 1.0)
80 |
81 |
82 | class MultiheadAttention(nn.Module):
83 | def __init__(
84 | self,
85 | *,
86 | device: torch.device,
87 | dtype: torch.dtype,
88 | n_ctx: int,
89 | width: int,
90 | heads: int,
91 | qkv_bias: bool
92 | ):
93 | super().__init__()
94 | self.n_ctx = n_ctx
95 | self.width = width
96 | self.heads = heads
97 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
98 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
99 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
100 |
101 | def forward(self, x):
102 | x = self.c_qkv(x)
103 | x = checkpoint(self.attention, (x,), (), True)
104 | x = self.c_proj(x)
105 | return x
106 |
107 |
108 | class QKVMultiheadAttention(nn.Module):
109 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
110 | super().__init__()
111 | self.device = device
112 | self.dtype = dtype
113 | self.heads = heads
114 | self.n_ctx = n_ctx
115 |
116 | def forward(self, qkv):
117 | bs, n_ctx, width = qkv.shape
118 | attn_ch = width // self.heads // 3
119 | scale = 1 / math.sqrt(attn_ch)
120 | qkv = qkv.view(bs, n_ctx, self.heads, -1)
121 | q, k, v = torch.split(qkv, attn_ch, dim=-1)
122 | weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
123 | wdtype = weight.dtype
124 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
125 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
126 |
127 |
128 | class ResidualAttentionBlock(nn.Module):
129 | def __init__(
130 | self,
131 | *,
132 | device: torch.device,
133 | dtype: torch.dtype,
134 | n_ctx: int,
135 | width: int,
136 | heads: int,
137 | qkv_bias: bool = True,
138 | use_checkpoint: bool = False
139 | ):
140 | super().__init__()
141 |
142 | self.use_checkpoint = use_checkpoint
143 |
144 | self.attn = MultiheadAttention(
145 | device=device,
146 | dtype=dtype,
147 | n_ctx=n_ctx,
148 | width=width,
149 | heads=heads,
150 | qkv_bias=qkv_bias
151 | )
152 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
153 | self.mlp = MLP(device=device, dtype=dtype, width=width)
154 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
155 |
156 | def _forward(self, x: torch.Tensor):
157 | x = x + self.attn(self.ln_1(x))
158 | x = x + self.mlp(self.ln_2(x))
159 | return x
160 |
161 | def forward(self, x: torch.Tensor):
162 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
163 |
164 |
165 | class MultiheadCrossAttention(nn.Module):
166 | def __init__(
167 | self,
168 | *,
169 | device: torch.device,
170 | dtype: torch.dtype,
171 | width: int,
172 | heads: int,
173 | qkv_bias: bool = True,
174 | n_data: Optional[int] = None,
175 | data_width: Optional[int] = None,
176 | ):
177 | super().__init__()
178 | self.n_data = n_data
179 | self.width = width
180 | self.heads = heads
181 | self.data_width = width if data_width is None else data_width
182 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
183 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
184 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
185 | self.attention = QKVMultiheadCrossAttention(
186 | device=device, dtype=dtype, heads=heads, n_data=n_data
187 | )
188 |
189 | def forward(self, x, data):
190 | x = self.c_q(x)
191 | data = self.c_kv(data)
192 | x = checkpoint(self.attention, (x, data), (), True)
193 | x = self.c_proj(x)
194 | return x
195 |
196 |
197 | class QKVMultiheadCrossAttention(nn.Module):
198 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None):
199 | super().__init__()
200 | self.device = device
201 | self.dtype = dtype
202 | self.heads = heads
203 | self.n_data = n_data
204 |
205 | def forward(self, q, kv):
206 | _, n_ctx, _ = q.shape
207 | bs, n_data, width = kv.shape
208 | attn_ch = width // self.heads // 2
209 | scale = 1 / math.sqrt(attn_ch)
210 | q = q.view(bs, n_ctx, self.heads, -1)
211 | kv = kv.view(bs, n_data, self.heads, -1)
212 | k, v = torch.split(kv, attn_ch, dim=-1)
213 | weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
214 | wdtype = weight.dtype
215 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
216 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
217 |
218 |
219 | class ResidualCrossAttentionBlock(nn.Module):
220 | def __init__(
221 | self,
222 | *,
223 | device: Optional[torch.device],
224 | dtype: Optional[torch.dtype],
225 | n_data: Optional[int] = None,
226 | width: int,
227 | heads: int,
228 | data_width: Optional[int] = None,
229 | qkv_bias: bool = True
230 | ):
231 | super().__init__()
232 |
233 | if data_width is None:
234 | data_width = width
235 |
236 | self.attn = MultiheadCrossAttention(
237 | device=device,
238 | dtype=dtype,
239 | n_data=n_data,
240 | width=width,
241 | heads=heads,
242 | data_width=data_width,
243 | qkv_bias=qkv_bias
244 | )
245 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
246 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
247 | self.mlp = MLP(device=device, dtype=dtype, width=width)
248 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
249 |
250 | def forward(self, x: torch.Tensor, data: torch.Tensor):
251 | x = x + self.attn(self.ln_1(x), self.ln_2(data))
252 | x = x + self.mlp(self.ln_3(x))
253 | return x
254 |
255 |
256 | class MLP(nn.Module):
257 | def __init__(self, *,
258 | device: Optional[torch.device],
259 | dtype: Optional[torch.dtype],
260 | width: int):
261 | super().__init__()
262 | self.width = width
263 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
264 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
265 | self.gelu = nn.GELU()
266 |
267 | def forward(self, x):
268 | return self.c_proj(self.gelu(self.c_fc(x)))
269 |
270 |
271 | class Transformer(nn.Module):
272 | def __init__(
273 | self,
274 | *,
275 | device: Optional[torch.device],
276 | dtype: Optional[torch.dtype],
277 | n_ctx: int,
278 | width: int,
279 | layers: int,
280 | heads: int,
281 | qkv_bias: bool = True,
282 | use_checkpoint: bool = False
283 | ):
284 | super().__init__()
285 | self.n_ctx = n_ctx
286 | self.width = width
287 | self.layers = layers
288 | self.resblocks = nn.ModuleList(
289 | [
290 | ResidualAttentionBlock(
291 | device=device,
292 | dtype=dtype,
293 | n_ctx=n_ctx,
294 | width=width,
295 | heads=heads,
296 | qkv_bias=qkv_bias,
297 | use_checkpoint=use_checkpoint
298 | )
299 | for _ in range(layers)
300 | ]
301 | )
302 |
303 | self.apply(init_weights)
304 |
305 | def forward(self, x: torch.Tensor):
306 | for block in self.resblocks:
307 | x = block(x)
308 | return x
309 |
--------------------------------------------------------------------------------
/michelangelo/models/tsal/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/michelangelo/models/tsal/asl_pl_module.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import List, Tuple, Dict, Optional
4 | from omegaconf import DictConfig
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.optim import lr_scheduler
9 | import pytorch_lightning as pl
10 | from typing import Union
11 | from functools import partial
12 |
13 | from michelangelo.utils import instantiate_from_config
14 |
15 | from .inference_utils import extract_geometry
16 | from .tsal_base import (
17 | AlignedShapeAsLatentModule,
18 | ShapeAsLatentModule,
19 | Latent2MeshOutput,
20 | AlignedMeshOutput
21 | )
22 |
23 |
24 | class AlignedShapeAsLatentPLModule(pl.LightningModule):
25 |
26 | def __init__(self, *,
27 | shape_module_cfg,
28 | aligned_module_cfg,
29 | loss_cfg,
30 | optimizer_cfg: Optional[DictConfig] = None,
31 | ckpt_path: Optional[str] = None,
32 | ignore_keys: Union[Tuple[str], List[str]] = ()):
33 |
34 | super().__init__()
35 |
36 | shape_model: ShapeAsLatentModule = instantiate_from_config(
37 | shape_module_cfg, device=None, dtype=None
38 | )
39 | self.model: AlignedShapeAsLatentModule = instantiate_from_config(
40 | aligned_module_cfg, shape_model=shape_model
41 | )
42 |
43 | self.loss = instantiate_from_config(loss_cfg)
44 |
45 | self.optimizer_cfg = optimizer_cfg
46 |
47 | if ckpt_path is not None:
48 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
49 |
50 | self.save_hyperparameters()
51 |
52 | def set_shape_model_only(self):
53 | self.model.set_shape_model_only()
54 |
55 | @property
56 | def latent_shape(self):
57 | return self.model.shape_model.latent_shape
58 |
59 | @property
60 | def zero_rank(self):
61 | if self._trainer:
62 | zero_rank = self.trainer.local_rank == 0
63 | else:
64 | zero_rank = True
65 |
66 | return zero_rank
67 |
68 | def init_from_ckpt(self, path, ignore_keys=()):
69 | state_dict = torch.load(path, map_location="cpu")["state_dict"]
70 |
71 | keys = list(state_dict.keys())
72 | for k in keys:
73 | for ik in ignore_keys:
74 | if k.startswith(ik):
75 | print("Deleting key {} from state_dict.".format(k))
76 | del state_dict[k]
77 |
78 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
79 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
80 | if len(missing) > 0:
81 | print(f"Missing Keys: {missing}")
82 | print(f"Unexpected Keys: {unexpected}")
83 |
84 | def configure_optimizers(self) -> Tuple[List, List]:
85 | lr = self.learning_rate
86 |
87 | trainable_parameters = list(self.model.parameters())
88 |
89 | if self.optimizer_cfg is None:
90 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
91 | schedulers = []
92 | else:
93 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
94 | scheduler_func = instantiate_from_config(
95 | self.optimizer_cfg.scheduler,
96 | max_decay_steps=self.trainer.max_steps,
97 | lr_max=lr
98 | )
99 | scheduler = {
100 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
101 | "interval": "step",
102 | "frequency": 1
103 | }
104 | optimizers = [optimizer]
105 | schedulers = [scheduler]
106 |
107 | return optimizers, schedulers
108 |
109 | def forward(self,
110 | surface: torch.FloatTensor,
111 | image: torch.FloatTensor,
112 | text: torch.FloatTensor,
113 | volume_queries: torch.FloatTensor):
114 |
115 | """
116 |
117 | Args:
118 | surface (torch.FloatTensor):
119 | image (torch.FloatTensor):
120 | text (torch.FloatTensor):
121 | volume_queries (torch.FloatTensor):
122 |
123 | Returns:
124 |
125 | """
126 |
127 | embed_outputs, shape_z = self.model(surface, image, text)
128 |
129 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
130 | latents = self.model.shape_model.decode(shape_zq)
131 | logits = self.model.shape_model.query_geometry(volume_queries, latents)
132 |
133 | return embed_outputs, logits, posterior
134 |
135 | def encode(self, surface: torch.FloatTensor, sample_posterior=True):
136 |
137 | pc = surface[..., 0:3]
138 | feats = surface[..., 3:6]
139 |
140 | shape_embed, shape_zq, posterior = self.model.shape_model.encode(
141 | pc=pc, feats=feats, sample_posterior=sample_posterior
142 | )
143 |
144 | return shape_zq
145 |
146 | def decode(self,
147 | z_q,
148 | bounds: Union[Tuple[float], List[float], float] = 1.1,
149 | octree_depth: int = 7,
150 | num_chunks: int = 10000) -> List[Latent2MeshOutput]:
151 |
152 | latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim]
153 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
154 |
155 | return outputs
156 |
157 | def training_step(self, batch: Dict[str, torch.FloatTensor],
158 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
159 | """
160 |
161 | Args:
162 | batch (dict): the batch sample, and it contains:
163 | - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
164 | - image (torch.FloatTensor): [bs, 3, 224, 224]
165 | - text (torch.FloatTensor): [bs, num_templates, 77]
166 | - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
167 |
168 | batch_idx (int):
169 |
170 | optimizer_idx (int):
171 |
172 | Returns:
173 | loss (torch.FloatTensor):
174 |
175 | """
176 |
177 | surface = batch["surface"]
178 | image = batch["image"]
179 | text = batch["text"]
180 |
181 | volume_queries = batch["geo_points"][..., 0:3]
182 | shape_labels = batch["geo_points"][..., -1]
183 |
184 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
185 |
186 | aeloss, log_dict_ae = self.loss(
187 | **embed_outputs,
188 | posteriors=posteriors,
189 | shape_logits=shape_logits,
190 | shape_labels=shape_labels,
191 | split="train"
192 | )
193 |
194 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
195 | sync_dist=False, rank_zero_only=True)
196 |
197 | return aeloss
198 |
199 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
200 |
201 | surface = batch["surface"]
202 | image = batch["image"]
203 | text = batch["text"]
204 |
205 | volume_queries = batch["geo_points"][..., 0:3]
206 | shape_labels = batch["geo_points"][..., -1]
207 |
208 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
209 |
210 | aeloss, log_dict_ae = self.loss(
211 | **embed_outputs,
212 | posteriors=posteriors,
213 | shape_logits=shape_logits,
214 | shape_labels=shape_labels,
215 | split="val"
216 | )
217 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
218 | sync_dist=False, rank_zero_only=True)
219 |
220 | return aeloss
221 |
222 | def visual_alignment(self,
223 | surface: torch.FloatTensor,
224 | image: torch.FloatTensor,
225 | text: torch.FloatTensor,
226 | description: Optional[List[str]] = None,
227 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
228 | octree_depth: int = 7,
229 | num_chunks: int = 10000) -> List[AlignedMeshOutput]:
230 |
231 | """
232 |
233 | Args:
234 | surface:
235 | image:
236 | text:
237 | description:
238 | bounds:
239 | octree_depth:
240 | num_chunks:
241 |
242 | Returns:
243 | mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list.
244 |
245 | """
246 |
247 | outputs = []
248 |
249 | device = surface.device
250 | bs = surface.shape[0]
251 |
252 | embed_outputs, shape_z = self.model(surface, image, text)
253 |
254 | # calculate the similarity
255 | image_embed = embed_outputs["image_embed"]
256 | text_embed = embed_outputs["text_embed"]
257 | shape_embed = embed_outputs["shape_embed"]
258 |
259 | # normalized features
260 | shape_embed = F.normalize(shape_embed, dim=-1, p=2)
261 | text_embed = F.normalize(text_embed, dim=-1, p=2)
262 | image_embed = F.normalize(image_embed, dim=-1, p=2)
263 |
264 | # B x B
265 | shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1)
266 |
267 | # B x B
268 | shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1)
269 |
270 | # shape reconstruction
271 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
272 | latents = self.model.shape_model.decode(shape_zq)
273 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
274 |
275 | # 2. decode geometry
276 | mesh_v_f, has_surface = extract_geometry(
277 | geometric_func=geometric_func,
278 | device=device,
279 | batch_size=bs,
280 | bounds=bounds,
281 | octree_depth=octree_depth,
282 | num_chunks=num_chunks,
283 | disable=not self.zero_rank
284 | )
285 |
286 | # 3. decode texture
287 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
288 | if not is_surface:
289 | outputs.append(None)
290 | continue
291 |
292 | out = AlignedMeshOutput()
293 | out.mesh_v = mesh_v
294 | out.mesh_f = mesh_f
295 | out.surface = surface[i].cpu().numpy()
296 | out.image = image[i].cpu().numpy()
297 | if description is not None:
298 | out.text = description[i]
299 | out.shape_text_similarity = shape_text_similarity[i, i]
300 | out.shape_image_similarity = shape_image_similarity[i, i]
301 |
302 | outputs.append(out)
303 |
304 | return outputs
305 |
306 | def latent2mesh(self,
307 | latents: torch.FloatTensor,
308 | bounds: Union[Tuple[float], List[float], float] = 1.1,
309 | octree_depth: int = 7,
310 | num_chunks: int = 10000) -> List[Latent2MeshOutput]:
311 |
312 | """
313 |
314 | Args:
315 | latents: [bs, num_latents, dim]
316 | bounds:
317 | octree_depth:
318 | num_chunks:
319 |
320 | Returns:
321 | mesh_outputs (List[MeshOutput]): the mesh outputs list.
322 |
323 | """
324 |
325 | outputs = []
326 |
327 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
328 |
329 | # 2. decode geometry
330 | device = latents.device
331 | mesh_v_f, has_surface = extract_geometry(
332 | geometric_func=geometric_func,
333 | device=device,
334 | batch_size=len(latents),
335 | bounds=bounds,
336 | octree_depth=octree_depth,
337 | num_chunks=num_chunks,
338 | disable=not self.zero_rank
339 | )
340 |
341 | # 3. decode texture
342 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
343 | if not is_surface:
344 | outputs.append(None)
345 | continue
346 |
347 | out = Latent2MeshOutput()
348 | out.mesh_v = mesh_v
349 | out.mesh_f = mesh_f
350 |
351 | outputs.append(out)
352 |
353 | return outputs
354 |
355 |
--------------------------------------------------------------------------------
/michelangelo/models/tsal/clip_asl_module.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from torch import nn
5 | from einops import rearrange
6 | from transformers import CLIPModel
7 |
8 | from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule
9 |
10 |
11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
12 |
13 | def __init__(self, *,
14 | shape_model,
15 | clip_model_version: str = "openai/clip-vit-large-patch14"):
16 |
17 | super().__init__()
18 |
19 | self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version)
20 | for params in self.clip_model.parameters():
21 | params.requires_grad = False
22 |
23 | self.shape_model = shape_model
24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.clip_model.projection_dim))
25 | nn.init.normal_(self.shape_projection, std=self.clip_model.projection_dim ** -0.5)
26 |
27 | def set_shape_model_only(self):
28 | self.clip_model = None
29 |
30 | def encode_shape_embed(self, surface, return_latents: bool = False):
31 | """
32 |
33 | Args:
34 | surface (torch.FloatTensor): [bs, n, 3 + c]
35 | return_latents (bool):
36 |
37 | Returns:
38 | x (torch.FloatTensor): [bs, projection_dim]
39 | shape_latents (torch.FloatTensor): [bs, m, d]
40 | """
41 |
42 | pc = surface[..., 0:3]
43 | feats = surface[..., 3:]
44 |
45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
46 | x = shape_embed @ self.shape_projection
47 |
48 | if return_latents:
49 | return x, shape_latents
50 | else:
51 | return x
52 |
53 | def encode_image_embed(self, image):
54 | """
55 |
56 | Args:
57 | image (torch.FloatTensor): [bs, 3, h, w]
58 |
59 | Returns:
60 | x (torch.FloatTensor): [bs, projection_dim]
61 | """
62 |
63 | x = self.clip_model.get_image_features(image)
64 |
65 | return x
66 |
67 | def encode_text_embed(self, text):
68 | x = self.clip_model.get_text_features(text)
69 | return x
70 |
71 | def forward(self, surface, image, text):
72 | """
73 |
74 | Args:
75 | surface (torch.FloatTensor):
76 | image (torch.FloatTensor): [bs, 3, 224, 224]
77 | text (torch.LongTensor): [bs, num_templates, 77]
78 |
79 | Returns:
80 | embed_outputs (dict): the embedding outputs, and it contains:
81 | - image_embed (torch.FloatTensor):
82 | - text_embed (torch.FloatTensor):
83 | - shape_embed (torch.FloatTensor):
84 | - logit_scale (float):
85 | """
86 |
87 | # # text embedding
88 | # text_embed_all = []
89 | # for i in range(text.shape[0]):
90 | # text_for_one_sample = text[i]
91 | # text_embed = self.encode_text_embed(text_for_one_sample)
92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
93 | # text_embed = text_embed.mean(dim=0)
94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
95 | # text_embed_all.append(text_embed)
96 | # text_embed_all = torch.stack(text_embed_all)
97 |
98 | b = text.shape[0]
99 | text_tokens = rearrange(text, "b t l -> (b t) l")
100 | text_embed = self.encode_text_embed(text_tokens)
101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
102 | text_embed = text_embed.mean(dim=1)
103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
104 |
105 | # image embedding
106 | image_embed = self.encode_image_embed(image)
107 |
108 | # shape embedding
109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
110 |
111 | embed_outputs = {
112 | "image_embed": image_embed,
113 | "text_embed": text_embed,
114 | "shape_embed": shape_embed,
115 | "logit_scale": self.clip_model.logit_scale.exp()
116 | }
117 |
118 | return embed_outputs, shape_latents
119 |
--------------------------------------------------------------------------------
/michelangelo/models/tsal/inference_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from tqdm import tqdm
5 | from einops import repeat
6 | import numpy as np
7 | from typing import Callable, Tuple, List, Union, Optional
8 | from skimage import measure
9 |
10 | from michelangelo.graphics.primitives import generate_dense_grid_points
11 |
12 |
13 | @torch.no_grad()
14 | def extract_geometry(geometric_func: Callable,
15 | device: torch.device,
16 | batch_size: int = 1,
17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
18 | octree_depth: int = 7,
19 | num_chunks: int = 10000,
20 | disable: bool = True):
21 | """
22 |
23 | Args:
24 | geometric_func:
25 | device:
26 | bounds:
27 | octree_depth:
28 | batch_size:
29 | num_chunks:
30 | disable:
31 |
32 | Returns:
33 |
34 | """
35 |
36 | if isinstance(bounds, float):
37 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
38 |
39 | bbox_min = np.array(bounds[0:3])
40 | bbox_max = np.array(bounds[3:6])
41 | bbox_size = bbox_max - bbox_min
42 |
43 | xyz_samples, grid_size, length = generate_dense_grid_points(
44 | bbox_min=bbox_min,
45 | bbox_max=bbox_max,
46 | octree_depth=octree_depth,
47 | indexing="ij"
48 | )
49 | xyz_samples = torch.FloatTensor(xyz_samples)
50 |
51 | batch_logits = []
52 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
53 | desc="Implicit Function:", disable=disable, leave=False):
54 | queries = xyz_samples[start: start + num_chunks, :].to(device)
55 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
56 |
57 | logits = geometric_func(batch_queries)
58 | batch_logits.append(logits.cpu())
59 |
60 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()
61 |
62 | mesh_v_f = []
63 | has_surface = np.zeros((batch_size,), dtype=np.bool_)
64 | for i in range(batch_size):
65 | try:
66 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
67 | vertices = vertices / grid_size * bbox_size + bbox_min
68 | # vertices[:, [0, 1]] = vertices[:, [1, 0]]
69 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
70 | has_surface[i] = True
71 |
72 | except ValueError:
73 | mesh_v_f.append((None, None))
74 | has_surface[i] = False
75 |
76 | except RuntimeError:
77 | mesh_v_f.append((None, None))
78 | has_surface[i] = False
79 |
80 | return mesh_v_f, has_surface
81 |
--------------------------------------------------------------------------------
/michelangelo/models/tsal/loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from typing import Optional, Tuple, Dict
7 |
8 | from michelangelo.models.modules.distributions import DiagonalGaussianDistribution
9 | from michelangelo.utils.eval import compute_psnr
10 | from michelangelo.utils import misc
11 |
12 |
13 | class KLNearFar(nn.Module):
14 | def __init__(self,
15 | near_weight: float = 0.1,
16 | kl_weight: float = 1.0,
17 | num_near_samples: Optional[int] = None):
18 |
19 | super().__init__()
20 |
21 | self.near_weight = near_weight
22 | self.kl_weight = kl_weight
23 | self.num_near_samples = num_near_samples
24 | self.geo_criterion = nn.BCEWithLogitsLoss()
25 |
26 | def forward(self,
27 | posteriors: Optional[DiagonalGaussianDistribution],
28 | logits: torch.FloatTensor,
29 | labels: torch.FloatTensor,
30 | split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
31 |
32 | """
33 |
34 | Args:
35 | posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
36 | logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
37 | labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
38 | split (str):
39 | **kwargs:
40 |
41 | Returns:
42 | loss (torch.Tensor): (,)
43 | log (dict):
44 |
45 | """
46 |
47 | if self.num_near_samples is None:
48 | num_vol = logits.shape[1] // 2
49 | else:
50 | num_vol = logits.shape[1] - self.num_near_samples
51 |
52 | vol_logits = logits[:, 0:num_vol]
53 | vol_labels = labels[:, 0:num_vol]
54 |
55 | near_logits = logits[:, num_vol:]
56 | near_labels = labels[:, num_vol:]
57 |
58 | # occupancy loss
59 | # vol_bce = self.geo_criterion(vol_logits, vol_labels)
60 | # near_bce = self.geo_criterion(near_logits, near_labels)
61 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
62 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
63 |
64 | if posteriors is None:
65 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
66 | else:
67 | kl_loss = posteriors.kl(dims=(1, 2))
68 | kl_loss = torch.mean(kl_loss)
69 |
70 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight
71 |
72 | with torch.no_grad():
73 | preds = logits >= 0
74 | accuracy = (preds == labels).float()
75 | accuracy = accuracy.mean()
76 | pos_ratio = torch.mean(labels)
77 |
78 | log = {
79 | "{}/total_loss".format(split): loss.clone().detach(),
80 | "{}/near".format(split): near_bce.detach(),
81 | "{}/far".format(split): vol_bce.detach(),
82 | "{}/kl".format(split): kl_loss.detach(),
83 | "{}/accuracy".format(split): accuracy,
84 | "{}/pos_ratio".format(split): pos_ratio
85 | }
86 |
87 | if posteriors is not None:
88 | log[f"{split}/mean"] = posteriors.mean.mean().detach()
89 | log[f"{split}/std_mean"] = posteriors.std.mean().detach()
90 | log[f"{split}/std_max"] = posteriors.std.max().detach()
91 |
92 | return loss, log
93 |
94 |
95 | class KLNearFarColor(nn.Module):
96 | def __init__(self,
97 | near_weight: float = 0.1,
98 | kl_weight: float = 1.0,
99 | color_weight: float = 1.0,
100 | color_criterion: str = "mse",
101 | num_near_samples: Optional[int] = None):
102 |
103 | super().__init__()
104 |
105 | self.color_weight = color_weight
106 | self.near_weight = near_weight
107 | self.kl_weight = kl_weight
108 | self.num_near_samples = num_near_samples
109 |
110 | if color_criterion == "mse":
111 | self.color_criterion = nn.MSELoss()
112 |
113 | elif color_criterion == "l1":
114 | self.color_criterion = nn.L1Loss()
115 |
116 | else:
117 | raise ValueError(f"{color_criterion} must be [`mse`, `l1`].")
118 |
119 | self.geo_criterion = nn.BCEWithLogitsLoss()
120 |
121 | def forward(self,
122 | posteriors: Optional[DiagonalGaussianDistribution],
123 | logits: torch.FloatTensor,
124 | labels: torch.FloatTensor,
125 | pred_colors: torch.FloatTensor,
126 | gt_colors: torch.FloatTensor,
127 | split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
128 |
129 | """
130 |
131 | Args:
132 | posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
133 | logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
134 | labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
135 | pred_colors (torch.FloatTensor): [B, M, 3]
136 | gt_colors (torch.FloatTensor): [B, M, 3]
137 | split (str):
138 | **kwargs:
139 |
140 | Returns:
141 | loss (torch.Tensor): (,)
142 | log (dict):
143 |
144 | """
145 |
146 | if self.num_near_samples is None:
147 | num_vol = logits.shape[1] // 2
148 | else:
149 | num_vol = logits.shape[1] - self.num_near_samples
150 |
151 | vol_logits = logits[:, 0:num_vol]
152 | vol_labels = labels[:, 0:num_vol]
153 |
154 | near_logits = logits[:, num_vol:]
155 | near_labels = labels[:, num_vol:]
156 |
157 | # occupancy loss
158 | # vol_bce = self.geo_criterion(vol_logits, vol_labels)
159 | # near_bce = self.geo_criterion(near_logits, near_labels)
160 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
161 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
162 |
163 | # surface color loss
164 | color = self.color_criterion(pred_colors, gt_colors)
165 |
166 | if posteriors is None:
167 | kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device)
168 | else:
169 | kl_loss = posteriors.kl(dims=(1, 2))
170 | kl_loss = torch.mean(kl_loss)
171 |
172 | loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight
173 |
174 | with torch.no_grad():
175 | preds = logits >= 0
176 | accuracy = (preds == labels).float()
177 | accuracy = accuracy.mean()
178 | psnr = compute_psnr(pred_colors, gt_colors)
179 |
180 | log = {
181 | "{}/total_loss".format(split): loss.clone().detach(),
182 | "{}/near".format(split): near_bce.detach(),
183 | "{}/far".format(split): vol_bce.detach(),
184 | "{}/color".format(split): color.detach(),
185 | "{}/kl".format(split): kl_loss.detach(),
186 | "{}/psnr".format(split): psnr.detach(),
187 | "{}/accuracy".format(split): accuracy
188 | }
189 |
190 | return loss, log
191 |
192 |
193 | class ContrastKLNearFar(nn.Module):
194 | def __init__(self,
195 | contrast_weight: float = 1.0,
196 | near_weight: float = 0.1,
197 | kl_weight: float = 1.0,
198 | num_near_samples: Optional[int] = None):
199 |
200 | super().__init__()
201 |
202 | self.labels = None
203 | self.last_local_batch_size = None
204 |
205 | self.contrast_weight = contrast_weight
206 | self.near_weight = near_weight
207 | self.kl_weight = kl_weight
208 | self.num_near_samples = num_near_samples
209 | self.geo_criterion = nn.BCEWithLogitsLoss()
210 |
211 | def forward(self,
212 | shape_embed: torch.FloatTensor,
213 | text_embed: torch.FloatTensor,
214 | image_embed: torch.FloatTensor,
215 | logit_scale: torch.FloatTensor,
216 | posteriors: Optional[DiagonalGaussianDistribution],
217 | shape_logits: torch.FloatTensor,
218 | shape_labels: torch.FloatTensor,
219 | split: Optional[str] = "train", **kwargs):
220 |
221 | local_batch_size = shape_embed.size(0)
222 |
223 | if local_batch_size != self.last_local_batch_size:
224 | self.labels = local_batch_size * misc.get_rank() + torch.arange(
225 | local_batch_size, device=shape_embed.device
226 | ).long()
227 | self.last_local_batch_size = local_batch_size
228 |
229 | # normalized features
230 | shape_embed = F.normalize(shape_embed, dim=-1, p=2)
231 | text_embed = F.normalize(text_embed, dim=-1, p=2)
232 | image_embed = F.normalize(image_embed, dim=-1, p=2)
233 |
234 | # gather features from all GPUs
235 | shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch(
236 | [shape_embed, text_embed, image_embed]
237 | )
238 |
239 | # cosine similarity as logits
240 | logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t()
241 | logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t()
242 | logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t()
243 | logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t()
244 | contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) +
245 | F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \
246 | (F.cross_entropy(logits_per_shape_image, self.labels) +
247 | F.cross_entropy(logits_per_image_shape, self.labels)) / 2
248 |
249 | # shape reconstruction
250 | if self.num_near_samples is None:
251 | num_vol = shape_logits.shape[1] // 2
252 | else:
253 | num_vol = shape_logits.shape[1] - self.num_near_samples
254 |
255 | vol_logits = shape_logits[:, 0:num_vol]
256 | vol_labels = shape_labels[:, 0:num_vol]
257 |
258 | near_logits = shape_logits[:, num_vol:]
259 | near_labels = shape_labels[:, num_vol:]
260 |
261 | # occupancy loss
262 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
263 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
264 |
265 | if posteriors is None:
266 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
267 | else:
268 | kl_loss = posteriors.kl(dims=(1, 2))
269 | kl_loss = torch.mean(kl_loss)
270 |
271 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight
272 |
273 | # compute accuracy
274 | with torch.no_grad():
275 | pred = torch.argmax(logits_per_shape_text, dim=-1)
276 | correct = pred.eq(self.labels).sum()
277 | shape_text_acc = 100 * correct / local_batch_size
278 |
279 | pred = torch.argmax(logits_per_shape_image, dim=-1)
280 | correct = pred.eq(self.labels).sum()
281 | shape_image_acc = 100 * correct / local_batch_size
282 |
283 | preds = shape_logits >= 0
284 | accuracy = (preds == shape_labels).float()
285 | accuracy = accuracy.mean()
286 |
287 | log = {
288 | "{}/contrast".format(split): contrast_loss.clone().detach(),
289 | "{}/near".format(split): near_bce.detach(),
290 | "{}/far".format(split): vol_bce.detach(),
291 | "{}/kl".format(split): kl_loss.detach(),
292 | "{}/shape_text_acc".format(split): shape_text_acc,
293 | "{}/shape_image_acc".format(split): shape_image_acc,
294 | "{}/total_loss".format(split): loss.clone().detach(),
295 | "{}/accuracy".format(split): accuracy,
296 | }
297 |
298 | if posteriors is not None:
299 | log[f"{split}/mean"] = posteriors.mean.mean().detach()
300 | log[f"{split}/std_mean"] = posteriors.std.mean().detach()
301 | log[f"{split}/std_max"] = posteriors.std.max().detach()
302 |
303 | return loss, log
304 |
--------------------------------------------------------------------------------
/michelangelo/models/tsal/sal_perceiver.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | from typing import Optional
6 | from einops import repeat
7 | import math
8 |
9 | from michelangelo.models.modules import checkpoint
10 | from michelangelo.models.modules.embedder import FourierEmbedder
11 | from michelangelo.models.modules.distributions import DiagonalGaussianDistribution
12 | from michelangelo.models.modules.transformer_blocks import (
13 | ResidualCrossAttentionBlock,
14 | Transformer
15 | )
16 |
17 | from .tsal_base import ShapeAsLatentModule
18 |
19 |
20 | class CrossAttentionEncoder(nn.Module):
21 |
22 | def __init__(self, *,
23 | device: Optional[torch.device],
24 | dtype: Optional[torch.dtype],
25 | num_latents: int,
26 | fourier_embedder: FourierEmbedder,
27 | point_feats: int,
28 | width: int,
29 | heads: int,
30 | layers: int,
31 | init_scale: float = 0.25,
32 | qkv_bias: bool = True,
33 | flash: bool = False,
34 | use_ln_post: bool = False,
35 | use_checkpoint: bool = False):
36 |
37 | super().__init__()
38 |
39 | self.use_checkpoint = use_checkpoint
40 | self.num_latents = num_latents
41 |
42 | self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
43 |
44 | self.fourier_embedder = fourier_embedder
45 | self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
46 | self.cross_attn = ResidualCrossAttentionBlock(
47 | device=device,
48 | dtype=dtype,
49 | width=width,
50 | heads=heads,
51 | init_scale=init_scale,
52 | qkv_bias=qkv_bias,
53 | flash=flash,
54 | )
55 |
56 | self.self_attn = Transformer(
57 | device=device,
58 | dtype=dtype,
59 | n_ctx=num_latents,
60 | width=width,
61 | layers=layers,
62 | heads=heads,
63 | init_scale=init_scale,
64 | qkv_bias=qkv_bias,
65 | flash=flash,
66 | use_checkpoint=False
67 | )
68 |
69 | if use_ln_post:
70 | self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
71 | else:
72 | self.ln_post = None
73 |
74 | def _forward(self, pc, feats):
75 | """
76 |
77 | Args:
78 | pc (torch.FloatTensor): [B, N, 3]
79 | feats (torch.FloatTensor or None): [B, N, C]
80 |
81 | Returns:
82 |
83 | """
84 |
85 | bs = pc.shape[0]
86 |
87 | data = self.fourier_embedder(pc)
88 | if feats is not None:
89 | data = torch.cat([data, feats], dim=-1)
90 | data = self.input_proj(data)
91 |
92 | query = repeat(self.query, "m c -> b m c", b=bs)
93 | latents = self.cross_attn(query, data)
94 | latents = self.self_attn(latents)
95 |
96 | if self.ln_post is not None:
97 | latents = self.ln_post(latents)
98 |
99 | return latents, pc
100 |
101 | def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
102 | """
103 |
104 | Args:
105 | pc (torch.FloatTensor): [B, N, 3]
106 | feats (torch.FloatTensor or None): [B, N, C]
107 |
108 | Returns:
109 | dict
110 | """
111 |
112 | return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
113 |
114 |
115 | class CrossAttentionDecoder(nn.Module):
116 |
117 | def __init__(self, *,
118 | device: Optional[torch.device],
119 | dtype: Optional[torch.dtype],
120 | num_latents: int,
121 | out_channels: int,
122 | fourier_embedder: FourierEmbedder,
123 | width: int,
124 | heads: int,
125 | init_scale: float = 0.25,
126 | qkv_bias: bool = True,
127 | flash: bool = False,
128 | use_checkpoint: bool = False):
129 |
130 | super().__init__()
131 |
132 | self.use_checkpoint = use_checkpoint
133 | self.fourier_embedder = fourier_embedder
134 |
135 | self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
136 |
137 | self.cross_attn_decoder = ResidualCrossAttentionBlock(
138 | device=device,
139 | dtype=dtype,
140 | n_data=num_latents,
141 | width=width,
142 | heads=heads,
143 | init_scale=init_scale,
144 | qkv_bias=qkv_bias,
145 | flash=flash
146 | )
147 |
148 | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
149 | self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
150 |
151 | def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
152 | queries = self.query_proj(self.fourier_embedder(queries))
153 | x = self.cross_attn_decoder(queries, latents)
154 | x = self.ln_post(x)
155 | x = self.output_proj(x)
156 | return x
157 |
158 | def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
159 | return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
160 |
161 |
162 | class ShapeAsLatentPerceiver(ShapeAsLatentModule):
163 | def __init__(self, *,
164 | device: Optional[torch.device],
165 | dtype: Optional[torch.dtype],
166 | num_latents: int,
167 | point_feats: int = 0,
168 | embed_dim: int = 0,
169 | num_freqs: int = 8,
170 | include_pi: bool = True,
171 | width: int,
172 | heads: int,
173 | num_encoder_layers: int,
174 | num_decoder_layers: int,
175 | init_scale: float = 0.25,
176 | qkv_bias: bool = True,
177 | flash: bool = False,
178 | use_ln_post: bool = False,
179 | use_checkpoint: bool = False):
180 |
181 | super().__init__()
182 |
183 | self.use_checkpoint = use_checkpoint
184 |
185 | self.num_latents = num_latents
186 | self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
187 |
188 | init_scale = init_scale * math.sqrt(1.0 / width)
189 | self.encoder = CrossAttentionEncoder(
190 | device=device,
191 | dtype=dtype,
192 | fourier_embedder=self.fourier_embedder,
193 | num_latents=num_latents,
194 | point_feats=point_feats,
195 | width=width,
196 | heads=heads,
197 | layers=num_encoder_layers,
198 | init_scale=init_scale,
199 | qkv_bias=qkv_bias,
200 | flash=flash,
201 | use_ln_post=use_ln_post,
202 | use_checkpoint=use_checkpoint
203 | )
204 |
205 | self.embed_dim = embed_dim
206 | if embed_dim > 0:
207 | # VAE embed
208 | self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
209 | self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype)
210 | self.latent_shape = (num_latents, embed_dim)
211 | else:
212 | self.latent_shape = (num_latents, width)
213 |
214 | self.transformer = Transformer(
215 | device=device,
216 | dtype=dtype,
217 | n_ctx=num_latents,
218 | width=width,
219 | layers=num_decoder_layers,
220 | heads=heads,
221 | init_scale=init_scale,
222 | qkv_bias=qkv_bias,
223 | flash=flash,
224 | use_checkpoint=use_checkpoint
225 | )
226 |
227 | # geometry decoder
228 | self.geo_decoder = CrossAttentionDecoder(
229 | device=device,
230 | dtype=dtype,
231 | fourier_embedder=self.fourier_embedder,
232 | out_channels=1,
233 | num_latents=num_latents,
234 | width=width,
235 | heads=heads,
236 | init_scale=init_scale,
237 | qkv_bias=qkv_bias,
238 | flash=flash,
239 | use_checkpoint=use_checkpoint
240 | )
241 |
242 | def encode(self,
243 | pc: torch.FloatTensor,
244 | feats: Optional[torch.FloatTensor] = None,
245 | sample_posterior: bool = True):
246 | """
247 |
248 | Args:
249 | pc (torch.FloatTensor): [B, N, 3]
250 | feats (torch.FloatTensor or None): [B, N, C]
251 | sample_posterior (bool):
252 |
253 | Returns:
254 | latents (torch.FloatTensor)
255 | center_pos (torch.FloatTensor or None):
256 | posterior (DiagonalGaussianDistribution or None):
257 | """
258 |
259 | latents, center_pos = self.encoder(pc, feats)
260 |
261 | posterior = None
262 | if self.embed_dim > 0:
263 | moments = self.pre_kl(latents)
264 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
265 |
266 | if sample_posterior:
267 | latents = posterior.sample()
268 | else:
269 | latents = posterior.mode()
270 |
271 | return latents, center_pos, posterior
272 |
273 | def decode(self, latents: torch.FloatTensor):
274 | latents = self.post_kl(latents)
275 | return self.transformer(latents)
276 |
277 | def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
278 | logits = self.geo_decoder(queries, latents).squeeze(-1)
279 | return logits
280 |
281 | def forward(self,
282 | pc: torch.FloatTensor,
283 | feats: torch.FloatTensor,
284 | volume_queries: torch.FloatTensor,
285 | sample_posterior: bool = True):
286 | """
287 |
288 | Args:
289 | pc (torch.FloatTensor): [B, N, 3]
290 | feats (torch.FloatTensor or None): [B, N, C]
291 | volume_queries (torch.FloatTensor): [B, P, 3]
292 | sample_posterior (bool):
293 |
294 | Returns:
295 | logits (torch.FloatTensor): [B, P]
296 | center_pos (torch.FloatTensor): [B, M, 3]
297 | posterior (DiagonalGaussianDistribution or None).
298 |
299 | """
300 |
301 | latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
302 |
303 | latents = self.decode(latents)
304 | logits = self.query_geometry(volume_queries, latents)
305 |
306 | return logits, center_pos, posterior
307 |
308 |
309 | class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
310 |
311 | def __init__(self, *,
312 | device: Optional[torch.device],
313 | dtype: Optional[torch.dtype],
314 | num_latents: int,
315 | point_feats: int = 0,
316 | embed_dim: int = 0,
317 | num_freqs: int = 8,
318 | include_pi: bool = True,
319 | width: int,
320 | heads: int,
321 | num_encoder_layers: int,
322 | num_decoder_layers: int,
323 | init_scale: float = 0.25,
324 | qkv_bias: bool = True,
325 | flash: bool = False,
326 | use_ln_post: bool = False,
327 | use_checkpoint: bool = False):
328 |
329 | super().__init__(
330 | device=device,
331 | dtype=dtype,
332 | num_latents=1 + num_latents,
333 | point_feats=point_feats,
334 | embed_dim=embed_dim,
335 | num_freqs=num_freqs,
336 | include_pi=include_pi,
337 | width=width,
338 | heads=heads,
339 | num_encoder_layers=num_encoder_layers,
340 | num_decoder_layers=num_decoder_layers,
341 | init_scale=init_scale,
342 | qkv_bias=qkv_bias,
343 | flash=flash,
344 | use_ln_post=use_ln_post,
345 | use_checkpoint=use_checkpoint
346 | )
347 |
348 | self.width = width
349 |
350 | def encode(self,
351 | pc: torch.FloatTensor,
352 | feats: Optional[torch.FloatTensor] = None,
353 | sample_posterior: bool = True):
354 | """
355 |
356 | Args:
357 | pc (torch.FloatTensor): [B, N, 3]
358 | feats (torch.FloatTensor or None): [B, N, c]
359 | sample_posterior (bool):
360 |
361 | Returns:
362 | shape_embed (torch.FloatTensor)
363 | kl_embed (torch.FloatTensor):
364 | posterior (DiagonalGaussianDistribution or None):
365 | """
366 |
367 | shape_embed, latents = self.encode_latents(pc, feats)
368 | kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
369 |
370 | return shape_embed, kl_embed, posterior
371 |
372 | def encode_latents(self,
373 | pc: torch.FloatTensor,
374 | feats: Optional[torch.FloatTensor] = None):
375 |
376 | x, _ = self.encoder(pc, feats)
377 |
378 | shape_embed = x[:, 0]
379 | latents = x[:, 1:]
380 |
381 | return shape_embed, latents
382 |
383 | def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
384 | posterior = None
385 | if self.embed_dim > 0:
386 | moments = self.pre_kl(latents)
387 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
388 |
389 | if sample_posterior:
390 | kl_embed = posterior.sample()
391 | else:
392 | kl_embed = posterior.mode()
393 | else:
394 | kl_embed = latents
395 |
396 | return kl_embed, posterior
397 |
398 | def forward(self,
399 | pc: torch.FloatTensor,
400 | feats: torch.FloatTensor,
401 | volume_queries: torch.FloatTensor,
402 | sample_posterior: bool = True):
403 | """
404 |
405 | Args:
406 | pc (torch.FloatTensor): [B, N, 3]
407 | feats (torch.FloatTensor or None): [B, N, C]
408 | volume_queries (torch.FloatTensor): [B, P, 3]
409 | sample_posterior (bool):
410 |
411 | Returns:
412 | shape_embed (torch.FloatTensor): [B, projection_dim]
413 | logits (torch.FloatTensor): [B, M]
414 | posterior (DiagonalGaussianDistribution or None).
415 |
416 | """
417 |
418 | shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
419 |
420 | latents = self.decode(kl_embed)
421 | logits = self.query_geometry(volume_queries, latents)
422 |
423 | return shape_embed, logits, posterior
424 |
--------------------------------------------------------------------------------
/michelangelo/models/tsal/sal_pl_module.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import List, Tuple, Dict, Optional
4 | from omegaconf import DictConfig
5 |
6 | import torch
7 | from torch.optim import lr_scheduler
8 | import pytorch_lightning as pl
9 | from typing import Union
10 | from functools import partial
11 |
12 | from michelangelo.utils import instantiate_from_config
13 |
14 | from .inference_utils import extract_geometry
15 | from .tsal_base import (
16 | ShapeAsLatentModule,
17 | Latent2MeshOutput,
18 | Point2MeshOutput
19 | )
20 |
21 |
22 | class ShapeAsLatentPLModule(pl.LightningModule):
23 |
24 | def __init__(self, *,
25 | module_cfg,
26 | loss_cfg,
27 | optimizer_cfg: Optional[DictConfig] = None,
28 | ckpt_path: Optional[str] = None,
29 | ignore_keys: Union[Tuple[str], List[str]] = ()):
30 |
31 | super().__init__()
32 |
33 | self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None)
34 |
35 | self.loss = instantiate_from_config(loss_cfg)
36 |
37 | self.optimizer_cfg = optimizer_cfg
38 |
39 | if ckpt_path is not None:
40 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
41 |
42 | self.save_hyperparameters()
43 |
44 | @property
45 | def latent_shape(self):
46 | return self.sal.latent_shape
47 |
48 | @property
49 | def zero_rank(self):
50 | if self._trainer:
51 | zero_rank = self.trainer.local_rank == 0
52 | else:
53 | zero_rank = True
54 |
55 | return zero_rank
56 |
57 | def init_from_ckpt(self, path, ignore_keys=()):
58 | state_dict = torch.load(path, map_location="cpu")["state_dict"]
59 |
60 | keys = list(state_dict.keys())
61 | for k in keys:
62 | for ik in ignore_keys:
63 | if k.startswith(ik):
64 | print("Deleting key {} from state_dict.".format(k))
65 | del state_dict[k]
66 |
67 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
68 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
69 | if len(missing) > 0:
70 | print(f"Missing Keys: {missing}")
71 | print(f"Unexpected Keys: {unexpected}")
72 |
73 | def configure_optimizers(self) -> Tuple[List, List]:
74 | lr = self.learning_rate
75 |
76 | # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
77 | # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
78 |
79 | if self.optimizer_cfg is None:
80 | optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
81 | schedulers = []
82 | else:
83 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters())
84 | scheduler_func = instantiate_from_config(
85 | self.optimizer_cfg.scheduler,
86 | max_decay_steps=self.trainer.max_steps,
87 | lr_max=lr
88 | )
89 | scheduler = {
90 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
91 | "interval": "step",
92 | "frequency": 1
93 | }
94 | optimizers = [optimizer]
95 | schedulers = [scheduler]
96 |
97 | return optimizers, schedulers
98 |
99 | def forward(self,
100 | pc: torch.FloatTensor,
101 | feats: torch.FloatTensor,
102 | volume_queries: torch.FloatTensor):
103 |
104 | logits, center_pos, posterior = self.sal(pc, feats, volume_queries)
105 |
106 | return posterior, logits
107 |
108 | def encode(self, surface: torch.FloatTensor, sample_posterior=True):
109 |
110 | pc = surface[..., 0:3]
111 | feats = surface[..., 3:6]
112 |
113 | latents, center_pos, posterior = self.sal.encode(
114 | pc=pc, feats=feats, sample_posterior=sample_posterior
115 | )
116 |
117 | return latents
118 |
119 | def decode(self,
120 | z_q,
121 | bounds: Union[Tuple[float], List[float], float] = 1.1,
122 | octree_depth: int = 7,
123 | num_chunks: int = 10000) -> List[Latent2MeshOutput]:
124 |
125 | latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim]
126 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
127 |
128 | return outputs
129 |
130 | def training_step(self, batch: Dict[str, torch.FloatTensor],
131 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
132 | """
133 |
134 | Args:
135 | batch (dict): the batch sample, and it contains:
136 | - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
137 | - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
138 |
139 | batch_idx (int):
140 |
141 | optimizer_idx (int):
142 |
143 | Returns:
144 | loss (torch.FloatTensor):
145 |
146 | """
147 |
148 | pc = batch["surface"][..., 0:3]
149 | feats = batch["surface"][..., 3:]
150 |
151 | volume_queries = batch["geo_points"][..., 0:3]
152 | volume_labels = batch["geo_points"][..., -1]
153 |
154 | posterior, logits = self(
155 | pc=pc, feats=feats, volume_queries=volume_queries
156 | )
157 | aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train")
158 |
159 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
160 | sync_dist=False, rank_zero_only=True)
161 |
162 | return aeloss
163 |
164 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
165 |
166 | pc = batch["surface"][..., 0:3]
167 | feats = batch["surface"][..., 3:]
168 |
169 | volume_queries = batch["geo_points"][..., 0:3]
170 | volume_labels = batch["geo_points"][..., -1]
171 |
172 | posterior, logits = self(
173 | pc=pc, feats=feats, volume_queries=volume_queries,
174 | )
175 | aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val")
176 |
177 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
178 | sync_dist=False, rank_zero_only=True)
179 |
180 | return aeloss
181 |
182 | def point2mesh(self,
183 | pc: torch.FloatTensor,
184 | feats: torch.FloatTensor,
185 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
186 | octree_depth: int = 7,
187 | num_chunks: int = 10000) -> List[Point2MeshOutput]:
188 |
189 | """
190 |
191 | Args:
192 | pc:
193 | feats:
194 | bounds:
195 | octree_depth:
196 | num_chunks:
197 |
198 | Returns:
199 | mesh_outputs (List[MeshOutput]): the mesh outputs list.
200 |
201 | """
202 |
203 | outputs = []
204 |
205 | device = pc.device
206 | bs = pc.shape[0]
207 |
208 | # 1. point encoder + latents transformer
209 | latents, center_pos, posterior = self.sal.encode(pc, feats)
210 | latents = self.sal.decode(latents) # latents: [bs, num_latents, dim]
211 |
212 | geometric_func = partial(self.sal.query_geometry, latents=latents)
213 |
214 | # 2. decode geometry
215 | mesh_v_f, has_surface = extract_geometry(
216 | geometric_func=geometric_func,
217 | device=device,
218 | batch_size=bs,
219 | bounds=bounds,
220 | octree_depth=octree_depth,
221 | num_chunks=num_chunks,
222 | disable=not self.zero_rank
223 | )
224 |
225 | # 3. decode texture
226 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
227 | if not is_surface:
228 | outputs.append(None)
229 | continue
230 |
231 | out = Point2MeshOutput()
232 | out.mesh_v = mesh_v
233 | out.mesh_f = mesh_f
234 | out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy()
235 |
236 | if center_pos is not None:
237 | out.center = center_pos[i].cpu().numpy()
238 |
239 | outputs.append(out)
240 |
241 | return outputs
242 |
243 | def latent2mesh(self,
244 | latents: torch.FloatTensor,
245 | bounds: Union[Tuple[float], List[float], float] = 1.1,
246 | octree_depth: int = 7,
247 | num_chunks: int = 10000) -> List[Latent2MeshOutput]:
248 |
249 | """
250 |
251 | Args:
252 | latents: [bs, num_latents, dim]
253 | bounds:
254 | octree_depth:
255 | num_chunks:
256 |
257 | Returns:
258 | mesh_outputs (List[MeshOutput]): the mesh outputs list.
259 |
260 | """
261 |
262 | outputs = []
263 |
264 | geometric_func = partial(self.sal.query_geometry, latents=latents)
265 |
266 | # 2. decode geometry
267 | device = latents.device
268 | mesh_v_f, has_surface = extract_geometry(
269 | geometric_func=geometric_func,
270 | device=device,
271 | batch_size=len(latents),
272 | bounds=bounds,
273 | octree_depth=octree_depth,
274 | num_chunks=num_chunks,
275 | disable=not self.zero_rank
276 | )
277 |
278 | # 3. decode texture
279 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
280 | if not is_surface:
281 | outputs.append(None)
282 | continue
283 |
284 | out = Latent2MeshOutput()
285 | out.mesh_v = mesh_v
286 | out.mesh_f = mesh_f
287 |
288 | outputs.append(out)
289 |
290 | return outputs
291 |
--------------------------------------------------------------------------------
/michelangelo/models/tsal/tsal_base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch.nn as nn
4 | from typing import Tuple, List, Optional
5 | import pytorch_lightning as pl
6 |
7 |
8 | class Point2MeshOutput(object):
9 | def __init__(self):
10 | self.mesh_v = None
11 | self.mesh_f = None
12 | self.center = None
13 | self.pc = None
14 |
15 |
16 | class Latent2MeshOutput(object):
17 |
18 | def __init__(self):
19 | self.mesh_v = None
20 | self.mesh_f = None
21 |
22 |
23 | class AlignedMeshOutput(object):
24 |
25 | def __init__(self):
26 | self.mesh_v = None
27 | self.mesh_f = None
28 | self.surface = None
29 | self.image = None
30 | self.text: Optional[str] = None
31 | self.shape_text_similarity: Optional[float] = None
32 | self.shape_image_similarity: Optional[float] = None
33 |
34 |
35 | class ShapeAsLatentPLModule(pl.LightningModule):
36 | latent_shape: Tuple[int]
37 |
38 | def encode(self, surface, *args, **kwargs):
39 | raise NotImplementedError
40 |
41 | def decode(self, z_q, *args, **kwargs):
42 | raise NotImplementedError
43 |
44 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
45 | raise NotImplementedError
46 |
47 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
48 | raise NotImplementedError
49 |
50 |
51 | class ShapeAsLatentModule(nn.Module):
52 | latent_shape: Tuple[int, int]
53 |
54 | def __init__(self, *args, **kwargs):
55 | super().__init__()
56 |
57 | def encode(self, *args, **kwargs):
58 | raise NotImplementedError
59 |
60 | def decode(self, *args, **kwargs):
61 | raise NotImplementedError
62 |
63 | def query_geometry(self, *args, **kwargs):
64 | raise NotImplementedError
65 |
66 |
67 | class AlignedShapeAsLatentPLModule(pl.LightningModule):
68 | latent_shape: Tuple[int]
69 |
70 | def set_shape_model_only(self):
71 | raise NotImplementedError
72 |
73 | def encode(self, surface, *args, **kwargs):
74 | raise NotImplementedError
75 |
76 | def decode(self, z_q, *args, **kwargs):
77 | raise NotImplementedError
78 |
79 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
80 | raise NotImplementedError
81 |
82 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
83 | raise NotImplementedError
84 |
85 |
86 | class AlignedShapeAsLatentModule(nn.Module):
87 | shape_model: ShapeAsLatentModule
88 | latent_shape: Tuple[int, int]
89 |
90 | def __init__(self, *args, **kwargs):
91 | super().__init__()
92 |
93 | def set_shape_model_only(self):
94 | raise NotImplementedError
95 |
96 | def encode_image_embed(self, *args, **kwargs):
97 | raise NotImplementedError
98 |
99 | def encode_text_embed(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 | def encode_shape_embed(self, *args, **kwargs):
103 | raise NotImplementedError
104 |
105 |
106 | class TexturedShapeAsLatentModule(nn.Module):
107 |
108 | def __init__(self, *args, **kwargs):
109 | super().__init__()
110 |
111 | def encode(self, *args, **kwargs):
112 | raise NotImplementedError
113 |
114 | def decode(self, *args, **kwargs):
115 | raise NotImplementedError
116 |
117 | def query_geometry(self, *args, **kwargs):
118 | raise NotImplementedError
119 |
120 | def query_color(self, *args, **kwargs):
121 | raise NotImplementedError
122 |
--------------------------------------------------------------------------------
/michelangelo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .misc import get_config_from_file
4 | from .misc import instantiate_from_config
5 |
--------------------------------------------------------------------------------
/michelangelo/utils/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 |
5 |
6 | def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7):
7 |
8 | mse = torch.mean((x - y) ** 2)
9 | psnr = 10 * torch.log10(data_range / (mse + eps))
10 |
11 | return psnr
12 |
13 |
--------------------------------------------------------------------------------
/michelangelo/utils/io.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import io
5 | import tarfile
6 | import json
7 | import numpy as np
8 | import numpy.lib.format
9 |
10 |
11 | def mkdir(path):
12 | os.makedirs(path, exist_ok=True)
13 | return path
14 |
15 |
16 | def npy_loads(data):
17 | stream = io.BytesIO(data)
18 | return np.lib.format.read_array(stream)
19 |
20 |
21 | def npz_loads(data):
22 | return np.load(io.BytesIO(data))
23 |
24 |
25 | def json_loads(data):
26 | return json.loads(data)
27 |
28 |
29 | def load_json(filepath):
30 | with open(filepath, "r") as f:
31 | data = json.load(f)
32 | return data
33 |
34 |
35 | def write_json(filepath, data):
36 | with open(filepath, "w") as f:
37 | json.dump(data, f, indent=2)
38 |
39 |
40 | def extract_tar(tar_path, tar_cache_folder):
41 |
42 | with tarfile.open(tar_path, "r") as tar:
43 | tar.extractall(path=tar_cache_folder)
44 |
45 | tar_uids = sorted(os.listdir(tar_cache_folder))
46 | print(f"extract tar: {tar_path} to {tar_cache_folder}")
47 | return tar_uids
48 |
--------------------------------------------------------------------------------
/michelangelo/utils/misc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import importlib
4 | from omegaconf import OmegaConf, DictConfig, ListConfig
5 |
6 | import torch
7 | import torch.distributed as dist
8 | from typing import Union
9 |
10 |
11 | def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]:
12 | config_file = OmegaConf.load(config_file)
13 |
14 | if 'base_config' in config_file.keys():
15 | if config_file['base_config'] == "default_base":
16 | base_config = OmegaConf.create()
17 | # base_config = get_default_config()
18 | elif config_file['base_config'].endswith(".yaml"):
19 | base_config = get_config_from_file(config_file['base_config'])
20 | else:
21 | raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
22 |
23 | config_file = {key: value for key, value in config_file if key != "base_config"}
24 |
25 | return OmegaConf.merge(base_config, config_file)
26 |
27 | return config_file
28 |
29 |
30 | def get_obj_from_str(string, reload=False):
31 | module, cls = string.rsplit(".", 1)
32 | if reload:
33 | module_imp = importlib.import_module(module)
34 | importlib.reload(module_imp)
35 | return getattr(importlib.import_module(module, package=None), cls)
36 |
37 |
38 | def get_obj_from_config(config):
39 | if "target" not in config:
40 | raise KeyError("Expected key `target` to instantiate.")
41 |
42 | return get_obj_from_str(config["target"])
43 |
44 |
45 | def instantiate_from_config(config, **kwargs):
46 | if "target" not in config:
47 | raise KeyError("Expected key `target` to instantiate.")
48 |
49 | cls = get_obj_from_str(config["target"])
50 |
51 | params = config.get("params", dict())
52 | # params.update(kwargs)
53 | # instance = cls(**params)
54 | kwargs.update(params)
55 | instance = cls(**kwargs)
56 |
57 | return instance
58 |
59 |
60 | def is_dist_avail_and_initialized():
61 | if not dist.is_available():
62 | return False
63 | if not dist.is_initialized():
64 | return False
65 | return True
66 |
67 |
68 | def get_rank():
69 | if not is_dist_avail_and_initialized():
70 | return 0
71 | return dist.get_rank()
72 |
73 |
74 | def get_world_size():
75 | if not is_dist_avail_and_initialized():
76 | return 1
77 | return dist.get_world_size()
78 |
79 |
80 | def all_gather_batch(tensors):
81 | """
82 | Performs all_gather operation on the provided tensors.
83 | """
84 | # Queue the gathered tensors
85 | world_size = get_world_size()
86 | # There is no need for reduction in the single-proc case
87 | if world_size == 1:
88 | return tensors
89 | tensor_list = []
90 | output_tensor = []
91 | for tensor in tensors:
92 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
93 | dist.all_gather(
94 | tensor_all,
95 | tensor,
96 | async_op=False # performance opt
97 | )
98 |
99 | tensor_list.append(tensor_all)
100 |
101 | for tensor_all in tensor_list:
102 | output_tensor.append(torch.cat(tensor_all, dim=0))
103 | return output_tensor
104 |
--------------------------------------------------------------------------------
/michelangelo/utils/visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/michelangelo/utils/visualizers/color_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | # Helper functions
6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
7 | colormap = plt.cm.get_cmap(colormap)
8 | if normalize:
9 | vmin = np.min(inp)
10 | vmax = np.max(inp)
11 |
12 | norm = plt.Normalize(vmin, vmax)
13 | return colormap(norm(inp))[:, :3]
14 |
15 |
16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
17 | # tex dims need to be power of two.
18 | array = np.ones((width, height, 3), dtype='float32')
19 |
20 | # width in texels of each checker
21 | checker_w = width / n_checkers_x
22 | checker_h = height / n_checkers_y
23 |
24 | for y in range(height):
25 | for x in range(width):
26 | color_key = int(x / checker_w) + int(y / checker_h)
27 | if color_key % 2 == 0:
28 | array[x, y, :] = [1., 0.874, 0.0]
29 | else:
30 | array[x, y, :] = [0., 0., 0.]
31 | return array
32 |
33 |
34 | def gen_circle(width=256, height=256):
35 | xx, yy = np.mgrid[:width, :height]
36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
37 | array = np.ones((width, height, 4), dtype='float32')
38 | array[:, :, 0] = (circle <= width)
39 | array[:, :, 1] = (circle <= width)
40 | array[:, :, 2] = (circle <= width)
41 | array[:, :, 3] = circle <= width
42 | return array
43 |
44 |
--------------------------------------------------------------------------------
/michelangelo/utils/visualizers/html_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import io
3 | import base64
4 | import numpy as np
5 | from PIL import Image
6 |
7 |
8 | def to_html_frame(content):
9 |
10 | html_frame = f"""
11 |
12 |
13 | {content}
14 |
15 |
16 | """
17 |
18 | return html_frame
19 |
20 |
21 | def to_single_row_table(caption: str, content: str):
22 |
23 | table_html = f"""
24 |
25 | {caption}
26 |
27 | {content} |
28 |
29 |
30 | """
31 |
32 | return table_html
33 |
34 |
35 | def to_image_embed_tag(image: np.ndarray):
36 |
37 | # Convert np.ndarray to bytes
38 | img = Image.fromarray(image)
39 | raw_bytes = io.BytesIO()
40 | img.save(raw_bytes, "PNG")
41 |
42 | # Encode bytes to base64
43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
44 |
45 | image_tag = f"""
46 |
47 | """
48 |
49 | return image_tag
50 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | accelerate==0.20.3
3 | addict==2.4.0
4 | aiofiles==23.1.0
5 | aiohttp==3.8.4
6 | aiosignal==1.3.1
7 | altair==5.0.1
8 | antlr4-python3-runtime==4.9.3
9 | anyio==3.6.2
10 | appdirs==1.4.4
11 | argon2-cffi==21.3.0
12 | argon2-cffi-bindings==21.2.0
13 | arrow==1.2.3
14 | asttokens==2.2.1
15 | async-timeout==4.0.2
16 | attrs==22.2.0
17 | backcall==0.2.0
18 | beautifulsoup4==4.11.2
19 | bleach==6.0.0
20 | braceexpand==0.1.7
21 | cachetools==5.3.0
22 | cffi==1.15.1
23 | charset-normalizer==3.0.1
24 | click==8.1.3
25 | coloredlogs==15.0.1
26 | comm==0.1.2
27 | configargparse==1.5.3
28 | contourpy==1.0.7
29 | controlnet-aux==0.0.5
30 | cycler==0.11.0
31 | cython==0.29.33
32 | dash==2.8.1
33 | dash-core-components==2.0.0
34 | dash-html-components==2.0.0
35 | dash-table==5.0.0
36 | dataclasses-json==0.6.0
37 | debugpy==1.6.6
38 | decorator==5.1.1
39 | deepspeed==0.8.1
40 | defusedxml==0.7.1
41 | deprecated==1.2.14
42 | diffusers==0.18.2
43 | docker-pycreds==0.4.0
44 | einops==0.6.0
45 | executing==1.2.0
46 | fastapi==0.101.0
47 | fastjsonschema==2.16.2
48 | ffmpy==0.3.1
49 | filelock==3.9.0
50 | flask==2.2.3
51 | flatbuffers==23.5.26
52 | fonttools==4.38.0
53 | fqdn==1.5.1
54 | frozenlist==1.3.3
55 | fsspec==2023.1.0
56 | ftfy==6.1.1
57 | fvcore==0.1.5.post20221221
58 | gitdb==4.0.10
59 | gitpython==3.1.31
60 | google-auth==2.16.1
61 | google-auth-oauthlib==0.4.6
62 | gradio==3.39.0
63 | gradio-client==0.3.0
64 | grpcio==1.51.3
65 | h11==0.14.0
66 | hjson==3.1.0
67 | httpcore==0.17.3
68 | httpx==0.24.1
69 | huggingface-hub==0.16.4
70 | humanfriendly==10.0
71 | idna==3.4
72 | imageio==2.25.1
73 | importlib-metadata==6.0.0
74 | iopath==0.1.10
75 | ipydatawidgets==4.3.3
76 | ipykernel==6.21.2
77 | ipython==8.10.0
78 | ipython-genutils==0.2.0
79 | ipywidgets==8.0.4
80 | isoduration==20.11.0
81 | itsdangerous==2.1.2
82 | jedi==0.18.2
83 | jinja2==3.1.2
84 | joblib==1.2.0
85 | jsonpointer==2.3
86 | jsonschema==4.17.3
87 | jupyter==1.0.0
88 | jupyter-client==8.0.3
89 | jupyter-console==6.6.1
90 | jupyter-core==5.2.0
91 | jupyter-events==0.6.3
92 | jupyter-server==2.3.0
93 | jupyter-server-terminals==0.4.4
94 | jupyterlab-pygments==0.2.2
95 | jupyterlab-widgets==3.0.5
96 | kiwisolver==1.4.4
97 | lightning-utilities==0.7.1
98 | linkify-it-py==2.0.2
99 | lmdb==1.4.1
100 | markdown==3.4.1
101 | markdown-it-py==2.2.0
102 | markupsafe==2.1.2
103 | marshmallow==3.20.1
104 | matplotlib==3.6.3
105 | matplotlib-inline==0.1.6
106 | mdit-py-plugins==0.3.3
107 | mdurl==0.1.2
108 | mesh2sdf==1.1.0
109 | mistune==2.0.5
110 | mpmath==1.3.0
111 | multidict==6.0.4
112 | mypy-extensions==1.0.0
113 | nbclassic==0.5.2
114 | nbclient==0.7.2
115 | nbconvert==7.2.9
116 | nbformat==5.5.0
117 | nest-asyncio==1.5.6
118 | networkx==3.0
119 | ninja==1.11.1
120 | notebook==6.5.2
121 | notebook-shim==0.2.2
122 | numpy==1.23.1
123 | oauthlib==3.2.2
124 | objaverse==0.0.7
125 | omegaconf==2.3.0
126 | onnxruntime==1.15.1
127 | opencv-contrib-python==4.8.0.74
128 | opencv-python==4.7.0.72
129 | orjson==3.9.2
130 | packaging==21.3
131 | pandas==1.4.4
132 | pandocfilters==1.5.0
133 | parso==0.8.3
134 | pathtools==0.1.2
135 | pexpect==4.8.0
136 | pickleshare==0.7.5
137 | pillow==9.2.0
138 | platformdirs==3.0.0
139 | plotly==5.13.0
140 | portalocker==2.7.0
141 | prometheus-client==0.16.0
142 | prompt-toolkit==3.0.37
143 | protobuf==3.19.6
144 | psutil==5.9.4
145 | ptyprocess==0.7.0
146 | pure-eval==0.2.2
147 | py-cpuinfo==9.0.0
148 | pyasn1==0.4.8
149 | pyasn1-modules==0.2.8
150 | pycparser==2.21
151 | pydantic==1.10.5
152 | pydub==0.25.1
153 | pygltflib==1.16.0
154 | pygments==2.14.0
155 | pymeshlab==2022.2.post3
156 | pyparsing==3.0.9
157 | pyquaternion==0.9.9
158 | pyrsistent==0.19.3
159 | pysdf==0.1.8
160 | python-dateutil==2.8.2
161 | python-json-logger==2.0.7
162 | python-multipart==0.0.6
163 | pythreejs==2.4.2
164 | pytz==2022.7.1
165 | pywavelets==1.4.1
166 | pyyaml==6.0
167 | pyzmq==25.0.0
168 | qtconsole==5.4.0
169 | qtpy==2.3.0
170 | regex==2022.10.31
171 | requests==2.28.2
172 | requests-oauthlib==1.3.1
173 | rfc3339-validator==0.1.4
174 | rfc3986-validator==0.1.1
175 | rsa==4.9
176 | rtree==1.0.1
177 | safetensors==0.3.1
178 | scikit-image==0.19.3
179 | scikit-learn==1.2.1
180 | scipy==1.10.1
181 | semantic-version==2.10.0
182 | send2trash==1.8.0
183 | sentencepiece==0.1.97
184 | sentry-sdk==1.15.0
185 | setproctitle==1.3.2
186 | setuptools==63.4.3
187 | sh==2.0.2
188 | shapely==2.0.1
189 | six==1.16.0
190 | smmap==5.0.0
191 | sniffio==1.3.0
192 | soupsieve==2.4
193 | stack-data==0.6.2
194 | starlette==0.27.0
195 | sympy==1.12
196 | tabulate==0.9.0
197 | tenacity==8.2.1
198 | tensorboard==2.10.1
199 | tensorboard-data-server==0.6.1
200 | tensorboard-plugin-wit==1.8.1
201 | termcolor==2.3.0
202 | terminado==0.17.1
203 | threadpoolctl==3.1.0
204 | tifffile==2023.2.3
205 | timm==0.9.2
206 | tinycss2==1.2.1
207 | tokenizers==0.13.2
208 | toolz==0.12.0
209 | tornado==6.2
210 | tqdm==4.64.1
211 | traitlets==5.9.0
212 | traittypes==0.2.1
213 | transformers==4.30.2
214 | trimesh==3.18.3
215 | triton==1.1.1
216 | typing-extensions==4.5.0
217 | typing-inspect==0.9.0
218 | uc-micro-py==1.0.2
219 | uri-template==1.2.0
220 | urllib3==1.26.14
221 | uvicorn==0.23.2
222 | wandb==0.13.10
223 | wcwidth==0.2.6
224 | webcolors==1.12
225 | webdataset==0.2.33
226 | webencodings==0.5.1
227 | websocket-client==1.5.1
228 | websockets==11.0.3
229 | werkzeug==2.2.3
230 | widgetsnbextension==4.0.5
231 | wrapt==1.15.0
232 | xatlas==0.0.7
233 | yacs==0.1.8
234 | yarl==1.8.2
235 | zipp==3.14.0
236 | # torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
237 | https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=db457a822d736013b6ffe509053001bc918bdd78fe68967b605f53984a9afac5
238 | https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=a9fc38040e133d1779f131b4497caef830e9e699faf89cd323cd58794ffb305b
239 | https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp39-cp39-linux_x86_64.whl#sha256=5bc0e29cb78f7c452eeb4f27029c40049770d51553bf840b4ca2edd63da289ee
240 | # torch-cluster
241 | https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_cluster-1.6.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl
242 | # torch-scatter
243 | https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl
244 | torchmetrics==0.11.1
245 | pytorch_lightning~=1.9.3
246 | git+https://github.com/pyvista/fast-simplification.git
247 | git+https://github.com/skoch9/meshplot.git
248 | git+https://github.com/NVlabs/nvdiffrast/
--------------------------------------------------------------------------------
/scripts/infer.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --task reconstruction \
3 | --config_path ./configs/aligned_shape_latents/shapevae-256.yaml \
4 | --ckpt_path ./checkpoints/aligned_shape_latents/shapevae-256.ckpt \
5 | --pointcloud_path ./example_data/surface/surface.npz
6 |
7 | python inference.py \
8 | --task image2mesh \
9 | --config_path ./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml \
10 | --ckpt_path ./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt \
11 | --image_path ./example_data/image/car.jpg
12 |
13 | python inference.py \
14 | --task text2mesh \
15 | --config_path ./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml \
16 | --ckpt_path ./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt \
17 | --text "A 3D model of motorcar; Porche Cayenne Turbo."
--------------------------------------------------------------------------------
/scripts/inference/image2mesh.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --task image2mesh \
3 | --config_path ./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml \
4 | --ckpt_path ./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt \
5 | --image_path ./example_data/image/car.jpg
--------------------------------------------------------------------------------
/scripts/inference/reconstruction.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --task reconstruction \
3 | --config_path ./configs/aligned_shape_latents/shapevae-256.yaml \
4 | --ckpt_path ./checkpoints/aligned_shape_latents/shapevae-256.ckpt \
5 | --pointcloud_path ./example_data/surface/surface.npz
--------------------------------------------------------------------------------
/scripts/inference/text2mesh.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --task text2mesh \
3 | --config_path ./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml \
4 | --ckpt_path ./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt \
5 | --text "A 3D model of motorcar; Porche Cayenne Turbo."
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 | from distutils.extension import Extension
4 | from Cython.Build import cythonize
5 | import numpy as np
6 |
7 |
8 | setup(
9 | name="michelangelo",
10 | version="0.4.1",
11 | author="Zibo Zhao, Wen Liu and Xin Chen",
12 | author_email="liuwen@shanghaitech.edu.cn",
13 | description="Michelangelo: a 3D Shape Generation System.",
14 | packages=find_packages(exclude=("configs", "tests", "scripts", "example_data")),
15 | python_requires=">=3.8",
16 | install_requires=[
17 | "torch",
18 | "numpy",
19 | "cython",
20 | "tqdm",
21 | ],
22 | )
23 |
--------------------------------------------------------------------------------