├── .env.sample
├── .github
└── workflows
│ ├── python-publish.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── MeshGPT_demo.ipynb
├── README.md
├── demo_mesh
├── bar chair.glb
├── circle chair.glb
├── corner table.glb
├── designer chair.glb
├── designer sloped chair.glb
├── glass table.glb
├── high chair.glb
├── office table.glb
└── tv table.glb
├── meshgpt.png
├── meshgpt_pytorch
├── __init__.py
├── data.py
├── mesh_dataset.py
├── mesh_render.py
├── meshgpt_pytorch.py
├── trainer.py
├── typing.py
└── version.py
├── setup.cfg
├── setup.py
├── shapenet_labels.json
└── tests
└── test_meshgpt.py
/.env.sample:
--------------------------------------------------------------------------------
1 | TYPECHECK=True
2 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@v1.9.0
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: [3.9]
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | python -m pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu
30 | python -m pip install scipy
31 | - name: Test with pytest
32 | run: |
33 | python setup.py test
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MeshGPT_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "!pip install -q git+https://github.com/MarcusLoppe/meshgpt-pytorch.git"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 33,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import torch\n",
19 | "import trimesh\n",
20 | "import numpy as np\n",
21 | "import os\n",
22 | "import csv\n",
23 | "import json\n",
24 | "from collections import OrderedDict\n",
25 | "\n",
26 | "from meshgpt_pytorch import (\n",
27 | " MeshTransformerTrainer,\n",
28 | " MeshAutoencoderTrainer,\n",
29 | " MeshAutoencoder,\n",
30 | " MeshTransformer\n",
31 | ")\n",
32 | "from meshgpt_pytorch.data import ( \n",
33 | " derive_face_edges_from_faces\n",
34 | ") \n",
35 | "\n",
36 | "def get_mesh(file_path): \n",
37 | " mesh = trimesh.load(file_path, force='mesh') \n",
38 | " vertices = mesh.vertices.tolist()\n",
39 | " if \".off\" in file_path: # ModelNet dataset\n",
40 | " mesh.vertices[:, [1, 2]] = mesh.vertices[:, [2, 1]] \n",
41 | " rotation_matrix = trimesh.transformations.rotation_matrix(np.radians(-90), [0, 1, 0])\n",
42 | " mesh.apply_transform(rotation_matrix) \n",
43 | " # Extract vertices and faces from the rotated mesh\n",
44 | " vertices = mesh.vertices.tolist()\n",
45 | " \n",
46 | " faces = mesh.faces.tolist()\n",
47 | " \n",
48 | " centered_vertices = vertices - np.mean(vertices, axis=0) \n",
49 | " max_abs = np.max(np.abs(centered_vertices))\n",
50 | " vertices = centered_vertices / (max_abs / 0.95) # Limit vertices to [-0.95, 0.95]\n",
51 | " \n",
52 | " min_y = np.min(vertices[:, 1]) \n",
53 | " difference = -0.95 - min_y \n",
54 | " vertices[:, 1] += difference\n",
55 | " \n",
56 | " def sort_vertices(vertex):\n",
57 | " return vertex[1], vertex[2], vertex[0] \n",
58 | " \n",
59 | " seen = OrderedDict()\n",
60 | " for point in vertices: \n",
61 | " key = tuple(point)\n",
62 | " if key not in seen:\n",
63 | " seen[key] = point\n",
64 | " \n",
65 | " unique_vertices = list(seen.values()) \n",
66 | " sorted_vertices = sorted(unique_vertices, key=sort_vertices)\n",
67 | " \n",
68 | " vertices_as_tuples = [tuple(v) for v in vertices]\n",
69 | " sorted_vertices_as_tuples = [tuple(v) for v in sorted_vertices]\n",
70 | "\n",
71 | " vertex_map = {old_index: new_index for old_index, vertex_tuple in enumerate(vertices_as_tuples) for new_index, sorted_vertex_tuple in enumerate(sorted_vertices_as_tuples) if vertex_tuple == sorted_vertex_tuple} \n",
72 | " reindexed_faces = [[vertex_map[face[0]], vertex_map[face[1]], vertex_map[face[2]]] for face in faces] \n",
73 | " sorted_faces = [sorted(sub_arr) for sub_arr in reindexed_faces] \n",
74 | " return np.array(sorted_vertices), np.array(sorted_faces)\n",
75 | " \n",
76 | " \n",
77 | "\n",
78 | "def augment_mesh(vertices, scale_factor): \n",
79 | " jitter_factor=0.01 \n",
80 | " possible_values = np.arange(-jitter_factor, jitter_factor , 0.0005) \n",
81 | " offsets = np.random.choice(possible_values, size=vertices.shape) \n",
82 | " vertices = vertices + offsets \n",
83 | " \n",
84 | " vertices = vertices * scale_factor \n",
85 | " # To ensure that the mesh models are on the \"ground\"\n",
86 | " min_y = np.min(vertices[:, 1]) \n",
87 | " difference = -0.95 - min_y \n",
88 | " vertices[:, 1] += difference\n",
89 | " return vertices\n",
90 | "\n",
91 | "\n",
92 | "#load_shapenet(\"./shapenet\", \"./shapenet_csv_files\", 10, 10) \n",
93 | "#Find the csv files with the labels in the ShapeNetCore.v1.zip, download at https://huggingface.co/datasets/ShapeNet/ShapeNetCore-archive \n",
94 | "def load_shapenet(directory, per_category, variations ):\n",
95 | " obj_datas = [] \n",
96 | " chosen_models_count = {} \n",
97 | " print(f\"per_category: {per_category} variations {variations}\")\n",
98 | " \n",
99 | " with open('shapenet_labels.json' , 'r') as f:\n",
100 | " id_info = json.load(f) \n",
101 | " \n",
102 | " possible_values = np.arange(0.75, 1.0 , 0.005) \n",
103 | " scale_factors = np.random.choice(possible_values, size=variations) \n",
104 | " \n",
105 | " for category in os.listdir(directory): \n",
106 | " category_path = os.path.join(directory, category) \n",
107 | " if os.path.isdir(category_path) == False:\n",
108 | " continue \n",
109 | " \n",
110 | " num_files_in_category = len(os.listdir(category_path))\n",
111 | " print(f\"{category_path} got {num_files_in_category} files\") \n",
112 | " chosen_models_count[category] = 0 \n",
113 | " \n",
114 | " for filename in os.listdir(category_path):\n",
115 | " if filename.endswith((\".obj\", \".glb\", \".off\")):\n",
116 | " file_path = os.path.join(category_path, filename)\n",
117 | " \n",
118 | " if chosen_models_count[category] >= per_category:\n",
119 | " break \n",
120 | " if os.path.getsize(file_path) > 20 * 1024: # 20 kb limit = less then 400-600 faces\n",
121 | " continue \n",
122 | " if filename[:-4] not in id_info:\n",
123 | " print(\"Unable to find id info for \", filename)\n",
124 | " continue \n",
125 | " vertices, faces = get_mesh(file_path) \n",
126 | " if len(faces) > 800: \n",
127 | " continue\n",
128 | " \n",
129 | " chosen_models_count[category] += 1 \n",
130 | " textName = id_info[filename[:-4]] \n",
131 | " \n",
132 | " face_edges = derive_face_edges_from_faces(faces) \n",
133 | " for scale_factor in scale_factors: \n",
134 | " aug_vertices = augment_mesh(vertices.copy(), scale_factor) \n",
135 | " obj_data = {\"vertices\": torch.tensor(aug_vertices.tolist(), dtype=torch.float).to(\"cuda\"), \"faces\": torch.tensor(faces.tolist(), dtype=torch.long).to(\"cuda\"), \"face_edges\" : face_edges, \"texts\": textName } \n",
136 | " obj_datas.append(obj_data)\n",
137 | " \n",
138 | " print(\"=\"*25)\n",
139 | " print(\"Chosen models count for each category:\")\n",
140 | " for category, count in chosen_models_count.items():\n",
141 | " print(f\"{category}: {count}\") \n",
142 | " total_chosen_models = sum(chosen_models_count.values())\n",
143 | " print(f\"Total number of chosen models: {total_chosen_models}\")\n",
144 | " return obj_datas\n",
145 | "\n",
146 | " \n",
147 | " \n",
148 | "def load_filename(directory, variations):\n",
149 | " obj_datas = [] \n",
150 | " possible_values = np.arange(0.75, 1.0 , 0.005) \n",
151 | " scale_factors = np.random.choice(possible_values, size=variations) \n",
152 | " \n",
153 | " for filename in os.listdir(directory):\n",
154 | " if filename.endswith((\".obj\", \".glb\", \".off\")): \n",
155 | " file_path = os.path.join(directory, filename) \n",
156 | " vertices, faces = get_mesh(file_path) \n",
157 | " \n",
158 | " faces = torch.tensor(faces.tolist(), dtype=torch.long).to(\"cuda\")\n",
159 | " face_edges = derive_face_edges_from_faces(faces) \n",
160 | " texts, ext = os.path.splitext(filename) \n",
161 | " \n",
162 | " for scale_factor in scale_factors: \n",
163 | " aug_vertices = augment_mesh(vertices.copy(), scale_factor) \n",
164 | " obj_data = {\"vertices\": torch.tensor(aug_vertices.tolist(), dtype=torch.float).to(\"cuda\"), \"faces\": faces, \"face_edges\" : face_edges, \"texts\": texts } \n",
165 | " obj_datas.append(obj_data)\n",
166 | " \n",
167 | " print(f\"[create_mesh_dataset] Returning {len(obj_data)} meshes\")\n",
168 | " return obj_datas"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "import gzip,json\n",
178 | "from tqdm import tqdm\n",
179 | "import pandas as pd\n",
180 | "\n",
181 | "# Instruction to download objverse meshes: https://github.com/MarcusLoppe/Objaverse-downloader/tree/main\n",
182 | "def load_objverse(directory, variations ):\n",
183 | " obj_datas = [] \n",
184 | " id_info = {} \n",
185 | " pali_captions = pd.read_csv('.\\pali_captions.csv', sep=';') # https://github.com/google-deepmind/objaverse_annotations/blob/main/pali_captions.csv\n",
186 | " pali_captions_dict = pali_captions.set_index(\"object_uid\").to_dict()[\"top_aggregate_caption\"] \n",
187 | " \n",
188 | " possible_values = np.arange(0.75, 1.0) \n",
189 | " scale_factors = np.random.choice(possible_values, size=variations) \n",
190 | " \n",
191 | " for folder in os.listdir(directory): \n",
192 | " full_folder_path = os.path.join(directory, folder) \n",
193 | " if os.path.isdir(full_folder_path) == False:\n",
194 | " continue \n",
195 | " \n",
196 | " for filename in tqdm(os.listdir(full_folder_path)): \n",
197 | " if filename.endswith((\".obj\", \".glb\", \".off\")):\n",
198 | " file_path = os.path.join(full_folder_path, filename)\n",
199 | " kb = os.path.getsize(file_path) / 1024 \n",
200 | " if kb < 1 or kb > 30:\n",
201 | " continue\n",
202 | " \n",
203 | " if filename[:-4] not in pali_captions_dict: \n",
204 | " continue \n",
205 | " textName = pali_captions_dict[filename[:-4]]\n",
206 | " try: \n",
207 | " vertices, faces = get_mesh(file_path) \n",
208 | " except Exception as e:\n",
209 | " continue\n",
210 | " \n",
211 | " if len(faces) > 250 or len(faces) < 50: \n",
212 | " continue\n",
213 | " \n",
214 | " faces = torch.tensor(faces.tolist(), dtype=torch.long).to(\"cuda\")\n",
215 | " face_edges = derive_face_edges_from_faces(faces) \n",
216 | " for scale_factor in scale_factors: \n",
217 | " aug_vertices = augment_mesh(vertices.copy(), scale_factor) \n",
218 | " obj_data = {\"filename\": filename, \"vertices\": torch.tensor(aug_vertices.tolist(), dtype=torch.float).to(\"cuda\"), \"faces\": faces, \"face_edges\" : face_edges, \"texts\": textName } \n",
219 | " obj_datas.append(obj_data) \n",
220 | " return obj_datas"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "from pathlib import Path \n",
230 | "import gc \n",
231 | "import os\n",
232 | "from meshgpt_pytorch import MeshDataset \n",
233 | " \n",
234 | "project_name = \"demo_mesh\" \n",
235 | "\n",
236 | "working_dir = f'.\\{project_name}'\n",
237 | "\n",
238 | "working_dir = Path(working_dir)\n",
239 | "working_dir.mkdir(exist_ok = True, parents = True)\n",
240 | "dataset_path = working_dir / (project_name + \".npz\")\n",
241 | " \n",
242 | "if not os.path.isfile(dataset_path):\n",
243 | " data = load_filename(\"./demo_mesh\",50) \n",
244 | " dataset = MeshDataset(data) \n",
245 | " dataset.generate_face_edges() \n",
246 | " dataset.save(dataset_path)\n",
247 | " \n",
248 | "dataset = MeshDataset.load(dataset_path) \n",
249 | "print(dataset.data[0].keys())"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {},
255 | "source": [
256 | "#### Inspect imported meshes (optional)"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "from pathlib import Path\n",
266 | " \n",
267 | "folder = working_dir / f'renders' \n",
268 | "obj_file_path = Path(folder)\n",
269 | "obj_file_path.mkdir(exist_ok = True, parents = True)\n",
270 | " \n",
271 | "all_vertices = []\n",
272 | "all_faces = []\n",
273 | "vertex_offset = 0\n",
274 | "translation_distance = 0.5 \n",
275 | "\n",
276 | "for r, item in enumerate(data): \n",
277 | " vertices_copy = np.copy(item['vertices'])\n",
278 | " vertices_copy += translation_distance * (r / 0.2 - 1) \n",
279 | " \n",
280 | " for vert in vertices_copy:\n",
281 | " vertex = vert.to('cpu')\n",
282 | " all_vertices.append(f\"v {float(vertex[0])} {float(vertex[1])} {float(vertex[2])}\\n\") \n",
283 | " for face in item['faces']:\n",
284 | " all_faces.append(f\"f {face[0]+1+ vertex_offset} {face[1]+ 1+vertex_offset} {face[2]+ 1+vertex_offset}\\n\") \n",
285 | " vertex_offset = len(all_vertices)\n",
286 | " \n",
287 | "obj_file_content = \"\".join(all_vertices) + \"\".join(all_faces)\n",
288 | " \n",
289 | "obj_file_path = f'{folder}/3d_models_inspect.obj' \n",
290 | "with open(obj_file_path, \"w\") as file:\n",
291 | " file.write(obj_file_content) \n",
292 | " "
293 | ]
294 | },
295 | {
296 | "cell_type": "markdown",
297 | "metadata": {},
298 | "source": [
299 | "### Train!"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "autoencoder = MeshAutoencoder( \n",
309 | " decoder_dims_through_depth = (128,) * 6 + (192,) * 12 + (256,) * 24 + (384,) * 6, \n",
310 | " codebook_size = 2048, # Smaller vocab size will speed up the transformer training, however if you are training on meshes more then 250 triangle, I'd advice to use 16384 codebook size\n",
311 | " dim_codebook = 192, \n",
312 | " dim_area_embed = 16,\n",
313 | " dim_coor_embed = 16, \n",
314 | " dim_normal_embed = 16,\n",
315 | " dim_angle_embed = 8,\n",
316 | " \n",
317 | " attn_decoder_depth = 4,\n",
318 | " attn_encoder_depth = 2\n",
319 | ").to(\"cuda\") \n",
320 | "total_params = sum(p.numel() for p in autoencoder.parameters()) \n",
321 | "total_params = f\"{total_params / 1000000:.1f}M\"\n",
322 | "print(f\"Total parameters: {total_params}\")"
323 | ]
324 | },
325 | {
326 | "cell_type": "markdown",
327 | "metadata": {},
328 | "source": [
329 | "**Have at least 400-2000 items in the dataset, use this to multiply the dataset** "
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": null,
335 | "metadata": {},
336 | "outputs": [],
337 | "source": [
338 | "dataset.data = [dict(d) for d in dataset.data] * 10\n",
339 | "print(len(dataset.data))"
340 | ]
341 | },
342 | {
343 | "cell_type": "markdown",
344 | "metadata": {},
345 | "source": [
346 | "*Load previous saved model if you had to restart session*"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "pkg = torch.load(str(f'{working_dir}\\mesh-encoder_{project_name}.pt')) \n",
356 | "autoencoder.load_state_dict(pkg['model'])\n",
357 | "for param in autoencoder.parameters():\n",
358 | " param.requires_grad = True"
359 | ]
360 | },
361 | {
362 | "cell_type": "markdown",
363 | "metadata": {},
364 | "source": [
365 | "**Train to about 0.3 loss if you are using a small dataset**"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": null,
371 | "metadata": {},
372 | "outputs": [],
373 | "source": [
374 | "batch_size=16 # The batch size should be max 64.\n",
375 | "grad_accum_every = 4\n",
376 | "# So set the maximal batch size (max 64) that your VRAM can handle and then use grad_accum_every to create a effective batch size of 64, e.g 16 * 4 = 64\n",
377 | "learning_rate = 1e-3 # Start with 1e-3 then at staggnation around 0.35, you can lower it to 1e-4.\n",
378 | "\n",
379 | "autoencoder.commit_loss_weight = 0.2 # Set dependant on the dataset size, on smaller datasets, 0.1 is fine, otherwise try from 0.25 to 0.4.\n",
380 | "autoencoder_trainer = MeshAutoencoderTrainer(model =autoencoder ,warmup_steps = 10, dataset = dataset, num_train_steps=100,\n",
381 | " batch_size=batch_size,\n",
382 | " grad_accum_every = grad_accum_every,\n",
383 | " learning_rate = learning_rate,\n",
384 | " checkpoint_every_epoch=1) \n",
385 | "loss = autoencoder_trainer.train(480,stop_at_loss = 0.2, diplay_graph= True) "
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": null,
391 | "metadata": {},
392 | "outputs": [],
393 | "source": [
394 | "autoencoder_trainer.save(f'{working_dir}\\mesh-encoder_{project_name}.pt') "
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "metadata": {},
400 | "source": [
401 | "### Inspect how the autoencoder can encode and then provide the decoder with the codes to reconstruct the mesh"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {},
408 | "outputs": [],
409 | "source": [
410 | "import torch\n",
411 | "import random\n",
412 | "from tqdm import tqdm \n",
413 | "from meshgpt_pytorch import mesh_render \n",
414 | "\n",
415 | "min_mse, max_mse = float('inf'), float('-inf')\n",
416 | "min_coords, min_orgs, max_coords, max_orgs = None, None, None, None\n",
417 | "random_samples, random_samples_pred, all_random_samples = [], [], []\n",
418 | "total_mse, sample_size = 0.0, 200\n",
419 | "\n",
420 | "random.shuffle(dataset.data)\n",
421 | "\n",
422 | "for item in tqdm(dataset.data[:sample_size]):\n",
423 | " codes = autoencoder.tokenize(vertices=item['vertices'], faces=item['faces'], face_edges=item['face_edges']) \n",
424 | " \n",
425 | " codes = codes.flatten().unsqueeze(0)\n",
426 | " codes = codes[:, :codes.shape[-1] // autoencoder.num_quantizers * autoencoder.num_quantizers] \n",
427 | " \n",
428 | " coords, mask = autoencoder.decode_from_codes_to_faces(codes)\n",
429 | " orgs = item['vertices'][item['faces']].unsqueeze(0)\n",
430 | "\n",
431 | " mse = torch.mean((orgs.view(-1, 3).cpu() - coords.view(-1, 3).cpu())**2)\n",
432 | " total_mse += mse \n",
433 | "\n",
434 | " if mse < min_mse: min_mse, min_coords, min_orgs = mse, coords, orgs\n",
435 | " if mse > max_mse: max_mse, max_coords, max_orgs = mse, coords, orgs\n",
436 | " \n",
437 | " if len(random_samples) <= 30:\n",
438 | " random_samples.append(coords)\n",
439 | " random_samples_pred.append(orgs)\n",
440 | " else:\n",
441 | " all_random_samples.extend([random_samples_pred, random_samples])\n",
442 | " random_samples, random_samples_pred = [], []\n",
443 | "\n",
444 | "print(f'MSE AVG: {total_mse / sample_size:.10f}, Min: {min_mse:.10f}, Max: {max_mse:.10f}') \n",
445 | "mesh_render.combind_mesh_with_rows(f'{working_dir}\\mse_rows.obj', all_random_samples)"
446 | ]
447 | },
448 | {
449 | "cell_type": "markdown",
450 | "metadata": {},
451 | "source": [
452 | "### Training & fine-tuning\n",
453 | "\n",
454 | "**Pre-train:** Train the transformer on the full dataset with all the augmentations, the longer / more epochs will create a more robust model.
\n",
455 | "\n",
456 | "**Fine-tune:** Since it will take a long time to train on all the possible augmentations of the meshes, I recommend that you remove all the augmentations so you are left with x1 model per mesh.
\n",
457 | "Below is the function **filter_dataset** that will return a single copy of each mesh.
\n",
458 | "The function can also check for duplicate labels, this may speed up the fine-tuning process (not recommanded) however this most likely will remove it's ability for novel mesh generation."
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": null,
464 | "metadata": {},
465 | "outputs": [],
466 | "source": [
467 | "import gc \n",
468 | "torch.cuda.empty_cache()\n",
469 | "gc.collect() \n",
470 | "max_seq = max(len(d[\"faces\"]) for d in dataset if \"faces\" in d) * (autoencoder.num_vertices_per_face * autoencoder.num_quantizers) \n",
471 | "print(\"Max token sequence:\" , max_seq) \n",
472 | "\n",
473 | "# GPT2-Small model\n",
474 | "transformer = MeshTransformer(\n",
475 | " autoencoder,\n",
476 | " dim = 768,\n",
477 | " coarse_pre_gateloop_depth = 3, \n",
478 | " fine_pre_gateloop_depth= 3, \n",
479 | " attn_depth = 12, \n",
480 | " attn_heads = 12, \n",
481 | " max_seq_len = max_seq, \n",
482 | " condition_on_text = True, \n",
483 | " gateloop_use_heinsen = False,\n",
484 | " dropout = 0.0,\n",
485 | " text_condition_model_types = \"bge\", \n",
486 | " text_condition_cond_drop_prob = 0.0\n",
487 | ") \n",
488 | "\n",
489 | "total_params = sum(p.numel() for p in transformer.decoder.parameters())\n",
490 | "total_params = f\"{total_params / 1000000:.1f}M\"\n",
491 | "print(f\"Decoder total parameters: {total_params}\")"
492 | ]
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": null,
497 | "metadata": {},
498 | "outputs": [],
499 | "source": [
500 | "def filter_dataset(dataset, unique_labels = False):\n",
501 | " unique_dicts = []\n",
502 | " unique_tensors = set()\n",
503 | " texts = set()\n",
504 | " for d in dataset.data:\n",
505 | " tensor = d[\"faces\"]\n",
506 | " tensor_tuple = tuple(tensor.cpu().numpy().flatten())\n",
507 | " if unique_labels and d['texts'] in texts:\n",
508 | " continue\n",
509 | " if tensor_tuple not in unique_tensors:\n",
510 | " unique_tensors.add(tensor_tuple)\n",
511 | " unique_dicts.append(d)\n",
512 | " texts.add(d['texts'])\n",
513 | " return unique_dicts \n",
514 | "#dataset.data = filter_dataset(dataset.data, unique_labels = False)"
515 | ]
516 | },
517 | {
518 | "cell_type": "markdown",
519 | "metadata": {},
520 | "source": [
521 | "## **Required!**, embed the text and run generate_codes to save 4-96 GB VRAM (dependant on dataset) ##\n",
522 | "\n",
523 | "**If you don't;**
\n",
524 | "During each during each training step the autoencoder will generate the codes and the text encoder will embed the text.\n",
525 | "
\n",
526 | "After these fields are generate: **they will be deleted and next time it generates the code again:**
\n",
527 | "\n",
528 | "This is due to the dataloaders nature, it writes this information to a temporary COPY of the dataset\n"
529 | ]
530 | },
531 | {
532 | "cell_type": "code",
533 | "execution_count": null,
534 | "metadata": {},
535 | "outputs": [],
536 | "source": [
537 | "labels = list(set(item[\"texts\"] for item in dataset.data))\n",
538 | "dataset.embed_texts(transformer, batch_size = 25)\n",
539 | "dataset.generate_codes(autoencoder, batch_size = 50)\n",
540 | "print(dataset.data[0].keys())"
541 | ]
542 | },
543 | {
544 | "cell_type": "markdown",
545 | "metadata": {},
546 | "source": [
547 | "*Load previous saved model if you had to restart session*"
548 | ]
549 | },
550 | {
551 | "cell_type": "code",
552 | "execution_count": null,
553 | "metadata": {},
554 | "outputs": [],
555 | "source": [
556 | "pkg = torch.load(str(f'{working_dir}\\mesh-transformer_{project_name}.pt')) \n",
557 | "transformer.load_state_dict(pkg['model'])"
558 | ]
559 | },
560 | {
561 | "cell_type": "markdown",
562 | "metadata": {},
563 | "source": [
564 | "**Train to about 0.0001 loss (or less) if you are using a small dataset**"
565 | ]
566 | },
567 | {
568 | "cell_type": "code",
569 | "execution_count": null,
570 | "metadata": {},
571 | "outputs": [],
572 | "source": [
573 | "batch_size = 4 # Max 64\n",
574 | "grad_accum_every = 16\n",
575 | "\n",
576 | "# Set the maximal batch size (max 64) that your VRAM can handle and then use grad_accum_every to create a effective batch size of 64, e.g 4 * 16 = 64\n",
577 | "learning_rate = 1e-2 # Start training with the learning rate at 1e-2 then lower it to 1e-3 at stagnation or at 0.5 loss.\n",
578 | "\n",
579 | "trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,num_train_steps=100, dataset = dataset,\n",
580 | " grad_accum_every=grad_accum_every,\n",
581 | " learning_rate = learning_rate,\n",
582 | " batch_size=batch_size,\n",
583 | " checkpoint_every_epoch = 1,\n",
584 | " # FP16 training, it doesn't speed up very much but can increase the batch size which will in turn speed up the training.\n",
585 | " # However it might cause nan after a while.\n",
586 | " # accelerator_kwargs = {\"mixed_precision\" : \"fp16\"}, optimizer_kwargs = { \"eps\": 1e-7} \n",
587 | " )\n",
588 | "loss = trainer.train(300, stop_at_loss = 0.005) "
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": null,
594 | "metadata": {},
595 | "outputs": [],
596 | "source": [
597 | "trainer.save(f'{working_dir}\\mesh-transformer_{project_name}.pt') "
598 | ]
599 | },
600 | {
601 | "cell_type": "markdown",
602 | "metadata": {},
603 | "source": [
604 | "## Generate and view mesh"
605 | ]
606 | },
607 | {
608 | "cell_type": "markdown",
609 | "metadata": {},
610 | "source": [
611 | "**Using only text**"
612 | ]
613 | },
614 | {
615 | "cell_type": "code",
616 | "execution_count": null,
617 | "metadata": {},
618 | "outputs": [],
619 | "source": [
620 | " \n",
621 | "from meshgpt_pytorch import mesh_render \n",
622 | "from pathlib import Path\n",
623 | " \n",
624 | "folder = working_dir / 'renders'\n",
625 | "obj_file_path = Path(folder)\n",
626 | "obj_file_path.mkdir(exist_ok = True, parents = True) \n",
627 | " \n",
628 | "text_coords = [] \n",
629 | "for text in labels[:10]:\n",
630 | " print(f\"Generating {text}\") \n",
631 | " text_coords.append(transformer.generate(texts = [text], temperature = 0.0)) \n",
632 | " \n",
633 | "mesh_render.save_rendering(f'{folder}/3d_models_all.obj', text_coords)"
634 | ]
635 | },
636 | {
637 | "cell_type": "markdown",
638 | "metadata": {},
639 | "source": [
640 | "**Text + prompt of tokens**"
641 | ]
642 | },
643 | {
644 | "cell_type": "markdown",
645 | "metadata": {},
646 | "source": [
647 | "**Prompt with 10% of codes/tokens**"
648 | ]
649 | },
650 | {
651 | "cell_type": "code",
652 | "execution_count": null,
653 | "metadata": {},
654 | "outputs": [],
655 | "source": [
656 | "from pathlib import Path \n",
657 | "from meshgpt_pytorch import mesh_render \n",
658 | "folder = working_dir / f'renders/text+codes'\n",
659 | "obj_file_path = Path(folder)\n",
660 | "obj_file_path.mkdir(exist_ok = True, parents = True) \n",
661 | "\n",
662 | "token_length_procent = 0.10 \n",
663 | "codes = []\n",
664 | "texts = []\n",
665 | "for label in labels:\n",
666 | " for item in dataset.data: \n",
667 | " if item['texts'] == label:\n",
668 | " tokens = autoencoder.tokenize(\n",
669 | " vertices = item['vertices'],\n",
670 | " faces = item['faces'],\n",
671 | " face_edges = item['face_edges']\n",
672 | " ) \n",
673 | " num_tokens = int(tokens.shape[0] * token_length_procent) \n",
674 | " texts.append(item['texts']) \n",
675 | " codes.append(tokens.flatten()[:num_tokens].unsqueeze(0)) \n",
676 | " break\n",
677 | " \n",
678 | "coords = [] \n",
679 | "for text, prompt in zip(texts, codes): \n",
680 | " print(f\"Generating {text} with {prompt.shape[1]} tokens\") \n",
681 | " coords.append(transformer.generate(texts = [text], prompt = prompt, temperature = 0) ) \n",
682 | " \n",
683 | "mesh_render.save_rendering(f'{folder}/text+prompt_{token_length_procent*100}.obj', coords)"
684 | ]
685 | },
686 | {
687 | "cell_type": "markdown",
688 | "metadata": {},
689 | "source": [
690 | "**Prompt with 0% to 80% of tokens**"
691 | ]
692 | },
693 | {
694 | "cell_type": "code",
695 | "execution_count": null,
696 | "metadata": {},
697 | "outputs": [],
698 | "source": [
699 | "from pathlib import Path\n",
700 | "from meshgpt_pytorch import mesh_render \n",
701 | " \n",
702 | "folder = working_dir / f'renders/text+codes_rows'\n",
703 | "obj_file_path = Path(folder)\n",
704 | "obj_file_path.mkdir(exist_ok = True, parents = True) \n",
705 | "\n",
706 | "mesh_rows = []\n",
707 | "for token_length_procent in np.arange(0, 0.8, 0.1):\n",
708 | " codes = []\n",
709 | " texts = []\n",
710 | " for label in labels:\n",
711 | " for item in dataset.data: \n",
712 | " if item['texts'] == label:\n",
713 | " tokens = autoencoder.tokenize(\n",
714 | " vertices = item['vertices'],\n",
715 | " faces = item['faces'],\n",
716 | " face_edges = item['face_edges']\n",
717 | " ) \n",
718 | " num_tokens = int(tokens.shape[0] * token_length_procent) \n",
719 | " \n",
720 | " texts.append(item['texts']) \n",
721 | " codes.append(tokens.flatten()[:num_tokens].unsqueeze(0)) \n",
722 | " break\n",
723 | " \n",
724 | " coords = [] \n",
725 | " for text, prompt in zip(texts, codes): \n",
726 | " print(f\"Generating {text} with {prompt.shape[1]} tokens\") \n",
727 | " coords.append(transformer.generate(texts = [text], prompt = prompt, temperature = 0)) \n",
728 | " \n",
729 | " mesh_rows.append(coords) \n",
730 | " \n",
731 | "mesh_render.save_rendering(f'{folder}/all.obj', mesh_rows)\n",
732 | " "
733 | ]
734 | },
735 | {
736 | "cell_type": "markdown",
737 | "metadata": {},
738 | "source": [
739 | "**Just some testing for text embedding similarity**"
740 | ]
741 | },
742 | {
743 | "cell_type": "code",
744 | "execution_count": null,
745 | "metadata": {},
746 | "outputs": [],
747 | "source": [
748 | "import numpy as np \n",
749 | "texts = list(labels)\n",
750 | "vectors = [transformer.conditioner.text_models[0].embed_text([text], return_text_encodings = False).cpu().flatten() for text in texts]\n",
751 | " \n",
752 | "max_label_length = max(len(text) for text in texts)\n",
753 | " \n",
754 | "# Print the table header\n",
755 | "print(f\"{'Text':<{max_label_length}} |\", end=\" \")\n",
756 | "for text in texts:\n",
757 | " print(f\"{text:<{max_label_length}} |\", end=\" \")\n",
758 | "print()\n",
759 | "\n",
760 | "# Print the similarity matrix as a table with fixed-length columns\n",
761 | "for i in range(len(texts)):\n",
762 | " print(f\"{texts[i]:<{max_label_length}} |\", end=\" \")\n",
763 | " for j in range(len(texts)):\n",
764 | " # Encode the texts and calculate cosine similarity manually\n",
765 | " vector_i = vectors[i]\n",
766 | " vector_j = vectors[j]\n",
767 | " \n",
768 | " dot_product = torch.sum(vector_i * vector_j)\n",
769 | " norm_vector1 = torch.norm(vector_i)\n",
770 | " norm_vector2 = torch.norm(vector_j)\n",
771 | " similarity_score = dot_product / (norm_vector1 * norm_vector2)\n",
772 | " \n",
773 | " # Print with fixed-length columns\n",
774 | " print(f\"{similarity_score.item():<{max_label_length}.4f} |\", end=\" \")\n",
775 | " print()"
776 | ]
777 | }
778 | ],
779 | "metadata": {
780 | "kaggle": {
781 | "accelerator": "gpu",
782 | "dataSources": [],
783 | "dockerImageVersionId": 30627,
784 | "isGpuEnabled": true,
785 | "isInternetEnabled": true,
786 | "language": "python",
787 | "sourceType": "notebook"
788 | },
789 | "kernelspec": {
790 | "display_name": "Python 3",
791 | "language": "python",
792 | "name": "python3"
793 | },
794 | "language_info": {
795 | "codemirror_mode": {
796 | "name": "ipython",
797 | "version": 3
798 | },
799 | "file_extension": ".py",
800 | "mimetype": "text/x-python",
801 | "name": "python",
802 | "nbconvert_exporter": "python",
803 | "pygments_lexer": "ipython3",
804 | "version": "3.11.5"
805 | }
806 | },
807 | "nbformat": 4,
808 | "nbformat_minor": 4
809 | }
810 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## MeshGPT - Pytorch
4 |
5 | Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
6 |
7 | Will also add text conditioning, for eventual text-to-3d asset
8 |
9 |
10 | Please visit the orginal repo for more details:
11 | https://github.com/lucidrains/meshgpt-pytorch
12 |
13 | ### Data sources:
14 | #### ModelNet40: https://www.kaggle.com/datasets/balraj98/modelnet40-princeton-3d-object-dataset/data
15 |
16 | #### ShapeNet - [Extracted model labels](https://github.com/MarcusLoppe/meshgpt-pytorch/blob/main/shapenet_labels.json) Repository: https://huggingface.co/datasets/ShapeNet/shapenetcore-gltf
17 |
18 | #### Objaverse - [Downloader](https://github.com/MarcusLoppe/Objaverse-downloader/tree/main) Repository: https://huggingface.co/datasets/allenai/objaverse
19 |
20 | ## Pre-trained autoencoder on the Objaverse dataset (14k meshes, only meshes that have max 250 faces):
21 | This is contains only autoencoder model, I'm currently training the transformer model.
22 | Visit the discussions [Pre-trained autoencoder & data sourcing](https://github.com/lucidrains/meshgpt-pytorch/discussions/66) for more information about the training and details about the progression.
23 |
24 | https://drive.google.com/drive/folders/1C1l5QrCtg9UulMJE5n_on4A9O9Gn0CC5?usp=sharing
25 |
26 |
27 | The auto-encoder results shows that it's possible to compress many mesh models into tokens which then can be decoded and reconstruct a mesh near perfection!
28 | The auto-encoder was trained for 9 epochs for 20hrs on a single P100 GPU.
29 |
30 | The more compute heavy part is to train a transformer that can use these tokens learn the auto-encoder 'language'.
31 | Using the codes as a vocabablity and learn the relationship between the the codes and it's ordering requires a lot compute to train compared to the auto-encoder.
32 | So by using a single P100 GPU it will probaly take a few weeks till I can get out a pre-trained transformer.
33 |
34 | Let me know if you wish to donate any compute or I can provide you with the dataset + training notebook.
35 |
36 |
37 | ```
38 | num_layers = 23
39 | autoencoder = MeshAutoencoder(
40 | decoder_dims_through_depth = (128,) * 3 + (192,) * 4 + (256,) * num_layers + (384,) * 3,
41 | dim_codebook = 192 ,
42 | codebook_size = 16384 ,
43 | dim_area_embed = 16,
44 | dim_coor_embed = 16,
45 | dim_normal_embed = 16,
46 | dim_angle_embed = 8,
47 |
48 | attn_decoder_depth = 8,
49 | attn_encoder_depth = 4
50 | ).to("cuda")
51 | ```
52 |
53 | #### Results, it's about 14k models so with the limited training time and hardware It's a great result.
54 | 
55 |
56 | ## Citations
57 |
58 | ```bibtex
59 | @inproceedings{Siddiqui2023MeshGPTGT,
60 | title = {MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers},
61 | author = {Yawar Siddiqui and Antonio Alliegro and Alexey Artemov and Tatiana Tommasi and Daniele Sirigatti and Vladislav Rosov and Angela Dai and Matthias Nie{\ss}ner},
62 | year = {2023},
63 | url = {https://api.semanticscholar.org/CorpusID:265457242}
64 | }
65 | ```
66 |
67 | ```bibtex
68 | @inproceedings{dao2022flashattention,
69 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
70 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
71 | booktitle = {Advances in Neural Information Processing Systems},
72 | year = {2022}
73 | }
74 | ```
75 |
76 | ```bibtex
77 | @inproceedings{Leviathan2022FastIF,
78 | title = {Fast Inference from Transformers via Speculative Decoding},
79 | author = {Yaniv Leviathan and Matan Kalman and Y. Matias},
80 | booktitle = {International Conference on Machine Learning},
81 | year = {2022},
82 | url = {https://api.semanticscholar.org/CorpusID:254096365}
83 | }
84 | ```
85 |
86 | ```bibtex
87 | @misc{yu2023language,
88 | title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
89 | author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
90 | year = {2023},
91 | eprint = {2310.05737},
92 | archivePrefix = {arXiv},
93 | primaryClass = {cs.CV}
94 | }
95 | ```
96 |
97 | ```bibtex
98 | @article{Lee2022AutoregressiveIG,
99 | title = {Autoregressive Image Generation using Residual Quantization},
100 | author = {Doyup Lee and Chiheon Kim and Saehoon Kim and Minsu Cho and Wook-Shin Han},
101 | journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
102 | year = {2022},
103 | pages = {11513-11522},
104 | url = {https://api.semanticscholar.org/CorpusID:247244535}
105 | }
106 | ```
107 |
108 | ```bibtex
109 | @inproceedings{Katsch2023GateLoopFD,
110 | title = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
111 | author = {Tobias Katsch},
112 | year = {2023},
113 | url = {https://api.semanticscholar.org/CorpusID:265018962}
114 | }
115 | ```
116 |
--------------------------------------------------------------------------------
/demo_mesh/bar chair.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/bar chair.glb
--------------------------------------------------------------------------------
/demo_mesh/circle chair.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/circle chair.glb
--------------------------------------------------------------------------------
/demo_mesh/corner table.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/corner table.glb
--------------------------------------------------------------------------------
/demo_mesh/designer chair.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/designer chair.glb
--------------------------------------------------------------------------------
/demo_mesh/designer sloped chair.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/designer sloped chair.glb
--------------------------------------------------------------------------------
/demo_mesh/glass table.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/glass table.glb
--------------------------------------------------------------------------------
/demo_mesh/high chair.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/high chair.glb
--------------------------------------------------------------------------------
/demo_mesh/office table.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/office table.glb
--------------------------------------------------------------------------------
/demo_mesh/tv table.glb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/tv table.glb
--------------------------------------------------------------------------------
/meshgpt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/meshgpt.png
--------------------------------------------------------------------------------
/meshgpt_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from meshgpt_pytorch.meshgpt_pytorch import (
2 | MeshAutoencoder,
3 | MeshTransformer
4 | )
5 |
6 | from meshgpt_pytorch.trainer import (
7 | MeshAutoencoderTrainer,
8 | MeshTransformerTrainer
9 | )
10 |
11 | from meshgpt_pytorch.data import (
12 | DatasetFromTransforms,
13 | cache_text_embeds_for_dataset,
14 | cache_face_edges_for_dataset
15 | )
16 |
17 | from meshgpt_pytorch.mesh_dataset import (
18 | MeshDataset
19 | )
20 | from meshgpt_pytorch.mesh_render import (
21 | save_rendering,
22 | combind_mesh_with_rows
23 | )
24 |
25 |
--------------------------------------------------------------------------------
/meshgpt_pytorch/data.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from functools import partial
5 |
6 | import torch
7 | from torch import Tensor
8 | from torch import is_tensor
9 | from torch.utils.data import Dataset
10 | from torch.nn.utils.rnn import pad_sequence
11 |
12 | from numpy.lib.format import open_memmap
13 |
14 | from einops import rearrange, reduce
15 | from torch import nn, Tensor
16 |
17 | from beartype.typing import Tuple, List, Callable, Dict
18 | from meshgpt_pytorch.typing import typecheck
19 |
20 | from torchtyping import TensorType
21 |
22 | from pytorch_custom_utils.utils import pad_or_slice_to
23 |
24 | # helper fn
25 |
26 | def exists(v):
27 | return v is not None
28 |
29 | def identity(t):
30 | return t
31 |
32 | # constants
33 |
34 | Vertices = TensorType['nv', 3, float] # 3 coordinates
35 | Faces = TensorType['nf', 3, int] # 3 vertices
36 |
37 | # decorator for auto-caching texts -> text embeds
38 |
39 | # you would decorate your Dataset class with this
40 | # and then change your `data_kwargs = ["text_embeds", "vertices", "faces"]`
41 |
42 | @typecheck
43 | def cache_text_embeds_for_dataset(
44 | embed_texts_fn: Callable[[List[str]], Tensor],
45 | max_text_len: int,
46 | cache_path: str = './text_embed_cache'
47 | ):
48 | # create path to cache folder
49 |
50 | path = Path(cache_path)
51 | path.mkdir(exist_ok = True, parents = True)
52 | assert path.is_dir()
53 |
54 | # global memmap handles
55 |
56 | text_embed_cache = None
57 | is_cached = None
58 |
59 | # cache function
60 |
61 | def get_maybe_cached_text_embed(
62 | idx: int,
63 | dataset_len: int,
64 | text: str,
65 | memmap_file_mode = 'w+'
66 | ):
67 | nonlocal text_embed_cache
68 | nonlocal is_cached
69 |
70 | # init cache on first call
71 |
72 | if not exists(text_embed_cache):
73 | test_embed = embed_texts_fn(['test'])
74 | feat_dim = test_embed.shape[-1]
75 | shape = (dataset_len, max_text_len, feat_dim)
76 |
77 | text_embed_cache = open_memmap(str(path / 'cache.text_embed.memmap.npy'), mode = memmap_file_mode, dtype = 'float32', shape = shape)
78 | is_cached = open_memmap(str(path / 'cache.is_cached.memmap.npy'), mode = memmap_file_mode, dtype = 'bool', shape = (dataset_len,))
79 |
80 | # determine whether to fetch from cache
81 | # or call text model
82 |
83 | if is_cached[idx]:
84 | text_embed = torch.from_numpy(text_embed_cache[idx])
85 | else:
86 | # cache
87 |
88 | text_embed = get_text_embed(text)
89 | text_embed = pad_or_slice_to(text_embed, max_text_len, dim = 0, pad_value = 0.)
90 |
91 | is_cached[idx] = True
92 | text_embed_cache[idx] = text_embed.cpu().numpy()
93 |
94 | mask = ~reduce(text_embed == 0, 'n d -> n', 'all')
95 | return text_embed[mask]
96 |
97 | # get text embedding
98 |
99 | def get_text_embed(text: str):
100 | text_embeds = embed_texts_fn([text])
101 | return text_embeds[0]
102 |
103 | # inner function
104 |
105 | def inner(dataset_klass):
106 | assert issubclass(dataset_klass, Dataset)
107 |
108 | orig_init = dataset_klass.__init__
109 | orig_get_item = dataset_klass.__getitem__
110 |
111 | def __init__(
112 | self,
113 | *args,
114 | cache_memmap_file_mode = 'w+',
115 | **kwargs
116 | ):
117 | orig_init(self, *args, **kwargs)
118 |
119 | self._cache_memmap_file_mode = cache_memmap_file_mode
120 |
121 | if hasattr(self, 'data_kwargs'):
122 | self.data_kwargs = [('text_embeds' if data_kwarg == 'texts' else data_kwarg) for data_kwarg in self.data_kwargs]
123 |
124 | def __getitem__(self, idx):
125 | items = orig_get_item(self, idx)
126 |
127 | get_text_embed_ = partial(get_maybe_cached_text_embed, idx, len(self), memmap_file_mode = self._cache_memmap_file_mode)
128 |
129 | if isinstance(items, dict):
130 | if 'texts' in items:
131 | text_embed = get_text_embed_(items['texts'])
132 | items['text_embeds'] = text_embed
133 | del items['texts']
134 |
135 | elif isinstance(items, tuple):
136 | new_items = []
137 |
138 | for maybe_text in items:
139 | if not isinstance(maybe_text, str):
140 | new_items.append(maybe_text)
141 | continue
142 |
143 | new_items.append(get_text_embed_(maybe_text))
144 |
145 | items = tuple(new_items)
146 |
147 | return items
148 |
149 | dataset_klass.__init__ = __init__
150 | dataset_klass.__getitem__ = __getitem__
151 |
152 | return dataset_klass
153 |
154 | return inner
155 |
156 | # decorator for auto-caching face edges
157 |
158 | # you would decorate your Dataset class with this function
159 | # and then change your `data_kwargs = ["vertices", "faces", "face_edges"]`
160 |
161 | @typecheck
162 | def cache_face_edges_for_dataset(
163 | max_edges_len: int,
164 | cache_path: str = './face_edges_cache',
165 | assert_edge_len_lt_max: bool = True,
166 | pad_id = -1
167 | ):
168 | # create path to cache folder
169 |
170 | path = Path(cache_path)
171 | path.mkdir(exist_ok = True, parents = True)
172 | assert path.is_dir()
173 |
174 | # global memmap handles
175 |
176 | face_edges_cache = None
177 | is_cached = None
178 |
179 | # cache function
180 |
181 | def get_maybe_cached_face_edges(
182 | idx: int,
183 | dataset_len: int,
184 | faces: Tensor,
185 | memmap_file_mode = 'w+'
186 | ):
187 | nonlocal face_edges_cache
188 | nonlocal is_cached
189 |
190 | if not exists(face_edges_cache):
191 | # init cache on first call
192 |
193 | shape = (dataset_len, max_edges_len, 2)
194 | face_edges_cache = open_memmap(str(path / 'cache.face_edges_embed.memmap.npy'), mode = memmap_file_mode, dtype = 'float32', shape = shape)
195 | is_cached = open_memmap(str(path / 'cache.is_cached.memmap.npy'), mode = memmap_file_mode, dtype = 'bool', shape = (dataset_len,))
196 |
197 | # determine whether to fetch from cache
198 | # or call derive face edges function
199 |
200 | if is_cached[idx]:
201 | face_edges = torch.from_numpy(face_edges_cache[idx])
202 | else:
203 | # cache
204 |
205 | face_edges = derive_face_edges_from_faces(faces, pad_id = pad_id)
206 |
207 | edge_len = face_edges.shape[0]
208 | assert not assert_edge_len_lt_max or (edge_len <= max_edges_len), f'mesh #{idx} has {edge_len} which exceeds the cache length of {max_edges_len}'
209 |
210 | face_edges = pad_or_slice_to(face_edges, max_edges_len, dim = 0, pad_value = pad_id)
211 |
212 | is_cached[idx] = True
213 | face_edges_cache[idx] = face_edges.cpu().numpy()
214 |
215 | mask = reduce(face_edges != pad_id, 'n d -> n', 'all')
216 | return face_edges[mask]
217 |
218 | # inner function
219 |
220 | def inner(dataset_klass):
221 | assert issubclass(dataset_klass, Dataset)
222 |
223 | orig_init = dataset_klass.__init__
224 | orig_get_item = dataset_klass.__getitem__
225 |
226 | def __init__(
227 | self,
228 | *args,
229 | cache_memmap_file_mode = 'w+',
230 | **kwargs
231 | ):
232 | orig_init(self, *args, **kwargs)
233 |
234 | self._cache_memmap_file_mode = cache_memmap_file_mode
235 |
236 | if hasattr(self, 'data_kwargs'):
237 | self.data_kwargs.append('face_edges')
238 |
239 | def __getitem__(self, idx):
240 | items = orig_get_item(self, idx)
241 |
242 | get_face_edges_ = partial(get_maybe_cached_face_edges, idx, len(self), memmap_file_mode = self._cache_memmap_file_mode)
243 |
244 | if isinstance(items, dict):
245 | face_edges = get_face_edges_(items['faces'])
246 | items['face_edges'] = face_edges
247 |
248 | elif isinstance(items, tuple):
249 | _, faces, *_ = items
250 | face_edges = get_face_edges_(faces)
251 | items = (*items, face_edges)
252 |
253 | return items
254 |
255 | dataset_klass.__init__ = __init__
256 | dataset_klass.__getitem__ = __getitem__
257 |
258 | return dataset_klass
259 |
260 | return inner
261 |
262 | # dataset
263 |
264 | class DatasetFromTransforms(Dataset):
265 | @typecheck
266 | def __init__(
267 | self,
268 | folder: str,
269 | transforms: Dict[str, Callable[[Path], Tuple[Vertices, Faces]]],
270 | data_kwargs: List[str] | None = None,
271 | augment_fn: Callable = identity
272 | ):
273 | folder = Path(folder)
274 | assert folder.exists and folder.is_dir()
275 | self.folder = folder
276 |
277 | exts = transforms.keys()
278 | self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]
279 |
280 | print(f'{len(self.paths)} training samples found at {folder}')
281 | assert len(self.paths) > 0
282 |
283 | self.transforms = transforms
284 | self.data_kwargs = data_kwargs
285 | self.augment_fn = augment_fn
286 |
287 | def __len__(self):
288 | return len(self.paths)
289 |
290 | def __getitem__(self, idx):
291 | path = self.paths[idx]
292 | ext = path.suffix[1:]
293 | fn = self.transforms[ext]
294 |
295 | out = fn(path)
296 | return self.augment_fn(out)
297 |
298 | # tensor helper functions
299 |
300 | def derive_face_edges_from_faces(
301 | faces: TensorType['b', 'nf', 3, int],
302 | pad_id = -1,
303 | neighbor_if_share_one_vertex = False,
304 | include_self = True
305 | ) -> TensorType['b', 'e', 2, int]:
306 |
307 | is_one_face, device = faces.ndim == 2, faces.device
308 |
309 | if is_one_face:
310 | faces = rearrange(faces, 'nf c -> 1 nf c')
311 |
312 | max_num_faces = faces.shape[1]
313 | face_edges_vertices_threshold = 1 if neighbor_if_share_one_vertex else 2
314 |
315 | all_edges = torch.stack(torch.meshgrid(
316 | torch.arange(max_num_faces, device = device),
317 | torch.arange(max_num_faces, device = device),
318 | indexing = 'ij'), dim = -1)
319 |
320 | face_masks = reduce(faces != pad_id, 'b nf c -> b nf', 'all')
321 | face_edges_masks = rearrange(face_masks, 'b i -> b i 1') & rearrange(face_masks, 'b j -> b 1 j')
322 |
323 | face_edges = []
324 |
325 | for face, face_edge_mask in zip(faces, face_edges_masks):
326 |
327 | shared_vertices = rearrange(face, 'i c -> i 1 c 1') == rearrange(face, 'j c -> 1 j 1 c')
328 | num_shared_vertices = shared_vertices.any(dim = -1).sum(dim = -1)
329 |
330 | is_neighbor_face = (num_shared_vertices >= face_edges_vertices_threshold) & face_edge_mask
331 |
332 | if not include_self:
333 | is_neighbor_face &= num_shared_vertices != 3
334 |
335 | face_edge = all_edges[is_neighbor_face]
336 | face_edges.append(face_edge)
337 |
338 | face_edges = pad_sequence(face_edges, padding_value = pad_id, batch_first = True)
339 |
340 | if is_one_face:
341 | face_edges = rearrange(face_edges, '1 e ij -> e ij')
342 |
343 | return face_edges
344 |
345 | # custom collater
346 |
347 | def first(it):
348 | return it[0]
349 |
350 | def custom_collate(data, pad_id = -1):
351 | is_dict = isinstance(first(data), dict)
352 |
353 | if is_dict:
354 | keys = first(data).keys()
355 | data = [d.values() for d in data]
356 |
357 | output = []
358 |
359 | for datum in zip(*data):
360 | if is_tensor(first(datum)):
361 | datum = pad_sequence(datum, batch_first = True, padding_value = pad_id)
362 | else:
363 | datum = list(datum)
364 | output.append(datum)
365 |
366 | output.append(datum)
367 |
368 | output = tuple(output)
369 |
370 | if is_dict:
371 | output = dict(zip(keys, output))
372 |
373 | return output
--------------------------------------------------------------------------------
/meshgpt_pytorch/mesh_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import numpy as np
3 | from torch.nn.utils.rnn import pad_sequence
4 | from tqdm import tqdm
5 | import torch
6 | from meshgpt_pytorch import (
7 | MeshAutoencoder,
8 | MeshTransformer
9 | )
10 |
11 | from meshgpt_pytorch.data import (
12 | derive_face_edges_from_faces
13 | )
14 |
15 | class MeshDataset(Dataset):
16 | """
17 | A PyTorch Dataset to load and process mesh data.
18 | The `MeshDataset` provides functions to load mesh data from a file, embed text information, generate face edges, and generate codes.
19 |
20 | Attributes:
21 | data (list): A list of mesh data entries. Each entry is a dictionary containing the following keys:
22 | vertices (torch.Tensor): A tensor of vertices with shape (num_vertices, 3).
23 | faces (torch.Tensor): A tensor of faces with shape (num_faces, 3).
24 | text (str): A string containing the associated text information for the mesh.
25 | text_embeds (torch.Tensor): A tensor of text embeddings for the mesh.
26 | face_edges (torch.Tensor): A tensor of face edges with shape (num_faces, num_edges).
27 | codes (torch.Tensor): A tensor of codes generated from the mesh data.
28 |
29 | Example usage:
30 |
31 | ```
32 | data = [
33 | {'vertices': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'faces': torch.tensor([[0, 1, 2]], dtype=torch.long), 'text': 'table'},
34 | {'vertices': torch.tensor([[10, 20, 30], [40, 50, 60]], dtype=torch.float32), 'faces': torch.tensor([[1, 2, 0]], dtype=torch.long), "text": "chair"},
35 | ]
36 |
37 | # Create a MeshDataset instance
38 | mesh_dataset = MeshDataset(data)
39 |
40 | # Save the MeshDataset to disk
41 | mesh_dataset.save('mesh_dataset.npz')
42 |
43 | # Load the MeshDataset from disk
44 | loaded_mesh_dataset = MeshDataset.load('mesh_dataset.npz')
45 |
46 | # Generate face edges so it doesn't need to be done every time during training
47 | dataset.generate_face_edges()
48 | ```
49 | """
50 | def __init__(self, data):
51 | self.data = data
52 | print(f"[MeshDataset] Created from {len(self.data)} entries")
53 |
54 | def __len__(self):
55 | return len(self.data)
56 |
57 | def __getitem__(self, idx):
58 | data = self.data[idx]
59 | return data
60 |
61 | def save(self, path):
62 | np.savez_compressed(path, self.data, allow_pickle=True)
63 | print(f"[MeshDataset] Saved {len(self.data)} entries at {path}")
64 |
65 | @classmethod
66 | def load(cls, path):
67 | loaded_data = np.load(path, allow_pickle=True)
68 | data = []
69 | for item in loaded_data["arr_0"]:
70 | data.append(item)
71 | print(f"[MeshDataset] Loaded {len(data)} entries")
72 | return cls(data)
73 |
74 | def sort_dataset_keys(self):
75 | desired_order = ['vertices', 'faces', 'face_edges', 'texts','text_embeds','codes']
76 | self.data = [
77 | {key: d[key] for key in desired_order if key in d} for d in self.data
78 | ]
79 |
80 | def generate_face_edges(self, batch_size = 5):
81 | data_to_process = [item for item in self.data if 'faces_edges' not in item]
82 |
83 | total_batches = (len(data_to_process) + batch_size - 1) // batch_size
84 | device = "cuda" if torch.cuda.is_available() else "cpu"
85 |
86 | for i in tqdm(range(0, len(data_to_process), batch_size), total=total_batches):
87 | batch_data = data_to_process[i:i+batch_size]
88 |
89 | if not batch_data:
90 | continue
91 |
92 | padded_batch_faces = pad_sequence(
93 | [item['faces'] for item in batch_data],
94 | batch_first=True,
95 | padding_value=-1
96 | ).to(device)
97 |
98 | batched_faces_edges = derive_face_edges_from_faces(padded_batch_faces, pad_id=-1)
99 |
100 | mask = (batched_faces_edges != -1).all(dim=-1)
101 | for item_idx, (item_edges, item_mask) in enumerate(zip(batched_faces_edges, mask)):
102 | item_edges_masked = item_edges[item_mask]
103 | item = batch_data[item_idx]
104 | item['face_edges'] = item_edges_masked
105 |
106 | self.sort_dataset_keys()
107 | print(f"[MeshDataset] Generated face_edges for {len(data_to_process)} entries")
108 |
109 | def generate_codes(self, autoencoder : MeshAutoencoder, batch_size = 25):
110 | total_batches = (len(self.data) + batch_size - 1) // batch_size
111 |
112 | for i in tqdm(range(0, len(self.data), batch_size), total=total_batches):
113 | batch_data = self.data[i:i+batch_size]
114 |
115 | padded_batch_vertices = pad_sequence([item['vertices'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device)
116 | padded_batch_faces = pad_sequence([item['faces'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device)
117 | padded_batch_face_edges = pad_sequence([item['face_edges'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device)
118 |
119 | batch_codes = autoencoder.tokenize(
120 | vertices=padded_batch_vertices,
121 | faces=padded_batch_faces,
122 | face_edges=padded_batch_face_edges
123 | )
124 |
125 |
126 | mask = (batch_codes != autoencoder.pad_id).all(dim=-1)
127 | for item_idx, (item_codes, item_mask) in enumerate(zip(batch_codes, mask)):
128 | item_codes_masked = item_codes[item_mask]
129 | item = batch_data[item_idx]
130 | item['codes'] = item_codes_masked
131 |
132 | self.sort_dataset_keys()
133 | print(f"[MeshDataset] Generated codes for {len(self.data)} entries")
134 |
135 | def embed_texts(self, transformer : MeshTransformer, batch_size = 50):
136 | unique_texts = list(set(item['texts'] for item in self.data))
137 | text_embedding_dict = {}
138 | for i in tqdm(range(0,len(unique_texts), batch_size)):
139 | batch_texts = unique_texts[i:i+batch_size]
140 | text_embeddings = transformer.embed_texts(batch_texts)
141 | mask = (text_embeddings != transformer.conditioner.text_embed_pad_value).all(dim=-1)
142 |
143 | for idx, text in enumerate(batch_texts):
144 | masked_embedding = text_embeddings[idx][mask[idx]]
145 | text_embedding_dict[text] = masked_embedding
146 |
147 | for item in self.data:
148 | if 'texts' in item:
149 | item['text_embeds'] = text_embedding_dict.get(item['texts'], None)
150 | del item['texts']
151 |
152 | self.sort_dataset_keys()
153 | print(f"[MeshDataset] Generated {len(text_embedding_dict)} text_embeddings")
154 |
--------------------------------------------------------------------------------
/meshgpt_pytorch/mesh_render.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 |
4 | def orient_triangle_upward(v1, v2, v3):
5 | edge1 = v2 - v1
6 | edge2 = v3 - v1
7 | normal = np.cross(edge1, edge2)
8 | normal = normal / np.linalg.norm(normal)
9 |
10 | up = np.array([0, 1, 0])
11 | if np.dot(normal, up) < 0:
12 | v1, v3 = v3, v1
13 | return v1, v2, v3
14 |
15 | def get_angle(v1, v2, v3):
16 | v1, v2, v3 = orient_triangle_upward(v1, v2, v3)
17 | vec1 = v2 - v1
18 | vec2 = v3 - v1
19 | angle_rad = np.arccos(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
20 | return math.degrees(angle_rad)
21 |
22 | def combind_mesh_with_rows(path, meshes):
23 | all_vertices = []
24 | all_faces = []
25 | vertex_offset = 0
26 | translation_distance = 0.5
27 | obj_file_content = ""
28 |
29 | for row, mesh in enumerate(meshes):
30 | for r, faces_coordinates in enumerate(mesh):
31 | numpy_data = faces_coordinates[0].cpu().numpy().reshape(-1, 3)
32 | numpy_data[:, 0] += translation_distance * (r / 0.2 - 1)
33 | numpy_data[:, 2] += translation_distance * (row / 0.2 - 1)
34 |
35 | for vertex in numpy_data:
36 | all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")
37 |
38 | for i in range(1, len(numpy_data), 3):
39 | all_faces.append(f"f {i + vertex_offset} {i + 1 + vertex_offset} {i + 2 + vertex_offset}\n")
40 |
41 | vertex_offset += len(numpy_data)
42 |
43 | obj_file_content = "".join(all_vertices) + "".join(all_faces)
44 |
45 | with open(path , "w") as file:
46 | file.write(obj_file_content)
47 |
48 |
49 | def save_rendering(path, input_meshes):
50 | all_vertices,all_faces = [],[]
51 | vertex_offset = 0
52 | translation_distance = 0.5
53 | obj_file_content = ""
54 | meshes = input_meshes if isinstance(input_meshes, list) else [input_meshes]
55 |
56 | for row, mesh in enumerate(meshes):
57 | mesh = mesh if isinstance(mesh, list) else [mesh]
58 | cell_offset = 0
59 | for tensor, mask in mesh:
60 | for tensor_batch, mask_batch in zip(tensor,mask):
61 | numpy_data = tensor_batch[mask_batch].cpu().numpy().reshape(-1, 3)
62 | numpy_data[:, 0] += translation_distance * (cell_offset / 0.2 - 1)
63 | numpy_data[:, 2] += translation_distance * (row / 0.2 - 1)
64 | cell_offset += 1
65 | for vertex in numpy_data:
66 | all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")
67 |
68 | mesh_center = np.mean(numpy_data, axis=0)
69 | for i in range(0, len(numpy_data), 3):
70 | v1 = numpy_data[i]
71 | v2 = numpy_data[i + 1]
72 | v3 = numpy_data[i + 2]
73 |
74 | normal = np.cross(v2 - v1, v3 - v1)
75 | if get_angle(v1, v2, v3) > 60:
76 | direction_vector = mesh_center - np.mean([v1, v2, v3], axis=0)
77 | direction_vector = -direction_vector
78 | else:
79 | direction_vector = [0, 1, 0]
80 |
81 | if np.dot(normal, direction_vector) > 0:
82 | order = [0, 1, 2]
83 | else:
84 | order = [0, 2, 1]
85 |
86 | reordered_vertices = [v1, v2, v3][order[0]], [v1, v2, v3][order[1]], [v1, v2, v3][order[2]]
87 | indices = [np.where((numpy_data == vertex).all(axis=1))[0][0] + 1 + vertex_offset for vertex in reordered_vertices]
88 | all_faces.append(f"f {indices[0]} {indices[1]} {indices[2]}\n")
89 |
90 | vertex_offset += len(numpy_data)
91 | obj_file_content = "".join(all_vertices) + "".join(all_faces)
92 |
93 | with open(path , "w") as file:
94 | file.write(obj_file_content)
95 |
96 | print(f"[Save_rendering] Saved at {path}")
97 |
--------------------------------------------------------------------------------
/meshgpt_pytorch/meshgpt_pytorch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from functools import partial
5 | from math import ceil, pi, sqrt
6 |
7 | import torch
8 | from torch import nn, Tensor, einsum
9 | from torch.nn import Module, ModuleList
10 | import torch.nn.functional as F
11 | from torch.utils.checkpoint import checkpoint
12 | from torch.cuda.amp import autocast
13 |
14 | from pytorch_custom_utils import save_load
15 |
16 | from beartype.typing import Tuple, Callable, List, Dict, Any
17 | from meshgpt_pytorch.typing import Float, Int, Bool, typecheck
18 |
19 | from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
20 |
21 | from einops import rearrange, repeat, reduce, pack, unpack
22 | from einops.layers.torch import Rearrange
23 |
24 | from einx import get_at
25 |
26 | from x_transformers import Decoder
27 | from x_transformers.x_transformers import RMSNorm, FeedForward, LayerIntermediates
28 |
29 | from x_transformers.autoregressive_wrapper import (
30 | eval_decorator,
31 | top_k,
32 | top_p,
33 | )
34 |
35 | from local_attention import LocalMHA
36 |
37 | from vector_quantize_pytorch import (
38 | ResidualVQ,
39 | ResidualLFQ
40 | )
41 |
42 | from meshgpt_pytorch.data import derive_face_edges_from_faces
43 | from meshgpt_pytorch.version import __version__
44 |
45 | from taylor_series_linear_attention import TaylorSeriesLinearAttn
46 |
47 | from classifier_free_guidance_pytorch import (
48 | classifier_free_guidance,
49 | TextEmbeddingReturner
50 | )
51 |
52 | from torch_geometric.nn.conv import SAGEConv
53 |
54 | from gateloop_transformer import SimpleGateLoopLayer
55 |
56 | from tqdm import tqdm
57 |
58 | # helper functions
59 |
60 | def exists(v):
61 | return v is not None
62 |
63 | def default(v, d):
64 | return v if exists(v) else d
65 |
66 | def first(it):
67 | return it[0]
68 |
69 | def identity(t, *args, **kwargs):
70 | return t
71 |
72 | def divisible_by(num, den):
73 | return (num % den) == 0
74 |
75 | def is_odd(n):
76 | return not divisible_by(n, 2)
77 |
78 | def is_empty(x):
79 | return len(x) == 0
80 |
81 | def is_tensor_empty(t: Tensor):
82 | return t.numel() == 0
83 |
84 | def set_module_requires_grad_(
85 | module: Module,
86 | requires_grad: bool
87 | ):
88 | for param in module.parameters():
89 | param.requires_grad = requires_grad
90 |
91 | def l1norm(t):
92 | return F.normalize(t, dim = -1, p = 1)
93 |
94 | def l2norm(t):
95 | return F.normalize(t, dim = -1, p = 2)
96 |
97 | def safe_cat(tensors, dim):
98 | tensors = [*filter(exists, tensors)]
99 |
100 | if len(tensors) == 0:
101 | return None
102 | elif len(tensors) == 1:
103 | return first(tensors)
104 |
105 | return torch.cat(tensors, dim = dim)
106 |
107 | def pad_at_dim(t, padding, dim = -1, value = 0):
108 | ndim = t.ndim
109 | right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
110 | zeros = (0, 0) * right_dims
111 | return F.pad(t, (*zeros, *padding), value = value)
112 |
113 | def pad_to_length(t, length, dim = -1, value = 0, right = True):
114 | curr_length = t.shape[dim]
115 | remainder = length - curr_length
116 |
117 | if remainder <= 0:
118 | return t
119 |
120 | padding = (0, remainder) if right else (remainder, 0)
121 | return pad_at_dim(t, padding, dim = dim, value = value)
122 |
123 | def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
124 | if not exists(mask):
125 | return tensor.mean(dim = dim)
126 |
127 | mask = rearrange(mask, '... -> ... 1')
128 | tensor = tensor.masked_fill(~mask, 0.)
129 |
130 | total_el = mask.sum(dim = dim)
131 | num = tensor.sum(dim = dim)
132 | den = total_el.float().clamp(min = eps)
133 | mean = num / den
134 | mean = mean.masked_fill(total_el == 0, 0.)
135 | return mean
136 |
137 | # continuous embed
138 |
139 | def ContinuousEmbed(dim_cont):
140 | return nn.Sequential(
141 | Rearrange('... -> ... 1'),
142 | nn.Linear(1, dim_cont),
143 | nn.SiLU(),
144 | nn.Linear(dim_cont, dim_cont),
145 | nn.LayerNorm(dim_cont)
146 | )
147 |
148 | # additional encoder features
149 | # 1. angle (3), 2. area (1), 3. normals (3)
150 |
151 | def derive_angle(x, y, eps = 1e-5):
152 | z = einsum('... d, ... d -> ...', l2norm(x), l2norm(y))
153 | return z.clip(-1 + eps, 1 - eps).arccos()
154 |
155 | @torch.no_grad()
156 | @typecheck
157 | def get_derived_face_features(
158 | face_coords: Float['b nf nvf 3'] # 3 or 4 vertices with 3 coordinates
159 | ):
160 | is_quad = face_coords.shape[-2] == 4
161 |
162 | # shift face coordinates depending on triangles or quads
163 |
164 | shifted_face_coords = torch.roll(face_coords, 1, dims = (2,))
165 |
166 | angles = derive_angle(face_coords, shifted_face_coords)
167 |
168 | if is_quad:
169 | # @sbriseid says quads need to be shifted by 2
170 | shifted_face_coords = torch.roll(shifted_face_coords, 1, dims = (2,))
171 |
172 | edge1, edge2, *_ = (face_coords - shifted_face_coords).unbind(dim = 2)
173 |
174 | cross_product = torch.cross(edge1, edge2, dim = -1)
175 |
176 | normals = l2norm(cross_product)
177 | area = cross_product.norm(dim = -1, keepdim = True) * 0.5
178 |
179 | return dict(
180 | angles = angles,
181 | area = area,
182 | normals = normals
183 | )
184 |
185 | # tensor helper functions
186 |
187 | @typecheck
188 | def discretize(
189 | t: Tensor,
190 | *,
191 | continuous_range: Tuple[float, float],
192 | num_discrete: int = 128
193 | ) -> Tensor:
194 | lo, hi = continuous_range
195 | assert hi > lo
196 |
197 | t = (t - lo) / (hi - lo)
198 | t *= num_discrete
199 | t -= 0.5
200 |
201 | return t.round().long().clamp(min = 0, max = num_discrete - 1)
202 |
203 | @typecheck
204 | def undiscretize(
205 | t: Tensor,
206 | *,
207 | continuous_range = Tuple[float, float],
208 | num_discrete: int = 128
209 | ) -> Tensor:
210 | lo, hi = continuous_range
211 | assert hi > lo
212 |
213 | t = t.float()
214 |
215 | t += 0.5
216 | t /= num_discrete
217 | return t * (hi - lo) + lo
218 |
219 | @typecheck
220 | def gaussian_blur_1d(
221 | t: Tensor,
222 | *,
223 | sigma: float = 1.
224 | ) -> Tensor:
225 |
226 | _, _, channels, device, dtype = *t.shape, t.device, t.dtype
227 |
228 | width = int(ceil(sigma * 5))
229 | width += (width + 1) % 2
230 | half_width = width // 2
231 |
232 | distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)
233 |
234 | gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
235 | gaussian = l1norm(gaussian)
236 |
237 | kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
238 |
239 | t = rearrange(t, 'b n c -> b c n')
240 | out = F.conv1d(t, kernel, padding = half_width, groups = channels)
241 | return rearrange(out, 'b c n -> b n c')
242 |
243 | @typecheck
244 | def scatter_mean(
245 | tgt: Tensor,
246 | indices: Tensor,
247 | src = Tensor,
248 | *,
249 | dim: int = -1,
250 | eps: float = 1e-5
251 | ):
252 | """
253 | todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
254 | """
255 | num = tgt.scatter_add(dim, indices, src)
256 | den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src))
257 | return num / den.clamp(min = eps)
258 |
259 | # resnet block
260 |
261 | class FiLM(Module):
262 | def __init__(self, dim, dim_out = None):
263 | super().__init__()
264 | dim_out = default(dim_out, dim)
265 |
266 | self.to_gamma = nn.Linear(dim, dim_out, bias = False)
267 | self.to_beta = nn.Linear(dim, dim_out)
268 |
269 | self.gamma_mult = nn.Parameter(torch.zeros(1,))
270 | self.beta_mult = nn.Parameter(torch.zeros(1,))
271 |
272 | def forward(self, x, cond):
273 | gamma, beta = self.to_gamma(cond), self.to_beta(cond)
274 | gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta))
275 |
276 | # for initializing to identity
277 |
278 | gamma = (1 + self.gamma_mult * gamma.tanh())
279 | beta = beta.tanh() * self.beta_mult
280 |
281 | # classic film
282 |
283 | return x * gamma + beta
284 |
285 | class PixelNorm(Module):
286 | def __init__(self, dim, eps = 1e-4):
287 | super().__init__()
288 | self.dim = dim
289 | self.eps = eps
290 |
291 | def forward(self, x):
292 | dim = self.dim
293 | return F.normalize(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])
294 |
295 | class SqueezeExcite(Module):
296 | def __init__(
297 | self,
298 | dim,
299 | reduction_factor = 4,
300 | min_dim = 16
301 | ):
302 | super().__init__()
303 | dim_inner = max(dim // reduction_factor, min_dim)
304 |
305 | self.net = nn.Sequential(
306 | nn.Linear(dim, dim_inner),
307 | nn.SiLU(),
308 | nn.Linear(dim_inner, dim),
309 | nn.Sigmoid(),
310 | Rearrange('b c -> b c 1')
311 | )
312 |
313 | def forward(self, x, mask = None):
314 | if exists(mask):
315 | x = x.masked_fill(~mask, 0.)
316 |
317 | num = reduce(x, 'b c n -> b c', 'sum')
318 | den = reduce(mask.float(), 'b 1 n -> b 1', 'sum')
319 | avg = num / den.clamp(min = 1e-5)
320 | else:
321 | avg = reduce(x, 'b c n -> b c', 'mean')
322 |
323 | return x * self.net(avg)
324 |
325 | class Block(Module):
326 | def __init__(
327 | self,
328 | dim,
329 | dim_out = None,
330 | dropout = 0.
331 | ):
332 | super().__init__()
333 | dim_out = default(dim_out, dim)
334 |
335 | self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1)
336 | self.norm = PixelNorm(dim = 1)
337 | self.dropout = nn.Dropout(dropout)
338 | self.act = nn.SiLU()
339 |
340 | def forward(self, x, mask = None):
341 | if exists(mask):
342 | x = x.masked_fill(~mask, 0.)
343 |
344 | x = self.proj(x)
345 |
346 | if exists(mask):
347 | x = x.masked_fill(~mask, 0.)
348 |
349 | x = self.norm(x)
350 | x = self.act(x)
351 | x = self.dropout(x)
352 |
353 | return x
354 |
355 | class ResnetBlock(Module):
356 | def __init__(
357 | self,
358 | dim,
359 | dim_out = None,
360 | *,
361 | dropout = 0.
362 | ):
363 | super().__init__()
364 | dim_out = default(dim_out, dim)
365 | self.block1 = Block(dim, dim_out, dropout = dropout)
366 | self.block2 = Block(dim_out, dim_out, dropout = dropout)
367 | self.excite = SqueezeExcite(dim_out)
368 | self.residual_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
369 |
370 | def forward(
371 | self,
372 | x,
373 | mask = None
374 | ):
375 | res = self.residual_conv(x)
376 | h = self.block1(x, mask = mask)
377 | h = self.block2(h, mask = mask)
378 | h = self.excite(h, mask = mask)
379 | return h + res
380 |
381 | # gateloop layers
382 |
383 | class GateLoopBlock(Module):
384 | def __init__(
385 | self,
386 | dim,
387 | *,
388 | depth,
389 | use_heinsen = True
390 | ):
391 | super().__init__()
392 | self.gateloops = ModuleList([])
393 |
394 | for _ in range(depth):
395 | gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen)
396 | self.gateloops.append(gateloop)
397 |
398 | def forward(
399 | self,
400 | x,
401 | cache = None
402 | ):
403 | received_cache = exists(cache)
404 |
405 | if is_tensor_empty(x):
406 | return x, None
407 |
408 | if received_cache:
409 | prev, x = x[:, :-1], x[:, -1:]
410 |
411 | cache = default(cache, [])
412 | cache = iter(cache)
413 |
414 | new_caches = []
415 | for gateloop in self.gateloops:
416 | layer_cache = next(cache, None)
417 | out, new_cache = gateloop(x, cache = layer_cache, return_cache = True)
418 | new_caches.append(new_cache)
419 | x = x + out
420 |
421 | if received_cache:
422 | x = torch.cat((prev, x), dim = -2)
423 |
424 | return x, new_caches
425 |
426 | # main classes
427 |
428 | @save_load(version = __version__)
429 | class MeshAutoencoder(Module):
430 | @typecheck
431 | def __init__(
432 | self,
433 | num_discrete_coors = 128,
434 | coor_continuous_range: Tuple[float, float] = (-1., 1.),
435 | dim_coor_embed = 64,
436 | num_discrete_area = 128,
437 | dim_area_embed = 16,
438 | num_discrete_normals = 128,
439 | dim_normal_embed = 64,
440 | num_discrete_angle = 128,
441 | dim_angle_embed = 16,
442 | encoder_dims_through_depth: Tuple[int, ...] = (
443 | 64, 128, 256, 256, 576
444 | ),
445 | init_decoder_conv_kernel = 7,
446 | decoder_dims_through_depth: Tuple[int, ...] = (
447 | 128, 128, 128, 128,
448 | 192, 192, 192, 192,
449 | 256, 256, 256, 256, 256, 256,
450 | 384, 384, 384
451 | ),
452 | dim_codebook = 192,
453 | num_quantizers = 2, # or 'D' in the paper
454 | codebook_size = 16384, # they use 16k, shared codebook between layers
455 | use_residual_lfq = True, # whether to use the latest lookup-free quantization
456 | rq_kwargs: dict = dict(
457 | quantize_dropout = True,
458 | quantize_dropout_cutoff_index = 1,
459 | quantize_dropout_multiple_of = 1,
460 | ),
461 | rvq_kwargs: dict = dict(
462 | kmeans_init = True,
463 | threshold_ema_dead_code = 2,
464 | ),
465 | rlfq_kwargs: dict = dict(
466 | frac_per_sample_entropy = 1.,
467 | soft_clamp_input_value = 10.
468 | ),
469 | rvq_stochastic_sample_codes = True,
470 | sageconv_kwargs: dict = dict(
471 | normalize = True,
472 | project = True
473 | ),
474 | commit_loss_weight = 0.1,
475 | bin_smooth_blur_sigma = 0.4, # they blur the one hot discretized coordinate positions
476 | attn_encoder_depth = 0,
477 | attn_decoder_depth = 0,
478 | local_attn_kwargs: dict = dict(
479 | dim_head = 32,
480 | heads = 8
481 | ),
482 | local_attn_window_size = 64,
483 | linear_attn_kwargs: dict = dict(
484 | dim_head = 8,
485 | heads = 16
486 | ),
487 | use_linear_attn = True,
488 | pad_id = -1,
489 | flash_attn = True,
490 | attn_dropout = 0.,
491 | ff_dropout = 0.,
492 | resnet_dropout = 0,
493 | checkpoint_quantizer = False,
494 | quads = False
495 | ):
496 | super().__init__()
497 |
498 | self.num_vertices_per_face = 3 if not quads else 4
499 | total_coordinates_per_face = self.num_vertices_per_face * 3
500 |
501 | # main face coordinate embedding
502 |
503 | self.num_discrete_coors = num_discrete_coors
504 | self.coor_continuous_range = coor_continuous_range
505 |
506 | self.discretize_face_coords = partial(discretize, num_discrete = num_discrete_coors, continuous_range = coor_continuous_range)
507 | self.coor_embed = nn.Embedding(num_discrete_coors, dim_coor_embed)
508 |
509 | # derived feature embedding
510 |
511 | self.discretize_angle = partial(discretize, num_discrete = num_discrete_angle, continuous_range = (0., pi))
512 | self.angle_embed = nn.Embedding(num_discrete_angle, dim_angle_embed)
513 |
514 | lo, hi = coor_continuous_range
515 | self.discretize_area = partial(discretize, num_discrete = num_discrete_area, continuous_range = (0., (hi - lo) ** 2))
516 | self.area_embed = nn.Embedding(num_discrete_area, dim_area_embed)
517 |
518 | self.discretize_normals = partial(discretize, num_discrete = num_discrete_normals, continuous_range = coor_continuous_range)
519 | self.normal_embed = nn.Embedding(num_discrete_normals, dim_normal_embed)
520 |
521 | # attention related
522 |
523 | attn_kwargs = dict(
524 | causal = False,
525 | prenorm = True,
526 | dropout = attn_dropout,
527 | window_size = local_attn_window_size,
528 | )
529 |
530 | # initial dimension
531 |
532 | init_dim = dim_coor_embed * (3 * self.num_vertices_per_face) + dim_angle_embed * self.num_vertices_per_face + dim_normal_embed * 3 + dim_area_embed
533 |
534 | # project into model dimension
535 |
536 | self.project_in = nn.Linear(init_dim, dim_codebook)
537 |
538 | # initial sage conv
539 | init_encoder_dim, *encoder_dims_through_depth = encoder_dims_through_depth
540 | curr_dim = init_encoder_dim
541 |
542 | self.init_sage_conv = SAGEConv(dim_codebook, init_encoder_dim, **sageconv_kwargs)
543 |
544 | self.init_encoder_act_and_norm = nn.Sequential(
545 | nn.SiLU(),
546 | nn.LayerNorm(init_encoder_dim)
547 | )
548 |
549 | self.encoders = ModuleList([])
550 |
551 | for dim_layer in encoder_dims_through_depth:
552 | sage_conv = SAGEConv(
553 | curr_dim,
554 | dim_layer,
555 | **sageconv_kwargs
556 | )
557 |
558 | self.encoders.append(sage_conv)
559 | curr_dim = dim_layer
560 |
561 | self.encoder_attn_blocks = ModuleList([])
562 |
563 | for _ in range(attn_encoder_depth):
564 | self.encoder_attn_blocks.append(nn.ModuleList([
565 | TaylorSeriesLinearAttn(curr_dim, prenorm = True, **linear_attn_kwargs) if use_linear_attn else None,
566 | LocalMHA(dim = curr_dim, **attn_kwargs, **local_attn_kwargs),
567 | nn.Sequential(RMSNorm(curr_dim), FeedForward(curr_dim, glu = True, dropout = ff_dropout))
568 | ]))
569 |
570 | # residual quantization
571 |
572 | self.codebook_size = codebook_size
573 | self.num_quantizers = num_quantizers
574 |
575 | self.project_dim_codebook = nn.Linear(curr_dim, dim_codebook * self.num_vertices_per_face)
576 |
577 | if use_residual_lfq:
578 | self.quantizer = ResidualLFQ(
579 | dim = dim_codebook,
580 | num_quantizers = num_quantizers,
581 | codebook_size = codebook_size,
582 | commitment_loss_weight = 1.,
583 | **rlfq_kwargs,
584 | **rq_kwargs
585 | )
586 | else:
587 | self.quantizer = ResidualVQ(
588 | dim = dim_codebook,
589 | num_quantizers = num_quantizers,
590 | codebook_size = codebook_size,
591 | shared_codebook = True,
592 | commitment_weight = 1.,
593 | stochastic_sample_codes = rvq_stochastic_sample_codes,
594 | **rvq_kwargs,
595 | **rq_kwargs
596 | )
597 |
598 | self.checkpoint_quantizer = checkpoint_quantizer # whether to memory checkpoint the quantizer
599 |
600 | self.pad_id = pad_id # for variable lengthed faces, padding quantized ids will be set to this value
601 |
602 | # decoder
603 |
604 | decoder_input_dim = dim_codebook * 3
605 |
606 | self.decoder_attn_blocks = ModuleList([])
607 |
608 | for _ in range(attn_decoder_depth):
609 | self.decoder_attn_blocks.append(nn.ModuleList([
610 | TaylorSeriesLinearAttn(decoder_input_dim, prenorm = True, **linear_attn_kwargs) if use_linear_attn else None,
611 | LocalMHA(dim = decoder_input_dim, **attn_kwargs, **local_attn_kwargs),
612 | nn.Sequential(RMSNorm(decoder_input_dim), FeedForward(decoder_input_dim, glu = True, dropout = ff_dropout))
613 | ]))
614 |
615 | init_decoder_dim, *decoder_dims_through_depth = decoder_dims_through_depth
616 | curr_dim = init_decoder_dim
617 |
618 | assert is_odd(init_decoder_conv_kernel)
619 |
620 | self.init_decoder_conv = nn.Sequential(
621 | nn.Conv1d(dim_codebook * self.num_vertices_per_face, init_decoder_dim, kernel_size = init_decoder_conv_kernel, padding = init_decoder_conv_kernel // 2),
622 | nn.SiLU(),
623 | Rearrange('b c n -> b n c'),
624 | nn.LayerNorm(init_decoder_dim),
625 | Rearrange('b n c -> b c n')
626 | )
627 |
628 | self.decoders = ModuleList([])
629 |
630 | for dim_layer in decoder_dims_through_depth:
631 | resnet_block = ResnetBlock(curr_dim, dim_layer, dropout = resnet_dropout)
632 |
633 | self.decoders.append(resnet_block)
634 | curr_dim = dim_layer
635 |
636 | self.to_coor_logits = nn.Sequential(
637 | nn.Linear(curr_dim, num_discrete_coors * total_coordinates_per_face),
638 | Rearrange('... (v c) -> ... v c', v = total_coordinates_per_face)
639 | )
640 |
641 | # loss related
642 |
643 | self.commit_loss_weight = commit_loss_weight
644 | self.bin_smooth_blur_sigma = bin_smooth_blur_sigma
645 |
646 | @property
647 | def device(self):
648 | return next(self.parameters()).device
649 |
650 | @classmethod
651 | def _from_pretrained(
652 | cls,
653 | *,
654 | model_id: str,
655 | revision: str | None,
656 | cache_dir: str | Path | None,
657 | force_download: bool,
658 | proxies: Dict | None,
659 | resume_download: bool,
660 | local_files_only: bool,
661 | token: str | bool | None,
662 | map_location: str = "cpu",
663 | strict: bool = False,
664 | **model_kwargs,
665 | ):
666 | model_filename = "mesh-autoencoder.bin"
667 | model_file = Path(model_id) / model_filename
668 | if not model_file.exists():
669 | model_file = hf_hub_download(
670 | repo_id=model_id,
671 | filename=model_filename,
672 | revision=revision,
673 | cache_dir=cache_dir,
674 | force_download=force_download,
675 | proxies=proxies,
676 | resume_download=resume_download,
677 | token=token,
678 | local_files_only=local_files_only,
679 | )
680 | model = cls.init_and_load(model_file,strict=strict)
681 | model.to(map_location)
682 | return model
683 |
684 | @typecheck
685 | def encode(
686 | self,
687 | *,
688 | vertices: Float['b nv 3'],
689 | faces: Int['b nf nvf'],
690 | face_edges: Int['b e 2'],
691 | face_mask: Bool['b nf'],
692 | face_edges_mask: Bool['b e'],
693 | return_face_coordinates = False
694 | ):
695 | """
696 | einops:
697 | b - batch
698 | nf - number of faces
699 | nv - number of vertices (3)
700 | nvf - number of vertices per face (3 or 4) - triangles vs quads
701 | c - coordinates (3)
702 | d - embed dim
703 | """
704 |
705 | _, num_faces, num_vertices_per_face = faces.shape
706 |
707 | assert self.num_vertices_per_face == num_vertices_per_face
708 |
709 | face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0)
710 |
711 | # continuous face coords
712 |
713 | face_coords = get_at('b [nv] c, b nf mv -> b nf mv c', vertices, face_without_pad)
714 |
715 | # compute derived features and embed
716 |
717 | derived_features = get_derived_face_features(face_coords)
718 |
719 | discrete_angle = self.discretize_angle(derived_features['angles'])
720 | angle_embed = self.angle_embed(discrete_angle)
721 |
722 | discrete_area = self.discretize_area(derived_features['area'])
723 | area_embed = self.area_embed(discrete_area)
724 |
725 | discrete_normal = self.discretize_normals(derived_features['normals'])
726 | normal_embed = self.normal_embed(discrete_normal)
727 |
728 | # discretize vertices for face coordinate embedding
729 |
730 | discrete_face_coords = self.discretize_face_coords(face_coords)
731 | discrete_face_coords = rearrange(discrete_face_coords, 'b nf nv c -> b nf (nv c)') # 9 or 12 coordinates per face
732 |
733 | face_coor_embed = self.coor_embed(discrete_face_coords)
734 | face_coor_embed = rearrange(face_coor_embed, 'b nf c d -> b nf (c d)')
735 |
736 | # combine all features and project into model dimension
737 |
738 | face_embed, _ = pack([face_coor_embed, angle_embed, area_embed, normal_embed], 'b nf *')
739 | face_embed = self.project_in(face_embed)
740 |
741 | # handle variable lengths by using masked_select and masked_scatter
742 |
743 | # first handle edges
744 | # needs to be offset by number of faces for each batch
745 |
746 | face_index_offsets = reduce(face_mask.long(), 'b nf -> b', 'sum')
747 | face_index_offsets = F.pad(face_index_offsets.cumsum(dim = 0), (1, -1), value = 0)
748 | face_index_offsets = rearrange(face_index_offsets, 'b -> b 1 1')
749 |
750 | face_edges = face_edges + face_index_offsets
751 | face_edges = face_edges[face_edges_mask]
752 | face_edges = rearrange(face_edges, 'be ij -> ij be')
753 |
754 | # next prepare the face_mask for using masked_select and masked_scatter
755 |
756 | orig_face_embed_shape = face_embed.shape[:2]
757 |
758 | face_embed = face_embed[face_mask]
759 |
760 | # initial sage conv followed by activation and norm
761 |
762 | face_embed = self.init_sage_conv(face_embed, face_edges)
763 |
764 | face_embed = self.init_encoder_act_and_norm(face_embed)
765 |
766 | for conv in self.encoders:
767 | face_embed = conv(face_embed, face_edges)
768 |
769 | shape = (*orig_face_embed_shape, face_embed.shape[-1])
770 |
771 | face_embed = face_embed.new_zeros(shape).masked_scatter(rearrange(face_mask, '... -> ... 1'), face_embed)
772 |
773 | for linear_attn, attn, ff in self.encoder_attn_blocks:
774 | if exists(linear_attn):
775 | face_embed = linear_attn(face_embed, mask = face_mask) + face_embed
776 |
777 | face_embed = attn(face_embed, mask = face_mask) + face_embed
778 | face_embed = ff(face_embed) + face_embed
779 |
780 | if not return_face_coordinates:
781 | return face_embed
782 |
783 | return face_embed, discrete_face_coords
784 |
785 | @typecheck
786 | def quantize(
787 | self,
788 | *,
789 | faces: Int['b nf nvf'],
790 | face_mask: Bool['b n'],
791 | face_embed: Float['b nf d'],
792 | pad_id = None,
793 | rvq_sample_codebook_temp = 1.
794 | ):
795 | pad_id = default(pad_id, self.pad_id)
796 | batch, device = faces.shape[0], faces.device
797 |
798 | max_vertex_index = faces.amax()
799 | num_vertices = int(max_vertex_index.item() + 1)
800 |
801 | face_embed = self.project_dim_codebook(face_embed)
802 | face_embed = rearrange(face_embed, 'b nf (nvf d) -> b nf nvf d', nvf = self.num_vertices_per_face)
803 |
804 | vertex_dim = face_embed.shape[-1]
805 | vertices = torch.zeros((batch, num_vertices, vertex_dim), device = device)
806 |
807 | # create pad vertex, due to variable lengthed faces
808 |
809 | pad_vertex_id = num_vertices
810 | vertices = pad_at_dim(vertices, (0, 1), dim = -2, value = 0.)
811 |
812 | faces = faces.masked_fill(~rearrange(face_mask, 'b n -> b n 1'), pad_vertex_id)
813 |
814 | # prepare for scatter mean
815 |
816 | faces_with_dim = repeat(faces, 'b nf nvf -> b (nf nvf) d', d = vertex_dim)
817 |
818 | face_embed = rearrange(face_embed, 'b ... d -> b (...) d')
819 |
820 | # scatter mean
821 |
822 | averaged_vertices = scatter_mean(vertices, faces_with_dim, face_embed, dim = -2)
823 |
824 | # mask out null vertex token
825 |
826 | mask = torch.ones((batch, num_vertices + 1), device = device, dtype = torch.bool)
827 | mask[:, -1] = False
828 |
829 | # rvq specific kwargs
830 |
831 | quantize_kwargs = dict(mask = mask)
832 |
833 | if isinstance(self.quantizer, ResidualVQ):
834 | quantize_kwargs.update(sample_codebook_temp = rvq_sample_codebook_temp)
835 |
836 | # a quantize function that makes it memory checkpointable
837 |
838 | def quantize_wrapper_fn(inp):
839 | unquantized, quantize_kwargs = inp
840 | return self.quantizer(unquantized, **quantize_kwargs)
841 |
842 | # maybe checkpoint the quantize fn
843 |
844 | if self.checkpoint_quantizer:
845 | quantize_wrapper_fn = partial(checkpoint, quantize_wrapper_fn, use_reentrant = False)
846 |
847 | # residual VQ
848 |
849 | quantized, codes, commit_loss = quantize_wrapper_fn((averaged_vertices, quantize_kwargs))
850 |
851 | # gather quantized vertexes back to faces for decoding
852 | # now the faces have quantized vertices
853 |
854 | face_embed_output = get_at('b [n] d, b nf nvf -> b nf (nvf d)', quantized, faces)
855 |
856 | # vertex codes also need to be gathered to be organized by face sequence
857 | # for autoregressive learning
858 |
859 | codes_output = get_at('b [n] q, b nf nvf -> b (nf nvf) q', codes, faces)
860 |
861 | # make sure codes being outputted have this padding
862 |
863 | face_mask = repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face)
864 | codes_output = codes_output.masked_fill(~face_mask, self.pad_id)
865 |
866 | # output quantized, codes, as well as commitment loss
867 |
868 | return face_embed_output, codes_output, commit_loss
869 |
870 | @typecheck
871 | def decode(
872 | self,
873 | quantized: Float['b n d'],
874 | face_mask: Bool['b n']
875 | ):
876 | conv_face_mask = rearrange(face_mask, 'b n -> b 1 n')
877 |
878 | x = quantized
879 |
880 | for linear_attn, attn, ff in self.decoder_attn_blocks:
881 | if exists(linear_attn):
882 | x = linear_attn(x, mask = face_mask) + x
883 |
884 | x = attn(x, mask = face_mask) + x
885 | x = ff(x) + x
886 |
887 | x = rearrange(x, 'b n d -> b d n')
888 | x = x.masked_fill(~conv_face_mask, 0.)
889 | x = self.init_decoder_conv(x)
890 |
891 | for resnet_block in self.decoders:
892 | x = resnet_block(x, mask = conv_face_mask)
893 |
894 | return rearrange(x, 'b d n -> b n d')
895 |
896 | @typecheck
897 | @torch.no_grad()
898 | def decode_from_codes_to_faces(
899 | self,
900 | codes: Tensor,
901 | face_mask: Bool['b n'] | None = None,
902 | return_discrete_codes = False
903 | ):
904 | codes = rearrange(codes, 'b ... -> b (...)')
905 |
906 | if not exists(face_mask):
907 | face_mask = reduce(codes != self.pad_id, 'b (nf nvf q) -> b nf', 'all', nvf = self.num_vertices_per_face, q = self.num_quantizers)
908 |
909 | # handle different code shapes
910 |
911 | codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)
912 |
913 | # decode
914 |
915 | quantized = self.quantizer.get_output_from_indices(codes)
916 | quantized = rearrange(quantized, 'b (nf nvf) d -> b nf (nvf d)', nvf = self.num_vertices_per_face)
917 |
918 | decoded = self.decode(
919 | quantized,
920 | face_mask = face_mask
921 | )
922 |
923 | decoded = decoded.masked_fill(~face_mask[..., None], 0.)
924 | pred_face_coords = self.to_coor_logits(decoded)
925 |
926 | pred_face_coords = pred_face_coords.argmax(dim = -1)
927 |
928 | pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = self.num_vertices_per_face)
929 |
930 | # back to continuous space
931 |
932 | continuous_coors = undiscretize(
933 | pred_face_coords,
934 | num_discrete = self.num_discrete_coors,
935 | continuous_range = self.coor_continuous_range
936 | )
937 |
938 | # mask out with nan
939 |
940 | continuous_coors = continuous_coors.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1 1'), float('nan'))
941 |
942 | if not return_discrete_codes:
943 | return continuous_coors, face_mask
944 |
945 | return continuous_coors, pred_face_coords, face_mask
946 |
947 | @torch.no_grad()
948 | def tokenize(self, vertices, faces, face_edges = None, **kwargs):
949 | assert 'return_codes' not in kwargs
950 |
951 | inputs = [vertices, faces, face_edges]
952 | inputs = [*filter(exists, inputs)]
953 | ndims = {i.ndim for i in inputs}
954 |
955 | assert len(ndims) == 1
956 | batch_less = first(list(ndims)) == 2
957 |
958 | if batch_less:
959 | inputs = [rearrange(i, '... -> 1 ...') for i in inputs]
960 |
961 | input_kwargs = dict(zip(['vertices', 'faces', 'face_edges'], inputs))
962 |
963 | self.eval()
964 |
965 | codes = self.forward(
966 | **input_kwargs,
967 | return_codes = True,
968 | **kwargs
969 | )
970 |
971 | if batch_less:
972 | codes = rearrange(codes, '1 ... -> ...')
973 |
974 | return codes
975 |
976 | @typecheck
977 | def forward(
978 | self,
979 | *,
980 | vertices: Float['b nv 3'],
981 | faces: Int['b nf nvf'],
982 | face_edges: Int['b e 2'] | None = None,
983 | return_codes = False,
984 | return_loss_breakdown = False,
985 | return_recon_faces = False,
986 | only_return_recon_faces = False,
987 | rvq_sample_codebook_temp = 1.
988 | ):
989 | if not exists(face_edges):
990 | face_edges = derive_face_edges_from_faces(faces, pad_id = self.pad_id)
991 |
992 | device = faces.device
993 |
994 | face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all')
995 | face_edges_mask = reduce(face_edges != self.pad_id, 'b e ij -> b e', 'all')
996 |
997 | encoded, face_coordinates = self.encode(
998 | vertices = vertices,
999 | faces = faces,
1000 | face_edges = face_edges,
1001 | face_edges_mask = face_edges_mask,
1002 | face_mask = face_mask,
1003 | return_face_coordinates = True
1004 | )
1005 |
1006 | quantized, codes, commit_loss = self.quantize(
1007 | face_embed = encoded,
1008 | faces = faces,
1009 | face_mask = face_mask,
1010 | rvq_sample_codebook_temp = rvq_sample_codebook_temp
1011 | )
1012 |
1013 | if return_codes:
1014 | assert not return_recon_faces, 'cannot return reconstructed faces when just returning raw codes'
1015 |
1016 | codes = codes.masked_fill(~repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face), self.pad_id)
1017 | return codes
1018 |
1019 | decode = self.decode(
1020 | quantized,
1021 | face_mask = face_mask
1022 | )
1023 |
1024 | pred_face_coords = self.to_coor_logits(decode)
1025 |
1026 | # compute reconstructed faces if needed
1027 |
1028 | if return_recon_faces or only_return_recon_faces:
1029 |
1030 | recon_faces = undiscretize(
1031 | pred_face_coords.argmax(dim = -1),
1032 | num_discrete = self.num_discrete_coors,
1033 | continuous_range = self.coor_continuous_range,
1034 | )
1035 |
1036 | recon_faces = rearrange(recon_faces, 'b nf (nvf c) -> b nf nvf c', nvf = self.num_vertices_per_face)
1037 | face_mask = rearrange(face_mask, 'b nf -> b nf 1 1')
1038 | recon_faces = recon_faces.masked_fill(~face_mask, float('nan'))
1039 | face_mask = rearrange(face_mask, 'b nf 1 1 -> b nf')
1040 |
1041 | if only_return_recon_faces:
1042 | return recon_faces
1043 |
1044 | # prepare for recon loss
1045 |
1046 | pred_face_coords = rearrange(pred_face_coords, 'b ... c -> b c (...)')
1047 | face_coordinates = rearrange(face_coordinates, 'b ... -> b 1 (...)')
1048 |
1049 | # reconstruction loss on discretized coordinates on each face
1050 | # they also smooth (blur) the one hot positions, localized label smoothing basically
1051 |
1052 | with autocast(enabled = False):
1053 | pred_log_prob = pred_face_coords.log_softmax(dim = 1)
1054 |
1055 | target_one_hot = torch.zeros_like(pred_log_prob).scatter(1, face_coordinates, 1.)
1056 |
1057 | if self.bin_smooth_blur_sigma >= 0.:
1058 | target_one_hot = gaussian_blur_1d(target_one_hot, sigma = self.bin_smooth_blur_sigma)
1059 |
1060 | # cross entropy with localized smoothing
1061 |
1062 | recon_losses = (-target_one_hot * pred_log_prob).sum(dim = 1)
1063 |
1064 | face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = self.num_vertices_per_face * 3)
1065 | recon_loss = recon_losses[face_mask].mean()
1066 |
1067 | # calculate total loss
1068 |
1069 | total_loss = recon_loss + \
1070 | commit_loss.sum() * self.commit_loss_weight
1071 |
1072 | # calculate loss breakdown if needed
1073 |
1074 | loss_breakdown = (recon_loss, commit_loss)
1075 |
1076 | # some return logic
1077 |
1078 | if not return_loss_breakdown:
1079 | if not return_recon_faces:
1080 | return total_loss
1081 |
1082 | return recon_faces, total_loss
1083 |
1084 | if not return_recon_faces:
1085 | return total_loss, loss_breakdown
1086 |
1087 | return recon_faces, total_loss, loss_breakdown
1088 |
1089 | @save_load(version = __version__)
1090 | class MeshTransformer(Module, PyTorchModelHubMixin):
1091 | @typecheck
1092 | def __init__(
1093 | self,
1094 | autoencoder: MeshAutoencoder,
1095 | *,
1096 | dim: int | Tuple[int, int] = 512,
1097 | max_seq_len = 8192,
1098 | flash_attn = True,
1099 | attn_depth = 12,
1100 | attn_dim_head = 64,
1101 | attn_heads = 16,
1102 | attn_kwargs: dict = dict(
1103 | ff_glu = True,
1104 | attn_num_mem_kv = 4
1105 | ),
1106 | cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out text condition
1107 | dropout = 0.,
1108 | coarse_pre_gateloop_depth = 2,
1109 | coarse_post_gateloop_depth = 0,
1110 | coarse_adaptive_rmsnorm = False,
1111 | fine_pre_gateloop_depth = 2,
1112 | gateloop_use_heinsen = False,
1113 | fine_attn_depth = 2,
1114 | fine_attn_dim_head = 32,
1115 | fine_attn_heads = 8,
1116 | fine_cross_attend_text = False, # additional conditioning - fine transformer cross attention to text tokens
1117 | pad_id = -1,
1118 | num_sos_tokens = None,
1119 | condition_on_text = False,
1120 | text_cond_with_film = False,
1121 | text_condition_model_types = ('t5',),
1122 | text_condition_model_kwargs = (dict(),),
1123 | text_condition_cond_drop_prob = 0.25,
1124 | quads = False,
1125 | ):
1126 | super().__init__()
1127 | self.num_vertices_per_face = 3 if not quads else 4
1128 |
1129 | assert autoencoder.num_vertices_per_face == self.num_vertices_per_face, 'autoencoder and transformer must both support the same type of mesh (either all triangles, or all quads)'
1130 |
1131 | dim, dim_fine = (dim, dim) if isinstance(dim, int) else dim
1132 |
1133 | self.autoencoder = autoencoder
1134 | set_module_requires_grad_(autoencoder, False)
1135 |
1136 | self.codebook_size = autoencoder.codebook_size
1137 | self.num_quantizers = autoencoder.num_quantizers
1138 |
1139 | self.eos_token_id = self.codebook_size
1140 |
1141 | # the fine transformer sos token
1142 | # as well as a projection of pooled text embeddings to condition it
1143 |
1144 | num_sos_tokens = default(num_sos_tokens, 1 if not condition_on_text else 4)
1145 | assert num_sos_tokens > 0
1146 |
1147 | self.num_sos_tokens = num_sos_tokens
1148 | self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim))
1149 |
1150 | # they use axial positional embeddings
1151 |
1152 | assert divisible_by(max_seq_len, self.num_vertices_per_face * self.num_quantizers), f'max_seq_len ({max_seq_len}) must be divisible by (3 x {self.num_quantizers}) = {3 * self.num_quantizers}' # 3 or 4 vertices per face, with D codes per vertex
1153 |
1154 | self.token_embed = nn.Embedding(self.codebook_size + 1, dim)
1155 |
1156 | self.quantize_level_embed = nn.Parameter(torch.randn(self.num_quantizers, dim))
1157 | self.vertex_embed = nn.Parameter(torch.randn(self.num_vertices_per_face, dim))
1158 |
1159 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim)
1160 |
1161 | self.max_seq_len = max_seq_len
1162 |
1163 | # text condition
1164 |
1165 | self.condition_on_text = condition_on_text
1166 | self.conditioner = None
1167 |
1168 | cross_attn_dim_context = None
1169 | dim_text = None
1170 |
1171 | if condition_on_text:
1172 | self.conditioner = TextEmbeddingReturner(
1173 | model_types = text_condition_model_types,
1174 | model_kwargs = text_condition_model_kwargs,
1175 | cond_drop_prob = text_condition_cond_drop_prob,
1176 | text_embed_pad_value = -1.
1177 | )
1178 |
1179 | dim_text = self.conditioner.dim_latent
1180 | cross_attn_dim_context = dim_text
1181 |
1182 | self.text_coarse_film_cond = FiLM(dim_text, dim) if text_cond_with_film else identity
1183 | self.text_fine_film_cond = FiLM(dim_text, dim_fine) if text_cond_with_film else identity
1184 |
1185 | # for summarizing the vertices of each face
1186 |
1187 | self.to_face_tokens = nn.Sequential(
1188 | nn.Linear(self.num_quantizers * self.num_vertices_per_face * dim, dim),
1189 | nn.LayerNorm(dim)
1190 | )
1191 |
1192 | self.coarse_gateloop_block = GateLoopBlock(dim, depth = coarse_pre_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None
1193 |
1194 | self.coarse_post_gateloop_block = GateLoopBlock(dim, depth = coarse_post_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None
1195 |
1196 | # main autoregressive attention network
1197 | # attending to a face token
1198 |
1199 | self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm
1200 |
1201 | self.decoder = Decoder(
1202 | dim = dim,
1203 | depth = attn_depth,
1204 | heads = attn_heads,
1205 | attn_dim_head = attn_dim_head,
1206 | attn_flash = flash_attn,
1207 | attn_dropout = dropout,
1208 | ff_dropout = dropout,
1209 | use_adaptive_rmsnorm = coarse_adaptive_rmsnorm,
1210 | dim_condition = dim_text,
1211 | cross_attend = condition_on_text,
1212 | cross_attn_dim_context = cross_attn_dim_context,
1213 | cross_attn_num_mem_kv = cross_attn_num_mem_kv,
1214 | **attn_kwargs
1215 | )
1216 |
1217 | # projection from coarse to fine, if needed
1218 |
1219 | self.maybe_project_coarse_to_fine = nn.Linear(dim, dim_fine) if dim != dim_fine else nn.Identity()
1220 |
1221 | # address a weakness in attention
1222 |
1223 | self.fine_gateloop_block = GateLoopBlock(dim, depth = fine_pre_gateloop_depth, use_heinsen = gateloop_use_heinsen) if fine_pre_gateloop_depth > 0 else None
1224 |
1225 | # decoding the vertices, 2-stage hierarchy
1226 |
1227 | self.fine_cross_attend_text = condition_on_text and fine_cross_attend_text
1228 |
1229 | self.fine_decoder = Decoder(
1230 | dim = dim_fine,
1231 | depth = fine_attn_depth,
1232 | heads = attn_heads,
1233 | attn_dim_head = attn_dim_head,
1234 | attn_flash = flash_attn,
1235 | attn_dropout = dropout,
1236 | ff_dropout = dropout,
1237 | cross_attend = self.fine_cross_attend_text,
1238 | cross_attn_dim_context = cross_attn_dim_context,
1239 | cross_attn_num_mem_kv = cross_attn_num_mem_kv,
1240 | **attn_kwargs
1241 | )
1242 |
1243 | # to logits
1244 |
1245 | self.to_logits = nn.Linear(dim_fine, self.codebook_size + 1)
1246 |
1247 | # padding id
1248 | # force the autoencoder to use the same pad_id given in transformer
1249 |
1250 | self.pad_id = pad_id
1251 | autoencoder.pad_id = pad_id
1252 |
1253 | @classmethod
1254 | def _from_pretrained(
1255 | cls,
1256 | *,
1257 | model_id: str,
1258 | revision: str | None,
1259 | cache_dir: str | Path | None,
1260 | force_download: bool,
1261 | proxies: Dict | None,
1262 | resume_download: bool,
1263 | local_files_only: bool,
1264 | token: str | bool | None,
1265 | map_location: str = "cpu",
1266 | strict: bool = False,
1267 | **model_kwargs,
1268 | ):
1269 | model_filename = "mesh-transformer.bin"
1270 | model_file = Path(model_id) / model_filename
1271 |
1272 | if not model_file.exists():
1273 | model_file = hf_hub_download(
1274 | repo_id=model_id,
1275 | filename=model_filename,
1276 | revision=revision,
1277 | cache_dir=cache_dir,
1278 | force_download=force_download,
1279 | proxies=proxies,
1280 | resume_download=resume_download,
1281 | token=token,
1282 | local_files_only=local_files_only,
1283 | )
1284 |
1285 | model = cls.init_and_load(model_file,strict=strict)
1286 | model.to(map_location)
1287 | return model
1288 |
1289 | @property
1290 | def device(self):
1291 | return next(self.parameters()).device
1292 |
1293 | @typecheck
1294 | @torch.no_grad()
1295 | def embed_texts(self, texts: str | List[str]):
1296 | single_text = not isinstance(texts, list)
1297 | if single_text:
1298 | texts = [texts]
1299 |
1300 | assert exists(self.conditioner)
1301 | text_embeds = self.conditioner.embed_texts(texts).detach()
1302 |
1303 | if single_text:
1304 | text_embeds = text_embeds[0]
1305 |
1306 | return text_embeds
1307 |
1308 | @eval_decorator
1309 | @torch.no_grad()
1310 | @typecheck
1311 | def generate(
1312 | self,
1313 | prompt: Tensor | None = None,
1314 | batch_size: int | None = None,
1315 | filter_logits_fn: Callable = top_k,
1316 | filter_kwargs: dict = dict(),
1317 | temperature = 1.,
1318 | return_codes = False,
1319 | texts: List[str] | None = None,
1320 | text_embeds: Tensor | None = None,
1321 | cond_scale = 1.,
1322 | cache_kv = True,
1323 | max_seq_len = None,
1324 | face_coords_to_file: Callable[[Tensor], Any] | None = None
1325 | ):
1326 | max_seq_len = default(max_seq_len, self.max_seq_len)
1327 |
1328 | if exists(prompt):
1329 | assert not exists(batch_size)
1330 |
1331 | prompt = rearrange(prompt, 'b ... -> b (...)')
1332 | assert prompt.shape[-1] <= self.max_seq_len
1333 |
1334 | batch_size = prompt.shape[0]
1335 |
1336 | if self.condition_on_text:
1337 | assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True'
1338 | if exists(texts):
1339 | text_embeds = self.embed_texts(texts)
1340 |
1341 | batch_size = default(batch_size, text_embeds.shape[0])
1342 |
1343 | batch_size = default(batch_size, 1)
1344 |
1345 | codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device))
1346 |
1347 | curr_length = codes.shape[-1]
1348 |
1349 | cache = (None, None)
1350 |
1351 | for i in tqdm(range(curr_length, max_seq_len)):
1352 |
1353 | # example below for triangles, extrapolate for quads
1354 | # v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F)
1355 |
1356 | can_eos = i != 0 and divisible_by(i, self.num_quantizers * self.num_vertices_per_face) # only allow for eos to be decoded at the end of each face, defined as 3 or 4 vertices with D residual VQ codes
1357 |
1358 | output = self.forward_on_codes(
1359 | codes,
1360 | text_embeds = text_embeds,
1361 | return_loss = False,
1362 | return_cache = cache_kv,
1363 | append_eos = False,
1364 | cond_scale = cond_scale,
1365 | cfg_routed_kwargs = dict(
1366 | cache = cache
1367 | )
1368 | )
1369 |
1370 | if cache_kv:
1371 | logits, cache = output
1372 |
1373 | if cond_scale == 1.:
1374 | cache = (cache, None)
1375 | else:
1376 | logits = output
1377 |
1378 | logits = logits[:, -1]
1379 |
1380 | if not can_eos:
1381 | logits[:, -1] = -torch.finfo(logits.dtype).max
1382 |
1383 | filtered_logits = filter_logits_fn(logits, **filter_kwargs)
1384 |
1385 | if temperature == 0.:
1386 | sample = filtered_logits.argmax(dim = -1)
1387 | else:
1388 | probs = F.softmax(filtered_logits / temperature, dim = -1)
1389 | sample = torch.multinomial(probs, 1)
1390 |
1391 | codes, _ = pack([codes, sample], 'b *')
1392 |
1393 | # check for all rows to have [eos] to terminate
1394 |
1395 | is_eos_codes = (codes == self.eos_token_id)
1396 |
1397 | if is_eos_codes.any(dim = -1).all():
1398 | break
1399 |
1400 | # mask out to padding anything after the first eos
1401 |
1402 | mask = is_eos_codes.float().cumsum(dim = -1) >= 1
1403 | codes = codes.masked_fill(mask, self.pad_id)
1404 |
1405 | # remove a potential extra token from eos, if breaked early
1406 |
1407 | code_len = codes.shape[-1]
1408 | round_down_code_len = code_len // self.num_vertices_per_face * self.num_vertices_per_face
1409 | codes = codes[:, :round_down_code_len]
1410 |
1411 | # early return of raw residual quantizer codes
1412 |
1413 | if return_codes:
1414 | codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)
1415 | return codes
1416 |
1417 | self.autoencoder.eval()
1418 | face_coords, face_mask = self.autoencoder.decode_from_codes_to_faces(codes)
1419 |
1420 | if not exists(face_coords_to_file):
1421 | return face_coords, face_mask
1422 |
1423 | files = [face_coords_to_file(coords[mask]) for coords, mask in zip(face_coords, face_mask)]
1424 | return files
1425 |
1426 | def forward(
1427 | self,
1428 | *,
1429 | vertices: Int['b nv 3'],
1430 | faces: Int['b nf nvf'],
1431 | face_edges: Int['b e 2'] | None = None,
1432 | codes: Tensor | None = None,
1433 | cache: LayerIntermediates | None = None,
1434 | **kwargs
1435 | ):
1436 | if not exists(codes):
1437 | codes = self.autoencoder.tokenize(
1438 | vertices = vertices,
1439 | faces = faces,
1440 | face_edges = face_edges
1441 | )
1442 |
1443 | return self.forward_on_codes(codes, cache = cache, **kwargs)
1444 |
1445 | @classifier_free_guidance
1446 | def forward_on_codes(
1447 | self,
1448 | codes = None,
1449 | return_loss = True,
1450 | return_cache = False,
1451 | append_eos = True,
1452 | cache = None,
1453 | texts: List[str] | None = None,
1454 | text_embeds: Tensor | None = None,
1455 | cond_drop_prob = None
1456 | ):
1457 | # handle text conditions
1458 |
1459 | attn_context_kwargs = dict()
1460 |
1461 | if self.condition_on_text:
1462 | assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True'
1463 |
1464 | if exists(texts):
1465 | text_embeds = self.conditioner.embed_texts(texts)
1466 |
1467 | if exists(codes):
1468 | assert text_embeds.shape[0] == codes.shape[0], 'batch size of texts or text embeddings is not equal to the batch size of the mesh codes'
1469 |
1470 | _, maybe_dropped_text_embeds = self.conditioner(
1471 | text_embeds = text_embeds,
1472 | cond_drop_prob = cond_drop_prob
1473 | )
1474 |
1475 | text_embed, text_mask = maybe_dropped_text_embeds
1476 |
1477 | pooled_text_embed = masked_mean(text_embed, text_mask, dim = 1)
1478 |
1479 | attn_context_kwargs = dict(
1480 | context = text_embed,
1481 | context_mask = text_mask
1482 | )
1483 |
1484 | if self.coarse_adaptive_rmsnorm:
1485 | attn_context_kwargs.update(
1486 | condition = pooled_text_embed
1487 | )
1488 |
1489 | # take care of codes that may be flattened
1490 |
1491 | if codes.ndim > 2:
1492 | codes = rearrange(codes, 'b ... -> b (...)')
1493 |
1494 | # get some variable
1495 |
1496 | batch, seq_len, device = *codes.shape, codes.device
1497 |
1498 | assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}'
1499 |
1500 | # auto append eos token
1501 |
1502 | if append_eos:
1503 | assert exists(codes)
1504 |
1505 | code_lens = ((codes == self.pad_id).cumsum(dim = -1) == 0).sum(dim = -1)
1506 |
1507 | codes = F.pad(codes, (0, 1), value = 0)
1508 |
1509 | batch_arange = torch.arange(batch, device = device)
1510 |
1511 | batch_arange = rearrange(batch_arange, '... -> ... 1')
1512 | code_lens = rearrange(code_lens, '... -> ... 1')
1513 |
1514 | codes[batch_arange, code_lens] = self.eos_token_id
1515 |
1516 | # if returning loss, save the labels for cross entropy
1517 |
1518 | if return_loss:
1519 | assert seq_len > 0
1520 | codes, labels = codes[:, :-1], codes
1521 |
1522 | # token embed (each residual VQ id)
1523 |
1524 | codes = codes.masked_fill(codes == self.pad_id, 0)
1525 | codes = self.token_embed(codes)
1526 |
1527 | # codebook embed + absolute positions
1528 |
1529 | seq_arange = torch.arange(codes.shape[-2], device = device)
1530 |
1531 | codes = codes + self.abs_pos_emb(seq_arange)
1532 |
1533 | # embedding for quantizer level
1534 |
1535 | code_len = codes.shape[1]
1536 |
1537 | level_embed = repeat(self.quantize_level_embed, 'q d -> (r q) d', r = ceil(code_len / self.num_quantizers))
1538 | codes = codes + level_embed[:code_len]
1539 |
1540 | # embedding for each vertex
1541 |
1542 | vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (self.num_vertices_per_face * self.num_quantizers)), q = self.num_quantizers)
1543 | codes = codes + vertex_embed[:code_len]
1544 |
1545 | # create a token per face, by summarizing the 3 or 4 vertices
1546 | # this is similar in design to the RQ transformer from Lee et al. https://arxiv.org/abs/2203.01941
1547 |
1548 | num_tokens_per_face = self.num_quantizers * self.num_vertices_per_face
1549 |
1550 | curr_vertex_pos = code_len % num_tokens_per_face # the current intra-face vertex-code position id, needed for caching at the fine decoder stage
1551 |
1552 | code_len_is_multiple_of_face = divisible_by(code_len, num_tokens_per_face)
1553 |
1554 | next_multiple_code_len = ceil(code_len / num_tokens_per_face) * num_tokens_per_face
1555 |
1556 | codes = pad_to_length(codes, next_multiple_code_len, dim = -2)
1557 |
1558 | # grouped codes will be used for the second stage
1559 |
1560 | grouped_codes = rearrange(codes, 'b (nf n) d -> b nf n d', n = num_tokens_per_face)
1561 |
1562 | # create the coarse tokens for the first attention network
1563 |
1564 | face_codes = grouped_codes if code_len_is_multiple_of_face else grouped_codes[:, :-1]
1565 | face_codes = rearrange(face_codes, 'b nf n d -> b nf (n d)')
1566 | face_codes = self.to_face_tokens(face_codes)
1567 |
1568 | face_codes_len = face_codes.shape[-2]
1569 |
1570 | # cache logic
1571 |
1572 | (
1573 | cached_attended_face_codes,
1574 | coarse_cache,
1575 | fine_cache,
1576 | coarse_gateloop_cache,
1577 | coarse_post_gateloop_cache,
1578 | fine_gateloop_cache
1579 | ) = cache if exists(cache) else ((None,) * 6)
1580 |
1581 | if exists(cache):
1582 | cached_face_codes_len = cached_attended_face_codes.shape[-2]
1583 | cached_face_codes_len_without_sos = cached_face_codes_len - 1
1584 |
1585 | need_call_first_transformer = face_codes_len > cached_face_codes_len_without_sos
1586 | else:
1587 | # auto prepend sos token
1588 |
1589 | sos = repeat(self.sos_token, 'n d -> b n d', b = batch)
1590 | face_codes, packed_sos_shape = pack([sos, face_codes], 'b * d')
1591 |
1592 | # if no kv cache, always call first transformer
1593 |
1594 | need_call_first_transformer = True
1595 |
1596 | should_cache_fine = not divisible_by(curr_vertex_pos + 1, num_tokens_per_face)
1597 |
1598 | # condition face codes with text if needed
1599 |
1600 | if self.condition_on_text:
1601 | face_codes = self.text_coarse_film_cond(face_codes, pooled_text_embed)
1602 |
1603 | # attention on face codes (coarse)
1604 |
1605 | if need_call_first_transformer:
1606 | if exists(self.coarse_gateloop_block):
1607 | face_codes, coarse_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_gateloop_cache)
1608 |
1609 | attended_face_codes, coarse_cache = self.decoder(
1610 | face_codes,
1611 | cache = coarse_cache,
1612 | return_hiddens = True,
1613 | **attn_context_kwargs
1614 | )
1615 |
1616 | if exists(self.coarse_post_gateloop_block):
1617 | face_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(face_codes, cache = coarse_post_gateloop_cache)
1618 |
1619 | else:
1620 | attended_face_codes = None
1621 |
1622 | attended_face_codes = safe_cat((cached_attended_face_codes, attended_face_codes), dim = -2)
1623 |
1624 | # if calling without kv cache, pool the sos tokens, if greater than 1 sos token
1625 |
1626 | if not exists(cache):
1627 | sos_tokens, attended_face_codes = unpack(attended_face_codes, packed_sos_shape, 'b * d')
1628 | last_sos_token = sos_tokens[:, -1:]
1629 | attended_face_codes = torch.cat((last_sos_token, attended_face_codes), dim = 1)
1630 |
1631 | # maybe project from coarse to fine dimension for hierarchical transformers
1632 |
1633 | attended_face_codes = self.maybe_project_coarse_to_fine(attended_face_codes)
1634 |
1635 | grouped_codes = pad_to_length(grouped_codes, attended_face_codes.shape[-2], dim = 1)
1636 | fine_vertex_codes, _ = pack([attended_face_codes, grouped_codes], 'b n * d')
1637 |
1638 | fine_vertex_codes = fine_vertex_codes[..., :-1, :]
1639 |
1640 | # gateloop layers
1641 |
1642 | if exists(self.fine_gateloop_block):
1643 | fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> b (nf n) d')
1644 | orig_length = fine_vertex_codes.shape[-2]
1645 | fine_vertex_codes = fine_vertex_codes[:, :(code_len + 1)]
1646 |
1647 | fine_vertex_codes, fine_gateloop_cache = self.fine_gateloop_block(fine_vertex_codes, cache = fine_gateloop_cache)
1648 |
1649 | fine_vertex_codes = pad_to_length(fine_vertex_codes, orig_length, dim = -2)
1650 | fine_vertex_codes = rearrange(fine_vertex_codes, 'b (nf n) d -> b nf n d', n = num_tokens_per_face)
1651 |
1652 | # fine attention - 2nd stage
1653 |
1654 | if exists(cache):
1655 | fine_vertex_codes = fine_vertex_codes[:, -1:]
1656 |
1657 | if exists(fine_cache):
1658 | for attn_intermediate in fine_cache.attn_intermediates:
1659 | ck, cv = attn_intermediate.cached_kv
1660 | ck, cv = [rearrange(t, '(b nf) ... -> b nf ...', b = batch) for t in (ck, cv)]
1661 |
1662 | # when operating on the cached key / values, treat self attention and cross attention differently
1663 |
1664 | layer_type = attn_intermediate.layer_type
1665 |
1666 | if layer_type == 'a':
1667 | ck, cv = [t[:, -1, :, :curr_vertex_pos] for t in (ck, cv)]
1668 | elif layer_type == 'c':
1669 | ck, cv = [t[:, -1, ...] for t in (ck, cv)]
1670 |
1671 | attn_intermediate.cached_kv = (ck, cv)
1672 |
1673 | num_faces = fine_vertex_codes.shape[1]
1674 | one_face = num_faces == 1
1675 |
1676 | fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> (b nf) n d')
1677 |
1678 | if one_face:
1679 | fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]
1680 |
1681 | # handle maybe cross attention conditioning of fine transformer with text
1682 |
1683 | fine_attn_context_kwargs = dict()
1684 |
1685 | # optional text cross attention conditioning for fine transformer
1686 |
1687 | if self.fine_cross_attend_text:
1688 | repeat_batch = fine_vertex_codes.shape[0] // text_embed.shape[0]
1689 |
1690 | text_embed = repeat(text_embed, 'b ... -> (b r) ...' , r = repeat_batch)
1691 | text_mask = repeat(text_mask, 'b ... -> (b r) ...', r = repeat_batch)
1692 |
1693 | fine_attn_context_kwargs = dict(
1694 | context = text_embed,
1695 | context_mask = text_mask
1696 | )
1697 |
1698 | # also film condition the fine vertex codes
1699 |
1700 | if self.condition_on_text:
1701 | repeat_batch = fine_vertex_codes.shape[0] // pooled_text_embed.shape[0]
1702 |
1703 | pooled_text_embed = repeat(pooled_text_embed, 'b ... -> (b r) ...', r = repeat_batch)
1704 | fine_vertex_codes = self.text_fine_film_cond(fine_vertex_codes, pooled_text_embed)
1705 |
1706 | # fine transformer
1707 |
1708 | attended_vertex_codes, fine_cache = self.fine_decoder(
1709 | fine_vertex_codes,
1710 | cache = fine_cache,
1711 | **fine_attn_context_kwargs,
1712 | return_hiddens = True
1713 | )
1714 |
1715 | if not should_cache_fine:
1716 | fine_cache = None
1717 |
1718 | if not one_face:
1719 | # reconstitute original sequence
1720 |
1721 | embed = rearrange(attended_vertex_codes, '(b nf) n d -> b (nf n) d', b = batch)
1722 | embed = embed[:, :(code_len + 1)]
1723 | else:
1724 | embed = attended_vertex_codes
1725 |
1726 | # logits
1727 |
1728 | logits = self.to_logits(embed)
1729 |
1730 | if not return_loss:
1731 | if not return_cache:
1732 | return logits
1733 |
1734 | next_cache = (
1735 | attended_face_codes,
1736 | coarse_cache,
1737 | fine_cache,
1738 | coarse_gateloop_cache,
1739 | coarse_post_gateloop_cache,
1740 | fine_gateloop_cache
1741 | )
1742 |
1743 | return logits, next_cache
1744 |
1745 | # loss
1746 |
1747 | ce_loss = F.cross_entropy(
1748 | rearrange(logits, 'b n c -> b c n'),
1749 | labels,
1750 | ignore_index = self.pad_id
1751 | )
1752 |
1753 | return ce_loss
1754 |
--------------------------------------------------------------------------------
/meshgpt_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from functools import partial
5 | from packaging import version
6 | from contextlib import nullcontext
7 |
8 | import torch
9 | from torch.nn import Module
10 | from torch.utils.data import Dataset, DataLoader
11 | from torch.optim.lr_scheduler import _LRScheduler
12 |
13 | from pytorch_custom_utils import (
14 | get_adam_optimizer,
15 | OptimizerWithWarmupSchedule
16 | )
17 |
18 | from accelerate import Accelerator
19 | from accelerate.utils import DistributedDataParallelKwargs
20 |
21 |
22 | from beartype.typing import Tuple, Type, List
23 | from meshgpt_pytorch.typing import typecheck, beartype_isinstance
24 |
25 | from ema_pytorch import EMA
26 |
27 | from meshgpt_pytorch.data import custom_collate
28 |
29 | from meshgpt_pytorch.version import __version__
30 | import matplotlib.pyplot as plt
31 | from tqdm import tqdm
32 | from meshgpt_pytorch.meshgpt_pytorch import (
33 | MeshAutoencoder,
34 | MeshTransformer
35 | )
36 |
37 | # constants
38 |
39 | DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
40 | find_unused_parameters = True
41 | )
42 |
43 | # helper functions
44 |
45 | def exists(v):
46 | return v is not None
47 |
48 | def default(v, d):
49 | return v if exists(v) else d
50 |
51 | def divisible_by(num, den):
52 | return (num % den) == 0
53 |
54 | def cycle(dl):
55 | while True:
56 | for data in dl:
57 | yield data
58 |
59 | def maybe_del(d: dict, *keys):
60 | for key in keys:
61 | if key not in d:
62 | continue
63 |
64 | del d[key]
65 |
66 | # autoencoder trainer
67 |
68 | class MeshAutoencoderTrainer(Module):
69 | @typecheck
70 | def __init__(
71 | self,
72 | model: MeshAutoencoder,
73 | dataset: Dataset,
74 | num_train_steps: int,
75 | batch_size: int,
76 | grad_accum_every: int,
77 | val_dataset: Dataset | None = None,
78 | val_every: int = 100,
79 | val_num_batches: int = 5,
80 | learning_rate: float = 1e-4,
81 | weight_decay: float = 0.,
82 | max_grad_norm: float | None = None,
83 | ema_kwargs: dict = dict(
84 | use_foreach = True
85 | ),
86 | scheduler: Type[_LRScheduler] | None = None,
87 | scheduler_kwargs: dict = dict(),
88 | accelerator_kwargs: dict = dict(),
89 | optimizer_kwargs: dict = dict(),
90 | checkpoint_every = 1000,
91 | checkpoint_every_epoch: Type[int] | None = None,
92 | checkpoint_folder = './checkpoints',
93 | data_kwargs: Tuple[str, ...] = ('vertices', 'faces', 'face_edges'),
94 | warmup_steps = 1000,
95 | use_wandb_tracking = False
96 | ):
97 | super().__init__()
98 |
99 | # experiment tracker
100 |
101 | self.use_wandb_tracking = use_wandb_tracking
102 |
103 | if use_wandb_tracking:
104 | accelerator_kwargs['log_with'] = 'wandb'
105 |
106 | if 'kwargs_handlers' not in accelerator_kwargs:
107 | accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS]
108 |
109 | # accelerator
110 |
111 | self.accelerator = Accelerator(**accelerator_kwargs)
112 |
113 | self.model = model
114 |
115 | if self.is_main:
116 | self.ema_model = EMA(model, **ema_kwargs)
117 |
118 | self.optimizer = OptimizerWithWarmupSchedule(
119 | accelerator = self.accelerator,
120 | optimizer = get_adam_optimizer(model.parameters(), lr = learning_rate, wd = weight_decay, **optimizer_kwargs),
121 | scheduler = scheduler,
122 | scheduler_kwargs = scheduler_kwargs,
123 | warmup_steps = warmup_steps,
124 | max_grad_norm = max_grad_norm
125 | )
126 |
127 | self.dataloader = DataLoader(
128 | dataset,
129 | shuffle = True,
130 | batch_size = batch_size,
131 | drop_last = True,
132 | collate_fn = partial(custom_collate, pad_id = model.pad_id)
133 | )
134 |
135 | self.should_validate = exists(val_dataset)
136 |
137 | if self.should_validate:
138 | assert len(val_dataset) > 0, 'your validation dataset is empty'
139 |
140 | self.val_every = val_every
141 | self.val_num_batches = val_num_batches
142 |
143 | self.val_dataloader = DataLoader(
144 | val_dataset,
145 | shuffle = True,
146 | batch_size = batch_size,
147 | drop_last = True,
148 | collate_fn = partial(custom_collate, pad_id = model.pad_id)
149 | )
150 |
151 | if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs):
152 | assert beartype_isinstance(dataset.data_kwargs, List[str])
153 | self.data_kwargs = dataset.data_kwargs
154 | else:
155 | self.data_kwargs = data_kwargs
156 |
157 |
158 | (
159 | self.model,
160 | self.dataloader
161 | ) = self.accelerator.prepare(
162 | self.model,
163 | self.dataloader
164 | )
165 |
166 | self.grad_accum_every = grad_accum_every
167 | self.num_train_steps = num_train_steps
168 | self.register_buffer('step', torch.tensor(0))
169 |
170 | self.checkpoint_every_epoch = checkpoint_every_epoch
171 | self.checkpoint_every = checkpoint_every
172 | self.checkpoint_folder = Path(checkpoint_folder)
173 | self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
174 |
175 | @property
176 | def ema_tokenizer(self):
177 | return self.ema_model.ema_model
178 |
179 | def tokenize(self, *args, **kwargs):
180 | return self.ema_tokenizer.tokenize(*args, **kwargs)
181 |
182 | def log(self, **data_kwargs):
183 | self.accelerator.log(data_kwargs, step = self.step.item())
184 |
185 | @property
186 | def device(self):
187 | return self.unwrapped_model.device
188 |
189 | @property
190 | def is_main(self):
191 | return self.accelerator.is_main_process
192 |
193 | @property
194 | def unwrapped_model(self):
195 | return self.accelerator.unwrap_model(self.model)
196 |
197 | @property
198 | def is_local_main(self):
199 | return self.accelerator.is_local_main_process
200 |
201 | def wait(self):
202 | return self.accelerator.wait_for_everyone()
203 |
204 | def print(self, msg):
205 | return self.accelerator.print(msg)
206 |
207 | def save(self, path, overwrite = True):
208 | path = Path(path)
209 | assert overwrite or not path.exists()
210 |
211 | pkg = dict(
212 | model = self.unwrapped_model.state_dict(),
213 | ema_model = self.ema_model.state_dict(),
214 | optimizer = self.optimizer.state_dict(),
215 | version = __version__,
216 | step = self.step.item(),
217 | config = self.unwrapped_model._config
218 | )
219 |
220 | torch.save(pkg, str(path))
221 |
222 | def load(self, path):
223 | path = Path(path)
224 | assert path.exists()
225 |
226 | pkg = torch.load(str(path))
227 |
228 | if version.parse(__version__) != version.parse(pkg['version']):
229 | self.print(f'loading saved mesh autoencoder at version {pkg["version"]}, but current package version is {__version__}')
230 |
231 | self.model.load_state_dict(pkg['model'])
232 | self.ema_model.load_state_dict(pkg['ema_model'])
233 | self.optimizer.load_state_dict(pkg['optimizer'])
234 |
235 | self.step.copy_(pkg['step'])
236 |
237 | def next_data_to_forward_kwargs(self, dl_iter) -> dict:
238 | data = next(dl_iter)
239 |
240 | if isinstance(data, tuple):
241 | forward_kwargs = dict(zip(self.data_kwargs, data))
242 |
243 | elif isinstance(data, dict):
244 | forward_kwargs = data
245 |
246 | maybe_del(forward_kwargs, 'texts', 'text_embeds')
247 | return forward_kwargs
248 |
249 | def forward(self):
250 | step = self.step.item()
251 | dl_iter = cycle(self.dataloader)
252 |
253 | if self.is_main and self.should_validate:
254 | val_dl_iter = cycle(self.val_dataloader)
255 |
256 | while step < self.num_train_steps:
257 |
258 | for i in range(self.grad_accum_every):
259 | is_last = i == (self.grad_accum_every - 1)
260 | maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext
261 |
262 | forward_kwargs = self.next_data_to_forward_kwargs(dl_iter)
263 |
264 | with self.accelerator.autocast(), maybe_no_sync():
265 |
266 | total_loss, (recon_loss, commit_loss) = self.model(
267 | **forward_kwargs,
268 | return_loss_breakdown = True
269 | )
270 |
271 | self.accelerator.backward(total_loss / self.grad_accum_every)
272 |
273 | self.print(f'recon loss: {recon_loss.item():.3f} | commit loss: {commit_loss.sum().item():.3f}')
274 |
275 | self.log(
276 | total_loss = total_loss.item(),
277 | commit_loss = commit_loss.sum().item(),
278 | recon_loss = recon_loss.item()
279 | )
280 |
281 | self.optimizer.step()
282 | self.optimizer.zero_grad()
283 |
284 | step += 1
285 | self.step.add_(1)
286 |
287 | self.wait()
288 |
289 | if self.is_main:
290 | self.ema_model.update()
291 |
292 | self.wait()
293 |
294 | if self.is_main and self.should_validate and divisible_by(step, self.val_every):
295 |
296 | total_val_recon_loss = 0.
297 | self.ema_model.eval()
298 |
299 | num_val_batches = self.val_num_batches * self.grad_accum_every
300 |
301 | for _ in range(num_val_batches):
302 | with self.accelerator.autocast(), torch.no_grad():
303 |
304 | forward_kwargs = self.next_data_to_forward_kwargs(val_dl_iter)
305 |
306 | val_loss, (val_recon_loss, val_commit_loss) = self.ema_model(
307 | **forward_kwargs,
308 | return_loss_breakdown = True
309 | )
310 |
311 | total_val_recon_loss += (val_recon_loss / num_val_batches)
312 |
313 | self.print(f'valid recon loss: {total_val_recon_loss:.3f}')
314 |
315 | self.log(val_loss = total_val_recon_loss)
316 |
317 | self.wait()
318 |
319 | if self.is_main and divisible_by(step, self.checkpoint_every):
320 | checkpoint_num = step // self.checkpoint_every
321 | self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.{checkpoint_num}.pt')
322 |
323 | self.wait()
324 |
325 | self.print('training complete')
326 |
327 | def train(self, num_epochs, stop_at_loss = None, diplay_graph = False):
328 | epoch_losses, epoch_recon_losses, epoch_commit_losses = [] , [],[]
329 | self.model.train()
330 |
331 | for epoch in range(num_epochs):
332 | total_epoch_loss, total_epoch_recon_loss, total_epoch_commit_loss = 0.0, 0.0, 0.0
333 |
334 | progress_bar = tqdm(enumerate(self.dataloader), desc=f'Epoch {epoch + 1}/{num_epochs}', total=len(self.dataloader))
335 | for batch_idx, batch in progress_bar:
336 | is_last = (batch_idx+1) % self.grad_accum_every == 0
337 | maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext
338 |
339 | if isinstance(batch, tuple):
340 | forward_kwargs = dict(zip(self.data_kwargs, batch))
341 | elif isinstance(batch, dict):
342 | forward_kwargs = batch
343 | maybe_del(forward_kwargs, 'texts', 'text_embeds')
344 |
345 | with self.accelerator.autocast(), maybe_no_sync():
346 | total_loss, (recon_loss, commit_loss) = self.model(
347 | **forward_kwargs,
348 | return_loss_breakdown = True
349 | )
350 | self.accelerator.backward(total_loss / self.grad_accum_every)
351 |
352 | current_loss = total_loss.item()
353 | total_epoch_loss += current_loss
354 | total_epoch_recon_loss += recon_loss.item()
355 | total_epoch_commit_loss += commit_loss.sum().item()
356 |
357 | progress_bar.set_postfix(loss=current_loss, recon_loss = round(recon_loss.item(),3), commit_loss = round(commit_loss.sum().item(),4))
358 |
359 | if is_last or (batch_idx + 1 == len(self.dataloader)):
360 | self.optimizer.step()
361 | self.optimizer.zero_grad()
362 |
363 |
364 |
365 | avg_recon_loss = total_epoch_recon_loss / len(self.dataloader)
366 | avg_commit_loss = total_epoch_commit_loss / len(self.dataloader)
367 | avg_epoch_loss = total_epoch_loss / len(self.dataloader)
368 |
369 | epoch_losses.append(avg_epoch_loss)
370 | epoch_recon_losses.append(avg_recon_loss)
371 | epoch_commit_losses.append(avg_commit_loss)
372 |
373 | epochOut = f'Epoch {epoch + 1} average loss: {avg_epoch_loss} recon loss: {avg_recon_loss:.4f}: commit_loss {avg_commit_loss:.4f}'
374 |
375 | if len(epoch_losses) >= 4 and avg_epoch_loss > 0:
376 | avg_loss_improvement = sum(epoch_losses[-4:-1]) / 3 - avg_epoch_loss
377 | epochOut += f' avg loss speed: {avg_loss_improvement}'
378 | if avg_loss_improvement > 0 and avg_loss_improvement < 0.2:
379 | epochs_until_0_3 = max(0, abs(avg_epoch_loss-0.3) / avg_loss_improvement)
380 | if epochs_until_0_3> 0:
381 | epochOut += f' epochs left: {epochs_until_0_3:.2f}'
382 |
383 | self.wait()
384 | self.print(epochOut)
385 |
386 |
387 | if self.is_main and self.checkpoint_every_epoch is not None and (self.checkpoint_every_epoch == 1 or (epoch != 0 and epoch % self.checkpoint_every_epoch == 0)):
388 | self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.epoch_{epoch}_avg_loss_{avg_epoch_loss:.5f}_recon_{avg_recon_loss:.4f}_commit_{avg_commit_loss:.4f}.pt')
389 |
390 | if stop_at_loss is not None and avg_epoch_loss < stop_at_loss:
391 | self.print(f'Stopping training at epoch {epoch} with average loss {avg_epoch_loss}')
392 | if self.is_main and self.checkpoint_every_epoch is not None:
393 | self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.stop_at_loss_avg_loss_{avg_epoch_loss:.3f}.pt')
394 | break
395 |
396 | self.print('Training complete')
397 | if diplay_graph:
398 | plt.figure(figsize=(10, 5))
399 | plt.plot(range(1, len(epoch_losses)+1), epoch_losses, marker='o', label='Total Loss')
400 | plt.plot(range(1, len(epoch_losses)+1), epoch_recon_losses, marker='o', label='Recon Loss')
401 | plt.plot(range(1, len(epoch_losses)+1), epoch_commit_losses, marker='o', label='Commit Loss')
402 | plt.title('Training Loss Over Epochs')
403 | plt.xlabel('Epoch')
404 | plt.ylabel('Average Loss')
405 | plt.grid(True)
406 | plt.show()
407 | return epoch_losses[-1]
408 | # mesh transformer trainer
409 |
410 | class MeshTransformerTrainer(Module):
411 | @typecheck
412 | def __init__(
413 | self,
414 | model: MeshTransformer,
415 | dataset: Dataset,
416 | num_train_steps: int,
417 | batch_size: int,
418 | grad_accum_every: int,
419 | learning_rate: float = 2e-4,
420 | weight_decay: float = 0.,
421 | max_grad_norm: float | None = 0.5,
422 | val_dataset: Dataset | None = None,
423 | val_every = 1,
424 | val_num_batches = 5,
425 | scheduler: Type[_LRScheduler] | None = None,
426 | scheduler_kwargs: dict = dict(),
427 | ema_kwargs: dict = dict(),
428 | accelerator_kwargs: dict = dict(),
429 | optimizer_kwargs: dict = dict(),
430 |
431 | checkpoint_every = 1000,
432 | checkpoint_every_epoch: Type[int] | None = None,
433 | checkpoint_folder = './checkpoints',
434 | data_kwargs: Tuple[str, ...] = ('vertices', 'faces', 'face_edges', 'text'),
435 | warmup_steps = 1000,
436 | use_wandb_tracking = False
437 | ):
438 | super().__init__()
439 |
440 | # experiment tracker
441 |
442 | self.use_wandb_tracking = use_wandb_tracking
443 |
444 | if use_wandb_tracking:
445 | accelerator_kwargs['log_with'] = 'wandb'
446 |
447 | if 'kwargs_handlers' not in accelerator_kwargs:
448 | accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS]
449 |
450 | self.accelerator = Accelerator(**accelerator_kwargs)
451 |
452 | self.model = model
453 |
454 | optimizer = get_adam_optimizer(
455 | model.parameters(),
456 | lr = learning_rate,
457 | wd = weight_decay,
458 | filter_by_requires_grad = True,
459 | **optimizer_kwargs
460 | )
461 |
462 | self.optimizer = OptimizerWithWarmupSchedule(
463 | accelerator = self.accelerator,
464 | optimizer = optimizer,
465 | scheduler = scheduler,
466 | scheduler_kwargs = scheduler_kwargs,
467 | warmup_steps = warmup_steps,
468 | max_grad_norm = max_grad_norm
469 | )
470 |
471 | self.dataloader = DataLoader(
472 | dataset,
473 | shuffle = True,
474 | batch_size = batch_size,
475 | drop_last = True,
476 | collate_fn = partial(custom_collate, pad_id = model.pad_id)
477 | )
478 |
479 | self.should_validate = exists(val_dataset)
480 |
481 | if self.should_validate:
482 | assert len(val_dataset) > 0, 'your validation dataset is empty'
483 |
484 | self.val_every = val_every
485 | self.val_num_batches = val_num_batches
486 |
487 | self.val_dataloader = DataLoader(
488 | val_dataset,
489 | shuffle = True,
490 | batch_size = batch_size,
491 | drop_last = True,
492 | collate_fn = partial(custom_collate, pad_id = model.pad_id)
493 | )
494 |
495 | if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs):
496 | assert beartype_isinstance(dataset.data_kwargs, List[str])
497 | self.data_kwargs = dataset.data_kwargs
498 | else:
499 | self.data_kwargs = data_kwargs
500 |
501 | (
502 | self.model,
503 | self.dataloader
504 | ) = self.accelerator.prepare(
505 | self.model,
506 | self.dataloader
507 | )
508 |
509 | self.grad_accum_every = grad_accum_every
510 | self.num_train_steps = num_train_steps
511 | self.register_buffer('step', torch.tensor(0))
512 |
513 | self.checkpoint_every_epoch = checkpoint_every_epoch
514 | self.checkpoint_every = checkpoint_every
515 | self.checkpoint_folder = Path(checkpoint_folder)
516 | self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
517 |
518 | def log(self, **data_kwargs):
519 | self.accelerator.log(data_kwargs, step = self.step.item())
520 |
521 | @property
522 | def device(self):
523 | return self.unwrapped_model.device
524 |
525 | @property
526 | def is_main(self):
527 | return self.accelerator.is_main_process
528 |
529 | @property
530 | def unwrapped_model(self):
531 | return self.accelerator.unwrap_model(self.model)
532 |
533 | @property
534 | def is_local_main(self):
535 | return self.accelerator.is_local_main_process
536 |
537 | def wait(self):
538 | return self.accelerator.wait_for_everyone()
539 |
540 | def print(self, msg):
541 | return self.accelerator.print(msg)
542 |
543 | def next_data_to_forward_kwargs(self, dl_iter) -> dict:
544 | data = next(dl_iter)
545 |
546 | if isinstance(data, tuple):
547 | forward_kwargs = dict(zip(self.data_kwargs, data))
548 |
549 | elif isinstance(data, dict):
550 | forward_kwargs = data
551 |
552 | return forward_kwargs
553 |
554 | def save(self, path, overwrite = True):
555 | path = Path(path)
556 | assert overwrite or not path.exists()
557 |
558 | pkg = dict(
559 | model = self.unwrapped_model.state_dict(),
560 | optimizer = self.optimizer.state_dict(),
561 | step = self.step.item(),
562 | version = __version__
563 | )
564 |
565 | torch.save(pkg, str(path))
566 |
567 | def load(self, path):
568 | path = Path(path)
569 | assert path.exists()
570 |
571 | pkg = torch.load(str(path))
572 |
573 | if version.parse(__version__) != version.parse(pkg['version']):
574 | self.print(f'loading saved mesh transformer at version {pkg["version"]}, but current package version is {__version__}')
575 |
576 | self.model.load_state_dict(pkg['model'])
577 | self.optimizer.load_state_dict(pkg['optimizer'])
578 | self.step.copy_(pkg['step'])
579 |
580 | def forward(self):
581 | step = self.step.item()
582 | dl_iter = cycle(self.dataloader)
583 |
584 | if self.should_validate:
585 | val_dl_iter = cycle(self.val_dataloader)
586 |
587 | while step < self.num_train_steps:
588 |
589 | for i in range(self.grad_accum_every):
590 | is_last = i == (self.grad_accum_every - 1)
591 | maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext
592 |
593 | forward_kwargs = self.next_data_to_forward_kwargs(dl_iter)
594 |
595 | with self.accelerator.autocast(), maybe_no_sync():
596 | loss = self.model(**forward_kwargs)
597 |
598 | self.accelerator.backward(loss / self.grad_accum_every)
599 |
600 | self.print(f'loss: {loss.item():.3f}')
601 |
602 | self.log(loss = loss.item())
603 |
604 | self.optimizer.step()
605 | self.optimizer.zero_grad()
606 |
607 | step += 1
608 | self.step.add_(1)
609 |
610 | self.wait()
611 |
612 | if self.is_main and self.should_validate and divisible_by(step, self.val_every):
613 |
614 | total_val_loss = 0.
615 | self.unwrapped_model.eval()
616 |
617 | num_val_batches = self.val_num_batches * self.grad_accum_every
618 |
619 | for _ in range(num_val_batches):
620 | with self.accelerator.autocast(), torch.no_grad():
621 |
622 | forward_kwargs = self.next_data_to_forward_kwargs(val_dl_iter)
623 |
624 | val_loss = self.unwrapped_model(**forward_kwargs)
625 |
626 | total_val_loss += (val_loss / num_val_batches)
627 |
628 | self.print(f'valid recon loss: {total_val_loss:.3f}')
629 |
630 | self.log(val_loss = total_val_loss)
631 |
632 | self.wait()
633 |
634 | if self.is_main and divisible_by(step, self.checkpoint_every):
635 | checkpoint_num = step // self.checkpoint_every
636 | self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.{checkpoint_num}.pt')
637 |
638 | self.wait()
639 |
640 | self.print('training complete')
641 |
642 |
643 | def train(self, num_epochs, stop_at_loss = None, diplay_graph = False):
644 | epoch_losses = []
645 | epoch_size = len(self.dataloader)
646 | self.model.train()
647 |
648 | for epoch in range(num_epochs):
649 | total_epoch_loss = 0.0
650 |
651 | progress_bar = tqdm(enumerate(self.dataloader), desc=f'Epoch {epoch + 1}/{num_epochs}', total=len(self.dataloader))
652 | for batch_idx, batch in progress_bar:
653 |
654 | is_last = (batch_idx+1) % self.grad_accum_every == 0
655 | maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext
656 |
657 | with self.accelerator.autocast(), maybe_no_sync():
658 | total_loss = self.model(**batch)
659 | self.accelerator.backward(total_loss / self.grad_accum_every)
660 |
661 | current_loss = total_loss.item()
662 | total_epoch_loss += current_loss
663 |
664 | progress_bar.set_postfix(loss=current_loss)
665 |
666 | if is_last or (batch_idx + 1 == len(self.dataloader)):
667 | self.optimizer.step()
668 | self.optimizer.zero_grad()
669 |
670 | avg_epoch_loss = total_epoch_loss / epoch_size
671 | epochOut = f'Epoch {epoch + 1} average loss: {avg_epoch_loss}'
672 |
673 |
674 | epoch_losses.append(avg_epoch_loss)
675 |
676 | if len(epoch_losses) >= 4 and avg_epoch_loss > 0:
677 | avg_loss_improvement = sum(epoch_losses[-4:-1]) / 3 - avg_epoch_loss
678 | epochOut += f' avg loss speed: {avg_loss_improvement}'
679 | if avg_loss_improvement > 0 and avg_loss_improvement < 0.2:
680 | epochs_until_0_3 = max(0, abs(avg_epoch_loss-0.3) / avg_loss_improvement)
681 | if epochs_until_0_3> 0:
682 | epochOut += f' epochs left: {epochs_until_0_3:.2f}'
683 |
684 | self.wait()
685 | self.print(epochOut)
686 |
687 | if self.is_main and self.checkpoint_every_epoch is not None and (self.checkpoint_every_epoch == 1 or (epoch != 0 and epoch % self.checkpoint_every_epoch == 0)):
688 | self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.epoch_{epoch}_avg_loss_{avg_epoch_loss:.3f}.pt')
689 |
690 | if stop_at_loss is not None and avg_epoch_loss < stop_at_loss:
691 | self.print(f'Stopping training at epoch {epoch} with average loss {avg_epoch_loss}')
692 | if self.is_main and self.checkpoint_every_epoch is not None:
693 | self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.stop_at_loss_avg_loss_{avg_epoch_loss:.3f}.pt')
694 | break
695 |
696 |
697 | self.print('Training complete')
698 | if diplay_graph:
699 | plt.figure(figsize=(10, 5))
700 | plt.plot(range(1, len(epoch_losses) + 1), epoch_losses, marker='o', label='Total Loss')
701 | plt.title('Training Loss Over Epochs')
702 | plt.xlabel('Epoch')
703 | plt.ylabel('Average Loss')
704 | plt.grid(True)
705 | plt.show()
706 | return epoch_losses[-1]
707 |
--------------------------------------------------------------------------------
/meshgpt_pytorch/typing.py:
--------------------------------------------------------------------------------
1 | from environs import Env
2 |
3 | from torch import Tensor
4 |
5 | from beartype import beartype
6 | from beartype.door import is_bearable
7 |
8 | from jaxtyping import (
9 | Float,
10 | Int,
11 | Bool,
12 | jaxtyped
13 | )
14 |
15 | # environment
16 |
17 | env = Env()
18 | env.read_env()
19 |
20 | # function
21 |
22 | def always(value):
23 | def inner(*args, **kwargs):
24 | return value
25 | return inner
26 |
27 | def identity(t):
28 | return t
29 |
30 | # jaxtyping is a misnomer, works for pytorch
31 |
32 | class TorchTyping:
33 | def __init__(self, abstract_dtype):
34 | self.abstract_dtype = abstract_dtype
35 |
36 | def __getitem__(self, shapes: str):
37 | return self.abstract_dtype[Tensor, shapes]
38 |
39 | Float = TorchTyping(Float)
40 | Int = TorchTyping(Int)
41 | Bool = TorchTyping(Bool)
42 |
43 | # use env variable TYPECHECK to control whether to use beartype + jaxtyping
44 |
45 | should_typecheck = env.bool('TYPECHECK', False)
46 |
47 | typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity
48 |
49 | beartype_isinstance = is_bearable if should_typecheck else always(True)
50 |
51 | __all__ = [
52 | Float,
53 | Int,
54 | Bool,
55 | typecheck,
56 | beartype_isinstance
57 | ]
58 |
--------------------------------------------------------------------------------
/meshgpt_pytorch/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.5.12'
2 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 | [tool:pytest]
5 | addopts = --verbose -s
6 | python_files = tests/*.py
7 | python_paths = "."
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | exec(open('meshgpt_pytorch/version.py').read())
4 |
5 | setup(
6 | name = 'meshgpt-pytorch',
7 | packages = find_packages(exclude=[]),
8 | version = __version__,
9 | license='MIT',
10 | description = 'MeshGPT Pytorch',
11 | author = 'Phil Wang',
12 | author_email = 'lucidrains@gmail.com',
13 | long_description_content_type = 'text/markdown',
14 | url = 'https://github.com/lucidrains/meshgpt-pytorch',
15 | keywords = [
16 | 'artificial intelligence',
17 | 'deep learning',
18 | 'attention mechanisms',
19 | 'transformers',
20 | 'mesh generation'
21 | ],
22 | install_requires=[
23 | 'matplotlib',
24 | 'accelerate>=0.25.0',
25 | 'beartype',
26 | "huggingface_hub>=0.21.4",
27 | 'classifier-free-guidance-pytorch>=0.6.10',
28 | 'einops>=0.8.0',
29 | 'einx[torch]>=0.3.0',
30 | 'ema-pytorch>=0.5.1',
31 | 'environs',
32 | 'gateloop-transformer>=0.2.2',
33 | 'jaxtyping',
34 | 'local-attention>=1.9.11',
35 | 'numpy',
36 | 'pytorch-custom-utils>=0.0.9',
37 | 'rotary-embedding-torch>=0.6.4',
38 | 'sentencepiece',
39 | 'taylor-series-linear-attention>=0.1.6',
40 | 'torch>=2.1',
41 | 'torch_geometric',
42 | 'tqdm',
43 | 'vector-quantize-pytorch>=1.14.22',
44 | 'x-transformers>=1.30.19,<1.31',
45 | ],
46 | setup_requires=[
47 | 'pytest-runner',
48 | ],
49 | tests_require=[
50 | 'pytest'
51 | ],
52 | classifiers=[
53 | 'Development Status :: 4 - Beta',
54 | 'Intended Audience :: Developers',
55 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
56 | 'License :: OSI Approved :: MIT License',
57 | 'Programming Language :: Python :: 3.6',
58 | ],
59 | )
60 |
--------------------------------------------------------------------------------
/tests/test_meshgpt.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from meshgpt_pytorch import (
5 | MeshAutoencoder,
6 | MeshTransformer
7 | )
8 |
9 | @pytest.mark.parametrize('adaptive_rmsnorm', (True, False))
10 | def test_readme(adaptive_rmsnorm):
11 |
12 | autoencoder = MeshAutoencoder(
13 | num_discrete_coors = 128
14 | )
15 |
16 | # mock inputs
17 |
18 | vertices = torch.randn((2, 121, 3)) # (batch, num vertices, coor (3))
19 | faces = torch.randint(0, 121, (2, 2, 3)) # (batch, num faces, vertices (3))
20 |
21 | # forward in the faces
22 |
23 | loss = autoencoder(
24 | vertices = vertices,
25 | faces = faces
26 | )
27 |
28 | loss.backward()
29 |
30 | # after much training...
31 | # you can pass in the raw face data above to train a transformer to model this sequence of face vertices
32 |
33 | transformer = MeshTransformer(
34 | autoencoder,
35 | dim = 512,
36 | max_seq_len = 60,
37 | num_sos_tokens = 1,
38 | fine_cross_attend_text = True,
39 | text_cond_with_film = False,
40 | condition_on_text = True,
41 | coarse_post_gateloop_depth = 1,
42 | coarse_adaptive_rmsnorm = adaptive_rmsnorm
43 | )
44 |
45 | loss = transformer(
46 | vertices = vertices,
47 | faces = faces,
48 | texts = ['a high chair', 'a small teapot']
49 | )
50 |
51 | loss.backward()
52 |
53 | faces_coordinates, face_mask = transformer.generate(texts = ['a small chair'], cond_scale = 3.)
54 |
55 | def test_cache():
56 | # test that the output for generation with and without kv (and optional gateloop) cache is equivalent
57 |
58 | autoencoder = MeshAutoencoder(
59 | num_discrete_coors = 128
60 | )
61 |
62 | transformer = MeshTransformer(
63 | autoencoder,
64 | dim = 512,
65 | max_seq_len = 12
66 | )
67 |
68 | uncached_faces_coors, _ = transformer.generate(cache_kv = False, temperature = 0)
69 | cached_faces_coors, _ = transformer.generate(cache_kv = True, temperature = 0)
70 |
71 | assert torch.allclose(uncached_faces_coors, cached_faces_coors)
72 |
--------------------------------------------------------------------------------