├── LICENSE
├── README.md
├── assets
├── images
│ ├── meshxl_logo.jpg
│ ├── objaverse-samples.png
│ ├── pipeline.png
│ ├── teaser.png
│ └── text-to-mesh-samples.png
└── videos
│ └── mesh_grid.mp4
├── config
└── deepspeed_stage2.yaml
├── data
└── README.md
├── datasets
├── __init__.py
├── dummy_dataset.py
└── sft
│ ├── base_dataset.py
│ ├── shapenet_bench.py
│ ├── shapenet_chair.py
│ ├── shapenet_lamp.py
│ └── shapenet_table.py
├── engine.py
├── eval_utils
├── perplexity.py
└── sample_generation.py
├── main.py
├── mesh-xl
├── mesh-xl-1.3b
│ ├── config.json
│ ├── generation_config.json
│ └── pytorch_model.bin
├── mesh-xl-125m
│ ├── config.json
│ ├── generation_config.json
│ └── pytorch_model.bin
├── mesh-xl-350m
│ ├── config.json
│ ├── generation_config.json
│ └── pytorch_model.bin
└── x-mesh-xl-350m
│ ├── config.json
│ ├── generation_config.json
│ └── pytorch_model.bin
├── models
├── mesh_xl
│ ├── get_model.py
│ └── tokenizer.py
└── x_mesh_xl
│ ├── get_model.py
│ └── tokenizer.py
├── openai
└── clip-vit-base-patch32
│ ├── config.json
│ ├── special_tokens_map.json
│ ├── tokenizer.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── requirement.txt
├── sample_t2m.py
├── scripts
├── meshxl-sft-shapenet.sh
├── sample-1.3b.sh
├── sample-125m.sh
├── sample-350m.sh
└── sample-t2mesh.sh
├── set_env.sh
└── utils
├── dist.py
├── io.py
├── logger.py
├── misc.py
├── nms.py
├── pc_util.py
└── ply_helper.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 OpenMeshLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Official repo for MeshXL
4 |
5 |
6 |
7 |
18 |
19 |
20 | ## 🏃 Intro MeshXL
21 |
22 | **MeshXL** is a family of generative pre-trained foundation models for 3D mesh generation. With the Neural Coordinate Field representation, the generation of unstructured 3D mesh data can be seaminglessly addressed by modern LLM methods. In this paper, we validate the **Neur**al **C**oordinate **F**ield (NeurCF), an explicit coordinate representation with implicit neural embeddings, is a simple-yet-effective representation for large-scale sequential mesh modeling.
23 |
24 |
25 |
26 |
27 | ## 🚩 News
28 |
29 | - [2025/04/04] Upload pre-processed text-to-mesh generation data to [huggingface🤗](https://huggingface.co/datasets/CH3COOK/MeshXL-text-to-mesh-sketchfab).
30 | - [2024/12/12] Upload pre-processed ShapeNet data to [huggingface🤗](https://huggingface.co/datasets/CH3COOK/MeshXL-shapenet-data) and supervised fine-tuning scripts on specified categories.
31 | - [2024/09/26] MeshXL is accepted to **NeurIPS 2024**🔥! See you in Vancouver!
32 | - [2024/08/29] Upload code and 🤗[weights](https://huggingface.co/CH3COOK/x-mesh-xl-350m/blob/main/pytorch_model.bin) for text-to-mesh generation, welcome to check it out!
33 | - [2024/07/24] Upload the inference code and pre-trained weights.
34 | - [2024/06/02] Upload paper and init project.
35 |
36 |
37 | ## ⚡ Quick Start
38 |
39 |
40 |
41 | Environment Setting Up
42 |
43 | You can build the environment using the provided script:
44 | ```{bashrc}
45 | bash set_env.sh
46 | ```
47 |
48 |
49 |
50 |
51 |
52 | Data
53 | Work in Progress...
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 | ## 💻 Training and Evaluation
62 |
63 |
64 | Download Pre-Trained Weights
65 |
66 | **\[Special Notes\]**: All the following models are generative pre-trained base models. They are capable of **unconditional** 3D mesh generation and **partial mesh completion**.
67 |
68 | We provide pre-trained weights for different sizes of models (i.e. `125m`, `350m`, and `1.3b`) on huggingface🤗. Download the pre-trained weights from the links below to replace the `pytorch_model.bin` files in the corresponding folders under the `./mesh-xl/` folder. The model details are shown below:
69 |
70 | | Model Size | #Layers | #Heads | $d_\text{model}$ | $d_\text{FFN}$ | GPU Hours | Download Link |
71 | |:----------:|:-------:|:------:|:----------------:|:--------------:|:---------:|:---------------------------------------------------:|
72 | | 125M | 12 | 12 | 768 | 3072 | 1944 | [download link](https://huggingface.co/CH3COOK/mesh-xl-125m) |
73 | | 350M | 24 | 16 | 1024 | 4096 | 6000 | [download link](https://huggingface.co/CH3COOK/mesh-xl-350m) |
74 | | 1.3B | 24 | 32 | 2048 | 8192 | 23232 | [download link](https://huggingface.co/CH3COOK/mesh-xl-1.3b) |
75 |
76 | Use the following command for fast downloading:
77 | ```
78 | cd ./mesh-xl
79 | git lfs clone https://huggingface.co/CH3COOK/mesh-xl-125m
80 | git lfs clone https://huggingface.co/CH3COOK/mesh-xl-350m
81 | git lfs clone https://huggingface.co/CH3COOK/mesh-xl-1.3b
82 | cd ..
83 | ```
84 |
85 |
86 |
87 |
88 |
89 | MeshXL Generative Pre-Training
90 |
91 | Work in progress...
92 |
93 |
94 |
95 |
96 |
97 |
98 | Generating Samples
99 |
100 |
101 |
102 | To generate 3D meshes with different sizes, feel free to use the following commands. By default, we generate samples with 8 GPUs and the top-k top-p sampling strategy for diverse samples.
103 |
104 | ```{bashrc}
105 | bash scripts/sample-1.3b.sh
106 | bash scripts/sample-350m.sh
107 | bash scripts/sample-125m.sh
108 | ```
109 |
110 | **\[Special Notes\]**: The following weights are fine-tuned for **unconditional** 3D mesh generation on a **specified** category.
111 |
112 | Want to generating shapes for a specified category? We have also uploaded the supervised fine-tuned checkpoints on `chair`, `table`, `bench`, `lamp` to huggingface too! Download the fine-tuned weights from the links🤗 below.
113 |
114 | | Model Size | Table | Chair | Lamp | Bench |
115 | |:----------:|:-----------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|
116 | | 125M | [download link](https://huggingface.co/CH3COOK/MeshXL-125m-sft/blob/main/meshxl-125m-table.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-125m-sft/blob/main/meshxl-125m-chair.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-125m-sft/blob/main/meshxl-125m-lamp.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-125m-sft/blob/main/meshxl-125m-bench.pth) |
117 | | 350M | [download link](https://huggingface.co/CH3COOK/MeshXL-350m-sft/blob/main/meshxl-350m-table.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-350m-sft/blob/main/meshxl-350m-chair.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-350m-sft/blob/main/meshxl-350m-lamp.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-350m-sft/blob/main/meshxl-350m-bench.pth) |
118 | | 1.3B | [download link](https://huggingface.co/CH3COOK/MeshXL-1.3b-sft/blob/main/meshxl-1.3b-table.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-1.3b-sft/blob/main/meshxl-1.3b-chair.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-1.3b-sft/blob/main/meshxl-1.3b-lamp.pth) | [download link](https://huggingface.co/CH3COOK/MeshXL-1.3b-sft/blob/main/meshxl-1.3b-bench.pth) |
119 |
120 |
121 | After you have downloaded the corresponding checkpoints, please use the following command to generate samples.
122 |
123 | ```{bashrc}
124 | export LLM_CONFIG='mesh-xl/mesh-xl-125m'
125 | # the checkpoint mush align with the $LLM_CONFIG
126 | export TEST_CKPT='./ckpts-meshxl-125m-sft/meshxl-125m-bench.pth'
127 |
128 | accelerate launch \
129 | --num_machines 1 \
130 | --num_processes 8 \
131 | --mixed_precision bf16 \
132 | main.py \
133 | --dataset dummy_dataset \
134 | --n_max_triangles 800 \
135 | --n_discrete_size 128 \
136 | --llm mesh-xl/mesh-xl-125m \
137 | --model mesh_xl \
138 | --checkpoint_dir ./outputs \
139 | --batchsize_per_gpu 2 \
140 | --test_ckpt $TEST_CKPT \
141 | --sample_rounds 100 \
142 | --dataset_num_workers 0 \
143 | --test_only
144 | ```
145 |
146 | Want to see more results? Check our project page out [here](https://meshxl.github.io/)!
147 |
148 |
149 |
150 |
151 |
152 | Text-to-Mesh Generation
153 |
154 |
155 |
156 | We thank the awesome language annotations from [PointLLM](https://github.com/OpenRobotLab/PointLLM) for object captions. We fine-tune a `350m` MeshXL model on Objaverse with 8 RTX-3090 GPUs.
157 |
158 | **Note:** please download the pre-trained checkpoint from [huggingface](https://huggingface.co/CH3COOK/x-mesh-xl-350m/blob/main/pytorch_model.bin)🤗 to replace the `./mesh-xl/x-mesh-xl-350m/pytorch_model.bin` file.
159 |
160 | We are actively working on Gradio demos. Currently, we encourage you to generate samples locally with at least 1 GPU with the following code:
161 |
162 | ```{bashrc}
163 | bash scripts/sample-t2mesh.sh
164 | ```
165 |
166 | You are also welcome to explore other text conditions and hyper-parameters for better controls:
167 |
168 | ```{bashrc}
169 | accelerate launch \
170 | --num_machines 1 \
171 | --num_processes 1 \
172 | --mixed_precision bf16 \
173 | sample_t2m.py \
174 | --test_ckpt mesh-xl/x-mesh-xl-350m/pytorch_model.bin \
175 | --text '3d model of a table' \ # change to the text prompt you need
176 | --top_k 50 \ # larger k -> larger randomness
177 | --top_p 0.95 \ # larger p -> larger randomness
178 | --temperature 0.1 # larger temperature -> larger randomness
179 | ```
180 |
181 | [Update 2025-04-04] We have uploaded the training and evaluation data used for text-to-mesh generation to [huggingface](https://huggingface.co/datasets/CH3COOK/MeshXL-text-to-mesh-sketchfab). You can load one data chunk through:
182 |
183 | ```{python}
184 | import numpy as np
185 | data = np.load('ext_text_to_mesh/sketchfab_chunk/sketchfab_pointllm_train_0000.npz', allow_pickle=True)["arr_0"].tolist()
186 | # To check the caption for a specified object:
187 | data[0]['caption']
188 |
189 | >>> ['3D white arrow on gray background']
190 | ```
191 |
192 |
193 |
194 |
195 |
196 | MeshXL Supervised Fine-Tuning
197 |
198 | Please first download the pre-processed ShapeNet data to the `./data` folder from huggingface:
199 | ```
200 | cd ./data
201 | git lfs clone https://huggingface.co/datasets/CH3COOK/MeshXL-shapenet-data
202 | cd ..
203 | ```
204 | Then, use the following command for specified categories:
205 | ```
206 | export BASE_MESHXL=mesh-xl/mesh-xl-1.3b # TODO: change the MeshXL config
207 | export BATCHSIZE_PER_GPU=4 # TODO: change the training batch size
208 |
209 | accelerate launch \
210 | --config_file ./config/deepspeed_stage2.yaml \
211 | --num_machines 1 \
212 | --num_processes 8 \
213 | --mixed_precision bf16 \
214 | main.py \
215 | --dataset sft.shapenet_table \ # TODO: change the dataset filename
216 | --n_max_triangles 800 \
217 | --n_discrete_size 128 \
218 | --warm_lr_iters -1 \
219 | --base_lr 1e-6 \
220 | --llm $BASE_MESHXL \
221 | --model mesh_xl \
222 | --checkpoint_dir ./ckpts/meshxl-shapenet-sft-table \
223 | --batchsize_per_gpu $BATCHSIZE_PER_GPU \
224 | --dataset_num_workers 0 \
225 | --augment \
226 | --eval_every_iteration 10000 \
227 | --save_every 20000 \
228 | --max_epoch 1024
229 | ```
230 |
231 |
232 |
233 |
234 |
235 |
236 | ## 📖 Citation
237 |
238 | If you find our code or paper helps, please consider citing:
239 |
240 | ```bibtex
241 | @misc{chen2024meshxl,
242 | title={MeshXL: Neural Coordinate Field for Generative 3D Foundation Models},
243 | author={Sijin Chen and Xin Chen and Anqi Pang and Xianfang Zeng and Wei Cheng and Yijun Fu and Fukun Yin and Yanru Wang and Zhibin Wang and Chi Zhang and Jingyi Yu and Gang Yu and Bin Fu and Tao Chen},
244 | year={2024},
245 | eprint={2405.20853},
246 | archivePrefix={arXiv},
247 | primaryClass={cs.CV}
248 | }
249 | ```
250 |
251 | ## Acknowledgments
252 | We use [Paint3D](https://github.com/OpenTexture/Paint3D) for texturing generated 3D meshes.
253 |
254 | We express our genuine thanks to the amazing work: [ShapeNet](https://shapenet.org/), [3D-FUTURE](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future-cn), [Objaverse](https://github.com/allenai/objaverse-xl), [Objaverse-XL](https://github.com/allenai/objaverse-xl), [PolyGen](https://github.com/google-deepmind/deepmind-research/blob/master/polygen/README.md), [Get3D](https://github.com/nv-tlabs/GET3D) and [MeshGPT](https://github.com/nihalsid/mesh-gpt), and the amazing [MeshGPT-pytorch](https://github.com/lucidrains/meshgpt-pytorch) codebase.
255 |
--------------------------------------------------------------------------------
/assets/images/meshxl_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/meshxl_logo.jpg
--------------------------------------------------------------------------------
/assets/images/objaverse-samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/objaverse-samples.png
--------------------------------------------------------------------------------
/assets/images/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/pipeline.png
--------------------------------------------------------------------------------
/assets/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/teaser.png
--------------------------------------------------------------------------------
/assets/images/text-to-mesh-samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/images/text-to-mesh-samples.png
--------------------------------------------------------------------------------
/assets/videos/mesh_grid.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/assets/videos/mesh_grid.mp4
--------------------------------------------------------------------------------
/config/deepspeed_stage2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_accumulation_steps: 1
4 | gradient_clipping: 1.0
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | fsdp_config: {}
11 | machine_rank: 0
12 | main_process_ip: null
13 | main_process_port: null
14 | main_training_function: main
15 | mixed_precision: fp16
16 | num_machines: 4
17 | num_processes: 32
18 | use_cpu: false
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## Dataset Preparation
2 |
3 |
4 | ### ShapeNet SFT Data
5 |
6 | Download the pre-processed ShapeNet data from huggingface:
7 | ```
8 | git lfs clone https://huggingface.co/datasets/CH3COOK/MeshXL-shapenet-data
9 | ```
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/dummy_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from copy import deepcopy
5 | from eval_utils.sample_generation import evaluate
6 |
7 | class Dataset:
8 |
9 | def __init__(self, *args, **kwargs):
10 | super().__init__()
11 | self.eval_func = evaluate
12 |
13 | def __len__(self):
14 | return 10
15 |
16 | def __getitem__(self, idx):
17 | data_dict = {}
18 | data_dict['shape_idx'] = np.asarray(idx).astype(np.int64)
19 | return data_dict
20 |
--------------------------------------------------------------------------------
/datasets/sft/base_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from copy import deepcopy
5 | from eval_utils.perplexity import evaluate
6 |
7 | BASE_DIR = os.path.join('.', 'data')
8 |
9 |
10 | def scale_mesh(vertices: np.ndarray, scale_factor: tuple=(0.75, 1.25)) -> np.ndarray:
11 | lower, upper = scale_factor
12 | scale_axis = lower + np.random.rand(vertices.shape[-1]) * (upper - lower)
13 | vertices = vertices * scale_axis
14 | return vertices
15 |
16 |
17 | def normalize_mesh(vertices: np.ndarray, scale_range: tuple=(-1.0, 1.0)) -> np.ndarray:
18 | lower, upper = scale_range
19 | scale_per_axis = (vertices.max(0) - vertices.min(0)).max()
20 | center_xyz = 0.5 * (vertices.max(0) + vertices.min(0))
21 | normalized_xyz = (vertices - center_xyz) / scale_per_axis # scaled into range (0, 1)
22 | vertices = normalized_xyz * (upper - lower)
23 | return vertices
24 |
25 |
26 | class BaseDataset:
27 |
28 | def __init__(self, *args, **kwargs):
29 | super().__init__()
30 | self.data =[]
31 | self.num_repeat = 1
32 |
33 | def _preprocess_data(self, data_chunk):
34 | processed = []
35 | for data in data_chunk:
36 | processed.append(
37 | dict(
38 | vertices = np.asarray(data['vertices']),
39 | faces = np.asarray(data['faces']),
40 | )
41 | )
42 | return processed
43 |
44 | def _fetch_data(self, idx):
45 | idx = idx % len(self.data)
46 | return deepcopy(self.data[idx])
47 |
48 | def __len__(self):
49 | return len(self.data) * self.num_repeat
50 |
51 | def __getitem__(self, idx):
52 | data = self._fetch_data(idx)
53 | data['vertices'] = np.asarray(data['vertices'])
54 | data['faces'] = np.asarray(data['faces'])
55 |
56 | num_vertices = len(data['vertices'])
57 | num_faces = len(data['faces'])
58 |
59 | vertices = np.ones((self.max_vertices, 3)) * self.pad_id
60 | faces = np.ones((self.max_triangles, 3)) * self.pad_id
61 |
62 | if self.augment is True:
63 | data['vertices'] = scale_mesh(data['vertices'])
64 | data['vertices'] = normalize_mesh(data['vertices'])
65 |
66 | vertices[:num_vertices] = data['vertices']
67 | faces[:num_faces] = data['faces']
68 |
69 | gt_vertices = vertices[faces.clip(0).astype(np.int64)] # nface x 3 x 3
70 | gt_vertices[faces[:, 0] == self.pad_id] = float('nan')
71 |
72 | data_dict = {}
73 | data_dict['shape_idx'] = np.asarray(idx).astype(np.int64)
74 | data_dict['vertices'] = np.asarray(vertices).astype(np.float32)
75 | data_dict['faces'] = np.asarray(faces).astype(np.int64)
76 | data_dict['gt_vertices'] = np.asarray(gt_vertices).astype(np.float32)
77 |
78 | return data_dict
79 |
--------------------------------------------------------------------------------
/datasets/sft/shapenet_bench.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from eval_utils.perplexity import evaluate
5 | from datasets.sft.base_dataset import BaseDataset, BASE_DIR
6 |
7 | DATASET_DIR = os.path.join(BASE_DIR, 'MeshXL-shapenet-data')
8 |
9 |
10 |
11 | class Dataset(BaseDataset):
12 |
13 | def __init__(self, args, split_set="train", augment=False):
14 | super().__init__()
15 |
16 | # base dataset config
17 | self.dataset_name = 'shapenet_bench'
18 | self.category_id = '02828884'
19 | self.eval_func = evaluate
20 | self.augment = augment and (split_set == 'train')
21 | self.num_repeat = 1
22 | self.pad_id = -1
23 | self.max_triangles = args.n_max_triangles
24 | self.max_vertices = self.max_triangles * 3
25 |
26 | # pre-load data into memory
27 | full_data = []
28 | for filename in tqdm(os.listdir(DATASET_DIR)):
29 | if self.category_id not in filename:
30 | continue
31 | if (split_set in filename) and filename.endswith('.npz'):
32 | loaded_data = np.load(
33 | os.path.join(DATASET_DIR, filename),
34 | allow_pickle=True
35 | )
36 | loaded_data = loaded_data["arr_0"].tolist()
37 | loaded_data = self._preprocess_data(loaded_data)
38 | full_data = full_data + loaded_data
39 |
40 | self.data = full_data
41 |
42 | print(f"[MeshDataset] Created from {len(self.data)} shapes for {self.dataset_name} {split_set}")
43 |
--------------------------------------------------------------------------------
/datasets/sft/shapenet_chair.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from eval_utils.perplexity import evaluate
5 | from datasets.sft.base_dataset import BaseDataset, BASE_DIR
6 |
7 | DATASET_DIR = os.path.join(BASE_DIR, 'MeshXL-shapenet-data')
8 |
9 |
10 |
11 | class Dataset(BaseDataset):
12 |
13 | def __init__(self, args, split_set="train", augment=False):
14 | super().__init__()
15 |
16 | # base dataset config
17 | self.dataset_name = 'shapenet_chair'
18 | self.category_id = '03001627'
19 | self.eval_func = evaluate
20 | self.augment = augment and (split_set == 'train')
21 | self.num_repeat = 1
22 | self.pad_id = -1
23 | self.max_triangles = args.n_max_triangles
24 | self.max_vertices = self.max_triangles * 3
25 |
26 | # pre-load data into memory
27 | full_data = []
28 | for filename in tqdm(os.listdir(DATASET_DIR)):
29 | if self.category_id not in filename:
30 | continue
31 | if (split_set in filename) and filename.endswith('.npz'):
32 | loaded_data = np.load(
33 | os.path.join(DATASET_DIR, filename),
34 | allow_pickle=True
35 | )
36 | loaded_data = loaded_data["arr_0"].tolist()
37 | loaded_data = self._preprocess_data(loaded_data)
38 | full_data = full_data + loaded_data
39 |
40 | self.data = full_data
41 |
42 | print(f"[MeshDataset] Created from {len(self.data)} shapes for {self.dataset_name} {split_set}")
43 |
--------------------------------------------------------------------------------
/datasets/sft/shapenet_lamp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from eval_utils.perplexity import evaluate
5 | from datasets.sft.base_dataset import BaseDataset, BASE_DIR
6 |
7 | DATASET_DIR = os.path.join(BASE_DIR, 'MeshXL-shapenet-data')
8 |
9 |
10 |
11 | class Dataset(BaseDataset):
12 |
13 | def __init__(self, args, split_set="train", augment=False):
14 | super().__init__()
15 |
16 | # base dataset config
17 | self.dataset_name = 'shapenet_lamp'
18 | self.category_id = '03636649'
19 | self.eval_func = evaluate
20 | self.augment = augment and (split_set == 'train')
21 | self.num_repeat = 1
22 | self.pad_id = -1
23 | self.max_triangles = args.n_max_triangles
24 | self.max_vertices = self.max_triangles * 3
25 |
26 | # pre-load data into memory
27 | full_data = []
28 | for filename in tqdm(os.listdir(DATASET_DIR)):
29 | if self.category_id not in filename:
30 | continue
31 | if (split_set in filename) and filename.endswith('.npz'):
32 | loaded_data = np.load(
33 | os.path.join(DATASET_DIR, filename),
34 | allow_pickle=True
35 | )
36 | loaded_data = loaded_data["arr_0"].tolist()
37 | loaded_data = self._preprocess_data(loaded_data)
38 | full_data = full_data + loaded_data
39 |
40 | self.data = full_data
41 |
42 | print(f"[MeshDataset] Created from {len(self.data)} shapes for {self.dataset_name} {split_set}")
43 |
--------------------------------------------------------------------------------
/datasets/sft/shapenet_table.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from eval_utils.perplexity import evaluate
5 | from datasets.sft.base_dataset import BaseDataset, BASE_DIR
6 |
7 | DATASET_DIR = os.path.join(BASE_DIR, 'MeshXL-shapenet-data')
8 |
9 |
10 |
11 | class Dataset(BaseDataset):
12 |
13 | def __init__(self, args, split_set="train", augment=False):
14 | super().__init__()
15 |
16 | # base dataset config
17 | self.dataset_name = 'shapenet_table'
18 | self.category_id = '04379243'
19 | self.eval_func = evaluate
20 | self.augment = augment and (split_set == 'train')
21 | self.num_repeat = 1
22 | self.pad_id = -1
23 | self.max_triangles = args.n_max_triangles
24 | self.max_vertices = self.max_triangles * 3
25 |
26 | # pre-load data into memory
27 | full_data = []
28 | for filename in tqdm(os.listdir(DATASET_DIR)):
29 | if self.category_id not in filename:
30 | continue
31 | if (split_set in filename) and filename.endswith('.npz'):
32 | loaded_data = np.load(
33 | os.path.join(DATASET_DIR, filename),
34 | allow_pickle=True
35 | )
36 | loaded_data = loaded_data["arr_0"].tolist()
37 | loaded_data = self._preprocess_data(loaded_data)
38 | full_data = full_data + loaded_data
39 |
40 | self.data = full_data
41 |
42 | print(f"[MeshDataset] Created from {len(self.data)} shapes for {self.dataset_name} {split_set}")
43 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import time, math
2 | import torch
3 | import datetime
4 |
5 | from utils.io import save_checkpoint
6 | from utils.misc import SmoothedValue
7 |
8 |
9 |
10 | def compute_learning_rate(args, curr_iter, max_iters):
11 | assert curr_iter <= max_iters and curr_iter >= 0
12 | if (curr_iter <= args.warm_lr_iters) and args.warm_lr_iters > 0:
13 | # Linear Warmup: warm_lr -> curr_lr -> base_lr
14 | curr_lr = args.warm_lr + curr_iter / args.warm_lr_iters * (args.base_lr - args.warm_lr)
15 | else:
16 | # Cosine Learning Rate Schedule
17 | curr_lr = args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (
18 | 1 + math.cos(math.pi * curr_iter / max_iters)
19 | )
20 | return curr_lr
21 |
22 |
23 |
24 | def adjust_learning_rate(args, optimizer, curr_iter, max_iters):
25 | curr_lr = compute_learning_rate(args, curr_iter, max_iters)
26 | for param_group in optimizer.param_groups:
27 | param_group["lr"] = curr_lr
28 | return curr_lr
29 |
30 |
31 |
32 | def do_train(
33 | args,
34 | model,
35 | accelerator,
36 | optimizer,
37 | dataloaders,
38 | best_val_metrics,
39 | logger
40 | ):
41 |
42 | if accelerator.is_main_process:
43 | logger.log_messages(f"call with args: {args}")
44 | logger.log_messages(f"{model}")
45 |
46 | curr_iter = args.start_epoch * len(dataloaders['train'])
47 | max_iters = args.max_epoch * len(dataloaders['train'])
48 |
49 | time_delta = SmoothedValue(window_size=10)
50 | loss_avg = SmoothedValue(window_size=10)
51 | loss_break_down_avg = {}
52 |
53 | model.train()
54 | accelerator.wait_for_everyone()
55 |
56 | for curr_epoch in range(args.start_epoch, args.max_epoch):
57 |
58 | for batch_idx, batch_data_label in enumerate(dataloaders['train']):
59 |
60 | curr_time = time.time()
61 |
62 | ### core for model training
63 |
64 | curr_iter = curr_epoch * len(dataloaders['train']) + batch_idx
65 | curr_lr = adjust_learning_rate(args, optimizer, curr_iter, max_iters)
66 |
67 | with accelerator.accumulate(model):
68 |
69 | with accelerator.autocast():
70 | outputs = model(batch_data_label)
71 | loss = outputs['loss']
72 |
73 | # sanity check, skip the infinite loss
74 | if not math.isfinite(loss.item()):
75 | logger.log_messages("Loss in not finite. Skip this iteration.")
76 | model.eval()
77 | model.train()
78 | torch.cuda.empty_cache()
79 | continue
80 |
81 | accelerator.backward(loss)
82 | if args.clip_gradient > 0:
83 | accelerator.clip_grad_norm_(model.parameters(), args.clip_gradient)
84 |
85 | optimizer.step()
86 | optimizer.zero_grad()
87 |
88 | ### logging training loss status
89 |
90 | time_delta.update(time.time() - curr_time)
91 | loss_avg.update(loss.item())
92 |
93 | for key, value in outputs.items():
94 | if 'loss' in key.lower():
95 | loss_break_down_avg[key] = loss_break_down_avg.get(key, SmoothedValue(window_size=10))
96 | loss_break_down_avg[key].update(value.item())
97 |
98 | ### writing logs
99 |
100 | if accelerator.is_main_process and curr_iter % args.log_every == 0:
101 | mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
102 | eta_seconds = (max_iters - curr_iter) * time_delta.avg
103 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
104 |
105 | logger.log_messages(
106 | '; '.join(
107 | [
108 | f"Epoch [{curr_epoch}/{args.max_epoch}]",
109 | f"Iter [{curr_iter}/{max_iters}]",
110 | # loss string
111 | *(
112 | f'{key} {avg.avg:0.4f}' \
113 | for key, avg in loss_break_down_avg.items()
114 | ),
115 | # status string
116 | f"LR {curr_lr:0.2e}",
117 | f"Iter time {time_delta.avg:0.2f}s",
118 | f"ETA {eta_str}",
119 | f"Mem {mem_mb:0.2f}MB"
120 | ]
121 | )
122 | )
123 | train_loss_log = {k: v.avg for k, v in loss_break_down_avg.items()}
124 | train_loss_log["learning_rate"] = curr_lr
125 | logger.log_scalars(train_loss_log, prefix='train', step=curr_iter)
126 |
127 | ### saving checkpoints
128 |
129 | if accelerator.is_main_process and (curr_iter + 1) % args.save_every == 0:
130 | save_checkpoint(
131 | args.checkpoint_dir,
132 | accelerator.unwrap_model(model),
133 | optimizer,
134 | curr_epoch,
135 | args,
136 | best_val_metrics,
137 | filename=f"checkpoint_{(curr_iter + 1) // 1000}k.pth",
138 | )
139 |
140 | ### pending and doing evaluations: every xxx after xxx iterations
141 |
142 | do_eval_flag = (curr_iter + 1) % args.eval_every_iteration == 0
143 | do_eval_flag &= (curr_iter + 1) > args.start_eval_after
144 | do_eval_flag |= (curr_iter + 1) == max_iters
145 |
146 | if do_eval_flag is True:
147 | eval_metrics = {}
148 | model.eval()
149 | with accelerator.autocast():
150 | for test_loader in dataloaders['test']:
151 | task_metrics, eval_loss_dict = test_loader.dataset.eval_func(
152 | args,
153 | curr_epoch,
154 | accelerator.unwrap_model(model),
155 | accelerator,
156 | test_loader,
157 | logger,
158 | curr_train_iter=curr_iter
159 | )
160 | eval_metrics.update(task_metrics)
161 | logger.log_scalars(eval_loss_dict, prefix='val', step=curr_iter)
162 | model.train()
163 |
164 | ### saving `checkpoint_best.pth` do nothing for unknown criterion
165 |
166 | if args.criterion is None:
167 | continue
168 |
169 | if not best_val_metrics or (
170 | best_val_metrics[args.criterion] < eval_metrics[args.criterion]
171 | ):
172 | best_val_metrics = eval_metrics
173 | filename = "checkpoint_best.pth"
174 | save_checkpoint(
175 | args.checkpoint_dir,
176 | accelerator.unwrap_model(model),
177 | optimizer,
178 | curr_epoch,
179 | args,
180 | best_val_metrics,
181 | filename="checkpoint_best.pth",
182 | )
183 | if accelerator.is_main_process:
184 | logger.log_messages(
185 | f"Epoch [{curr_epoch}/{args.max_epoch}] "
186 | f"saved current best val checkpoint at {filename}; "
187 | f"{args.criterion} {eval_metrics[args.criterion]}"
188 | )
189 |
190 | ### end of an iteration
191 |
192 | ### end of an epoch
193 |
194 | save_checkpoint(
195 | args.checkpoint_dir,
196 | accelerator.unwrap_model(model),
197 | optimizer,
198 | curr_epoch,
199 | args,
200 | best_val_metrics,
201 | filename="checkpoint.pth",
202 | )
203 |
204 | # end of training
205 |
206 | return
207 |
--------------------------------------------------------------------------------
/eval_utils/perplexity.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import numpy as np
5 | from torch import nn, Tensor
6 | from collections import defaultdict, OrderedDict
7 | from utils.ply_helper import write_ply
8 | from utils.misc import SmoothedValue
9 | from accelerate.utils import set_seed
10 |
11 |
12 |
13 | def perplexity(neg_log_likelihood: list) -> Tensor:
14 | # gather per-sequence log likelihood for perplexity
15 | nll_chunk = torch.cat(neg_log_likelihood, dim=0)
16 | return torch.exp(nll_chunk.mean())
17 |
18 |
19 |
20 | def post_process_mesh(mesh_coords: Tensor, filename: str):
21 | mesh_coords = mesh_coords[~torch.isnan(mesh_coords[:, 0, 0])] # nvalid_face x 3 x 3
22 | vertices = mesh_coords.reshape(-1, 3)
23 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
24 | triangles = vertices_index.reshape(-1, 3)
25 | write_ply(
26 | np.asarray(vertices.cpu()),
27 | None,
28 | np.asarray(triangles),
29 | filename
30 | )
31 | return vertices
32 |
33 |
34 |
35 | @torch.no_grad()
36 | def evaluate(
37 | args,
38 | curr_epoch,
39 | model,
40 | accelerator,
41 | dataset_loader,
42 | logger,
43 | curr_train_iter=-1,
44 | ):
45 |
46 | model.eval()
47 | net_device = next(model.parameters()).device
48 | num_batches = len(dataset_loader)
49 |
50 | ### parse evaluation status
51 | if hasattr(dataset_loader.dataset, "dataset_name"):
52 | dataset_name = dataset_loader.dataset.dataset_name
53 | else:
54 | dataset_name = "default"
55 | task_name_prefix = dataset_name + '_'
56 |
57 | time_delta = SmoothedValue(window_size=10)
58 |
59 | accelerator.wait_for_everyone()
60 |
61 | epoch_str = f"[{curr_epoch}/{args.max_epoch}]" if curr_epoch > 0 else ""
62 |
63 | if accelerator.is_main_process:
64 | logger.log_messages("==" * 10)
65 | logger.log_messages(f"Evaluate Epoch [{curr_epoch}/{args.max_epoch}]")
66 | logger.log_messages("==" * 10)
67 |
68 | ### calculate perplexity
69 | neg_log_likelihood = []
70 | for curr_iter, batch_data_label in enumerate(dataset_loader):
71 |
72 | curr_time = time.time()
73 |
74 | # forward pass to calculate per-sequence negative log likelihood
75 | with accelerator.autocast():
76 | outputs = model(batch_data_label, is_eval=True)
77 | # [(batch,), (batch,), ...]
78 | neg_log_likelihood.append(outputs['neg_log_likelihood'])
79 |
80 | ### log status
81 | time_delta.update(time.time() - curr_time)
82 |
83 | if accelerator.is_main_process and curr_iter % args.log_every == 0:
84 | mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
85 | moving_average_ppl = perplexity(neg_log_likelihood)
86 | logger.log_messages(
87 | '; '.join(
88 | (
89 | f"Evaluate {epoch_str}",
90 | f"Batch [{curr_iter}/{num_batches}]",
91 | f"perplexity: {moving_average_ppl:0.4f}",
92 | f"Evaluating on iter: {curr_train_iter}",
93 | f"Iter time {time_delta.avg:0.2f}",
94 | f"Mem {mem_mb:0.2f}MB",
95 | )
96 | )
97 | )
98 |
99 | ### end of an iteration
100 |
101 | ### end of a round
102 |
103 | quantitative = {
104 | task_name_prefix + 'ppl': perplexity(neg_log_likelihood).item()
105 | }
106 |
107 | ### do sampling every evaluation
108 | curr_time = time.time()
109 |
110 | set_seed(accelerator.process_index)
111 |
112 | # create storage directory
113 | storage_dir = os.path.join(args.checkpoint_dir, task_name_prefix + 'visualization')
114 | if accelerator.is_main_process:
115 | os.makedirs(storage_dir, exist_ok = True)
116 | accelerator.wait_for_everyone()
117 |
118 | # just sample one round for checking training status
119 | for round_idx in range(1):
120 | outputs = model(
121 | data_dict=dict(),
122 | is_eval=True,
123 | is_generate=True,
124 | num_return_sequences=args.batchsize_per_gpu,
125 | )
126 |
127 | generated_meshes = outputs["recon_faces"] # nsample x nf x 3 x 3
128 |
129 | for sample_idx in range(args.batchsize_per_gpu):
130 | # store the generated meshes
131 | post_process_mesh(
132 | generated_meshes[sample_idx],
133 | os.path.join(
134 | storage_dir,
135 | '_'.join(
136 | (
137 | f'{accelerator.process_index:04d}',
138 | f'{round_idx:04d}',
139 | f'{sample_idx:04d}.ply',
140 | )
141 | )
142 | )
143 | )
144 |
145 | accelerator.wait_for_everyone()
146 |
147 | return {}, quantitative
--------------------------------------------------------------------------------
/eval_utils/sample_generation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import tqdm
4 | import torch
5 | import numpy as np
6 | from torch import Tensor
7 | from collections import defaultdict, OrderedDict
8 | from utils.ply_helper import write_ply
9 | from utils.misc import SmoothedValue
10 | from accelerate.utils import set_seed
11 |
12 |
13 | def process_mesh(mesh_coords: Tensor, filename: str):
14 | mesh_coords = mesh_coords[~torch.isnan(mesh_coords[:, 0, 0])] # nvalid_face x 3 x 3
15 | vertices = mesh_coords.reshape(-1, 3)
16 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
17 | triangles = vertices_index.reshape(-1, 3)
18 | write_ply(
19 | np.asarray(vertices.cpu()),
20 | None,
21 | np.asarray(triangles),
22 | filename
23 | )
24 | return vertices
25 |
26 |
27 | @torch.no_grad()
28 | def evaluate(
29 | args,
30 | curr_epoch,
31 | model,
32 | accelerator,
33 | dataset_loader,
34 | logout=print,
35 | curr_train_iter=-1,
36 | ):
37 |
38 | net_device = next(model.parameters()).device
39 | num_batches = len(dataset_loader)
40 |
41 | time_delta = SmoothedValue(window_size=10)
42 |
43 | storage_dir = os.path.join(args.checkpoint_dir, 'sampled')
44 | if accelerator.is_main_process:
45 | os.makedirs(storage_dir, exist_ok = True)
46 |
47 | model.eval()
48 | accelerator.wait_for_everyone()
49 |
50 | # do sampling
51 | curr_time = time.time()
52 |
53 | set_seed(accelerator.process_index)
54 |
55 | for sample_round in tqdm.tqdm(range(args.sample_rounds)):
56 |
57 | outputs = model(None, num_return_sequences=args.batchsize_per_gpu, is_eval=True, is_generate=True)
58 |
59 | batch_size = outputs['recon_faces'].shape[0]
60 | generated_faces = outputs["recon_faces"]
61 |
62 | for batch_id in range(batch_size):
63 | process_info = f'{accelerator.process_index:04d}'
64 | sample_info = f'{sample_round:04d}'
65 | batch_sample_info = f'{batch_id:04d}'
66 | sample_id = '_'.join(
67 | [
68 | process_info,
69 | sample_info,
70 | batch_sample_info
71 | ]
72 | )
73 | process_mesh(
74 | generated_faces[batch_id],
75 | os.path.join(storage_dir, f'{sample_id}_generated.ply')
76 | )
77 |
78 | # Memory intensive as it gathers point cloud GT tensor across all ranks
79 | time_delta.update(time.time() - curr_time)
80 | accelerator.wait_for_everyone()
81 |
82 | return {}, {}
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os, argparse, importlib
2 | import numpy as np
3 | import torch
4 |
5 | from engine import do_train
6 | from accelerate import Accelerator
7 | from accelerate.utils import set_seed
8 | from utils.io import resume_if_possible
9 | from utils.misc import my_worker_init_fn
10 | from utils.logger import Logger
11 |
12 |
13 | def make_args_parser():
14 |
15 | parser = argparse.ArgumentParser(
16 | "MeshXL: Neural Coordinate Field for Generative 3D Foundation Models",
17 | add_help=False
18 | )
19 |
20 | ##### Optimizer #####
21 | parser.add_argument("--base_lr", default=1e-4, type=float)
22 | parser.add_argument("--final_lr", default=1e-6, type=float)
23 | parser.add_argument("--weight_decay", default=0.1, type=float)
24 | parser.add_argument(
25 | "--clip_gradient", default=0.1, type=float,
26 | help="Max L2 norm of the gradient"
27 | )
28 | parser.add_argument("--warm_lr", default=1e-6, type=float)
29 | parser.add_argument("--warm_lr_iters", default=1000, type=int)
30 |
31 | ##### Dataset Setups #####
32 | parser.add_argument("--pad_id", default=-1, type=int, help="padding id")
33 | parser.add_argument("--dataset", default='shapenet_chair', help="dataset list split by ','")
34 | parser.add_argument("--augment", default=False, action='store_true', help="whether use data augmentation")
35 | parser.add_argument("--n_discrete_size", default=128, type=int, help="discretized 3D space")
36 | parser.add_argument("--n_max_triangles", default=800, type=int, help="max number of triangles")
37 |
38 | ##### Model Setups #####
39 | parser.add_argument(
40 | '--model',
41 | default=None,
42 | type=str,
43 | help="The model folder: unconditional / conditional mesh generation"
44 | )
45 | parser.add_argument(
46 | '--llm',
47 | default=None,
48 | type=str,
49 | help="The LLM super config and pre-trained weights"
50 | )
51 | # conditonal mesh generation, set to None for unconditional generation
52 | parser.add_argument('--text_condition', default=None, type=str, help="the conditional language model")
53 | parser.add_argument('--image_condition', default=None, type=str, help="the conditional vision model")
54 | parser.add_argument('--pretrained_weights', default=None, type=str, help='checkpoint to pre-trained weights')
55 | parser.add_argument("--dataset_num_workers", default=4, type=int, help='number of workers for dataloader')
56 | parser.add_argument("--batchsize_per_gpu", default=8, type=int, help='batch size for each GPU')
57 |
58 | ##### Training #####
59 | parser.add_argument("--start_epoch", default=-1, type=int, help='overwrite by pre-trained weights')
60 | parser.add_argument("--max_epoch", default=16, type=int, help='number of traversals for the dataset')
61 | parser.add_argument("--start_eval_after", default=-1, type=int, help='do not evaluate the model before xxx iterations')
62 | parser.add_argument("--eval_every_iteration", default=4000, type=int, help='do evaluate the model every xxx iterations')
63 | parser.add_argument("--seed", default=0, type=int, help='random seed')
64 |
65 | ##### Testing #####
66 | parser.add_argument("--test_only", default=False, action="store_true")
67 | parser.add_argument("--sample_rounds", default=100, type=int, help='do sample for xxx rounds to produce 3D meshes')
68 |
69 | parser.add_argument(
70 | "--criterion", default=None, type=str,
71 | help='metrics for saving the best model, set to None for not saving any'
72 | )
73 | parser.add_argument("--test_ckpt", default="", type=str, help='test checkpoint directory')
74 |
75 | ##### I/O #####
76 | parser.add_argument("--checkpoint_dir", default=None, type=str, help='path to save the checkpoints and visualization samples')
77 | parser.add_argument("--save_every", default=20000, type=int, help='save checkpoints every xxx iterations')
78 | parser.add_argument("--log_every", default=10, type=int, help='write training logs every xxx iterations')
79 |
80 | args = parser.parse_args()
81 |
82 | return args
83 |
84 |
85 | def build_dataloader_func(args, dataset, split):
86 |
87 | if split == "train":
88 | sampler = torch.utils.data.RandomSampler(dataset)
89 | else:
90 | sampler = torch.utils.data.SequentialSampler(dataset)
91 |
92 | dataloader = torch.utils.data.DataLoader(
93 | dataset,
94 | sampler=sampler,
95 | batch_size=args.batchsize_per_gpu,
96 | num_workers=args.dataset_num_workers,
97 | worker_init_fn=my_worker_init_fn,
98 | # add for meshgpt
99 | drop_last = True,
100 | )
101 | return sampler, dataloader
102 |
103 |
104 | def build_dataset_func(args):
105 |
106 | datasets = {
107 | 'train': [],
108 | 'test': []
109 | }
110 |
111 | for dataset in args.dataset.split(','):
112 | dataset_module = importlib.import_module(f'datasets.{dataset}')
113 | datasets['train'].append(
114 | dataset_module.Dataset(args, split_set="train", augment=args.augment)
115 | )
116 | datasets['test'].append(
117 | dataset_module.Dataset(args, split_set="val", augment=False)
118 | )
119 | datasets['train'] = torch.utils.data.ConcatDataset(datasets['train'])
120 |
121 | train_sampler, train_loader = build_dataloader_func(args, datasets['train'], split='train')
122 | dataloaders = {
123 | 'train': train_loader,
124 | 'test': [],
125 | 'train_sampler': train_sampler,
126 | }
127 |
128 | for dataset in datasets['test']:
129 | _, test_loader = build_dataloader_func(args, dataset, split='test')
130 | dataloaders['test'].append(test_loader)
131 |
132 | return datasets, dataloaders
133 |
134 |
135 | def build_model_func(args):
136 | model_module = importlib.import_module(f'models.{args.model}.get_model')
137 | model = model_module.get_model(args)
138 | return model
139 |
140 |
141 | def main(args):
142 |
143 | np.random.seed(args.seed)
144 | torch.cuda.manual_seed_all(args.seed)
145 |
146 | if args.checkpoint_dir is not None:
147 | pass
148 | elif args.test_ckpt is not None:
149 | # if not define the checkpoint-dir, set to the test checkpoint folder as default
150 | args.checkpoint_dir = os.path.dirname(args.test_ckpt)
151 | print(f'testing directory: {args.checkpoint_dir}')
152 | else:
153 | raise AssertionError(
154 | 'Either checkpoint_dir or test_ckpt should be presented!'
155 | )
156 |
157 | os.makedirs(args.checkpoint_dir, exist_ok=True)
158 | accelerator = Accelerator()
159 | set_seed(args.seed)
160 |
161 | ### build datasets and dataloaders
162 | datasets, dataloaders = build_dataset_func(args)
163 |
164 | ### build models
165 | model = build_model_func(args)
166 | ### set default checkpoint
167 | checkpoint = None
168 |
169 | # testing phase
170 | if args.test_only:
171 |
172 | try:
173 | checkpoint = torch.load(args.test_ckpt, map_location=torch.device("cpu"))
174 | model.load_state_dict(checkpoint["model"], strict=False)
175 | except:
176 | print('test the model from scratch...')
177 |
178 | model, dataloaders['train'], *dataloaders['test'] = accelerator.prepare(
179 | model, dataloaders['train'], *dataloaders['test']
180 | )
181 |
182 | for test_loader in dataloaders['test']:
183 | test_loader.dataset.eval_func(
184 | args,
185 | -1,
186 | model,
187 | accelerator,
188 | test_loader
189 | )
190 |
191 | # training phase
192 | else:
193 |
194 | assert (
195 | args.checkpoint_dir is not None
196 | ), "`--checkpoint_dir` is required to identify the directory to store the checkpoint"
197 | os.makedirs(args.checkpoint_dir, exist_ok=True)
198 |
199 | logger = Logger(args.checkpoint_dir, accelerator)
200 |
201 | ### whether or not use pretrained weights
202 | if args.pretrained_weights is not None:
203 | checkpoint = torch.load(args.pretrained_weights, map_location="cpu")
204 | model.load_state_dict(checkpoint['model'], strict=False)
205 |
206 | if accelerator.is_main_process:
207 | if checkpoint is not None:
208 | logger.log_messages('Loaded the parameters for weight initialization:')
209 | for name, param in checkpoint['model'].items():
210 | logger.log_messages('\t'.join(['', name + ':', f'{param.shape}']))
211 | logger.log_messages('\n' * 10)
212 | logger.log_messages('====\n')
213 | logger.log_messages('The trainable parameters are:')
214 |
215 | for name, param in model.named_parameters():
216 | status = '[train]' if param.requires_grad else '[eval]'
217 | logger.log_messages('\t'.join(['', status, name + ':', f'{param.shape}']))
218 |
219 | optimizer = torch.optim.AdamW(
220 | filter(lambda params: params.requires_grad, model.parameters()),
221 | lr=args.base_lr,
222 | weight_decay=args.weight_decay
223 | )
224 |
225 | loaded_epoch, best_val_metrics = resume_if_possible(
226 | args.checkpoint_dir, model, optimizer
227 | )
228 | args.start_epoch = loaded_epoch + 1
229 |
230 | model, optimizer, dataloaders['train'], *dataloaders['test'] = accelerator.prepare(
231 | model, optimizer, dataloaders['train'], *dataloaders['test']
232 | )
233 |
234 | do_train(
235 | args,
236 | model,
237 | accelerator,
238 | optimizer,
239 | dataloaders,
240 | best_val_metrics,
241 | logger
242 | )
243 |
244 |
245 | if __name__ == "__main__":
246 | args = make_args_parser()
247 |
248 | os.environ['PYTHONWARNINGS']='ignore:semaphore_tracker:UserWarning'
249 |
250 | main(args)
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-1.3b/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "mesh-xl/mesh-xl-1.3b",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 8192,
14 | "hidden_size": 2048,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 32,
20 | "num_hidden_layers": 24,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "float16",
24 | "transformers_version": "4.21.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 50272,
27 | "word_embed_proj_dim": 2048
28 | }
29 |
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-1.3b/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "bos_token_id": 2,
4 | "eos_token_id": 2,
5 | "pad_token_id": 1,
6 | "transformers_version": "4.27.0.dev0"
7 | }
8 |
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-1.3b/pytorch_model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/mesh-xl/mesh-xl-1.3b/pytorch_model.bin
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-125m/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "mesh-xl/mesh-xl-125m",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 3072,
14 | "hidden_size": 768,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 12,
20 | "num_hidden_layers": 12,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "float16",
24 | "transformers_version": "4.21.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 50272,
27 | "word_embed_proj_dim": 768
28 | }
29 |
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-125m/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "bos_token_id": 2,
4 | "eos_token_id": 2,
5 | "pad_token_id": 1,
6 | "transformers_version": "4.27.0.dev0"
7 | }
8 |
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-125m/pytorch_model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/mesh-xl/mesh-xl-125m/pytorch_model.bin
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-350m/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "mesh-xl/mesh-xl-350m",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 4096,
14 | "hidden_size": 1024,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 16,
20 | "num_hidden_layers": 24,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "float16",
24 | "transformers_version": "4.20.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 50272,
27 | "word_embed_proj_dim": 1024
28 | }
29 |
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-350m/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "bos_token_id": 2,
4 | "eos_token_id": 2,
5 | "pad_token_id": 1,
6 | "transformers_version": "4.27.0.dev0"
7 | }
8 |
--------------------------------------------------------------------------------
/mesh-xl/mesh-xl-350m/pytorch_model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/mesh-xl/mesh-xl-350m/pytorch_model.bin
--------------------------------------------------------------------------------
/mesh-xl/x-mesh-xl-350m/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "mesh-xl/mesh-xl-350m",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 4096,
14 | "hidden_size": 1024,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 16,
20 | "num_hidden_layers": 24,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "float16",
24 | "transformers_version": "4.20.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 50272,
27 | "word_embed_proj_dim": 1024
28 | }
29 |
--------------------------------------------------------------------------------
/mesh-xl/x-mesh-xl-350m/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "bos_token_id": 2,
4 | "eos_token_id": 2,
5 | "pad_token_id": 1,
6 | "transformers_version": "4.27.0.dev0"
7 | }
8 |
--------------------------------------------------------------------------------
/mesh-xl/x-mesh-xl-350m/pytorch_model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenMeshLab/MeshXL/1bfbf568ee188291f7d0bf021506b90d6f282c21/mesh-xl/x-mesh-xl-350m/pytorch_model.bin
--------------------------------------------------------------------------------
/models/mesh_xl/get_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as nnf
3 | from torch import nn, Tensor
4 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
5 | from models.mesh_xl.tokenizer import MeshTokenizer
6 | from typing import Dict
7 |
8 |
9 | class MeshXL(nn.Module):
10 |
11 | def train(self, mode: bool = True):
12 | super().train(mode)
13 | return self
14 |
15 | def __init__(self, args):
16 | super().__init__()
17 |
18 | self.tokenizer = MeshTokenizer(args)
19 |
20 | # causal LM model initialization
21 | self.vocab_size = self.tokenizer.codebook_size + 3
22 | self.bos_token_id = self.tokenizer.codebook_size
23 | self.eos_token_id = self.tokenizer.codebook_size + 1
24 | self.pad_token_id = self.tokenizer.codebook_size + 2
25 |
26 | config = AutoConfig.from_pretrained(
27 | args.llm,
28 | n_positions=8192,
29 | max_position_embeddings=8192,
30 | vocab_size=self.vocab_size,
31 | bos_token_id=self.bos_token_id,
32 | eos_token_id=self.eos_token_id,
33 | pad_token_id=self.pad_token_id
34 | )
35 |
36 | config.word_embed_proj_dim = config.hidden_size
37 | self.transformer = AutoModelForCausalLM.from_pretrained(
38 | args.llm,
39 | config=config,
40 | ignore_mismatched_sizes=True
41 | )
42 | self.transformer.to_bettertransformer()
43 |
44 | # setting status for all parameters
45 | self.train()
46 |
47 |
48 | def forward(
49 | self,
50 | data_dict: dict=None,
51 | is_eval: bool=False,
52 | is_generate: bool=False,
53 | num_return_sequences: int=8,
54 | generation_config: Dict=dict(
55 | do_sample=True,
56 | top_k=50,
57 | top_p=0.95,
58 | # no_repeat_ngram_size=9,
59 | )
60 | ) -> dict:
61 |
62 | if not is_eval:
63 | return self.train_one_step(data_dict)
64 |
65 | if is_eval and not is_generate:
66 | return self.perplexity(data_dict)
67 |
68 | if is_eval and is_generate:
69 | return self.generate(
70 | data_dict=data_dict,
71 | num_return_sequences=num_return_sequences,
72 | generation_config=generation_config
73 | )
74 |
75 | raise NotImplementedError('training status undefined!')
76 | return
77 |
78 |
79 | def loss_wrapper(self, loss: Tensor) -> Tensor:
80 | # parameter activation: it is a l2 loss with 0 weight
81 | for param in self.parameters():
82 | loss += 0 * torch.sum(param ** 2)
83 | return loss
84 |
85 |
86 | def train_one_step(self, data_dict: dict) -> dict:
87 |
88 | data_dict = self.tokenizer.tokenize(data_dict)
89 |
90 | input_ids = data_dict['input_ids'] # batch x ntoken
91 | attention_mask = data_dict['attention_mask'] # batch x ntoken
92 |
93 | # parse input with and tokens
94 | input_ids[input_ids == self.tokenizer.pad_id] = self.pad_token_id # xxx
95 | input_ids[:, 0] = self.bos_token_id # xxx
96 | eos_pos_id = attention_mask.sum(1, keepdim=True) - 1
97 | input_ids = torch.scatter( # xxx
98 | input_ids,
99 | 1,
100 | eos_pos_id.long(),
101 | torch.ones_like(input_ids) * self.eos_token_id
102 | )
103 |
104 | target = input_ids.clone()
105 | target[attention_mask == 0] = -100 # not loss for the padding tokens
106 |
107 | # Forward padd, calling causal llm with better transformer.
108 | output = self.transformer(
109 | input_ids = input_ids.long(),
110 | )
111 |
112 | # compute loss with shift one-token right
113 | logit = output.logits[:, :-1] # batch x ntoken x vocab
114 | label = target[:, 1:] # batch x ntoken
115 |
116 | final_loss = nnf.cross_entropy(
117 | logit.permute(0, 2, 1), # batch x vocab x ntoken
118 | label,
119 | ) # batch x ntoken
120 |
121 | data_dict['loss'] = self.loss_wrapper(final_loss)
122 | data_dict['gen_loss'] = final_loss
123 |
124 | return data_dict
125 |
126 |
127 | @torch.no_grad()
128 | def perplexity(self, data_dict: dict) -> dict:
129 |
130 | data_dict = self.tokenizer.tokenize(data_dict)
131 |
132 | input_ids = data_dict['input_ids'] # batch x ntoken
133 | attention_mask = data_dict['attention_mask'] # batch x ntoken
134 |
135 | # set pad_token_id = eos_token_id
136 | input_ids[input_ids == self.tokenizer.pad_id] = self.pad_token_id # xxx
137 | input_ids[:, 0] = self.bos_token_id # xxx
138 | eos_pos_id = attention_mask.sum(1, keepdim=True) - 1
139 | input_ids = torch.scatter( # xxx
140 | input_ids,
141 | 1,
142 | eos_pos_id.long(),
143 | torch.ones_like(input_ids) * self.eos_token_id
144 | )
145 |
146 | # llm loss calculation
147 | output = self.transformer(
148 | input_ids = input_ids.long(),
149 | )
150 |
151 | # compute loss with shift token right
152 | logit = output.logits[:, :-1] # batch x (ntoken - 1) x vocab
153 | label = input_ids[:, 1:] # batch x (ntoken - 1)
154 | masks = attention_mask[:, 1:] # batch x (ntoken - 1)
155 | loss_per_token = nnf.cross_entropy(
156 | logit.permute(0, 2, 1), # batch x (ntoken - 1) x ntoken
157 | label, # batch x (ntoken - 1)
158 | reduction='none'
159 | ) # batch x ntoken
160 |
161 | # compute negative log likelihood for each sequence
162 | neg_log_likelihood = torch.sum(loss_per_token * masks, dim=1) / torch.sum(masks, dim=1)
163 |
164 | data_dict['neg_log_likelihood'] = neg_log_likelihood # batch,
165 | return data_dict
166 |
167 |
168 |
169 | @torch.no_grad()
170 | def generate(self, data_dict: dict=None, num_return_sequences: int=8, generation_config: dict=dict()) -> dict:
171 |
172 | net_device = next(self.parameters()).device
173 | max_length = 8192
174 | output_ids = torch.ones(num_return_sequences, max_length).long().to(net_device) * self.eos_token_id
175 |
176 | # batch x ntokens
177 | results = self.transformer.generate(
178 | max_new_tokens=max_length-1,
179 | num_return_sequences=num_return_sequences,
180 | bos_token_id=self.bos_token_id,
181 | eos_token_id=self.eos_token_id,
182 | pad_token_id=self.eos_token_id,
183 | **generation_config
184 | )
185 | output_ids[:, :results.shape[1]] = results
186 |
187 | # discard and tokens to pad tokens
188 | output_ids = output_ids[:, 1: -1]
189 | output_ids[output_ids == self.eos_token_id] = self.tokenizer.pad_id
190 |
191 | decoder_output = self.tokenizer.detokenize(input_ids=output_ids)
192 |
193 | return decoder_output
194 |
195 |
196 |
197 | def get_model(args):
198 | model = MeshXL(args)
199 | return model
--------------------------------------------------------------------------------
/models/mesh_xl/tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from typing import Tuple
4 | from einops import rearrange, repeat, reduce
5 |
6 |
7 |
8 | def discretize(
9 | t: Tensor,
10 | continuous_range: Tuple[float, float],
11 | num_discrete: int = 128
12 | ) -> Tensor:
13 |
14 | lo, hi = continuous_range
15 | assert hi > lo
16 | t = (t - lo) / (hi - lo) # cube normalize
17 | t *= num_discrete
18 | t -= 0.5
19 | return t.round().long().clamp(min = 0, max = num_discrete - 1)
20 |
21 |
22 |
23 | def undiscretize(
24 | t: Tensor,
25 | continuous_range = Tuple[float, float],
26 | num_discrete: int = 128
27 | ) -> Tensor:
28 | lo, hi = continuous_range
29 | assert hi > lo
30 | t = t.float()
31 | t += 0.5
32 | t /= num_discrete # cube normalize
33 | return t * (hi - lo) + lo
34 |
35 |
36 |
37 | class MeshTokenizer(nn.Module):
38 |
39 | def __init__(self, args):
40 | super().__init__()
41 | self.pad_id = -1
42 | self.num_discrete_coors = args.n_discrete_size # default: 800
43 | self.codebook_size = args.n_discrete_size # default: 128
44 | self.coor_continuous_range = (-1., 1.)
45 |
46 |
47 | def tokenize(self, data_dict: dict) -> dict:
48 | '''
49 | Turn 3D meshes into sequential tokens: [, , ], ...
50 | '''
51 |
52 | ### 3D mesh face parsing
53 | vertices = data_dict['vertices'] # batch x nv x 3
54 | faces = data_dict['faces'] # batch x nf x 3
55 | face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all') # batch x nf
56 |
57 | batch, num_vertices, num_coors = vertices.shape
58 | _, num_faces, _ = faces.shape
59 |
60 | # fill padding tokens with 0, to prevent gather idx error
61 | face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0)
62 |
63 | # collect vertice coordinates per-face: b x nf x nv x c
64 | faces_vertices = repeat(face_without_pad, 'b nf nv -> b nf nv c', c = num_coors)
65 | vertices = repeat(vertices, 'b nv c -> b nf nv c', nf = num_faces)
66 | face_coords = vertices.gather(-2, faces_vertices.long())
67 |
68 | # continuous to discrete face coords: b x nf x nv x c
69 | discrete_face_coords = discretize(
70 | face_coords,
71 | continuous_range=self.coor_continuous_range,
72 | num_discrete=self.num_discrete_coors
73 | )
74 |
75 | # pad invalid faces with : batch x nf x nv x c
76 | discrete_padded_coords = discrete_face_coords.masked_fill(
77 | ~rearrange(face_mask, 'b nf -> b nf 1 1'),
78 | self.pad_id
79 | )
80 |
81 |
82 | ### mesh to sequence convertion: batch x ntokens
83 | input_ids = discrete_padded_coords.reshape(batch, -1)
84 | attention_mask = (input_ids != self.pad_id).float()
85 | # reserve two spots:
86 | # input_ids: ... ... => ... ...
87 | # attn_mask: 1 ... 1 0 ... => 1 ... 1 0 ...
88 | place_holder = torch.ones_like(input_ids[:, [0]]) # batch x 1
89 | input_ids = torch.cat((place_holder * self.pad_id, input_ids, place_holder * self.pad_id), dim=1)
90 | attention_mask = torch.cat((place_holder, place_holder, attention_mask), dim=1)
91 |
92 | ### meshXL inputs
93 | data_dict['input_ids'] = input_ids.long() # batch x (nf * 3 * 3 + 2)
94 | data_dict['attention_mask'] = attention_mask.float() # batch x (nf * 3 * 3 + 2)
95 |
96 | # discard and tokens
97 | data_dict['codes'] = discrete_padded_coords.long() # batch x (nf * 3 * 3)
98 | data_dict['discrete_face_coords'] = discrete_face_coords
99 |
100 | return data_dict
101 |
102 |
103 | def detokenize(self, input_ids: Tensor) -> dict:
104 | '''
105 | Turn sequential tokens: [, , ], ... into 3D meshes
106 | '''
107 | # input_ids: b (n q) or b n q, without or
108 | input_ids = input_ids.reshape(input_ids.shape[0], -1)
109 | # batch x nface
110 | face_mask = reduce(
111 | input_ids != self.pad_id, 'b (nf c) -> b nf', 'all', c = 9
112 | )
113 |
114 | # batch x (nface x 9) -> batch x nface x 3 x 3
115 | pred_face_coords = input_ids.reshape(input_ids.shape[0], -1, 9)
116 | pred_face_coords = rearrange(
117 | pred_face_coords, '... (v c) -> ... v c', v = 3
118 | )
119 |
120 | # back to continuous space
121 | continuous_coors = undiscretize(
122 | pred_face_coords,
123 | num_discrete = self.num_discrete_coors,
124 | continuous_range = self.coor_continuous_range
125 | )
126 | # mask padding coordinates out with nan
127 | continuous_coors = continuous_coors.masked_fill(
128 | ~rearrange(face_mask, 'b nf -> b nf 1 1'),
129 | float('nan')
130 | )
131 | output_dict = {}
132 | output_dict['recon_faces'] = continuous_coors
133 |
134 | return output_dict
135 |
136 |
137 | def forward(self, data_dict: dict) -> dict:
138 |
139 | encoder_output = self.tokenize(data_dict)
140 | decoder_output = self.detokenize(
141 | input_ids = encoder_output['codes'],
142 | )
143 | data_dict.update(encoder_output)
144 | data_dict.update(decoder_output)
145 | return data_dict
146 |
--------------------------------------------------------------------------------
/models/x_mesh_xl/get_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as nnf
3 | from torch import nn, Tensor
4 | from einops import repeat, rearrange
5 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, Blip2QFormerModel, Blip2QFormerConfig
6 | from models.x_mesh_xl.tokenizer import MeshTokenizer
7 | from typing import Dict
8 |
9 |
10 | class ConditionEncoder(nn.Module):
11 |
12 | def train(self, mode: bool = True):
13 | super().train(mode)
14 | self.multi_encoder.eval()
15 | for param in self.multi_encoder.parameters():
16 | param.requires_grad = False
17 | return self
18 |
19 | def __init__(self, args, hidden_size):
20 | super().__init__()
21 | self.n_learnable_queries = 32
22 | config = AutoConfig.from_pretrained(args.text_condition)
23 | self.multi_encoder = AutoModel.from_config(config)
24 | qformer_config = Blip2QFormerConfig(
25 | num_hidden_layers=12,
26 | encoder_hidden_size=self.multi_encoder.config.text_config.hidden_size
27 | )
28 | self.qformer = Blip2QFormerModel(qformer_config)
29 | self.query_embeds = nn.Embedding(self.n_learnable_queries, qformer_config.hidden_size)
30 | self.out_project = nn.Linear(qformer_config.hidden_size, hidden_size)
31 |
32 | @torch.no_grad()
33 | def encode_text(self, input_ids, attention_mask):
34 | text_encoder_output = self.multi_encoder.text_model(
35 | input_ids=input_ids,
36 | attention_mask=attention_mask
37 | )
38 | text_embeds = text_encoder_output.last_hidden_state
39 | return text_embeds # bs x ntoken x ch
40 |
41 | def forward(self, input_ids, attention_mask):
42 | net_device = next(self.parameters()).device
43 | text_embeds = self.encode_text(input_ids=input_ids, attention_mask=attention_mask)
44 | query_embeds = self.query_embeds(
45 | repeat(
46 | torch.arange(0, self.n_learnable_queries, dtype=torch.int64).to(net_device),
47 | 'src -> bs src',
48 | bs = text_embeds.shape[0]
49 | )
50 | )
51 | query_outputs = self.qformer(
52 | query_embeds=query_embeds,
53 | encoder_hidden_states=text_embeds,
54 | encoder_attention_mask=attention_mask
55 | )
56 | query_outputs = query_outputs[0][:, : self.n_learnable_queries, :]
57 | return self.out_project(query_outputs)
58 |
59 |
60 |
61 | class MeshXL(nn.Module):
62 |
63 | def train(self, mode: bool = True):
64 | super().train(mode)
65 | # self.transformer.eval()
66 | # for param in self.transformer.parameters():
67 | # param.requires_grad = False
68 | return self
69 |
70 | def __init__(self, args):
71 | super().__init__()
72 |
73 | self.tokenizer = MeshTokenizer(args)
74 |
75 | # causal LM model initialization
76 | self.vocab_size = self.tokenizer.codebook_size + 3
77 | self.bos_token_id = self.tokenizer.codebook_size
78 | self.eos_token_id = self.tokenizer.codebook_size + 1
79 | self.pad_token_id = self.tokenizer.codebook_size + 2
80 |
81 | config = AutoConfig.from_pretrained(
82 | args.llm,
83 | n_positions=8192,
84 | max_position_embeddings=8192,
85 | vocab_size=self.vocab_size,
86 | bos_token_id=self.bos_token_id,
87 | eos_token_id=self.eos_token_id,
88 | pad_token_id=self.pad_token_id
89 | )
90 |
91 | config.word_embed_proj_dim = config.hidden_size
92 | self.transformer = AutoModelForCausalLM.from_config(config=config)
93 |
94 | try:
95 | self.transformer.to_bettertransformer()
96 | except:
97 | pass
98 |
99 | self.condition_encoder = ConditionEncoder(args, config.hidden_size)
100 |
101 | # setting status for all parameters
102 | self.train()
103 |
104 |
105 | def forward(
106 | self,
107 | data_dict: dict=None,
108 | is_eval: bool=False,
109 | is_generate: bool=False,
110 | num_return_sequences: int=8,
111 | generation_config: Dict=dict(
112 | do_sample=True,
113 | top_k=50,
114 | top_p=0.95,
115 | # no_repeat_ngram_size=9,
116 | )
117 | ) -> dict:
118 |
119 | data_dict['prefix_embeds'] = self.condition_encoder(
120 | input_ids = data_dict['text_input_ids'],
121 | attention_mask = data_dict['text_attention_mask']
122 | )
123 |
124 | if not is_eval:
125 | return NotImplementedError
126 |
127 | if is_eval and not is_generate:
128 | return NotImplementedError
129 |
130 | if is_eval and is_generate:
131 | return self.generate(
132 | data_dict=data_dict,
133 | num_return_sequences=num_return_sequences,
134 | generation_config=generation_config
135 | )
136 |
137 | raise NotImplementedError('training status undefined!')
138 | return
139 |
140 | @torch.no_grad()
141 | def generate(self, data_dict: dict=None, num_return_sequences: int=8, generation_config: dict=dict()) -> dict:
142 |
143 | net_device = next(self.parameters()).device
144 | max_length = 8191
145 | output_ids = torch.ones(num_return_sequences, max_length).long().to(net_device) * self.eos_token_id
146 |
147 | # batch x ntokens
148 | results = self.transformer.generate(
149 | inputs_embeds=data_dict['prefix_embeds'],
150 | max_length=max_length-1,
151 | num_return_sequences=num_return_sequences,
152 | bos_token_id=self.bos_token_id,
153 | eos_token_id=self.eos_token_id,
154 | pad_token_id=self.eos_token_id,
155 | **generation_config
156 | )
157 | output_ids[:, :results.shape[1]] = results
158 |
159 | # discard and tokens to pad tokens
160 | output_ids = output_ids[:, :-1]
161 | output_ids[output_ids == self.eos_token_id] = self.tokenizer.pad_id
162 |
163 | decoder_output = self.tokenizer.detokenize(input_ids=output_ids)
164 |
165 | return decoder_output
166 |
167 |
168 |
169 | def get_model(args):
170 | model = MeshXL(args)
171 | return model
--------------------------------------------------------------------------------
/models/x_mesh_xl/tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from typing import Tuple
4 | from einops import rearrange, repeat, reduce
5 |
6 |
7 |
8 | def discretize(
9 | t: Tensor,
10 | continuous_range: Tuple[float, float],
11 | num_discrete: int = 128
12 | ) -> Tensor:
13 |
14 | lo, hi = continuous_range
15 | assert hi > lo
16 | t = (t - lo) / (hi - lo) # cube normalize
17 | t *= num_discrete
18 | t -= 0.5
19 | return t.round().long().clamp(min = 0, max = num_discrete - 1)
20 |
21 |
22 |
23 | def undiscretize(
24 | t: Tensor,
25 | continuous_range = Tuple[float, float],
26 | num_discrete: int = 128
27 | ) -> Tensor:
28 | lo, hi = continuous_range
29 | assert hi > lo
30 | t = t.float()
31 | t += 0.5
32 | t /= num_discrete # cube normalize
33 | return t * (hi - lo) + lo
34 |
35 |
36 |
37 | class MeshTokenizer(nn.Module):
38 |
39 | def __init__(self, args):
40 | super().__init__()
41 | self.pad_id = -1
42 | self.num_discrete_coors = args.n_discrete_size # default: 800
43 | self.codebook_size = args.n_discrete_size # default: 128
44 | self.coor_continuous_range = (-1., 1.)
45 |
46 |
47 | def tokenize(self, data_dict: dict) -> dict:
48 | '''
49 | Turn 3D meshes into sequential tokens: [, , ], ...
50 | '''
51 |
52 | ### 3D mesh face parsing
53 | vertices = data_dict['vertices'] # batch x nv x 3
54 | faces = data_dict['faces'] # batch x nf x 3
55 | face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all') # batch x nf
56 |
57 | batch, num_vertices, num_coors = vertices.shape
58 | _, num_faces, _ = faces.shape
59 |
60 | # fill padding tokens with 0, to prevent gather idx error
61 | face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0)
62 |
63 | # collect vertice coordinates per-face: b x nf x nv x c
64 | faces_vertices = repeat(face_without_pad, 'b nf nv -> b nf nv c', c = num_coors)
65 | vertices = repeat(vertices, 'b nv c -> b nf nv c', nf = num_faces)
66 | face_coords = vertices.gather(-2, faces_vertices.long())
67 |
68 | # continuous to discrete face coords: b x nf x nv x c
69 | discrete_face_coords = discretize(
70 | face_coords,
71 | continuous_range=self.coor_continuous_range,
72 | num_discrete=self.num_discrete_coors
73 | )
74 |
75 | # pad invalid faces with : batch x nf x nv x c
76 | discrete_padded_coords = discrete_face_coords.masked_fill(
77 | ~rearrange(face_mask, 'b nf -> b nf 1 1'),
78 | self.pad_id
79 | )
80 |
81 |
82 | ### mesh to sequence convertion: batch x ntokens
83 | input_ids = discrete_padded_coords.reshape(batch, -1)
84 | attention_mask = (input_ids != self.pad_id).float()
85 | # reserve two spots:
86 | # input_ids: ... ... => ... ...
87 | # attn_mask: 1 ... 1 0 ... => 1 ... 1 0 ...
88 | place_holder = torch.ones_like(input_ids[:, [0]]) # batch x 1
89 | input_ids = torch.cat((place_holder * self.pad_id, input_ids, place_holder * self.pad_id), dim=1)
90 | attention_mask = torch.cat((place_holder, place_holder, attention_mask), dim=1)
91 |
92 | ### meshXL inputs
93 | data_dict['input_ids'] = input_ids.long() # batch x (nf * 3 * 3 + 2)
94 | data_dict['attention_mask'] = attention_mask.float() # batch x (nf * 3 * 3 + 2)
95 |
96 | # discard and tokens
97 | data_dict['codes'] = discrete_padded_coords.long() # batch x (nf * 3 * 3)
98 | data_dict['discrete_face_coords'] = discrete_face_coords
99 |
100 | return data_dict
101 |
102 |
103 | def detokenize(self, input_ids: Tensor) -> dict:
104 | '''
105 | Turn sequential tokens: [, , ], ... into 3D meshes
106 | '''
107 | # input_ids: b (n q) or b n q, without or
108 | input_ids = input_ids.reshape(input_ids.shape[0], -1)
109 | # batch x nface
110 | face_mask = reduce(
111 | input_ids != self.pad_id, 'b (nf c) -> b nf', 'all', c = 9
112 | )
113 |
114 | # batch x (nface x 9) -> batch x nface x 3 x 3
115 | pred_face_coords = input_ids.reshape(input_ids.shape[0], -1, 9)
116 | pred_face_coords = rearrange(
117 | pred_face_coords, '... (v c) -> ... v c', v = 3
118 | )
119 |
120 | # back to continuous space
121 | continuous_coors = undiscretize(
122 | pred_face_coords,
123 | num_discrete = self.num_discrete_coors,
124 | continuous_range = self.coor_continuous_range
125 | )
126 | # mask padding coordinates out with nan
127 | continuous_coors = continuous_coors.masked_fill(
128 | ~rearrange(face_mask, 'b nf -> b nf 1 1'),
129 | float('nan')
130 | )
131 | output_dict = {}
132 | output_dict['recon_faces'] = continuous_coors
133 |
134 | return output_dict
135 |
136 |
137 | def forward(self, data_dict: dict) -> dict:
138 |
139 | encoder_output = self.tokenize(data_dict)
140 | decoder_output = self.detokenize(
141 | input_ids = encoder_output['codes'],
142 | )
143 | data_dict.update(encoder_output)
144 | data_dict.update(decoder_output)
145 | return data_dict
146 |
--------------------------------------------------------------------------------
/openai/clip-vit-base-patch32/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "openai/clip-vit-base-patch32",
3 | "architectures": [
4 | "CLIPModel"
5 | ],
6 | "initializer_factor": 1.0,
7 | "logit_scale_init_value": 2.6592,
8 | "model_type": "clip",
9 | "projection_dim": 512,
10 | "text_config": {
11 | "_name_or_path": "",
12 | "add_cross_attention": false,
13 | "architectures": null,
14 | "attention_dropout": 0.0,
15 | "bad_words_ids": null,
16 | "bos_token_id": 0,
17 | "chunk_size_feed_forward": 0,
18 | "cross_attention_hidden_size": null,
19 | "decoder_start_token_id": null,
20 | "diversity_penalty": 0.0,
21 | "do_sample": false,
22 | "dropout": 0.0,
23 | "early_stopping": false,
24 | "encoder_no_repeat_ngram_size": 0,
25 | "eos_token_id": 2,
26 | "finetuning_task": null,
27 | "forced_bos_token_id": null,
28 | "forced_eos_token_id": null,
29 | "hidden_act": "quick_gelu",
30 | "hidden_size": 512,
31 | "id2label": {
32 | "0": "LABEL_0",
33 | "1": "LABEL_1"
34 | },
35 | "initializer_factor": 1.0,
36 | "initializer_range": 0.02,
37 | "intermediate_size": 2048,
38 | "is_decoder": false,
39 | "is_encoder_decoder": false,
40 | "label2id": {
41 | "LABEL_0": 0,
42 | "LABEL_1": 1
43 | },
44 | "layer_norm_eps": 1e-05,
45 | "length_penalty": 1.0,
46 | "max_length": 20,
47 | "max_position_embeddings": 77,
48 | "min_length": 0,
49 | "model_type": "clip_text_model",
50 | "no_repeat_ngram_size": 0,
51 | "num_attention_heads": 8,
52 | "num_beam_groups": 1,
53 | "num_beams": 1,
54 | "num_hidden_layers": 12,
55 | "num_return_sequences": 1,
56 | "output_attentions": false,
57 | "output_hidden_states": false,
58 | "output_scores": false,
59 | "pad_token_id": 1,
60 | "prefix": null,
61 | "projection_dim": 512,
62 | "problem_type": null,
63 | "pruned_heads": {},
64 | "remove_invalid_values": false,
65 | "repetition_penalty": 1.0,
66 | "return_dict": true,
67 | "return_dict_in_generate": false,
68 | "sep_token_id": null,
69 | "task_specific_params": null,
70 | "temperature": 1.0,
71 | "tie_encoder_decoder": false,
72 | "tie_word_embeddings": true,
73 | "tokenizer_class": null,
74 | "top_k": 50,
75 | "top_p": 1.0,
76 | "torch_dtype": null,
77 | "torchscript": false,
78 | "transformers_version": "4.16.0.dev0",
79 | "use_bfloat16": false,
80 | "vocab_size": 49408
81 | },
82 | "text_config_dict": null,
83 | "transformers_version": null,
84 | "vision_config": {
85 | "_name_or_path": "",
86 | "add_cross_attention": false,
87 | "architectures": null,
88 | "attention_dropout": 0.0,
89 | "bad_words_ids": null,
90 | "bos_token_id": null,
91 | "chunk_size_feed_forward": 0,
92 | "cross_attention_hidden_size": null,
93 | "decoder_start_token_id": null,
94 | "diversity_penalty": 0.0,
95 | "do_sample": false,
96 | "dropout": 0.0,
97 | "early_stopping": false,
98 | "encoder_no_repeat_ngram_size": 0,
99 | "eos_token_id": null,
100 | "finetuning_task": null,
101 | "forced_bos_token_id": null,
102 | "forced_eos_token_id": null,
103 | "hidden_act": "quick_gelu",
104 | "hidden_size": 768,
105 | "id2label": {
106 | "0": "LABEL_0",
107 | "1": "LABEL_1"
108 | },
109 | "image_size": 224,
110 | "initializer_factor": 1.0,
111 | "initializer_range": 0.02,
112 | "intermediate_size": 3072,
113 | "is_decoder": false,
114 | "is_encoder_decoder": false,
115 | "label2id": {
116 | "LABEL_0": 0,
117 | "LABEL_1": 1
118 | },
119 | "layer_norm_eps": 1e-05,
120 | "length_penalty": 1.0,
121 | "max_length": 20,
122 | "min_length": 0,
123 | "model_type": "clip_vision_model",
124 | "no_repeat_ngram_size": 0,
125 | "num_attention_heads": 12,
126 | "num_beam_groups": 1,
127 | "num_beams": 1,
128 | "num_hidden_layers": 12,
129 | "num_return_sequences": 1,
130 | "output_attentions": false,
131 | "output_hidden_states": false,
132 | "output_scores": false,
133 | "pad_token_id": null,
134 | "patch_size": 32,
135 | "prefix": null,
136 | "projection_dim" : 512,
137 | "problem_type": null,
138 | "pruned_heads": {},
139 | "remove_invalid_values": false,
140 | "repetition_penalty": 1.0,
141 | "return_dict": true,
142 | "return_dict_in_generate": false,
143 | "sep_token_id": null,
144 | "task_specific_params": null,
145 | "temperature": 1.0,
146 | "tie_encoder_decoder": false,
147 | "tie_word_embeddings": true,
148 | "tokenizer_class": null,
149 | "top_k": 50,
150 | "top_p": 1.0,
151 | "torch_dtype": null,
152 | "torchscript": false,
153 | "transformers_version": "4.16.0.dev0",
154 | "use_bfloat16": false
155 | },
156 | "vision_config_dict": null
157 | }
158 |
--------------------------------------------------------------------------------
/openai/clip-vit-base-patch32/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
--------------------------------------------------------------------------------
/openai/clip-vit-base-patch32/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/", "model_max_length": 77}
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | torchtyping
2 | pytorch_custom_utils
3 | beartype
4 | x_transformers
5 | local_attention
6 | vector_quantize_pytorch
7 | classifier_free_guidance_pytorch
8 | torch_geometric
9 | gateloop_transformer
10 | ema_pytorch
11 | trimesh
12 | wandb
13 | libigl
14 | matplotlib
15 | plyfile
16 | optimum
17 | transformers==4.38.2
18 | accelerate
19 | tensorboard
20 | deepspeed
--------------------------------------------------------------------------------
/sample_t2m.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import numpy as np
5 | from torch import Tensor
6 | from accelerate import Accelerator
7 | from transformers import AutoTokenizer
8 | from utils.ply_helper import write_ply
9 | from models.x_mesh_xl.get_model import get_model
10 |
11 |
12 | def post_process_mesh(mesh_coords: Tensor, filename: str):
13 | mesh_coords = mesh_coords[~torch.isnan(mesh_coords[:, 0, 0])] # nvalid_face x 3 x 3
14 | vertices = mesh_coords.reshape(-1, 3)
15 | vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
16 | triangles = vertices_index.reshape(-1, 3)
17 | write_ply(
18 | np.asarray(vertices.cpu()),
19 | None,
20 | np.asarray(triangles),
21 | filename
22 | )
23 | return vertices
24 |
25 |
26 | def make_args_parser():
27 | parser = argparse.ArgumentParser(
28 | "MeshXL: Neural Coordinate Field for Generative 3D Foundation Models",
29 | add_help=False
30 | )
31 | ##### model config #####
32 | parser.add_argument("--llm", default='mesh-xl/mesh-xl-350m', type=str)
33 | parser.add_argument("--n_discrete_size", default=128, type=int)
34 | parser.add_argument("--text_condition", default='openai/clip-vit-base-patch32', type=str)
35 | parser.add_argument("--test_ckpt", default='mesh-xl/x-mesh-xl-350m/pytorch_model.bin', type=str)
36 | parser.add_argument("--text", default='3d model of a chair', type=str)
37 | parser.add_argument("--output_dir", default='outputs', type=str)
38 | parser.add_argument("--num_samples", default=4, type=int)
39 | parser.add_argument("--top_k", default=50, type=int)
40 | parser.add_argument("--top_p", default=0.95, type=float)
41 | parser.add_argument("--temperature", default=0.1, type=float)
42 |
43 | args = parser.parse_args()
44 | return args
45 |
46 |
47 | if __name__ == '__main__':
48 |
49 | args = make_args_parser()
50 | accelerator = Accelerator()
51 |
52 | # prepare model
53 | tokenizer = AutoTokenizer.from_pretrained(args.text_condition)
54 | mesh_xl = get_model(args)
55 | mesh_xl.load_state_dict(torch.load(args.test_ckpt, map_location='cpu'))
56 | mesh_xl = accelerator.prepare(mesh_xl)
57 |
58 | net_device = next(mesh_xl.parameters()).device
59 | num_samples = args.num_samples
60 |
61 | text_inputs = tokenizer.batch_encode_plus(
62 | [args.text],
63 | max_length=64,
64 | padding='max_length',
65 | truncation='longest_first',
66 | return_tensors='pt'
67 | )
68 | text_inputs = dict(
69 | text_input_ids=text_inputs['input_ids'].to(net_device),
70 | text_attention_mask=text_inputs['attention_mask'].to(net_device)
71 | )
72 |
73 | # model forward
74 | output_dict = mesh_xl(
75 | text_inputs,
76 | is_eval=True,
77 | is_generate=True,
78 | num_return_sequences=args.num_samples,
79 | generation_config=dict(
80 | do_sample=True,
81 | top_k=args.top_k,
82 | top_p=args.top_p,
83 | temperature=args.temperature
84 | )
85 | )
86 |
87 | # save samples
88 | os.makedirs(args.output_dir, exist_ok=True)
89 | for gen_id, sample in enumerate(output_dict['recon_faces']):
90 | post_process_mesh(
91 | sample,
92 | os.path.join(
93 | args.output_dir,
94 | f'{accelerator.process_index:04d}_{gen_id}.ply'
95 | )
96 | )
97 |
--------------------------------------------------------------------------------
/scripts/meshxl-sft-shapenet.sh:
--------------------------------------------------------------------------------
1 | export BASE_MESHXL=meshxl/mesh-xl-1.3b
2 | export BATCHSIZE_PER_GPU=2
3 |
4 | accelerate launch \
5 | --config_file ./config/deepspeed_stage2.yaml \
6 | --num_machines 1 \
7 | --num_processes 8 \
8 | --mixed_precision bf16 \
9 | main.py \
10 | --dataset sft.shapenet_table \
11 | --n_max_triangles 800 \
12 | --n_discrete_size 128 \
13 | --warm_lr_iters -1 \
14 | --base_lr 1e-6 \
15 | --llm $BASE_MESHXL \
16 | --model mesh_xl \
17 | --checkpoint_dir ./ckpts/mesh_xl_1.3b_base_pretrain_bs2_8a100 \
18 | --batchsize_per_gpu $BATCHSIZE_PER_GPU \
19 | --dataset_num_workers 0 \
20 | --augment \
21 | --eval_every_iteration 10000 \
22 | --save_every 20000 \
23 | --max_epoch 1024
--------------------------------------------------------------------------------
/scripts/sample-1.3b.sh:
--------------------------------------------------------------------------------
1 | export LLM_CONFIG='mesh-xl/mesh-xl-1.3b'
2 | export NSAMPLE_PER_GPU=2
3 | export SAMPLE_ROUNDS=100
4 | export OUTPUT_DIR='./output-samples-1.3b'
5 |
6 | accelerate launch \
7 | --num_machines 1 \
8 | --num_processes 8 \
9 | --mixed_precision bf16 \
10 | main.py \
11 | --dataset dummy_dataset \
12 | --n_max_triangles 800 \
13 | --n_discrete_size 128 \
14 | --llm $LLM_CONFIG \
15 | --model mesh_xl \
16 | --checkpoint_dir $OUTPUT_DIR \
17 | --batchsize_per_gpu $NSAMPLE_PER_GPU \
18 | --sample_rounds $SAMPLE_ROUNDS \
19 | --dataset_num_workers 0 \
20 | --test_only
21 |
--------------------------------------------------------------------------------
/scripts/sample-125m.sh:
--------------------------------------------------------------------------------
1 | export LLM_CONFIG='mesh-xl/mesh-xl-125m'
2 | export NSAMPLE_PER_GPU=2
3 | export SAMPLE_ROUNDS=100
4 | export OUTPUT_DIR='./output-samples-125m'
5 |
6 | accelerate launch \
7 | --num_machines 1 \
8 | --num_processes 8 \
9 | --mixed_precision bf16 \
10 | main.py \
11 | --dataset dummy_dataset \
12 | --n_max_triangles 800 \
13 | --n_discrete_size 128 \
14 | --llm $LLM_CONFIG \
15 | --model mesh_xl \
16 | --checkpoint_dir $OUTPUT_DIR \
17 | --batchsize_per_gpu $NSAMPLE_PER_GPU \
18 | --sample_rounds $SAMPLE_ROUNDS \
19 | --dataset_num_workers 0 \
20 | --test_only
21 |
--------------------------------------------------------------------------------
/scripts/sample-350m.sh:
--------------------------------------------------------------------------------
1 | export LLM_CONFIG='mesh-xl/mesh-xl-350m'
2 | export NSAMPLE_PER_GPU=2
3 | export SAMPLE_ROUNDS=100
4 | export OUTPUT_DIR='./output-samples-350m'
5 |
6 | accelerate launch \
7 | --num_machines 1 \
8 | --num_processes 8 \
9 | --mixed_precision bf16 \
10 | main.py \
11 | --dataset dummy_dataset \
12 | --n_max_triangles 800 \
13 | --n_discrete_size 128 \
14 | --llm $LLM_CONFIG \
15 | --model mesh_xl \
16 | --checkpoint_dir $OUTPUT_DIR \
17 | --batchsize_per_gpu $NSAMPLE_PER_GPU \
18 | --sample_rounds $SAMPLE_ROUNDS \
19 | --dataset_num_workers 0 \
20 | --test_only
21 |
--------------------------------------------------------------------------------
/scripts/sample-t2mesh.sh:
--------------------------------------------------------------------------------
1 | accelerate launch \
2 | --num_machines 1 \
3 | --num_processes 1 \
4 | --mixed_precision bf16 \
5 | sample_t2m.py \
6 | --test_ckpt mesh-xl/x-mesh-xl-350m/pytorch_model.bin \
7 | --text '3d model of a table' \
8 | --top_k 25 \
9 | --top_p 0.95 \
10 | --temperature 0.1
--------------------------------------------------------------------------------
/set_env.sh:
--------------------------------------------------------------------------------
1 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118
2 | pip install -r requirement.txt
3 | pip install deepspeed
--------------------------------------------------------------------------------
/utils/dist.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import pickle
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 |
8 | def is_distributed():
9 | if not dist.is_available() or not dist.is_initialized():
10 | return False
11 | return True
12 |
13 |
14 | def get_rank():
15 | if not is_distributed():
16 | return 0
17 | return dist.get_rank()
18 |
19 |
20 | def is_primary():
21 | return get_rank() == 0
22 |
23 |
24 | def get_world_size():
25 | if not is_distributed():
26 | return 1
27 | return dist.get_world_size()
28 |
29 |
30 | def barrier():
31 | if not is_distributed():
32 | return
33 | torch.distributed.barrier()
34 |
35 |
36 | def setup_print_for_distributed(is_primary):
37 | """
38 | This function disables printing when not in primary process
39 | """
40 | import builtins as __builtin__
41 | builtin_print = __builtin__.print
42 |
43 | def print(*args, **kwargs):
44 | force = kwargs.pop('force', False)
45 | if is_primary or force:
46 | builtin_print(*args, **kwargs)
47 |
48 | __builtin__.print = print
49 |
50 |
51 | def init_distributed(gpu_id, global_rank, world_size, dist_url, dist_backend):
52 | torch.cuda.set_device(gpu_id)
53 | print(
54 | f"| distributed init (rank {global_rank}) (world {world_size}): {dist_url}",
55 | flush=True,
56 | )
57 | torch.distributed.init_process_group(
58 | backend=dist_backend,
59 | init_method=dist_url,
60 | world_size=world_size,
61 | rank=global_rank,
62 | )
63 | torch.distributed.barrier()
64 | setup_print_for_distributed(is_primary())
65 |
66 |
67 | def all_reduce_sum(tensor):
68 | if not is_distributed():
69 | return tensor
70 | dim_squeeze = False
71 | if tensor.ndim == 0:
72 | tensor = tensor[None, ...]
73 | dim_squeeze = True
74 | torch.distributed.all_reduce(tensor)
75 | if dim_squeeze:
76 | tensor = tensor.squeeze(0)
77 | return tensor
78 |
79 |
80 | def all_reduce_average(tensor):
81 | val = all_reduce_sum(tensor)
82 | return val / get_world_size()
83 |
84 |
85 | # Function from DETR - https://github.com/facebookresearch/detr/blob/master/util/misc.py
86 | def reduce_dict(input_dict, average=True):
87 | """
88 | Args:
89 | input_dict (dict): all the values will be reduced
90 | average (bool): whether to do average or sum
91 | Reduce the values in the dictionary from all processes so that all processes
92 | have the averaged results. Returns a dict with the same fields as
93 | input_dict, after reduction.
94 | """
95 | world_size = get_world_size()
96 | if world_size < 2:
97 | return input_dict
98 | with torch.no_grad():
99 | names = []
100 | values = []
101 | # sort the keys so that they are consistent across processes
102 | for k in sorted(input_dict.keys()):
103 | names.append(k)
104 | values.append(input_dict[k])
105 | values = torch.stack(values, dim=0)
106 | torch.distributed.all_reduce(values)
107 | if average:
108 | values /= world_size
109 | reduced_dict = {k: v for k, v in zip(names, values)}
110 | return reduced_dict
111 |
112 |
113 | # Function from https://github.com/facebookresearch/detr/blob/master/util/misc.py
114 | def all_gather_pickle(data, device):
115 | """
116 | Run all_gather on arbitrary picklable data (not necessarily tensors)
117 | Args:
118 | data: any picklable object
119 | Returns:
120 | list[data]: list of data gathered from each rank
121 | """
122 | world_size = get_world_size()
123 | if world_size == 1:
124 | return [data]
125 |
126 | # serialized to a Tensor
127 | buffer = pickle.dumps(data)
128 | storage = torch.ByteStorage.from_buffer(buffer)
129 | tensor = torch.ByteTensor(storage).to(device)
130 |
131 | # obtain Tensor size of each rank
132 | local_size = torch.tensor([tensor.numel()], device=device)
133 | size_list = [torch.tensor([0], device=device) for _ in range(world_size)]
134 | dist.all_gather(size_list, local_size)
135 | size_list = [int(size.item()) for size in size_list]
136 | max_size = max(size_list)
137 |
138 | # receiving Tensor from all ranks
139 | # we pad the tensor because torch all_gather does not support
140 | # gathering tensors of different shapes
141 | tensor_list = []
142 | for _ in size_list:
143 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
144 | if local_size != max_size:
145 | padding = torch.empty(
146 | size=(max_size - local_size,), dtype=torch.uint8, device=device
147 | )
148 | tensor = torch.cat((tensor, padding), dim=0)
149 | dist.all_gather(tensor_list, tensor)
150 |
151 | data_list = []
152 | for size, tensor in zip(size_list, tensor_list):
153 | buffer = tensor.cpu().numpy().tobytes()[:size]
154 | data_list.append(pickle.loads(buffer))
155 |
156 | return data_list
157 |
158 |
159 | def all_gather_dict(data):
160 | """
161 | Run all_gather on data which is a dictionary of Tensors
162 | """
163 | assert isinstance(data, dict)
164 |
165 | gathered_dict = {}
166 | for item_key in data:
167 | if isinstance(data[item_key], torch.Tensor):
168 | if is_distributed():
169 | data[item_key] = data[item_key].contiguous()
170 | tensor_list = [torch.empty_like(data[item_key]) for _ in range(get_world_size())]
171 | dist.all_gather(tensor_list, data[item_key])
172 | gathered_tensor = torch.cat(tensor_list, dim=0)
173 | else:
174 | gathered_tensor = data[item_key]
175 | gathered_dict[item_key] = gathered_tensor
176 | return gathered_dict
177 |
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | import torch
4 | import os
5 | from utils.dist import is_primary
6 |
7 |
8 | def save_checkpoint(
9 | checkpoint_dir,
10 | model_no_ddp,
11 | optimizer,
12 | epoch,
13 | args,
14 | best_val_metrics,
15 | filename=None,
16 | ):
17 | if not is_primary():
18 | return
19 | if filename is None:
20 | filename = f"checkpoint_{epoch:04d}.pth"
21 | checkpoint_name = os.path.join(checkpoint_dir, filename)
22 |
23 | weight_ckpt = model_no_ddp.state_dict()
24 | sd = {
25 | "model": weight_ckpt,
26 | "optimizer": optimizer.state_dict(),
27 | "epoch": epoch,
28 | "args": args,
29 | "best_val_metrics": best_val_metrics,
30 | }
31 | torch.save(sd, checkpoint_name)
32 |
33 |
34 | def resume_if_possible(checkpoint_dir, model_no_ddp, optimizer):
35 | """
36 | Resume if checkpoint is available.
37 | Return
38 | - epoch of loaded checkpoint.
39 | """
40 | epoch = -1
41 | best_val_metrics = {}
42 | if not os.path.isdir(checkpoint_dir):
43 | return epoch, best_val_metrics
44 |
45 | last_checkpoint = os.path.join(checkpoint_dir, "checkpoint.pth")
46 | if not os.path.isfile(last_checkpoint):
47 | return epoch, best_val_metrics
48 |
49 | sd = torch.load(last_checkpoint, map_location=torch.device("cpu"))
50 | epoch = sd["epoch"]
51 | best_val_metrics = sd["best_val_metrics"]
52 | print(f"Found checkpoint at {epoch}. Resuming.")
53 |
54 | model_no_ddp.load_state_dict(sd["model"], strict=False)
55 | try:
56 | optimizer.load_state_dict(sd["optimizer"])
57 | except:
58 | print('optimizer weights could not be loaded')
59 | print(
60 | f"Loaded model and optimizer state at {epoch}. Loaded best val metrics so far."
61 | )
62 | return epoch, best_val_metrics
63 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import os
3 | import torch
4 | from torch.utils.tensorboard import SummaryWriter
5 |
6 |
7 | class Logger(object):
8 |
9 | def __init__(self, log_dir=None, accelerator=None) -> None:
10 | self.log_dir = log_dir
11 | self.accelerator = accelerator
12 |
13 | if self.log_dir is not None:
14 | self.txt_writer = open(os.path.join(self.log_dir, 'logger.log'), 'a')
15 | else:
16 | self.txt_writer = None
17 |
18 | if SummaryWriter is not None and self.accelerator.is_main_process:
19 | self.writer = SummaryWriter(self.log_dir)
20 | else:
21 | self.writer = None
22 |
23 | def log_scalars(self, scalar_dict, step, prefix=None):
24 | if self.writer is None:
25 | return
26 | for k in scalar_dict:
27 | v = scalar_dict[k]
28 | if isinstance(v, torch.Tensor):
29 | v = v.detach().cpu().item()
30 | if prefix is not None:
31 | k = prefix + '_' + k
32 | self.writer.add_scalar(k, v, step)
33 | return
34 |
35 | def log_messages(self, message: str):
36 | if self.txt_writer is not None:
37 | self.txt_writer.write(message + "\n")
38 | self.txt_writer.flush()
39 | print(message, flush=True)
40 | return
41 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 | import numpy as np
4 | from collections import deque
5 | from typing import List
6 | from utils.dist import is_distributed, barrier, all_reduce_sum
7 |
8 |
9 | def my_worker_init_fn(worker_id):
10 | np.random.seed(np.random.get_state()[1][0] + worker_id)
11 |
12 |
13 | @torch.jit.ignore
14 | def to_list_1d(arr) -> List[float]:
15 | arr = arr.detach().cpu().numpy().tolist()
16 | return arr
17 |
18 |
19 | @torch.jit.ignore
20 | def to_list_3d(arr) -> List[List[List[float]]]:
21 | arr = arr.detach().cpu().numpy().tolist()
22 | return arr
23 |
24 |
25 | def huber_loss(error, delta=1.0):
26 | """
27 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py
28 | x = error = pred - gt or dist(pred,gt)
29 | 0.5 * |x|^2 if |x|<=d
30 | 0.5 * d^2 + d * (|x|-d) if |x|>d
31 | """
32 | abs_error = torch.abs(error)
33 | quadratic = torch.clamp(abs_error, max=delta)
34 | linear = abs_error - quadratic
35 | loss = 0.5 * quadratic ** 2 + delta * linear
36 | return loss
37 |
38 |
39 | # From https://github.com/facebookresearch/detr/blob/master/util/misc.py
40 | class SmoothedValue(object):
41 | """Track a series of values and provide access to smoothed values over a
42 | window or the global series average.
43 | """
44 |
45 | def __init__(self, window_size=20, fmt=None):
46 | if fmt is None:
47 | fmt = "{median:.4f} ({global_avg:.4f})"
48 | self.deque = deque(maxlen=window_size)
49 | self.total = 0.0
50 | self.count = 0
51 | self.fmt = fmt
52 |
53 | def update(self, value, n=1):
54 | self.deque.append(value)
55 | self.count += n
56 | self.total += value * n
57 |
58 | def synchronize_between_processes(self):
59 | """
60 | Warning: does not synchronize the deque!
61 | """
62 | if not is_distributed():
63 | return
64 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
65 | barrier()
66 | all_reduce_sum(t)
67 | t = t.tolist()
68 | self.count = int(t[0])
69 | self.total = t[1]
70 |
71 | @property
72 | def median(self):
73 | d = torch.tensor(list(self.deque))
74 | return d.median().item()
75 |
76 | @property
77 | def avg(self):
78 | d = torch.tensor(list(self.deque), dtype=torch.float32)
79 | return d.mean().item()
80 |
81 | @property
82 | def global_avg(self):
83 | return self.total / self.count
84 |
85 | @property
86 | def max(self):
87 | return max(self.deque)
88 |
89 | @property
90 | def value(self):
91 | return self.deque[-1]
92 |
93 | def __str__(self):
94 | return self.fmt.format(
95 | median=self.median,
96 | avg=self.avg,
97 | global_avg=self.global_avg,
98 | max=self.max,
99 | value=self.value,
100 | )
101 |
--------------------------------------------------------------------------------
/utils/nms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | import numpy as np
4 |
5 | # boxes are axis aigned 2D boxes of shape (n,5) in FLOAT numbers with (x1,y1,x2,y2,score)
6 | """ Ref: https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
7 | Ref: https://github.com/vickyboy47/nms-python/blob/master/nms.py
8 | """
9 |
10 |
11 | def nms_2d(boxes, overlap_threshold):
12 | x1 = boxes[:, 0]
13 | y1 = boxes[:, 1]
14 | x2 = boxes[:, 2]
15 | y2 = boxes[:, 3]
16 | score = boxes[:, 4]
17 | area = (x2 - x1) * (y2 - y1)
18 |
19 | I = np.argsort(score)
20 | pick = []
21 | while I.size != 0:
22 | last = I.size
23 | i = I[-1]
24 | pick.append(i)
25 | suppress = [last - 1]
26 | for pos in range(last - 1):
27 | j = I[pos]
28 | xx1 = max(x1[i], x1[j])
29 | yy1 = max(y1[i], y1[j])
30 | xx2 = min(x2[i], x2[j])
31 | yy2 = min(y2[i], y2[j])
32 | w = xx2 - xx1
33 | h = yy2 - yy1
34 | if w > 0 and h > 0:
35 | o = w * h / area[j]
36 | print("Overlap is", o)
37 | if o > overlap_threshold:
38 | suppress.append(pos)
39 | I = np.delete(I, suppress)
40 | return pick
41 |
42 |
43 | def nms_2d_faster(boxes, overlap_threshold, old_type=False):
44 | x1 = boxes[:, 0]
45 | y1 = boxes[:, 1]
46 | x2 = boxes[:, 2]
47 | y2 = boxes[:, 3]
48 | score = boxes[:, 4]
49 | area = (x2 - x1) * (y2 - y1)
50 |
51 | I = np.argsort(score)
52 | pick = []
53 | while I.size != 0:
54 | last = I.size
55 | i = I[-1]
56 | pick.append(i)
57 |
58 | xx1 = np.maximum(x1[i], x1[I[: last - 1]])
59 | yy1 = np.maximum(y1[i], y1[I[: last - 1]])
60 | xx2 = np.minimum(x2[i], x2[I[: last - 1]])
61 | yy2 = np.minimum(y2[i], y2[I[: last - 1]])
62 |
63 | w = np.maximum(0, xx2 - xx1)
64 | h = np.maximum(0, yy2 - yy1)
65 |
66 | if old_type:
67 | o = (w * h) / area[I[: last - 1]]
68 | else:
69 | inter = w * h
70 | o = inter / (area[i] + area[I[: last - 1]] - inter)
71 |
72 | I = np.delete(
73 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0]))
74 | )
75 |
76 | return pick
77 |
78 |
79 | def nms_3d_faster(boxes, overlap_threshold, old_type=False):
80 | x1 = boxes[:, 0]
81 | y1 = boxes[:, 1]
82 | z1 = boxes[:, 2]
83 | x2 = boxes[:, 3]
84 | y2 = boxes[:, 4]
85 | z2 = boxes[:, 5]
86 | score = boxes[:, 6]
87 | area = (x2 - x1) * (y2 - y1) * (z2 - z1)
88 |
89 | I = np.argsort(score)
90 | pick = []
91 | while I.size != 0:
92 | last = I.size
93 | i = I[-1]
94 | pick.append(i)
95 |
96 | xx1 = np.maximum(x1[i], x1[I[: last - 1]])
97 | yy1 = np.maximum(y1[i], y1[I[: last - 1]])
98 | zz1 = np.maximum(z1[i], z1[I[: last - 1]])
99 | xx2 = np.minimum(x2[i], x2[I[: last - 1]])
100 | yy2 = np.minimum(y2[i], y2[I[: last - 1]])
101 | zz2 = np.minimum(z2[i], z2[I[: last - 1]])
102 |
103 | l = np.maximum(0, xx2 - xx1)
104 | w = np.maximum(0, yy2 - yy1)
105 | h = np.maximum(0, zz2 - zz1)
106 |
107 | if old_type:
108 | o = (l * w * h) / area[I[: last - 1]]
109 | else:
110 | inter = l * w * h
111 | o = inter / (area[i] + area[I[: last - 1]] - inter)
112 |
113 | I = np.delete(
114 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0]))
115 | )
116 |
117 | return pick
118 |
119 |
120 | def nms_3d_faster_samecls(boxes, overlap_threshold, old_type=False):
121 | x1 = boxes[:, 0]
122 | y1 = boxes[:, 1]
123 | z1 = boxes[:, 2]
124 | x2 = boxes[:, 3]
125 | y2 = boxes[:, 4]
126 | z2 = boxes[:, 5]
127 | score = boxes[:, 6]
128 | cls = boxes[:, 7]
129 | area = (x2 - x1) * (y2 - y1) * (z2 - z1)
130 |
131 | I = np.argsort(score)
132 | pick = []
133 | while I.size != 0:
134 | last = I.size
135 | i = I[-1]
136 | pick.append(i)
137 |
138 | xx1 = np.maximum(x1[i], x1[I[: last - 1]])
139 | yy1 = np.maximum(y1[i], y1[I[: last - 1]])
140 | zz1 = np.maximum(z1[i], z1[I[: last - 1]])
141 | xx2 = np.minimum(x2[i], x2[I[: last - 1]])
142 | yy2 = np.minimum(y2[i], y2[I[: last - 1]])
143 | zz2 = np.minimum(z2[i], z2[I[: last - 1]])
144 | cls1 = cls[i]
145 | cls2 = cls[I[: last - 1]]
146 |
147 | l = np.maximum(0, xx2 - xx1)
148 | w = np.maximum(0, yy2 - yy1)
149 | h = np.maximum(0, zz2 - zz1)
150 |
151 | if old_type:
152 | o = (l * w * h) / area[I[: last - 1]]
153 | else:
154 | inter = l * w * h
155 | o = inter / (area[i] + area[I[: last - 1]] - inter)
156 | o = o * (cls1 == cls2)
157 |
158 | I = np.delete(
159 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0]))
160 | )
161 |
162 | return pick
163 |
--------------------------------------------------------------------------------
/utils/pc_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | """ Utility functions for processing point clouds.
4 |
5 | Author: Charles R. Qi and Or Litany
6 | """
7 |
8 | import os
9 | import sys
10 | import torch
11 |
12 | # Point cloud IO
13 | import numpy as np
14 | from plyfile import PlyData, PlyElement
15 |
16 | # Mesh IO
17 | import trimesh
18 |
19 | # ----------------------------------------
20 | # Point Cloud Sampling
21 | # ----------------------------------------
22 |
23 |
24 | def random_sampling(pc, num_sample, replace=None, return_choices=False):
25 | """Input is NxC, output is num_samplexC"""
26 | if replace is None:
27 | replace = pc.shape[0] < num_sample
28 | choices = np.random.choice(pc.shape[0], num_sample, replace=replace)
29 | if return_choices:
30 | return pc[choices], choices
31 | else:
32 | return pc[choices]
33 |
34 |
35 | # ----------------------------------------
36 | # Simple Point manipulations
37 | # ----------------------------------------
38 | def shift_scale_points(pred_xyz, src_range, dst_range=None):
39 | """
40 | pred_xyz: B x N x 3
41 | src_range: [[B x 3], [B x 3]] - min and max XYZ coords
42 | dst_range: [[B x 3], [B x 3]] - min and max XYZ coords
43 | """
44 | if dst_range is None:
45 | dst_range = [
46 | torch.zeros((src_range[0].shape[0], 3), device=src_range[0].device),
47 | torch.ones((src_range[0].shape[0], 3), device=src_range[0].device),
48 | ]
49 |
50 | if pred_xyz.ndim == 4:
51 | src_range = [x[:, None] for x in src_range]
52 | dst_range = [x[:, None] for x in dst_range]
53 |
54 | assert src_range[0].shape[0] == pred_xyz.shape[0]
55 | assert dst_range[0].shape[0] == pred_xyz.shape[0]
56 | assert src_range[0].shape[-1] == pred_xyz.shape[-1]
57 | assert src_range[0].shape == src_range[1].shape
58 | assert dst_range[0].shape == dst_range[1].shape
59 | assert src_range[0].shape == dst_range[1].shape
60 |
61 | src_diff = src_range[1][:, None, :] - src_range[0][:, None, :]
62 | dst_diff = dst_range[1][:, None, :] - dst_range[0][:, None, :]
63 | prop_xyz = (
64 | ((pred_xyz - src_range[0][:, None, :]) * dst_diff) / src_diff
65 | ) + dst_range[0][:, None, :]
66 | return prop_xyz
67 |
68 |
69 | def scale_points(pred_xyz, mult_factor):
70 | if pred_xyz.ndim == 4:
71 | mult_factor = mult_factor[:, None]
72 | scaled_xyz = pred_xyz * mult_factor[:, None, :]
73 | return scaled_xyz
74 |
75 |
76 | def rotate_point_cloud(points, rotation_matrix=None):
77 | """Input: (n,3), Output: (n,3)"""
78 | # Rotate in-place around Z axis.
79 | if rotation_matrix is None:
80 | rotation_angle = np.random.uniform() * 2 * np.pi
81 | sinval, cosval = np.sin(rotation_angle), np.cos(rotation_angle)
82 | rotation_matrix = np.array(
83 | [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
84 | )
85 | ctr = points.mean(axis=0)
86 | rotated_data = np.dot(points - ctr, rotation_matrix) + ctr
87 | return rotated_data, rotation_matrix
88 |
89 |
90 | def rotate_pc_along_y(pc, rot_angle):
91 | """Input ps is NxC points with first 3 channels as XYZ
92 | z is facing forward, x is left ward, y is downward
93 | """
94 | cosval = np.cos(rot_angle)
95 | sinval = np.sin(rot_angle)
96 | rotmat = np.array([[cosval, -sinval], [sinval, cosval]])
97 | pc[:, [0, 2]] = np.dot(pc[:, [0, 2]], np.transpose(rotmat))
98 | return pc
99 |
100 |
101 | def roty(t):
102 | """Rotation about the y-axis."""
103 | c = np.cos(t)
104 | s = np.sin(t)
105 | return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
106 |
107 |
108 | def roty_batch(t):
109 | """Rotation about the y-axis.
110 | t: (x1,x2,...xn)
111 | return: (x1,x2,...,xn,3,3)
112 | """
113 | input_shape = t.shape
114 | output = np.zeros(tuple(list(input_shape) + [3, 3]))
115 | c = np.cos(t)
116 | s = np.sin(t)
117 | output[..., 0, 0] = c
118 | output[..., 0, 2] = s
119 | output[..., 1, 1] = 1
120 | output[..., 2, 0] = -s
121 | output[..., 2, 2] = c
122 | return output
123 |
124 |
125 | def rotz(t):
126 | """Rotation about the z-axis."""
127 | c = np.cos(t)
128 | s = np.sin(t)
129 | return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
130 |
131 |
132 | def point_cloud_to_bbox(points):
133 | """Extract the axis aligned box from a pcl or batch of pcls
134 | Args:
135 | points: Nx3 points or BxNx3
136 | output is 6 dim: xyz pos of center and 3 lengths
137 | """
138 | which_dim = len(points.shape) - 2 # first dim if a single cloud and second if batch
139 | mn, mx = points.min(which_dim), points.max(which_dim)
140 | lengths = mx - mn
141 | cntr = 0.5 * (mn + mx)
142 | return np.concatenate([cntr, lengths], axis=which_dim)
143 |
144 |
145 | def write_bbox(scene_bbox, out_filename):
146 | """Export scene bbox to meshes
147 | Args:
148 | scene_bbox: (N x 6 numpy array): xyz pos of center and 3 lengths
149 | out_filename: (string) filename
150 |
151 | Note:
152 | To visualize the boxes in MeshLab.
153 | 1. Select the objects (the boxes)
154 | 2. Filters -> Polygon and Quad Mesh -> Turn into Quad-Dominant Mesh
155 | 3. Select Wireframe view.
156 | """
157 |
158 | def convert_box_to_trimesh_fmt(box):
159 | ctr = box[:3]
160 | lengths = box[3:]
161 | trns = np.eye(4)
162 | trns[0:3, 3] = ctr
163 | trns[3, 3] = 1.0
164 | box_trimesh_fmt = trimesh.creation.box(lengths, trns)
165 | return box_trimesh_fmt
166 |
167 | scene = trimesh.scene.Scene()
168 | for box in scene_bbox:
169 | scene.add_geometry(convert_box_to_trimesh_fmt(box))
170 |
171 | mesh_list = trimesh.util.concatenate(scene.dump())
172 | # save to ply file
173 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply")
174 |
175 | return
176 |
177 |
178 | def write_oriented_bbox(scene_bbox, out_filename, colors=None):
179 | """Export oriented (around Z axis) scene bbox to meshes
180 | Args:
181 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz)
182 | and heading angle around Z axis.
183 | Y forward, X right, Z upward. heading angle of positive X is 0,
184 | heading angle of positive Y is 90 degrees.
185 | out_filename: (string) filename
186 | """
187 |
188 | def heading2rotmat(heading_angle):
189 | pass
190 | rotmat = np.zeros((3, 3))
191 | rotmat[2, 2] = 1
192 | cosval = np.cos(heading_angle)
193 | sinval = np.sin(heading_angle)
194 | rotmat[0:2, 0:2] = np.array([[cosval, -sinval], [sinval, cosval]])
195 | return rotmat
196 |
197 | def convert_oriented_box_to_trimesh_fmt(box):
198 | ctr = box[:3]
199 | lengths = box[3:6]
200 | trns = np.eye(4)
201 | trns[0:3, 3] = ctr
202 | trns[3, 3] = 1.0
203 | trns[0:3, 0:3] = heading2rotmat(box[6])
204 | box_trimesh_fmt = trimesh.creation.box(lengths, trns)
205 | return box_trimesh_fmt
206 |
207 | if colors is not None:
208 | if colors.shape[0] != len(scene_bbox):
209 | colors = [colors for _ in range(len(scene_bbox))]
210 | colors = np.array(colors).astype(np.uint8)
211 | assert colors.shape[0] == len(scene_bbox)
212 | assert colors.shape[1] == 4
213 |
214 | scene = trimesh.scene.Scene()
215 | for idx, box in enumerate(scene_bbox):
216 | box_tr = convert_oriented_box_to_trimesh_fmt(box)
217 | if colors is not None:
218 | box_tr.visual.main_color[:] = colors[idx]
219 | box_tr.visual.vertex_colors[:] = colors[idx]
220 | for facet in box_tr.facets:
221 | box_tr.visual.face_colors[facet] = colors[idx]
222 | scene.add_geometry(box_tr)
223 |
224 | mesh_list = trimesh.util.concatenate(scene.dump())
225 | # save to ply file
226 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply")
227 |
228 | return
229 |
230 |
231 | def write_oriented_bbox_camera_coord(scene_bbox, out_filename):
232 | """Export oriented (around Y axis) scene bbox to meshes
233 | Args:
234 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz)
235 | and heading angle around Y axis.
236 | Z forward, X rightward, Y downward. heading angle of positive X is 0,
237 | heading angle of negative Z is 90 degrees.
238 | out_filename: (string) filename
239 | """
240 |
241 | def heading2rotmat(heading_angle):
242 | pass
243 | rotmat = np.zeros((3, 3))
244 | rotmat[1, 1] = 1
245 | cosval = np.cos(heading_angle)
246 | sinval = np.sin(heading_angle)
247 | rotmat[0, :] = np.array([cosval, 0, sinval])
248 | rotmat[2, :] = np.array([-sinval, 0, cosval])
249 | return rotmat
250 |
251 | def convert_oriented_box_to_trimesh_fmt(box):
252 | ctr = box[:3]
253 | lengths = box[3:6]
254 | trns = np.eye(4)
255 | trns[0:3, 3] = ctr
256 | trns[3, 3] = 1.0
257 | trns[0:3, 0:3] = heading2rotmat(box[6])
258 | box_trimesh_fmt = trimesh.creation.box(lengths, trns)
259 | return box_trimesh_fmt
260 |
261 | scene = trimesh.scene.Scene()
262 | for box in scene_bbox:
263 | scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box))
264 |
265 | mesh_list = trimesh.util.concatenate(scene.dump())
266 | # save to ply file
267 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply")
268 |
269 | return
270 |
271 |
272 | def write_lines_as_cylinders(pcl, filename, rad=0.005, res=64):
273 | """Create lines represented as cylinders connecting pairs of 3D points
274 | Args:
275 | pcl: (N x 2 x 3 numpy array): N pairs of xyz pos
276 | filename: (string) filename for the output mesh (ply) file
277 | rad: radius for the cylinder
278 | res: number of sections used to create the cylinder
279 | """
280 | scene = trimesh.scene.Scene()
281 | for src, tgt in pcl:
282 | # compute line
283 | vec = tgt - src
284 | M = trimesh.geometry.align_vectors([0, 0, 1], vec, False)
285 | vec = tgt - src # compute again since align_vectors modifies vec in-place!
286 | M[:3, 3] = 0.5 * src + 0.5 * tgt
287 | height = np.sqrt(np.dot(vec, vec))
288 | scene.add_geometry(
289 | trimesh.creation.cylinder(
290 | radius=rad, height=height, sections=res, transform=M
291 | )
292 | )
293 | mesh_list = trimesh.util.concatenate(scene.dump())
294 | trimesh.io.export.export_mesh(mesh_list, "%s.ply" % (filename), file_type="ply")
295 |
--------------------------------------------------------------------------------
/utils/ply_helper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from plyfile import PlyData,PlyElement
5 |
6 |
7 | def read_mesh_vertices_rgb_normal(filename):
8 | """ read XYZ RGB normals point cloud from filename PLY file """
9 | assert(os.path.isfile(filename))
10 | with open(filename, 'rb') as f:
11 | plydata = PlyData.read(f)
12 | num_verts = plydata['vertex'].count
13 | vertices = np.zeros(shape=[num_verts, 6], dtype=np.float32)
14 | vertices[:,0] = plydata['vertex'].data['x']
15 | vertices[:,1] = plydata['vertex'].data['y']
16 | vertices[:,2] = plydata['vertex'].data['z']
17 | vertices[:,3] = plydata['vertex'].data['red']
18 | vertices[:,4] = plydata['vertex'].data['green']
19 | vertices[:,5] = plydata['vertex'].data['blue']
20 |
21 | # compute normals
22 | face = np.array([f[0] for f in plydata["face"].data])
23 |
24 | return vertices, face
25 |
26 |
27 | def write_ply(verts, colors, indices, output_file):
28 | if colors is None:
29 | colors = np.zeros_like(verts)
30 | if indices is None:
31 | indices = []
32 |
33 | file = open(output_file, 'w')
34 | file.write('ply\n')
35 | file.write('format ascii 1.0\n')
36 | file.write('element vertex {:d}\n'.format(len(verts)))
37 | file.write('property float x\n')
38 | file.write('property float y\n')
39 | file.write('property float z\n')
40 | file.write('property uchar red\n')
41 | file.write('property uchar green\n')
42 | file.write('property uchar blue\n')
43 | file.write('element face {:d}\n'.format(len(indices)))
44 | file.write('property list uchar uint vertex_indices\n')
45 | file.write('end_header\n')
46 | for vert, color in zip(verts, colors):
47 | file.write("{:f} {:f} {:f} {:d} {:d} {:d}\n".format(vert[0], vert[1], vert[2] , int(color[0]*255), int(color[1]*255), int(color[2]*255)))
48 | for ind in indices:
49 | file.write('3 {:d} {:d} {:d}\n'.format(ind[0], ind[1], ind[2]))
50 | file.close()
51 |
52 |
53 | def create_cylinder_mesh(radius, p0, p1, stacks=10, slices=10):
54 |
55 | import math
56 |
57 | def compute_length_vec3(vec3):
58 | return math.sqrt(vec3[0]*vec3[0] + vec3[1]*vec3[1] + vec3[2]*vec3[2])
59 |
60 | def rotation(axis, angle):
61 | rot = np.eye(4)
62 | c = np.cos(-angle)
63 | s = np.sin(-angle)
64 | t = 1.0 - c
65 | axis /= compute_length_vec3(axis)
66 | x = axis[0]
67 | y = axis[1]
68 | z = axis[2]
69 | rot[0,0] = 1 + t*(x*x-1)
70 | rot[0,1] = z*s+t*x*y
71 | rot[0,2] = -y*s+t*x*z
72 | rot[1,0] = -z*s+t*x*y
73 | rot[1,1] = 1+t*(y*y-1)
74 | rot[1,2] = x*s+t*y*z
75 | rot[2,0] = y*s+t*x*z
76 | rot[2,1] = -x*s+t*y*z
77 | rot[2,2] = 1+t*(z*z-1)
78 | return rot
79 |
80 |
81 | verts = []
82 | indices = []
83 | diff = (p1 - p0).astype(np.float32)
84 | height = compute_length_vec3(diff)
85 | for i in range(stacks+1):
86 | for i2 in range(slices):
87 | theta = i2 * 2.0 * math.pi / slices
88 | pos = np.array([radius*math.cos(theta), radius*math.sin(theta), height*i/stacks])
89 | verts.append(pos)
90 | for i in range(stacks):
91 | for i2 in range(slices):
92 | i2p1 = math.fmod(i2 + 1, slices)
93 | indices.append( np.array([(i + 1)*slices + i2, i*slices + i2, i*slices + i2p1], dtype=np.uint32) )
94 | indices.append( np.array([(i + 1)*slices + i2, i*slices + i2p1, (i + 1)*slices + i2p1], dtype=np.uint32) )
95 | transform = np.eye(4)
96 | va = np.array([0, 0, 1], dtype=np.float32)
97 | vb = diff
98 | vb /= compute_length_vec3(vb)
99 | axis = np.cross(vb, va)
100 | angle = np.arccos(np.clip(np.dot(va, vb), -1, 1))
101 | if angle != 0:
102 | if compute_length_vec3(axis) == 0:
103 | dotx = va[0]
104 | if (math.fabs(dotx) != 1.0):
105 | axis = np.array([1,0,0]) - dotx * va
106 | else:
107 | axis = np.array([0,1,0]) - va[1] * va
108 | axis /= compute_length_vec3(axis)
109 | transform = rotation(axis, -angle)
110 | transform[:3,3] += p0
111 | verts = [np.dot(transform, np.array([v[0], v[1], v[2], 1.0])) for v in verts]
112 | verts = [np.array([v[0], v[1], v[2]]) / v[3] for v in verts]
113 |
114 | return verts, indices
115 |
116 | def write_bbox(corners, color, output_file):
117 | """
118 | bbox: (cx, cy, cz, lx, ly, lz, r), center and length in three axis, the last is the rotation
119 | output_file: string
120 | """
121 |
122 | def get_bbox_edges(bbox_min, bbox_max):
123 | def get_bbox_verts(bbox_min, bbox_max):
124 | verts = [
125 | np.array([bbox_min[0], bbox_min[1], bbox_min[2]]),
126 | np.array([bbox_max[0], bbox_min[1], bbox_min[2]]),
127 | np.array([bbox_max[0], bbox_max[1], bbox_min[2]]),
128 | np.array([bbox_min[0], bbox_max[1], bbox_min[2]]),
129 |
130 | np.array([bbox_min[0], bbox_min[1], bbox_max[2]]),
131 | np.array([bbox_max[0], bbox_min[1], bbox_max[2]]),
132 | np.array([bbox_max[0], bbox_max[1], bbox_max[2]]),
133 | np.array([bbox_min[0], bbox_max[1], bbox_max[2]])
134 | ]
135 | return verts
136 |
137 | box_verts = get_bbox_verts(bbox_min, bbox_max)
138 | edges = [
139 | (box_verts[0], box_verts[1]),
140 | (box_verts[1], box_verts[2]),
141 | (box_verts[2], box_verts[3]),
142 | (box_verts[3], box_verts[0]),
143 |
144 | (box_verts[4], box_verts[5]),
145 | (box_verts[5], box_verts[6]),
146 | (box_verts[6], box_verts[7]),
147 | (box_verts[7], box_verts[4]),
148 |
149 | (box_verts[0], box_verts[4]),
150 | (box_verts[1], box_verts[5]),
151 | (box_verts[2], box_verts[6]),
152 | (box_verts[3], box_verts[7])
153 | ]
154 | return edges
155 |
156 | radius = 0.03
157 | offset = [0,0,0]
158 | verts = []
159 | indices = []
160 | colors = []
161 |
162 | box_min = np.min(corners, axis=0)
163 | box_max = np.max(corners, axis=0)
164 | edges = get_bbox_edges(box_min, box_max)
165 | for k in range(len(edges)):
166 | cyl_verts, cyl_ind = create_cylinder_mesh(radius, edges[k][0], edges[k][1])
167 | cur_num_verts = len(verts)
168 | cyl_color = [[c / 255 for c in color] for _ in cyl_verts]
169 | cyl_verts = [x + offset for x in cyl_verts]
170 | cyl_ind = [x + cur_num_verts for x in cyl_ind]
171 | verts.extend(cyl_verts)
172 | indices.extend(cyl_ind)
173 | colors.extend(cyl_color)
174 |
175 | write_ply(verts, colors, indices, output_file)
176 | return
177 |
178 |
179 | def write_path(points, color, output_file):
180 |
181 | radius = 0.03
182 | offset = [0,0,0]
183 | verts = []
184 | indices = []
185 | colors = []
186 |
187 | for start, end in zip(points[:-1], points[1:]):
188 | cyl_verts, cyl_ind = create_cylinder_mesh(radius, start, end)
189 | cur_num_verts = len(verts)
190 | cyl_color = [[c / 255 for c in color] for _ in cyl_verts]
191 | cyl_verts = [x + offset for x in cyl_verts]
192 | cyl_ind = [x + cur_num_verts for x in cyl_ind]
193 | verts.extend(cyl_verts)
194 | indices.extend(cyl_ind)
195 | colors.extend(cyl_color)
196 |
197 | write_ply(verts, colors, indices, output_file)
198 | return
--------------------------------------------------------------------------------