├── .env.sample ├── .github └── workflows │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── MeshGPT_demo.ipynb ├── README.md ├── demo_mesh ├── bar chair.glb ├── circle chair.glb ├── corner table.glb ├── designer chair.glb ├── designer sloped chair.glb ├── glass table.glb ├── high chair.glb ├── office table.glb └── tv table.glb ├── meshgpt.png ├── meshgpt_pytorch ├── __init__.py ├── data.py ├── mesh_dataset.py ├── mesh_render.py ├── meshgpt_pytorch.py ├── trainer.py ├── typing.py └── version.py ├── setup.cfg ├── setup.py ├── shapenet_labels.json └── tests └── test_meshgpt.py /.env.sample: -------------------------------------------------------------------------------- 1 | TYPECHECK=True 2 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@v1.9.0 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu 30 | python -m pip install scipy 31 | - name: Test with pytest 32 | run: | 33 | python setup.py test 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MeshGPT_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!pip install -q git+https://github.com/MarcusLoppe/meshgpt-pytorch.git" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 33, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import trimesh\n", 20 | "import numpy as np\n", 21 | "import os\n", 22 | "import csv\n", 23 | "import json\n", 24 | "from collections import OrderedDict\n", 25 | "\n", 26 | "from meshgpt_pytorch import (\n", 27 | " MeshTransformerTrainer,\n", 28 | " MeshAutoencoderTrainer,\n", 29 | " MeshAutoencoder,\n", 30 | " MeshTransformer\n", 31 | ")\n", 32 | "from meshgpt_pytorch.data import ( \n", 33 | " derive_face_edges_from_faces\n", 34 | ") \n", 35 | "\n", 36 | "def get_mesh(file_path): \n", 37 | " mesh = trimesh.load(file_path, force='mesh') \n", 38 | " vertices = mesh.vertices.tolist()\n", 39 | " if \".off\" in file_path: # ModelNet dataset\n", 40 | " mesh.vertices[:, [1, 2]] = mesh.vertices[:, [2, 1]] \n", 41 | " rotation_matrix = trimesh.transformations.rotation_matrix(np.radians(-90), [0, 1, 0])\n", 42 | " mesh.apply_transform(rotation_matrix) \n", 43 | " # Extract vertices and faces from the rotated mesh\n", 44 | " vertices = mesh.vertices.tolist()\n", 45 | " \n", 46 | " faces = mesh.faces.tolist()\n", 47 | " \n", 48 | " centered_vertices = vertices - np.mean(vertices, axis=0) \n", 49 | " max_abs = np.max(np.abs(centered_vertices))\n", 50 | " vertices = centered_vertices / (max_abs / 0.95) # Limit vertices to [-0.95, 0.95]\n", 51 | " \n", 52 | " min_y = np.min(vertices[:, 1]) \n", 53 | " difference = -0.95 - min_y \n", 54 | " vertices[:, 1] += difference\n", 55 | " \n", 56 | " def sort_vertices(vertex):\n", 57 | " return vertex[1], vertex[2], vertex[0] \n", 58 | " \n", 59 | " seen = OrderedDict()\n", 60 | " for point in vertices: \n", 61 | " key = tuple(point)\n", 62 | " if key not in seen:\n", 63 | " seen[key] = point\n", 64 | " \n", 65 | " unique_vertices = list(seen.values()) \n", 66 | " sorted_vertices = sorted(unique_vertices, key=sort_vertices)\n", 67 | " \n", 68 | " vertices_as_tuples = [tuple(v) for v in vertices]\n", 69 | " sorted_vertices_as_tuples = [tuple(v) for v in sorted_vertices]\n", 70 | "\n", 71 | " vertex_map = {old_index: new_index for old_index, vertex_tuple in enumerate(vertices_as_tuples) for new_index, sorted_vertex_tuple in enumerate(sorted_vertices_as_tuples) if vertex_tuple == sorted_vertex_tuple} \n", 72 | " reindexed_faces = [[vertex_map[face[0]], vertex_map[face[1]], vertex_map[face[2]]] for face in faces] \n", 73 | " sorted_faces = [sorted(sub_arr) for sub_arr in reindexed_faces] \n", 74 | " return np.array(sorted_vertices), np.array(sorted_faces)\n", 75 | " \n", 76 | " \n", 77 | "\n", 78 | "def augment_mesh(vertices, scale_factor): \n", 79 | " jitter_factor=0.01 \n", 80 | " possible_values = np.arange(-jitter_factor, jitter_factor , 0.0005) \n", 81 | " offsets = np.random.choice(possible_values, size=vertices.shape) \n", 82 | " vertices = vertices + offsets \n", 83 | " \n", 84 | " vertices = vertices * scale_factor \n", 85 | " # To ensure that the mesh models are on the \"ground\"\n", 86 | " min_y = np.min(vertices[:, 1]) \n", 87 | " difference = -0.95 - min_y \n", 88 | " vertices[:, 1] += difference\n", 89 | " return vertices\n", 90 | "\n", 91 | "\n", 92 | "#load_shapenet(\"./shapenet\", \"./shapenet_csv_files\", 10, 10) \n", 93 | "#Find the csv files with the labels in the ShapeNetCore.v1.zip, download at https://huggingface.co/datasets/ShapeNet/ShapeNetCore-archive \n", 94 | "def load_shapenet(directory, per_category, variations ):\n", 95 | " obj_datas = [] \n", 96 | " chosen_models_count = {} \n", 97 | " print(f\"per_category: {per_category} variations {variations}\")\n", 98 | " \n", 99 | " with open('shapenet_labels.json' , 'r') as f:\n", 100 | " id_info = json.load(f) \n", 101 | " \n", 102 | " possible_values = np.arange(0.75, 1.0 , 0.005) \n", 103 | " scale_factors = np.random.choice(possible_values, size=variations) \n", 104 | " \n", 105 | " for category in os.listdir(directory): \n", 106 | " category_path = os.path.join(directory, category) \n", 107 | " if os.path.isdir(category_path) == False:\n", 108 | " continue \n", 109 | " \n", 110 | " num_files_in_category = len(os.listdir(category_path))\n", 111 | " print(f\"{category_path} got {num_files_in_category} files\") \n", 112 | " chosen_models_count[category] = 0 \n", 113 | " \n", 114 | " for filename in os.listdir(category_path):\n", 115 | " if filename.endswith((\".obj\", \".glb\", \".off\")):\n", 116 | " file_path = os.path.join(category_path, filename)\n", 117 | " \n", 118 | " if chosen_models_count[category] >= per_category:\n", 119 | " break \n", 120 | " if os.path.getsize(file_path) > 20 * 1024: # 20 kb limit = less then 400-600 faces\n", 121 | " continue \n", 122 | " if filename[:-4] not in id_info:\n", 123 | " print(\"Unable to find id info for \", filename)\n", 124 | " continue \n", 125 | " vertices, faces = get_mesh(file_path) \n", 126 | " if len(faces) > 800: \n", 127 | " continue\n", 128 | " \n", 129 | " chosen_models_count[category] += 1 \n", 130 | " textName = id_info[filename[:-4]] \n", 131 | " \n", 132 | " face_edges = derive_face_edges_from_faces(faces) \n", 133 | " for scale_factor in scale_factors: \n", 134 | " aug_vertices = augment_mesh(vertices.copy(), scale_factor) \n", 135 | " obj_data = {\"vertices\": torch.tensor(aug_vertices.tolist(), dtype=torch.float).to(\"cuda\"), \"faces\": torch.tensor(faces.tolist(), dtype=torch.long).to(\"cuda\"), \"face_edges\" : face_edges, \"texts\": textName } \n", 136 | " obj_datas.append(obj_data)\n", 137 | " \n", 138 | " print(\"=\"*25)\n", 139 | " print(\"Chosen models count for each category:\")\n", 140 | " for category, count in chosen_models_count.items():\n", 141 | " print(f\"{category}: {count}\") \n", 142 | " total_chosen_models = sum(chosen_models_count.values())\n", 143 | " print(f\"Total number of chosen models: {total_chosen_models}\")\n", 144 | " return obj_datas\n", 145 | "\n", 146 | " \n", 147 | " \n", 148 | "def load_filename(directory, variations):\n", 149 | " obj_datas = [] \n", 150 | " possible_values = np.arange(0.75, 1.0 , 0.005) \n", 151 | " scale_factors = np.random.choice(possible_values, size=variations) \n", 152 | " \n", 153 | " for filename in os.listdir(directory):\n", 154 | " if filename.endswith((\".obj\", \".glb\", \".off\")): \n", 155 | " file_path = os.path.join(directory, filename) \n", 156 | " vertices, faces = get_mesh(file_path) \n", 157 | " \n", 158 | " faces = torch.tensor(faces.tolist(), dtype=torch.long).to(\"cuda\")\n", 159 | " face_edges = derive_face_edges_from_faces(faces) \n", 160 | " texts, ext = os.path.splitext(filename) \n", 161 | " \n", 162 | " for scale_factor in scale_factors: \n", 163 | " aug_vertices = augment_mesh(vertices.copy(), scale_factor) \n", 164 | " obj_data = {\"vertices\": torch.tensor(aug_vertices.tolist(), dtype=torch.float).to(\"cuda\"), \"faces\": faces, \"face_edges\" : face_edges, \"texts\": texts } \n", 165 | " obj_datas.append(obj_data)\n", 166 | " \n", 167 | " print(f\"[create_mesh_dataset] Returning {len(obj_data)} meshes\")\n", 168 | " return obj_datas" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "import gzip,json\n", 178 | "from tqdm import tqdm\n", 179 | "import pandas as pd\n", 180 | "\n", 181 | "# Instruction to download objverse meshes: https://github.com/MarcusLoppe/Objaverse-downloader/tree/main\n", 182 | "def load_objverse(directory, variations ):\n", 183 | " obj_datas = [] \n", 184 | " id_info = {} \n", 185 | " pali_captions = pd.read_csv('.\\pali_captions.csv', sep=';') # https://github.com/google-deepmind/objaverse_annotations/blob/main/pali_captions.csv\n", 186 | " pali_captions_dict = pali_captions.set_index(\"object_uid\").to_dict()[\"top_aggregate_caption\"] \n", 187 | " \n", 188 | " possible_values = np.arange(0.75, 1.0) \n", 189 | " scale_factors = np.random.choice(possible_values, size=variations) \n", 190 | " \n", 191 | " for folder in os.listdir(directory): \n", 192 | " full_folder_path = os.path.join(directory, folder) \n", 193 | " if os.path.isdir(full_folder_path) == False:\n", 194 | " continue \n", 195 | " \n", 196 | " for filename in tqdm(os.listdir(full_folder_path)): \n", 197 | " if filename.endswith((\".obj\", \".glb\", \".off\")):\n", 198 | " file_path = os.path.join(full_folder_path, filename)\n", 199 | " kb = os.path.getsize(file_path) / 1024 \n", 200 | " if kb < 1 or kb > 30:\n", 201 | " continue\n", 202 | " \n", 203 | " if filename[:-4] not in pali_captions_dict: \n", 204 | " continue \n", 205 | " textName = pali_captions_dict[filename[:-4]]\n", 206 | " try: \n", 207 | " vertices, faces = get_mesh(file_path) \n", 208 | " except Exception as e:\n", 209 | " continue\n", 210 | " \n", 211 | " if len(faces) > 250 or len(faces) < 50: \n", 212 | " continue\n", 213 | " \n", 214 | " faces = torch.tensor(faces.tolist(), dtype=torch.long).to(\"cuda\")\n", 215 | " face_edges = derive_face_edges_from_faces(faces) \n", 216 | " for scale_factor in scale_factors: \n", 217 | " aug_vertices = augment_mesh(vertices.copy(), scale_factor) \n", 218 | " obj_data = {\"filename\": filename, \"vertices\": torch.tensor(aug_vertices.tolist(), dtype=torch.float).to(\"cuda\"), \"faces\": faces, \"face_edges\" : face_edges, \"texts\": textName } \n", 219 | " obj_datas.append(obj_data) \n", 220 | " return obj_datas" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "from pathlib import Path \n", 230 | "import gc \n", 231 | "import os\n", 232 | "from meshgpt_pytorch import MeshDataset \n", 233 | " \n", 234 | "project_name = \"demo_mesh\" \n", 235 | "\n", 236 | "working_dir = f'.\\{project_name}'\n", 237 | "\n", 238 | "working_dir = Path(working_dir)\n", 239 | "working_dir.mkdir(exist_ok = True, parents = True)\n", 240 | "dataset_path = working_dir / (project_name + \".npz\")\n", 241 | " \n", 242 | "if not os.path.isfile(dataset_path):\n", 243 | " data = load_filename(\"./demo_mesh\",50) \n", 244 | " dataset = MeshDataset(data) \n", 245 | " dataset.generate_face_edges() \n", 246 | " dataset.save(dataset_path)\n", 247 | " \n", 248 | "dataset = MeshDataset.load(dataset_path) \n", 249 | "print(dataset.data[0].keys())" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "#### Inspect imported meshes (optional)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "from pathlib import Path\n", 266 | " \n", 267 | "folder = working_dir / f'renders' \n", 268 | "obj_file_path = Path(folder)\n", 269 | "obj_file_path.mkdir(exist_ok = True, parents = True)\n", 270 | " \n", 271 | "all_vertices = []\n", 272 | "all_faces = []\n", 273 | "vertex_offset = 0\n", 274 | "translation_distance = 0.5 \n", 275 | "\n", 276 | "for r, item in enumerate(data): \n", 277 | " vertices_copy = np.copy(item['vertices'])\n", 278 | " vertices_copy += translation_distance * (r / 0.2 - 1) \n", 279 | " \n", 280 | " for vert in vertices_copy:\n", 281 | " vertex = vert.to('cpu')\n", 282 | " all_vertices.append(f\"v {float(vertex[0])} {float(vertex[1])} {float(vertex[2])}\\n\") \n", 283 | " for face in item['faces']:\n", 284 | " all_faces.append(f\"f {face[0]+1+ vertex_offset} {face[1]+ 1+vertex_offset} {face[2]+ 1+vertex_offset}\\n\") \n", 285 | " vertex_offset = len(all_vertices)\n", 286 | " \n", 287 | "obj_file_content = \"\".join(all_vertices) + \"\".join(all_faces)\n", 288 | " \n", 289 | "obj_file_path = f'{folder}/3d_models_inspect.obj' \n", 290 | "with open(obj_file_path, \"w\") as file:\n", 291 | " file.write(obj_file_content) \n", 292 | " " 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "metadata": {}, 298 | "source": [ 299 | "### Train!" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "autoencoder = MeshAutoencoder( \n", 309 | " decoder_dims_through_depth = (128,) * 6 + (192,) * 12 + (256,) * 24 + (384,) * 6, \n", 310 | " codebook_size = 2048, # Smaller vocab size will speed up the transformer training, however if you are training on meshes more then 250 triangle, I'd advice to use 16384 codebook size\n", 311 | " dim_codebook = 192, \n", 312 | " dim_area_embed = 16,\n", 313 | " dim_coor_embed = 16, \n", 314 | " dim_normal_embed = 16,\n", 315 | " dim_angle_embed = 8,\n", 316 | " \n", 317 | " attn_decoder_depth = 4,\n", 318 | " attn_encoder_depth = 2\n", 319 | ").to(\"cuda\") \n", 320 | "total_params = sum(p.numel() for p in autoencoder.parameters()) \n", 321 | "total_params = f\"{total_params / 1000000:.1f}M\"\n", 322 | "print(f\"Total parameters: {total_params}\")" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "**Have at least 400-2000 items in the dataset, use this to multiply the dataset** " 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "dataset.data = [dict(d) for d in dataset.data] * 10\n", 339 | "print(len(dataset.data))" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "*Load previous saved model if you had to restart session*" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "pkg = torch.load(str(f'{working_dir}\\mesh-encoder_{project_name}.pt')) \n", 356 | "autoencoder.load_state_dict(pkg['model'])\n", 357 | "for param in autoencoder.parameters():\n", 358 | " param.requires_grad = True" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "**Train to about 0.3 loss if you are using a small dataset**" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "batch_size=16 # The batch size should be max 64.\n", 375 | "grad_accum_every = 4\n", 376 | "# So set the maximal batch size (max 64) that your VRAM can handle and then use grad_accum_every to create a effective batch size of 64, e.g 16 * 4 = 64\n", 377 | "learning_rate = 1e-3 # Start with 1e-3 then at staggnation around 0.35, you can lower it to 1e-4.\n", 378 | "\n", 379 | "autoencoder.commit_loss_weight = 0.2 # Set dependant on the dataset size, on smaller datasets, 0.1 is fine, otherwise try from 0.25 to 0.4.\n", 380 | "autoencoder_trainer = MeshAutoencoderTrainer(model =autoencoder ,warmup_steps = 10, dataset = dataset, num_train_steps=100,\n", 381 | " batch_size=batch_size,\n", 382 | " grad_accum_every = grad_accum_every,\n", 383 | " learning_rate = learning_rate,\n", 384 | " checkpoint_every_epoch=1) \n", 385 | "loss = autoencoder_trainer.train(480,stop_at_loss = 0.2, diplay_graph= True) " 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "autoencoder_trainer.save(f'{working_dir}\\mesh-encoder_{project_name}.pt') " 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "### Inspect how the autoencoder can encode and then provide the decoder with the codes to reconstruct the mesh" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "import torch\n", 411 | "import random\n", 412 | "from tqdm import tqdm \n", 413 | "from meshgpt_pytorch import mesh_render \n", 414 | "\n", 415 | "min_mse, max_mse = float('inf'), float('-inf')\n", 416 | "min_coords, min_orgs, max_coords, max_orgs = None, None, None, None\n", 417 | "random_samples, random_samples_pred, all_random_samples = [], [], []\n", 418 | "total_mse, sample_size = 0.0, 200\n", 419 | "\n", 420 | "random.shuffle(dataset.data)\n", 421 | "\n", 422 | "for item in tqdm(dataset.data[:sample_size]):\n", 423 | " codes = autoencoder.tokenize(vertices=item['vertices'], faces=item['faces'], face_edges=item['face_edges']) \n", 424 | " \n", 425 | " codes = codes.flatten().unsqueeze(0)\n", 426 | " codes = codes[:, :codes.shape[-1] // autoencoder.num_quantizers * autoencoder.num_quantizers] \n", 427 | " \n", 428 | " coords, mask = autoencoder.decode_from_codes_to_faces(codes)\n", 429 | " orgs = item['vertices'][item['faces']].unsqueeze(0)\n", 430 | "\n", 431 | " mse = torch.mean((orgs.view(-1, 3).cpu() - coords.view(-1, 3).cpu())**2)\n", 432 | " total_mse += mse \n", 433 | "\n", 434 | " if mse < min_mse: min_mse, min_coords, min_orgs = mse, coords, orgs\n", 435 | " if mse > max_mse: max_mse, max_coords, max_orgs = mse, coords, orgs\n", 436 | " \n", 437 | " if len(random_samples) <= 30:\n", 438 | " random_samples.append(coords)\n", 439 | " random_samples_pred.append(orgs)\n", 440 | " else:\n", 441 | " all_random_samples.extend([random_samples_pred, random_samples])\n", 442 | " random_samples, random_samples_pred = [], []\n", 443 | "\n", 444 | "print(f'MSE AVG: {total_mse / sample_size:.10f}, Min: {min_mse:.10f}, Max: {max_mse:.10f}') \n", 445 | "mesh_render.combind_mesh_with_rows(f'{working_dir}\\mse_rows.obj', all_random_samples)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "metadata": {}, 451 | "source": [ 452 | "### Training & fine-tuning\n", 453 | "\n", 454 | "**Pre-train:** Train the transformer on the full dataset with all the augmentations, the longer / more epochs will create a more robust model.
\n", 455 | "\n", 456 | "**Fine-tune:** Since it will take a long time to train on all the possible augmentations of the meshes, I recommend that you remove all the augmentations so you are left with x1 model per mesh.
\n", 457 | "Below is the function **filter_dataset** that will return a single copy of each mesh.
\n", 458 | "The function can also check for duplicate labels, this may speed up the fine-tuning process (not recommanded) however this most likely will remove it's ability for novel mesh generation." 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "import gc \n", 468 | "torch.cuda.empty_cache()\n", 469 | "gc.collect() \n", 470 | "max_seq = max(len(d[\"faces\"]) for d in dataset if \"faces\" in d) * (autoencoder.num_vertices_per_face * autoencoder.num_quantizers) \n", 471 | "print(\"Max token sequence:\" , max_seq) \n", 472 | "\n", 473 | "# GPT2-Small model\n", 474 | "transformer = MeshTransformer(\n", 475 | " autoencoder,\n", 476 | " dim = 768,\n", 477 | " coarse_pre_gateloop_depth = 3, \n", 478 | " fine_pre_gateloop_depth= 3, \n", 479 | " attn_depth = 12, \n", 480 | " attn_heads = 12, \n", 481 | " max_seq_len = max_seq, \n", 482 | " condition_on_text = True, \n", 483 | " gateloop_use_heinsen = False,\n", 484 | " dropout = 0.0,\n", 485 | " text_condition_model_types = \"bge\", \n", 486 | " text_condition_cond_drop_prob = 0.0\n", 487 | ") \n", 488 | "\n", 489 | "total_params = sum(p.numel() for p in transformer.decoder.parameters())\n", 490 | "total_params = f\"{total_params / 1000000:.1f}M\"\n", 491 | "print(f\"Decoder total parameters: {total_params}\")" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "def filter_dataset(dataset, unique_labels = False):\n", 501 | " unique_dicts = []\n", 502 | " unique_tensors = set()\n", 503 | " texts = set()\n", 504 | " for d in dataset.data:\n", 505 | " tensor = d[\"faces\"]\n", 506 | " tensor_tuple = tuple(tensor.cpu().numpy().flatten())\n", 507 | " if unique_labels and d['texts'] in texts:\n", 508 | " continue\n", 509 | " if tensor_tuple not in unique_tensors:\n", 510 | " unique_tensors.add(tensor_tuple)\n", 511 | " unique_dicts.append(d)\n", 512 | " texts.add(d['texts'])\n", 513 | " return unique_dicts \n", 514 | "#dataset.data = filter_dataset(dataset.data, unique_labels = False)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": {}, 520 | "source": [ 521 | "## **Required!**, embed the text and run generate_codes to save 4-96 GB VRAM (dependant on dataset) ##\n", 522 | "\n", 523 | "**If you don't;**
\n", 524 | "During each during each training step the autoencoder will generate the codes and the text encoder will embed the text.\n", 525 | "
\n", 526 | "After these fields are generate: **they will be deleted and next time it generates the code again:**
\n", 527 | "\n", 528 | "This is due to the dataloaders nature, it writes this information to a temporary COPY of the dataset\n" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "labels = list(set(item[\"texts\"] for item in dataset.data))\n", 538 | "dataset.embed_texts(transformer, batch_size = 25)\n", 539 | "dataset.generate_codes(autoencoder, batch_size = 50)\n", 540 | "print(dataset.data[0].keys())" 541 | ] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": {}, 546 | "source": [ 547 | "*Load previous saved model if you had to restart session*" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": null, 553 | "metadata": {}, 554 | "outputs": [], 555 | "source": [ 556 | "pkg = torch.load(str(f'{working_dir}\\mesh-transformer_{project_name}.pt')) \n", 557 | "transformer.load_state_dict(pkg['model'])" 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "metadata": {}, 563 | "source": [ 564 | "**Train to about 0.0001 loss (or less) if you are using a small dataset**" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": null, 570 | "metadata": {}, 571 | "outputs": [], 572 | "source": [ 573 | "batch_size = 4 # Max 64\n", 574 | "grad_accum_every = 16\n", 575 | "\n", 576 | "# Set the maximal batch size (max 64) that your VRAM can handle and then use grad_accum_every to create a effective batch size of 64, e.g 4 * 16 = 64\n", 577 | "learning_rate = 1e-2 # Start training with the learning rate at 1e-2 then lower it to 1e-3 at stagnation or at 0.5 loss.\n", 578 | "\n", 579 | "trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,num_train_steps=100, dataset = dataset,\n", 580 | " grad_accum_every=grad_accum_every,\n", 581 | " learning_rate = learning_rate,\n", 582 | " batch_size=batch_size,\n", 583 | " checkpoint_every_epoch = 1,\n", 584 | " # FP16 training, it doesn't speed up very much but can increase the batch size which will in turn speed up the training.\n", 585 | " # However it might cause nan after a while.\n", 586 | " # accelerator_kwargs = {\"mixed_precision\" : \"fp16\"}, optimizer_kwargs = { \"eps\": 1e-7} \n", 587 | " )\n", 588 | "loss = trainer.train(300, stop_at_loss = 0.005) " 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [ 597 | "trainer.save(f'{working_dir}\\mesh-transformer_{project_name}.pt') " 598 | ] 599 | }, 600 | { 601 | "cell_type": "markdown", 602 | "metadata": {}, 603 | "source": [ 604 | "## Generate and view mesh" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": {}, 610 | "source": [ 611 | "**Using only text**" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": {}, 618 | "outputs": [], 619 | "source": [ 620 | " \n", 621 | "from meshgpt_pytorch import mesh_render \n", 622 | "from pathlib import Path\n", 623 | " \n", 624 | "folder = working_dir / 'renders'\n", 625 | "obj_file_path = Path(folder)\n", 626 | "obj_file_path.mkdir(exist_ok = True, parents = True) \n", 627 | " \n", 628 | "text_coords = [] \n", 629 | "for text in labels[:10]:\n", 630 | " print(f\"Generating {text}\") \n", 631 | " text_coords.append(transformer.generate(texts = [text], temperature = 0.0)) \n", 632 | " \n", 633 | "mesh_render.save_rendering(f'{folder}/3d_models_all.obj', text_coords)" 634 | ] 635 | }, 636 | { 637 | "cell_type": "markdown", 638 | "metadata": {}, 639 | "source": [ 640 | "**Text + prompt of tokens**" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "metadata": {}, 646 | "source": [ 647 | "**Prompt with 10% of codes/tokens**" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": null, 653 | "metadata": {}, 654 | "outputs": [], 655 | "source": [ 656 | "from pathlib import Path \n", 657 | "from meshgpt_pytorch import mesh_render \n", 658 | "folder = working_dir / f'renders/text+codes'\n", 659 | "obj_file_path = Path(folder)\n", 660 | "obj_file_path.mkdir(exist_ok = True, parents = True) \n", 661 | "\n", 662 | "token_length_procent = 0.10 \n", 663 | "codes = []\n", 664 | "texts = []\n", 665 | "for label in labels:\n", 666 | " for item in dataset.data: \n", 667 | " if item['texts'] == label:\n", 668 | " tokens = autoencoder.tokenize(\n", 669 | " vertices = item['vertices'],\n", 670 | " faces = item['faces'],\n", 671 | " face_edges = item['face_edges']\n", 672 | " ) \n", 673 | " num_tokens = int(tokens.shape[0] * token_length_procent) \n", 674 | " texts.append(item['texts']) \n", 675 | " codes.append(tokens.flatten()[:num_tokens].unsqueeze(0)) \n", 676 | " break\n", 677 | " \n", 678 | "coords = [] \n", 679 | "for text, prompt in zip(texts, codes): \n", 680 | " print(f\"Generating {text} with {prompt.shape[1]} tokens\") \n", 681 | " coords.append(transformer.generate(texts = [text], prompt = prompt, temperature = 0) ) \n", 682 | " \n", 683 | "mesh_render.save_rendering(f'{folder}/text+prompt_{token_length_procent*100}.obj', coords)" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": {}, 689 | "source": [ 690 | "**Prompt with 0% to 80% of tokens**" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": null, 696 | "metadata": {}, 697 | "outputs": [], 698 | "source": [ 699 | "from pathlib import Path\n", 700 | "from meshgpt_pytorch import mesh_render \n", 701 | " \n", 702 | "folder = working_dir / f'renders/text+codes_rows'\n", 703 | "obj_file_path = Path(folder)\n", 704 | "obj_file_path.mkdir(exist_ok = True, parents = True) \n", 705 | "\n", 706 | "mesh_rows = []\n", 707 | "for token_length_procent in np.arange(0, 0.8, 0.1):\n", 708 | " codes = []\n", 709 | " texts = []\n", 710 | " for label in labels:\n", 711 | " for item in dataset.data: \n", 712 | " if item['texts'] == label:\n", 713 | " tokens = autoencoder.tokenize(\n", 714 | " vertices = item['vertices'],\n", 715 | " faces = item['faces'],\n", 716 | " face_edges = item['face_edges']\n", 717 | " ) \n", 718 | " num_tokens = int(tokens.shape[0] * token_length_procent) \n", 719 | " \n", 720 | " texts.append(item['texts']) \n", 721 | " codes.append(tokens.flatten()[:num_tokens].unsqueeze(0)) \n", 722 | " break\n", 723 | " \n", 724 | " coords = [] \n", 725 | " for text, prompt in zip(texts, codes): \n", 726 | " print(f\"Generating {text} with {prompt.shape[1]} tokens\") \n", 727 | " coords.append(transformer.generate(texts = [text], prompt = prompt, temperature = 0)) \n", 728 | " \n", 729 | " mesh_rows.append(coords) \n", 730 | " \n", 731 | "mesh_render.save_rendering(f'{folder}/all.obj', mesh_rows)\n", 732 | " " 733 | ] 734 | }, 735 | { 736 | "cell_type": "markdown", 737 | "metadata": {}, 738 | "source": [ 739 | "**Just some testing for text embedding similarity**" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": null, 745 | "metadata": {}, 746 | "outputs": [], 747 | "source": [ 748 | "import numpy as np \n", 749 | "texts = list(labels)\n", 750 | "vectors = [transformer.conditioner.text_models[0].embed_text([text], return_text_encodings = False).cpu().flatten() for text in texts]\n", 751 | " \n", 752 | "max_label_length = max(len(text) for text in texts)\n", 753 | " \n", 754 | "# Print the table header\n", 755 | "print(f\"{'Text':<{max_label_length}} |\", end=\" \")\n", 756 | "for text in texts:\n", 757 | " print(f\"{text:<{max_label_length}} |\", end=\" \")\n", 758 | "print()\n", 759 | "\n", 760 | "# Print the similarity matrix as a table with fixed-length columns\n", 761 | "for i in range(len(texts)):\n", 762 | " print(f\"{texts[i]:<{max_label_length}} |\", end=\" \")\n", 763 | " for j in range(len(texts)):\n", 764 | " # Encode the texts and calculate cosine similarity manually\n", 765 | " vector_i = vectors[i]\n", 766 | " vector_j = vectors[j]\n", 767 | " \n", 768 | " dot_product = torch.sum(vector_i * vector_j)\n", 769 | " norm_vector1 = torch.norm(vector_i)\n", 770 | " norm_vector2 = torch.norm(vector_j)\n", 771 | " similarity_score = dot_product / (norm_vector1 * norm_vector2)\n", 772 | " \n", 773 | " # Print with fixed-length columns\n", 774 | " print(f\"{similarity_score.item():<{max_label_length}.4f} |\", end=\" \")\n", 775 | " print()" 776 | ] 777 | } 778 | ], 779 | "metadata": { 780 | "kaggle": { 781 | "accelerator": "gpu", 782 | "dataSources": [], 783 | "dockerImageVersionId": 30627, 784 | "isGpuEnabled": true, 785 | "isInternetEnabled": true, 786 | "language": "python", 787 | "sourceType": "notebook" 788 | }, 789 | "kernelspec": { 790 | "display_name": "Python 3", 791 | "language": "python", 792 | "name": "python3" 793 | }, 794 | "language_info": { 795 | "codemirror_mode": { 796 | "name": "ipython", 797 | "version": 3 798 | }, 799 | "file_extension": ".py", 800 | "mimetype": "text/x-python", 801 | "name": "python", 802 | "nbconvert_exporter": "python", 803 | "pygments_lexer": "ipython3", 804 | "version": "3.11.5" 805 | } 806 | }, 807 | "nbformat": 4, 808 | "nbformat_minor": 4 809 | } 810 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## MeshGPT - Pytorch 4 | 5 | Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch 6 | 7 | Will also add text conditioning, for eventual text-to-3d asset 8 | 9 | 10 | Please visit the orginal repo for more details: 11 | https://github.com/lucidrains/meshgpt-pytorch 12 | 13 | ### Data sources: 14 | #### ModelNet40: https://www.kaggle.com/datasets/balraj98/modelnet40-princeton-3d-object-dataset/data 15 | 16 | #### ShapeNet - [Extracted model labels](https://github.com/MarcusLoppe/meshgpt-pytorch/blob/main/shapenet_labels.json) Repository: https://huggingface.co/datasets/ShapeNet/shapenetcore-gltf 17 | 18 | #### Objaverse - [Downloader](https://github.com/MarcusLoppe/Objaverse-downloader/tree/main) Repository: https://huggingface.co/datasets/allenai/objaverse 19 | 20 | ## Pre-trained autoencoder on the Objaverse dataset (14k meshes, only meshes that have max 250 faces): 21 | This is contains only autoencoder model, I'm currently training the transformer model.
22 | Visit the discussions [Pre-trained autoencoder & data sourcing](https://github.com/lucidrains/meshgpt-pytorch/discussions/66) for more information about the training and details about the progression. 23 | 24 | https://drive.google.com/drive/folders/1C1l5QrCtg9UulMJE5n_on4A9O9Gn0CC5?usp=sharing 25 | 26 |
27 | The auto-encoder results shows that it's possible to compress many mesh models into tokens which then can be decoded and reconstruct a mesh near perfection!
28 | The auto-encoder was trained for 9 epochs for 20hrs on a single P100 GPU.

29 | 30 | The more compute heavy part is to train a transformer that can use these tokens learn the auto-encoder 'language'.
31 | Using the codes as a vocabablity and learn the relationship between the the codes and it's ordering requires a lot compute to train compared to the auto-encoder.
32 | So by using a single P100 GPU it will probaly take a few weeks till I can get out a pre-trained transformer. 33 |
34 | Let me know if you wish to donate any compute or I can provide you with the dataset + training notebook. 35 |

36 | 37 | ``` 38 | num_layers = 23 39 | autoencoder = MeshAutoencoder( 40 | decoder_dims_through_depth = (128,) * 3 + (192,) * 4 + (256,) * num_layers + (384,) * 3, 41 | dim_codebook = 192 , 42 | codebook_size = 16384 , 43 | dim_area_embed = 16, 44 | dim_coor_embed = 16, 45 | dim_normal_embed = 16, 46 | dim_angle_embed = 8, 47 | 48 | attn_decoder_depth = 8, 49 | attn_encoder_depth = 4 50 | ).to("cuda") 51 | ``` 52 | 53 | #### Results, it's about 14k models so with the limited training time and hardware It's a great result. 54 | ![bild](https://github.com/lucidrains/meshgpt-pytorch/assets/65302107/18949b70-a982-4d22-9346-0f40ecf21cae) 55 | 56 | ## Citations 57 | 58 | ```bibtex 59 | @inproceedings{Siddiqui2023MeshGPTGT, 60 | title = {MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers}, 61 | author = {Yawar Siddiqui and Antonio Alliegro and Alexey Artemov and Tatiana Tommasi and Daniele Sirigatti and Vladislav Rosov and Angela Dai and Matthias Nie{\ss}ner}, 62 | year = {2023}, 63 | url = {https://api.semanticscholar.org/CorpusID:265457242} 64 | } 65 | ``` 66 | 67 | ```bibtex 68 | @inproceedings{dao2022flashattention, 69 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, 70 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, 71 | booktitle = {Advances in Neural Information Processing Systems}, 72 | year = {2022} 73 | } 74 | ``` 75 | 76 | ```bibtex 77 | @inproceedings{Leviathan2022FastIF, 78 | title = {Fast Inference from Transformers via Speculative Decoding}, 79 | author = {Yaniv Leviathan and Matan Kalman and Y. Matias}, 80 | booktitle = {International Conference on Machine Learning}, 81 | year = {2022}, 82 | url = {https://api.semanticscholar.org/CorpusID:254096365} 83 | } 84 | ``` 85 | 86 | ```bibtex 87 | @misc{yu2023language, 88 | title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation}, 89 | author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang}, 90 | year = {2023}, 91 | eprint = {2310.05737}, 92 | archivePrefix = {arXiv}, 93 | primaryClass = {cs.CV} 94 | } 95 | ``` 96 | 97 | ```bibtex 98 | @article{Lee2022AutoregressiveIG, 99 | title = {Autoregressive Image Generation using Residual Quantization}, 100 | author = {Doyup Lee and Chiheon Kim and Saehoon Kim and Minsu Cho and Wook-Shin Han}, 101 | journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 102 | year = {2022}, 103 | pages = {11513-11522}, 104 | url = {https://api.semanticscholar.org/CorpusID:247244535} 105 | } 106 | ``` 107 | 108 | ```bibtex 109 | @inproceedings{Katsch2023GateLoopFD, 110 | title = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling}, 111 | author = {Tobias Katsch}, 112 | year = {2023}, 113 | url = {https://api.semanticscholar.org/CorpusID:265018962} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /demo_mesh/bar chair.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/bar chair.glb -------------------------------------------------------------------------------- /demo_mesh/circle chair.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/circle chair.glb -------------------------------------------------------------------------------- /demo_mesh/corner table.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/corner table.glb -------------------------------------------------------------------------------- /demo_mesh/designer chair.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/designer chair.glb -------------------------------------------------------------------------------- /demo_mesh/designer sloped chair.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/designer sloped chair.glb -------------------------------------------------------------------------------- /demo_mesh/glass table.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/glass table.glb -------------------------------------------------------------------------------- /demo_mesh/high chair.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/high chair.glb -------------------------------------------------------------------------------- /demo_mesh/office table.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/office table.glb -------------------------------------------------------------------------------- /demo_mesh/tv table.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/demo_mesh/tv table.glb -------------------------------------------------------------------------------- /meshgpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusLoppe/meshgpt-pytorch/8985669c284bd01740e6890b654a288a37896856/meshgpt.png -------------------------------------------------------------------------------- /meshgpt_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from meshgpt_pytorch.meshgpt_pytorch import ( 2 | MeshAutoencoder, 3 | MeshTransformer 4 | ) 5 | 6 | from meshgpt_pytorch.trainer import ( 7 | MeshAutoencoderTrainer, 8 | MeshTransformerTrainer 9 | ) 10 | 11 | from meshgpt_pytorch.data import ( 12 | DatasetFromTransforms, 13 | cache_text_embeds_for_dataset, 14 | cache_face_edges_for_dataset 15 | ) 16 | 17 | from meshgpt_pytorch.mesh_dataset import ( 18 | MeshDataset 19 | ) 20 | from meshgpt_pytorch.mesh_render import ( 21 | save_rendering, 22 | combind_mesh_with_rows 23 | ) 24 | 25 | -------------------------------------------------------------------------------- /meshgpt_pytorch/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from functools import partial 5 | 6 | import torch 7 | from torch import Tensor 8 | from torch import is_tensor 9 | from torch.utils.data import Dataset 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | from numpy.lib.format import open_memmap 13 | 14 | from einops import rearrange, reduce 15 | from torch import nn, Tensor 16 | 17 | from beartype.typing import Tuple, List, Callable, Dict 18 | from meshgpt_pytorch.typing import typecheck 19 | 20 | from torchtyping import TensorType 21 | 22 | from pytorch_custom_utils.utils import pad_or_slice_to 23 | 24 | # helper fn 25 | 26 | def exists(v): 27 | return v is not None 28 | 29 | def identity(t): 30 | return t 31 | 32 | # constants 33 | 34 | Vertices = TensorType['nv', 3, float] # 3 coordinates 35 | Faces = TensorType['nf', 3, int] # 3 vertices 36 | 37 | # decorator for auto-caching texts -> text embeds 38 | 39 | # you would decorate your Dataset class with this 40 | # and then change your `data_kwargs = ["text_embeds", "vertices", "faces"]` 41 | 42 | @typecheck 43 | def cache_text_embeds_for_dataset( 44 | embed_texts_fn: Callable[[List[str]], Tensor], 45 | max_text_len: int, 46 | cache_path: str = './text_embed_cache' 47 | ): 48 | # create path to cache folder 49 | 50 | path = Path(cache_path) 51 | path.mkdir(exist_ok = True, parents = True) 52 | assert path.is_dir() 53 | 54 | # global memmap handles 55 | 56 | text_embed_cache = None 57 | is_cached = None 58 | 59 | # cache function 60 | 61 | def get_maybe_cached_text_embed( 62 | idx: int, 63 | dataset_len: int, 64 | text: str, 65 | memmap_file_mode = 'w+' 66 | ): 67 | nonlocal text_embed_cache 68 | nonlocal is_cached 69 | 70 | # init cache on first call 71 | 72 | if not exists(text_embed_cache): 73 | test_embed = embed_texts_fn(['test']) 74 | feat_dim = test_embed.shape[-1] 75 | shape = (dataset_len, max_text_len, feat_dim) 76 | 77 | text_embed_cache = open_memmap(str(path / 'cache.text_embed.memmap.npy'), mode = memmap_file_mode, dtype = 'float32', shape = shape) 78 | is_cached = open_memmap(str(path / 'cache.is_cached.memmap.npy'), mode = memmap_file_mode, dtype = 'bool', shape = (dataset_len,)) 79 | 80 | # determine whether to fetch from cache 81 | # or call text model 82 | 83 | if is_cached[idx]: 84 | text_embed = torch.from_numpy(text_embed_cache[idx]) 85 | else: 86 | # cache 87 | 88 | text_embed = get_text_embed(text) 89 | text_embed = pad_or_slice_to(text_embed, max_text_len, dim = 0, pad_value = 0.) 90 | 91 | is_cached[idx] = True 92 | text_embed_cache[idx] = text_embed.cpu().numpy() 93 | 94 | mask = ~reduce(text_embed == 0, 'n d -> n', 'all') 95 | return text_embed[mask] 96 | 97 | # get text embedding 98 | 99 | def get_text_embed(text: str): 100 | text_embeds = embed_texts_fn([text]) 101 | return text_embeds[0] 102 | 103 | # inner function 104 | 105 | def inner(dataset_klass): 106 | assert issubclass(dataset_klass, Dataset) 107 | 108 | orig_init = dataset_klass.__init__ 109 | orig_get_item = dataset_klass.__getitem__ 110 | 111 | def __init__( 112 | self, 113 | *args, 114 | cache_memmap_file_mode = 'w+', 115 | **kwargs 116 | ): 117 | orig_init(self, *args, **kwargs) 118 | 119 | self._cache_memmap_file_mode = cache_memmap_file_mode 120 | 121 | if hasattr(self, 'data_kwargs'): 122 | self.data_kwargs = [('text_embeds' if data_kwarg == 'texts' else data_kwarg) for data_kwarg in self.data_kwargs] 123 | 124 | def __getitem__(self, idx): 125 | items = orig_get_item(self, idx) 126 | 127 | get_text_embed_ = partial(get_maybe_cached_text_embed, idx, len(self), memmap_file_mode = self._cache_memmap_file_mode) 128 | 129 | if isinstance(items, dict): 130 | if 'texts' in items: 131 | text_embed = get_text_embed_(items['texts']) 132 | items['text_embeds'] = text_embed 133 | del items['texts'] 134 | 135 | elif isinstance(items, tuple): 136 | new_items = [] 137 | 138 | for maybe_text in items: 139 | if not isinstance(maybe_text, str): 140 | new_items.append(maybe_text) 141 | continue 142 | 143 | new_items.append(get_text_embed_(maybe_text)) 144 | 145 | items = tuple(new_items) 146 | 147 | return items 148 | 149 | dataset_klass.__init__ = __init__ 150 | dataset_klass.__getitem__ = __getitem__ 151 | 152 | return dataset_klass 153 | 154 | return inner 155 | 156 | # decorator for auto-caching face edges 157 | 158 | # you would decorate your Dataset class with this function 159 | # and then change your `data_kwargs = ["vertices", "faces", "face_edges"]` 160 | 161 | @typecheck 162 | def cache_face_edges_for_dataset( 163 | max_edges_len: int, 164 | cache_path: str = './face_edges_cache', 165 | assert_edge_len_lt_max: bool = True, 166 | pad_id = -1 167 | ): 168 | # create path to cache folder 169 | 170 | path = Path(cache_path) 171 | path.mkdir(exist_ok = True, parents = True) 172 | assert path.is_dir() 173 | 174 | # global memmap handles 175 | 176 | face_edges_cache = None 177 | is_cached = None 178 | 179 | # cache function 180 | 181 | def get_maybe_cached_face_edges( 182 | idx: int, 183 | dataset_len: int, 184 | faces: Tensor, 185 | memmap_file_mode = 'w+' 186 | ): 187 | nonlocal face_edges_cache 188 | nonlocal is_cached 189 | 190 | if not exists(face_edges_cache): 191 | # init cache on first call 192 | 193 | shape = (dataset_len, max_edges_len, 2) 194 | face_edges_cache = open_memmap(str(path / 'cache.face_edges_embed.memmap.npy'), mode = memmap_file_mode, dtype = 'float32', shape = shape) 195 | is_cached = open_memmap(str(path / 'cache.is_cached.memmap.npy'), mode = memmap_file_mode, dtype = 'bool', shape = (dataset_len,)) 196 | 197 | # determine whether to fetch from cache 198 | # or call derive face edges function 199 | 200 | if is_cached[idx]: 201 | face_edges = torch.from_numpy(face_edges_cache[idx]) 202 | else: 203 | # cache 204 | 205 | face_edges = derive_face_edges_from_faces(faces, pad_id = pad_id) 206 | 207 | edge_len = face_edges.shape[0] 208 | assert not assert_edge_len_lt_max or (edge_len <= max_edges_len), f'mesh #{idx} has {edge_len} which exceeds the cache length of {max_edges_len}' 209 | 210 | face_edges = pad_or_slice_to(face_edges, max_edges_len, dim = 0, pad_value = pad_id) 211 | 212 | is_cached[idx] = True 213 | face_edges_cache[idx] = face_edges.cpu().numpy() 214 | 215 | mask = reduce(face_edges != pad_id, 'n d -> n', 'all') 216 | return face_edges[mask] 217 | 218 | # inner function 219 | 220 | def inner(dataset_klass): 221 | assert issubclass(dataset_klass, Dataset) 222 | 223 | orig_init = dataset_klass.__init__ 224 | orig_get_item = dataset_klass.__getitem__ 225 | 226 | def __init__( 227 | self, 228 | *args, 229 | cache_memmap_file_mode = 'w+', 230 | **kwargs 231 | ): 232 | orig_init(self, *args, **kwargs) 233 | 234 | self._cache_memmap_file_mode = cache_memmap_file_mode 235 | 236 | if hasattr(self, 'data_kwargs'): 237 | self.data_kwargs.append('face_edges') 238 | 239 | def __getitem__(self, idx): 240 | items = orig_get_item(self, idx) 241 | 242 | get_face_edges_ = partial(get_maybe_cached_face_edges, idx, len(self), memmap_file_mode = self._cache_memmap_file_mode) 243 | 244 | if isinstance(items, dict): 245 | face_edges = get_face_edges_(items['faces']) 246 | items['face_edges'] = face_edges 247 | 248 | elif isinstance(items, tuple): 249 | _, faces, *_ = items 250 | face_edges = get_face_edges_(faces) 251 | items = (*items, face_edges) 252 | 253 | return items 254 | 255 | dataset_klass.__init__ = __init__ 256 | dataset_klass.__getitem__ = __getitem__ 257 | 258 | return dataset_klass 259 | 260 | return inner 261 | 262 | # dataset 263 | 264 | class DatasetFromTransforms(Dataset): 265 | @typecheck 266 | def __init__( 267 | self, 268 | folder: str, 269 | transforms: Dict[str, Callable[[Path], Tuple[Vertices, Faces]]], 270 | data_kwargs: List[str] | None = None, 271 | augment_fn: Callable = identity 272 | ): 273 | folder = Path(folder) 274 | assert folder.exists and folder.is_dir() 275 | self.folder = folder 276 | 277 | exts = transforms.keys() 278 | self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')] 279 | 280 | print(f'{len(self.paths)} training samples found at {folder}') 281 | assert len(self.paths) > 0 282 | 283 | self.transforms = transforms 284 | self.data_kwargs = data_kwargs 285 | self.augment_fn = augment_fn 286 | 287 | def __len__(self): 288 | return len(self.paths) 289 | 290 | def __getitem__(self, idx): 291 | path = self.paths[idx] 292 | ext = path.suffix[1:] 293 | fn = self.transforms[ext] 294 | 295 | out = fn(path) 296 | return self.augment_fn(out) 297 | 298 | # tensor helper functions 299 | 300 | def derive_face_edges_from_faces( 301 | faces: TensorType['b', 'nf', 3, int], 302 | pad_id = -1, 303 | neighbor_if_share_one_vertex = False, 304 | include_self = True 305 | ) -> TensorType['b', 'e', 2, int]: 306 | 307 | is_one_face, device = faces.ndim == 2, faces.device 308 | 309 | if is_one_face: 310 | faces = rearrange(faces, 'nf c -> 1 nf c') 311 | 312 | max_num_faces = faces.shape[1] 313 | face_edges_vertices_threshold = 1 if neighbor_if_share_one_vertex else 2 314 | 315 | all_edges = torch.stack(torch.meshgrid( 316 | torch.arange(max_num_faces, device = device), 317 | torch.arange(max_num_faces, device = device), 318 | indexing = 'ij'), dim = -1) 319 | 320 | face_masks = reduce(faces != pad_id, 'b nf c -> b nf', 'all') 321 | face_edges_masks = rearrange(face_masks, 'b i -> b i 1') & rearrange(face_masks, 'b j -> b 1 j') 322 | 323 | face_edges = [] 324 | 325 | for face, face_edge_mask in zip(faces, face_edges_masks): 326 | 327 | shared_vertices = rearrange(face, 'i c -> i 1 c 1') == rearrange(face, 'j c -> 1 j 1 c') 328 | num_shared_vertices = shared_vertices.any(dim = -1).sum(dim = -1) 329 | 330 | is_neighbor_face = (num_shared_vertices >= face_edges_vertices_threshold) & face_edge_mask 331 | 332 | if not include_self: 333 | is_neighbor_face &= num_shared_vertices != 3 334 | 335 | face_edge = all_edges[is_neighbor_face] 336 | face_edges.append(face_edge) 337 | 338 | face_edges = pad_sequence(face_edges, padding_value = pad_id, batch_first = True) 339 | 340 | if is_one_face: 341 | face_edges = rearrange(face_edges, '1 e ij -> e ij') 342 | 343 | return face_edges 344 | 345 | # custom collater 346 | 347 | def first(it): 348 | return it[0] 349 | 350 | def custom_collate(data, pad_id = -1): 351 | is_dict = isinstance(first(data), dict) 352 | 353 | if is_dict: 354 | keys = first(data).keys() 355 | data = [d.values() for d in data] 356 | 357 | output = [] 358 | 359 | for datum in zip(*data): 360 | if is_tensor(first(datum)): 361 | datum = pad_sequence(datum, batch_first = True, padding_value = pad_id) 362 | else: 363 | datum = list(datum) 364 | output.append(datum) 365 | 366 | output.append(datum) 367 | 368 | output = tuple(output) 369 | 370 | if is_dict: 371 | output = dict(zip(keys, output)) 372 | 373 | return output -------------------------------------------------------------------------------- /meshgpt_pytorch/mesh_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | from torch.nn.utils.rnn import pad_sequence 4 | from tqdm import tqdm 5 | import torch 6 | from meshgpt_pytorch import ( 7 | MeshAutoencoder, 8 | MeshTransformer 9 | ) 10 | 11 | from meshgpt_pytorch.data import ( 12 | derive_face_edges_from_faces 13 | ) 14 | 15 | class MeshDataset(Dataset): 16 | """ 17 | A PyTorch Dataset to load and process mesh data. 18 | The `MeshDataset` provides functions to load mesh data from a file, embed text information, generate face edges, and generate codes. 19 | 20 | Attributes: 21 | data (list): A list of mesh data entries. Each entry is a dictionary containing the following keys: 22 | vertices (torch.Tensor): A tensor of vertices with shape (num_vertices, 3). 23 | faces (torch.Tensor): A tensor of faces with shape (num_faces, 3). 24 | text (str): A string containing the associated text information for the mesh. 25 | text_embeds (torch.Tensor): A tensor of text embeddings for the mesh. 26 | face_edges (torch.Tensor): A tensor of face edges with shape (num_faces, num_edges). 27 | codes (torch.Tensor): A tensor of codes generated from the mesh data. 28 | 29 | Example usage: 30 | 31 | ``` 32 | data = [ 33 | {'vertices': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'faces': torch.tensor([[0, 1, 2]], dtype=torch.long), 'text': 'table'}, 34 | {'vertices': torch.tensor([[10, 20, 30], [40, 50, 60]], dtype=torch.float32), 'faces': torch.tensor([[1, 2, 0]], dtype=torch.long), "text": "chair"}, 35 | ] 36 | 37 | # Create a MeshDataset instance 38 | mesh_dataset = MeshDataset(data) 39 | 40 | # Save the MeshDataset to disk 41 | mesh_dataset.save('mesh_dataset.npz') 42 | 43 | # Load the MeshDataset from disk 44 | loaded_mesh_dataset = MeshDataset.load('mesh_dataset.npz') 45 | 46 | # Generate face edges so it doesn't need to be done every time during training 47 | dataset.generate_face_edges() 48 | ``` 49 | """ 50 | def __init__(self, data): 51 | self.data = data 52 | print(f"[MeshDataset] Created from {len(self.data)} entries") 53 | 54 | def __len__(self): 55 | return len(self.data) 56 | 57 | def __getitem__(self, idx): 58 | data = self.data[idx] 59 | return data 60 | 61 | def save(self, path): 62 | np.savez_compressed(path, self.data, allow_pickle=True) 63 | print(f"[MeshDataset] Saved {len(self.data)} entries at {path}") 64 | 65 | @classmethod 66 | def load(cls, path): 67 | loaded_data = np.load(path, allow_pickle=True) 68 | data = [] 69 | for item in loaded_data["arr_0"]: 70 | data.append(item) 71 | print(f"[MeshDataset] Loaded {len(data)} entries") 72 | return cls(data) 73 | 74 | def sort_dataset_keys(self): 75 | desired_order = ['vertices', 'faces', 'face_edges', 'texts','text_embeds','codes'] 76 | self.data = [ 77 | {key: d[key] for key in desired_order if key in d} for d in self.data 78 | ] 79 | 80 | def generate_face_edges(self, batch_size = 5): 81 | data_to_process = [item for item in self.data if 'faces_edges' not in item] 82 | 83 | total_batches = (len(data_to_process) + batch_size - 1) // batch_size 84 | device = "cuda" if torch.cuda.is_available() else "cpu" 85 | 86 | for i in tqdm(range(0, len(data_to_process), batch_size), total=total_batches): 87 | batch_data = data_to_process[i:i+batch_size] 88 | 89 | if not batch_data: 90 | continue 91 | 92 | padded_batch_faces = pad_sequence( 93 | [item['faces'] for item in batch_data], 94 | batch_first=True, 95 | padding_value=-1 96 | ).to(device) 97 | 98 | batched_faces_edges = derive_face_edges_from_faces(padded_batch_faces, pad_id=-1) 99 | 100 | mask = (batched_faces_edges != -1).all(dim=-1) 101 | for item_idx, (item_edges, item_mask) in enumerate(zip(batched_faces_edges, mask)): 102 | item_edges_masked = item_edges[item_mask] 103 | item = batch_data[item_idx] 104 | item['face_edges'] = item_edges_masked 105 | 106 | self.sort_dataset_keys() 107 | print(f"[MeshDataset] Generated face_edges for {len(data_to_process)} entries") 108 | 109 | def generate_codes(self, autoencoder : MeshAutoencoder, batch_size = 25): 110 | total_batches = (len(self.data) + batch_size - 1) // batch_size 111 | 112 | for i in tqdm(range(0, len(self.data), batch_size), total=total_batches): 113 | batch_data = self.data[i:i+batch_size] 114 | 115 | padded_batch_vertices = pad_sequence([item['vertices'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device) 116 | padded_batch_faces = pad_sequence([item['faces'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device) 117 | padded_batch_face_edges = pad_sequence([item['face_edges'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device) 118 | 119 | batch_codes = autoencoder.tokenize( 120 | vertices=padded_batch_vertices, 121 | faces=padded_batch_faces, 122 | face_edges=padded_batch_face_edges 123 | ) 124 | 125 | 126 | mask = (batch_codes != autoencoder.pad_id).all(dim=-1) 127 | for item_idx, (item_codes, item_mask) in enumerate(zip(batch_codes, mask)): 128 | item_codes_masked = item_codes[item_mask] 129 | item = batch_data[item_idx] 130 | item['codes'] = item_codes_masked 131 | 132 | self.sort_dataset_keys() 133 | print(f"[MeshDataset] Generated codes for {len(self.data)} entries") 134 | 135 | def embed_texts(self, transformer : MeshTransformer, batch_size = 50): 136 | unique_texts = list(set(item['texts'] for item in self.data)) 137 | text_embedding_dict = {} 138 | for i in tqdm(range(0,len(unique_texts), batch_size)): 139 | batch_texts = unique_texts[i:i+batch_size] 140 | text_embeddings = transformer.embed_texts(batch_texts) 141 | mask = (text_embeddings != transformer.conditioner.text_embed_pad_value).all(dim=-1) 142 | 143 | for idx, text in enumerate(batch_texts): 144 | masked_embedding = text_embeddings[idx][mask[idx]] 145 | text_embedding_dict[text] = masked_embedding 146 | 147 | for item in self.data: 148 | if 'texts' in item: 149 | item['text_embeds'] = text_embedding_dict.get(item['texts'], None) 150 | del item['texts'] 151 | 152 | self.sort_dataset_keys() 153 | print(f"[MeshDataset] Generated {len(text_embedding_dict)} text_embeddings") 154 | -------------------------------------------------------------------------------- /meshgpt_pytorch/mesh_render.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | def orient_triangle_upward(v1, v2, v3): 5 | edge1 = v2 - v1 6 | edge2 = v3 - v1 7 | normal = np.cross(edge1, edge2) 8 | normal = normal / np.linalg.norm(normal) 9 | 10 | up = np.array([0, 1, 0]) 11 | if np.dot(normal, up) < 0: 12 | v1, v3 = v3, v1 13 | return v1, v2, v3 14 | 15 | def get_angle(v1, v2, v3): 16 | v1, v2, v3 = orient_triangle_upward(v1, v2, v3) 17 | vec1 = v2 - v1 18 | vec2 = v3 - v1 19 | angle_rad = np.arccos(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))) 20 | return math.degrees(angle_rad) 21 | 22 | def combind_mesh_with_rows(path, meshes): 23 | all_vertices = [] 24 | all_faces = [] 25 | vertex_offset = 0 26 | translation_distance = 0.5 27 | obj_file_content = "" 28 | 29 | for row, mesh in enumerate(meshes): 30 | for r, faces_coordinates in enumerate(mesh): 31 | numpy_data = faces_coordinates[0].cpu().numpy().reshape(-1, 3) 32 | numpy_data[:, 0] += translation_distance * (r / 0.2 - 1) 33 | numpy_data[:, 2] += translation_distance * (row / 0.2 - 1) 34 | 35 | for vertex in numpy_data: 36 | all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n") 37 | 38 | for i in range(1, len(numpy_data), 3): 39 | all_faces.append(f"f {i + vertex_offset} {i + 1 + vertex_offset} {i + 2 + vertex_offset}\n") 40 | 41 | vertex_offset += len(numpy_data) 42 | 43 | obj_file_content = "".join(all_vertices) + "".join(all_faces) 44 | 45 | with open(path , "w") as file: 46 | file.write(obj_file_content) 47 | 48 | 49 | def save_rendering(path, input_meshes): 50 | all_vertices,all_faces = [],[] 51 | vertex_offset = 0 52 | translation_distance = 0.5 53 | obj_file_content = "" 54 | meshes = input_meshes if isinstance(input_meshes, list) else [input_meshes] 55 | 56 | for row, mesh in enumerate(meshes): 57 | mesh = mesh if isinstance(mesh, list) else [mesh] 58 | cell_offset = 0 59 | for tensor, mask in mesh: 60 | for tensor_batch, mask_batch in zip(tensor,mask): 61 | numpy_data = tensor_batch[mask_batch].cpu().numpy().reshape(-1, 3) 62 | numpy_data[:, 0] += translation_distance * (cell_offset / 0.2 - 1) 63 | numpy_data[:, 2] += translation_distance * (row / 0.2 - 1) 64 | cell_offset += 1 65 | for vertex in numpy_data: 66 | all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n") 67 | 68 | mesh_center = np.mean(numpy_data, axis=0) 69 | for i in range(0, len(numpy_data), 3): 70 | v1 = numpy_data[i] 71 | v2 = numpy_data[i + 1] 72 | v3 = numpy_data[i + 2] 73 | 74 | normal = np.cross(v2 - v1, v3 - v1) 75 | if get_angle(v1, v2, v3) > 60: 76 | direction_vector = mesh_center - np.mean([v1, v2, v3], axis=0) 77 | direction_vector = -direction_vector 78 | else: 79 | direction_vector = [0, 1, 0] 80 | 81 | if np.dot(normal, direction_vector) > 0: 82 | order = [0, 1, 2] 83 | else: 84 | order = [0, 2, 1] 85 | 86 | reordered_vertices = [v1, v2, v3][order[0]], [v1, v2, v3][order[1]], [v1, v2, v3][order[2]] 87 | indices = [np.where((numpy_data == vertex).all(axis=1))[0][0] + 1 + vertex_offset for vertex in reordered_vertices] 88 | all_faces.append(f"f {indices[0]} {indices[1]} {indices[2]}\n") 89 | 90 | vertex_offset += len(numpy_data) 91 | obj_file_content = "".join(all_vertices) + "".join(all_faces) 92 | 93 | with open(path , "w") as file: 94 | file.write(obj_file_content) 95 | 96 | print(f"[Save_rendering] Saved at {path}") 97 | -------------------------------------------------------------------------------- /meshgpt_pytorch/meshgpt_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from functools import partial 5 | from math import ceil, pi, sqrt 6 | 7 | import torch 8 | from torch import nn, Tensor, einsum 9 | from torch.nn import Module, ModuleList 10 | import torch.nn.functional as F 11 | from torch.utils.checkpoint import checkpoint 12 | from torch.cuda.amp import autocast 13 | 14 | from pytorch_custom_utils import save_load 15 | 16 | from beartype.typing import Tuple, Callable, List, Dict, Any 17 | from meshgpt_pytorch.typing import Float, Int, Bool, typecheck 18 | 19 | from huggingface_hub import PyTorchModelHubMixin, hf_hub_download 20 | 21 | from einops import rearrange, repeat, reduce, pack, unpack 22 | from einops.layers.torch import Rearrange 23 | 24 | from einx import get_at 25 | 26 | from x_transformers import Decoder 27 | from x_transformers.x_transformers import RMSNorm, FeedForward, LayerIntermediates 28 | 29 | from x_transformers.autoregressive_wrapper import ( 30 | eval_decorator, 31 | top_k, 32 | top_p, 33 | ) 34 | 35 | from local_attention import LocalMHA 36 | 37 | from vector_quantize_pytorch import ( 38 | ResidualVQ, 39 | ResidualLFQ 40 | ) 41 | 42 | from meshgpt_pytorch.data import derive_face_edges_from_faces 43 | from meshgpt_pytorch.version import __version__ 44 | 45 | from taylor_series_linear_attention import TaylorSeriesLinearAttn 46 | 47 | from classifier_free_guidance_pytorch import ( 48 | classifier_free_guidance, 49 | TextEmbeddingReturner 50 | ) 51 | 52 | from torch_geometric.nn.conv import SAGEConv 53 | 54 | from gateloop_transformer import SimpleGateLoopLayer 55 | 56 | from tqdm import tqdm 57 | 58 | # helper functions 59 | 60 | def exists(v): 61 | return v is not None 62 | 63 | def default(v, d): 64 | return v if exists(v) else d 65 | 66 | def first(it): 67 | return it[0] 68 | 69 | def identity(t, *args, **kwargs): 70 | return t 71 | 72 | def divisible_by(num, den): 73 | return (num % den) == 0 74 | 75 | def is_odd(n): 76 | return not divisible_by(n, 2) 77 | 78 | def is_empty(x): 79 | return len(x) == 0 80 | 81 | def is_tensor_empty(t: Tensor): 82 | return t.numel() == 0 83 | 84 | def set_module_requires_grad_( 85 | module: Module, 86 | requires_grad: bool 87 | ): 88 | for param in module.parameters(): 89 | param.requires_grad = requires_grad 90 | 91 | def l1norm(t): 92 | return F.normalize(t, dim = -1, p = 1) 93 | 94 | def l2norm(t): 95 | return F.normalize(t, dim = -1, p = 2) 96 | 97 | def safe_cat(tensors, dim): 98 | tensors = [*filter(exists, tensors)] 99 | 100 | if len(tensors) == 0: 101 | return None 102 | elif len(tensors) == 1: 103 | return first(tensors) 104 | 105 | return torch.cat(tensors, dim = dim) 106 | 107 | def pad_at_dim(t, padding, dim = -1, value = 0): 108 | ndim = t.ndim 109 | right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1) 110 | zeros = (0, 0) * right_dims 111 | return F.pad(t, (*zeros, *padding), value = value) 112 | 113 | def pad_to_length(t, length, dim = -1, value = 0, right = True): 114 | curr_length = t.shape[dim] 115 | remainder = length - curr_length 116 | 117 | if remainder <= 0: 118 | return t 119 | 120 | padding = (0, remainder) if right else (remainder, 0) 121 | return pad_at_dim(t, padding, dim = dim, value = value) 122 | 123 | def masked_mean(tensor, mask, dim = -1, eps = 1e-5): 124 | if not exists(mask): 125 | return tensor.mean(dim = dim) 126 | 127 | mask = rearrange(mask, '... -> ... 1') 128 | tensor = tensor.masked_fill(~mask, 0.) 129 | 130 | total_el = mask.sum(dim = dim) 131 | num = tensor.sum(dim = dim) 132 | den = total_el.float().clamp(min = eps) 133 | mean = num / den 134 | mean = mean.masked_fill(total_el == 0, 0.) 135 | return mean 136 | 137 | # continuous embed 138 | 139 | def ContinuousEmbed(dim_cont): 140 | return nn.Sequential( 141 | Rearrange('... -> ... 1'), 142 | nn.Linear(1, dim_cont), 143 | nn.SiLU(), 144 | nn.Linear(dim_cont, dim_cont), 145 | nn.LayerNorm(dim_cont) 146 | ) 147 | 148 | # additional encoder features 149 | # 1. angle (3), 2. area (1), 3. normals (3) 150 | 151 | def derive_angle(x, y, eps = 1e-5): 152 | z = einsum('... d, ... d -> ...', l2norm(x), l2norm(y)) 153 | return z.clip(-1 + eps, 1 - eps).arccos() 154 | 155 | @torch.no_grad() 156 | @typecheck 157 | def get_derived_face_features( 158 | face_coords: Float['b nf nvf 3'] # 3 or 4 vertices with 3 coordinates 159 | ): 160 | is_quad = face_coords.shape[-2] == 4 161 | 162 | # shift face coordinates depending on triangles or quads 163 | 164 | shifted_face_coords = torch.roll(face_coords, 1, dims = (2,)) 165 | 166 | angles = derive_angle(face_coords, shifted_face_coords) 167 | 168 | if is_quad: 169 | # @sbriseid says quads need to be shifted by 2 170 | shifted_face_coords = torch.roll(shifted_face_coords, 1, dims = (2,)) 171 | 172 | edge1, edge2, *_ = (face_coords - shifted_face_coords).unbind(dim = 2) 173 | 174 | cross_product = torch.cross(edge1, edge2, dim = -1) 175 | 176 | normals = l2norm(cross_product) 177 | area = cross_product.norm(dim = -1, keepdim = True) * 0.5 178 | 179 | return dict( 180 | angles = angles, 181 | area = area, 182 | normals = normals 183 | ) 184 | 185 | # tensor helper functions 186 | 187 | @typecheck 188 | def discretize( 189 | t: Tensor, 190 | *, 191 | continuous_range: Tuple[float, float], 192 | num_discrete: int = 128 193 | ) -> Tensor: 194 | lo, hi = continuous_range 195 | assert hi > lo 196 | 197 | t = (t - lo) / (hi - lo) 198 | t *= num_discrete 199 | t -= 0.5 200 | 201 | return t.round().long().clamp(min = 0, max = num_discrete - 1) 202 | 203 | @typecheck 204 | def undiscretize( 205 | t: Tensor, 206 | *, 207 | continuous_range = Tuple[float, float], 208 | num_discrete: int = 128 209 | ) -> Tensor: 210 | lo, hi = continuous_range 211 | assert hi > lo 212 | 213 | t = t.float() 214 | 215 | t += 0.5 216 | t /= num_discrete 217 | return t * (hi - lo) + lo 218 | 219 | @typecheck 220 | def gaussian_blur_1d( 221 | t: Tensor, 222 | *, 223 | sigma: float = 1. 224 | ) -> Tensor: 225 | 226 | _, _, channels, device, dtype = *t.shape, t.device, t.dtype 227 | 228 | width = int(ceil(sigma * 5)) 229 | width += (width + 1) % 2 230 | half_width = width // 2 231 | 232 | distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device) 233 | 234 | gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2)) 235 | gaussian = l1norm(gaussian) 236 | 237 | kernel = repeat(gaussian, 'n -> c 1 n', c = channels) 238 | 239 | t = rearrange(t, 'b n c -> b c n') 240 | out = F.conv1d(t, kernel, padding = half_width, groups = channels) 241 | return rearrange(out, 'b c n -> b n c') 242 | 243 | @typecheck 244 | def scatter_mean( 245 | tgt: Tensor, 246 | indices: Tensor, 247 | src = Tensor, 248 | *, 249 | dim: int = -1, 250 | eps: float = 1e-5 251 | ): 252 | """ 253 | todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_ 254 | """ 255 | num = tgt.scatter_add(dim, indices, src) 256 | den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src)) 257 | return num / den.clamp(min = eps) 258 | 259 | # resnet block 260 | 261 | class FiLM(Module): 262 | def __init__(self, dim, dim_out = None): 263 | super().__init__() 264 | dim_out = default(dim_out, dim) 265 | 266 | self.to_gamma = nn.Linear(dim, dim_out, bias = False) 267 | self.to_beta = nn.Linear(dim, dim_out) 268 | 269 | self.gamma_mult = nn.Parameter(torch.zeros(1,)) 270 | self.beta_mult = nn.Parameter(torch.zeros(1,)) 271 | 272 | def forward(self, x, cond): 273 | gamma, beta = self.to_gamma(cond), self.to_beta(cond) 274 | gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta)) 275 | 276 | # for initializing to identity 277 | 278 | gamma = (1 + self.gamma_mult * gamma.tanh()) 279 | beta = beta.tanh() * self.beta_mult 280 | 281 | # classic film 282 | 283 | return x * gamma + beta 284 | 285 | class PixelNorm(Module): 286 | def __init__(self, dim, eps = 1e-4): 287 | super().__init__() 288 | self.dim = dim 289 | self.eps = eps 290 | 291 | def forward(self, x): 292 | dim = self.dim 293 | return F.normalize(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim]) 294 | 295 | class SqueezeExcite(Module): 296 | def __init__( 297 | self, 298 | dim, 299 | reduction_factor = 4, 300 | min_dim = 16 301 | ): 302 | super().__init__() 303 | dim_inner = max(dim // reduction_factor, min_dim) 304 | 305 | self.net = nn.Sequential( 306 | nn.Linear(dim, dim_inner), 307 | nn.SiLU(), 308 | nn.Linear(dim_inner, dim), 309 | nn.Sigmoid(), 310 | Rearrange('b c -> b c 1') 311 | ) 312 | 313 | def forward(self, x, mask = None): 314 | if exists(mask): 315 | x = x.masked_fill(~mask, 0.) 316 | 317 | num = reduce(x, 'b c n -> b c', 'sum') 318 | den = reduce(mask.float(), 'b 1 n -> b 1', 'sum') 319 | avg = num / den.clamp(min = 1e-5) 320 | else: 321 | avg = reduce(x, 'b c n -> b c', 'mean') 322 | 323 | return x * self.net(avg) 324 | 325 | class Block(Module): 326 | def __init__( 327 | self, 328 | dim, 329 | dim_out = None, 330 | dropout = 0. 331 | ): 332 | super().__init__() 333 | dim_out = default(dim_out, dim) 334 | 335 | self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1) 336 | self.norm = PixelNorm(dim = 1) 337 | self.dropout = nn.Dropout(dropout) 338 | self.act = nn.SiLU() 339 | 340 | def forward(self, x, mask = None): 341 | if exists(mask): 342 | x = x.masked_fill(~mask, 0.) 343 | 344 | x = self.proj(x) 345 | 346 | if exists(mask): 347 | x = x.masked_fill(~mask, 0.) 348 | 349 | x = self.norm(x) 350 | x = self.act(x) 351 | x = self.dropout(x) 352 | 353 | return x 354 | 355 | class ResnetBlock(Module): 356 | def __init__( 357 | self, 358 | dim, 359 | dim_out = None, 360 | *, 361 | dropout = 0. 362 | ): 363 | super().__init__() 364 | dim_out = default(dim_out, dim) 365 | self.block1 = Block(dim, dim_out, dropout = dropout) 366 | self.block2 = Block(dim_out, dim_out, dropout = dropout) 367 | self.excite = SqueezeExcite(dim_out) 368 | self.residual_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 369 | 370 | def forward( 371 | self, 372 | x, 373 | mask = None 374 | ): 375 | res = self.residual_conv(x) 376 | h = self.block1(x, mask = mask) 377 | h = self.block2(h, mask = mask) 378 | h = self.excite(h, mask = mask) 379 | return h + res 380 | 381 | # gateloop layers 382 | 383 | class GateLoopBlock(Module): 384 | def __init__( 385 | self, 386 | dim, 387 | *, 388 | depth, 389 | use_heinsen = True 390 | ): 391 | super().__init__() 392 | self.gateloops = ModuleList([]) 393 | 394 | for _ in range(depth): 395 | gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen) 396 | self.gateloops.append(gateloop) 397 | 398 | def forward( 399 | self, 400 | x, 401 | cache = None 402 | ): 403 | received_cache = exists(cache) 404 | 405 | if is_tensor_empty(x): 406 | return x, None 407 | 408 | if received_cache: 409 | prev, x = x[:, :-1], x[:, -1:] 410 | 411 | cache = default(cache, []) 412 | cache = iter(cache) 413 | 414 | new_caches = [] 415 | for gateloop in self.gateloops: 416 | layer_cache = next(cache, None) 417 | out, new_cache = gateloop(x, cache = layer_cache, return_cache = True) 418 | new_caches.append(new_cache) 419 | x = x + out 420 | 421 | if received_cache: 422 | x = torch.cat((prev, x), dim = -2) 423 | 424 | return x, new_caches 425 | 426 | # main classes 427 | 428 | @save_load(version = __version__) 429 | class MeshAutoencoder(Module): 430 | @typecheck 431 | def __init__( 432 | self, 433 | num_discrete_coors = 128, 434 | coor_continuous_range: Tuple[float, float] = (-1., 1.), 435 | dim_coor_embed = 64, 436 | num_discrete_area = 128, 437 | dim_area_embed = 16, 438 | num_discrete_normals = 128, 439 | dim_normal_embed = 64, 440 | num_discrete_angle = 128, 441 | dim_angle_embed = 16, 442 | encoder_dims_through_depth: Tuple[int, ...] = ( 443 | 64, 128, 256, 256, 576 444 | ), 445 | init_decoder_conv_kernel = 7, 446 | decoder_dims_through_depth: Tuple[int, ...] = ( 447 | 128, 128, 128, 128, 448 | 192, 192, 192, 192, 449 | 256, 256, 256, 256, 256, 256, 450 | 384, 384, 384 451 | ), 452 | dim_codebook = 192, 453 | num_quantizers = 2, # or 'D' in the paper 454 | codebook_size = 16384, # they use 16k, shared codebook between layers 455 | use_residual_lfq = True, # whether to use the latest lookup-free quantization 456 | rq_kwargs: dict = dict( 457 | quantize_dropout = True, 458 | quantize_dropout_cutoff_index = 1, 459 | quantize_dropout_multiple_of = 1, 460 | ), 461 | rvq_kwargs: dict = dict( 462 | kmeans_init = True, 463 | threshold_ema_dead_code = 2, 464 | ), 465 | rlfq_kwargs: dict = dict( 466 | frac_per_sample_entropy = 1., 467 | soft_clamp_input_value = 10. 468 | ), 469 | rvq_stochastic_sample_codes = True, 470 | sageconv_kwargs: dict = dict( 471 | normalize = True, 472 | project = True 473 | ), 474 | commit_loss_weight = 0.1, 475 | bin_smooth_blur_sigma = 0.4, # they blur the one hot discretized coordinate positions 476 | attn_encoder_depth = 0, 477 | attn_decoder_depth = 0, 478 | local_attn_kwargs: dict = dict( 479 | dim_head = 32, 480 | heads = 8 481 | ), 482 | local_attn_window_size = 64, 483 | linear_attn_kwargs: dict = dict( 484 | dim_head = 8, 485 | heads = 16 486 | ), 487 | use_linear_attn = True, 488 | pad_id = -1, 489 | flash_attn = True, 490 | attn_dropout = 0., 491 | ff_dropout = 0., 492 | resnet_dropout = 0, 493 | checkpoint_quantizer = False, 494 | quads = False 495 | ): 496 | super().__init__() 497 | 498 | self.num_vertices_per_face = 3 if not quads else 4 499 | total_coordinates_per_face = self.num_vertices_per_face * 3 500 | 501 | # main face coordinate embedding 502 | 503 | self.num_discrete_coors = num_discrete_coors 504 | self.coor_continuous_range = coor_continuous_range 505 | 506 | self.discretize_face_coords = partial(discretize, num_discrete = num_discrete_coors, continuous_range = coor_continuous_range) 507 | self.coor_embed = nn.Embedding(num_discrete_coors, dim_coor_embed) 508 | 509 | # derived feature embedding 510 | 511 | self.discretize_angle = partial(discretize, num_discrete = num_discrete_angle, continuous_range = (0., pi)) 512 | self.angle_embed = nn.Embedding(num_discrete_angle, dim_angle_embed) 513 | 514 | lo, hi = coor_continuous_range 515 | self.discretize_area = partial(discretize, num_discrete = num_discrete_area, continuous_range = (0., (hi - lo) ** 2)) 516 | self.area_embed = nn.Embedding(num_discrete_area, dim_area_embed) 517 | 518 | self.discretize_normals = partial(discretize, num_discrete = num_discrete_normals, continuous_range = coor_continuous_range) 519 | self.normal_embed = nn.Embedding(num_discrete_normals, dim_normal_embed) 520 | 521 | # attention related 522 | 523 | attn_kwargs = dict( 524 | causal = False, 525 | prenorm = True, 526 | dropout = attn_dropout, 527 | window_size = local_attn_window_size, 528 | ) 529 | 530 | # initial dimension 531 | 532 | init_dim = dim_coor_embed * (3 * self.num_vertices_per_face) + dim_angle_embed * self.num_vertices_per_face + dim_normal_embed * 3 + dim_area_embed 533 | 534 | # project into model dimension 535 | 536 | self.project_in = nn.Linear(init_dim, dim_codebook) 537 | 538 | # initial sage conv 539 | init_encoder_dim, *encoder_dims_through_depth = encoder_dims_through_depth 540 | curr_dim = init_encoder_dim 541 | 542 | self.init_sage_conv = SAGEConv(dim_codebook, init_encoder_dim, **sageconv_kwargs) 543 | 544 | self.init_encoder_act_and_norm = nn.Sequential( 545 | nn.SiLU(), 546 | nn.LayerNorm(init_encoder_dim) 547 | ) 548 | 549 | self.encoders = ModuleList([]) 550 | 551 | for dim_layer in encoder_dims_through_depth: 552 | sage_conv = SAGEConv( 553 | curr_dim, 554 | dim_layer, 555 | **sageconv_kwargs 556 | ) 557 | 558 | self.encoders.append(sage_conv) 559 | curr_dim = dim_layer 560 | 561 | self.encoder_attn_blocks = ModuleList([]) 562 | 563 | for _ in range(attn_encoder_depth): 564 | self.encoder_attn_blocks.append(nn.ModuleList([ 565 | TaylorSeriesLinearAttn(curr_dim, prenorm = True, **linear_attn_kwargs) if use_linear_attn else None, 566 | LocalMHA(dim = curr_dim, **attn_kwargs, **local_attn_kwargs), 567 | nn.Sequential(RMSNorm(curr_dim), FeedForward(curr_dim, glu = True, dropout = ff_dropout)) 568 | ])) 569 | 570 | # residual quantization 571 | 572 | self.codebook_size = codebook_size 573 | self.num_quantizers = num_quantizers 574 | 575 | self.project_dim_codebook = nn.Linear(curr_dim, dim_codebook * self.num_vertices_per_face) 576 | 577 | if use_residual_lfq: 578 | self.quantizer = ResidualLFQ( 579 | dim = dim_codebook, 580 | num_quantizers = num_quantizers, 581 | codebook_size = codebook_size, 582 | commitment_loss_weight = 1., 583 | **rlfq_kwargs, 584 | **rq_kwargs 585 | ) 586 | else: 587 | self.quantizer = ResidualVQ( 588 | dim = dim_codebook, 589 | num_quantizers = num_quantizers, 590 | codebook_size = codebook_size, 591 | shared_codebook = True, 592 | commitment_weight = 1., 593 | stochastic_sample_codes = rvq_stochastic_sample_codes, 594 | **rvq_kwargs, 595 | **rq_kwargs 596 | ) 597 | 598 | self.checkpoint_quantizer = checkpoint_quantizer # whether to memory checkpoint the quantizer 599 | 600 | self.pad_id = pad_id # for variable lengthed faces, padding quantized ids will be set to this value 601 | 602 | # decoder 603 | 604 | decoder_input_dim = dim_codebook * 3 605 | 606 | self.decoder_attn_blocks = ModuleList([]) 607 | 608 | for _ in range(attn_decoder_depth): 609 | self.decoder_attn_blocks.append(nn.ModuleList([ 610 | TaylorSeriesLinearAttn(decoder_input_dim, prenorm = True, **linear_attn_kwargs) if use_linear_attn else None, 611 | LocalMHA(dim = decoder_input_dim, **attn_kwargs, **local_attn_kwargs), 612 | nn.Sequential(RMSNorm(decoder_input_dim), FeedForward(decoder_input_dim, glu = True, dropout = ff_dropout)) 613 | ])) 614 | 615 | init_decoder_dim, *decoder_dims_through_depth = decoder_dims_through_depth 616 | curr_dim = init_decoder_dim 617 | 618 | assert is_odd(init_decoder_conv_kernel) 619 | 620 | self.init_decoder_conv = nn.Sequential( 621 | nn.Conv1d(dim_codebook * self.num_vertices_per_face, init_decoder_dim, kernel_size = init_decoder_conv_kernel, padding = init_decoder_conv_kernel // 2), 622 | nn.SiLU(), 623 | Rearrange('b c n -> b n c'), 624 | nn.LayerNorm(init_decoder_dim), 625 | Rearrange('b n c -> b c n') 626 | ) 627 | 628 | self.decoders = ModuleList([]) 629 | 630 | for dim_layer in decoder_dims_through_depth: 631 | resnet_block = ResnetBlock(curr_dim, dim_layer, dropout = resnet_dropout) 632 | 633 | self.decoders.append(resnet_block) 634 | curr_dim = dim_layer 635 | 636 | self.to_coor_logits = nn.Sequential( 637 | nn.Linear(curr_dim, num_discrete_coors * total_coordinates_per_face), 638 | Rearrange('... (v c) -> ... v c', v = total_coordinates_per_face) 639 | ) 640 | 641 | # loss related 642 | 643 | self.commit_loss_weight = commit_loss_weight 644 | self.bin_smooth_blur_sigma = bin_smooth_blur_sigma 645 | 646 | @property 647 | def device(self): 648 | return next(self.parameters()).device 649 | 650 | @classmethod 651 | def _from_pretrained( 652 | cls, 653 | *, 654 | model_id: str, 655 | revision: str | None, 656 | cache_dir: str | Path | None, 657 | force_download: bool, 658 | proxies: Dict | None, 659 | resume_download: bool, 660 | local_files_only: bool, 661 | token: str | bool | None, 662 | map_location: str = "cpu", 663 | strict: bool = False, 664 | **model_kwargs, 665 | ): 666 | model_filename = "mesh-autoencoder.bin" 667 | model_file = Path(model_id) / model_filename 668 | if not model_file.exists(): 669 | model_file = hf_hub_download( 670 | repo_id=model_id, 671 | filename=model_filename, 672 | revision=revision, 673 | cache_dir=cache_dir, 674 | force_download=force_download, 675 | proxies=proxies, 676 | resume_download=resume_download, 677 | token=token, 678 | local_files_only=local_files_only, 679 | ) 680 | model = cls.init_and_load(model_file,strict=strict) 681 | model.to(map_location) 682 | return model 683 | 684 | @typecheck 685 | def encode( 686 | self, 687 | *, 688 | vertices: Float['b nv 3'], 689 | faces: Int['b nf nvf'], 690 | face_edges: Int['b e 2'], 691 | face_mask: Bool['b nf'], 692 | face_edges_mask: Bool['b e'], 693 | return_face_coordinates = False 694 | ): 695 | """ 696 | einops: 697 | b - batch 698 | nf - number of faces 699 | nv - number of vertices (3) 700 | nvf - number of vertices per face (3 or 4) - triangles vs quads 701 | c - coordinates (3) 702 | d - embed dim 703 | """ 704 | 705 | _, num_faces, num_vertices_per_face = faces.shape 706 | 707 | assert self.num_vertices_per_face == num_vertices_per_face 708 | 709 | face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0) 710 | 711 | # continuous face coords 712 | 713 | face_coords = get_at('b [nv] c, b nf mv -> b nf mv c', vertices, face_without_pad) 714 | 715 | # compute derived features and embed 716 | 717 | derived_features = get_derived_face_features(face_coords) 718 | 719 | discrete_angle = self.discretize_angle(derived_features['angles']) 720 | angle_embed = self.angle_embed(discrete_angle) 721 | 722 | discrete_area = self.discretize_area(derived_features['area']) 723 | area_embed = self.area_embed(discrete_area) 724 | 725 | discrete_normal = self.discretize_normals(derived_features['normals']) 726 | normal_embed = self.normal_embed(discrete_normal) 727 | 728 | # discretize vertices for face coordinate embedding 729 | 730 | discrete_face_coords = self.discretize_face_coords(face_coords) 731 | discrete_face_coords = rearrange(discrete_face_coords, 'b nf nv c -> b nf (nv c)') # 9 or 12 coordinates per face 732 | 733 | face_coor_embed = self.coor_embed(discrete_face_coords) 734 | face_coor_embed = rearrange(face_coor_embed, 'b nf c d -> b nf (c d)') 735 | 736 | # combine all features and project into model dimension 737 | 738 | face_embed, _ = pack([face_coor_embed, angle_embed, area_embed, normal_embed], 'b nf *') 739 | face_embed = self.project_in(face_embed) 740 | 741 | # handle variable lengths by using masked_select and masked_scatter 742 | 743 | # first handle edges 744 | # needs to be offset by number of faces for each batch 745 | 746 | face_index_offsets = reduce(face_mask.long(), 'b nf -> b', 'sum') 747 | face_index_offsets = F.pad(face_index_offsets.cumsum(dim = 0), (1, -1), value = 0) 748 | face_index_offsets = rearrange(face_index_offsets, 'b -> b 1 1') 749 | 750 | face_edges = face_edges + face_index_offsets 751 | face_edges = face_edges[face_edges_mask] 752 | face_edges = rearrange(face_edges, 'be ij -> ij be') 753 | 754 | # next prepare the face_mask for using masked_select and masked_scatter 755 | 756 | orig_face_embed_shape = face_embed.shape[:2] 757 | 758 | face_embed = face_embed[face_mask] 759 | 760 | # initial sage conv followed by activation and norm 761 | 762 | face_embed = self.init_sage_conv(face_embed, face_edges) 763 | 764 | face_embed = self.init_encoder_act_and_norm(face_embed) 765 | 766 | for conv in self.encoders: 767 | face_embed = conv(face_embed, face_edges) 768 | 769 | shape = (*orig_face_embed_shape, face_embed.shape[-1]) 770 | 771 | face_embed = face_embed.new_zeros(shape).masked_scatter(rearrange(face_mask, '... -> ... 1'), face_embed) 772 | 773 | for linear_attn, attn, ff in self.encoder_attn_blocks: 774 | if exists(linear_attn): 775 | face_embed = linear_attn(face_embed, mask = face_mask) + face_embed 776 | 777 | face_embed = attn(face_embed, mask = face_mask) + face_embed 778 | face_embed = ff(face_embed) + face_embed 779 | 780 | if not return_face_coordinates: 781 | return face_embed 782 | 783 | return face_embed, discrete_face_coords 784 | 785 | @typecheck 786 | def quantize( 787 | self, 788 | *, 789 | faces: Int['b nf nvf'], 790 | face_mask: Bool['b n'], 791 | face_embed: Float['b nf d'], 792 | pad_id = None, 793 | rvq_sample_codebook_temp = 1. 794 | ): 795 | pad_id = default(pad_id, self.pad_id) 796 | batch, device = faces.shape[0], faces.device 797 | 798 | max_vertex_index = faces.amax() 799 | num_vertices = int(max_vertex_index.item() + 1) 800 | 801 | face_embed = self.project_dim_codebook(face_embed) 802 | face_embed = rearrange(face_embed, 'b nf (nvf d) -> b nf nvf d', nvf = self.num_vertices_per_face) 803 | 804 | vertex_dim = face_embed.shape[-1] 805 | vertices = torch.zeros((batch, num_vertices, vertex_dim), device = device) 806 | 807 | # create pad vertex, due to variable lengthed faces 808 | 809 | pad_vertex_id = num_vertices 810 | vertices = pad_at_dim(vertices, (0, 1), dim = -2, value = 0.) 811 | 812 | faces = faces.masked_fill(~rearrange(face_mask, 'b n -> b n 1'), pad_vertex_id) 813 | 814 | # prepare for scatter mean 815 | 816 | faces_with_dim = repeat(faces, 'b nf nvf -> b (nf nvf) d', d = vertex_dim) 817 | 818 | face_embed = rearrange(face_embed, 'b ... d -> b (...) d') 819 | 820 | # scatter mean 821 | 822 | averaged_vertices = scatter_mean(vertices, faces_with_dim, face_embed, dim = -2) 823 | 824 | # mask out null vertex token 825 | 826 | mask = torch.ones((batch, num_vertices + 1), device = device, dtype = torch.bool) 827 | mask[:, -1] = False 828 | 829 | # rvq specific kwargs 830 | 831 | quantize_kwargs = dict(mask = mask) 832 | 833 | if isinstance(self.quantizer, ResidualVQ): 834 | quantize_kwargs.update(sample_codebook_temp = rvq_sample_codebook_temp) 835 | 836 | # a quantize function that makes it memory checkpointable 837 | 838 | def quantize_wrapper_fn(inp): 839 | unquantized, quantize_kwargs = inp 840 | return self.quantizer(unquantized, **quantize_kwargs) 841 | 842 | # maybe checkpoint the quantize fn 843 | 844 | if self.checkpoint_quantizer: 845 | quantize_wrapper_fn = partial(checkpoint, quantize_wrapper_fn, use_reentrant = False) 846 | 847 | # residual VQ 848 | 849 | quantized, codes, commit_loss = quantize_wrapper_fn((averaged_vertices, quantize_kwargs)) 850 | 851 | # gather quantized vertexes back to faces for decoding 852 | # now the faces have quantized vertices 853 | 854 | face_embed_output = get_at('b [n] d, b nf nvf -> b nf (nvf d)', quantized, faces) 855 | 856 | # vertex codes also need to be gathered to be organized by face sequence 857 | # for autoregressive learning 858 | 859 | codes_output = get_at('b [n] q, b nf nvf -> b (nf nvf) q', codes, faces) 860 | 861 | # make sure codes being outputted have this padding 862 | 863 | face_mask = repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face) 864 | codes_output = codes_output.masked_fill(~face_mask, self.pad_id) 865 | 866 | # output quantized, codes, as well as commitment loss 867 | 868 | return face_embed_output, codes_output, commit_loss 869 | 870 | @typecheck 871 | def decode( 872 | self, 873 | quantized: Float['b n d'], 874 | face_mask: Bool['b n'] 875 | ): 876 | conv_face_mask = rearrange(face_mask, 'b n -> b 1 n') 877 | 878 | x = quantized 879 | 880 | for linear_attn, attn, ff in self.decoder_attn_blocks: 881 | if exists(linear_attn): 882 | x = linear_attn(x, mask = face_mask) + x 883 | 884 | x = attn(x, mask = face_mask) + x 885 | x = ff(x) + x 886 | 887 | x = rearrange(x, 'b n d -> b d n') 888 | x = x.masked_fill(~conv_face_mask, 0.) 889 | x = self.init_decoder_conv(x) 890 | 891 | for resnet_block in self.decoders: 892 | x = resnet_block(x, mask = conv_face_mask) 893 | 894 | return rearrange(x, 'b d n -> b n d') 895 | 896 | @typecheck 897 | @torch.no_grad() 898 | def decode_from_codes_to_faces( 899 | self, 900 | codes: Tensor, 901 | face_mask: Bool['b n'] | None = None, 902 | return_discrete_codes = False 903 | ): 904 | codes = rearrange(codes, 'b ... -> b (...)') 905 | 906 | if not exists(face_mask): 907 | face_mask = reduce(codes != self.pad_id, 'b (nf nvf q) -> b nf', 'all', nvf = self.num_vertices_per_face, q = self.num_quantizers) 908 | 909 | # handle different code shapes 910 | 911 | codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers) 912 | 913 | # decode 914 | 915 | quantized = self.quantizer.get_output_from_indices(codes) 916 | quantized = rearrange(quantized, 'b (nf nvf) d -> b nf (nvf d)', nvf = self.num_vertices_per_face) 917 | 918 | decoded = self.decode( 919 | quantized, 920 | face_mask = face_mask 921 | ) 922 | 923 | decoded = decoded.masked_fill(~face_mask[..., None], 0.) 924 | pred_face_coords = self.to_coor_logits(decoded) 925 | 926 | pred_face_coords = pred_face_coords.argmax(dim = -1) 927 | 928 | pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = self.num_vertices_per_face) 929 | 930 | # back to continuous space 931 | 932 | continuous_coors = undiscretize( 933 | pred_face_coords, 934 | num_discrete = self.num_discrete_coors, 935 | continuous_range = self.coor_continuous_range 936 | ) 937 | 938 | # mask out with nan 939 | 940 | continuous_coors = continuous_coors.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1 1'), float('nan')) 941 | 942 | if not return_discrete_codes: 943 | return continuous_coors, face_mask 944 | 945 | return continuous_coors, pred_face_coords, face_mask 946 | 947 | @torch.no_grad() 948 | def tokenize(self, vertices, faces, face_edges = None, **kwargs): 949 | assert 'return_codes' not in kwargs 950 | 951 | inputs = [vertices, faces, face_edges] 952 | inputs = [*filter(exists, inputs)] 953 | ndims = {i.ndim for i in inputs} 954 | 955 | assert len(ndims) == 1 956 | batch_less = first(list(ndims)) == 2 957 | 958 | if batch_less: 959 | inputs = [rearrange(i, '... -> 1 ...') for i in inputs] 960 | 961 | input_kwargs = dict(zip(['vertices', 'faces', 'face_edges'], inputs)) 962 | 963 | self.eval() 964 | 965 | codes = self.forward( 966 | **input_kwargs, 967 | return_codes = True, 968 | **kwargs 969 | ) 970 | 971 | if batch_less: 972 | codes = rearrange(codes, '1 ... -> ...') 973 | 974 | return codes 975 | 976 | @typecheck 977 | def forward( 978 | self, 979 | *, 980 | vertices: Float['b nv 3'], 981 | faces: Int['b nf nvf'], 982 | face_edges: Int['b e 2'] | None = None, 983 | return_codes = False, 984 | return_loss_breakdown = False, 985 | return_recon_faces = False, 986 | only_return_recon_faces = False, 987 | rvq_sample_codebook_temp = 1. 988 | ): 989 | if not exists(face_edges): 990 | face_edges = derive_face_edges_from_faces(faces, pad_id = self.pad_id) 991 | 992 | device = faces.device 993 | 994 | face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all') 995 | face_edges_mask = reduce(face_edges != self.pad_id, 'b e ij -> b e', 'all') 996 | 997 | encoded, face_coordinates = self.encode( 998 | vertices = vertices, 999 | faces = faces, 1000 | face_edges = face_edges, 1001 | face_edges_mask = face_edges_mask, 1002 | face_mask = face_mask, 1003 | return_face_coordinates = True 1004 | ) 1005 | 1006 | quantized, codes, commit_loss = self.quantize( 1007 | face_embed = encoded, 1008 | faces = faces, 1009 | face_mask = face_mask, 1010 | rvq_sample_codebook_temp = rvq_sample_codebook_temp 1011 | ) 1012 | 1013 | if return_codes: 1014 | assert not return_recon_faces, 'cannot return reconstructed faces when just returning raw codes' 1015 | 1016 | codes = codes.masked_fill(~repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face), self.pad_id) 1017 | return codes 1018 | 1019 | decode = self.decode( 1020 | quantized, 1021 | face_mask = face_mask 1022 | ) 1023 | 1024 | pred_face_coords = self.to_coor_logits(decode) 1025 | 1026 | # compute reconstructed faces if needed 1027 | 1028 | if return_recon_faces or only_return_recon_faces: 1029 | 1030 | recon_faces = undiscretize( 1031 | pred_face_coords.argmax(dim = -1), 1032 | num_discrete = self.num_discrete_coors, 1033 | continuous_range = self.coor_continuous_range, 1034 | ) 1035 | 1036 | recon_faces = rearrange(recon_faces, 'b nf (nvf c) -> b nf nvf c', nvf = self.num_vertices_per_face) 1037 | face_mask = rearrange(face_mask, 'b nf -> b nf 1 1') 1038 | recon_faces = recon_faces.masked_fill(~face_mask, float('nan')) 1039 | face_mask = rearrange(face_mask, 'b nf 1 1 -> b nf') 1040 | 1041 | if only_return_recon_faces: 1042 | return recon_faces 1043 | 1044 | # prepare for recon loss 1045 | 1046 | pred_face_coords = rearrange(pred_face_coords, 'b ... c -> b c (...)') 1047 | face_coordinates = rearrange(face_coordinates, 'b ... -> b 1 (...)') 1048 | 1049 | # reconstruction loss on discretized coordinates on each face 1050 | # they also smooth (blur) the one hot positions, localized label smoothing basically 1051 | 1052 | with autocast(enabled = False): 1053 | pred_log_prob = pred_face_coords.log_softmax(dim = 1) 1054 | 1055 | target_one_hot = torch.zeros_like(pred_log_prob).scatter(1, face_coordinates, 1.) 1056 | 1057 | if self.bin_smooth_blur_sigma >= 0.: 1058 | target_one_hot = gaussian_blur_1d(target_one_hot, sigma = self.bin_smooth_blur_sigma) 1059 | 1060 | # cross entropy with localized smoothing 1061 | 1062 | recon_losses = (-target_one_hot * pred_log_prob).sum(dim = 1) 1063 | 1064 | face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = self.num_vertices_per_face * 3) 1065 | recon_loss = recon_losses[face_mask].mean() 1066 | 1067 | # calculate total loss 1068 | 1069 | total_loss = recon_loss + \ 1070 | commit_loss.sum() * self.commit_loss_weight 1071 | 1072 | # calculate loss breakdown if needed 1073 | 1074 | loss_breakdown = (recon_loss, commit_loss) 1075 | 1076 | # some return logic 1077 | 1078 | if not return_loss_breakdown: 1079 | if not return_recon_faces: 1080 | return total_loss 1081 | 1082 | return recon_faces, total_loss 1083 | 1084 | if not return_recon_faces: 1085 | return total_loss, loss_breakdown 1086 | 1087 | return recon_faces, total_loss, loss_breakdown 1088 | 1089 | @save_load(version = __version__) 1090 | class MeshTransformer(Module, PyTorchModelHubMixin): 1091 | @typecheck 1092 | def __init__( 1093 | self, 1094 | autoencoder: MeshAutoencoder, 1095 | *, 1096 | dim: int | Tuple[int, int] = 512, 1097 | max_seq_len = 8192, 1098 | flash_attn = True, 1099 | attn_depth = 12, 1100 | attn_dim_head = 64, 1101 | attn_heads = 16, 1102 | attn_kwargs: dict = dict( 1103 | ff_glu = True, 1104 | attn_num_mem_kv = 4 1105 | ), 1106 | cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out text condition 1107 | dropout = 0., 1108 | coarse_pre_gateloop_depth = 2, 1109 | coarse_post_gateloop_depth = 0, 1110 | coarse_adaptive_rmsnorm = False, 1111 | fine_pre_gateloop_depth = 2, 1112 | gateloop_use_heinsen = False, 1113 | fine_attn_depth = 2, 1114 | fine_attn_dim_head = 32, 1115 | fine_attn_heads = 8, 1116 | fine_cross_attend_text = False, # additional conditioning - fine transformer cross attention to text tokens 1117 | pad_id = -1, 1118 | num_sos_tokens = None, 1119 | condition_on_text = False, 1120 | text_cond_with_film = False, 1121 | text_condition_model_types = ('t5',), 1122 | text_condition_model_kwargs = (dict(),), 1123 | text_condition_cond_drop_prob = 0.25, 1124 | quads = False, 1125 | ): 1126 | super().__init__() 1127 | self.num_vertices_per_face = 3 if not quads else 4 1128 | 1129 | assert autoencoder.num_vertices_per_face == self.num_vertices_per_face, 'autoencoder and transformer must both support the same type of mesh (either all triangles, or all quads)' 1130 | 1131 | dim, dim_fine = (dim, dim) if isinstance(dim, int) else dim 1132 | 1133 | self.autoencoder = autoencoder 1134 | set_module_requires_grad_(autoencoder, False) 1135 | 1136 | self.codebook_size = autoencoder.codebook_size 1137 | self.num_quantizers = autoencoder.num_quantizers 1138 | 1139 | self.eos_token_id = self.codebook_size 1140 | 1141 | # the fine transformer sos token 1142 | # as well as a projection of pooled text embeddings to condition it 1143 | 1144 | num_sos_tokens = default(num_sos_tokens, 1 if not condition_on_text else 4) 1145 | assert num_sos_tokens > 0 1146 | 1147 | self.num_sos_tokens = num_sos_tokens 1148 | self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim)) 1149 | 1150 | # they use axial positional embeddings 1151 | 1152 | assert divisible_by(max_seq_len, self.num_vertices_per_face * self.num_quantizers), f'max_seq_len ({max_seq_len}) must be divisible by (3 x {self.num_quantizers}) = {3 * self.num_quantizers}' # 3 or 4 vertices per face, with D codes per vertex 1153 | 1154 | self.token_embed = nn.Embedding(self.codebook_size + 1, dim) 1155 | 1156 | self.quantize_level_embed = nn.Parameter(torch.randn(self.num_quantizers, dim)) 1157 | self.vertex_embed = nn.Parameter(torch.randn(self.num_vertices_per_face, dim)) 1158 | 1159 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) 1160 | 1161 | self.max_seq_len = max_seq_len 1162 | 1163 | # text condition 1164 | 1165 | self.condition_on_text = condition_on_text 1166 | self.conditioner = None 1167 | 1168 | cross_attn_dim_context = None 1169 | dim_text = None 1170 | 1171 | if condition_on_text: 1172 | self.conditioner = TextEmbeddingReturner( 1173 | model_types = text_condition_model_types, 1174 | model_kwargs = text_condition_model_kwargs, 1175 | cond_drop_prob = text_condition_cond_drop_prob, 1176 | text_embed_pad_value = -1. 1177 | ) 1178 | 1179 | dim_text = self.conditioner.dim_latent 1180 | cross_attn_dim_context = dim_text 1181 | 1182 | self.text_coarse_film_cond = FiLM(dim_text, dim) if text_cond_with_film else identity 1183 | self.text_fine_film_cond = FiLM(dim_text, dim_fine) if text_cond_with_film else identity 1184 | 1185 | # for summarizing the vertices of each face 1186 | 1187 | self.to_face_tokens = nn.Sequential( 1188 | nn.Linear(self.num_quantizers * self.num_vertices_per_face * dim, dim), 1189 | nn.LayerNorm(dim) 1190 | ) 1191 | 1192 | self.coarse_gateloop_block = GateLoopBlock(dim, depth = coarse_pre_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None 1193 | 1194 | self.coarse_post_gateloop_block = GateLoopBlock(dim, depth = coarse_post_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None 1195 | 1196 | # main autoregressive attention network 1197 | # attending to a face token 1198 | 1199 | self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm 1200 | 1201 | self.decoder = Decoder( 1202 | dim = dim, 1203 | depth = attn_depth, 1204 | heads = attn_heads, 1205 | attn_dim_head = attn_dim_head, 1206 | attn_flash = flash_attn, 1207 | attn_dropout = dropout, 1208 | ff_dropout = dropout, 1209 | use_adaptive_rmsnorm = coarse_adaptive_rmsnorm, 1210 | dim_condition = dim_text, 1211 | cross_attend = condition_on_text, 1212 | cross_attn_dim_context = cross_attn_dim_context, 1213 | cross_attn_num_mem_kv = cross_attn_num_mem_kv, 1214 | **attn_kwargs 1215 | ) 1216 | 1217 | # projection from coarse to fine, if needed 1218 | 1219 | self.maybe_project_coarse_to_fine = nn.Linear(dim, dim_fine) if dim != dim_fine else nn.Identity() 1220 | 1221 | # address a weakness in attention 1222 | 1223 | self.fine_gateloop_block = GateLoopBlock(dim, depth = fine_pre_gateloop_depth, use_heinsen = gateloop_use_heinsen) if fine_pre_gateloop_depth > 0 else None 1224 | 1225 | # decoding the vertices, 2-stage hierarchy 1226 | 1227 | self.fine_cross_attend_text = condition_on_text and fine_cross_attend_text 1228 | 1229 | self.fine_decoder = Decoder( 1230 | dim = dim_fine, 1231 | depth = fine_attn_depth, 1232 | heads = attn_heads, 1233 | attn_dim_head = attn_dim_head, 1234 | attn_flash = flash_attn, 1235 | attn_dropout = dropout, 1236 | ff_dropout = dropout, 1237 | cross_attend = self.fine_cross_attend_text, 1238 | cross_attn_dim_context = cross_attn_dim_context, 1239 | cross_attn_num_mem_kv = cross_attn_num_mem_kv, 1240 | **attn_kwargs 1241 | ) 1242 | 1243 | # to logits 1244 | 1245 | self.to_logits = nn.Linear(dim_fine, self.codebook_size + 1) 1246 | 1247 | # padding id 1248 | # force the autoencoder to use the same pad_id given in transformer 1249 | 1250 | self.pad_id = pad_id 1251 | autoencoder.pad_id = pad_id 1252 | 1253 | @classmethod 1254 | def _from_pretrained( 1255 | cls, 1256 | *, 1257 | model_id: str, 1258 | revision: str | None, 1259 | cache_dir: str | Path | None, 1260 | force_download: bool, 1261 | proxies: Dict | None, 1262 | resume_download: bool, 1263 | local_files_only: bool, 1264 | token: str | bool | None, 1265 | map_location: str = "cpu", 1266 | strict: bool = False, 1267 | **model_kwargs, 1268 | ): 1269 | model_filename = "mesh-transformer.bin" 1270 | model_file = Path(model_id) / model_filename 1271 | 1272 | if not model_file.exists(): 1273 | model_file = hf_hub_download( 1274 | repo_id=model_id, 1275 | filename=model_filename, 1276 | revision=revision, 1277 | cache_dir=cache_dir, 1278 | force_download=force_download, 1279 | proxies=proxies, 1280 | resume_download=resume_download, 1281 | token=token, 1282 | local_files_only=local_files_only, 1283 | ) 1284 | 1285 | model = cls.init_and_load(model_file,strict=strict) 1286 | model.to(map_location) 1287 | return model 1288 | 1289 | @property 1290 | def device(self): 1291 | return next(self.parameters()).device 1292 | 1293 | @typecheck 1294 | @torch.no_grad() 1295 | def embed_texts(self, texts: str | List[str]): 1296 | single_text = not isinstance(texts, list) 1297 | if single_text: 1298 | texts = [texts] 1299 | 1300 | assert exists(self.conditioner) 1301 | text_embeds = self.conditioner.embed_texts(texts).detach() 1302 | 1303 | if single_text: 1304 | text_embeds = text_embeds[0] 1305 | 1306 | return text_embeds 1307 | 1308 | @eval_decorator 1309 | @torch.no_grad() 1310 | @typecheck 1311 | def generate( 1312 | self, 1313 | prompt: Tensor | None = None, 1314 | batch_size: int | None = None, 1315 | filter_logits_fn: Callable = top_k, 1316 | filter_kwargs: dict = dict(), 1317 | temperature = 1., 1318 | return_codes = False, 1319 | texts: List[str] | None = None, 1320 | text_embeds: Tensor | None = None, 1321 | cond_scale = 1., 1322 | cache_kv = True, 1323 | max_seq_len = None, 1324 | face_coords_to_file: Callable[[Tensor], Any] | None = None 1325 | ): 1326 | max_seq_len = default(max_seq_len, self.max_seq_len) 1327 | 1328 | if exists(prompt): 1329 | assert not exists(batch_size) 1330 | 1331 | prompt = rearrange(prompt, 'b ... -> b (...)') 1332 | assert prompt.shape[-1] <= self.max_seq_len 1333 | 1334 | batch_size = prompt.shape[0] 1335 | 1336 | if self.condition_on_text: 1337 | assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True' 1338 | if exists(texts): 1339 | text_embeds = self.embed_texts(texts) 1340 | 1341 | batch_size = default(batch_size, text_embeds.shape[0]) 1342 | 1343 | batch_size = default(batch_size, 1) 1344 | 1345 | codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device)) 1346 | 1347 | curr_length = codes.shape[-1] 1348 | 1349 | cache = (None, None) 1350 | 1351 | for i in tqdm(range(curr_length, max_seq_len)): 1352 | 1353 | # example below for triangles, extrapolate for quads 1354 | # v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F) 1355 | 1356 | can_eos = i != 0 and divisible_by(i, self.num_quantizers * self.num_vertices_per_face) # only allow for eos to be decoded at the end of each face, defined as 3 or 4 vertices with D residual VQ codes 1357 | 1358 | output = self.forward_on_codes( 1359 | codes, 1360 | text_embeds = text_embeds, 1361 | return_loss = False, 1362 | return_cache = cache_kv, 1363 | append_eos = False, 1364 | cond_scale = cond_scale, 1365 | cfg_routed_kwargs = dict( 1366 | cache = cache 1367 | ) 1368 | ) 1369 | 1370 | if cache_kv: 1371 | logits, cache = output 1372 | 1373 | if cond_scale == 1.: 1374 | cache = (cache, None) 1375 | else: 1376 | logits = output 1377 | 1378 | logits = logits[:, -1] 1379 | 1380 | if not can_eos: 1381 | logits[:, -1] = -torch.finfo(logits.dtype).max 1382 | 1383 | filtered_logits = filter_logits_fn(logits, **filter_kwargs) 1384 | 1385 | if temperature == 0.: 1386 | sample = filtered_logits.argmax(dim = -1) 1387 | else: 1388 | probs = F.softmax(filtered_logits / temperature, dim = -1) 1389 | sample = torch.multinomial(probs, 1) 1390 | 1391 | codes, _ = pack([codes, sample], 'b *') 1392 | 1393 | # check for all rows to have [eos] to terminate 1394 | 1395 | is_eos_codes = (codes == self.eos_token_id) 1396 | 1397 | if is_eos_codes.any(dim = -1).all(): 1398 | break 1399 | 1400 | # mask out to padding anything after the first eos 1401 | 1402 | mask = is_eos_codes.float().cumsum(dim = -1) >= 1 1403 | codes = codes.masked_fill(mask, self.pad_id) 1404 | 1405 | # remove a potential extra token from eos, if breaked early 1406 | 1407 | code_len = codes.shape[-1] 1408 | round_down_code_len = code_len // self.num_vertices_per_face * self.num_vertices_per_face 1409 | codes = codes[:, :round_down_code_len] 1410 | 1411 | # early return of raw residual quantizer codes 1412 | 1413 | if return_codes: 1414 | codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers) 1415 | return codes 1416 | 1417 | self.autoencoder.eval() 1418 | face_coords, face_mask = self.autoencoder.decode_from_codes_to_faces(codes) 1419 | 1420 | if not exists(face_coords_to_file): 1421 | return face_coords, face_mask 1422 | 1423 | files = [face_coords_to_file(coords[mask]) for coords, mask in zip(face_coords, face_mask)] 1424 | return files 1425 | 1426 | def forward( 1427 | self, 1428 | *, 1429 | vertices: Int['b nv 3'], 1430 | faces: Int['b nf nvf'], 1431 | face_edges: Int['b e 2'] | None = None, 1432 | codes: Tensor | None = None, 1433 | cache: LayerIntermediates | None = None, 1434 | **kwargs 1435 | ): 1436 | if not exists(codes): 1437 | codes = self.autoencoder.tokenize( 1438 | vertices = vertices, 1439 | faces = faces, 1440 | face_edges = face_edges 1441 | ) 1442 | 1443 | return self.forward_on_codes(codes, cache = cache, **kwargs) 1444 | 1445 | @classifier_free_guidance 1446 | def forward_on_codes( 1447 | self, 1448 | codes = None, 1449 | return_loss = True, 1450 | return_cache = False, 1451 | append_eos = True, 1452 | cache = None, 1453 | texts: List[str] | None = None, 1454 | text_embeds: Tensor | None = None, 1455 | cond_drop_prob = None 1456 | ): 1457 | # handle text conditions 1458 | 1459 | attn_context_kwargs = dict() 1460 | 1461 | if self.condition_on_text: 1462 | assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True' 1463 | 1464 | if exists(texts): 1465 | text_embeds = self.conditioner.embed_texts(texts) 1466 | 1467 | if exists(codes): 1468 | assert text_embeds.shape[0] == codes.shape[0], 'batch size of texts or text embeddings is not equal to the batch size of the mesh codes' 1469 | 1470 | _, maybe_dropped_text_embeds = self.conditioner( 1471 | text_embeds = text_embeds, 1472 | cond_drop_prob = cond_drop_prob 1473 | ) 1474 | 1475 | text_embed, text_mask = maybe_dropped_text_embeds 1476 | 1477 | pooled_text_embed = masked_mean(text_embed, text_mask, dim = 1) 1478 | 1479 | attn_context_kwargs = dict( 1480 | context = text_embed, 1481 | context_mask = text_mask 1482 | ) 1483 | 1484 | if self.coarse_adaptive_rmsnorm: 1485 | attn_context_kwargs.update( 1486 | condition = pooled_text_embed 1487 | ) 1488 | 1489 | # take care of codes that may be flattened 1490 | 1491 | if codes.ndim > 2: 1492 | codes = rearrange(codes, 'b ... -> b (...)') 1493 | 1494 | # get some variable 1495 | 1496 | batch, seq_len, device = *codes.shape, codes.device 1497 | 1498 | assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}' 1499 | 1500 | # auto append eos token 1501 | 1502 | if append_eos: 1503 | assert exists(codes) 1504 | 1505 | code_lens = ((codes == self.pad_id).cumsum(dim = -1) == 0).sum(dim = -1) 1506 | 1507 | codes = F.pad(codes, (0, 1), value = 0) 1508 | 1509 | batch_arange = torch.arange(batch, device = device) 1510 | 1511 | batch_arange = rearrange(batch_arange, '... -> ... 1') 1512 | code_lens = rearrange(code_lens, '... -> ... 1') 1513 | 1514 | codes[batch_arange, code_lens] = self.eos_token_id 1515 | 1516 | # if returning loss, save the labels for cross entropy 1517 | 1518 | if return_loss: 1519 | assert seq_len > 0 1520 | codes, labels = codes[:, :-1], codes 1521 | 1522 | # token embed (each residual VQ id) 1523 | 1524 | codes = codes.masked_fill(codes == self.pad_id, 0) 1525 | codes = self.token_embed(codes) 1526 | 1527 | # codebook embed + absolute positions 1528 | 1529 | seq_arange = torch.arange(codes.shape[-2], device = device) 1530 | 1531 | codes = codes + self.abs_pos_emb(seq_arange) 1532 | 1533 | # embedding for quantizer level 1534 | 1535 | code_len = codes.shape[1] 1536 | 1537 | level_embed = repeat(self.quantize_level_embed, 'q d -> (r q) d', r = ceil(code_len / self.num_quantizers)) 1538 | codes = codes + level_embed[:code_len] 1539 | 1540 | # embedding for each vertex 1541 | 1542 | vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (self.num_vertices_per_face * self.num_quantizers)), q = self.num_quantizers) 1543 | codes = codes + vertex_embed[:code_len] 1544 | 1545 | # create a token per face, by summarizing the 3 or 4 vertices 1546 | # this is similar in design to the RQ transformer from Lee et al. https://arxiv.org/abs/2203.01941 1547 | 1548 | num_tokens_per_face = self.num_quantizers * self.num_vertices_per_face 1549 | 1550 | curr_vertex_pos = code_len % num_tokens_per_face # the current intra-face vertex-code position id, needed for caching at the fine decoder stage 1551 | 1552 | code_len_is_multiple_of_face = divisible_by(code_len, num_tokens_per_face) 1553 | 1554 | next_multiple_code_len = ceil(code_len / num_tokens_per_face) * num_tokens_per_face 1555 | 1556 | codes = pad_to_length(codes, next_multiple_code_len, dim = -2) 1557 | 1558 | # grouped codes will be used for the second stage 1559 | 1560 | grouped_codes = rearrange(codes, 'b (nf n) d -> b nf n d', n = num_tokens_per_face) 1561 | 1562 | # create the coarse tokens for the first attention network 1563 | 1564 | face_codes = grouped_codes if code_len_is_multiple_of_face else grouped_codes[:, :-1] 1565 | face_codes = rearrange(face_codes, 'b nf n d -> b nf (n d)') 1566 | face_codes = self.to_face_tokens(face_codes) 1567 | 1568 | face_codes_len = face_codes.shape[-2] 1569 | 1570 | # cache logic 1571 | 1572 | ( 1573 | cached_attended_face_codes, 1574 | coarse_cache, 1575 | fine_cache, 1576 | coarse_gateloop_cache, 1577 | coarse_post_gateloop_cache, 1578 | fine_gateloop_cache 1579 | ) = cache if exists(cache) else ((None,) * 6) 1580 | 1581 | if exists(cache): 1582 | cached_face_codes_len = cached_attended_face_codes.shape[-2] 1583 | cached_face_codes_len_without_sos = cached_face_codes_len - 1 1584 | 1585 | need_call_first_transformer = face_codes_len > cached_face_codes_len_without_sos 1586 | else: 1587 | # auto prepend sos token 1588 | 1589 | sos = repeat(self.sos_token, 'n d -> b n d', b = batch) 1590 | face_codes, packed_sos_shape = pack([sos, face_codes], 'b * d') 1591 | 1592 | # if no kv cache, always call first transformer 1593 | 1594 | need_call_first_transformer = True 1595 | 1596 | should_cache_fine = not divisible_by(curr_vertex_pos + 1, num_tokens_per_face) 1597 | 1598 | # condition face codes with text if needed 1599 | 1600 | if self.condition_on_text: 1601 | face_codes = self.text_coarse_film_cond(face_codes, pooled_text_embed) 1602 | 1603 | # attention on face codes (coarse) 1604 | 1605 | if need_call_first_transformer: 1606 | if exists(self.coarse_gateloop_block): 1607 | face_codes, coarse_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_gateloop_cache) 1608 | 1609 | attended_face_codes, coarse_cache = self.decoder( 1610 | face_codes, 1611 | cache = coarse_cache, 1612 | return_hiddens = True, 1613 | **attn_context_kwargs 1614 | ) 1615 | 1616 | if exists(self.coarse_post_gateloop_block): 1617 | face_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(face_codes, cache = coarse_post_gateloop_cache) 1618 | 1619 | else: 1620 | attended_face_codes = None 1621 | 1622 | attended_face_codes = safe_cat((cached_attended_face_codes, attended_face_codes), dim = -2) 1623 | 1624 | # if calling without kv cache, pool the sos tokens, if greater than 1 sos token 1625 | 1626 | if not exists(cache): 1627 | sos_tokens, attended_face_codes = unpack(attended_face_codes, packed_sos_shape, 'b * d') 1628 | last_sos_token = sos_tokens[:, -1:] 1629 | attended_face_codes = torch.cat((last_sos_token, attended_face_codes), dim = 1) 1630 | 1631 | # maybe project from coarse to fine dimension for hierarchical transformers 1632 | 1633 | attended_face_codes = self.maybe_project_coarse_to_fine(attended_face_codes) 1634 | 1635 | grouped_codes = pad_to_length(grouped_codes, attended_face_codes.shape[-2], dim = 1) 1636 | fine_vertex_codes, _ = pack([attended_face_codes, grouped_codes], 'b n * d') 1637 | 1638 | fine_vertex_codes = fine_vertex_codes[..., :-1, :] 1639 | 1640 | # gateloop layers 1641 | 1642 | if exists(self.fine_gateloop_block): 1643 | fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> b (nf n) d') 1644 | orig_length = fine_vertex_codes.shape[-2] 1645 | fine_vertex_codes = fine_vertex_codes[:, :(code_len + 1)] 1646 | 1647 | fine_vertex_codes, fine_gateloop_cache = self.fine_gateloop_block(fine_vertex_codes, cache = fine_gateloop_cache) 1648 | 1649 | fine_vertex_codes = pad_to_length(fine_vertex_codes, orig_length, dim = -2) 1650 | fine_vertex_codes = rearrange(fine_vertex_codes, 'b (nf n) d -> b nf n d', n = num_tokens_per_face) 1651 | 1652 | # fine attention - 2nd stage 1653 | 1654 | if exists(cache): 1655 | fine_vertex_codes = fine_vertex_codes[:, -1:] 1656 | 1657 | if exists(fine_cache): 1658 | for attn_intermediate in fine_cache.attn_intermediates: 1659 | ck, cv = attn_intermediate.cached_kv 1660 | ck, cv = [rearrange(t, '(b nf) ... -> b nf ...', b = batch) for t in (ck, cv)] 1661 | 1662 | # when operating on the cached key / values, treat self attention and cross attention differently 1663 | 1664 | layer_type = attn_intermediate.layer_type 1665 | 1666 | if layer_type == 'a': 1667 | ck, cv = [t[:, -1, :, :curr_vertex_pos] for t in (ck, cv)] 1668 | elif layer_type == 'c': 1669 | ck, cv = [t[:, -1, ...] for t in (ck, cv)] 1670 | 1671 | attn_intermediate.cached_kv = (ck, cv) 1672 | 1673 | num_faces = fine_vertex_codes.shape[1] 1674 | one_face = num_faces == 1 1675 | 1676 | fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> (b nf) n d') 1677 | 1678 | if one_face: 1679 | fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)] 1680 | 1681 | # handle maybe cross attention conditioning of fine transformer with text 1682 | 1683 | fine_attn_context_kwargs = dict() 1684 | 1685 | # optional text cross attention conditioning for fine transformer 1686 | 1687 | if self.fine_cross_attend_text: 1688 | repeat_batch = fine_vertex_codes.shape[0] // text_embed.shape[0] 1689 | 1690 | text_embed = repeat(text_embed, 'b ... -> (b r) ...' , r = repeat_batch) 1691 | text_mask = repeat(text_mask, 'b ... -> (b r) ...', r = repeat_batch) 1692 | 1693 | fine_attn_context_kwargs = dict( 1694 | context = text_embed, 1695 | context_mask = text_mask 1696 | ) 1697 | 1698 | # also film condition the fine vertex codes 1699 | 1700 | if self.condition_on_text: 1701 | repeat_batch = fine_vertex_codes.shape[0] // pooled_text_embed.shape[0] 1702 | 1703 | pooled_text_embed = repeat(pooled_text_embed, 'b ... -> (b r) ...', r = repeat_batch) 1704 | fine_vertex_codes = self.text_fine_film_cond(fine_vertex_codes, pooled_text_embed) 1705 | 1706 | # fine transformer 1707 | 1708 | attended_vertex_codes, fine_cache = self.fine_decoder( 1709 | fine_vertex_codes, 1710 | cache = fine_cache, 1711 | **fine_attn_context_kwargs, 1712 | return_hiddens = True 1713 | ) 1714 | 1715 | if not should_cache_fine: 1716 | fine_cache = None 1717 | 1718 | if not one_face: 1719 | # reconstitute original sequence 1720 | 1721 | embed = rearrange(attended_vertex_codes, '(b nf) n d -> b (nf n) d', b = batch) 1722 | embed = embed[:, :(code_len + 1)] 1723 | else: 1724 | embed = attended_vertex_codes 1725 | 1726 | # logits 1727 | 1728 | logits = self.to_logits(embed) 1729 | 1730 | if not return_loss: 1731 | if not return_cache: 1732 | return logits 1733 | 1734 | next_cache = ( 1735 | attended_face_codes, 1736 | coarse_cache, 1737 | fine_cache, 1738 | coarse_gateloop_cache, 1739 | coarse_post_gateloop_cache, 1740 | fine_gateloop_cache 1741 | ) 1742 | 1743 | return logits, next_cache 1744 | 1745 | # loss 1746 | 1747 | ce_loss = F.cross_entropy( 1748 | rearrange(logits, 'b n c -> b c n'), 1749 | labels, 1750 | ignore_index = self.pad_id 1751 | ) 1752 | 1753 | return ce_loss 1754 | -------------------------------------------------------------------------------- /meshgpt_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from functools import partial 5 | from packaging import version 6 | from contextlib import nullcontext 7 | 8 | import torch 9 | from torch.nn import Module 10 | from torch.utils.data import Dataset, DataLoader 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | 13 | from pytorch_custom_utils import ( 14 | get_adam_optimizer, 15 | OptimizerWithWarmupSchedule 16 | ) 17 | 18 | from accelerate import Accelerator 19 | from accelerate.utils import DistributedDataParallelKwargs 20 | 21 | 22 | from beartype.typing import Tuple, Type, List 23 | from meshgpt_pytorch.typing import typecheck, beartype_isinstance 24 | 25 | from ema_pytorch import EMA 26 | 27 | from meshgpt_pytorch.data import custom_collate 28 | 29 | from meshgpt_pytorch.version import __version__ 30 | import matplotlib.pyplot as plt 31 | from tqdm import tqdm 32 | from meshgpt_pytorch.meshgpt_pytorch import ( 33 | MeshAutoencoder, 34 | MeshTransformer 35 | ) 36 | 37 | # constants 38 | 39 | DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs( 40 | find_unused_parameters = True 41 | ) 42 | 43 | # helper functions 44 | 45 | def exists(v): 46 | return v is not None 47 | 48 | def default(v, d): 49 | return v if exists(v) else d 50 | 51 | def divisible_by(num, den): 52 | return (num % den) == 0 53 | 54 | def cycle(dl): 55 | while True: 56 | for data in dl: 57 | yield data 58 | 59 | def maybe_del(d: dict, *keys): 60 | for key in keys: 61 | if key not in d: 62 | continue 63 | 64 | del d[key] 65 | 66 | # autoencoder trainer 67 | 68 | class MeshAutoencoderTrainer(Module): 69 | @typecheck 70 | def __init__( 71 | self, 72 | model: MeshAutoencoder, 73 | dataset: Dataset, 74 | num_train_steps: int, 75 | batch_size: int, 76 | grad_accum_every: int, 77 | val_dataset: Dataset | None = None, 78 | val_every: int = 100, 79 | val_num_batches: int = 5, 80 | learning_rate: float = 1e-4, 81 | weight_decay: float = 0., 82 | max_grad_norm: float | None = None, 83 | ema_kwargs: dict = dict( 84 | use_foreach = True 85 | ), 86 | scheduler: Type[_LRScheduler] | None = None, 87 | scheduler_kwargs: dict = dict(), 88 | accelerator_kwargs: dict = dict(), 89 | optimizer_kwargs: dict = dict(), 90 | checkpoint_every = 1000, 91 | checkpoint_every_epoch: Type[int] | None = None, 92 | checkpoint_folder = './checkpoints', 93 | data_kwargs: Tuple[str, ...] = ('vertices', 'faces', 'face_edges'), 94 | warmup_steps = 1000, 95 | use_wandb_tracking = False 96 | ): 97 | super().__init__() 98 | 99 | # experiment tracker 100 | 101 | self.use_wandb_tracking = use_wandb_tracking 102 | 103 | if use_wandb_tracking: 104 | accelerator_kwargs['log_with'] = 'wandb' 105 | 106 | if 'kwargs_handlers' not in accelerator_kwargs: 107 | accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS] 108 | 109 | # accelerator 110 | 111 | self.accelerator = Accelerator(**accelerator_kwargs) 112 | 113 | self.model = model 114 | 115 | if self.is_main: 116 | self.ema_model = EMA(model, **ema_kwargs) 117 | 118 | self.optimizer = OptimizerWithWarmupSchedule( 119 | accelerator = self.accelerator, 120 | optimizer = get_adam_optimizer(model.parameters(), lr = learning_rate, wd = weight_decay, **optimizer_kwargs), 121 | scheduler = scheduler, 122 | scheduler_kwargs = scheduler_kwargs, 123 | warmup_steps = warmup_steps, 124 | max_grad_norm = max_grad_norm 125 | ) 126 | 127 | self.dataloader = DataLoader( 128 | dataset, 129 | shuffle = True, 130 | batch_size = batch_size, 131 | drop_last = True, 132 | collate_fn = partial(custom_collate, pad_id = model.pad_id) 133 | ) 134 | 135 | self.should_validate = exists(val_dataset) 136 | 137 | if self.should_validate: 138 | assert len(val_dataset) > 0, 'your validation dataset is empty' 139 | 140 | self.val_every = val_every 141 | self.val_num_batches = val_num_batches 142 | 143 | self.val_dataloader = DataLoader( 144 | val_dataset, 145 | shuffle = True, 146 | batch_size = batch_size, 147 | drop_last = True, 148 | collate_fn = partial(custom_collate, pad_id = model.pad_id) 149 | ) 150 | 151 | if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs): 152 | assert beartype_isinstance(dataset.data_kwargs, List[str]) 153 | self.data_kwargs = dataset.data_kwargs 154 | else: 155 | self.data_kwargs = data_kwargs 156 | 157 | 158 | ( 159 | self.model, 160 | self.dataloader 161 | ) = self.accelerator.prepare( 162 | self.model, 163 | self.dataloader 164 | ) 165 | 166 | self.grad_accum_every = grad_accum_every 167 | self.num_train_steps = num_train_steps 168 | self.register_buffer('step', torch.tensor(0)) 169 | 170 | self.checkpoint_every_epoch = checkpoint_every_epoch 171 | self.checkpoint_every = checkpoint_every 172 | self.checkpoint_folder = Path(checkpoint_folder) 173 | self.checkpoint_folder.mkdir(exist_ok = True, parents = True) 174 | 175 | @property 176 | def ema_tokenizer(self): 177 | return self.ema_model.ema_model 178 | 179 | def tokenize(self, *args, **kwargs): 180 | return self.ema_tokenizer.tokenize(*args, **kwargs) 181 | 182 | def log(self, **data_kwargs): 183 | self.accelerator.log(data_kwargs, step = self.step.item()) 184 | 185 | @property 186 | def device(self): 187 | return self.unwrapped_model.device 188 | 189 | @property 190 | def is_main(self): 191 | return self.accelerator.is_main_process 192 | 193 | @property 194 | def unwrapped_model(self): 195 | return self.accelerator.unwrap_model(self.model) 196 | 197 | @property 198 | def is_local_main(self): 199 | return self.accelerator.is_local_main_process 200 | 201 | def wait(self): 202 | return self.accelerator.wait_for_everyone() 203 | 204 | def print(self, msg): 205 | return self.accelerator.print(msg) 206 | 207 | def save(self, path, overwrite = True): 208 | path = Path(path) 209 | assert overwrite or not path.exists() 210 | 211 | pkg = dict( 212 | model = self.unwrapped_model.state_dict(), 213 | ema_model = self.ema_model.state_dict(), 214 | optimizer = self.optimizer.state_dict(), 215 | version = __version__, 216 | step = self.step.item(), 217 | config = self.unwrapped_model._config 218 | ) 219 | 220 | torch.save(pkg, str(path)) 221 | 222 | def load(self, path): 223 | path = Path(path) 224 | assert path.exists() 225 | 226 | pkg = torch.load(str(path)) 227 | 228 | if version.parse(__version__) != version.parse(pkg['version']): 229 | self.print(f'loading saved mesh autoencoder at version {pkg["version"]}, but current package version is {__version__}') 230 | 231 | self.model.load_state_dict(pkg['model']) 232 | self.ema_model.load_state_dict(pkg['ema_model']) 233 | self.optimizer.load_state_dict(pkg['optimizer']) 234 | 235 | self.step.copy_(pkg['step']) 236 | 237 | def next_data_to_forward_kwargs(self, dl_iter) -> dict: 238 | data = next(dl_iter) 239 | 240 | if isinstance(data, tuple): 241 | forward_kwargs = dict(zip(self.data_kwargs, data)) 242 | 243 | elif isinstance(data, dict): 244 | forward_kwargs = data 245 | 246 | maybe_del(forward_kwargs, 'texts', 'text_embeds') 247 | return forward_kwargs 248 | 249 | def forward(self): 250 | step = self.step.item() 251 | dl_iter = cycle(self.dataloader) 252 | 253 | if self.is_main and self.should_validate: 254 | val_dl_iter = cycle(self.val_dataloader) 255 | 256 | while step < self.num_train_steps: 257 | 258 | for i in range(self.grad_accum_every): 259 | is_last = i == (self.grad_accum_every - 1) 260 | maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext 261 | 262 | forward_kwargs = self.next_data_to_forward_kwargs(dl_iter) 263 | 264 | with self.accelerator.autocast(), maybe_no_sync(): 265 | 266 | total_loss, (recon_loss, commit_loss) = self.model( 267 | **forward_kwargs, 268 | return_loss_breakdown = True 269 | ) 270 | 271 | self.accelerator.backward(total_loss / self.grad_accum_every) 272 | 273 | self.print(f'recon loss: {recon_loss.item():.3f} | commit loss: {commit_loss.sum().item():.3f}') 274 | 275 | self.log( 276 | total_loss = total_loss.item(), 277 | commit_loss = commit_loss.sum().item(), 278 | recon_loss = recon_loss.item() 279 | ) 280 | 281 | self.optimizer.step() 282 | self.optimizer.zero_grad() 283 | 284 | step += 1 285 | self.step.add_(1) 286 | 287 | self.wait() 288 | 289 | if self.is_main: 290 | self.ema_model.update() 291 | 292 | self.wait() 293 | 294 | if self.is_main and self.should_validate and divisible_by(step, self.val_every): 295 | 296 | total_val_recon_loss = 0. 297 | self.ema_model.eval() 298 | 299 | num_val_batches = self.val_num_batches * self.grad_accum_every 300 | 301 | for _ in range(num_val_batches): 302 | with self.accelerator.autocast(), torch.no_grad(): 303 | 304 | forward_kwargs = self.next_data_to_forward_kwargs(val_dl_iter) 305 | 306 | val_loss, (val_recon_loss, val_commit_loss) = self.ema_model( 307 | **forward_kwargs, 308 | return_loss_breakdown = True 309 | ) 310 | 311 | total_val_recon_loss += (val_recon_loss / num_val_batches) 312 | 313 | self.print(f'valid recon loss: {total_val_recon_loss:.3f}') 314 | 315 | self.log(val_loss = total_val_recon_loss) 316 | 317 | self.wait() 318 | 319 | if self.is_main and divisible_by(step, self.checkpoint_every): 320 | checkpoint_num = step // self.checkpoint_every 321 | self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.{checkpoint_num}.pt') 322 | 323 | self.wait() 324 | 325 | self.print('training complete') 326 | 327 | def train(self, num_epochs, stop_at_loss = None, diplay_graph = False): 328 | epoch_losses, epoch_recon_losses, epoch_commit_losses = [] , [],[] 329 | self.model.train() 330 | 331 | for epoch in range(num_epochs): 332 | total_epoch_loss, total_epoch_recon_loss, total_epoch_commit_loss = 0.0, 0.0, 0.0 333 | 334 | progress_bar = tqdm(enumerate(self.dataloader), desc=f'Epoch {epoch + 1}/{num_epochs}', total=len(self.dataloader)) 335 | for batch_idx, batch in progress_bar: 336 | is_last = (batch_idx+1) % self.grad_accum_every == 0 337 | maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext 338 | 339 | if isinstance(batch, tuple): 340 | forward_kwargs = dict(zip(self.data_kwargs, batch)) 341 | elif isinstance(batch, dict): 342 | forward_kwargs = batch 343 | maybe_del(forward_kwargs, 'texts', 'text_embeds') 344 | 345 | with self.accelerator.autocast(), maybe_no_sync(): 346 | total_loss, (recon_loss, commit_loss) = self.model( 347 | **forward_kwargs, 348 | return_loss_breakdown = True 349 | ) 350 | self.accelerator.backward(total_loss / self.grad_accum_every) 351 | 352 | current_loss = total_loss.item() 353 | total_epoch_loss += current_loss 354 | total_epoch_recon_loss += recon_loss.item() 355 | total_epoch_commit_loss += commit_loss.sum().item() 356 | 357 | progress_bar.set_postfix(loss=current_loss, recon_loss = round(recon_loss.item(),3), commit_loss = round(commit_loss.sum().item(),4)) 358 | 359 | if is_last or (batch_idx + 1 == len(self.dataloader)): 360 | self.optimizer.step() 361 | self.optimizer.zero_grad() 362 | 363 | 364 | 365 | avg_recon_loss = total_epoch_recon_loss / len(self.dataloader) 366 | avg_commit_loss = total_epoch_commit_loss / len(self.dataloader) 367 | avg_epoch_loss = total_epoch_loss / len(self.dataloader) 368 | 369 | epoch_losses.append(avg_epoch_loss) 370 | epoch_recon_losses.append(avg_recon_loss) 371 | epoch_commit_losses.append(avg_commit_loss) 372 | 373 | epochOut = f'Epoch {epoch + 1} average loss: {avg_epoch_loss} recon loss: {avg_recon_loss:.4f}: commit_loss {avg_commit_loss:.4f}' 374 | 375 | if len(epoch_losses) >= 4 and avg_epoch_loss > 0: 376 | avg_loss_improvement = sum(epoch_losses[-4:-1]) / 3 - avg_epoch_loss 377 | epochOut += f' avg loss speed: {avg_loss_improvement}' 378 | if avg_loss_improvement > 0 and avg_loss_improvement < 0.2: 379 | epochs_until_0_3 = max(0, abs(avg_epoch_loss-0.3) / avg_loss_improvement) 380 | if epochs_until_0_3> 0: 381 | epochOut += f' epochs left: {epochs_until_0_3:.2f}' 382 | 383 | self.wait() 384 | self.print(epochOut) 385 | 386 | 387 | if self.is_main and self.checkpoint_every_epoch is not None and (self.checkpoint_every_epoch == 1 or (epoch != 0 and epoch % self.checkpoint_every_epoch == 0)): 388 | self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.epoch_{epoch}_avg_loss_{avg_epoch_loss:.5f}_recon_{avg_recon_loss:.4f}_commit_{avg_commit_loss:.4f}.pt') 389 | 390 | if stop_at_loss is not None and avg_epoch_loss < stop_at_loss: 391 | self.print(f'Stopping training at epoch {epoch} with average loss {avg_epoch_loss}') 392 | if self.is_main and self.checkpoint_every_epoch is not None: 393 | self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.stop_at_loss_avg_loss_{avg_epoch_loss:.3f}.pt') 394 | break 395 | 396 | self.print('Training complete') 397 | if diplay_graph: 398 | plt.figure(figsize=(10, 5)) 399 | plt.plot(range(1, len(epoch_losses)+1), epoch_losses, marker='o', label='Total Loss') 400 | plt.plot(range(1, len(epoch_losses)+1), epoch_recon_losses, marker='o', label='Recon Loss') 401 | plt.plot(range(1, len(epoch_losses)+1), epoch_commit_losses, marker='o', label='Commit Loss') 402 | plt.title('Training Loss Over Epochs') 403 | plt.xlabel('Epoch') 404 | plt.ylabel('Average Loss') 405 | plt.grid(True) 406 | plt.show() 407 | return epoch_losses[-1] 408 | # mesh transformer trainer 409 | 410 | class MeshTransformerTrainer(Module): 411 | @typecheck 412 | def __init__( 413 | self, 414 | model: MeshTransformer, 415 | dataset: Dataset, 416 | num_train_steps: int, 417 | batch_size: int, 418 | grad_accum_every: int, 419 | learning_rate: float = 2e-4, 420 | weight_decay: float = 0., 421 | max_grad_norm: float | None = 0.5, 422 | val_dataset: Dataset | None = None, 423 | val_every = 1, 424 | val_num_batches = 5, 425 | scheduler: Type[_LRScheduler] | None = None, 426 | scheduler_kwargs: dict = dict(), 427 | ema_kwargs: dict = dict(), 428 | accelerator_kwargs: dict = dict(), 429 | optimizer_kwargs: dict = dict(), 430 | 431 | checkpoint_every = 1000, 432 | checkpoint_every_epoch: Type[int] | None = None, 433 | checkpoint_folder = './checkpoints', 434 | data_kwargs: Tuple[str, ...] = ('vertices', 'faces', 'face_edges', 'text'), 435 | warmup_steps = 1000, 436 | use_wandb_tracking = False 437 | ): 438 | super().__init__() 439 | 440 | # experiment tracker 441 | 442 | self.use_wandb_tracking = use_wandb_tracking 443 | 444 | if use_wandb_tracking: 445 | accelerator_kwargs['log_with'] = 'wandb' 446 | 447 | if 'kwargs_handlers' not in accelerator_kwargs: 448 | accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS] 449 | 450 | self.accelerator = Accelerator(**accelerator_kwargs) 451 | 452 | self.model = model 453 | 454 | optimizer = get_adam_optimizer( 455 | model.parameters(), 456 | lr = learning_rate, 457 | wd = weight_decay, 458 | filter_by_requires_grad = True, 459 | **optimizer_kwargs 460 | ) 461 | 462 | self.optimizer = OptimizerWithWarmupSchedule( 463 | accelerator = self.accelerator, 464 | optimizer = optimizer, 465 | scheduler = scheduler, 466 | scheduler_kwargs = scheduler_kwargs, 467 | warmup_steps = warmup_steps, 468 | max_grad_norm = max_grad_norm 469 | ) 470 | 471 | self.dataloader = DataLoader( 472 | dataset, 473 | shuffle = True, 474 | batch_size = batch_size, 475 | drop_last = True, 476 | collate_fn = partial(custom_collate, pad_id = model.pad_id) 477 | ) 478 | 479 | self.should_validate = exists(val_dataset) 480 | 481 | if self.should_validate: 482 | assert len(val_dataset) > 0, 'your validation dataset is empty' 483 | 484 | self.val_every = val_every 485 | self.val_num_batches = val_num_batches 486 | 487 | self.val_dataloader = DataLoader( 488 | val_dataset, 489 | shuffle = True, 490 | batch_size = batch_size, 491 | drop_last = True, 492 | collate_fn = partial(custom_collate, pad_id = model.pad_id) 493 | ) 494 | 495 | if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs): 496 | assert beartype_isinstance(dataset.data_kwargs, List[str]) 497 | self.data_kwargs = dataset.data_kwargs 498 | else: 499 | self.data_kwargs = data_kwargs 500 | 501 | ( 502 | self.model, 503 | self.dataloader 504 | ) = self.accelerator.prepare( 505 | self.model, 506 | self.dataloader 507 | ) 508 | 509 | self.grad_accum_every = grad_accum_every 510 | self.num_train_steps = num_train_steps 511 | self.register_buffer('step', torch.tensor(0)) 512 | 513 | self.checkpoint_every_epoch = checkpoint_every_epoch 514 | self.checkpoint_every = checkpoint_every 515 | self.checkpoint_folder = Path(checkpoint_folder) 516 | self.checkpoint_folder.mkdir(exist_ok = True, parents = True) 517 | 518 | def log(self, **data_kwargs): 519 | self.accelerator.log(data_kwargs, step = self.step.item()) 520 | 521 | @property 522 | def device(self): 523 | return self.unwrapped_model.device 524 | 525 | @property 526 | def is_main(self): 527 | return self.accelerator.is_main_process 528 | 529 | @property 530 | def unwrapped_model(self): 531 | return self.accelerator.unwrap_model(self.model) 532 | 533 | @property 534 | def is_local_main(self): 535 | return self.accelerator.is_local_main_process 536 | 537 | def wait(self): 538 | return self.accelerator.wait_for_everyone() 539 | 540 | def print(self, msg): 541 | return self.accelerator.print(msg) 542 | 543 | def next_data_to_forward_kwargs(self, dl_iter) -> dict: 544 | data = next(dl_iter) 545 | 546 | if isinstance(data, tuple): 547 | forward_kwargs = dict(zip(self.data_kwargs, data)) 548 | 549 | elif isinstance(data, dict): 550 | forward_kwargs = data 551 | 552 | return forward_kwargs 553 | 554 | def save(self, path, overwrite = True): 555 | path = Path(path) 556 | assert overwrite or not path.exists() 557 | 558 | pkg = dict( 559 | model = self.unwrapped_model.state_dict(), 560 | optimizer = self.optimizer.state_dict(), 561 | step = self.step.item(), 562 | version = __version__ 563 | ) 564 | 565 | torch.save(pkg, str(path)) 566 | 567 | def load(self, path): 568 | path = Path(path) 569 | assert path.exists() 570 | 571 | pkg = torch.load(str(path)) 572 | 573 | if version.parse(__version__) != version.parse(pkg['version']): 574 | self.print(f'loading saved mesh transformer at version {pkg["version"]}, but current package version is {__version__}') 575 | 576 | self.model.load_state_dict(pkg['model']) 577 | self.optimizer.load_state_dict(pkg['optimizer']) 578 | self.step.copy_(pkg['step']) 579 | 580 | def forward(self): 581 | step = self.step.item() 582 | dl_iter = cycle(self.dataloader) 583 | 584 | if self.should_validate: 585 | val_dl_iter = cycle(self.val_dataloader) 586 | 587 | while step < self.num_train_steps: 588 | 589 | for i in range(self.grad_accum_every): 590 | is_last = i == (self.grad_accum_every - 1) 591 | maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext 592 | 593 | forward_kwargs = self.next_data_to_forward_kwargs(dl_iter) 594 | 595 | with self.accelerator.autocast(), maybe_no_sync(): 596 | loss = self.model(**forward_kwargs) 597 | 598 | self.accelerator.backward(loss / self.grad_accum_every) 599 | 600 | self.print(f'loss: {loss.item():.3f}') 601 | 602 | self.log(loss = loss.item()) 603 | 604 | self.optimizer.step() 605 | self.optimizer.zero_grad() 606 | 607 | step += 1 608 | self.step.add_(1) 609 | 610 | self.wait() 611 | 612 | if self.is_main and self.should_validate and divisible_by(step, self.val_every): 613 | 614 | total_val_loss = 0. 615 | self.unwrapped_model.eval() 616 | 617 | num_val_batches = self.val_num_batches * self.grad_accum_every 618 | 619 | for _ in range(num_val_batches): 620 | with self.accelerator.autocast(), torch.no_grad(): 621 | 622 | forward_kwargs = self.next_data_to_forward_kwargs(val_dl_iter) 623 | 624 | val_loss = self.unwrapped_model(**forward_kwargs) 625 | 626 | total_val_loss += (val_loss / num_val_batches) 627 | 628 | self.print(f'valid recon loss: {total_val_loss:.3f}') 629 | 630 | self.log(val_loss = total_val_loss) 631 | 632 | self.wait() 633 | 634 | if self.is_main and divisible_by(step, self.checkpoint_every): 635 | checkpoint_num = step // self.checkpoint_every 636 | self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.{checkpoint_num}.pt') 637 | 638 | self.wait() 639 | 640 | self.print('training complete') 641 | 642 | 643 | def train(self, num_epochs, stop_at_loss = None, diplay_graph = False): 644 | epoch_losses = [] 645 | epoch_size = len(self.dataloader) 646 | self.model.train() 647 | 648 | for epoch in range(num_epochs): 649 | total_epoch_loss = 0.0 650 | 651 | progress_bar = tqdm(enumerate(self.dataloader), desc=f'Epoch {epoch + 1}/{num_epochs}', total=len(self.dataloader)) 652 | for batch_idx, batch in progress_bar: 653 | 654 | is_last = (batch_idx+1) % self.grad_accum_every == 0 655 | maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext 656 | 657 | with self.accelerator.autocast(), maybe_no_sync(): 658 | total_loss = self.model(**batch) 659 | self.accelerator.backward(total_loss / self.grad_accum_every) 660 | 661 | current_loss = total_loss.item() 662 | total_epoch_loss += current_loss 663 | 664 | progress_bar.set_postfix(loss=current_loss) 665 | 666 | if is_last or (batch_idx + 1 == len(self.dataloader)): 667 | self.optimizer.step() 668 | self.optimizer.zero_grad() 669 | 670 | avg_epoch_loss = total_epoch_loss / epoch_size 671 | epochOut = f'Epoch {epoch + 1} average loss: {avg_epoch_loss}' 672 | 673 | 674 | epoch_losses.append(avg_epoch_loss) 675 | 676 | if len(epoch_losses) >= 4 and avg_epoch_loss > 0: 677 | avg_loss_improvement = sum(epoch_losses[-4:-1]) / 3 - avg_epoch_loss 678 | epochOut += f' avg loss speed: {avg_loss_improvement}' 679 | if avg_loss_improvement > 0 and avg_loss_improvement < 0.2: 680 | epochs_until_0_3 = max(0, abs(avg_epoch_loss-0.3) / avg_loss_improvement) 681 | if epochs_until_0_3> 0: 682 | epochOut += f' epochs left: {epochs_until_0_3:.2f}' 683 | 684 | self.wait() 685 | self.print(epochOut) 686 | 687 | if self.is_main and self.checkpoint_every_epoch is not None and (self.checkpoint_every_epoch == 1 or (epoch != 0 and epoch % self.checkpoint_every_epoch == 0)): 688 | self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.epoch_{epoch}_avg_loss_{avg_epoch_loss:.3f}.pt') 689 | 690 | if stop_at_loss is not None and avg_epoch_loss < stop_at_loss: 691 | self.print(f'Stopping training at epoch {epoch} with average loss {avg_epoch_loss}') 692 | if self.is_main and self.checkpoint_every_epoch is not None: 693 | self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.stop_at_loss_avg_loss_{avg_epoch_loss:.3f}.pt') 694 | break 695 | 696 | 697 | self.print('Training complete') 698 | if diplay_graph: 699 | plt.figure(figsize=(10, 5)) 700 | plt.plot(range(1, len(epoch_losses) + 1), epoch_losses, marker='o', label='Total Loss') 701 | plt.title('Training Loss Over Epochs') 702 | plt.xlabel('Epoch') 703 | plt.ylabel('Average Loss') 704 | plt.grid(True) 705 | plt.show() 706 | return epoch_losses[-1] 707 | -------------------------------------------------------------------------------- /meshgpt_pytorch/typing.py: -------------------------------------------------------------------------------- 1 | from environs import Env 2 | 3 | from torch import Tensor 4 | 5 | from beartype import beartype 6 | from beartype.door import is_bearable 7 | 8 | from jaxtyping import ( 9 | Float, 10 | Int, 11 | Bool, 12 | jaxtyped 13 | ) 14 | 15 | # environment 16 | 17 | env = Env() 18 | env.read_env() 19 | 20 | # function 21 | 22 | def always(value): 23 | def inner(*args, **kwargs): 24 | return value 25 | return inner 26 | 27 | def identity(t): 28 | return t 29 | 30 | # jaxtyping is a misnomer, works for pytorch 31 | 32 | class TorchTyping: 33 | def __init__(self, abstract_dtype): 34 | self.abstract_dtype = abstract_dtype 35 | 36 | def __getitem__(self, shapes: str): 37 | return self.abstract_dtype[Tensor, shapes] 38 | 39 | Float = TorchTyping(Float) 40 | Int = TorchTyping(Int) 41 | Bool = TorchTyping(Bool) 42 | 43 | # use env variable TYPECHECK to control whether to use beartype + jaxtyping 44 | 45 | should_typecheck = env.bool('TYPECHECK', False) 46 | 47 | typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity 48 | 49 | beartype_isinstance = is_bearable if should_typecheck else always(True) 50 | 51 | __all__ = [ 52 | Float, 53 | Int, 54 | Bool, 55 | typecheck, 56 | beartype_isinstance 57 | ] 58 | -------------------------------------------------------------------------------- /meshgpt_pytorch/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.5.12' 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose -s 6 | python_files = tests/*.py 7 | python_paths = "." 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | exec(open('meshgpt_pytorch/version.py').read()) 4 | 5 | setup( 6 | name = 'meshgpt-pytorch', 7 | packages = find_packages(exclude=[]), 8 | version = __version__, 9 | license='MIT', 10 | description = 'MeshGPT Pytorch', 11 | author = 'Phil Wang', 12 | author_email = 'lucidrains@gmail.com', 13 | long_description_content_type = 'text/markdown', 14 | url = 'https://github.com/lucidrains/meshgpt-pytorch', 15 | keywords = [ 16 | 'artificial intelligence', 17 | 'deep learning', 18 | 'attention mechanisms', 19 | 'transformers', 20 | 'mesh generation' 21 | ], 22 | install_requires=[ 23 | 'matplotlib', 24 | 'accelerate>=0.25.0', 25 | 'beartype', 26 | "huggingface_hub>=0.21.4", 27 | 'classifier-free-guidance-pytorch>=0.6.10', 28 | 'einops>=0.8.0', 29 | 'einx[torch]>=0.3.0', 30 | 'ema-pytorch>=0.5.1', 31 | 'environs', 32 | 'gateloop-transformer>=0.2.2', 33 | 'jaxtyping', 34 | 'local-attention>=1.9.11', 35 | 'numpy', 36 | 'pytorch-custom-utils>=0.0.9', 37 | 'rotary-embedding-torch>=0.6.4', 38 | 'sentencepiece', 39 | 'taylor-series-linear-attention>=0.1.6', 40 | 'torch>=2.1', 41 | 'torch_geometric', 42 | 'tqdm', 43 | 'vector-quantize-pytorch>=1.14.22', 44 | 'x-transformers>=1.30.19,<1.31', 45 | ], 46 | setup_requires=[ 47 | 'pytest-runner', 48 | ], 49 | tests_require=[ 50 | 'pytest' 51 | ], 52 | classifiers=[ 53 | 'Development Status :: 4 - Beta', 54 | 'Intended Audience :: Developers', 55 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 56 | 'License :: OSI Approved :: MIT License', 57 | 'Programming Language :: Python :: 3.6', 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /tests/test_meshgpt.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from meshgpt_pytorch import ( 5 | MeshAutoencoder, 6 | MeshTransformer 7 | ) 8 | 9 | @pytest.mark.parametrize('adaptive_rmsnorm', (True, False)) 10 | def test_readme(adaptive_rmsnorm): 11 | 12 | autoencoder = MeshAutoencoder( 13 | num_discrete_coors = 128 14 | ) 15 | 16 | # mock inputs 17 | 18 | vertices = torch.randn((2, 121, 3)) # (batch, num vertices, coor (3)) 19 | faces = torch.randint(0, 121, (2, 2, 3)) # (batch, num faces, vertices (3)) 20 | 21 | # forward in the faces 22 | 23 | loss = autoencoder( 24 | vertices = vertices, 25 | faces = faces 26 | ) 27 | 28 | loss.backward() 29 | 30 | # after much training... 31 | # you can pass in the raw face data above to train a transformer to model this sequence of face vertices 32 | 33 | transformer = MeshTransformer( 34 | autoencoder, 35 | dim = 512, 36 | max_seq_len = 60, 37 | num_sos_tokens = 1, 38 | fine_cross_attend_text = True, 39 | text_cond_with_film = False, 40 | condition_on_text = True, 41 | coarse_post_gateloop_depth = 1, 42 | coarse_adaptive_rmsnorm = adaptive_rmsnorm 43 | ) 44 | 45 | loss = transformer( 46 | vertices = vertices, 47 | faces = faces, 48 | texts = ['a high chair', 'a small teapot'] 49 | ) 50 | 51 | loss.backward() 52 | 53 | faces_coordinates, face_mask = transformer.generate(texts = ['a small chair'], cond_scale = 3.) 54 | 55 | def test_cache(): 56 | # test that the output for generation with and without kv (and optional gateloop) cache is equivalent 57 | 58 | autoencoder = MeshAutoencoder( 59 | num_discrete_coors = 128 60 | ) 61 | 62 | transformer = MeshTransformer( 63 | autoencoder, 64 | dim = 512, 65 | max_seq_len = 12 66 | ) 67 | 68 | uncached_faces_coors, _ = transformer.generate(cache_kv = False, temperature = 0) 69 | cached_faces_coors, _ = transformer.generate(cache_kv = True, temperature = 0) 70 | 71 | assert torch.allclose(uncached_faces_coors, cached_faces_coors) 72 | --------------------------------------------------------------------------------