├── LICENSE ├── README.md ├── assets ├── images │ ├── meshxl_logo.jpg │ ├── objaverse-samples.png │ ├── pipeline.png │ ├── teaser.png │ └── text-to-mesh-samples.png └── videos │ └── mesh_grid.mp4 ├── config └── deepspeed_stage2.yaml ├── data └── README.md ├── datasets ├── __init__.py ├── dummy_dataset.py └── sft │ ├── base_dataset.py │ ├── shapenet_bench.py │ ├── shapenet_chair.py │ ├── shapenet_lamp.py │ └── shapenet_table.py ├── engine.py ├── eval_utils ├── perplexity.py └── sample_generation.py ├── main.py ├── mesh-xl ├── mesh-xl-1.3b │ ├── config.json │ ├── generation_config.json │ └── pytorch_model.bin ├── mesh-xl-125m │ ├── config.json │ ├── generation_config.json │ └── pytorch_model.bin ├── mesh-xl-350m │ ├── config.json │ ├── generation_config.json │ └── pytorch_model.bin └── x-mesh-xl-350m │ ├── config.json │ ├── generation_config.json │ └── pytorch_model.bin ├── models ├── mesh_xl │ ├── get_model.py │ └── tokenizer.py └── x_mesh_xl │ ├── get_model.py │ └── tokenizer.py ├── openai └── clip-vit-base-patch32 │ ├── config.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json ├── requirement.txt ├── sample_t2m.py ├── scripts ├── meshxl-sft-shapenet.sh ├── sample-1.3b.sh ├── sample-125m.sh ├── sample-350m.sh └── sample-t2mesh.sh ├── set_env.sh └── utils ├── dist.py ├── io.py ├── logger.py ├── misc.py ├── nms.py ├── pc_util.py └── ply_helper.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 OpenMeshLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

Official repo for MeshXL

4 | 5 |
6 | 7 |
8 |

MeshXL: Neural Coordinate Field for Generative 3D Foundation Models

9 | 10 |

11 | Project Page • 12 | Arxiv Paper • 13 | HuggingFace Demo • 14 | Citation 15 |

16 | 17 |
18 | 19 | 20 | ## 🏃 Intro MeshXL 21 | 22 | **MeshXL** is a family of generative pre-trained foundation models for 3D mesh generation. With the Neural Coordinate Field representation, the generation of unstructured 3D mesh data can be seaminglessly addressed by modern LLM methods. In this paper, we validate the **Neur**al **C**oordinate **F**ield (NeurCF), an explicit coordinate representation with implicit neural embeddings, is a simple-yet-effective representation for large-scale sequential mesh modeling. 23 | 24 | pipeline 25 | 26 | 27 | ## 🚩 News 28 | 29 | - [2025/04/04] Upload pre-processed text-to-mesh generation data to [huggingface🤗](https://huggingface.co/datasets/CH3COOK/MeshXL-text-to-mesh-sketchfab). 30 | - [2024/12/12] Upload pre-processed ShapeNet data to [huggingface🤗](https://huggingface.co/datasets/CH3COOK/MeshXL-shapenet-data) and supervised fine-tuning scripts on specified categories. 31 | - [2024/09/26] MeshXL is accepted to **NeurIPS 2024**🔥! See you in Vancouver! 32 | - [2024/08/29] Upload code and 🤗[weights](https://huggingface.co/CH3COOK/x-mesh-xl-350m/blob/main/pytorch_model.bin) for text-to-mesh generation, welcome to check it out! 33 | - [2024/07/24] Upload the inference code and pre-trained weights. 34 | - [2024/06/02] Upload paper and init project. 35 | 36 | 37 | ## ⚡ Quick Start 38 | 39 | 40 |
41 | Environment Setting Up 42 | 43 | You can build the environment using the provided script: 44 | ```{bashrc} 45 | bash set_env.sh 46 | ``` 47 | 48 |
49 | 50 | 51 |
52 | Data 53 | Work in Progress... 54 | 55 |
56 | 57 | 58 | 59 | 60 | 61 | ## 💻 Training and Evaluation 62 | 63 |
64 | Download Pre-Trained Weights 65 | 66 | **\[Special Notes\]**: All the following models are generative pre-trained base models. They are capable of **unconditional** 3D mesh generation and **partial mesh completion**. 67 | 68 | We provide pre-trained weights for different sizes of models (i.e. `125m`, `350m`, and `1.3b`) on huggingface🤗. Download the pre-trained weights from the links below to replace the `pytorch_model.bin` files in the corresponding folders under the `./mesh-xl/` folder. The model details are shown below: 69 | 70 | | Model Size | #Layers | #Heads | $d_\text{model}$ | $d_\text{FFN}$ | GPU Hours | Download Link | 71 | |:----------:|:-------:|:------:|:----------------:|:--------------:|:---------:|:---------------------------------------------------:| 72 | | 125M | 12 | 12 | 768 | 3072 | 1944 | [download link](https://huggingface.co/CH3COOK/mesh-xl-125m) | 73 | | 350M | 24 | 16 | 1024 | 4096 | 6000 | [download link](https://huggingface.co/CH3COOK/mesh-xl-350m) | 74 | | 1.3B | 24 | 32 | 2048 | 8192 | 23232 | [download link](https://huggingface.co/CH3COOK/mesh-xl-1.3b) | 75 | 76 | Use the following command for fast downloading: 77 | ``` 78 | cd ./mesh-xl 79 | git lfs clone https://huggingface.co/CH3COOK/mesh-xl-125m 80 | git lfs clone https://huggingface.co/CH3COOK/mesh-xl-350m 81 | git lfs clone https://huggingface.co/CH3COOK/mesh-xl-1.3b 82 | cd .. 83 | ``` 84 | 85 |
86 | 87 | 88 |
89 | MeshXL Generative Pre-Training 90 | 91 | Work in progress... 92 | 93 |
94 | 95 | 96 | 97 |
98 | Generating Samples 99 | 100 | samples 101 | 102 | To generate 3D meshes with different sizes, feel free to use the following commands. By default, we generate samples with 8 GPUs and the top-k top-p sampling strategy for diverse samples. 103 | 104 | ```{bashrc} 105 | bash scripts/sample-1.3b.sh 106 | bash scripts/sample-350m.sh 107 | bash scripts/sample-125m.sh 108 | ``` 109 | 110 | **\[Special Notes\]**: The following weights are fine-tuned for **unconditional** 3D mesh generation on a **specified** category. 111 | 112 | Want to generating shapes for a specified category? We have also uploaded the supervised fine-tuned checkpoints on `chair`, `table`, `bench`, `lamp` to huggingface too! Download the fine-tuned weights from the links🤗 below. 113 | 114 | | Model Size | Table | Chair | Lamp | Bench | 115 | |:----------:|:-----------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:| 116 | | 125M | [download link](https://huggingface.co/CH3COOK/MeshXL-125m-sft/blob/main/meshxl-125m-table.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-125m-sft/blob/main/meshxl-125m-chair.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-125m-sft/blob/main/meshxl-125m-lamp.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-125m-sft/blob/main/meshxl-125m-bench.pth) | 117 | | 350M | [download link](https://huggingface.co/CH3COOK/MeshXL-350m-sft/blob/main/meshxl-350m-table.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-350m-sft/blob/main/meshxl-350m-chair.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-350m-sft/blob/main/meshxl-350m-lamp.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-350m-sft/blob/main/meshxl-350m-bench.pth) | 118 | | 1.3B | [download link](https://huggingface.co/CH3COOK/MeshXL-1.3b-sft/blob/main/meshxl-1.3b-table.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-1.3b-sft/blob/main/meshxl-1.3b-chair.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-1.3b-sft/blob/main/meshxl-1.3b-lamp.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-1.3b-sft/blob/main/meshxl-1.3b-bench.pth) | 119 | 120 | 121 | After you have downloaded the corresponding checkpoints, please use the following command to generate samples. 122 | 123 | ```{bashrc} 124 | export LLM_CONFIG='mesh-xl/mesh-xl-125m' 125 | # the checkpoint mush align with the $LLM_CONFIG 126 | export TEST_CKPT='./ckpts-meshxl-125m-sft/meshxl-125m-bench.pth' 127 | 128 | accelerate launch \ 129 | --num_machines 1 \ 130 | --num_processes 8 \ 131 | --mixed_precision bf16 \ 132 | main.py \ 133 | --dataset dummy_dataset \ 134 | --n_max_triangles 800 \ 135 | --n_discrete_size 128 \ 136 | --llm mesh-xl/mesh-xl-125m \ 137 | --model mesh_xl \ 138 | --checkpoint_dir ./outputs \ 139 | --batchsize_per_gpu 2 \ 140 | --test_ckpt $TEST_CKPT \ 141 | --sample_rounds 100 \ 142 | --dataset_num_workers 0 \ 143 | --test_only 144 | ``` 145 | 146 | Want to see more results? Check our project page out [here](https://meshxl.github.io/)! 147 | 148 |
149 | 150 | 151 |
152 | Text-to-Mesh Generation 153 | 154 | samples 155 | 156 | We thank the awesome language annotations from [PointLLM](https://github.com/OpenRobotLab/PointLLM) for object captions. We fine-tune a `350m` MeshXL model on Objaverse with 8 RTX-3090 GPUs. 157 | 158 | **Note:** please download the pre-trained checkpoint from [huggingface](https://huggingface.co/CH3COOK/x-mesh-xl-350m/blob/main/pytorch_model.bin)🤗 to replace the `./mesh-xl/x-mesh-xl-350m/pytorch_model.bin` file. 159 | 160 | We are actively working on Gradio demos. Currently, we encourage you to generate samples locally with at least 1 GPU with the following code: 161 | 162 | ```{bashrc} 163 | bash scripts/sample-t2mesh.sh 164 | ``` 165 | 166 | You are also welcome to explore other text conditions and hyper-parameters for better controls: 167 | 168 | ```{bashrc} 169 | accelerate launch \ 170 | --num_machines 1 \ 171 | --num_processes 1 \ 172 | --mixed_precision bf16 \ 173 | sample_t2m.py \ 174 | --test_ckpt mesh-xl/x-mesh-xl-350m/pytorch_model.bin \ 175 | --text '3d model of a table' \ # change to the text prompt you need 176 | --top_k 50 \ # larger k -> larger randomness 177 | --top_p 0.95 \ # larger p -> larger randomness 178 | --temperature 0.1 # larger temperature -> larger randomness 179 | ``` 180 | 181 | [Update 2025-04-04] We have uploaded the training and evaluation data used for text-to-mesh generation to [huggingface](https://huggingface.co/datasets/CH3COOK/MeshXL-text-to-mesh-sketchfab). You can load one data chunk through: 182 | 183 | ```{python} 184 | import numpy as np 185 | data = np.load('ext_text_to_mesh/sketchfab_chunk/sketchfab_pointllm_train_0000.npz', allow_pickle=True)["arr_0"].tolist() 186 | # To check the caption for a specified object: 187 | data[0]['caption'] 188 | 189 | >>> ['3D white arrow on gray background'] 190 | ``` 191 |
192 | 193 | 194 | 195 |
196 | MeshXL Supervised Fine-Tuning 197 | 198 | Please first download the pre-processed ShapeNet data to the `./data` folder from huggingface: 199 | ``` 200 | cd ./data 201 | git lfs clone https://huggingface.co/datasets/CH3COOK/MeshXL-shapenet-data 202 | cd .. 203 | ``` 204 | Then, use the following command for specified categories: 205 | ``` 206 | export BASE_MESHXL=mesh-xl/mesh-xl-1.3b # TODO: change the MeshXL config 207 | export BATCHSIZE_PER_GPU=4 # TODO: change the training batch size 208 | 209 | accelerate launch \ 210 | --config_file ./config/deepspeed_stage2.yaml \ 211 | --num_machines 1 \ 212 | --num_processes 8 \ 213 | --mixed_precision bf16 \ 214 | main.py \ 215 | --dataset sft.shapenet_table \ # TODO: change the dataset filename 216 | --n_max_triangles 800 \ 217 | --n_discrete_size 128 \ 218 | --warm_lr_iters -1 \ 219 | --base_lr 1e-6 \ 220 | --llm $BASE_MESHXL \ 221 | --model mesh_xl \ 222 | --checkpoint_dir ./ckpts/meshxl-shapenet-sft-table \ 223 | --batchsize_per_gpu $BATCHSIZE_PER_GPU \ 224 | --dataset_num_workers 0 \ 225 | --augment \ 226 | --eval_every_iteration 10000 \ 227 | --save_every 20000 \ 228 | --max_epoch 1024 229 | ``` 230 | 231 | 232 | 233 |
234 | 235 | 236 | ## 📖 Citation 237 | 238 | If you find our code or paper helps, please consider citing: 239 | 240 | ```bibtex 241 | @misc{chen2024meshxl, 242 | title={MeshXL: Neural Coordinate Field for Generative 3D Foundation Models}, 243 | author={Sijin Chen and Xin Chen and Anqi Pang and Xianfang Zeng and Wei Cheng and Yijun Fu and Fukun Yin and Yanru Wang and Zhibin Wang and Chi Zhang and Jingyi Yu and Gang Yu and Bin Fu and Tao Chen}, 244 | year={2024}, 245 | eprint={2405.20853}, 246 | archivePrefix={arXiv}, 247 | primaryClass={cs.CV} 248 | } 249 | ``` 250 | 251 | ## Acknowledgments 252 | We use [Paint3D](https://github.com/OpenTexture/Paint3D) for texturing generated 3D meshes. 253 | 254 | We express our genuine thanks to the amazing work: [ShapeNet](https://shapenet.org/), [3D-FUTURE](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future-cn), [Objaverse](https://github.com/allenai/objaverse-xl), [Objaverse-XL](https://github.com/allenai/objaverse-xl), [PolyGen](https://github.com/google-deepmind/deepmind-research/blob/master/polygen/README.md), [Get3D](https://github.com/nv-tlabs/GET3D) and [MeshGPT](https://github.com/nihalsid/mesh-gpt), and the amazing [MeshGPT-pytorch](https://github.com/lucidrains/meshgpt-pytorch) codebase. 255 | -------------------------------------------------------------------------------- /assets/images/meshxl_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/meshxl_logo.jpg -------------------------------------------------------------------------------- /assets/images/objaverse-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/objaverse-samples.png -------------------------------------------------------------------------------- /assets/images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/pipeline.png -------------------------------------------------------------------------------- /assets/images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/teaser.png -------------------------------------------------------------------------------- /assets/images/text-to-mesh-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/text-to-mesh-samples.png -------------------------------------------------------------------------------- /assets/videos/mesh_grid.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/videos/mesh_grid.mp4 -------------------------------------------------------------------------------- /config/deepspeed_stage2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | fsdp_config: {} 11 | machine_rank: 0 12 | main_process_ip: null 13 | main_process_port: null 14 | main_training_function: main 15 | mixed_precision: fp16 16 | num_machines: 4 17 | num_processes: 32 18 | use_cpu: false -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Dataset Preparation 2 | 3 | 4 | ### ShapeNet SFT Data 5 | 6 | Download the pre-processed ShapeNet data from huggingface: 7 | ``` 8 | git lfs clone https://huggingface.co/datasets/CH3COOK/MeshXL-shapenet-data 9 | ``` -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/dummy_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from copy import deepcopy 5 | from eval_utils.sample_generation import evaluate 6 | 7 | class Dataset: 8 | 9 | def __init__(self, *args, **kwargs): 10 | super().__init__() 11 | self.eval_func = evaluate 12 | 13 | def __len__(self): 14 | return 10 15 | 16 | def __getitem__(self, idx): 17 | data_dict = {} 18 | data_dict['shape_idx'] = np.asarray(idx).astype(np.int64) 19 | return data_dict 20 | -------------------------------------------------------------------------------- /datasets/sft/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from copy import deepcopy 5 | from eval_utils.perplexity import evaluate 6 | 7 | BASE_DIR = os.path.join('.', 'data') 8 | 9 | 10 | def scale_mesh(vertices: np.ndarray, scale_factor: tuple=(0.75, 1.25)) -> np.ndarray: 11 | lower, upper = scale_factor 12 | scale_axis = lower + np.random.rand(vertices.shape[-1]) * (upper - lower) 13 | vertices = vertices * scale_axis 14 | return vertices 15 | 16 | 17 | def normalize_mesh(vertices: np.ndarray, scale_range: tuple=(-1.0, 1.0)) -> np.ndarray: 18 | lower, upper = scale_range 19 | scale_per_axis = (vertices.max(0) - vertices.min(0)).max() 20 | center_xyz = 0.5 * (vertices.max(0) + vertices.min(0)) 21 | normalized_xyz = (vertices - center_xyz) / scale_per_axis # scaled into range (0, 1) 22 | vertices = normalized_xyz * (upper - lower) 23 | return vertices 24 | 25 | 26 | class BaseDataset: 27 | 28 | def __init__(self, *args, **kwargs): 29 | super().__init__() 30 | self.data =[] 31 | self.num_repeat = 1 32 | 33 | def _preprocess_data(self, data_chunk): 34 | processed = [] 35 | for data in data_chunk: 36 | processed.append( 37 | dict( 38 | vertices = np.asarray(data['vertices']), 39 | faces = np.asarray(data['faces']), 40 | ) 41 | ) 42 | return processed 43 | 44 | def _fetch_data(self, idx): 45 | idx = idx % len(self.data) 46 | return deepcopy(self.data[idx]) 47 | 48 | def __len__(self): 49 | return len(self.data) * self.num_repeat 50 | 51 | def __getitem__(self, idx): 52 | data = self._fetch_data(idx) 53 | data['vertices'] = np.asarray(data['vertices']) 54 | data['faces'] = np.asarray(data['faces']) 55 | 56 | num_vertices = len(data['vertices']) 57 | num_faces = len(data['faces']) 58 | 59 | vertices = np.ones((self.max_vertices, 3)) * self.pad_id 60 | faces = np.ones((self.max_triangles, 3)) * self.pad_id 61 | 62 | if self.augment is True: 63 | data['vertices'] = scale_mesh(data['vertices']) 64 | data['vertices'] = normalize_mesh(data['vertices']) 65 | 66 | vertices[:num_vertices] = data['vertices'] 67 | faces[:num_faces] = data['faces'] 68 | 69 | gt_vertices = vertices[faces.clip(0).astype(np.int64)] # nface x 3 x 3 70 | gt_vertices[faces[:, 0] == self.pad_id] = float('nan') 71 | 72 | data_dict = {} 73 | data_dict['shape_idx'] = np.asarray(idx).astype(np.int64) 74 | data_dict['vertices'] = np.asarray(vertices).astype(np.float32) 75 | data_dict['faces'] = np.asarray(faces).astype(np.int64) 76 | data_dict['gt_vertices'] = np.asarray(gt_vertices).astype(np.float32) 77 | 78 | return data_dict 79 | -------------------------------------------------------------------------------- /datasets/sft/shapenet_bench.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from eval_utils.perplexity import evaluate 5 | from datasets.sft.base_dataset import BaseDataset, BASE_DIR 6 | 7 | DATASET_DIR = os.path.join(BASE_DIR, 'MeshXL-shapenet-data') 8 | 9 | 10 | 11 | class Dataset(BaseDataset): 12 | 13 | def __init__(self, args, split_set="train", augment=False): 14 | super().__init__() 15 | 16 | # base dataset config 17 | self.dataset_name = 'shapenet_bench' 18 | self.category_id = '02828884' 19 | self.eval_func = evaluate 20 | self.augment = augment and (split_set == 'train') 21 | self.num_repeat = 1 22 | self.pad_id = -1 23 | self.max_triangles = args.n_max_triangles 24 | self.max_vertices = self.max_triangles * 3 25 | 26 | # pre-load data into memory 27 | full_data = [] 28 | for filename in tqdm(os.listdir(DATASET_DIR)): 29 | if self.category_id not in filename: 30 | continue 31 | if (split_set in filename) and filename.endswith('.npz'): 32 | loaded_data = np.load( 33 | os.path.join(DATASET_DIR, filename), 34 | allow_pickle=True 35 | ) 36 | loaded_data = loaded_data["arr_0"].tolist() 37 | loaded_data = self._preprocess_data(loaded_data) 38 | full_data = full_data + loaded_data 39 | 40 | self.data = full_data 41 | 42 | print(f"[MeshDataset] Created from {len(self.data)} shapes for {self.dataset_name} {split_set}") 43 | -------------------------------------------------------------------------------- /datasets/sft/shapenet_chair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from eval_utils.perplexity import evaluate 5 | from datasets.sft.base_dataset import BaseDataset, BASE_DIR 6 | 7 | DATASET_DIR = os.path.join(BASE_DIR, 'MeshXL-shapenet-data') 8 | 9 | 10 | 11 | class Dataset(BaseDataset): 12 | 13 | def __init__(self, args, split_set="train", augment=False): 14 | super().__init__() 15 | 16 | # base dataset config 17 | self.dataset_name = 'shapenet_chair' 18 | self.category_id = '03001627' 19 | self.eval_func = evaluate 20 | self.augment = augment and (split_set == 'train') 21 | self.num_repeat = 1 22 | self.pad_id = -1 23 | self.max_triangles = args.n_max_triangles 24 | self.max_vertices = self.max_triangles * 3 25 | 26 | # pre-load data into memory 27 | full_data = [] 28 | for filename in tqdm(os.listdir(DATASET_DIR)): 29 | if self.category_id not in filename: 30 | continue 31 | if (split_set in filename) and filename.endswith('.npz'): 32 | loaded_data = np.load( 33 | os.path.join(DATASET_DIR, filename), 34 | allow_pickle=True 35 | ) 36 | loaded_data = loaded_data["arr_0"].tolist() 37 | loaded_data = self._preprocess_data(loaded_data) 38 | full_data = full_data + loaded_data 39 | 40 | self.data = full_data 41 | 42 | print(f"[MeshDataset] Created from {len(self.data)} shapes for {self.dataset_name} {split_set}") 43 | -------------------------------------------------------------------------------- /datasets/sft/shapenet_lamp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from eval_utils.perplexity import evaluate 5 | from datasets.sft.base_dataset import BaseDataset, BASE_DIR 6 | 7 | DATASET_DIR = os.path.join(BASE_DIR, 'MeshXL-shapenet-data') 8 | 9 | 10 | 11 | class Dataset(BaseDataset): 12 | 13 | def __init__(self, args, split_set="train", augment=False): 14 | super().__init__() 15 | 16 | # base dataset config 17 | self.dataset_name = 'shapenet_lamp' 18 | self.category_id = '03636649' 19 | self.eval_func = evaluate 20 | self.augment = augment and (split_set == 'train') 21 | self.num_repeat = 1 22 | self.pad_id = -1 23 | self.max_triangles = args.n_max_triangles 24 | self.max_vertices = self.max_triangles * 3 25 | 26 | # pre-load data into memory 27 | full_data = [] 28 | for filename in tqdm(os.listdir(DATASET_DIR)): 29 | if self.category_id not in filename: 30 | continue 31 | if (split_set in filename) and filename.endswith('.npz'): 32 | loaded_data = np.load( 33 | os.path.join(DATASET_DIR, filename), 34 | allow_pickle=True 35 | ) 36 | loaded_data = loaded_data["arr_0"].tolist() 37 | loaded_data = self._preprocess_data(loaded_data) 38 | full_data = full_data + loaded_data 39 | 40 | self.data = full_data 41 | 42 | print(f"[MeshDataset] Created from {len(self.data)} shapes for {self.dataset_name} {split_set}") 43 | -------------------------------------------------------------------------------- /datasets/sft/shapenet_table.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from eval_utils.perplexity import evaluate 5 | from datasets.sft.base_dataset import BaseDataset, BASE_DIR 6 | 7 | DATASET_DIR = os.path.join(BASE_DIR, 'MeshXL-shapenet-data') 8 | 9 | 10 | 11 | class Dataset(BaseDataset): 12 | 13 | def __init__(self, args, split_set="train", augment=False): 14 | super().__init__() 15 | 16 | # base dataset config 17 | self.dataset_name = 'shapenet_table' 18 | self.category_id = '04379243' 19 | self.eval_func = evaluate 20 | self.augment = augment and (split_set == 'train') 21 | self.num_repeat = 1 22 | self.pad_id = -1 23 | self.max_triangles = args.n_max_triangles 24 | self.max_vertices = self.max_triangles * 3 25 | 26 | # pre-load data into memory 27 | full_data = [] 28 | for filename in tqdm(os.listdir(DATASET_DIR)): 29 | if self.category_id not in filename: 30 | continue 31 | if (split_set in filename) and filename.endswith('.npz'): 32 | loaded_data = np.load( 33 | os.path.join(DATASET_DIR, filename), 34 | allow_pickle=True 35 | ) 36 | loaded_data = loaded_data["arr_0"].tolist() 37 | loaded_data = self._preprocess_data(loaded_data) 38 | full_data = full_data + loaded_data 39 | 40 | self.data = full_data 41 | 42 | print(f"[MeshDataset] Created from {len(self.data)} shapes for {self.dataset_name} {split_set}") 43 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import time, math 2 | import torch 3 | import datetime 4 | 5 | from utils.io import save_checkpoint 6 | from utils.misc import SmoothedValue 7 | 8 | 9 | 10 | def compute_learning_rate(args, curr_iter, max_iters): 11 | assert curr_iter <= max_iters and curr_iter >= 0 12 | if (curr_iter <= args.warm_lr_iters) and args.warm_lr_iters > 0: 13 | # Linear Warmup: warm_lr -> curr_lr -> base_lr 14 | curr_lr = args.warm_lr + curr_iter / args.warm_lr_iters * (args.base_lr - args.warm_lr) 15 | else: 16 | # Cosine Learning Rate Schedule 17 | curr_lr = args.final_lr + 0.5 * (args.base_lr - args.final_lr) * ( 18 | 1 + math.cos(math.pi * curr_iter / max_iters) 19 | ) 20 | return curr_lr 21 | 22 | 23 | 24 | def adjust_learning_rate(args, optimizer, curr_iter, max_iters): 25 | curr_lr = compute_learning_rate(args, curr_iter, max_iters) 26 | for param_group in optimizer.param_groups: 27 | param_group["lr"] = curr_lr 28 | return curr_lr 29 | 30 | 31 | 32 | def do_train( 33 | args, 34 | model, 35 | accelerator, 36 | optimizer, 37 | dataloaders, 38 | best_val_metrics, 39 | logger 40 | ): 41 | 42 | if accelerator.is_main_process: 43 | logger.log_messages(f"call with args: {args}") 44 | logger.log_messages(f"{model}") 45 | 46 | curr_iter = args.start_epoch * len(dataloaders['train']) 47 | max_iters = args.max_epoch * len(dataloaders['train']) 48 | 49 | time_delta = SmoothedValue(window_size=10) 50 | loss_avg = SmoothedValue(window_size=10) 51 | loss_break_down_avg = {} 52 | 53 | model.train() 54 | accelerator.wait_for_everyone() 55 | 56 | for curr_epoch in range(args.start_epoch, args.max_epoch): 57 | 58 | for batch_idx, batch_data_label in enumerate(dataloaders['train']): 59 | 60 | curr_time = time.time() 61 | 62 | ### core for model training 63 | 64 | curr_iter = curr_epoch * len(dataloaders['train']) + batch_idx 65 | curr_lr = adjust_learning_rate(args, optimizer, curr_iter, max_iters) 66 | 67 | with accelerator.accumulate(model): 68 | 69 | with accelerator.autocast(): 70 | outputs = model(batch_data_label) 71 | loss = outputs['loss'] 72 | 73 | # sanity check, skip the infinite loss 74 | if not math.isfinite(loss.item()): 75 | logger.log_messages("Loss in not finite. Skip this iteration.") 76 | model.eval() 77 | model.train() 78 | torch.cuda.empty_cache() 79 | continue 80 | 81 | accelerator.backward(loss) 82 | if args.clip_gradient > 0: 83 | accelerator.clip_grad_norm_(model.parameters(), args.clip_gradient) 84 | 85 | optimizer.step() 86 | optimizer.zero_grad() 87 | 88 | ### logging training loss status 89 | 90 | time_delta.update(time.time() - curr_time) 91 | loss_avg.update(loss.item()) 92 | 93 | for key, value in outputs.items(): 94 | if 'loss' in key.lower(): 95 | loss_break_down_avg[key] = loss_break_down_avg.get(key, SmoothedValue(window_size=10)) 96 | loss_break_down_avg[key].update(value.item()) 97 | 98 | ### writing logs 99 | 100 | if accelerator.is_main_process and curr_iter % args.log_every == 0: 101 | mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2) 102 | eta_seconds = (max_iters - curr_iter) * time_delta.avg 103 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) 104 | 105 | logger.log_messages( 106 | '; '.join( 107 | [ 108 | f"Epoch [{curr_epoch}/{args.max_epoch}]", 109 | f"Iter [{curr_iter}/{max_iters}]", 110 | # loss string 111 | *( 112 | f'{key} {avg.avg:0.4f}' \ 113 | for key, avg in loss_break_down_avg.items() 114 | ), 115 | # status string 116 | f"LR {curr_lr:0.2e}", 117 | f"Iter time {time_delta.avg:0.2f}s", 118 | f"ETA {eta_str}", 119 | f"Mem {mem_mb:0.2f}MB" 120 | ] 121 | ) 122 | ) 123 | train_loss_log = {k: v.avg for k, v in loss_break_down_avg.items()} 124 | train_loss_log["learning_rate"] = curr_lr 125 | logger.log_scalars(train_loss_log, prefix='train', step=curr_iter) 126 | 127 | ### saving checkpoints 128 | 129 | if accelerator.is_main_process and (curr_iter + 1) % args.save_every == 0: 130 | save_checkpoint( 131 | args.checkpoint_dir, 132 | accelerator.unwrap_model(model), 133 | optimizer, 134 | curr_epoch, 135 | args, 136 | best_val_metrics, 137 | filename=f"checkpoint_{(curr_iter + 1) // 1000}k.pth", 138 | ) 139 | 140 | ### pending and doing evaluations: every xxx after xxx iterations 141 | 142 | do_eval_flag = (curr_iter + 1) % args.eval_every_iteration == 0 143 | do_eval_flag &= (curr_iter + 1) > args.start_eval_after 144 | do_eval_flag |= (curr_iter + 1) == max_iters 145 | 146 | if do_eval_flag is True: 147 | eval_metrics = {} 148 | model.eval() 149 | with accelerator.autocast(): 150 | for test_loader in dataloaders['test']: 151 | task_metrics, eval_loss_dict = test_loader.dataset.eval_func( 152 | args, 153 | curr_epoch, 154 | accelerator.unwrap_model(model), 155 | accelerator, 156 | test_loader, 157 | logger, 158 | curr_train_iter=curr_iter 159 | ) 160 | eval_metrics.update(task_metrics) 161 | logger.log_scalars(eval_loss_dict, prefix='val', step=curr_iter) 162 | model.train() 163 | 164 | ### saving `checkpoint_best.pth` do nothing for unknown criterion 165 | 166 | if args.criterion is None: 167 | continue 168 | 169 | if not best_val_metrics or ( 170 | best_val_metrics[args.criterion] < eval_metrics[args.criterion] 171 | ): 172 | best_val_metrics = eval_metrics 173 | filename = "checkpoint_best.pth" 174 | save_checkpoint( 175 | args.checkpoint_dir, 176 | accelerator.unwrap_model(model), 177 | optimizer, 178 | curr_epoch, 179 | args, 180 | best_val_metrics, 181 | filename="checkpoint_best.pth", 182 | ) 183 | if accelerator.is_main_process: 184 | logger.log_messages( 185 | f"Epoch [{curr_epoch}/{args.max_epoch}] " 186 | f"saved current best val checkpoint at {filename}; " 187 | f"{args.criterion} {eval_metrics[args.criterion]}" 188 | ) 189 | 190 | ### end of an iteration 191 | 192 | ### end of an epoch 193 | 194 | save_checkpoint( 195 | args.checkpoint_dir, 196 | accelerator.unwrap_model(model), 197 | optimizer, 198 | curr_epoch, 199 | args, 200 | best_val_metrics, 201 | filename="checkpoint.pth", 202 | ) 203 | 204 | # end of training 205 | 206 | return 207 | -------------------------------------------------------------------------------- /eval_utils/perplexity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import numpy as np 5 | from torch import nn, Tensor 6 | from collections import defaultdict, OrderedDict 7 | from utils.ply_helper import write_ply 8 | from utils.misc import SmoothedValue 9 | from accelerate.utils import set_seed 10 | 11 | 12 | 13 | def perplexity(neg_log_likelihood: list) -> Tensor: 14 | # gather per-sequence log likelihood for perplexity 15 | nll_chunk = torch.cat(neg_log_likelihood, dim=0) 16 | return torch.exp(nll_chunk.mean()) 17 | 18 | 19 | 20 | def post_process_mesh(mesh_coords: Tensor, filename: str): 21 | mesh_coords = mesh_coords[~torch.isnan(mesh_coords[:, 0, 0])] # nvalid_face x 3 x 3 22 | vertices = mesh_coords.reshape(-1, 3) 23 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face 24 | triangles = vertices_index.reshape(-1, 3) 25 | write_ply( 26 | np.asarray(vertices.cpu()), 27 | None, 28 | np.asarray(triangles), 29 | filename 30 | ) 31 | return vertices 32 | 33 | 34 | 35 | @torch.no_grad() 36 | def evaluate( 37 | args, 38 | curr_epoch, 39 | model, 40 | accelerator, 41 | dataset_loader, 42 | logger, 43 | curr_train_iter=-1, 44 | ): 45 | 46 | model.eval() 47 | net_device = next(model.parameters()).device 48 | num_batches = len(dataset_loader) 49 | 50 | ### parse evaluation status 51 | if hasattr(dataset_loader.dataset, "dataset_name"): 52 | dataset_name = dataset_loader.dataset.dataset_name 53 | else: 54 | dataset_name = "default" 55 | task_name_prefix = dataset_name + '_' 56 | 57 | time_delta = SmoothedValue(window_size=10) 58 | 59 | accelerator.wait_for_everyone() 60 | 61 | epoch_str = f"[{curr_epoch}/{args.max_epoch}]" if curr_epoch > 0 else "" 62 | 63 | if accelerator.is_main_process: 64 | logger.log_messages("==" * 10) 65 | logger.log_messages(f"Evaluate Epoch [{curr_epoch}/{args.max_epoch}]") 66 | logger.log_messages("==" * 10) 67 | 68 | ### calculate perplexity 69 | neg_log_likelihood = [] 70 | for curr_iter, batch_data_label in enumerate(dataset_loader): 71 | 72 | curr_time = time.time() 73 | 74 | # forward pass to calculate per-sequence negative log likelihood 75 | with accelerator.autocast(): 76 | outputs = model(batch_data_label, is_eval=True) 77 | # [(batch,), (batch,), ...] 78 | neg_log_likelihood.append(outputs['neg_log_likelihood']) 79 | 80 | ### log status 81 | time_delta.update(time.time() - curr_time) 82 | 83 | if accelerator.is_main_process and curr_iter % args.log_every == 0: 84 | mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2) 85 | moving_average_ppl = perplexity(neg_log_likelihood) 86 | logger.log_messages( 87 | '; '.join( 88 | ( 89 | f"Evaluate {epoch_str}", 90 | f"Batch [{curr_iter}/{num_batches}]", 91 | f"perplexity: {moving_average_ppl:0.4f}", 92 | f"Evaluating on iter: {curr_train_iter}", 93 | f"Iter time {time_delta.avg:0.2f}", 94 | f"Mem {mem_mb:0.2f}MB", 95 | ) 96 | ) 97 | ) 98 | 99 | ### end of an iteration 100 | 101 | ### end of a round 102 | 103 | quantitative = { 104 | task_name_prefix + 'ppl': perplexity(neg_log_likelihood).item() 105 | } 106 | 107 | ### do sampling every evaluation 108 | curr_time = time.time() 109 | 110 | set_seed(accelerator.process_index) 111 | 112 | # create storage directory 113 | storage_dir = os.path.join(args.checkpoint_dir, task_name_prefix + 'visualization') 114 | if accelerator.is_main_process: 115 | os.makedirs(storage_dir, exist_ok = True) 116 | accelerator.wait_for_everyone() 117 | 118 | # just sample one round for checking training status 119 | for round_idx in range(1): 120 | outputs = model( 121 | data_dict=dict(), 122 | is_eval=True, 123 | is_generate=True, 124 | num_return_sequences=args.batchsize_per_gpu, 125 | ) 126 | 127 | generated_meshes = outputs["recon_faces"] # nsample x nf x 3 x 3 128 | 129 | for sample_idx in range(args.batchsize_per_gpu): 130 | # store the generated meshes 131 | post_process_mesh( 132 | generated_meshes[sample_idx], 133 | os.path.join( 134 | storage_dir, 135 | '_'.join( 136 | ( 137 | f'{accelerator.process_index:04d}', 138 | f'{round_idx:04d}', 139 | f'{sample_idx:04d}.ply', 140 | ) 141 | ) 142 | ) 143 | ) 144 | 145 | accelerator.wait_for_everyone() 146 | 147 | return {}, quantitative -------------------------------------------------------------------------------- /eval_utils/sample_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import tqdm 4 | import torch 5 | import numpy as np 6 | from torch import Tensor 7 | from collections import defaultdict, OrderedDict 8 | from utils.ply_helper import write_ply 9 | from utils.misc import SmoothedValue 10 | from accelerate.utils import set_seed 11 | 12 | 13 | def process_mesh(mesh_coords: Tensor, filename: str): 14 | mesh_coords = mesh_coords[~torch.isnan(mesh_coords[:, 0, 0])] # nvalid_face x 3 x 3 15 | vertices = mesh_coords.reshape(-1, 3) 16 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face 17 | triangles = vertices_index.reshape(-1, 3) 18 | write_ply( 19 | np.asarray(vertices.cpu()), 20 | None, 21 | np.asarray(triangles), 22 | filename 23 | ) 24 | return vertices 25 | 26 | 27 | @torch.no_grad() 28 | def evaluate( 29 | args, 30 | curr_epoch, 31 | model, 32 | accelerator, 33 | dataset_loader, 34 | logout=print, 35 | curr_train_iter=-1, 36 | ): 37 | 38 | net_device = next(model.parameters()).device 39 | num_batches = len(dataset_loader) 40 | 41 | time_delta = SmoothedValue(window_size=10) 42 | 43 | storage_dir = os.path.join(args.checkpoint_dir, 'sampled') 44 | if accelerator.is_main_process: 45 | os.makedirs(storage_dir, exist_ok = True) 46 | 47 | model.eval() 48 | accelerator.wait_for_everyone() 49 | 50 | # do sampling 51 | curr_time = time.time() 52 | 53 | set_seed(accelerator.process_index) 54 | 55 | for sample_round in tqdm.tqdm(range(args.sample_rounds)): 56 | 57 | outputs = model(None, num_return_sequences=args.batchsize_per_gpu, is_eval=True, is_generate=True) 58 | 59 | batch_size = outputs['recon_faces'].shape[0] 60 | generated_faces = outputs["recon_faces"] 61 | 62 | for batch_id in range(batch_size): 63 | process_info = f'{accelerator.process_index:04d}' 64 | sample_info = f'{sample_round:04d}' 65 | batch_sample_info = f'{batch_id:04d}' 66 | sample_id = '_'.join( 67 | [ 68 | process_info, 69 | sample_info, 70 | batch_sample_info 71 | ] 72 | ) 73 | process_mesh( 74 | generated_faces[batch_id], 75 | os.path.join(storage_dir, f'{sample_id}_generated.ply') 76 | ) 77 | 78 | # Memory intensive as it gathers point cloud GT tensor across all ranks 79 | time_delta.update(time.time() - curr_time) 80 | accelerator.wait_for_everyone() 81 | 82 | return {}, {} -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, argparse, importlib 2 | import numpy as np 3 | import torch 4 | 5 | from engine import do_train 6 | from accelerate import Accelerator 7 | from accelerate.utils import set_seed 8 | from utils.io import resume_if_possible 9 | from utils.misc import my_worker_init_fn 10 | from utils.logger import Logger 11 | 12 | 13 | def make_args_parser(): 14 | 15 | parser = argparse.ArgumentParser( 16 | "MeshXL: Neural Coordinate Field for Generative 3D Foundation Models", 17 | add_help=False 18 | ) 19 | 20 | ##### Optimizer ##### 21 | parser.add_argument("--base_lr", default=1e-4, type=float) 22 | parser.add_argument("--final_lr", default=1e-6, type=float) 23 | parser.add_argument("--weight_decay", default=0.1, type=float) 24 | parser.add_argument( 25 | "--clip_gradient", default=0.1, type=float, 26 | help="Max L2 norm of the gradient" 27 | ) 28 | parser.add_argument("--warm_lr", default=1e-6, type=float) 29 | parser.add_argument("--warm_lr_iters", default=1000, type=int) 30 | 31 | ##### Dataset Setups ##### 32 | parser.add_argument("--pad_id", default=-1, type=int, help="padding id") 33 | parser.add_argument("--dataset", default='shapenet_chair', help="dataset list split by ','") 34 | parser.add_argument("--augment", default=False, action='store_true', help="whether use data augmentation") 35 | parser.add_argument("--n_discrete_size", default=128, type=int, help="discretized 3D space") 36 | parser.add_argument("--n_max_triangles", default=800, type=int, help="max number of triangles") 37 | 38 | ##### Model Setups ##### 39 | parser.add_argument( 40 | '--model', 41 | default=None, 42 | type=str, 43 | help="The model folder: unconditional / conditional mesh generation" 44 | ) 45 | parser.add_argument( 46 | '--llm', 47 | default=None, 48 | type=str, 49 | help="The LLM super config and pre-trained weights" 50 | ) 51 | # conditonal mesh generation, set to None for unconditional generation 52 | parser.add_argument('--text_condition', default=None, type=str, help="the conditional language model") 53 | parser.add_argument('--image_condition', default=None, type=str, help="the conditional vision model") 54 | parser.add_argument('--pretrained_weights', default=None, type=str, help='checkpoint to pre-trained weights') 55 | parser.add_argument("--dataset_num_workers", default=4, type=int, help='number of workers for dataloader') 56 | parser.add_argument("--batchsize_per_gpu", default=8, type=int, help='batch size for each GPU') 57 | 58 | ##### Training ##### 59 | parser.add_argument("--start_epoch", default=-1, type=int, help='overwrite by pre-trained weights') 60 | parser.add_argument("--max_epoch", default=16, type=int, help='number of traversals for the dataset') 61 | parser.add_argument("--start_eval_after", default=-1, type=int, help='do not evaluate the model before xxx iterations') 62 | parser.add_argument("--eval_every_iteration", default=4000, type=int, help='do evaluate the model every xxx iterations') 63 | parser.add_argument("--seed", default=0, type=int, help='random seed') 64 | 65 | ##### Testing ##### 66 | parser.add_argument("--test_only", default=False, action="store_true") 67 | parser.add_argument("--sample_rounds", default=100, type=int, help='do sample for xxx rounds to produce 3D meshes') 68 | 69 | parser.add_argument( 70 | "--criterion", default=None, type=str, 71 | help='metrics for saving the best model, set to None for not saving any' 72 | ) 73 | parser.add_argument("--test_ckpt", default="", type=str, help='test checkpoint directory') 74 | 75 | ##### I/O ##### 76 | parser.add_argument("--checkpoint_dir", default=None, type=str, help='path to save the checkpoints and visualization samples') 77 | parser.add_argument("--save_every", default=20000, type=int, help='save checkpoints every xxx iterations') 78 | parser.add_argument("--log_every", default=10, type=int, help='write training logs every xxx iterations') 79 | 80 | args = parser.parse_args() 81 | 82 | return args 83 | 84 | 85 | def build_dataloader_func(args, dataset, split): 86 | 87 | if split == "train": 88 | sampler = torch.utils.data.RandomSampler(dataset) 89 | else: 90 | sampler = torch.utils.data.SequentialSampler(dataset) 91 | 92 | dataloader = torch.utils.data.DataLoader( 93 | dataset, 94 | sampler=sampler, 95 | batch_size=args.batchsize_per_gpu, 96 | num_workers=args.dataset_num_workers, 97 | worker_init_fn=my_worker_init_fn, 98 | # add for meshgpt 99 | drop_last = True, 100 | ) 101 | return sampler, dataloader 102 | 103 | 104 | def build_dataset_func(args): 105 | 106 | datasets = { 107 | 'train': [], 108 | 'test': [] 109 | } 110 | 111 | for dataset in args.dataset.split(','): 112 | dataset_module = importlib.import_module(f'datasets.{dataset}') 113 | datasets['train'].append( 114 | dataset_module.Dataset(args, split_set="train", augment=args.augment) 115 | ) 116 | datasets['test'].append( 117 | dataset_module.Dataset(args, split_set="val", augment=False) 118 | ) 119 | datasets['train'] = torch.utils.data.ConcatDataset(datasets['train']) 120 | 121 | train_sampler, train_loader = build_dataloader_func(args, datasets['train'], split='train') 122 | dataloaders = { 123 | 'train': train_loader, 124 | 'test': [], 125 | 'train_sampler': train_sampler, 126 | } 127 | 128 | for dataset in datasets['test']: 129 | _, test_loader = build_dataloader_func(args, dataset, split='test') 130 | dataloaders['test'].append(test_loader) 131 | 132 | return datasets, dataloaders 133 | 134 | 135 | def build_model_func(args): 136 | model_module = importlib.import_module(f'models.{args.model}.get_model') 137 | model = model_module.get_model(args) 138 | return model 139 | 140 | 141 | def main(args): 142 | 143 | np.random.seed(args.seed) 144 | torch.cuda.manual_seed_all(args.seed) 145 | 146 | if args.checkpoint_dir is not None: 147 | pass 148 | elif args.test_ckpt is not None: 149 | # if not define the checkpoint-dir, set to the test checkpoint folder as default 150 | args.checkpoint_dir = os.path.dirname(args.test_ckpt) 151 | print(f'testing directory: {args.checkpoint_dir}') 152 | else: 153 | raise AssertionError( 154 | 'Either checkpoint_dir or test_ckpt should be presented!' 155 | ) 156 | 157 | os.makedirs(args.checkpoint_dir, exist_ok=True) 158 | accelerator = Accelerator() 159 | set_seed(args.seed) 160 | 161 | ### build datasets and dataloaders 162 | datasets, dataloaders = build_dataset_func(args) 163 | 164 | ### build models 165 | model = build_model_func(args) 166 | ### set default checkpoint 167 | checkpoint = None 168 | 169 | # testing phase 170 | if args.test_only: 171 | 172 | try: 173 | checkpoint = torch.load(args.test_ckpt, map_location=torch.device("cpu")) 174 | model.load_state_dict(checkpoint["model"], strict=False) 175 | except: 176 | print('test the model from scratch...') 177 | 178 | model, dataloaders['train'], *dataloaders['test'] = accelerator.prepare( 179 | model, dataloaders['train'], *dataloaders['test'] 180 | ) 181 | 182 | for test_loader in dataloaders['test']: 183 | test_loader.dataset.eval_func( 184 | args, 185 | -1, 186 | model, 187 | accelerator, 188 | test_loader 189 | ) 190 | 191 | # training phase 192 | else: 193 | 194 | assert ( 195 | args.checkpoint_dir is not None 196 | ), "`--checkpoint_dir` is required to identify the directory to store the checkpoint" 197 | os.makedirs(args.checkpoint_dir, exist_ok=True) 198 | 199 | logger = Logger(args.checkpoint_dir, accelerator) 200 | 201 | ### whether or not use pretrained weights 202 | if args.pretrained_weights is not None: 203 | checkpoint = torch.load(args.pretrained_weights, map_location="cpu") 204 | model.load_state_dict(checkpoint['model'], strict=False) 205 | 206 | if accelerator.is_main_process: 207 | if checkpoint is not None: 208 | logger.log_messages('Loaded the parameters for weight initialization:') 209 | for name, param in checkpoint['model'].items(): 210 | logger.log_messages('\t'.join(['', name + ':', f'{param.shape}'])) 211 | logger.log_messages('\n' * 10) 212 | logger.log_messages('====\n') 213 | logger.log_messages('The trainable parameters are:') 214 | 215 | for name, param in model.named_parameters(): 216 | status = '[train]' if param.requires_grad else '[eval]' 217 | logger.log_messages('\t'.join(['', status, name + ':', f'{param.shape}'])) 218 | 219 | optimizer = torch.optim.AdamW( 220 | filter(lambda params: params.requires_grad, model.parameters()), 221 | lr=args.base_lr, 222 | weight_decay=args.weight_decay 223 | ) 224 | 225 | loaded_epoch, best_val_metrics = resume_if_possible( 226 | args.checkpoint_dir, model, optimizer 227 | ) 228 | args.start_epoch = loaded_epoch + 1 229 | 230 | model, optimizer, dataloaders['train'], *dataloaders['test'] = accelerator.prepare( 231 | model, optimizer, dataloaders['train'], *dataloaders['test'] 232 | ) 233 | 234 | do_train( 235 | args, 236 | model, 237 | accelerator, 238 | optimizer, 239 | dataloaders, 240 | best_val_metrics, 241 | logger 242 | ) 243 | 244 | 245 | if __name__ == "__main__": 246 | args = make_args_parser() 247 | 248 | os.environ['PYTHONWARNINGS']='ignore:semaphore_tracker:UserWarning' 249 | 250 | main(args) -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-1.3b/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "mesh-xl/mesh-xl-1.3b", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 8192, 14 | "hidden_size": 2048, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 32, 20 | "num_hidden_layers": 24, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "float16", 24 | "transformers_version": "4.21.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 50272, 27 | "word_embed_proj_dim": 2048 28 | } 29 | -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-1.3b/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 2, 4 | "eos_token_id": 2, 5 | "pad_token_id": 1, 6 | "transformers_version": "4.27.0.dev0" 7 | } 8 | -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-1.3b/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/mesh-xl/mesh-xl-1.3b/pytorch_model.bin -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-125m/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "mesh-xl/mesh-xl-125m", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 3072, 14 | "hidden_size": 768, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 12, 20 | "num_hidden_layers": 12, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "float16", 24 | "transformers_version": "4.21.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 50272, 27 | "word_embed_proj_dim": 768 28 | } 29 | -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-125m/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 2, 4 | "eos_token_id": 2, 5 | "pad_token_id": 1, 6 | "transformers_version": "4.27.0.dev0" 7 | } 8 | -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-125m/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/mesh-xl/mesh-xl-125m/pytorch_model.bin -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-350m/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "mesh-xl/mesh-xl-350m", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 4096, 14 | "hidden_size": 1024, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 16, 20 | "num_hidden_layers": 24, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "float16", 24 | "transformers_version": "4.20.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 50272, 27 | "word_embed_proj_dim": 1024 28 | } 29 | -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-350m/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 2, 4 | "eos_token_id": 2, 5 | "pad_token_id": 1, 6 | "transformers_version": "4.27.0.dev0" 7 | } 8 | -------------------------------------------------------------------------------- /mesh-xl/mesh-xl-350m/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/mesh-xl/mesh-xl-350m/pytorch_model.bin -------------------------------------------------------------------------------- /mesh-xl/x-mesh-xl-350m/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "mesh-xl/mesh-xl-350m", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 4096, 14 | "hidden_size": 1024, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 16, 20 | "num_hidden_layers": 24, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "float16", 24 | "transformers_version": "4.20.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 50272, 27 | "word_embed_proj_dim": 1024 28 | } 29 | -------------------------------------------------------------------------------- /mesh-xl/x-mesh-xl-350m/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 2, 4 | "eos_token_id": 2, 5 | "pad_token_id": 1, 6 | "transformers_version": "4.27.0.dev0" 7 | } 8 | -------------------------------------------------------------------------------- /mesh-xl/x-mesh-xl-350m/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/mesh-xl/x-mesh-xl-350m/pytorch_model.bin -------------------------------------------------------------------------------- /models/mesh_xl/get_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as nnf 3 | from torch import nn, Tensor 4 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 5 | from models.mesh_xl.tokenizer import MeshTokenizer 6 | from typing import Dict 7 | 8 | 9 | class MeshXL(nn.Module): 10 | 11 | def train(self, mode: bool = True): 12 | super().train(mode) 13 | return self 14 | 15 | def __init__(self, args): 16 | super().__init__() 17 | 18 | self.tokenizer = MeshTokenizer(args) 19 | 20 | # causal LM model initialization 21 | self.vocab_size = self.tokenizer.codebook_size + 3 22 | self.bos_token_id = self.tokenizer.codebook_size 23 | self.eos_token_id = self.tokenizer.codebook_size + 1 24 | self.pad_token_id = self.tokenizer.codebook_size + 2 25 | 26 | config = AutoConfig.from_pretrained( 27 | args.llm, 28 | n_positions=8192, 29 | max_position_embeddings=8192, 30 | vocab_size=self.vocab_size, 31 | bos_token_id=self.bos_token_id, 32 | eos_token_id=self.eos_token_id, 33 | pad_token_id=self.pad_token_id 34 | ) 35 | 36 | config.word_embed_proj_dim = config.hidden_size 37 | self.transformer = AutoModelForCausalLM.from_pretrained( 38 | args.llm, 39 | config=config, 40 | ignore_mismatched_sizes=True 41 | ) 42 | self.transformer.to_bettertransformer() 43 | 44 | # setting status for all parameters 45 | self.train() 46 | 47 | 48 | def forward( 49 | self, 50 | data_dict: dict=None, 51 | is_eval: bool=False, 52 | is_generate: bool=False, 53 | num_return_sequences: int=8, 54 | generation_config: Dict=dict( 55 | do_sample=True, 56 | top_k=50, 57 | top_p=0.95, 58 | # no_repeat_ngram_size=9, 59 | ) 60 | ) -> dict: 61 | 62 | if not is_eval: 63 | return self.train_one_step(data_dict) 64 | 65 | if is_eval and not is_generate: 66 | return self.perplexity(data_dict) 67 | 68 | if is_eval and is_generate: 69 | return self.generate( 70 | data_dict=data_dict, 71 | num_return_sequences=num_return_sequences, 72 | generation_config=generation_config 73 | ) 74 | 75 | raise NotImplementedError('training status undefined!') 76 | return 77 | 78 | 79 | def loss_wrapper(self, loss: Tensor) -> Tensor: 80 | # parameter activation: it is a l2 loss with 0 weight 81 | for param in self.parameters(): 82 | loss += 0 * torch.sum(param ** 2) 83 | return loss 84 | 85 | 86 | def train_one_step(self, data_dict: dict) -> dict: 87 | 88 | data_dict = self.tokenizer.tokenize(data_dict) 89 | 90 | input_ids = data_dict['input_ids'] # batch x ntoken 91 | attention_mask = data_dict['attention_mask'] # batch x ntoken 92 | 93 | # parse input with and tokens 94 | input_ids[input_ids == self.tokenizer.pad_id] = self.pad_token_id # xxx 95 | input_ids[:, 0] = self.bos_token_id # xxx 96 | eos_pos_id = attention_mask.sum(1, keepdim=True) - 1 97 | input_ids = torch.scatter( # xxx 98 | input_ids, 99 | 1, 100 | eos_pos_id.long(), 101 | torch.ones_like(input_ids) * self.eos_token_id 102 | ) 103 | 104 | target = input_ids.clone() 105 | target[attention_mask == 0] = -100 # not loss for the padding tokens 106 | 107 | # Forward padd, calling causal llm with better transformer. 108 | output = self.transformer( 109 | input_ids = input_ids.long(), 110 | ) 111 | 112 | # compute loss with shift one-token right 113 | logit = output.logits[:, :-1] # batch x ntoken x vocab 114 | label = target[:, 1:] # batch x ntoken 115 | 116 | final_loss = nnf.cross_entropy( 117 | logit.permute(0, 2, 1), # batch x vocab x ntoken 118 | label, 119 | ) # batch x ntoken 120 | 121 | data_dict['loss'] = self.loss_wrapper(final_loss) 122 | data_dict['gen_loss'] = final_loss 123 | 124 | return data_dict 125 | 126 | 127 | @torch.no_grad() 128 | def perplexity(self, data_dict: dict) -> dict: 129 | 130 | data_dict = self.tokenizer.tokenize(data_dict) 131 | 132 | input_ids = data_dict['input_ids'] # batch x ntoken 133 | attention_mask = data_dict['attention_mask'] # batch x ntoken 134 | 135 | # set pad_token_id = eos_token_id 136 | input_ids[input_ids == self.tokenizer.pad_id] = self.pad_token_id # xxx 137 | input_ids[:, 0] = self.bos_token_id # xxx 138 | eos_pos_id = attention_mask.sum(1, keepdim=True) - 1 139 | input_ids = torch.scatter( # xxx 140 | input_ids, 141 | 1, 142 | eos_pos_id.long(), 143 | torch.ones_like(input_ids) * self.eos_token_id 144 | ) 145 | 146 | # llm loss calculation 147 | output = self.transformer( 148 | input_ids = input_ids.long(), 149 | ) 150 | 151 | # compute loss with shift token right 152 | logit = output.logits[:, :-1] # batch x (ntoken - 1) x vocab 153 | label = input_ids[:, 1:] # batch x (ntoken - 1) 154 | masks = attention_mask[:, 1:] # batch x (ntoken - 1) 155 | loss_per_token = nnf.cross_entropy( 156 | logit.permute(0, 2, 1), # batch x (ntoken - 1) x ntoken 157 | label, # batch x (ntoken - 1) 158 | reduction='none' 159 | ) # batch x ntoken 160 | 161 | # compute negative log likelihood for each sequence 162 | neg_log_likelihood = torch.sum(loss_per_token * masks, dim=1) / torch.sum(masks, dim=1) 163 | 164 | data_dict['neg_log_likelihood'] = neg_log_likelihood # batch, 165 | return data_dict 166 | 167 | 168 | 169 | @torch.no_grad() 170 | def generate(self, data_dict: dict=None, num_return_sequences: int=8, generation_config: dict=dict()) -> dict: 171 | 172 | net_device = next(self.parameters()).device 173 | max_length = 8192 174 | output_ids = torch.ones(num_return_sequences, max_length).long().to(net_device) * self.eos_token_id 175 | 176 | # batch x ntokens 177 | results = self.transformer.generate( 178 | max_new_tokens=max_length-1, 179 | num_return_sequences=num_return_sequences, 180 | bos_token_id=self.bos_token_id, 181 | eos_token_id=self.eos_token_id, 182 | pad_token_id=self.eos_token_id, 183 | **generation_config 184 | ) 185 | output_ids[:, :results.shape[1]] = results 186 | 187 | # discard and tokens to pad tokens 188 | output_ids = output_ids[:, 1: -1] 189 | output_ids[output_ids == self.eos_token_id] = self.tokenizer.pad_id 190 | 191 | decoder_output = self.tokenizer.detokenize(input_ids=output_ids) 192 | 193 | return decoder_output 194 | 195 | 196 | 197 | def get_model(args): 198 | model = MeshXL(args) 199 | return model -------------------------------------------------------------------------------- /models/mesh_xl/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import Tuple 4 | from einops import rearrange, repeat, reduce 5 | 6 | 7 | 8 | def discretize( 9 | t: Tensor, 10 | continuous_range: Tuple[float, float], 11 | num_discrete: int = 128 12 | ) -> Tensor: 13 | 14 | lo, hi = continuous_range 15 | assert hi > lo 16 | t = (t - lo) / (hi - lo) # cube normalize 17 | t *= num_discrete 18 | t -= 0.5 19 | return t.round().long().clamp(min = 0, max = num_discrete - 1) 20 | 21 | 22 | 23 | def undiscretize( 24 | t: Tensor, 25 | continuous_range = Tuple[float, float], 26 | num_discrete: int = 128 27 | ) -> Tensor: 28 | lo, hi = continuous_range 29 | assert hi > lo 30 | t = t.float() 31 | t += 0.5 32 | t /= num_discrete # cube normalize 33 | return t * (hi - lo) + lo 34 | 35 | 36 | 37 | class MeshTokenizer(nn.Module): 38 | 39 | def __init__(self, args): 40 | super().__init__() 41 | self.pad_id = -1 42 | self.num_discrete_coors = args.n_discrete_size # default: 800 43 | self.codebook_size = args.n_discrete_size # default: 128 44 | self.coor_continuous_range = (-1., 1.) 45 | 46 | 47 | def tokenize(self, data_dict: dict) -> dict: 48 | ''' 49 | Turn 3D meshes into sequential tokens: [, , ], ... 50 | ''' 51 | 52 | ### 3D mesh face parsing 53 | vertices = data_dict['vertices'] # batch x nv x 3 54 | faces = data_dict['faces'] # batch x nf x 3 55 | face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all') # batch x nf 56 | 57 | batch, num_vertices, num_coors = vertices.shape 58 | _, num_faces, _ = faces.shape 59 | 60 | # fill padding tokens with 0, to prevent gather idx error 61 | face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0) 62 | 63 | # collect vertice coordinates per-face: b x nf x nv x c 64 | faces_vertices = repeat(face_without_pad, 'b nf nv -> b nf nv c', c = num_coors) 65 | vertices = repeat(vertices, 'b nv c -> b nf nv c', nf = num_faces) 66 | face_coords = vertices.gather(-2, faces_vertices.long()) 67 | 68 | # continuous to discrete face coords: b x nf x nv x c 69 | discrete_face_coords = discretize( 70 | face_coords, 71 | continuous_range=self.coor_continuous_range, 72 | num_discrete=self.num_discrete_coors 73 | ) 74 | 75 | # pad invalid faces with : batch x nf x nv x c 76 | discrete_padded_coords = discrete_face_coords.masked_fill( 77 | ~rearrange(face_mask, 'b nf -> b nf 1 1'), 78 | self.pad_id 79 | ) 80 | 81 | 82 | ### mesh to sequence convertion: batch x ntokens 83 | input_ids = discrete_padded_coords.reshape(batch, -1) 84 | attention_mask = (input_ids != self.pad_id).float() 85 | # reserve two spots: 86 | # input_ids: ... ... => ... ... 87 | # attn_mask: 1 ... 1 0 ... => 1 ... 1 0 ... 88 | place_holder = torch.ones_like(input_ids[:, [0]]) # batch x 1 89 | input_ids = torch.cat((place_holder * self.pad_id, input_ids, place_holder * self.pad_id), dim=1) 90 | attention_mask = torch.cat((place_holder, place_holder, attention_mask), dim=1) 91 | 92 | ### meshXL inputs 93 | data_dict['input_ids'] = input_ids.long() # batch x (nf * 3 * 3 + 2) 94 | data_dict['attention_mask'] = attention_mask.float() # batch x (nf * 3 * 3 + 2) 95 | 96 | # discard and tokens 97 | data_dict['codes'] = discrete_padded_coords.long() # batch x (nf * 3 * 3) 98 | data_dict['discrete_face_coords'] = discrete_face_coords 99 | 100 | return data_dict 101 | 102 | 103 | def detokenize(self, input_ids: Tensor) -> dict: 104 | ''' 105 | Turn sequential tokens: [, , ], ... into 3D meshes 106 | ''' 107 | # input_ids: b (n q) or b n q, without or 108 | input_ids = input_ids.reshape(input_ids.shape[0], -1) 109 | # batch x nface 110 | face_mask = reduce( 111 | input_ids != self.pad_id, 'b (nf c) -> b nf', 'all', c = 9 112 | ) 113 | 114 | # batch x (nface x 9) -> batch x nface x 3 x 3 115 | pred_face_coords = input_ids.reshape(input_ids.shape[0], -1, 9) 116 | pred_face_coords = rearrange( 117 | pred_face_coords, '... (v c) -> ... v c', v = 3 118 | ) 119 | 120 | # back to continuous space 121 | continuous_coors = undiscretize( 122 | pred_face_coords, 123 | num_discrete = self.num_discrete_coors, 124 | continuous_range = self.coor_continuous_range 125 | ) 126 | # mask padding coordinates out with nan 127 | continuous_coors = continuous_coors.masked_fill( 128 | ~rearrange(face_mask, 'b nf -> b nf 1 1'), 129 | float('nan') 130 | ) 131 | output_dict = {} 132 | output_dict['recon_faces'] = continuous_coors 133 | 134 | return output_dict 135 | 136 | 137 | def forward(self, data_dict: dict) -> dict: 138 | 139 | encoder_output = self.tokenize(data_dict) 140 | decoder_output = self.detokenize( 141 | input_ids = encoder_output['codes'], 142 | ) 143 | data_dict.update(encoder_output) 144 | data_dict.update(decoder_output) 145 | return data_dict 146 | -------------------------------------------------------------------------------- /models/x_mesh_xl/get_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as nnf 3 | from torch import nn, Tensor 4 | from einops import repeat, rearrange 5 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, Blip2QFormerModel, Blip2QFormerConfig 6 | from models.x_mesh_xl.tokenizer import MeshTokenizer 7 | from typing import Dict 8 | 9 | 10 | class ConditionEncoder(nn.Module): 11 | 12 | def train(self, mode: bool = True): 13 | super().train(mode) 14 | self.multi_encoder.eval() 15 | for param in self.multi_encoder.parameters(): 16 | param.requires_grad = False 17 | return self 18 | 19 | def __init__(self, args, hidden_size): 20 | super().__init__() 21 | self.n_learnable_queries = 32 22 | config = AutoConfig.from_pretrained(args.text_condition) 23 | self.multi_encoder = AutoModel.from_config(config) 24 | qformer_config = Blip2QFormerConfig( 25 | num_hidden_layers=12, 26 | encoder_hidden_size=self.multi_encoder.config.text_config.hidden_size 27 | ) 28 | self.qformer = Blip2QFormerModel(qformer_config) 29 | self.query_embeds = nn.Embedding(self.n_learnable_queries, qformer_config.hidden_size) 30 | self.out_project = nn.Linear(qformer_config.hidden_size, hidden_size) 31 | 32 | @torch.no_grad() 33 | def encode_text(self, input_ids, attention_mask): 34 | text_encoder_output = self.multi_encoder.text_model( 35 | input_ids=input_ids, 36 | attention_mask=attention_mask 37 | ) 38 | text_embeds = text_encoder_output.last_hidden_state 39 | return text_embeds # bs x ntoken x ch 40 | 41 | def forward(self, input_ids, attention_mask): 42 | net_device = next(self.parameters()).device 43 | text_embeds = self.encode_text(input_ids=input_ids, attention_mask=attention_mask) 44 | query_embeds = self.query_embeds( 45 | repeat( 46 | torch.arange(0, self.n_learnable_queries, dtype=torch.int64).to(net_device), 47 | 'src -> bs src', 48 | bs = text_embeds.shape[0] 49 | ) 50 | ) 51 | query_outputs = self.qformer( 52 | query_embeds=query_embeds, 53 | encoder_hidden_states=text_embeds, 54 | encoder_attention_mask=attention_mask 55 | ) 56 | query_outputs = query_outputs[0][:, : self.n_learnable_queries, :] 57 | return self.out_project(query_outputs) 58 | 59 | 60 | 61 | class MeshXL(nn.Module): 62 | 63 | def train(self, mode: bool = True): 64 | super().train(mode) 65 | # self.transformer.eval() 66 | # for param in self.transformer.parameters(): 67 | # param.requires_grad = False 68 | return self 69 | 70 | def __init__(self, args): 71 | super().__init__() 72 | 73 | self.tokenizer = MeshTokenizer(args) 74 | 75 | # causal LM model initialization 76 | self.vocab_size = self.tokenizer.codebook_size + 3 77 | self.bos_token_id = self.tokenizer.codebook_size 78 | self.eos_token_id = self.tokenizer.codebook_size + 1 79 | self.pad_token_id = self.tokenizer.codebook_size + 2 80 | 81 | config = AutoConfig.from_pretrained( 82 | args.llm, 83 | n_positions=8192, 84 | max_position_embeddings=8192, 85 | vocab_size=self.vocab_size, 86 | bos_token_id=self.bos_token_id, 87 | eos_token_id=self.eos_token_id, 88 | pad_token_id=self.pad_token_id 89 | ) 90 | 91 | config.word_embed_proj_dim = config.hidden_size 92 | self.transformer = AutoModelForCausalLM.from_config(config=config) 93 | 94 | try: 95 | self.transformer.to_bettertransformer() 96 | except: 97 | pass 98 | 99 | self.condition_encoder = ConditionEncoder(args, config.hidden_size) 100 | 101 | # setting status for all parameters 102 | self.train() 103 | 104 | 105 | def forward( 106 | self, 107 | data_dict: dict=None, 108 | is_eval: bool=False, 109 | is_generate: bool=False, 110 | num_return_sequences: int=8, 111 | generation_config: Dict=dict( 112 | do_sample=True, 113 | top_k=50, 114 | top_p=0.95, 115 | # no_repeat_ngram_size=9, 116 | ) 117 | ) -> dict: 118 | 119 | data_dict['prefix_embeds'] = self.condition_encoder( 120 | input_ids = data_dict['text_input_ids'], 121 | attention_mask = data_dict['text_attention_mask'] 122 | ) 123 | 124 | if not is_eval: 125 | return NotImplementedError 126 | 127 | if is_eval and not is_generate: 128 | return NotImplementedError 129 | 130 | if is_eval and is_generate: 131 | return self.generate( 132 | data_dict=data_dict, 133 | num_return_sequences=num_return_sequences, 134 | generation_config=generation_config 135 | ) 136 | 137 | raise NotImplementedError('training status undefined!') 138 | return 139 | 140 | @torch.no_grad() 141 | def generate(self, data_dict: dict=None, num_return_sequences: int=8, generation_config: dict=dict()) -> dict: 142 | 143 | net_device = next(self.parameters()).device 144 | max_length = 8191 145 | output_ids = torch.ones(num_return_sequences, max_length).long().to(net_device) * self.eos_token_id 146 | 147 | # batch x ntokens 148 | results = self.transformer.generate( 149 | inputs_embeds=data_dict['prefix_embeds'], 150 | max_length=max_length-1, 151 | num_return_sequences=num_return_sequences, 152 | bos_token_id=self.bos_token_id, 153 | eos_token_id=self.eos_token_id, 154 | pad_token_id=self.eos_token_id, 155 | **generation_config 156 | ) 157 | output_ids[:, :results.shape[1]] = results 158 | 159 | # discard and tokens to pad tokens 160 | output_ids = output_ids[:, :-1] 161 | output_ids[output_ids == self.eos_token_id] = self.tokenizer.pad_id 162 | 163 | decoder_output = self.tokenizer.detokenize(input_ids=output_ids) 164 | 165 | return decoder_output 166 | 167 | 168 | 169 | def get_model(args): 170 | model = MeshXL(args) 171 | return model -------------------------------------------------------------------------------- /models/x_mesh_xl/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import Tuple 4 | from einops import rearrange, repeat, reduce 5 | 6 | 7 | 8 | def discretize( 9 | t: Tensor, 10 | continuous_range: Tuple[float, float], 11 | num_discrete: int = 128 12 | ) -> Tensor: 13 | 14 | lo, hi = continuous_range 15 | assert hi > lo 16 | t = (t - lo) / (hi - lo) # cube normalize 17 | t *= num_discrete 18 | t -= 0.5 19 | return t.round().long().clamp(min = 0, max = num_discrete - 1) 20 | 21 | 22 | 23 | def undiscretize( 24 | t: Tensor, 25 | continuous_range = Tuple[float, float], 26 | num_discrete: int = 128 27 | ) -> Tensor: 28 | lo, hi = continuous_range 29 | assert hi > lo 30 | t = t.float() 31 | t += 0.5 32 | t /= num_discrete # cube normalize 33 | return t * (hi - lo) + lo 34 | 35 | 36 | 37 | class MeshTokenizer(nn.Module): 38 | 39 | def __init__(self, args): 40 | super().__init__() 41 | self.pad_id = -1 42 | self.num_discrete_coors = args.n_discrete_size # default: 800 43 | self.codebook_size = args.n_discrete_size # default: 128 44 | self.coor_continuous_range = (-1., 1.) 45 | 46 | 47 | def tokenize(self, data_dict: dict) -> dict: 48 | ''' 49 | Turn 3D meshes into sequential tokens: [, , ], ... 50 | ''' 51 | 52 | ### 3D mesh face parsing 53 | vertices = data_dict['vertices'] # batch x nv x 3 54 | faces = data_dict['faces'] # batch x nf x 3 55 | face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all') # batch x nf 56 | 57 | batch, num_vertices, num_coors = vertices.shape 58 | _, num_faces, _ = faces.shape 59 | 60 | # fill padding tokens with 0, to prevent gather idx error 61 | face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0) 62 | 63 | # collect vertice coordinates per-face: b x nf x nv x c 64 | faces_vertices = repeat(face_without_pad, 'b nf nv -> b nf nv c', c = num_coors) 65 | vertices = repeat(vertices, 'b nv c -> b nf nv c', nf = num_faces) 66 | face_coords = vertices.gather(-2, faces_vertices.long()) 67 | 68 | # continuous to discrete face coords: b x nf x nv x c 69 | discrete_face_coords = discretize( 70 | face_coords, 71 | continuous_range=self.coor_continuous_range, 72 | num_discrete=self.num_discrete_coors 73 | ) 74 | 75 | # pad invalid faces with : batch x nf x nv x c 76 | discrete_padded_coords = discrete_face_coords.masked_fill( 77 | ~rearrange(face_mask, 'b nf -> b nf 1 1'), 78 | self.pad_id 79 | ) 80 | 81 | 82 | ### mesh to sequence convertion: batch x ntokens 83 | input_ids = discrete_padded_coords.reshape(batch, -1) 84 | attention_mask = (input_ids != self.pad_id).float() 85 | # reserve two spots: 86 | # input_ids: ... ... => ... ... 87 | # attn_mask: 1 ... 1 0 ... => 1 ... 1 0 ... 88 | place_holder = torch.ones_like(input_ids[:, [0]]) # batch x 1 89 | input_ids = torch.cat((place_holder * self.pad_id, input_ids, place_holder * self.pad_id), dim=1) 90 | attention_mask = torch.cat((place_holder, place_holder, attention_mask), dim=1) 91 | 92 | ### meshXL inputs 93 | data_dict['input_ids'] = input_ids.long() # batch x (nf * 3 * 3 + 2) 94 | data_dict['attention_mask'] = attention_mask.float() # batch x (nf * 3 * 3 + 2) 95 | 96 | # discard and tokens 97 | data_dict['codes'] = discrete_padded_coords.long() # batch x (nf * 3 * 3) 98 | data_dict['discrete_face_coords'] = discrete_face_coords 99 | 100 | return data_dict 101 | 102 | 103 | def detokenize(self, input_ids: Tensor) -> dict: 104 | ''' 105 | Turn sequential tokens: [, , ], ... into 3D meshes 106 | ''' 107 | # input_ids: b (n q) or b n q, without or 108 | input_ids = input_ids.reshape(input_ids.shape[0], -1) 109 | # batch x nface 110 | face_mask = reduce( 111 | input_ids != self.pad_id, 'b (nf c) -> b nf', 'all', c = 9 112 | ) 113 | 114 | # batch x (nface x 9) -> batch x nface x 3 x 3 115 | pred_face_coords = input_ids.reshape(input_ids.shape[0], -1, 9) 116 | pred_face_coords = rearrange( 117 | pred_face_coords, '... (v c) -> ... v c', v = 3 118 | ) 119 | 120 | # back to continuous space 121 | continuous_coors = undiscretize( 122 | pred_face_coords, 123 | num_discrete = self.num_discrete_coors, 124 | continuous_range = self.coor_continuous_range 125 | ) 126 | # mask padding coordinates out with nan 127 | continuous_coors = continuous_coors.masked_fill( 128 | ~rearrange(face_mask, 'b nf -> b nf 1 1'), 129 | float('nan') 130 | ) 131 | output_dict = {} 132 | output_dict['recon_faces'] = continuous_coors 133 | 134 | return output_dict 135 | 136 | 137 | def forward(self, data_dict: dict) -> dict: 138 | 139 | encoder_output = self.tokenize(data_dict) 140 | decoder_output = self.detokenize( 141 | input_ids = encoder_output['codes'], 142 | ) 143 | data_dict.update(encoder_output) 144 | data_dict.update(decoder_output) 145 | return data_dict 146 | -------------------------------------------------------------------------------- /openai/clip-vit-base-patch32/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "openai/clip-vit-base-patch32", 3 | "architectures": [ 4 | "CLIPModel" 5 | ], 6 | "initializer_factor": 1.0, 7 | "logit_scale_init_value": 2.6592, 8 | "model_type": "clip", 9 | "projection_dim": 512, 10 | "text_config": { 11 | "_name_or_path": "", 12 | "add_cross_attention": false, 13 | "architectures": null, 14 | "attention_dropout": 0.0, 15 | "bad_words_ids": null, 16 | "bos_token_id": 0, 17 | "chunk_size_feed_forward": 0, 18 | "cross_attention_hidden_size": null, 19 | "decoder_start_token_id": null, 20 | "diversity_penalty": 0.0, 21 | "do_sample": false, 22 | "dropout": 0.0, 23 | "early_stopping": false, 24 | "encoder_no_repeat_ngram_size": 0, 25 | "eos_token_id": 2, 26 | "finetuning_task": null, 27 | "forced_bos_token_id": null, 28 | "forced_eos_token_id": null, 29 | "hidden_act": "quick_gelu", 30 | "hidden_size": 512, 31 | "id2label": { 32 | "0": "LABEL_0", 33 | "1": "LABEL_1" 34 | }, 35 | "initializer_factor": 1.0, 36 | "initializer_range": 0.02, 37 | "intermediate_size": 2048, 38 | "is_decoder": false, 39 | "is_encoder_decoder": false, 40 | "label2id": { 41 | "LABEL_0": 0, 42 | "LABEL_1": 1 43 | }, 44 | "layer_norm_eps": 1e-05, 45 | "length_penalty": 1.0, 46 | "max_length": 20, 47 | "max_position_embeddings": 77, 48 | "min_length": 0, 49 | "model_type": "clip_text_model", 50 | "no_repeat_ngram_size": 0, 51 | "num_attention_heads": 8, 52 | "num_beam_groups": 1, 53 | "num_beams": 1, 54 | "num_hidden_layers": 12, 55 | "num_return_sequences": 1, 56 | "output_attentions": false, 57 | "output_hidden_states": false, 58 | "output_scores": false, 59 | "pad_token_id": 1, 60 | "prefix": null, 61 | "projection_dim": 512, 62 | "problem_type": null, 63 | "pruned_heads": {}, 64 | "remove_invalid_values": false, 65 | "repetition_penalty": 1.0, 66 | "return_dict": true, 67 | "return_dict_in_generate": false, 68 | "sep_token_id": null, 69 | "task_specific_params": null, 70 | "temperature": 1.0, 71 | "tie_encoder_decoder": false, 72 | "tie_word_embeddings": true, 73 | "tokenizer_class": null, 74 | "top_k": 50, 75 | "top_p": 1.0, 76 | "torch_dtype": null, 77 | "torchscript": false, 78 | "transformers_version": "4.16.0.dev0", 79 | "use_bfloat16": false, 80 | "vocab_size": 49408 81 | }, 82 | "text_config_dict": null, 83 | "transformers_version": null, 84 | "vision_config": { 85 | "_name_or_path": "", 86 | "add_cross_attention": false, 87 | "architectures": null, 88 | "attention_dropout": 0.0, 89 | "bad_words_ids": null, 90 | "bos_token_id": null, 91 | "chunk_size_feed_forward": 0, 92 | "cross_attention_hidden_size": null, 93 | "decoder_start_token_id": null, 94 | "diversity_penalty": 0.0, 95 | "do_sample": false, 96 | "dropout": 0.0, 97 | "early_stopping": false, 98 | "encoder_no_repeat_ngram_size": 0, 99 | "eos_token_id": null, 100 | "finetuning_task": null, 101 | "forced_bos_token_id": null, 102 | "forced_eos_token_id": null, 103 | "hidden_act": "quick_gelu", 104 | "hidden_size": 768, 105 | "id2label": { 106 | "0": "LABEL_0", 107 | "1": "LABEL_1" 108 | }, 109 | "image_size": 224, 110 | "initializer_factor": 1.0, 111 | "initializer_range": 0.02, 112 | "intermediate_size": 3072, 113 | "is_decoder": false, 114 | "is_encoder_decoder": false, 115 | "label2id": { 116 | "LABEL_0": 0, 117 | "LABEL_1": 1 118 | }, 119 | "layer_norm_eps": 1e-05, 120 | "length_penalty": 1.0, 121 | "max_length": 20, 122 | "min_length": 0, 123 | "model_type": "clip_vision_model", 124 | "no_repeat_ngram_size": 0, 125 | "num_attention_heads": 12, 126 | "num_beam_groups": 1, 127 | "num_beams": 1, 128 | "num_hidden_layers": 12, 129 | "num_return_sequences": 1, 130 | "output_attentions": false, 131 | "output_hidden_states": false, 132 | "output_scores": false, 133 | "pad_token_id": null, 134 | "patch_size": 32, 135 | "prefix": null, 136 | "projection_dim" : 512, 137 | "problem_type": null, 138 | "pruned_heads": {}, 139 | "remove_invalid_values": false, 140 | "repetition_penalty": 1.0, 141 | "return_dict": true, 142 | "return_dict_in_generate": false, 143 | "sep_token_id": null, 144 | "task_specific_params": null, 145 | "temperature": 1.0, 146 | "tie_encoder_decoder": false, 147 | "tie_word_embeddings": true, 148 | "tokenizer_class": null, 149 | "top_k": 50, 150 | "top_p": 1.0, 151 | "torch_dtype": null, 152 | "torchscript": false, 153 | "transformers_version": "4.16.0.dev0", 154 | "use_bfloat16": false 155 | }, 156 | "vision_config_dict": null 157 | } 158 | -------------------------------------------------------------------------------- /openai/clip-vit-base-patch32/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /openai/clip-vit-base-patch32/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/", "model_max_length": 77} -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torchtyping 2 | pytorch_custom_utils 3 | beartype 4 | x_transformers 5 | local_attention 6 | vector_quantize_pytorch 7 | classifier_free_guidance_pytorch 8 | torch_geometric 9 | gateloop_transformer 10 | ema_pytorch 11 | trimesh 12 | wandb 13 | libigl 14 | matplotlib 15 | plyfile 16 | optimum 17 | transformers==4.38.2 18 | accelerate 19 | tensorboard 20 | deepspeed -------------------------------------------------------------------------------- /sample_t2m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | from torch import Tensor 6 | from accelerate import Accelerator 7 | from transformers import AutoTokenizer 8 | from utils.ply_helper import write_ply 9 | from models.x_mesh_xl.get_model import get_model 10 | 11 | 12 | def post_process_mesh(mesh_coords: Tensor, filename: str): 13 | mesh_coords = mesh_coords[~torch.isnan(mesh_coords[:, 0, 0])] # nvalid_face x 3 x 3 14 | vertices = mesh_coords.reshape(-1, 3) 15 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face 16 | triangles = vertices_index.reshape(-1, 3) 17 | write_ply( 18 | np.asarray(vertices.cpu()), 19 | None, 20 | np.asarray(triangles), 21 | filename 22 | ) 23 | return vertices 24 | 25 | 26 | def make_args_parser(): 27 | parser = argparse.ArgumentParser( 28 | "MeshXL: Neural Coordinate Field for Generative 3D Foundation Models", 29 | add_help=False 30 | ) 31 | ##### model config ##### 32 | parser.add_argument("--llm", default='mesh-xl/mesh-xl-350m', type=str) 33 | parser.add_argument("--n_discrete_size", default=128, type=int) 34 | parser.add_argument("--text_condition", default='openai/clip-vit-base-patch32', type=str) 35 | parser.add_argument("--test_ckpt", default='mesh-xl/x-mesh-xl-350m/pytorch_model.bin', type=str) 36 | parser.add_argument("--text", default='3d model of a chair', type=str) 37 | parser.add_argument("--output_dir", default='outputs', type=str) 38 | parser.add_argument("--num_samples", default=4, type=int) 39 | parser.add_argument("--top_k", default=50, type=int) 40 | parser.add_argument("--top_p", default=0.95, type=float) 41 | parser.add_argument("--temperature", default=0.1, type=float) 42 | 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | args = make_args_parser() 50 | accelerator = Accelerator() 51 | 52 | # prepare model 53 | tokenizer = AutoTokenizer.from_pretrained(args.text_condition) 54 | mesh_xl = get_model(args) 55 | mesh_xl.load_state_dict(torch.load(args.test_ckpt, map_location='cpu')) 56 | mesh_xl = accelerator.prepare(mesh_xl) 57 | 58 | net_device = next(mesh_xl.parameters()).device 59 | num_samples = args.num_samples 60 | 61 | text_inputs = tokenizer.batch_encode_plus( 62 | [args.text], 63 | max_length=64, 64 | padding='max_length', 65 | truncation='longest_first', 66 | return_tensors='pt' 67 | ) 68 | text_inputs = dict( 69 | text_input_ids=text_inputs['input_ids'].to(net_device), 70 | text_attention_mask=text_inputs['attention_mask'].to(net_device) 71 | ) 72 | 73 | # model forward 74 | output_dict = mesh_xl( 75 | text_inputs, 76 | is_eval=True, 77 | is_generate=True, 78 | num_return_sequences=args.num_samples, 79 | generation_config=dict( 80 | do_sample=True, 81 | top_k=args.top_k, 82 | top_p=args.top_p, 83 | temperature=args.temperature 84 | ) 85 | ) 86 | 87 | # save samples 88 | os.makedirs(args.output_dir, exist_ok=True) 89 | for gen_id, sample in enumerate(output_dict['recon_faces']): 90 | post_process_mesh( 91 | sample, 92 | os.path.join( 93 | args.output_dir, 94 | f'{accelerator.process_index:04d}_{gen_id}.ply' 95 | ) 96 | ) 97 | -------------------------------------------------------------------------------- /scripts/meshxl-sft-shapenet.sh: -------------------------------------------------------------------------------- 1 | export BASE_MESHXL=meshxl/mesh-xl-1.3b 2 | export BATCHSIZE_PER_GPU=2 3 | 4 | accelerate launch \ 5 | --config_file ./config/deepspeed_stage2.yaml \ 6 | --num_machines 1 \ 7 | --num_processes 8 \ 8 | --mixed_precision bf16 \ 9 | main.py \ 10 | --dataset sft.shapenet_table \ 11 | --n_max_triangles 800 \ 12 | --n_discrete_size 128 \ 13 | --warm_lr_iters -1 \ 14 | --base_lr 1e-6 \ 15 | --llm $BASE_MESHXL \ 16 | --model mesh_xl \ 17 | --checkpoint_dir ./ckpts/mesh_xl_1.3b_base_pretrain_bs2_8a100 \ 18 | --batchsize_per_gpu $BATCHSIZE_PER_GPU \ 19 | --dataset_num_workers 0 \ 20 | --augment \ 21 | --eval_every_iteration 10000 \ 22 | --save_every 20000 \ 23 | --max_epoch 1024 -------------------------------------------------------------------------------- /scripts/sample-1.3b.sh: -------------------------------------------------------------------------------- 1 | export LLM_CONFIG='mesh-xl/mesh-xl-1.3b' 2 | export NSAMPLE_PER_GPU=2 3 | export SAMPLE_ROUNDS=100 4 | export OUTPUT_DIR='./output-samples-1.3b' 5 | 6 | accelerate launch \ 7 | --num_machines 1 \ 8 | --num_processes 8 \ 9 | --mixed_precision bf16 \ 10 | main.py \ 11 | --dataset dummy_dataset \ 12 | --n_max_triangles 800 \ 13 | --n_discrete_size 128 \ 14 | --llm $LLM_CONFIG \ 15 | --model mesh_xl \ 16 | --checkpoint_dir $OUTPUT_DIR \ 17 | --batchsize_per_gpu $NSAMPLE_PER_GPU \ 18 | --sample_rounds $SAMPLE_ROUNDS \ 19 | --dataset_num_workers 0 \ 20 | --test_only 21 | -------------------------------------------------------------------------------- /scripts/sample-125m.sh: -------------------------------------------------------------------------------- 1 | export LLM_CONFIG='mesh-xl/mesh-xl-125m' 2 | export NSAMPLE_PER_GPU=2 3 | export SAMPLE_ROUNDS=100 4 | export OUTPUT_DIR='./output-samples-125m' 5 | 6 | accelerate launch \ 7 | --num_machines 1 \ 8 | --num_processes 8 \ 9 | --mixed_precision bf16 \ 10 | main.py \ 11 | --dataset dummy_dataset \ 12 | --n_max_triangles 800 \ 13 | --n_discrete_size 128 \ 14 | --llm $LLM_CONFIG \ 15 | --model mesh_xl \ 16 | --checkpoint_dir $OUTPUT_DIR \ 17 | --batchsize_per_gpu $NSAMPLE_PER_GPU \ 18 | --sample_rounds $SAMPLE_ROUNDS \ 19 | --dataset_num_workers 0 \ 20 | --test_only 21 | -------------------------------------------------------------------------------- /scripts/sample-350m.sh: -------------------------------------------------------------------------------- 1 | export LLM_CONFIG='mesh-xl/mesh-xl-350m' 2 | export NSAMPLE_PER_GPU=2 3 | export SAMPLE_ROUNDS=100 4 | export OUTPUT_DIR='./output-samples-350m' 5 | 6 | accelerate launch \ 7 | --num_machines 1 \ 8 | --num_processes 8 \ 9 | --mixed_precision bf16 \ 10 | main.py \ 11 | --dataset dummy_dataset \ 12 | --n_max_triangles 800 \ 13 | --n_discrete_size 128 \ 14 | --llm $LLM_CONFIG \ 15 | --model mesh_xl \ 16 | --checkpoint_dir $OUTPUT_DIR \ 17 | --batchsize_per_gpu $NSAMPLE_PER_GPU \ 18 | --sample_rounds $SAMPLE_ROUNDS \ 19 | --dataset_num_workers 0 \ 20 | --test_only 21 | -------------------------------------------------------------------------------- /scripts/sample-t2mesh.sh: -------------------------------------------------------------------------------- 1 | accelerate launch \ 2 | --num_machines 1 \ 3 | --num_processes 1 \ 4 | --mixed_precision bf16 \ 5 | sample_t2m.py \ 6 | --test_ckpt mesh-xl/x-mesh-xl-350m/pytorch_model.bin \ 7 | --text '3d model of a table' \ 8 | --top_k 25 \ 9 | --top_p 0.95 \ 10 | --temperature 0.1 -------------------------------------------------------------------------------- /set_env.sh: -------------------------------------------------------------------------------- 1 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 2 | pip install -r requirement.txt 3 | pip install deepspeed -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import pickle 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def is_distributed(): 9 | if not dist.is_available() or not dist.is_initialized(): 10 | return False 11 | return True 12 | 13 | 14 | def get_rank(): 15 | if not is_distributed(): 16 | return 0 17 | return dist.get_rank() 18 | 19 | 20 | def is_primary(): 21 | return get_rank() == 0 22 | 23 | 24 | def get_world_size(): 25 | if not is_distributed(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def barrier(): 31 | if not is_distributed(): 32 | return 33 | torch.distributed.barrier() 34 | 35 | 36 | def setup_print_for_distributed(is_primary): 37 | """ 38 | This function disables printing when not in primary process 39 | """ 40 | import builtins as __builtin__ 41 | builtin_print = __builtin__.print 42 | 43 | def print(*args, **kwargs): 44 | force = kwargs.pop('force', False) 45 | if is_primary or force: 46 | builtin_print(*args, **kwargs) 47 | 48 | __builtin__.print = print 49 | 50 | 51 | def init_distributed(gpu_id, global_rank, world_size, dist_url, dist_backend): 52 | torch.cuda.set_device(gpu_id) 53 | print( 54 | f"| distributed init (rank {global_rank}) (world {world_size}): {dist_url}", 55 | flush=True, 56 | ) 57 | torch.distributed.init_process_group( 58 | backend=dist_backend, 59 | init_method=dist_url, 60 | world_size=world_size, 61 | rank=global_rank, 62 | ) 63 | torch.distributed.barrier() 64 | setup_print_for_distributed(is_primary()) 65 | 66 | 67 | def all_reduce_sum(tensor): 68 | if not is_distributed(): 69 | return tensor 70 | dim_squeeze = False 71 | if tensor.ndim == 0: 72 | tensor = tensor[None, ...] 73 | dim_squeeze = True 74 | torch.distributed.all_reduce(tensor) 75 | if dim_squeeze: 76 | tensor = tensor.squeeze(0) 77 | return tensor 78 | 79 | 80 | def all_reduce_average(tensor): 81 | val = all_reduce_sum(tensor) 82 | return val / get_world_size() 83 | 84 | 85 | # Function from DETR - https://github.com/facebookresearch/detr/blob/master/util/misc.py 86 | def reduce_dict(input_dict, average=True): 87 | """ 88 | Args: 89 | input_dict (dict): all the values will be reduced 90 | average (bool): whether to do average or sum 91 | Reduce the values in the dictionary from all processes so that all processes 92 | have the averaged results. Returns a dict with the same fields as 93 | input_dict, after reduction. 94 | """ 95 | world_size = get_world_size() 96 | if world_size < 2: 97 | return input_dict 98 | with torch.no_grad(): 99 | names = [] 100 | values = [] 101 | # sort the keys so that they are consistent across processes 102 | for k in sorted(input_dict.keys()): 103 | names.append(k) 104 | values.append(input_dict[k]) 105 | values = torch.stack(values, dim=0) 106 | torch.distributed.all_reduce(values) 107 | if average: 108 | values /= world_size 109 | reduced_dict = {k: v for k, v in zip(names, values)} 110 | return reduced_dict 111 | 112 | 113 | # Function from https://github.com/facebookresearch/detr/blob/master/util/misc.py 114 | def all_gather_pickle(data, device): 115 | """ 116 | Run all_gather on arbitrary picklable data (not necessarily tensors) 117 | Args: 118 | data: any picklable object 119 | Returns: 120 | list[data]: list of data gathered from each rank 121 | """ 122 | world_size = get_world_size() 123 | if world_size == 1: 124 | return [data] 125 | 126 | # serialized to a Tensor 127 | buffer = pickle.dumps(data) 128 | storage = torch.ByteStorage.from_buffer(buffer) 129 | tensor = torch.ByteTensor(storage).to(device) 130 | 131 | # obtain Tensor size of each rank 132 | local_size = torch.tensor([tensor.numel()], device=device) 133 | size_list = [torch.tensor([0], device=device) for _ in range(world_size)] 134 | dist.all_gather(size_list, local_size) 135 | size_list = [int(size.item()) for size in size_list] 136 | max_size = max(size_list) 137 | 138 | # receiving Tensor from all ranks 139 | # we pad the tensor because torch all_gather does not support 140 | # gathering tensors of different shapes 141 | tensor_list = [] 142 | for _ in size_list: 143 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) 144 | if local_size != max_size: 145 | padding = torch.empty( 146 | size=(max_size - local_size,), dtype=torch.uint8, device=device 147 | ) 148 | tensor = torch.cat((tensor, padding), dim=0) 149 | dist.all_gather(tensor_list, tensor) 150 | 151 | data_list = [] 152 | for size, tensor in zip(size_list, tensor_list): 153 | buffer = tensor.cpu().numpy().tobytes()[:size] 154 | data_list.append(pickle.loads(buffer)) 155 | 156 | return data_list 157 | 158 | 159 | def all_gather_dict(data): 160 | """ 161 | Run all_gather on data which is a dictionary of Tensors 162 | """ 163 | assert isinstance(data, dict) 164 | 165 | gathered_dict = {} 166 | for item_key in data: 167 | if isinstance(data[item_key], torch.Tensor): 168 | if is_distributed(): 169 | data[item_key] = data[item_key].contiguous() 170 | tensor_list = [torch.empty_like(data[item_key]) for _ in range(get_world_size())] 171 | dist.all_gather(tensor_list, data[item_key]) 172 | gathered_tensor = torch.cat(tensor_list, dim=0) 173 | else: 174 | gathered_tensor = data[item_key] 175 | gathered_dict[item_key] = gathered_tensor 176 | return gathered_dict 177 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch 4 | import os 5 | from utils.dist import is_primary 6 | 7 | 8 | def save_checkpoint( 9 | checkpoint_dir, 10 | model_no_ddp, 11 | optimizer, 12 | epoch, 13 | args, 14 | best_val_metrics, 15 | filename=None, 16 | ): 17 | if not is_primary(): 18 | return 19 | if filename is None: 20 | filename = f"checkpoint_{epoch:04d}.pth" 21 | checkpoint_name = os.path.join(checkpoint_dir, filename) 22 | 23 | weight_ckpt = model_no_ddp.state_dict() 24 | sd = { 25 | "model": weight_ckpt, 26 | "optimizer": optimizer.state_dict(), 27 | "epoch": epoch, 28 | "args": args, 29 | "best_val_metrics": best_val_metrics, 30 | } 31 | torch.save(sd, checkpoint_name) 32 | 33 | 34 | def resume_if_possible(checkpoint_dir, model_no_ddp, optimizer): 35 | """ 36 | Resume if checkpoint is available. 37 | Return 38 | - epoch of loaded checkpoint. 39 | """ 40 | epoch = -1 41 | best_val_metrics = {} 42 | if not os.path.isdir(checkpoint_dir): 43 | return epoch, best_val_metrics 44 | 45 | last_checkpoint = os.path.join(checkpoint_dir, "checkpoint.pth") 46 | if not os.path.isfile(last_checkpoint): 47 | return epoch, best_val_metrics 48 | 49 | sd = torch.load(last_checkpoint, map_location=torch.device("cpu")) 50 | epoch = sd["epoch"] 51 | best_val_metrics = sd["best_val_metrics"] 52 | print(f"Found checkpoint at {epoch}. Resuming.") 53 | 54 | model_no_ddp.load_state_dict(sd["model"], strict=False) 55 | try: 56 | optimizer.load_state_dict(sd["optimizer"]) 57 | except: 58 | print('optimizer weights could not be loaded') 59 | print( 60 | f"Loaded model and optimizer state at {epoch}. Loaded best val metrics so far." 61 | ) 62 | return epoch, best_val_metrics 63 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | 7 | class Logger(object): 8 | 9 | def __init__(self, log_dir=None, accelerator=None) -> None: 10 | self.log_dir = log_dir 11 | self.accelerator = accelerator 12 | 13 | if self.log_dir is not None: 14 | self.txt_writer = open(os.path.join(self.log_dir, 'logger.log'), 'a') 15 | else: 16 | self.txt_writer = None 17 | 18 | if SummaryWriter is not None and self.accelerator.is_main_process: 19 | self.writer = SummaryWriter(self.log_dir) 20 | else: 21 | self.writer = None 22 | 23 | def log_scalars(self, scalar_dict, step, prefix=None): 24 | if self.writer is None: 25 | return 26 | for k in scalar_dict: 27 | v = scalar_dict[k] 28 | if isinstance(v, torch.Tensor): 29 | v = v.detach().cpu().item() 30 | if prefix is not None: 31 | k = prefix + '_' + k 32 | self.writer.add_scalar(k, v, step) 33 | return 34 | 35 | def log_messages(self, message: str): 36 | if self.txt_writer is not None: 37 | self.txt_writer.write(message + "\n") 38 | self.txt_writer.flush() 39 | print(message, flush=True) 40 | return 41 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import numpy as np 4 | from collections import deque 5 | from typing import List 6 | from utils.dist import is_distributed, barrier, all_reduce_sum 7 | 8 | 9 | def my_worker_init_fn(worker_id): 10 | np.random.seed(np.random.get_state()[1][0] + worker_id) 11 | 12 | 13 | @torch.jit.ignore 14 | def to_list_1d(arr) -> List[float]: 15 | arr = arr.detach().cpu().numpy().tolist() 16 | return arr 17 | 18 | 19 | @torch.jit.ignore 20 | def to_list_3d(arr) -> List[List[List[float]]]: 21 | arr = arr.detach().cpu().numpy().tolist() 22 | return arr 23 | 24 | 25 | def huber_loss(error, delta=1.0): 26 | """ 27 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py 28 | x = error = pred - gt or dist(pred,gt) 29 | 0.5 * |x|^2 if |x|<=d 30 | 0.5 * d^2 + d * (|x|-d) if |x|>d 31 | """ 32 | abs_error = torch.abs(error) 33 | quadratic = torch.clamp(abs_error, max=delta) 34 | linear = abs_error - quadratic 35 | loss = 0.5 * quadratic ** 2 + delta * linear 36 | return loss 37 | 38 | 39 | # From https://github.com/facebookresearch/detr/blob/master/util/misc.py 40 | class SmoothedValue(object): 41 | """Track a series of values and provide access to smoothed values over a 42 | window or the global series average. 43 | """ 44 | 45 | def __init__(self, window_size=20, fmt=None): 46 | if fmt is None: 47 | fmt = "{median:.4f} ({global_avg:.4f})" 48 | self.deque = deque(maxlen=window_size) 49 | self.total = 0.0 50 | self.count = 0 51 | self.fmt = fmt 52 | 53 | def update(self, value, n=1): 54 | self.deque.append(value) 55 | self.count += n 56 | self.total += value * n 57 | 58 | def synchronize_between_processes(self): 59 | """ 60 | Warning: does not synchronize the deque! 61 | """ 62 | if not is_distributed(): 63 | return 64 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 65 | barrier() 66 | all_reduce_sum(t) 67 | t = t.tolist() 68 | self.count = int(t[0]) 69 | self.total = t[1] 70 | 71 | @property 72 | def median(self): 73 | d = torch.tensor(list(self.deque)) 74 | return d.median().item() 75 | 76 | @property 77 | def avg(self): 78 | d = torch.tensor(list(self.deque), dtype=torch.float32) 79 | return d.mean().item() 80 | 81 | @property 82 | def global_avg(self): 83 | return self.total / self.count 84 | 85 | @property 86 | def max(self): 87 | return max(self.deque) 88 | 89 | @property 90 | def value(self): 91 | return self.deque[-1] 92 | 93 | def __str__(self): 94 | return self.fmt.format( 95 | median=self.median, 96 | avg=self.avg, 97 | global_avg=self.global_avg, 98 | max=self.max, 99 | value=self.value, 100 | ) 101 | -------------------------------------------------------------------------------- /utils/nms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import numpy as np 4 | 5 | # boxes are axis aigned 2D boxes of shape (n,5) in FLOAT numbers with (x1,y1,x2,y2,score) 6 | """ Ref: https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/ 7 | Ref: https://github.com/vickyboy47/nms-python/blob/master/nms.py 8 | """ 9 | 10 | 11 | def nms_2d(boxes, overlap_threshold): 12 | x1 = boxes[:, 0] 13 | y1 = boxes[:, 1] 14 | x2 = boxes[:, 2] 15 | y2 = boxes[:, 3] 16 | score = boxes[:, 4] 17 | area = (x2 - x1) * (y2 - y1) 18 | 19 | I = np.argsort(score) 20 | pick = [] 21 | while I.size != 0: 22 | last = I.size 23 | i = I[-1] 24 | pick.append(i) 25 | suppress = [last - 1] 26 | for pos in range(last - 1): 27 | j = I[pos] 28 | xx1 = max(x1[i], x1[j]) 29 | yy1 = max(y1[i], y1[j]) 30 | xx2 = min(x2[i], x2[j]) 31 | yy2 = min(y2[i], y2[j]) 32 | w = xx2 - xx1 33 | h = yy2 - yy1 34 | if w > 0 and h > 0: 35 | o = w * h / area[j] 36 | print("Overlap is", o) 37 | if o > overlap_threshold: 38 | suppress.append(pos) 39 | I = np.delete(I, suppress) 40 | return pick 41 | 42 | 43 | def nms_2d_faster(boxes, overlap_threshold, old_type=False): 44 | x1 = boxes[:, 0] 45 | y1 = boxes[:, 1] 46 | x2 = boxes[:, 2] 47 | y2 = boxes[:, 3] 48 | score = boxes[:, 4] 49 | area = (x2 - x1) * (y2 - y1) 50 | 51 | I = np.argsort(score) 52 | pick = [] 53 | while I.size != 0: 54 | last = I.size 55 | i = I[-1] 56 | pick.append(i) 57 | 58 | xx1 = np.maximum(x1[i], x1[I[: last - 1]]) 59 | yy1 = np.maximum(y1[i], y1[I[: last - 1]]) 60 | xx2 = np.minimum(x2[i], x2[I[: last - 1]]) 61 | yy2 = np.minimum(y2[i], y2[I[: last - 1]]) 62 | 63 | w = np.maximum(0, xx2 - xx1) 64 | h = np.maximum(0, yy2 - yy1) 65 | 66 | if old_type: 67 | o = (w * h) / area[I[: last - 1]] 68 | else: 69 | inter = w * h 70 | o = inter / (area[i] + area[I[: last - 1]] - inter) 71 | 72 | I = np.delete( 73 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])) 74 | ) 75 | 76 | return pick 77 | 78 | 79 | def nms_3d_faster(boxes, overlap_threshold, old_type=False): 80 | x1 = boxes[:, 0] 81 | y1 = boxes[:, 1] 82 | z1 = boxes[:, 2] 83 | x2 = boxes[:, 3] 84 | y2 = boxes[:, 4] 85 | z2 = boxes[:, 5] 86 | score = boxes[:, 6] 87 | area = (x2 - x1) * (y2 - y1) * (z2 - z1) 88 | 89 | I = np.argsort(score) 90 | pick = [] 91 | while I.size != 0: 92 | last = I.size 93 | i = I[-1] 94 | pick.append(i) 95 | 96 | xx1 = np.maximum(x1[i], x1[I[: last - 1]]) 97 | yy1 = np.maximum(y1[i], y1[I[: last - 1]]) 98 | zz1 = np.maximum(z1[i], z1[I[: last - 1]]) 99 | xx2 = np.minimum(x2[i], x2[I[: last - 1]]) 100 | yy2 = np.minimum(y2[i], y2[I[: last - 1]]) 101 | zz2 = np.minimum(z2[i], z2[I[: last - 1]]) 102 | 103 | l = np.maximum(0, xx2 - xx1) 104 | w = np.maximum(0, yy2 - yy1) 105 | h = np.maximum(0, zz2 - zz1) 106 | 107 | if old_type: 108 | o = (l * w * h) / area[I[: last - 1]] 109 | else: 110 | inter = l * w * h 111 | o = inter / (area[i] + area[I[: last - 1]] - inter) 112 | 113 | I = np.delete( 114 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])) 115 | ) 116 | 117 | return pick 118 | 119 | 120 | def nms_3d_faster_samecls(boxes, overlap_threshold, old_type=False): 121 | x1 = boxes[:, 0] 122 | y1 = boxes[:, 1] 123 | z1 = boxes[:, 2] 124 | x2 = boxes[:, 3] 125 | y2 = boxes[:, 4] 126 | z2 = boxes[:, 5] 127 | score = boxes[:, 6] 128 | cls = boxes[:, 7] 129 | area = (x2 - x1) * (y2 - y1) * (z2 - z1) 130 | 131 | I = np.argsort(score) 132 | pick = [] 133 | while I.size != 0: 134 | last = I.size 135 | i = I[-1] 136 | pick.append(i) 137 | 138 | xx1 = np.maximum(x1[i], x1[I[: last - 1]]) 139 | yy1 = np.maximum(y1[i], y1[I[: last - 1]]) 140 | zz1 = np.maximum(z1[i], z1[I[: last - 1]]) 141 | xx2 = np.minimum(x2[i], x2[I[: last - 1]]) 142 | yy2 = np.minimum(y2[i], y2[I[: last - 1]]) 143 | zz2 = np.minimum(z2[i], z2[I[: last - 1]]) 144 | cls1 = cls[i] 145 | cls2 = cls[I[: last - 1]] 146 | 147 | l = np.maximum(0, xx2 - xx1) 148 | w = np.maximum(0, yy2 - yy1) 149 | h = np.maximum(0, zz2 - zz1) 150 | 151 | if old_type: 152 | o = (l * w * h) / area[I[: last - 1]] 153 | else: 154 | inter = l * w * h 155 | o = inter / (area[i] + area[I[: last - 1]] - inter) 156 | o = o * (cls1 == cls2) 157 | 158 | I = np.delete( 159 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])) 160 | ) 161 | 162 | return pick 163 | -------------------------------------------------------------------------------- /utils/pc_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | """ Utility functions for processing point clouds. 4 | 5 | Author: Charles R. Qi and Or Litany 6 | """ 7 | 8 | import os 9 | import sys 10 | import torch 11 | 12 | # Point cloud IO 13 | import numpy as np 14 | from plyfile import PlyData, PlyElement 15 | 16 | # Mesh IO 17 | import trimesh 18 | 19 | # ---------------------------------------- 20 | # Point Cloud Sampling 21 | # ---------------------------------------- 22 | 23 | 24 | def random_sampling(pc, num_sample, replace=None, return_choices=False): 25 | """Input is NxC, output is num_samplexC""" 26 | if replace is None: 27 | replace = pc.shape[0] < num_sample 28 | choices = np.random.choice(pc.shape[0], num_sample, replace=replace) 29 | if return_choices: 30 | return pc[choices], choices 31 | else: 32 | return pc[choices] 33 | 34 | 35 | # ---------------------------------------- 36 | # Simple Point manipulations 37 | # ---------------------------------------- 38 | def shift_scale_points(pred_xyz, src_range, dst_range=None): 39 | """ 40 | pred_xyz: B x N x 3 41 | src_range: [[B x 3], [B x 3]] - min and max XYZ coords 42 | dst_range: [[B x 3], [B x 3]] - min and max XYZ coords 43 | """ 44 | if dst_range is None: 45 | dst_range = [ 46 | torch.zeros((src_range[0].shape[0], 3), device=src_range[0].device), 47 | torch.ones((src_range[0].shape[0], 3), device=src_range[0].device), 48 | ] 49 | 50 | if pred_xyz.ndim == 4: 51 | src_range = [x[:, None] for x in src_range] 52 | dst_range = [x[:, None] for x in dst_range] 53 | 54 | assert src_range[0].shape[0] == pred_xyz.shape[0] 55 | assert dst_range[0].shape[0] == pred_xyz.shape[0] 56 | assert src_range[0].shape[-1] == pred_xyz.shape[-1] 57 | assert src_range[0].shape == src_range[1].shape 58 | assert dst_range[0].shape == dst_range[1].shape 59 | assert src_range[0].shape == dst_range[1].shape 60 | 61 | src_diff = src_range[1][:, None, :] - src_range[0][:, None, :] 62 | dst_diff = dst_range[1][:, None, :] - dst_range[0][:, None, :] 63 | prop_xyz = ( 64 | ((pred_xyz - src_range[0][:, None, :]) * dst_diff) / src_diff 65 | ) + dst_range[0][:, None, :] 66 | return prop_xyz 67 | 68 | 69 | def scale_points(pred_xyz, mult_factor): 70 | if pred_xyz.ndim == 4: 71 | mult_factor = mult_factor[:, None] 72 | scaled_xyz = pred_xyz * mult_factor[:, None, :] 73 | return scaled_xyz 74 | 75 | 76 | def rotate_point_cloud(points, rotation_matrix=None): 77 | """Input: (n,3), Output: (n,3)""" 78 | # Rotate in-place around Z axis. 79 | if rotation_matrix is None: 80 | rotation_angle = np.random.uniform() * 2 * np.pi 81 | sinval, cosval = np.sin(rotation_angle), np.cos(rotation_angle) 82 | rotation_matrix = np.array( 83 | [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]] 84 | ) 85 | ctr = points.mean(axis=0) 86 | rotated_data = np.dot(points - ctr, rotation_matrix) + ctr 87 | return rotated_data, rotation_matrix 88 | 89 | 90 | def rotate_pc_along_y(pc, rot_angle): 91 | """Input ps is NxC points with first 3 channels as XYZ 92 | z is facing forward, x is left ward, y is downward 93 | """ 94 | cosval = np.cos(rot_angle) 95 | sinval = np.sin(rot_angle) 96 | rotmat = np.array([[cosval, -sinval], [sinval, cosval]]) 97 | pc[:, [0, 2]] = np.dot(pc[:, [0, 2]], np.transpose(rotmat)) 98 | return pc 99 | 100 | 101 | def roty(t): 102 | """Rotation about the y-axis.""" 103 | c = np.cos(t) 104 | s = np.sin(t) 105 | return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]) 106 | 107 | 108 | def roty_batch(t): 109 | """Rotation about the y-axis. 110 | t: (x1,x2,...xn) 111 | return: (x1,x2,...,xn,3,3) 112 | """ 113 | input_shape = t.shape 114 | output = np.zeros(tuple(list(input_shape) + [3, 3])) 115 | c = np.cos(t) 116 | s = np.sin(t) 117 | output[..., 0, 0] = c 118 | output[..., 0, 2] = s 119 | output[..., 1, 1] = 1 120 | output[..., 2, 0] = -s 121 | output[..., 2, 2] = c 122 | return output 123 | 124 | 125 | def rotz(t): 126 | """Rotation about the z-axis.""" 127 | c = np.cos(t) 128 | s = np.sin(t) 129 | return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) 130 | 131 | 132 | def point_cloud_to_bbox(points): 133 | """Extract the axis aligned box from a pcl or batch of pcls 134 | Args: 135 | points: Nx3 points or BxNx3 136 | output is 6 dim: xyz pos of center and 3 lengths 137 | """ 138 | which_dim = len(points.shape) - 2 # first dim if a single cloud and second if batch 139 | mn, mx = points.min(which_dim), points.max(which_dim) 140 | lengths = mx - mn 141 | cntr = 0.5 * (mn + mx) 142 | return np.concatenate([cntr, lengths], axis=which_dim) 143 | 144 | 145 | def write_bbox(scene_bbox, out_filename): 146 | """Export scene bbox to meshes 147 | Args: 148 | scene_bbox: (N x 6 numpy array): xyz pos of center and 3 lengths 149 | out_filename: (string) filename 150 | 151 | Note: 152 | To visualize the boxes in MeshLab. 153 | 1. Select the objects (the boxes) 154 | 2. Filters -> Polygon and Quad Mesh -> Turn into Quad-Dominant Mesh 155 | 3. Select Wireframe view. 156 | """ 157 | 158 | def convert_box_to_trimesh_fmt(box): 159 | ctr = box[:3] 160 | lengths = box[3:] 161 | trns = np.eye(4) 162 | trns[0:3, 3] = ctr 163 | trns[3, 3] = 1.0 164 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 165 | return box_trimesh_fmt 166 | 167 | scene = trimesh.scene.Scene() 168 | for box in scene_bbox: 169 | scene.add_geometry(convert_box_to_trimesh_fmt(box)) 170 | 171 | mesh_list = trimesh.util.concatenate(scene.dump()) 172 | # save to ply file 173 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply") 174 | 175 | return 176 | 177 | 178 | def write_oriented_bbox(scene_bbox, out_filename, colors=None): 179 | """Export oriented (around Z axis) scene bbox to meshes 180 | Args: 181 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz) 182 | and heading angle around Z axis. 183 | Y forward, X right, Z upward. heading angle of positive X is 0, 184 | heading angle of positive Y is 90 degrees. 185 | out_filename: (string) filename 186 | """ 187 | 188 | def heading2rotmat(heading_angle): 189 | pass 190 | rotmat = np.zeros((3, 3)) 191 | rotmat[2, 2] = 1 192 | cosval = np.cos(heading_angle) 193 | sinval = np.sin(heading_angle) 194 | rotmat[0:2, 0:2] = np.array([[cosval, -sinval], [sinval, cosval]]) 195 | return rotmat 196 | 197 | def convert_oriented_box_to_trimesh_fmt(box): 198 | ctr = box[:3] 199 | lengths = box[3:6] 200 | trns = np.eye(4) 201 | trns[0:3, 3] = ctr 202 | trns[3, 3] = 1.0 203 | trns[0:3, 0:3] = heading2rotmat(box[6]) 204 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 205 | return box_trimesh_fmt 206 | 207 | if colors is not None: 208 | if colors.shape[0] != len(scene_bbox): 209 | colors = [colors for _ in range(len(scene_bbox))] 210 | colors = np.array(colors).astype(np.uint8) 211 | assert colors.shape[0] == len(scene_bbox) 212 | assert colors.shape[1] == 4 213 | 214 | scene = trimesh.scene.Scene() 215 | for idx, box in enumerate(scene_bbox): 216 | box_tr = convert_oriented_box_to_trimesh_fmt(box) 217 | if colors is not None: 218 | box_tr.visual.main_color[:] = colors[idx] 219 | box_tr.visual.vertex_colors[:] = colors[idx] 220 | for facet in box_tr.facets: 221 | box_tr.visual.face_colors[facet] = colors[idx] 222 | scene.add_geometry(box_tr) 223 | 224 | mesh_list = trimesh.util.concatenate(scene.dump()) 225 | # save to ply file 226 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply") 227 | 228 | return 229 | 230 | 231 | def write_oriented_bbox_camera_coord(scene_bbox, out_filename): 232 | """Export oriented (around Y axis) scene bbox to meshes 233 | Args: 234 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz) 235 | and heading angle around Y axis. 236 | Z forward, X rightward, Y downward. heading angle of positive X is 0, 237 | heading angle of negative Z is 90 degrees. 238 | out_filename: (string) filename 239 | """ 240 | 241 | def heading2rotmat(heading_angle): 242 | pass 243 | rotmat = np.zeros((3, 3)) 244 | rotmat[1, 1] = 1 245 | cosval = np.cos(heading_angle) 246 | sinval = np.sin(heading_angle) 247 | rotmat[0, :] = np.array([cosval, 0, sinval]) 248 | rotmat[2, :] = np.array([-sinval, 0, cosval]) 249 | return rotmat 250 | 251 | def convert_oriented_box_to_trimesh_fmt(box): 252 | ctr = box[:3] 253 | lengths = box[3:6] 254 | trns = np.eye(4) 255 | trns[0:3, 3] = ctr 256 | trns[3, 3] = 1.0 257 | trns[0:3, 0:3] = heading2rotmat(box[6]) 258 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 259 | return box_trimesh_fmt 260 | 261 | scene = trimesh.scene.Scene() 262 | for box in scene_bbox: 263 | scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box)) 264 | 265 | mesh_list = trimesh.util.concatenate(scene.dump()) 266 | # save to ply file 267 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply") 268 | 269 | return 270 | 271 | 272 | def write_lines_as_cylinders(pcl, filename, rad=0.005, res=64): 273 | """Create lines represented as cylinders connecting pairs of 3D points 274 | Args: 275 | pcl: (N x 2 x 3 numpy array): N pairs of xyz pos 276 | filename: (string) filename for the output mesh (ply) file 277 | rad: radius for the cylinder 278 | res: number of sections used to create the cylinder 279 | """ 280 | scene = trimesh.scene.Scene() 281 | for src, tgt in pcl: 282 | # compute line 283 | vec = tgt - src 284 | M = trimesh.geometry.align_vectors([0, 0, 1], vec, False) 285 | vec = tgt - src # compute again since align_vectors modifies vec in-place! 286 | M[:3, 3] = 0.5 * src + 0.5 * tgt 287 | height = np.sqrt(np.dot(vec, vec)) 288 | scene.add_geometry( 289 | trimesh.creation.cylinder( 290 | radius=rad, height=height, sections=res, transform=M 291 | ) 292 | ) 293 | mesh_list = trimesh.util.concatenate(scene.dump()) 294 | trimesh.io.export.export_mesh(mesh_list, "%s.ply" % (filename), file_type="ply") 295 | -------------------------------------------------------------------------------- /utils/ply_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from plyfile import PlyData,PlyElement 5 | 6 | 7 | def read_mesh_vertices_rgb_normal(filename): 8 | """ read XYZ RGB normals point cloud from filename PLY file """ 9 | assert(os.path.isfile(filename)) 10 | with open(filename, 'rb') as f: 11 | plydata = PlyData.read(f) 12 | num_verts = plydata['vertex'].count 13 | vertices = np.zeros(shape=[num_verts, 6], dtype=np.float32) 14 | vertices[:,0] = plydata['vertex'].data['x'] 15 | vertices[:,1] = plydata['vertex'].data['y'] 16 | vertices[:,2] = plydata['vertex'].data['z'] 17 | vertices[:,3] = plydata['vertex'].data['red'] 18 | vertices[:,4] = plydata['vertex'].data['green'] 19 | vertices[:,5] = plydata['vertex'].data['blue'] 20 | 21 | # compute normals 22 | face = np.array([f[0] for f in plydata["face"].data]) 23 | 24 | return vertices, face 25 | 26 | 27 | def write_ply(verts, colors, indices, output_file): 28 | if colors is None: 29 | colors = np.zeros_like(verts) 30 | if indices is None: 31 | indices = [] 32 | 33 | file = open(output_file, 'w') 34 | file.write('ply\n') 35 | file.write('format ascii 1.0\n') 36 | file.write('element vertex {:d}\n'.format(len(verts))) 37 | file.write('property float x\n') 38 | file.write('property float y\n') 39 | file.write('property float z\n') 40 | file.write('property uchar red\n') 41 | file.write('property uchar green\n') 42 | file.write('property uchar blue\n') 43 | file.write('element face {:d}\n'.format(len(indices))) 44 | file.write('property list uchar uint vertex_indices\n') 45 | file.write('end_header\n') 46 | for vert, color in zip(verts, colors): 47 | file.write("{:f} {:f} {:f} {:d} {:d} {:d}\n".format(vert[0], vert[1], vert[2] , int(color[0]*255), int(color[1]*255), int(color[2]*255))) 48 | for ind in indices: 49 | file.write('3 {:d} {:d} {:d}\n'.format(ind[0], ind[1], ind[2])) 50 | file.close() 51 | 52 | 53 | def create_cylinder_mesh(radius, p0, p1, stacks=10, slices=10): 54 | 55 | import math 56 | 57 | def compute_length_vec3(vec3): 58 | return math.sqrt(vec3[0]*vec3[0] + vec3[1]*vec3[1] + vec3[2]*vec3[2]) 59 | 60 | def rotation(axis, angle): 61 | rot = np.eye(4) 62 | c = np.cos(-angle) 63 | s = np.sin(-angle) 64 | t = 1.0 - c 65 | axis /= compute_length_vec3(axis) 66 | x = axis[0] 67 | y = axis[1] 68 | z = axis[2] 69 | rot[0,0] = 1 + t*(x*x-1) 70 | rot[0,1] = z*s+t*x*y 71 | rot[0,2] = -y*s+t*x*z 72 | rot[1,0] = -z*s+t*x*y 73 | rot[1,1] = 1+t*(y*y-1) 74 | rot[1,2] = x*s+t*y*z 75 | rot[2,0] = y*s+t*x*z 76 | rot[2,1] = -x*s+t*y*z 77 | rot[2,2] = 1+t*(z*z-1) 78 | return rot 79 | 80 | 81 | verts = [] 82 | indices = [] 83 | diff = (p1 - p0).astype(np.float32) 84 | height = compute_length_vec3(diff) 85 | for i in range(stacks+1): 86 | for i2 in range(slices): 87 | theta = i2 * 2.0 * math.pi / slices 88 | pos = np.array([radius*math.cos(theta), radius*math.sin(theta), height*i/stacks]) 89 | verts.append(pos) 90 | for i in range(stacks): 91 | for i2 in range(slices): 92 | i2p1 = math.fmod(i2 + 1, slices) 93 | indices.append( np.array([(i + 1)*slices + i2, i*slices + i2, i*slices + i2p1], dtype=np.uint32) ) 94 | indices.append( np.array([(i + 1)*slices + i2, i*slices + i2p1, (i + 1)*slices + i2p1], dtype=np.uint32) ) 95 | transform = np.eye(4) 96 | va = np.array([0, 0, 1], dtype=np.float32) 97 | vb = diff 98 | vb /= compute_length_vec3(vb) 99 | axis = np.cross(vb, va) 100 | angle = np.arccos(np.clip(np.dot(va, vb), -1, 1)) 101 | if angle != 0: 102 | if compute_length_vec3(axis) == 0: 103 | dotx = va[0] 104 | if (math.fabs(dotx) != 1.0): 105 | axis = np.array([1,0,0]) - dotx * va 106 | else: 107 | axis = np.array([0,1,0]) - va[1] * va 108 | axis /= compute_length_vec3(axis) 109 | transform = rotation(axis, -angle) 110 | transform[:3,3] += p0 111 | verts = [np.dot(transform, np.array([v[0], v[1], v[2], 1.0])) for v in verts] 112 | verts = [np.array([v[0], v[1], v[2]]) / v[3] for v in verts] 113 | 114 | return verts, indices 115 | 116 | def write_bbox(corners, color, output_file): 117 | """ 118 | bbox: (cx, cy, cz, lx, ly, lz, r), center and length in three axis, the last is the rotation 119 | output_file: string 120 | """ 121 | 122 | def get_bbox_edges(bbox_min, bbox_max): 123 | def get_bbox_verts(bbox_min, bbox_max): 124 | verts = [ 125 | np.array([bbox_min[0], bbox_min[1], bbox_min[2]]), 126 | np.array([bbox_max[0], bbox_min[1], bbox_min[2]]), 127 | np.array([bbox_max[0], bbox_max[1], bbox_min[2]]), 128 | np.array([bbox_min[0], bbox_max[1], bbox_min[2]]), 129 | 130 | np.array([bbox_min[0], bbox_min[1], bbox_max[2]]), 131 | np.array([bbox_max[0], bbox_min[1], bbox_max[2]]), 132 | np.array([bbox_max[0], bbox_max[1], bbox_max[2]]), 133 | np.array([bbox_min[0], bbox_max[1], bbox_max[2]]) 134 | ] 135 | return verts 136 | 137 | box_verts = get_bbox_verts(bbox_min, bbox_max) 138 | edges = [ 139 | (box_verts[0], box_verts[1]), 140 | (box_verts[1], box_verts[2]), 141 | (box_verts[2], box_verts[3]), 142 | (box_verts[3], box_verts[0]), 143 | 144 | (box_verts[4], box_verts[5]), 145 | (box_verts[5], box_verts[6]), 146 | (box_verts[6], box_verts[7]), 147 | (box_verts[7], box_verts[4]), 148 | 149 | (box_verts[0], box_verts[4]), 150 | (box_verts[1], box_verts[5]), 151 | (box_verts[2], box_verts[6]), 152 | (box_verts[3], box_verts[7]) 153 | ] 154 | return edges 155 | 156 | radius = 0.03 157 | offset = [0,0,0] 158 | verts = [] 159 | indices = [] 160 | colors = [] 161 | 162 | box_min = np.min(corners, axis=0) 163 | box_max = np.max(corners, axis=0) 164 | edges = get_bbox_edges(box_min, box_max) 165 | for k in range(len(edges)): 166 | cyl_verts, cyl_ind = create_cylinder_mesh(radius, edges[k][0], edges[k][1]) 167 | cur_num_verts = len(verts) 168 | cyl_color = [[c / 255 for c in color] for _ in cyl_verts] 169 | cyl_verts = [x + offset for x in cyl_verts] 170 | cyl_ind = [x + cur_num_verts for x in cyl_ind] 171 | verts.extend(cyl_verts) 172 | indices.extend(cyl_ind) 173 | colors.extend(cyl_color) 174 | 175 | write_ply(verts, colors, indices, output_file) 176 | return 177 | 178 | 179 | def write_path(points, color, output_file): 180 | 181 | radius = 0.03 182 | offset = [0,0,0] 183 | verts = [] 184 | indices = [] 185 | colors = [] 186 | 187 | for start, end in zip(points[:-1], points[1:]): 188 | cyl_verts, cyl_ind = create_cylinder_mesh(radius, start, end) 189 | cur_num_verts = len(verts) 190 | cyl_color = [[c / 255 for c in color] for _ in cyl_verts] 191 | cyl_verts = [x + offset for x in cyl_verts] 192 | cyl_ind = [x + cur_num_verts for x in cyl_ind] 193 | verts.extend(cyl_verts) 194 | indices.extend(cyl_ind) 195 | colors.extend(cyl_color) 196 | 197 | write_ply(verts, colors, indices, output_file) 198 | return --------------------------------------------------------------------------------