├── LICENSE ├── README.md ├── config └── meshgpt.yaml ├── dataset ├── __init__.py ├── quantized_soup.py └── triangles.py ├── inference └── infer_meshgpt.py ├── model ├── LICENSE ├── README.md ├── decoder.py ├── encoder.py ├── nanogpt.py ├── pointnet.py ├── softargmax.py ├── transformer.py └── transformer_base.py ├── requirements.txt ├── trainer ├── __init__.py ├── train_transformer.py └── train_vocabulary.py └── util ├── __init__.py ├── filesystem_logger.py ├── meshlab.py ├── misc.py ├── positional_encoding.py └── visualization.py /LICENSE: -------------------------------------------------------------------------------- 1 | Automotive Development Public Non-Commercial License Version 1.0 2 | (Based on Mozilla Public License Version 2.0 with additions regarding modifications) 3 | 4 | 1. Definitions 5 | 6 | 1.1 "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 7 | 8 | 1.2 "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution. 9 | 10 | 1.3 "Contribution" means Covered Software of a particular Contributor. 11 | 12 | 1.4 "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 13 | 14 | 1.5 "Executable Form" means any form of the work other than Source Code Form. 15 | 16 | 1.6 "Larger Work" means a work that combines Covered Software statically or dynamically with code not governed by the terms of this license. 17 | 18 | 1.7 "License" means this document. 19 | 20 | 1.8 "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 21 | 22 | 1.9 "Modifications" means any of the following: 23 | 24 | (a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or 25 | 26 | (b) any new file in Source Code Form that contains any Covered Software. 27 | 1.10 "Non-Commercial" means not primarily intended for or directed towards commercial advantage or monetary compensation. 28 | 29 | 1.11 "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 30 | 31 | 1.12 "Source Code Form" means the form of the work preferred for making modifications. 32 | 33 | 1.13 "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 34 | 35 | 2. License Grants and Conditions 36 | 37 | 2.1 Grants 38 | Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: 39 | 40 | (a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions for Non-Commercial purposes only, either on an unmodified basis, with Modifications, or as part of a Larger Work; and 41 | 42 | (b) under Patent Claims of such Contributor to make, use, have made, import, and otherwise transfer either its Contributions or its Contributor Version for Non-Commercial purposes only. 43 | 44 | 2.2 Effective Date 45 | The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 46 | 47 | 2.3 Limitations on Grant Scope 48 | The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: 49 | 50 | (a) for any code that a Contributor has removed from Covered Software; or 51 | 52 | (b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or 53 | 54 | (c) under Patent Claims infringed by Covered Software in the absence of its Contributions. 55 | 56 | This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.5). 57 | 58 | 2.4 Subsequent Licenses 59 | No Contributor makes additional grants as result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of any other license of Your choice (if permitted under the terms of Section 3.3). 60 | 61 | 2.5 Representation 62 | Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 63 | 64 | 2.6 Fair Use 65 | This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 66 | 67 | 2.7 Conditions 68 | Sections 3.1, 3.2, 3.3, 3.4 and 3.5 are conditions of the licenses granted in Section 2.1. 69 | 70 | 3. Responsibilities 71 | 72 | 3.1 Distribution of Source Form 73 | All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License and include a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form. 74 | 75 | 3.2 Distribution of Executable Form 76 | If You distribute Covered Software in Executable Form then: 77 | 78 | (a) such Covered Software must also be made available as described in Section 3.1 in Source Code Form including any Modifications that You create or to which You contribute and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and 79 | 80 | (b) You may distribute such Executable Form under the terms of this License and include this license, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. In any case, you must inform the recipients about the Initial Contributor providing the Covered Software under the terms of this License (we recommend the use of a notice like that in Exhibit A). 81 | 82 | 3.3 Distribution of a Larger Work 83 | You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. 84 | 3.4 Modifications 85 | (a) You may modify Covered Software, if you cause all Modifications to be supplemented by a file documenting the differences between the Covered Software and the Modifications and the date of the creation of any Modification. You must include a prominent statement that any Modification is derived, directly or indirectly, from the Covered Software provided by the Contributor and including the name of the Contributor in (i) the Source Code Form, and (ii) in a notice according to Exhibit A in an Executable Form or related documentation in which You describe the origin or ownership of the Covered Software. 86 | 87 | (b) You are obliged to provide this documentation supplementing your Modifications and the Covered Software including your Modifications to the Contributor of the Covered Software as specified in the notices according to Exhibit A irrespective of any distribution of the Covered Software. You may provide this information and code directly to the Contributor via email or another accepted electronic distribution mechanism. 88 | 89 | 3.5 Notices 90 | You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 91 | 92 | 4. Inability to Comply Due to Statute or Regulation 93 | 94 | If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 95 | 96 | 5. Termination 97 | 98 | 5.1 The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated, unless and until such Contributor explicitly and finally terminates Your grants. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice and inform the Contributor of Your compliance. 99 | 100 | 5.2 If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 101 | 102 | 6. Disclaimer of Warranty 103 | 104 | Covered Software is provided under this License basically free of charge in a development status and on an "as is" basis, without warranty of any kind, either expressed, implied, or statutory, including, without limitation, warranties that the Covered Software is free of defects, merchantable, fit for a particular purpose or non-infringing. The entire risk as to the interaction and performance of the Covered Software is with You. This disclaimer of warranty constitutes an essential part of this License. No use of any Covered Software is authorized under this License except under this disclaimer. 105 | 106 | 7. Limitation of Liability 107 | 108 | Any Contributor shall only be liable for damages other than those resulting from the detriment to life, body and health to the extent such damages arise from willful misconduct, gross negligence or the culpable violation of a fundamental contractual obligation on the part of the Contributor or its vicarious agents. Any further liability for damages shall be excluded, especially liability for the loss of data and the recovery of this data if this loss could have been avoided by You through appropriate precautionary measures, in particular by creating daily backups of all data. The provisions of the German Product Liability Act and other mandatory legal statutes shall remain unaffected. 109 | 110 | 8. Litigation 111 | 112 | Any litigation relating to this License may be brought only in courts of the German jurisdiction and such litigation shall be governed by German law. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims. 113 | 114 | 9. Miscellaneous 115 | 116 | This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 117 | 118 | 10. Versions of the License 119 | 120 | 10.1 New Versions 121 | Audi AG is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 122 | 123 | 10.2 Effect of New Versions 124 | You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 125 | 126 | 10.3 Modified Versions 127 | If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 128 | 129 | 130 | Exhibit A - Source Code Form License Notice 131 | 132 | Copyright (c) 2024, Audi AG, ai-licenses@audi.de. All rights reserved. 133 | 134 | This Source Code Form is subject to the terms of the Automotive Development Public Non-Commercial License, v. 1.0. If a copy of the ADPNCL was not distributed with this file, You can obtain one from the distributor. 135 | 136 | If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers 2 | 3 |
4 | 5 | [**arXiv**](http://arxiv.org/abs/2311.15475) | [**Video**](https://www.youtube.com/watch?v=UV90O1_69_o) | [**Project Page**](https://nihalsid.github.io/mesh-gpt/)
6 | 7 | 8 | This repository contains the implementation for the paper: 9 | 10 | [**MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers**](http://arxiv.org/abs/2311.15475) by Yawar Siddiqui, Antonio Alliegro, Alexey Artemov, Tatiana Tommasi, Daniele Sirigatti, Vladislav Rosov, Angela Dai, Matthias Nießner. 11 | 12 |
13 |
14 | animated 15 |
16 |
17 | MeshGPT creates triangle meshes by autoregressively sampling from a transformer model that has been trained to produce tokens from a learned geometric vocabulary. These tokens can then be decoded into the faces of a triangle mesh. Our method generates clean, coherent, and compact meshes, characterized by sharp edges and high fidelity. 18 |
19 |
20 | 21 | ## Dependencies 22 | 23 | Install requirements from the project root directory: 24 | 25 | ```bash 26 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu118.html 27 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118 28 | pip install packaging 29 | pip install -r requirements.txt 30 | ``` 31 | In case errors show up for missing packages, install them manually. 32 | 33 | ## Structure 34 | 35 | Overall code structure is as follows: 36 | 37 | | Folder | Description | 38 | |------------------------|----------------------------------------------------------------------------------------------| 39 | | `config/` | hydra configs | 40 | | `data/` | processed dataset 41 | | `dataset/` | pytorch datasets and dataloaders | 42 | | `docs/` | project webpage files | 43 | | `inference/` | scripts for inferencing trained model | 44 | | `model/` | pytorch modules for encoder, decoder and the transformer | 45 | | `pretrained/` | pretrained models on shapenet chairs and tables | 46 | | `runs/` | model training logs and checkpoints go here in addition to wandb | 47 | | `trainer/` | pytorch-lightning module for training | 48 | | `util/` | misc utilities for positional encoding, visualization, logging etc. | 49 | 50 | ## Pre-trained Models and Data 51 | 52 | Download the pretrained models and the data from [here](https://drive.google.com/drive/folders/1Gzuxn6c1pguvRWrsedmCa9xKtest8aC2?usp=drive_link). Place them in the project, such that trained models are in `pretrained/` directory and data is in `data/shapenet` directory. 53 | 54 | ### Running inference 55 | 56 | To run inference use the following command 57 | 58 | ```bash 59 | python inference/infer_meshgpt.py 60 | ``` 61 | 62 | Examples: 63 | 64 | ```bash 65 | # for chairs 66 | python inference/infer_meshgpt.py pretrained/transformer_ft_03001627/checkpoints/2287-0.ckpt beam 25 67 | 68 | # for tables 69 | python inference/infer_meshgpt.py pretrained/transformer_ft_04379243/checkpoints/1607-0.ckpt beam 25 70 | ``` 71 | 72 | ## Training 73 | 74 | For launching training, use the following command from project root 75 | 76 | ``` 77 | # vocabulary 78 | python trainer/train_vocabulary.py vq_resume= 79 | 80 | # transformer 81 | python trainer/train_transformer.py vq_resume= ft_category= ft_resume= 82 | ``` 83 | 84 | Some example trainings: 85 | 86 | #### Vocabulary training 87 | ```bash 88 | python trainer/train_vocabulary.py batch_size=32 shift_augment=True scale_augment=True wandb_main=True experiment=vq128 val_check_percent=1.0 val_check_interval=5 overfit=False max_epoch=2000 only_chairs=False use_smoothed_loss=True graph_conv=sage use_point_feats=False num_workers=24 n_embed=16384 num_tokens=131 embed_levels=2 num_val_samples=16 use_multimodal_loss=True weight_decay=0.1 embed_dim=192 code_decay=0.99 embed_share=True distribute_features=True 89 | ``` 90 | #### Base transformer training 91 | ```bash 92 | 93 | # run over multiple GPUs (recommended GPUs >= 8), if you have a good budget, can use higher gradient_accumulation_steps 94 | 95 | python trainer/train_transformer.py wandb_main=True batch_size=8 gradient_accumulation_steps=8 max_val_tokens=5000 max_epoch=2000 sanity_steps=0 val_check_interval=1 val_check_percent=1 block_size=4608 model.n_layer=24 model.n_head=16 model.n_embd=768 model.dropout=0 scale_augment=True shift_augment=True num_workers=24 experiment=bl4608-GPT2_m24-16-768-0_b8x8x8_lr1e-4 use_smoothed_loss=True num_tokens=131 vq_resume= padding=0 96 | ``` 97 | #### Transformer finetuning 98 | ```bash 99 | 100 | # run over multiple GPUs (recommended GPUs >= 8), if you have a good budget, can use higher gradient_accumulation_steps 101 | 102 | python trainer/train_transformer.py wandb_main=True batch_size=8 gradient_accumulation_steps=8 max_val_tokens=5000 max_epoch=2400 sanity_steps=0 val_check_interval=8 val_check_percent=1 block_size=4608 model.n_layer=24 model.n_head=16 model.n_embd=768 model.dropout=0 scale_augment=True shift_augment=True num_workers=24 experiment=bl4608-GPT2_m24-16-768-0_b8x8x8_FT04379243 use_smoothed_loss=True num_tokens=131 vq_resume= padding=0 num_val_samples=4 ft_category=04379243 ft_resume= warmup_steps=100 103 | 104 | ``` 105 | 106 | ## License 107 | MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers by Mohd Yawar Nihal Siddiqui is licensed under [Automotive Development Public Non-Commercial License Version 1.0](LICENSE), however portions of the project are available under separate license terms: e.g. NanoGPT code is under MIT license. 108 | 109 | ## Citation 110 | 111 | If you wish to cite us, please use the following BibTeX entry: 112 | 113 | ```BibTeX 114 | @InProceedings{siddiqui_meshgpt_2024, 115 | title={MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers}, 116 | author={Siddiqui, Yawar and Alliegro, Antonio and Artemov, Alexey and Tommasi, Tatiana and Sirigatti, Daniele and Rosov, Vladislav and Dai, Angela and Nie{\ss}ner, Matthias}, 117 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 118 | year = {2024}, 119 | } 120 | 121 | ``` -------------------------------------------------------------------------------- /config/meshgpt.yaml: -------------------------------------------------------------------------------- 1 | # data 2 | dataset: 'shapenet' 3 | dataset_root: 'data/shapenet/processed_data.pkl' 4 | # not actually number of tokens, but the quantization resolution + 3 5 | num_tokens: 131 6 | 7 | gradient_accumulation_steps: 2 # used to simulate larger batch sizes 8 | batch_size: 32 # if gradient_accumulation_steps > 1, this is the micro-batch size 9 | block_size: 4608 # context_size of the transformer, for good performance block_size > max sequence length 10 | padding: 0.0 # fraction of padding allowed, when sequences are beyond the transformer context size 11 | 12 | scale_augment: True # augment shapes by random scaling 13 | scale_augment_val: False # augment val shapes by random scaling 14 | 15 | shift_augment: True # augment shapes by shifting them in space 16 | shift_augment_val: False 17 | 18 | wandb_main: False # set to true to log to main board rather than debug board 19 | suffix: '' # suffix for project name in wandb, if wandb_main is false, auto populated to dump debug experiments to a different project 20 | experiment: snet # experiment name 21 | seed: null 22 | save_epoch: 1 # save every n epoch 23 | sanity_steps: 1 # sanity steps before the run 24 | val_check_percent: 1.0 # check this proportion of val set when evaluation runs 25 | val_check_interval: 1 # run evaluation every x% of the train set steps 26 | resume: null # resume from a checkpoint 27 | num_workers: 24 28 | logger: wandb 29 | overfit: False # overfitting dataloaders 30 | 31 | num_val_samples: 16 # number of meshes to visualize in evaluation 32 | max_val_tokens: 5000 33 | top_k_tokens: 200 # sampling top-k tokens 34 | top_p: 0.9 # p val for nucleus sampling 35 | temperature: 0.8 # temprature for sampling 36 | sequence_stride: 32 # use when sequences are larger than context length 37 | use_smoothed_loss: True # smoothing over neighboring tokens in the quantized space 38 | 39 | use_point_feats: False # point net like point features in graph network 40 | graph_conv: sage # flavor of graph convs 41 | g_no_max_pool: True # no max pooling in graph conv 42 | g_aggr: mean # aggregation op in graph conv 43 | ce_output: True # use quantized predictions 44 | embed_dim: 192 # vq embed dim 45 | n_embed: 16384 # vq num embeddings 46 | embed_loss_weight: 1.0 47 | embed_levels: 2 # rvq levels 48 | tri_weight: 0.00 # weight on geometric predictions 49 | norm_weight: 0.00 50 | area_weight: 0.00 51 | angle_weight: 0.00 52 | code_decay: 0.99 # code decay for vq 53 | embed_share: True # share embeddings across rvq levels 54 | use_multimodal_loss: True # multiple modes in ce loss when training vocabulary 55 | 56 | vq_resume: null # path to trained vocab when training the transformer 57 | ft_resume: null # path to transformer trained on all categories when finetuning 58 | ft_category: null # shapenet category to finetune 59 | distribute_features: True # distribute face features across vertices 60 | low_augment: False # lower the scale of augmentation 61 | 62 | # model 63 | model: 64 | in_emb: 3 65 | n_layer: 24 66 | n_head: 16 67 | n_embd: 768 68 | dropout: 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 69 | bias: False # do we use bias inside LayerNorm and Linear layers? 70 | 71 | # adamw optimizer 72 | lr: 1e-4 # max learning rate 73 | force_lr: null 74 | max_epoch: 2000 # total number of training iterations 75 | weight_decay: 1e-1 76 | beta1: 0.9 77 | beta2: 0.95 78 | grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 79 | 80 | # learning rate decay settings 81 | warmup_steps: 2000 # how many steps to warm up for 82 | min_lr: 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 83 | 84 | only_chairs: False 85 | stochasticity: 0.1 86 | 87 | hydra: 88 | output_subdir: null # Disable saving of config files. We'll do that ourselves. 89 | run: 90 | dir: . 91 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import networkx as nx 4 | 5 | 6 | newface_token = 0 7 | stopface_token = 1 8 | padface_token = 2 9 | 10 | 11 | def get_shifted_sequence(sequence): 12 | non_special = np.flatnonzero(np.isin(sequence, [0, 1, 2], invert=True)) 13 | if non_special.shape[0] > 0: 14 | idx = non_special[0] 15 | val = sequence[idx] 16 | sequence[non_special] -= (val - 3) 17 | return sequence 18 | 19 | 20 | def read_faces(text): 21 | all_lines = text.splitlines() 22 | all_face_lines = [x for x in all_lines if x.startswith('f ')] 23 | all_faces = [[int(y.split('/')[0]) - 1 for y in x.strip().split(' ')[1:]] for x in all_face_lines] 24 | return all_faces 25 | 26 | 27 | def read_vertices(text): 28 | all_lines = text.splitlines() 29 | all_vertex_lines = [x for x in all_lines if x.startswith('v ')] 30 | all_vertices = np.array([[float(y) for y in x.strip().split(' ')[1:]] for x in all_vertex_lines]) 31 | assert all_vertices.shape[1] == 3, 'vertices should have 3 coordinates' 32 | return all_vertices 33 | 34 | 35 | def quantize_coordinates(coords, num_tokens=256): 36 | if torch.is_tensor(coords): 37 | coords = torch.clip((coords + 0.5), 0, 1) * num_tokens # type: ignore 38 | coords_quantized = coords.round().long() 39 | else: 40 | coords = np.clip((coords + 0.5), 0, 1) * num_tokens # type: ignore 41 | coords_quantized = coords.round().astype(int) 42 | return coords_quantized 43 | 44 | 45 | def face_to_cycles(face): 46 | """Find cycles in face.""" 47 | g = nx.Graph() 48 | for v in range(len(face) - 1): 49 | g.add_edge(face[v], face[v + 1]) 50 | g.add_edge(face[-1], face[0]) 51 | return list(nx.cycle_basis(g)) 52 | 53 | 54 | def sort_vertices_and_faces(vertices_, faces_, num_tokens=256): 55 | vertices = np.clip((vertices_ + 0.5), 0, 1) * num_tokens # type: ignore 56 | vertices_quantized_ = vertices.round().astype(int) 57 | 58 | vertices_quantized_ = vertices_quantized_[:, [2, 0, 1]] 59 | vertices_quantized, unique_inverse = np.unique(vertices_quantized_, axis=0, return_inverse=True) 60 | 61 | sort_inds = np.lexsort(vertices_quantized.T) 62 | 63 | vertices_quantized = vertices_quantized[sort_inds] 64 | vertices_quantized = np.stack([vertices_quantized[:, 2], vertices_quantized[:, 1], vertices_quantized[:, 0]], axis=-1) 65 | 66 | # Re-index faces and tris to re-ordered vertices. 67 | faces = [np.argsort(sort_inds)[unique_inverse[f]] for f in faces_] 68 | # Merging duplicate vertices and re-indexing the faces causes some faces to 69 | # contain loops (e.g [2, 3, 5, 2, 4]). Split these faces into distinct 70 | # sub-faces. 71 | sub_faces = [] 72 | for f in faces: 73 | cliques = face_to_cycles(f) 74 | for c in cliques: 75 | c_length = len(c) 76 | # Only append faces with more than two verts. 77 | if c_length > 2: 78 | d = np.argmin(c) 79 | # Cyclically permute faces just that first index is the smallest. 80 | sub_faces.append([c[(d + i) % c_length] for i in range(c_length)]) 81 | faces = sub_faces 82 | # Sort faces by lowest vertex indices. If two faces have the same lowest 83 | # index then sort by next lowest and so on. 84 | faces.sort(key=lambda f: tuple(sorted(f))) 85 | 86 | # After removing degenerate faces some vertices are now unreferenced. 87 | # Remove these. 88 | num_verts = vertices_quantized.shape[0] 89 | vert_connected = np.equal( 90 | np.arange(num_verts)[:, None], np.hstack(faces)[None]).any(axis=-1) 91 | vertices_quantized = vertices_quantized[vert_connected] 92 | # Re-index faces and tris to re-ordered vertices. 93 | vert_indices = ( 94 | np.arange(num_verts) - np.cumsum(1 - vert_connected.astype('int'))) 95 | faces = [vert_indices[f].tolist() for f in faces] 96 | vertices = vertices_quantized / num_tokens - 0.5 97 | # order: Z, Y, X --> X, Y, Z 98 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1) 99 | return vertices, faces -------------------------------------------------------------------------------- /dataset/quantized_soup.py: -------------------------------------------------------------------------------- 1 | import random 2 | import omegaconf 3 | import torch 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from pathlib import Path 7 | import torch.utils.data 8 | import pickle 9 | from numpy import random 10 | import trimesh 11 | import torch_scatter 12 | 13 | from dataset import get_shifted_sequence 14 | from dataset.triangles import create_feature_stack 15 | from trainer import create_conv_batch, get_rvqvae_v0_encoder_vq, get_rvqvae_v1_encoder_vq 16 | from util.misc import normalize_vertices, scale_vertices 17 | from util.visualization import triangle_sequence_to_mesh 18 | 19 | 20 | class QuantizedSoup(Dataset): 21 | 22 | def __init__(self, config, split, scale_augment): 23 | super().__init__() 24 | data_path = Path(config.dataset_root) 25 | self.block_size = config.block_size 26 | self.vq_depth = config.embed_levels 27 | self.vq_is_shared = config.embed_share 28 | self.vq_num_codes_per_level = config.n_embed 29 | self.cached_vertices = [] 30 | self.cached_faces = [] 31 | self.names = [] 32 | self.num_tokens = config.num_tokens - 3 33 | self.scale_augment = scale_augment 34 | with open(data_path, 'rb') as fptr: 35 | data = pickle.load(fptr) 36 | if config.only_chairs: 37 | for s in ['train', 'val']: 38 | data[f'vertices_{s}'] = [data[f'vertices_{s}'][i] for i in range(len(data[f'vertices_{s}'])) if data[f'name_{s}'][i].split('_')[0] == '03001627'] 39 | data[f'faces_{s}'] = [data[f'faces_{s}'][i] for i in range(len(data[f'faces_{s}'])) if data[f'name_{s}'][i].split('_')[0] == '03001627'] 40 | data[f'name_{s}'] = [data[f'name_{s}'][i] for i in range(len(data[f'name_{s}'])) if data[f'name_{s}'][i].split('_')[0] == '03001627'] 41 | if not config.overfit: 42 | self.names = data[f'name_{split}'] 43 | self.cached_vertices = data[f'vertices_{split}'] 44 | self.cached_faces = data[f'faces_{split}'] 45 | else: 46 | multiplier = 16 if split == 'val' else 512 47 | self.names = data[f'name_train'][:1] * multiplier 48 | self.cached_vertices = data[f'vertices_train'][:1] * multiplier 49 | self.cached_faces = data[f'faces_train'][:1] * multiplier 50 | 51 | print(len(self.cached_vertices), "meshes loaded") 52 | 53 | max_inner_face_len = 0 54 | for i in range(len(self.cached_vertices)): 55 | self.cached_vertices[i] = np.array(self.cached_vertices[i]) 56 | for j in range(len(self.cached_faces[i])): 57 | max_inner_face_len = max(max_inner_face_len, len(self.cached_faces[i][j])) 58 | print('Longest inner face sequence', max_inner_face_len) 59 | assert max_inner_face_len == 3, "Only triangles are supported" 60 | 61 | self.start_sequence_token = 0 62 | self.end_sequence_token = 1 63 | self.pad_face_token = 2 64 | self.padding = int(config.padding * self.block_size) 65 | self.indices = [] 66 | 67 | max_face_sequence_len = 0 68 | min_face_sequence_len = 1e7 69 | for i in range(len(self.cached_faces)): 70 | sequence_len = len(self.cached_faces[i]) * self.vq_depth + 1 + 1 71 | max_face_sequence_len = max(max_face_sequence_len, sequence_len) 72 | min_face_sequence_len = min(min_face_sequence_len, sequence_len) 73 | for j in range(0, max(1, sequence_len - self.block_size + self.padding + 1), config.sequence_stride): # todo: possible bug? +1 added recently 74 | self.indices.append((i, j)) 75 | print('Length of', split, len(self.indices)) 76 | print('Shortest face sequence', min_face_sequence_len) 77 | print('Longest face sequence', max_face_sequence_len) 78 | self.encoder = None 79 | self.pre_quant = None 80 | self.vq = None 81 | self.post_quant = None 82 | self.decoder = None 83 | self.device = None 84 | 85 | def set_quantizer(self, encoder, pre_quant, vq, post_quant, decoder, device): 86 | self.encoder = encoder.eval() 87 | self.pre_quant = pre_quant.eval() 88 | self.decoder = decoder.eval() 89 | self.vq = vq.eval() 90 | self.post_quant = post_quant.eval() 91 | self.device = device 92 | 93 | @torch.no_grad() 94 | def get_codes(self, vertices, faces): 95 | triangles, normals, areas, angles, vertices, faces = create_feature_stack(vertices, faces, self.num_tokens) 96 | features = np.hstack([triangles, normals, areas, angles]) 97 | face_neighborhood = np.array(trimesh.Trimesh(vertices=vertices, faces=faces, process=False).face_neighborhood) # type: ignore 98 | 99 | encoded_x = self.encoder( 100 | torch.from_numpy(features).float().to(self.device), 101 | torch.from_numpy(face_neighborhood.T).long().to(self.device), 102 | torch.zeros([features.shape[0]], device=self.device).long() 103 | ) 104 | 105 | encoded_x = self.pre_quant(encoded_x) 106 | _, all_indices, _ = self.vq(encoded_x.unsqueeze(0)) 107 | all_indices = all_indices.squeeze(0) 108 | if not self.vq_is_shared: 109 | correction = (torch.arange(0, self.vq_depth, device=self.device) * self.vq_num_codes_per_level).reshape(1, -1) 110 | all_indices = all_indices + correction 111 | inner_face_id = torch.arange(0, self.vq_depth, device=self.device).reshape(1, -1).expand(all_indices.shape[0], -1) 112 | outer_face_id = torch.arange(0, all_indices.shape[0], device=self.device).reshape(-1, 1).expand(-1, self.vq_depth) 113 | 114 | # adjust for start, end and padding tokens 115 | all_indices = all_indices.reshape(-1) + 3 116 | inner_face_id = inner_face_id.reshape(-1) + 3 117 | outer_face_id = outer_face_id.reshape(-1) + 3 118 | 119 | # add start token and end token 120 | all_indices = torch.cat(( 121 | torch.tensor([self.start_sequence_token], device=self.device), 122 | all_indices, 123 | torch.tensor([self.end_sequence_token], device=self.device) 124 | )).long().cpu() 125 | 126 | inner_face_id = torch.cat(( 127 | torch.tensor([self.start_sequence_token], device=self.device), 128 | inner_face_id, 129 | torch.tensor([self.end_sequence_token], device=self.device) 130 | )).long().cpu() 131 | 132 | outer_face_id = torch.cat(( 133 | torch.tensor([self.start_sequence_token], device=self.device), 134 | outer_face_id, 135 | torch.tensor([self.end_sequence_token], device=self.device) 136 | )).long().cpu() 137 | 138 | return all_indices, inner_face_id, outer_face_id 139 | 140 | def __getitem__(self, index: int): 141 | i, j = self.indices[index] 142 | vertices = self.cached_vertices[i] 143 | faces = self.cached_faces[i] 144 | if self.scale_augment: 145 | vertices = normalize_vertices(scale_vertices(vertices)) 146 | else: 147 | vertices = normalize_vertices(vertices) 148 | 149 | soup_sequence, face_in_idx, face_out_idx = self.get_codes(vertices, faces) 150 | 151 | # face sequence in block format 152 | end_index = min(j + self.block_size, len(soup_sequence)) 153 | x_in = soup_sequence[j:end_index] 154 | y_in = soup_sequence[j + 1:end_index + 1] 155 | fpi_in = face_in_idx[j:end_index] 156 | fpo_in = face_out_idx[j:end_index] 157 | 158 | x_pad = torch.tensor([self.pad_face_token for _ in range(0, self.block_size - len(x_in))]) 159 | y_pad = torch.tensor([self.pad_face_token for _ in range(0, len(x_in) + len(x_pad) - len(y_in))]) 160 | fpi_in_pad = torch.tensor([self.pad_face_token for _ in range(0, self.block_size - len(fpi_in))]) 161 | fpo_in_pad = torch.tensor([self.pad_face_token for _ in range(0, self.block_size - len(fpo_in))]) 162 | 163 | x = torch.cat((x_in, x_pad)).long() 164 | y = torch.cat((y_in, y_pad)).long() 165 | fpi = torch.cat((fpi_in, fpi_in_pad)).long() 166 | fpo = torch.from_numpy(get_shifted_sequence(torch.cat((fpo_in, fpo_in_pad)).numpy())).long() 167 | 168 | return { 169 | 'name': self.names[i], 170 | 'seq_in': x, 171 | 'seq_out': y, 172 | 'seq_pos_inner': fpi, 173 | 'seq_pos_outer': fpo, 174 | } 175 | 176 | def get_completion_sequence(self, i, tokens, device=torch.device("cpu")): 177 | vertices = normalize_vertices(self.cached_vertices[i]) 178 | faces = self.cached_faces[i] 179 | soup_sequence, face_in_idx, face_out_idx = self.get_codes(vertices, faces) 180 | face_out_idx = torch.from_numpy(get_shifted_sequence(face_out_idx.numpy())) 181 | original_fseq = soup_sequence.to(device) 182 | if isinstance(tokens, int): 183 | num_pre_tokens = tokens 184 | else: 185 | num_pre_tokens = int(len(original_fseq) * tokens) 186 | x = ( 187 | soup_sequence[:num_pre_tokens].to(device)[None, ...], 188 | face_in_idx[:num_pre_tokens].to(device)[None, ...], 189 | face_out_idx[:num_pre_tokens].to(device)[None, ...], 190 | original_fseq[None, ...], 191 | ) 192 | return x 193 | 194 | def get_start(self, device=torch.device("cpu")): 195 | i = random.choice(list(range(len(self.cached_vertices)))) 196 | x = self.get_completion_sequence(i, 11, device) 197 | return x 198 | 199 | def __len__(self) -> int: 200 | return len(self.indices) 201 | 202 | def decode(self, sequence): 203 | mask = torch.isin(sequence, torch.tensor([self.start_sequence_token, self.end_sequence_token, self.pad_face_token], device=sequence.device)).logical_not() 204 | sequence = sequence[mask] 205 | sequence = sequence - 3 206 | 207 | sequence_len = (sequence.shape[0] // self.vq_depth) * self.vq_depth 208 | sequence = sequence[:sequence_len].reshape(-1, self.vq_depth) 209 | N = sequence.shape[0] 210 | E, D = self.vq_num_codes_per_level, self.vq_depth 211 | all_codes = self.vq.get_codes_from_indices(sequence).permute(1, 2, 0) 212 | encoded_x = all_codes.reshape(N, E, D).sum(-1) 213 | encoded_x = self.post_quant(encoded_x) 214 | encoded_x_conv, conv_mask = create_conv_batch(encoded_x, torch.zeros([sequence.shape[0]], device=self.device).long(), 1, self.device) 215 | decoded_x_conv = self.decoder(encoded_x_conv) 216 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :] 217 | coords = decoded_x.argmax(-1).detach().cpu().numpy() / self.num_tokens - 0.5 218 | gen_vertices, gen_faces = triangle_sequence_to_mesh(coords) 219 | return gen_vertices, gen_faces 220 | 221 | def plot_sequence_lenght_stats(self): 222 | sequence_lengths = [] 223 | for i in range(len(self.cached_faces)): 224 | sequence_len = len(self.cached_faces[i]) * self.vq_depth + 1 + 1 225 | sequence_lengths.append(sequence_len) 226 | import matplotlib.pyplot as plt 227 | plt.hist(sequence_lengths, bins=32) 228 | plt.ylim(0, 100) 229 | plt.show() 230 | 231 | 232 | class QuantizedSoupCreator(torch.nn.Module): 233 | vq_depth_factor = 1 234 | def __init__(self, config, vq_cfg): 235 | super().__init__() 236 | self.vq_cfg = vq_cfg 237 | self.vq_depth = self.vq_cfg.embed_levels 238 | self.vq_is_shared = self.vq_cfg.embed_share 239 | self.vq_num_codes_per_level = self.vq_cfg.n_embed 240 | self.vq_dim = self.vq_cfg.embed_dim 241 | assert config.num_tokens == self.vq_cfg.num_tokens, "Number of tokens must match" 242 | self.block_size = config.block_size 243 | self.start_sequence_token = 0 244 | self.end_sequence_token = 1 245 | self.pad_face_token = 2 246 | self.num_tokens = config.num_tokens - 3 247 | self.padding = int(config.padding * self.block_size) 248 | self.rq_transformer_input = False 249 | self.encoder, self.pre_quant, self.post_quant, self.vq = self.get_rvq_encoders(config.vq_resume) 250 | 251 | def get_rvq_encoders(self, resume): 252 | return get_rvqvae_v0_encoder_vq(self.vq_cfg, resume) 253 | 254 | def freeze_vq(self): 255 | for model in [self.encoder, self.pre_quant, self.post_quant, self.vq]: 256 | for param in model.parameters(): 257 | param.requires_grad = False 258 | 259 | @torch.no_grad() 260 | def embed(self, idx): 261 | assert self.vq_is_shared, "Only shared embedding is supported" 262 | all_codes = self.vq.codebooks[0][idx].reshape(-1, self.vq_dim) 263 | return all_codes 264 | 265 | @torch.no_grad() 266 | def get_indices(self, x, edge_index, batch, _faces, _num_vertices): 267 | encoded_x = self.encoder(x, edge_index, batch) 268 | encoded_x = self.pre_quant(encoded_x) 269 | _, all_indices, _ = self.vq(encoded_x.unsqueeze(0)) 270 | all_indices = all_indices.squeeze(0) 271 | if not self.vq_is_shared: 272 | correction = (torch.arange(0, self.vq_depth * self.vq_depth_factor, device=x.device) * self.vq_num_codes_per_level).reshape(1, -1) 273 | all_indices = all_indices + correction 274 | return all_indices 275 | 276 | @torch.no_grad() 277 | def forward(self, x, edge_index, batch, faces, num_vertices, js, force_full_sequence=False): 278 | for model in [self.encoder, self.pre_quant, self.post_quant, self.vq]: 279 | model.eval() 280 | if force_full_sequence: 281 | assert js.shape[0] == 1, "Only single mesh supported" 282 | all_indices = self.get_indices(x, edge_index, batch, faces, num_vertices) 283 | batch_size = js.shape[0] 284 | sequences = [] 285 | targets = [] 286 | position_inners = [] 287 | position_outers = [] 288 | max_sequence_length_x = 0 289 | for k in range(batch_size): 290 | sequence_k = all_indices[batch == k, :] 291 | inner_face_id_k = torch.arange(0, self.vq_depth * self.vq_depth_factor, device=x.device).reshape(1, -1).expand(sequence_k.shape[0], -1) 292 | outer_face_id_k = torch.arange(0, sequence_k.shape[0], device=x.device).reshape(-1, 1).expand(-1, self.vq_depth * self.vq_depth_factor) 293 | sequence_k = sequence_k.reshape(-1) + 3 294 | inner_face_id_k = inner_face_id_k.reshape(-1) + 3 295 | outer_face_id_k = outer_face_id_k.reshape(-1) + 3 296 | # add start token and end token 297 | 298 | prefix = [torch.tensor([self.start_sequence_token], device=x.device)] 299 | postfix = [torch.tensor([self.end_sequence_token], device=x.device)] 300 | 301 | if self.rq_transformer_input: 302 | prefix = prefix * self.vq_depth 303 | postfix = postfix * self.vq_depth 304 | 305 | sequence_k = torch.cat(prefix + [sequence_k] + postfix).long() 306 | 307 | inner_face_id_k = torch.cat(prefix + [inner_face_id_k] + postfix).long() 308 | 309 | outer_face_id_k = torch.cat(prefix + [outer_face_id_k] + postfix).long() 310 | 311 | j = js[k] 312 | if force_full_sequence: 313 | end_index = len(sequence_k) 314 | else: 315 | end_index = min(j + self.block_size, len(sequence_k)) 316 | x_in = sequence_k[j:end_index] 317 | if self.rq_transformer_input: 318 | y_in = sequence_k[j + self.vq_depth:end_index + self.vq_depth] 319 | else: 320 | y_in = sequence_k[j + 1:end_index + 1] 321 | fpi_in = inner_face_id_k[j:end_index] 322 | fpo_in = outer_face_id_k[j:end_index].cpu() 323 | 324 | max_sequence_length_x = max(max_sequence_length_x, len(x_in)) 325 | pad_len_x = self.block_size - len(x_in) 326 | if self.rq_transformer_input: 327 | pad_len_x = pad_len_x + (self.vq_depth - (len(x_in) + pad_len_x) % self.vq_depth) 328 | x_pad = torch.tensor([self.pad_face_token for _ in range(0, pad_len_x)], device=x.device) 329 | 330 | pad_len_y = len(x_in) + len(x_pad) - len(y_in) 331 | pad_len_fpi = self.block_size - len(fpi_in) 332 | pad_len_fpo = self.block_size - len(fpo_in) 333 | 334 | if self.rq_transformer_input: 335 | pad_len_fpi = pad_len_fpi + (self.vq_depth - (len(fpi_in) + pad_len_fpi) % self.vq_depth) 336 | pad_len_fpo = pad_len_fpo + (self.vq_depth - (len(fpo_in) + pad_len_fpo) % self.vq_depth) 337 | 338 | y_pad = torch.tensor([self.pad_face_token for _ in range(0, pad_len_y)], device=x.device) 339 | fpi_in_pad = torch.tensor([self.pad_face_token for _ in range(0, pad_len_fpi)], device=x.device) 340 | fpo_in_pad = torch.tensor([self.pad_face_token for _ in range(0, pad_len_fpo)]) 341 | 342 | x = torch.cat((x_in, x_pad)).long() 343 | y = torch.cat((y_in, y_pad)).long() 344 | fpi = torch.cat((fpi_in, fpi_in_pad)).long() 345 | fpo = torch.from_numpy(get_shifted_sequence(torch.cat((fpo_in, fpo_in_pad)).numpy())).long().to(x.device) 346 | 347 | sequences.append(x) 348 | targets.append(y) 349 | position_inners.append(fpi) 350 | position_outers.append(fpo) 351 | 352 | sequences = torch.stack(sequences, dim=0)[:, :max_sequence_length_x].contiguous() 353 | targets = torch.stack(targets, dim=0)[:, :max_sequence_length_x].contiguous() 354 | position_inners = torch.stack(position_inners, dim=0)[:, :max_sequence_length_x].contiguous() 355 | position_outers = torch.stack(position_outers, dim=0)[:, :max_sequence_length_x].contiguous() 356 | return sequences, targets, position_inners, position_outers 357 | 358 | @torch.no_grad() 359 | def get_completion_sequence(self, x, edge_index, face, num_vertices, tokens): 360 | soup_sequence, target, face_in_idx, face_out_idx = self.forward( 361 | x, edge_index, 362 | torch.zeros([x.shape[0]], device=x.device).long(), face, 363 | num_vertices, 364 | torch.zeros([1], device=x.device).long(), 365 | force_full_sequence=True 366 | ) 367 | soup_sequence = soup_sequence[0] 368 | face_in_idx = face_in_idx[0] 369 | face_out_idx = face_out_idx[0] 370 | target = target[0] 371 | if isinstance(tokens, int): 372 | num_pre_tokens = tokens 373 | else: 374 | num_pre_tokens = int(len(target) * tokens) 375 | x = ( 376 | soup_sequence[:num_pre_tokens][None, ...], 377 | face_in_idx[:num_pre_tokens][None, ...], 378 | face_out_idx[:num_pre_tokens][None, ...], 379 | target[None, ...], 380 | ) 381 | return x 382 | 383 | def encode_sequence(self, sequence): 384 | N = sequence.shape[0] 385 | E, D = self.vq_dim, self.vq_depth 386 | all_codes = self.vq.get_codes_from_indices(sequence).permute(1, 2, 0) 387 | encoded_x = all_codes.reshape(N, E, D).sum(-1) 388 | return encoded_x 389 | 390 | @torch.no_grad() 391 | def decode(self, sequence, decoder): 392 | mask = torch.isin(sequence, torch.tensor([self.start_sequence_token, self.end_sequence_token, self.pad_face_token], device=sequence.device)).logical_not() 393 | sequence = sequence[mask] 394 | sequence = sequence - 3 395 | sequence_len = (sequence.shape[0] // (self.vq_depth * self.vq_depth_factor)) * (self.vq_depth * self.vq_depth_factor) 396 | sequence = sequence[:sequence_len].reshape(-1, self.vq_depth) 397 | encoded_x = self.encode_sequence(sequence) 398 | encoded_x = self.post_quant(encoded_x) 399 | encoded_x_conv, conv_mask = create_conv_batch(encoded_x, torch.zeros([encoded_x.shape[0]], device=sequence.device).long(), 1, sequence.device) 400 | decoded_x_conv = decoder(encoded_x_conv) 401 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :] 402 | coords = decoded_x.argmax(-1).detach().cpu().numpy() / self.num_tokens - 0.5 403 | gen_vertices, gen_faces = triangle_sequence_to_mesh(coords) 404 | return gen_vertices, gen_faces 405 | 406 | 407 | class QuantizedSoupTripletsCreator(QuantizedSoupCreator): 408 | vq_depth_factor = 3 409 | def __init__(self, config, vq_cfg): 410 | super().__init__(config, vq_cfg) 411 | 412 | def get_rvq_encoders(self, resume): 413 | return get_rvqvae_v1_encoder_vq(self.vq_cfg, resume) 414 | 415 | @torch.no_grad() 416 | def get_indices(self, x, edge_index, batch, faces, num_vertices): 417 | encoded_x = self.encoder(x, edge_index, batch) 418 | encoded_x = encoded_x.reshape(encoded_x.shape[0] * 3, 192) 419 | encoded_x = distribute_features(encoded_x, faces, num_vertices, x.device) 420 | encoded_x = self.pre_quant(encoded_x) 421 | _, all_indices, _ = self.vq(encoded_x.unsqueeze(0)) 422 | all_indices = all_indices.squeeze(0).reshape(-1, self.vq_depth * 3) 423 | if not self.vq_is_shared: 424 | correction = (torch.arange(0, self.vq_depth, device=x.device) * self.vq_num_codes_per_level).reshape(1, -1) 425 | all_indices = all_indices + correction 426 | return all_indices 427 | 428 | def encode_sequence(self, sequence): 429 | N = sequence.shape[0] 430 | E, D = self.vq_dim, self.vq_depth 431 | all_codes = self.vq.get_codes_from_indices(sequence).permute(1, 2, 0) 432 | encoded_x = all_codes.reshape(-1, 3 * E, D).sum(-1) 433 | return encoded_x 434 | 435 | 436 | def distribute_features(features, face_indices, num_vertices, device): 437 | assert features.shape[0] == face_indices.shape[0] * face_indices.shape[1], "Features and face indices must match in size" 438 | vertex_features = torch.zeros([num_vertices, features.shape[1]], device=device) 439 | torch_scatter.scatter_mean(features, face_indices.reshape(-1), out=vertex_features, dim=0) 440 | distributed_features = vertex_features[face_indices.reshape(-1), :] 441 | return distributed_features 442 | -------------------------------------------------------------------------------- /dataset/triangles.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Sequence 2 | 3 | import omegaconf 4 | import torch 5 | import numpy as np 6 | import trimesh 7 | from torch.utils.data import Dataset, default_collate 8 | from pathlib import Path 9 | import torch.utils.data 10 | import pickle 11 | 12 | from torch_geometric.data.data import BaseData 13 | from tqdm import tqdm 14 | 15 | from dataset import sort_vertices_and_faces, quantize_coordinates 16 | from util.misc import normalize_vertices, scale_vertices, shift_vertices 17 | from torch_geometric.data import Dataset as GeometricDataset, Batch 18 | from torch_geometric.data import Data as GeometricData 19 | from torch_geometric.loader.dataloader import Collater as GeometricCollator 20 | 21 | 22 | class TriangleNodes(GeometricDataset): 23 | 24 | def __init__(self, config, split, scale_augment, shift_augment, force_category, use_start_stop=False, only_backward_edges=False): 25 | super().__init__() 26 | data_path = Path(config.dataset_root) 27 | self.cached_vertices = [] 28 | self.cached_faces = [] 29 | self.names = [] 30 | self.scale_augment = scale_augment 31 | self.shift_augment = shift_augment 32 | self.low_augment = config.low_augment 33 | self.use_start_stop = use_start_stop 34 | self.ce_output = config.ce_output 35 | self.only_backward_edges = only_backward_edges 36 | self.num_tokens = config.num_tokens - 3 37 | with open(data_path, 'rb') as fptr: 38 | data = pickle.load(fptr) 39 | if force_category is not None: 40 | for s in ['train', 'val']: 41 | data[f'vertices_{s}'] = [data[f'vertices_{s}'][i] for i in range(len(data[f'vertices_{s}'])) if data[f'name_{s}'][i].split('_')[0] == force_category] 42 | data[f'faces_{s}'] = [data[f'faces_{s}'][i] for i in range(len(data[f'faces_{s}'])) if data[f'name_{s}'][i].split('_')[0] == force_category] 43 | data[f'name_{s}'] = [data[f'name_{s}'][i] for i in range(len(data[f'name_{s}'])) if data[f'name_{s}'][i].split('_')[0] == force_category] 44 | if len(data[f'vertices_val']) == 0: 45 | data[f'vertices_val'] = data[f'vertices_train'] 46 | data[f'faces_val'] = data[f'faces_train'] 47 | data[f'name_val'] = data[f'name_train'] 48 | if not config.overfit: 49 | self.names = data[f'name_{split}'] 50 | self.cached_vertices = data[f'vertices_{split}'] 51 | self.cached_faces = data[f'faces_{split}'] 52 | else: 53 | multiplier = 16 if split == 'val' else 512 54 | self.names = data[f'name_train'][:1] * multiplier 55 | self.cached_vertices = data[f'vertices_train'][:1] * multiplier 56 | self.cached_faces = data[f'faces_train'][:1] * multiplier 57 | 58 | print(len(self.cached_vertices), "meshes loaded") 59 | 60 | def len(self): 61 | return len(self.cached_vertices) 62 | 63 | def get_all_features_for_shape(self, idx): 64 | vertices = self.cached_vertices[idx] 65 | faces = self.cached_faces[idx] 66 | if self.scale_augment: 67 | if self.low_augment: 68 | x_lims = (0.9, 1.1) 69 | y_lims = (0.9, 1.1) 70 | z_lims = (0.9, 1.1) 71 | else: 72 | x_lims = (0.75, 1.25) 73 | y_lims = (0.75, 1.25) 74 | z_lims = (0.75, 1.25) 75 | vertices = scale_vertices(vertices, x_lims=x_lims, y_lims=y_lims, z_lims=z_lims) 76 | vertices = normalize_vertices(vertices) 77 | if self.shift_augment: 78 | vertices = shift_vertices(vertices) 79 | triangles, normals, areas, angles, vertices, faces = create_feature_stack(vertices, faces, self.num_tokens) 80 | features = np.hstack([triangles, normals, areas, angles]) 81 | face_neighborhood = np.array(trimesh.Trimesh(vertices=vertices, faces=faces, process=False).face_neighborhood) # type: ignore 82 | target = torch.from_numpy(features[:, :9]).float() 83 | if self.use_start_stop: 84 | features = np.concatenate([np.zeros((1, features.shape[1])), features], axis=0) 85 | target = torch.cat([target, torch.ones(1, 9) * 0.5], dim=0) 86 | face_neighborhood = face_neighborhood + 1 87 | if self.only_backward_edges: 88 | face_neighborhood = face_neighborhood[face_neighborhood[:, 1] > face_neighborhood[:, 0], :] 89 | # face_neighborhood = modify so that only edges in backward direction are present 90 | if self.ce_output: 91 | target = quantize_coordinates(target, self.num_tokens) 92 | return features, target, vertices, faces, face_neighborhood 93 | 94 | def get(self, idx): 95 | features, target, _, _, face_neighborhood = self.get_all_features_for_shape(idx) 96 | return GeometricData(x=torch.from_numpy(features).float(), y=target, edge_index=torch.from_numpy(face_neighborhood.T).long()) 97 | 98 | 99 | class TriangleNodesWithFaces(TriangleNodes): 100 | 101 | def __init__(self, config, split, scale_augment, shift_augment, force_category): 102 | super().__init__(config, split, scale_augment, shift_augment, force_category) 103 | 104 | def get(self, idx): 105 | features, target, vertices, faces, face_neighborhood = self.get_all_features_for_shape(idx) 106 | return GeometricData(x=torch.from_numpy(features).float(), y=target, 107 | edge_index=torch.from_numpy(face_neighborhood.T).long(), 108 | num_vertices=vertices.shape[0], faces=torch.from_numpy(np.array(faces)).long()) 109 | 110 | 111 | class TriangleNodesWithFacesDataloader(torch.utils.data.DataLoader): 112 | 113 | def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=None, exclude_keys=None, **kwargs): 114 | # Remove for PyTorch Lightning: 115 | kwargs.pop('collate_fn', None) 116 | # Save for PyTorch Lightning < 1.6: 117 | self.follow_batch = follow_batch 118 | self.exclude_keys = exclude_keys 119 | super().__init__( 120 | dataset, 121 | batch_size, 122 | shuffle, 123 | collate_fn=FaceCollator(follow_batch, exclude_keys), 124 | **kwargs, 125 | ) 126 | 127 | 128 | class FaceCollator(GeometricCollator): 129 | 130 | def __init__(self, follow_batch, exclude_keys): 131 | super().__init__(follow_batch, exclude_keys) 132 | 133 | def __call__(self, batch): 134 | elem = batch[0] 135 | 136 | num_vertices = 0 137 | for b_idx in range(len(batch)): 138 | batch[b_idx].faces += num_vertices 139 | num_vertices += batch[b_idx].num_vertices 140 | 141 | if isinstance(elem, BaseData): 142 | return Batch.from_data_list(batch, self.follow_batch, self.exclude_keys) 143 | elif isinstance(elem, torch.Tensor): 144 | return default_collate(batch) 145 | elif isinstance(elem, float): 146 | return torch.tensor(batch, dtype=torch.float) 147 | elif isinstance(elem, int): 148 | return torch.tensor(batch) 149 | elif isinstance(elem, str): 150 | return batch 151 | elif isinstance(elem, Mapping): 152 | return {key: self([data[key] for data in batch]) for key in elem} 153 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): 154 | return type(elem)(*(self(s) for s in zip(*batch))) 155 | elif isinstance(elem, Sequence) and not isinstance(elem, str): 156 | return [self(s) for s in zip(*batch)] 157 | 158 | raise TypeError(f'DataLoader found invalid type: {type(elem)}') 159 | 160 | def collate(self, batch): # pragma: no cover 161 | raise NotImplementedError 162 | 163 | class TriangleNodesWithSequenceIndices(TriangleNodes): 164 | 165 | vq_depth_factor = 1 166 | 167 | def __init__(self, config, split, scale_augment, shift_augment, force_category): 168 | super().__init__(config, split, scale_augment=scale_augment, shift_augment=shift_augment, force_category=force_category) 169 | vq_cfg = omegaconf.OmegaConf.load(Path(config.vq_resume).parents[1] / "config.yaml") 170 | self.vq_depth = vq_cfg.embed_levels 171 | self.block_size = config.block_size 172 | max_inner_face_len = 0 173 | self.padding = int(config.padding * self.block_size) 174 | self.sequence_stride = config.sequence_stride 175 | for i in range(len(self.cached_vertices)): 176 | self.cached_vertices[i] = np.array(self.cached_vertices[i]) 177 | for j in range(len(self.cached_faces[i])): 178 | max_inner_face_len = max(max_inner_face_len, len(self.cached_faces[i][j])) 179 | print('Longest inner face sequence', max_inner_face_len) 180 | assert max_inner_face_len == 3, f"Only triangles are supported, but found a face with {max_inner_face_len}." 181 | self.sequence_indices = [] 182 | max_face_sequence_len = 0 183 | min_face_sequence_len = 1e7 184 | for i in range(len(self.cached_faces)): 185 | sequence_len = len(self.cached_faces[i]) * self.vq_depth * self.vq_depth_factor + 1 + 1 186 | max_face_sequence_len = max(max_face_sequence_len, sequence_len) 187 | min_face_sequence_len = min(min_face_sequence_len, sequence_len) 188 | self.sequence_indices.append((i, 0, False)) 189 | for j in range(config.sequence_stride, max(1, sequence_len - self.block_size + self.padding + 1), config.sequence_stride): # todo: possible bug? +1 added recently 190 | self.sequence_indices.append((i, j, True if split == 'train' else False)) 191 | if sequence_len > self.block_size: 192 | self.sequence_indices.append((i, sequence_len - self.block_size, False)) 193 | print('Length of', split, len(self.sequence_indices)) 194 | print('Shortest face sequence', min_face_sequence_len) 195 | print('Longest face sequence', max_face_sequence_len) 196 | 197 | def len(self): 198 | return len(self.sequence_indices) 199 | 200 | def get(self, idx): 201 | i, j, randomness = self.sequence_indices[idx] 202 | if randomness: 203 | sequence_len = len(self.cached_faces[i]) * self.vq_depth * self.vq_depth_factor + 1 + 1 204 | j = min(max(0, j + np.random.randint(-self.sequence_stride // 2, self.sequence_stride // 2)), sequence_len - self.block_size + self.padding) 205 | features, target, _, _, face_neighborhood = self.get_all_features_for_shape(i) 206 | return GeometricData(x=torch.from_numpy(features).float(), y=target, edge_index=torch.from_numpy(face_neighborhood.T).long(), js=torch.tensor(j).long()) 207 | 208 | def plot_sequence_lenght_stats(self): 209 | sequence_lengths = [] 210 | for i in range(len(self.cached_faces)): 211 | sequence_len = len(self.cached_faces[i]) * self.vq_depth * self.vq_depth_factor + 1 + 1 212 | sequence_lengths.append(sequence_len) 213 | import matplotlib.pyplot as plt 214 | plt.hist(sequence_lengths, bins=32) 215 | plt.ylim(0, 100) 216 | plt.show() 217 | return sequence_lengths 218 | 219 | 220 | class TriangleNodesWithFacesAndSequenceIndices(TriangleNodesWithSequenceIndices): 221 | vq_depth_factor = 3 222 | def __init__(self, config, split, scale_augment, shift_augment, force_category): 223 | super().__init__(config, split, scale_augment, shift_augment, force_category) 224 | 225 | def get(self, idx): 226 | i, j, randomness = self.sequence_indices[idx] 227 | if randomness: 228 | sequence_len = len(self.cached_faces[i]) * self.vq_depth * self.vq_depth_factor + 1 + 1 229 | j = min(max(0, j + np.random.randint(-self.sequence_stride // 2, self.sequence_stride // 2)), sequence_len - self.block_size + self.padding) 230 | features, target, vertices, faces, face_neighborhood = self.get_all_features_for_shape(i) 231 | return GeometricData(x=torch.from_numpy(features).float(), 232 | y=target, mesh_name=self.names[i], edge_index=torch.from_numpy(face_neighborhood.T).long(), 233 | js=torch.tensor(j).long(), num_vertices=vertices.shape[0], 234 | faces=torch.from_numpy(np.array(faces)).long()) 235 | 236 | 237 | class Triangles(Dataset): 238 | 239 | def __init__(self, config, split, scale_augment, shift_augment): 240 | super().__init__() 241 | data_path = Path(config.dataset_root) 242 | self.cached_vertices = [] 243 | self.cached_faces = [] 244 | self.names = [] 245 | self.scale_augment = scale_augment 246 | self.shift_augment = shift_augment 247 | with open(data_path, 'rb') as fptr: 248 | data = pickle.load(fptr) 249 | if not config.overfit: 250 | self.names = data[f'name_{split}'] 251 | self.cached_vertices = data[f'vertices_{split}'] 252 | self.cached_faces = data[f'faces_{split}'] 253 | else: 254 | multiplier = 1 if split == 'val' else 500 255 | self.names = data[f'name_train'][:1] * multiplier 256 | self.cached_vertices = data[f'vertices_train'][:1] * multiplier 257 | self.cached_faces = data[f'faces_train'][:1] * multiplier 258 | 259 | print(len(self.cached_vertices), "meshes loaded") 260 | self.features = None 261 | self.setup_triangles_for_epoch() 262 | 263 | def __len__(self): 264 | return self.features.shape[0] 265 | 266 | def setup_triangles_for_epoch(self): 267 | all_features = [] 268 | for idx in tqdm(range(len(self.cached_vertices)), desc="refresh augs"): 269 | vertices = self.cached_vertices[idx] 270 | faces = self.cached_faces[idx] 271 | if self.scale_augment: 272 | vertices = scale_vertices(vertices) 273 | vertices = normalize_vertices(vertices) 274 | if self.shift_augment: 275 | vertices = shift_vertices(vertices) 276 | all_features.append(create_feature_stack(vertices, faces)[0]) 277 | self.features = np.vstack(all_features) 278 | 279 | def __getitem__(self, idx): 280 | return { 281 | 'features': self.features[idx], 282 | 'target': self.features[idx, :9] 283 | } 284 | 285 | def get_all_features_for_shape(self, idx): 286 | vertices = self.cached_vertices[idx] 287 | faces = self.cached_faces[idx] 288 | feature_stack = create_feature_stack(vertices, faces)[0] 289 | return torch.from_numpy(feature_stack).float(), torch.from_numpy(feature_stack[:, :9]).float() 290 | 291 | 292 | def normal(triangles): 293 | # The cross product of two sides is a normal vector 294 | if torch.is_tensor(triangles): 295 | return torch.cross(triangles[:, 1] - triangles[:, 0], triangles[:, 2] - triangles[:, 0], dim=1) 296 | else: 297 | return np.cross(triangles[:, 1] - triangles[:, 0], triangles[:, 2] - triangles[:, 0], axis=1) 298 | 299 | 300 | def area(triangles): 301 | # The norm of the cross product of two sides is twice the area 302 | if torch.is_tensor(triangles): 303 | return torch.norm(normal(triangles), dim=1) / 2 304 | else: 305 | return np.linalg.norm(normal(triangles), axis=1) / 2 306 | 307 | 308 | def angle(triangles): 309 | v_01 = triangles[:, 1] - triangles[:, 0] 310 | v_02 = triangles[:, 2] - triangles[:, 0] 311 | v_10 = -v_01 312 | v_12 = triangles[:, 2] - triangles[:, 1] 313 | v_20 = -v_02 314 | v_21 = -v_12 315 | if torch.is_tensor(triangles): 316 | return torch.stack([angle_between(v_01, v_02), angle_between(v_10, v_12), angle_between(v_20, v_21)], dim=1) 317 | else: 318 | return np.stack([angle_between(v_01, v_02), angle_between(v_10, v_12), angle_between(v_20, v_21)], axis=1) 319 | 320 | 321 | def angle_between(v0, v1): 322 | v0_u = unit_vector(v0) 323 | v1_u = unit_vector(v1) 324 | if torch.is_tensor(v0): 325 | return torch.arccos(torch.clip(torch.einsum('ij,ij->i', v0_u, v1_u), -1.0, 1.0)) 326 | else: 327 | return np.arccos(np.clip(np.einsum('ij,ij->i', v0_u, v1_u), -1.0, 1.0)) 328 | 329 | 330 | def unit_vector(vector): 331 | if torch.is_tensor(vector): 332 | return vector / (torch.norm(vector, dim=-1)[:, None] + 1e-8) 333 | else: 334 | return vector / (np.linalg.norm(vector, axis=-1)[:, None] + 1e-8) 335 | 336 | 337 | def create_feature_stack(vertices, faces, num_tokens): 338 | vertices, faces = sort_vertices_and_faces(vertices, faces, num_tokens) 339 | # need more features: positions, angles, area, cross_product 340 | triangles = vertices[faces, :] 341 | triangles, normals, areas, angles = create_feature_stack_from_triangles(triangles) 342 | return triangles, normals, areas, angles, vertices, faces 343 | 344 | 345 | def create_feature_stack_from_triangles(triangles): 346 | t_areas = area(triangles) * 1e3 347 | t_angles = angle(triangles) / float(np.pi) 348 | t_normals = unit_vector(normal(triangles)) 349 | return triangles.reshape(-1, 9), t_normals.reshape(-1, 3), t_areas.reshape(-1, 1), t_angles.reshape(-1, 3) 350 | -------------------------------------------------------------------------------- /inference/infer_meshgpt.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | from model.pointnet import get_pointnet_classifier 4 | import omegaconf 5 | import torch 6 | from pathlib import Path 7 | 8 | import trimesh 9 | 10 | from dataset.quantized_soup import QuantizedSoupTripletsCreator 11 | from dataset.triangles import TriangleNodesWithFacesAndSequenceIndices 12 | from trainer import get_rvqvae_v0_decoder 13 | from trainer.train_transformer import get_qsoup_model_config 14 | from util.meshlab import meshlab_proc 15 | from util.misc import get_parameters_from_state_dict 16 | from util.visualization import plot_vertices_and_faces 17 | from tqdm import tqdm 18 | from model.transformer import QuantSoupTransformer 19 | from pytorch_lightning import seed_everything 20 | 21 | 22 | @torch.no_grad() 23 | def main(config, mode): 24 | seed_everything(42) 25 | device = torch.device('cuda:0') 26 | vq_cfg = omegaconf.OmegaConf.load(Path(config.vq_resume).parents[1] / "config.yaml") 27 | dataset = TriangleNodesWithFacesAndSequenceIndices(config, 'train', True, True , config.ft_category) 28 | prompt_num_faces = 4 29 | output_dir_image = Path(f'runs/{config.experiment}/inf_image_{mode}') 30 | output_dir_image.mkdir(exist_ok=True, parents=True) 31 | output_dir_mesh = Path(f'runs/{config.experiment}/inf_mesh_{mode}') 32 | output_dir_mesh.mkdir(exist_ok=True, parents=True) 33 | model_cfg = get_qsoup_model_config(config, vq_cfg.embed_levels) 34 | model = QuantSoupTransformer(model_cfg, vq_cfg) 35 | state_dict = torch.load(config.resume, map_location="cpu")["state_dict"] 36 | sequencer = QuantizedSoupTripletsCreator(config, vq_cfg) 37 | model.load_state_dict(get_parameters_from_state_dict(state_dict, "model")) 38 | model = model.to(device) 39 | model = model.eval() 40 | sequencer = sequencer.to(device) 41 | sequencer = sequencer.eval() 42 | decoder = get_rvqvae_v0_decoder(vq_cfg, config.vq_resume, device) 43 | pnet = get_pointnet_classifier().to(device) 44 | 45 | k = 0 46 | while k < config.num_val_samples: 47 | 48 | data = dataset.get(random.randint(0, len(dataset) - 1)) 49 | soup_sequence, face_in_idx, face_out_idx, target = sequencer.get_completion_sequence( 50 | data.x.to(device), 51 | data.edge_index.to(device), 52 | data.faces.to(device), 53 | data.num_vertices, 54 | 1 + 6 * prompt_num_faces 55 | ) 56 | 57 | y = None 58 | 59 | if mode == "topp": 60 | # more diversity but more change of bad sequences 61 | y = model.generate( 62 | soup_sequence, face_in_idx, face_out_idx, sequencer, config.max_val_tokens, 63 | temperature=config.temperature, top_k=config.top_k_tokens, top_p=config.top_p, use_kv_cache=True 64 | ) 65 | elif mode == "beam": 66 | # less diversity but more robust 67 | y = model.generate_with_beamsearch( 68 | soup_sequence, face_in_idx, face_out_idx, sequencer, config.max_val_tokens, use_kv_cache=True, beam_width=6 69 | ) 70 | 71 | if y is None: 72 | continue 73 | 74 | gen_vertices, gen_faces = sequencer.decode(y[0], decoder) 75 | 76 | try: 77 | mesh = trimesh.Trimesh(vertices=gen_vertices, faces=gen_faces, process=False) 78 | if pnet.classifier_guided_filter(mesh, config.ft_category): 79 | mesh.export(output_dir_mesh / f"{k:06d}.obj") 80 | meshlab_proc(output_dir_mesh / f"{k:06d}.obj") 81 | plot_vertices_and_faces(gen_vertices, gen_faces, output_dir_image / f"{k:06d}.jpg") 82 | print('Generated mesh', k + 1) 83 | k = k + 1 84 | except Exception as e: 85 | print('Exception occured: ', e) 86 | pass # sometimes the mesh is invalid (ngon) and we don't want to crash 87 | 88 | 89 | if __name__ == "__main__": 90 | cfg = omegaconf.OmegaConf.load(Path(sys.argv[1]).parents[1] / "config.yaml") 91 | cfg.resume = sys.argv[1] 92 | cfg.padding = 0.0 93 | cfg.num_val_samples = int(sys.argv[3]) 94 | cfg.sequence_stride = cfg.block_size 95 | cfg.top_p = 0.95 96 | cfg.temperature = 1.0 97 | cfg.low_augment = True 98 | main(cfg, sys.argv[2]) 99 | -------------------------------------------------------------------------------- /model/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | 2 | # nanoGPT 3 | 4 | ![nanoGPT](assets/nanogpt.jpg) 5 | 6 | The simplest, fastest repository for training/finetuning medium-sized GPTs. It is a rewrite of [minGPT](https://github.com/karpathy/minGPT) that prioritizes teeth over education. Still under active development, but currently the file `train.py` reproduces GPT-2 (124M) on OpenWebText, running on a single 8XA100 40GB node in about 4 days of training. The code itself is plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it. 7 | 8 | ![repro124m](assets/gpt2_124M_loss.png) 9 | 10 | Because the code is so simple, it is very easy to hack to your needs, train new models from scratch, or finetune pretrained checkpoints (e.g. biggest one currently available as a starting point would be the GPT-2 1.3B model from OpenAI). 11 | 12 | ## install 13 | 14 | ``` 15 | pip install torch numpy transformers datasets tiktoken wandb tqdm 16 | ``` 17 | 18 | Dependencies: 19 | 20 | - [pytorch](https://pytorch.org) <3 21 | - [numpy](https://numpy.org/install/) <3 22 | - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints) 23 | - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) 24 | - `tiktoken` for OpenAI's fast BPE code <3 25 | - `wandb` for optional logging <3 26 | - `tqdm` for progress bars <3 27 | 28 | ## quick start 29 | 30 | If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers: 31 | 32 | ``` 33 | $ python data/shakespeare_char/prepare.py 34 | ``` 35 | 36 | This creates a `train.bin` and `val.bin` in that data directory. Now it is time to train your GPT. The size of it very much depends on the computational resources of your system: 37 | 38 | **I have a GPU**. Great, we can quickly train a baby GPT with the settings provided in the [config/train_shakespeare_char.py](config/train_shakespeare_char.py) config file: 39 | 40 | ``` 41 | $ python train.py config/train_shakespeare_char.py 42 | ``` 43 | 44 | If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory: 45 | 46 | ``` 47 | $ python sample.py --out_dir=out-shakespeare-char 48 | ``` 49 | 50 | This generates a few samples, for example: 51 | 52 | ``` 53 | ANGELO: 54 | And cowards it be strawn to my bed, 55 | And thrust the gates of my threats, 56 | Because he that ale away, and hang'd 57 | An one with him. 58 | 59 | DUKE VINCENTIO: 60 | I thank your eyes against it. 61 | 62 | DUKE VINCENTIO: 63 | Then will answer him to save the malm: 64 | And what have you tyrannous shall do this? 65 | 66 | DUKE VINCENTIO: 67 | If you have done evils of all disposition 68 | To end his power, the day of thrust for a common men 69 | That I leave, to fight with over-liking 70 | Hasting in a roseman. 71 | ``` 72 | 73 | lol `¯\_(ツ)_/¯`. Not bad for a character-level model after 3 minutes of training on a GPU. Better results are quite likely obtainable by instead finetuning a pretrained GPT-2 model on this dataset (see finetuning section later). 74 | 75 | **I only have a macbook** (or other cheap computer). No worries, we can still train a GPT but we want to dial things down a notch. I recommend getting the bleeding edge PyTorch nightly ([select it here](https://pytorch.org/get-started/locally/) when installing) as it is currently quite likely to make your code more efficient. But even without it, a simple train run could look as follows: 76 | 77 | ``` 78 | $ python train.py config/train_shakespeare_char.py --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0 79 | ``` 80 | 81 | Here, since we are running on CPU instead of GPU we must set both `--device=cpu` and also turn off PyTorch 2.0 compile with `--compile=False`. Then when we evaluate we get a bit more noisy but faster estimate (`--eval_iters=20`, down from 200), our context size is only 64 characters instead of 256, and the batch size only 12 examples per iteration, not 64. We'll also use a much smaller Transformer (4 layers, 4 heads, 128 embedding size), and decrease the number of iterations to 2000 (and correspondingly usually decay the learning rate to around max_iters with `--lr_decay_iters`). Because our network is so small we also ease down on regularization (`--dropout=0.0`). This still runs in about ~3 minutes, but gets us a loss of only 1.88 and therefore also worse samples, but it's still good fun: 82 | 83 | ``` 84 | $ python sample.py --out_dir=out-shakespeare-char --device=cpu 85 | ``` 86 | Generates samples like this: 87 | 88 | ``` 89 | GLEORKEN VINGHARD III: 90 | Whell's the couse, the came light gacks, 91 | And the for mought you in Aut fries the not high shee 92 | bot thou the sought bechive in that to doth groan you, 93 | No relving thee post mose the wear 94 | ``` 95 | 96 | Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc. 97 | 98 | Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device=mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more. 99 | 100 | ## reproducing GPT-2 101 | 102 | A more serious deep learning professional may be more interested in reproducing GPT-2 results. So here we go - we first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText: 103 | 104 | ``` 105 | $ python data/openwebtext/prepare.py 106 | ``` 107 | 108 | This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run: 109 | 110 | ``` 111 | $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 112 | ``` 113 | 114 | This will run for about 4 days using PyTorch Distributed Data Parallel (DDP) and go down to loss of ~2.85. Now, a GPT-2 model just evaluated on OWT gets a val loss of about 3.11, but if you finetune it it will come down to ~2.85 territory (due to an apparent domain gap), making the two models ~match. 115 | 116 | If you're in a cluster environment and you are blessed with multiple GPU nodes you can make GPU go brrrr e.g. across 2 nodes like: 117 | 118 | ``` 119 | Run on the first (master) node with example IP 123.456.123.456: 120 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 121 | Run on the worker node: 122 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 123 | ``` 124 | 125 | It is a good idea to benchmark your interconnect (e.g. iperf3). In particular, if you don't have Infiniband then also prepend `NCCL_IB_DISABLE=1` to the above launches. Your multinode training will work, but most likely _crawl_. By default checkpoints are periodically written to the `--out_dir`. We can sample from the model by simply `$ python sample.py`. 126 | 127 | Finally, to train on a single GPU simply run the `$ python train.py` script. Have a look at all of its args, the script tries to be very readable, hackable and transparent. You'll most likely want to tune a number of those variables depending on your needs. 128 | 129 | ## baselines 130 | 131 | OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows: 132 | 133 | ``` 134 | $ python train.py eval_gpt2 135 | $ python train.py eval_gpt2_medium 136 | $ python train.py eval_gpt2_large 137 | $ python train.py eval_gpt2_xl 138 | ``` 139 | 140 | and observe the following losses on train and val: 141 | 142 | | model | params | train loss | val loss | 143 | | ------| ------ | ---------- | -------- | 144 | | gpt2 | 124M | 3.11 | 3.12 | 145 | | gpt2-medium | 350M | 2.85 | 2.84 | 146 | | gpt2-large | 774M | 2.66 | 2.67 | 147 | | gpt2-xl | 1558M | 2.56 | 2.54 | 148 | 149 | However, we have to note that GPT-2 was trained on (closed, never released) WebText, while OpenWebText is just a best-effort open reproduction of this dataset. This means there is a dataset domain gap. Indeed, taking the GPT-2 (124M) checkpoint and finetuning on OWT directly for a while reaches loss down to ~2.85. This then becomes the more appropriate baseline w.r.t. reproduction. 150 | 151 | ## finetuning 152 | 153 | Finetuning is no different than training, we just make sure to initialize from a pretrained model and train with a smaller learning rate. For an example of how to finetune a GPT on new text go to `data/shakespeare` and run `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`, using the OpenAI BPE tokenizer from GPT-2. Unlike OpenWebText this will run in seconds. Finetuning can take very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like: 154 | 155 | ``` 156 | $ python train.py config/finetune_shakespeare.py 157 | ``` 158 | 159 | This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. If you're running out of memory try decreasing the model size (they are `{'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}`) or possibly decreasing the `block_size` (context length). The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py --out_dir=out-shakespeare`: 160 | 161 | ``` 162 | THEODORE: 163 | Thou shalt sell me to the highest bidder: if I die, 164 | I sell thee to the first; if I go mad, 165 | I sell thee to the second; if I 166 | lie, I sell thee to the third; if I slay, 167 | I sell thee to the fourth: so buy or sell, 168 | I tell thee again, thou shalt not sell my 169 | possession. 170 | 171 | JULIET: 172 | And if thou steal, thou shalt not sell thyself. 173 | 174 | THEODORE: 175 | I do not steal; I sell the stolen goods. 176 | 177 | THEODORE: 178 | Thou know'st not what thou sell'st; thou, a woman, 179 | Thou art ever a victim, a thing of no worth: 180 | Thou hast no right, no right, but to be sold. 181 | ``` 182 | 183 | Whoa there, GPT, entering some dark place over there. I didn't really tune the hyperparameters in the config too much, feel free to try! 184 | 185 | ## sampling / inference 186 | 187 | Use the script `sample.py` to sample either from pre-trained GPT-2 models released by OpenAI, or from a model you trained yourself. For example, here is a way to sample from the largest available `gpt2-xl` model: 188 | 189 | ``` 190 | $ python sample.py \ 191 | --init_from=gpt2-xl \ 192 | --start="What is the answer to life, the universe, and everything?" \ 193 | --num_samples=5 --max_new_tokens=100 194 | ``` 195 | 196 | If you'd like to sample from a model you trained, use the `--out_dir` to point the code appropriately. You can also prompt the model with some text from a file, e.g. `$ python sample.py --start=FILE:prompt.txt`. 197 | 198 | ## efficiency notes 199 | 200 | For simple model benchmarking and profiling, `bench.py` might be useful. It's identical to what happens in the meat of the training loop of `train.py`, but omits much of the other complexities. 201 | 202 | Note that the code by default uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team! 203 | 204 | ## todos 205 | 206 | - Investigate and add FSDP instead of DDP 207 | - Eval zero-shot perplexities on standard evals (e.g. LAMBADA? HELM? etc.) 208 | - Finetune the finetuning script, I think the hyperparams are not great 209 | - Schedule for linear batch size increase during training 210 | - Incorporate other embeddings (rotary, alibi) 211 | - Separate out the optim buffers from model params in checkpoints I think 212 | - Additional logging around network health (e.g. gradient clip events, magnitudes) 213 | - Few more investigations around better init etc. 214 | 215 | ## troubleshooting 216 | 217 | Note that by default this repo uses PyTorch 2.0 (i.e. `torch.compile`). This is fairly new and experimental, and not yet available on all platforms (e.g. Windows). If you're running into related error messages try to disable this by adding `--compile=False` flag. This will slow down the code but at least it will run. 218 | 219 | For some context on this repository, GPT, and language modeling it might be helpful to watch my [Zero To Hero series](https://karpathy.ai/zero-to-hero.html). Specifically, the [GPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY) is popular if you have some prior language modeling context. 220 | 221 | For more questions/discussions feel free to stop by **#nanoGPT** on Discord: 222 | 223 | [![](https://dcbadge.vercel.app/api/server/3zy8kqD9Cp?compact=true&style=flat)](https://discord.gg/3zy8kqD9Cp) 224 | 225 | ## acknowledgements 226 | 227 | All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT! 228 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from util.positional_encoding import get_embedder 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion: int = 1 8 | 9 | def __init__(self, inplanes, planes, base_width=64): 10 | super().__init__() 11 | norm_layer = nn.BatchNorm1d 12 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 13 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=3, padding=1) 14 | self.bn1 = norm_layer(planes) 15 | self.relu = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, padding=1) 17 | self.bn2 = norm_layer(planes) 18 | 19 | if inplanes != planes: 20 | self.identity_fn = nn.Sequential( 21 | nn.Conv1d(inplanes, planes * self.expansion, 1), 22 | norm_layer(planes * self.expansion), 23 | ) 24 | else: 25 | self.identity_fn = nn.Identity() 26 | 27 | def forward(self, x): 28 | identity = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | out += self.identity_fn(identity) 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 45 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 46 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 47 | # This variant is also known as ResNet V1.5 and improves accuracy according to 48 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 49 | 50 | expansion: int = 4 51 | 52 | def __init__(self, inplanes, planes, base_width=64): 53 | super().__init__() 54 | norm_layer = nn.BatchNorm1d 55 | width = int(planes * (base_width / 64.0)) 56 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 57 | self.conv1 = nn.Conv1d(inplanes, width, 1) 58 | self.bn1 = norm_layer(width) 59 | self.conv2 = nn.Conv1d(width, width, 3, padding=1) 60 | self.bn2 = norm_layer(width) 61 | self.conv3 = nn.Conv1d(width, planes * self.expansion, 1) 62 | self.bn3 = norm_layer(planes * self.expansion) 63 | self.relu = nn.ReLU(inplace=True) 64 | if inplanes != planes: 65 | self.identity_fn = nn.Sequential( 66 | nn.Conv1d(inplanes, planes * self.expansion, 1), 67 | norm_layer(planes * self.expansion), 68 | ) 69 | else: 70 | self.identity_fn = nn.Identity() 71 | 72 | def forward(self, x): 73 | identity = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | out += self.identity_fn(identity) 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class ResNetDecoder(nn.Module): 93 | 94 | def __init__(self, in_feats, num_tokens, block, layers, zero_init_residual=False, width_per_group=64, ce_output=True): 95 | super().__init__() 96 | norm_layer = nn.BatchNorm1d 97 | self._norm_layer = norm_layer 98 | self.num_tokens = num_tokens 99 | self.inplanes = 512 100 | self.base_width = width_per_group 101 | self.conv1 = nn.Conv1d(in_feats, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False) 102 | self.bn1 = norm_layer(self.inplanes) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.layer1 = self._make_layer(block, 512, layers[0]) 105 | self.layer2 = self._make_layer(block, 384, layers[1]) 106 | self.layer3 = self._make_layer(block, 384, layers[2]) 107 | self.layer4 = self._make_layer(block, 320, layers[3]) 108 | self.ce_output = ce_output 109 | 110 | if ce_output: 111 | self.fc = nn.Conv1d(320 * block.expansion, self.num_tokens * 9, 1) 112 | else: 113 | self.fc = nn.Conv1d(320 * block.expansion, 9, 1) 114 | 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv1d): 117 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 118 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 119 | nn.init.constant_(m.weight, 1) 120 | nn.init.constant_(m.bias, 0) 121 | 122 | # Zero-initialize the last BN in each residual branch, 123 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 124 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 125 | if zero_init_residual: 126 | for m in self.modules(): 127 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 128 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 129 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 130 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 131 | 132 | def _make_layer(self, block, planes: int, blocks: int): 133 | layers = [block(self.inplanes, planes, self.base_width)] 134 | self.inplanes = planes * block.expansion 135 | for _ in range(1, blocks): 136 | layers.append( 137 | block( 138 | self.inplanes, 139 | planes, 140 | base_width=self.base_width, 141 | ) 142 | ) 143 | return nn.Sequential(*layers) 144 | 145 | def _forward_impl(self, x): 146 | B, _, N = x.shape 147 | x = self.conv1(x) 148 | x = self.bn1(x) 149 | x = self.relu(x) 150 | 151 | x = self.layer1(x) 152 | x = self.layer2(x) 153 | x = self.layer3(x) 154 | x = self.layer4(x) 155 | x = self.fc(x) 156 | if self.ce_output: 157 | x = x.permute((0, 2, 1)).reshape(B, N, 9, self.num_tokens) 158 | else: 159 | x = x.permute((0, 2, 1)).reshape(B, N, 9) 160 | return x 161 | 162 | def forward(self, x): 163 | return self._forward_impl(x) 164 | 165 | 166 | class ResNetEncoder(nn.Module): 167 | 168 | def __init__(self, in_feats, block, layers, zero_init_residual=False, width_per_group=64): 169 | super().__init__() 170 | norm_layer = nn.BatchNorm1d 171 | self._norm_layer = norm_layer 172 | self.inplanes = 128 173 | self.base_width = width_per_group 174 | self.conv1 = nn.Conv1d(in_feats, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False) 175 | self.bn1 = norm_layer(self.inplanes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.layer1 = self._make_layer(block, 128, layers[0]) 178 | self.layer2 = self._make_layer(block, 192, layers[1]) 179 | self.layer3 = self._make_layer(block, 256, layers[2]) 180 | self.layer4 = self._make_layer(block, 384, layers[3]) 181 | 182 | self.fc = nn.Conv1d(384 * block.expansion, 512, 1) 183 | 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv1d): 186 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 187 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 188 | nn.init.constant_(m.weight, 1) 189 | nn.init.constant_(m.bias, 0) 190 | 191 | # Zero-initialize the last BN in each residual branch, 192 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 193 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 194 | if zero_init_residual: 195 | for m in self.modules(): 196 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 197 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 198 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 199 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 200 | 201 | def _make_layer(self, block, planes: int, blocks: int): 202 | layers = [block(self.inplanes, planes, self.base_width)] 203 | self.inplanes = planes * block.expansion 204 | for _ in range(1, blocks): 205 | layers.append( 206 | block( 207 | self.inplanes, 208 | planes, 209 | base_width=self.base_width, 210 | ) 211 | ) 212 | return nn.Sequential(*layers) 213 | 214 | def _forward_impl(self, x): 215 | x = self.conv1(x) 216 | x = self.bn1(x) 217 | x = self.relu(x) 218 | 219 | x = self.layer1(x) 220 | x = self.layer2(x) 221 | x = self.layer3(x) 222 | x = self.layer4(x) 223 | x = self.fc(x) 224 | B, C, N = x.shape 225 | x = x.permute((0, 2, 1)).reshape(-1, C) 226 | return x 227 | 228 | def forward(self, x): 229 | return self._forward_impl(x) 230 | 231 | 232 | def resnet18_decoder(in_feats, num_tokens, ce_output=True): 233 | return ResNetDecoder(in_feats, num_tokens, BasicBlock, [2, 2, 2, 2], zero_init_residual=True, ce_output=ce_output) 234 | 235 | 236 | def resnet34_decoder(in_feats, num_tokens, ce_output=True): 237 | return ResNetDecoder(in_feats, num_tokens, BasicBlock, [3, 4, 6, 3], zero_init_residual=True, ce_output=ce_output) 238 | 239 | 240 | def resnet18_encoder(in_feats): 241 | return ResNetEncoder(in_feats, BasicBlock, [2, 2, 2, 2], zero_init_residual=True) 242 | 243 | 244 | def resnet34_encoder(in_feats): 245 | return ResNetEncoder(in_feats, BasicBlock, [3, 4, 6, 3], zero_init_residual=True) 246 | 247 | 248 | def test_resnet_decoder(): 249 | import torch 250 | model = resnet18_decoder(512, 65) 251 | print(model) 252 | print(sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6) 253 | 254 | model = model.cuda() 255 | x_ = torch.rand(1, 512, 2048).cuda() 256 | y_, c_ = model(x_) 257 | print(y_.shape, c_.shape) 258 | 259 | 260 | def test_resnet_encoder(): 261 | import torch 262 | model = resnet18_encoder(70) 263 | print(model) 264 | print(sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6) 265 | 266 | model = model.cuda() 267 | x_ = torch.rand(1, 70, 2048).cuda() 268 | y_ = model(x_) 269 | print(y_.shape) 270 | 271 | 272 | if __name__ == "__main__": 273 | test_resnet_encoder() 274 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric 4 | import torch_scatter 5 | 6 | from util.positional_encoding import get_embedder 7 | 8 | 9 | class GraphEncoder(nn.Module): 10 | 11 | def __init__(self, no_max_pool=True, aggr='mean', graph_conv="edge", use_point_features=False, output_dim=512): 12 | super().__init__() 13 | self.no_max_pool = no_max_pool 14 | self.use_point_features = use_point_features 15 | self.embedder, self.embed_dim = get_embedder(10) 16 | self.conv = graph_conv 17 | self.gc1 = get_conv(self.conv, self.embed_dim * 3 + 7, 64, aggr=aggr) 18 | self.gc2 = get_conv(self.conv, 64, 128, aggr=aggr) 19 | self.gc3 = get_conv(self.conv, 128, 256, aggr=aggr) 20 | self.gc4 = get_conv(self.conv, 256, 256, aggr=aggr) 21 | self.gc5 = get_conv(self.conv, 256, output_dim, aggr=aggr) 22 | 23 | self.norm1 = torch_geometric.nn.BatchNorm(64) 24 | self.norm2 = torch_geometric.nn.BatchNorm(128) 25 | self.norm3 = torch_geometric.nn.BatchNorm(256) 26 | self.norm4 = torch_geometric.nn.BatchNorm(256) 27 | 28 | self.relu = nn.ReLU() 29 | 30 | def forward(self, x, edge_index, batch): 31 | x_0 = self.embedder(x[:, :3]) 32 | x_1 = self.embedder(x[:, 3:6]) 33 | x_2 = self.embedder(x[:, 6:9]) 34 | x_n = x[:, 9:12] 35 | x_ar = x[:, 12:13] 36 | x_an_0 = x[:, 13:14] 37 | x_an_1 = x[:, 14:15] 38 | x_an_2 = x[:, 15:] 39 | x = torch.cat([x_0, x_1, x_2, x_n, x_ar, x_an_0, x_an_1, x_an_2], dim=-1) 40 | x = self.relu(self.norm1(self.gc1(x, edge_index))) 41 | x = self.norm2(self.gc2(x, edge_index)) 42 | point_features = x 43 | x = self.relu(x) 44 | x = self.relu(self.norm3(self.gc3(x, edge_index))) 45 | x = self.relu(self.norm4(self.gc4(x, edge_index))) 46 | x = self.gc5(x, edge_index) 47 | if not self.no_max_pool: 48 | x = torch_scatter.scatter_max(x, batch, dim=0)[0] 49 | x = x[batch, :] 50 | if self.use_point_features: 51 | return torch.cat([x, point_features], dim=-1) 52 | return x 53 | 54 | 55 | class GraphEncoderTriangleSoup(nn.Module): 56 | 57 | def __init__(self, aggr='mean', graph_conv="edge"): 58 | super().__init__() 59 | self.embedder, self.embed_dim = get_embedder(10) 60 | self.conv = graph_conv 61 | self.gc1 = get_conv(self.conv, self.embed_dim * 3 + 7, 96, aggr=aggr) 62 | self.gc2 = get_conv(self.conv, 96, 192, aggr=aggr) 63 | self.gc3 = get_conv(self.conv, 192, 384, aggr=aggr) 64 | self.gc4 = get_conv(self.conv, 384, 384, aggr=aggr) 65 | self.gc5 = get_conv(self.conv, 384, 576, aggr=aggr) 66 | 67 | self.norm1 = torch_geometric.nn.BatchNorm(96) 68 | self.norm2 = torch_geometric.nn.BatchNorm(192) 69 | self.norm3 = torch_geometric.nn.BatchNorm(384) 70 | self.norm4 = torch_geometric.nn.BatchNorm(384) 71 | 72 | self.relu = nn.ReLU() 73 | 74 | @staticmethod 75 | def distribute_features(features, face_indices, num_vertices): 76 | N, F = features.shape 77 | features = features.reshape(N * 3, F // 3) 78 | assert features.shape[0] == face_indices.shape[0] * face_indices.shape[1], "Features and face indices must match in size" 79 | vertex_features = torch.zeros([num_vertices, features.shape[1]], device=features.device) 80 | torch_scatter.scatter_mean(features, face_indices.reshape(-1), out=vertex_features, dim=0) 81 | distributed_features = vertex_features[face_indices.reshape(-1), :] 82 | distributed_features = distributed_features.reshape(N, F) 83 | return distributed_features 84 | 85 | def forward(self, x, edge_index, faces, num_vertices): 86 | x_0 = self.embedder(x[:, :3]) 87 | x_1 = self.embedder(x[:, 3:6]) 88 | x_2 = self.embedder(x[:, 6:9]) 89 | x = torch.cat([x_0, x_1, x_2, x[:, 9:]], dim=-1) 90 | x = self.relu(self.norm1(self.gc1(x, edge_index))) 91 | x = self.distribute_features(x, faces, num_vertices) 92 | x = self.relu(self.norm2(self.gc2(x, edge_index))) 93 | x = self.distribute_features(x, faces, num_vertices) 94 | x = self.relu(self.norm3(self.gc3(x, edge_index))) 95 | x = self.distribute_features(x, faces, num_vertices) 96 | x = self.relu(self.norm4(self.gc4(x, edge_index))) 97 | x = self.distribute_features(x, faces, num_vertices) 98 | x = self.gc5(x, edge_index) 99 | x = self.distribute_features(x, faces, num_vertices) 100 | return x 101 | 102 | 103 | def get_conv(conv, in_dim, out_dim, aggr): 104 | if conv == 'sage': 105 | return torch_geometric.nn.SAGEConv(in_dim, out_dim, aggr=aggr) 106 | elif conv == 'gat': 107 | return torch_geometric.nn.GATv2Conv(in_dim, out_dim, fill_value=aggr) 108 | elif conv == 'edge': 109 | return torch_geometric.nn.EdgeConv( 110 | torch.nn.Sequential( 111 | torch.nn.Linear(in_dim * 2, 2 * out_dim), 112 | torch.nn.ReLU(), 113 | torch.nn.Linear(2 * out_dim, out_dim), 114 | ), 115 | aggr=aggr, 116 | ) 117 | -------------------------------------------------------------------------------- /model/nanogpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | from util.misc import top_p_sampling 19 | 20 | 21 | class LayerNorm(nn.Module): 22 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 23 | 24 | def __init__(self, ndim, bias): 25 | super().__init__() 26 | self.weight = nn.Parameter(torch.ones(ndim)) 27 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 28 | 29 | def forward(self, input): 30 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 31 | 32 | 33 | class AttentionBase(nn.Module): 34 | 35 | def __init__(self, config): 36 | super().__init__() 37 | assert config.n_embd % config.n_head == 0 38 | # output projection 39 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 40 | # regularization 41 | self.attn_dropout = nn.Dropout(config.dropout) 42 | self.resid_dropout = nn.Dropout(config.dropout) 43 | self.n_head = config.n_head 44 | self.n_embd = config.n_embd 45 | self.dropout = config.dropout 46 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 47 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 48 | if not self.flash: 49 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 50 | # causal mask to ensure that attention is only applied to the left in the input sequence 51 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 52 | .view(1, 1, config.block_size, config.block_size)) 53 | 54 | 55 | class CausalSelfAttention(AttentionBase): 56 | 57 | def __init__(self, config): 58 | super().__init__(config) 59 | # key, query, value projections for all heads, but in a batch 60 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 61 | 62 | def forward(self, x, kv_cache=None, mask=None): 63 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 64 | 65 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 66 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 67 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 68 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 69 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 70 | 71 | if kv_cache is not None: 72 | key_cached, value_cached = kv_cache.unbind(dim=0) # (2, B, T, head_size) -> 2 * (B, T, head_size) 73 | k = torch.cat((key_cached, k), dim=-2) # (B, cache + T, head_size) 74 | v = torch.cat((value_cached, v), dim=-2) # (B, cache + T, head_size) 75 | 76 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 77 | if self.flash: 78 | # efficient attention using Flash Attention CUDA kernels 79 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0, is_causal=mask is None) 80 | else: 81 | # manual implementation of attention 82 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 83 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) 84 | att = F.softmax(att, dim=-1) 85 | att = self.attn_dropout(att) 86 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 87 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 88 | 89 | # output projection 90 | y = self.resid_dropout(self.c_proj(y)) 91 | return y, None if kv_cache is None else torch.stack((k, v)) 92 | 93 | 94 | class CrossAttention(AttentionBase): 95 | 96 | def __init__(self, config): 97 | super().__init__(config) 98 | self.ckv_attn = nn.Linear(config.n_embd, 2 * config.n_embd, bias=config.bias) 99 | self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 100 | 101 | def forward(self, x, encoding): 102 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)4 103 | 104 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 105 | k, v = self.ckv_attn(encoding).split(self.n_embd, dim=2) 106 | q = self.cq_attn(x) 107 | 108 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 109 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 110 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 111 | 112 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 113 | if self.flash: 114 | # efficient attention using Flash Attention CUDA kernels 115 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 116 | else: 117 | # manual implementation of attention 118 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 119 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) 120 | att = F.softmax(att, dim=-1) 121 | att = self.attn_dropout(att) 122 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 123 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 124 | 125 | # output projection 126 | y = self.resid_dropout(self.c_proj(y)) 127 | return y 128 | 129 | 130 | class MLP(nn.Module): 131 | 132 | def __init__(self, config): 133 | super().__init__() 134 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 135 | self.gelu = nn.GELU() 136 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 137 | self.dropout = nn.Dropout(config.dropout) 138 | 139 | def forward(self, x): 140 | x = self.c_fc(x) 141 | x = self.gelu(x) 142 | x = self.c_proj(x) 143 | x = self.dropout(x) 144 | return x 145 | 146 | 147 | class Block(nn.Module): 148 | 149 | def __init__(self, config): 150 | super().__init__() 151 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 152 | self.attn = CausalSelfAttention(config) 153 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 154 | self.mlp = MLP(config) 155 | 156 | def forward(self, x, kv_cache=None, mask=None): 157 | out, kv_cache = self.attn(self.ln_1(x), kv_cache, mask) 158 | x = x + out 159 | x = x + self.mlp(self.ln_2(x)) 160 | return x, kv_cache 161 | 162 | 163 | class BlockWithCrossAttention(nn.Module): 164 | 165 | def __init__(self, config): 166 | super().__init__() 167 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 168 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 169 | self.ln_3 = LayerNorm(config.n_embd, bias=config.bias) 170 | self.attn = CausalSelfAttention(config) 171 | self.xattn = CrossAttention(config) 172 | self.mlp = MLP(config) 173 | 174 | def forward(self, x, encoding): 175 | x = x + self.attn(self.ln_1(x)) 176 | x = x + self.xattn(self.ln_2(x, encoding)) 177 | x = x + self.mlp(self.ln_3(x)) 178 | return x 179 | 180 | 181 | @dataclass 182 | class GPTConfig: 183 | block_size: int = 1024 184 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 185 | n_layer: int = 12 186 | n_head: int = 12 187 | n_embd: int = 768 188 | dropout: float = 0.0 189 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 190 | 191 | 192 | class GPT(nn.Module): 193 | 194 | def __init__(self, config): 195 | super().__init__() 196 | assert config.vocab_size is not None 197 | assert config.block_size is not None 198 | self.config = config 199 | self.padding_idx = config.vocab_size - 1 200 | self.vocab_size = config.vocab_size 201 | print('Model Padding Index:', self.padding_idx) 202 | self.transformer = nn.ModuleDict(dict( 203 | wte=nn.Embedding(config.vocab_size, config.n_embd, padding_idx=self.padding_idx), 204 | wpe=nn.Embedding(config.block_size, config.n_embd), 205 | wce=nn.Embedding(3, config.n_embd), 206 | drop=nn.Dropout(config.dropout), 207 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 208 | ln_f=LayerNorm(config.n_embd, bias=config.bias), 209 | )) 210 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 211 | # with weight tying when using torch.compile() some warnings get generated: 212 | # "UserWarning: functional_call was passed multiple values for tied weights. 213 | # This behavior is deprecated and will be an error in future versions" 214 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 215 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 216 | 217 | # init all weights 218 | self.apply(self._init_weights) 219 | # apply special scaled init to the residual projections, per GPT-2 paper 220 | for pn, p in self.named_parameters(): 221 | if pn.endswith('c_proj.weight'): 222 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 223 | 224 | # report number of parameters 225 | print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) 226 | 227 | def get_num_params(self, non_embedding=True): 228 | """ 229 | Return the number of parameters in the model. 230 | For non-embedding count (default), the position embeddings get subtracted. 231 | The token embeddings would too, except due to the parameter sharing these 232 | params are actually used as weights in the final layer, so we include them. 233 | """ 234 | n_params = sum(p.numel() for p in self.parameters()) 235 | if non_embedding: 236 | n_params -= self.transformer.wpe.weight.numel() 237 | return n_params 238 | 239 | def _init_weights(self, module): 240 | if isinstance(module, nn.Linear): 241 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 242 | if module.bias is not None: 243 | torch.nn.init.zeros_(module.bias) 244 | elif isinstance(module, nn.Embedding): 245 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 246 | 247 | def forward(self, idx, coord, targets=None, kv_cache=None, mask_cache=None): 248 | use_kv_cache = kv_cache is not None 249 | device = idx.device 250 | b, t = idx.size() 251 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 252 | 253 | if kv_cache is not None and kv_cache[0].numel(): 254 | pos = kv_cache[0].shape[-2] # kv_cache of shape: num_layers * (2, B, nh, T, hs) 255 | pos_emb = self.transformer.wpe.weight[None, pos // 3] # 1 x n_embd 256 | mask = mask_cache.index_select(2, torch.LongTensor([pos]).to(pos_emb.device))[:, :, :, :pos + 1] 257 | else: 258 | pos = torch.tensor([i // 3 for i in range(t)], dtype=torch.long, device=device) 259 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 260 | mask = None 261 | 262 | # print('embs:', idx.min(), idx.max(), pos.min(), pos.max(), coord.min(), coord.max()) 263 | # forward the GPT model itself 264 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 265 | coord_emb = self.transformer.wce(coord) # position embeddings of shape (t, n_embd) 266 | # print('shapes:', tok_emb.shape, pos_emb.shape, coord_emb.shape) 267 | x = self.transformer.drop(tok_emb + pos_emb + coord_emb) 268 | 269 | # apply multiple transformer blocks 270 | new_kv_cache = [] 271 | kv_cache = kv_cache or [None] * self.config.n_layer 272 | 273 | for block, kv_cache_layer in zip(self.transformer.h, kv_cache): 274 | x, new_kv = block(x, kv_cache_layer, mask) 275 | new_kv_cache.append(new_kv) 276 | 277 | x = self.transformer.ln_f(x) 278 | 279 | if targets is not None: 280 | # if we are given some desired targets also calculate the loss 281 | logits = self.lm_head(x) 282 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=self.padding_idx) 283 | else: 284 | # inference-time mini-optimization: only forward the lm_head on the very last position 285 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 286 | loss = None 287 | 288 | if not use_kv_cache: 289 | return logits, loss 290 | else: 291 | return logits, new_kv_cache 292 | 293 | def crop_block_size(self, block_size): 294 | # model surgery to decrease the block size if necessary 295 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 296 | # but want to use a smaller block size for some smaller, simpler model 297 | assert block_size <= self.config.block_size 298 | self.config.block_size = block_size 299 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 300 | for block in self.transformer.h: 301 | if hasattr(block.attn, 'bias'): 302 | block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] 303 | 304 | @classmethod 305 | def from_pretrained(cls, model_type, override_args=None): 306 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 307 | override_args = override_args or {} # default to empty dict 308 | # only dropout can be overridden see more notes below 309 | assert all(k == 'dropout' for k in override_args) 310 | from transformers import GPT2LMHeadModel 311 | print("loading weights from pretrained gpt: %s" % model_type) 312 | 313 | # n_layer, n_head and n_embd are determined from model_type 314 | config_args = { 315 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 316 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 317 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 318 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 319 | }[model_type] 320 | print("forcing vocab_size=50257, block_size=1024, bias=True") 321 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 322 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 323 | config_args['bias'] = True # always True for GPT model checkpoints 324 | # we can override the dropout rate, if desired 325 | if 'dropout' in override_args: 326 | print(f"overriding dropout rate to {override_args['dropout']}") 327 | config_args['dropout'] = override_args['dropout'] 328 | # create a from-scratch initialized minGPT model 329 | config = GPTConfig(**config_args) 330 | model = GPT(config) 331 | sd = model.state_dict() 332 | sd_keys = sd.keys() 333 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 334 | 335 | # init a huggingface/transformers model 336 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 337 | sd_hf = model_hf.state_dict() 338 | 339 | # copy while ensuring all of the parameters are aligned and match in names and shapes 340 | sd_keys_hf = sd_hf.keys() 341 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 342 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 343 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 344 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 345 | # this means that we have to transpose these weights when we import them 346 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 347 | for k in sd_keys_hf: 348 | if any(k.endswith(w) for w in transposed): 349 | # special treatment for the Conv1D weights we need to transpose 350 | assert sd_hf[k].shape[::-1] == sd[k].shape 351 | with torch.no_grad(): 352 | sd[k].copy_(sd_hf[k].t()) 353 | else: 354 | # vanilla copy over the other parameters 355 | assert sd_hf[k].shape == sd[k].shape 356 | with torch.no_grad(): 357 | sd[k].copy_(sd_hf[k]) 358 | 359 | return model 360 | 361 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 362 | return configure_optimizers(self.named_parameters(), weight_decay, learning_rate, betas, device_type) 363 | 364 | def estimate_mfu(self, fwdbwd_per_iter, dt): 365 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 366 | # first estimate the number of flops we do per iteration. 367 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 368 | N = self.get_num_params() 369 | cfg = self.config 370 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size 371 | flops_per_token = 6 * N + 12 * L * H * Q * T 372 | flops_per_fwdbwd = flops_per_token * T 373 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 374 | # express our flops throughput as ratio of A100 bfloat16 peak flops 375 | flops_achieved = flops_per_iter * (1.0 / dt) # per second 376 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 377 | mfu = flops_achieved / flops_promised 378 | return mfu 379 | 380 | @torch.no_grad() 381 | def generate(self, idx, coord, max_new_tokens, temperature=1.0, top_k=None, top_p=0.9, use_kv_cache=False): 382 | 383 | if use_kv_cache and (max_new_tokens + idx.shape[-1] - 1) > self.config.block_size: 384 | # print(f"Cannot generate more than {self.config.block_size} tokens with kv cache, setting max new tokens to {self.config.block_size - idx.shape[-1]}") 385 | max_new_tokens = self.config.block_size - idx.shape[-1] 386 | 387 | kv_cache = ( 388 | [torch.empty(2, 0, device=idx.device, dtype=idx.dtype) for _ in range(self.config.n_layer)] 389 | if use_kv_cache 390 | else None 391 | ) 392 | mask_cache = None 393 | if use_kv_cache: 394 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) 395 | mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) 396 | 397 | current_coord = coord 398 | one_t = torch.LongTensor([1]).to(coord.device) 399 | for iteration in range(max_new_tokens): 400 | 401 | if not use_kv_cache or (iteration == 0 and idx.shape[-1] > 1): 402 | # if the sequence context is growing too long we must crop it at block_size 403 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 404 | coord_cond = current_coord if current_coord.size(0) <= self.config.block_size else current_coord[-self.config.block_size:] 405 | else: 406 | idx_cond = idx[:, -1:] 407 | coord_cond = current_coord[-1:] 408 | # forward the model to get the logits for the index in the sequence 409 | logits, kv_cache = self(idx_cond, coord_cond, kv_cache=kv_cache if use_kv_cache else None, mask_cache=mask_cache) 410 | # pluck the logits at the final step and scale by desired temperature 411 | logits = logits[:, -1, :] / temperature 412 | # optionally crop the logits to only the top k options 413 | if top_k is not None: 414 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 415 | logits[logits < v[:, [-1]]] = -float('Inf') 416 | # apply softmax to convert logits to (normalized) probabilities 417 | 418 | use_hard_constraints = False 419 | if use_hard_constraints: 420 | # Use hard constraints 421 | # stop can only occur when 3 coords are generated 422 | if idx.shape[1] % 3 != 0: 423 | logits[:, self.vocab_size - 2] = -float('Inf') 424 | 425 | # all z less than last z are not allowed 426 | if idx.shape[1] % 3 == 0: 427 | if idx.shape[1] > 2: 428 | last_z = idx[0, -3] 429 | if last_z > 0: 430 | logits[:, :last_z - 1] = -float('Inf') 431 | if idx.shape[1] % 3 == 1: 432 | if idx.shape[1] > 3: 433 | last_z = idx[0, -1] 434 | last_to_last_z = idx[0, -4] 435 | last_y = idx[0, -3] 436 | if last_z == last_to_last_z: 437 | if last_y > 0: 438 | logits[:, :last_y - 1] = -float('Inf') 439 | if idx.shape[1] % 3 == 2: 440 | if idx.shape[1] > 4: 441 | last_z = idx[0, -2] 442 | last_to_last_z = idx[0, -5] 443 | last_y = idx[0, -1] 444 | last_to_last_y = idx[0, -4] 445 | last_x = idx[0, -3] 446 | if last_z == last_to_last_z and last_to_last_y == last_y: 447 | if last_x > 0: 448 | logits[:, :last_x - 1] = -float('Inf') 449 | 450 | # sample from the distribution 451 | if top_p is not None: 452 | idx_next = top_p_sampling(logits, top_p) 453 | else: 454 | # apply softmax to convert logits to (normalized) probabilities 455 | probs = F.softmax(logits, dim=-1) 456 | idx_next = torch.multinomial(probs, num_samples=1) 457 | 458 | # append sampled index to the running sequence and continue 459 | idx = torch.cat((idx, idx_next), dim=1) 460 | current_coord = torch.cat((current_coord, one_t * (current_coord[-1] + 1) % 3), dim=0) 461 | if idx[0, -1] == (self.vocab_size - 2): 462 | break 463 | return idx 464 | 465 | 466 | def configure_optimizers(named_parameters, weight_decay, learning_rate, betas, device_type, additional_params=None): 467 | # start with all of the candidate parameters 468 | param_dict = {pn: p for pn, p in named_parameters} 469 | # filter out those that do not require grad 470 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 471 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 472 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 473 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 474 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 475 | if additional_params is not None: 476 | if isinstance(additional_params, dict): 477 | for n, p in additional_params: 478 | nodecay_params.append(p) 479 | else: 480 | for additional_param in additional_params: 481 | for n, p in additional_param: 482 | nodecay_params.append(p) 483 | optim_groups = [ 484 | {'params': decay_params, 'weight_decay': weight_decay}, 485 | {'params': nodecay_params, 'weight_decay': 0.0} 486 | ] 487 | num_decay_params = sum(p.numel() for p in decay_params) 488 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 489 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 490 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 491 | # Create AdamW optimizer and use the fused version if it is available 492 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 493 | use_fused = fused_available and device_type == 'cuda' 494 | extra_args = dict(fused=True) if use_fused else dict() 495 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 496 | print(f"using fused AdamW: {use_fused}") 497 | 498 | return optimizer 499 | -------------------------------------------------------------------------------- /model/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from trimesh import transformations 3 | import math 4 | import numpy as np 5 | from torch import nn 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from time import time 11 | import numpy as np 12 | 13 | 14 | def timeit(tag, t): 15 | print("{}: {}s".format(tag, time() - t)) 16 | return time() 17 | 18 | def pc_normalize(pc): 19 | l = pc.shape[0] 20 | centroid = np.mean(pc, axis=0) 21 | pc = pc - centroid 22 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 23 | pc = pc / m 24 | return pc 25 | 26 | def square_distance(src, dst): 27 | """ 28 | Calculate Euclid distance between each two points. 29 | 30 | src^T * dst = xn * xm + yn * ym + zn * zm; 31 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 32 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 33 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 34 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 35 | 36 | Input: 37 | src: source points, [B, N, C] 38 | dst: target points, [B, M, C] 39 | Output: 40 | dist: per-point square distance, [B, N, M] 41 | """ 42 | B, N, _ = src.shape 43 | _, M, _ = dst.shape 44 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 45 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 46 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 47 | return dist 48 | 49 | 50 | def index_points(points, idx): 51 | """ 52 | 53 | Input: 54 | points: input points data, [B, N, C] 55 | idx: sample index data, [B, S] 56 | Return: 57 | new_points:, indexed points data, [B, S, C] 58 | """ 59 | device = points.device 60 | B = points.shape[0] 61 | view_shape = list(idx.shape) 62 | view_shape[1:] = [1] * (len(view_shape) - 1) 63 | repeat_shape = list(idx.shape) 64 | repeat_shape[0] = 1 65 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 66 | new_points = points[batch_indices, idx, :] 67 | return new_points 68 | 69 | 70 | def farthest_point_sample(xyz, npoint): 71 | """ 72 | Input: 73 | xyz: pointcloud data, [B, N, 3] 74 | npoint: number of samples 75 | Return: 76 | centroids: sampled pointcloud index, [B, npoint] 77 | """ 78 | device = xyz.device 79 | B, N, C = xyz.shape 80 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 81 | distance = torch.ones(B, N).to(device) * 1e10 82 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 83 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 84 | for i in range(npoint): 85 | centroids[:, i] = farthest 86 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 87 | dist = torch.sum((xyz - centroid) ** 2, -1) 88 | mask = dist < distance 89 | distance[mask] = dist[mask] 90 | farthest = torch.max(distance, -1)[1] 91 | return centroids 92 | 93 | 94 | def query_ball_point(radius, nsample, xyz, new_xyz): 95 | """ 96 | Input: 97 | radius: local region radius 98 | nsample: max sample number in local region 99 | xyz: all points, [B, N, 3] 100 | new_xyz: query points, [B, S, 3] 101 | Return: 102 | group_idx: grouped points index, [B, S, nsample] 103 | """ 104 | device = xyz.device 105 | B, N, C = xyz.shape 106 | _, S, _ = new_xyz.shape 107 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 108 | sqrdists = square_distance(new_xyz, xyz) 109 | group_idx[sqrdists > radius ** 2] = N 110 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 111 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 112 | mask = group_idx == N 113 | group_idx[mask] = group_first[mask] 114 | return group_idx 115 | 116 | 117 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 118 | """ 119 | Input: 120 | npoint: 121 | radius: 122 | nsample: 123 | xyz: input points position data, [B, N, 3] 124 | points: input points data, [B, N, D] 125 | Return: 126 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 127 | new_points: sampled points data, [B, npoint, nsample, 3+D] 128 | """ 129 | B, N, C = xyz.shape 130 | S = npoint 131 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 132 | new_xyz = index_points(xyz, fps_idx) 133 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 134 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 135 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 136 | 137 | if points is not None: 138 | grouped_points = index_points(points, idx) 139 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 140 | else: 141 | new_points = grouped_xyz_norm 142 | if returnfps: 143 | return new_xyz, new_points, grouped_xyz, fps_idx 144 | else: 145 | return new_xyz, new_points 146 | 147 | 148 | def sample_and_group_all(xyz, points): 149 | """ 150 | Input: 151 | xyz: input points position data, [B, N, 3] 152 | points: input points data, [B, N, D] 153 | Return: 154 | new_xyz: sampled points position data, [B, 1, 3] 155 | new_points: sampled points data, [B, 1, N, 3+D] 156 | """ 157 | device = xyz.device 158 | B, N, C = xyz.shape 159 | new_xyz = torch.zeros(B, 1, C).to(device) 160 | grouped_xyz = xyz.view(B, 1, N, C) 161 | if points is not None: 162 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 163 | else: 164 | new_points = grouped_xyz 165 | return new_xyz, new_points 166 | 167 | 168 | class PointNetSetAbstraction(nn.Module): 169 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 170 | super(PointNetSetAbstraction, self).__init__() 171 | self.npoint = npoint 172 | self.radius = radius 173 | self.nsample = nsample 174 | self.mlp_convs = nn.ModuleList() 175 | self.mlp_bns = nn.ModuleList() 176 | last_channel = in_channel 177 | for out_channel in mlp: 178 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 179 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 180 | last_channel = out_channel 181 | self.group_all = group_all 182 | 183 | def forward(self, xyz, points): 184 | """ 185 | Input: 186 | xyz: input points position data, [B, C, N] 187 | points: input points data, [B, D, N] 188 | Return: 189 | new_xyz: sampled points position data, [B, C, S] 190 | new_points_concat: sample points feature data, [B, D', S] 191 | """ 192 | xyz = xyz.permute(0, 2, 1) 193 | if points is not None: 194 | points = points.permute(0, 2, 1) 195 | 196 | if self.group_all: 197 | new_xyz, new_points = sample_and_group_all(xyz, points) 198 | else: 199 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 200 | # new_xyz: sampled points position data, [B, npoint, C] 201 | # new_points: sampled points data, [B, npoint, nsample, C+D] 202 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 203 | for i, conv in enumerate(self.mlp_convs): 204 | bn = self.mlp_bns[i] 205 | new_points = F.relu(bn(conv(new_points))) 206 | 207 | new_points = torch.max(new_points, 2)[0] 208 | new_xyz = new_xyz.permute(0, 2, 1) 209 | return new_xyz, new_points 210 | 211 | 212 | class PointNetSetAbstractionMsg(nn.Module): 213 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 214 | super(PointNetSetAbstractionMsg, self).__init__() 215 | self.npoint = npoint 216 | self.radius_list = radius_list 217 | self.nsample_list = nsample_list 218 | self.conv_blocks = nn.ModuleList() 219 | self.bn_blocks = nn.ModuleList() 220 | for i in range(len(mlp_list)): 221 | convs = nn.ModuleList() 222 | bns = nn.ModuleList() 223 | last_channel = in_channel + 3 224 | for out_channel in mlp_list[i]: 225 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 226 | bns.append(nn.BatchNorm2d(out_channel)) 227 | last_channel = out_channel 228 | self.conv_blocks.append(convs) 229 | self.bn_blocks.append(bns) 230 | 231 | def forward(self, xyz, points): 232 | """ 233 | Input: 234 | xyz: input points position data, [B, C, N] 235 | points: input points data, [B, D, N] 236 | Return: 237 | new_xyz: sampled points position data, [B, C, S] 238 | new_points_concat: sample points feature data, [B, D', S] 239 | """ 240 | xyz = xyz.permute(0, 2, 1) 241 | if points is not None: 242 | points = points.permute(0, 2, 1) 243 | 244 | B, N, C = xyz.shape 245 | S = self.npoint 246 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 247 | new_points_list = [] 248 | for i, radius in enumerate(self.radius_list): 249 | K = self.nsample_list[i] 250 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 251 | grouped_xyz = index_points(xyz, group_idx) 252 | grouped_xyz -= new_xyz.view(B, S, 1, C) 253 | if points is not None: 254 | grouped_points = index_points(points, group_idx) 255 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 256 | else: 257 | grouped_points = grouped_xyz 258 | 259 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 260 | for j in range(len(self.conv_blocks[i])): 261 | conv = self.conv_blocks[i][j] 262 | bn = self.bn_blocks[i][j] 263 | grouped_points = F.relu(bn(conv(grouped_points))) 264 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 265 | new_points_list.append(new_points) 266 | 267 | new_xyz = new_xyz.permute(0, 2, 1) 268 | new_points_concat = torch.cat(new_points_list, dim=1) 269 | return new_xyz, new_points_concat 270 | 271 | 272 | class PointNetFeaturePropagation(nn.Module): 273 | def __init__(self, in_channel, mlp): 274 | super(PointNetFeaturePropagation, self).__init__() 275 | self.mlp_convs = nn.ModuleList() 276 | self.mlp_bns = nn.ModuleList() 277 | last_channel = in_channel 278 | for out_channel in mlp: 279 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 280 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 281 | last_channel = out_channel 282 | 283 | def forward(self, xyz1, xyz2, points1, points2): 284 | """ 285 | Input: 286 | xyz1: input points position data, [B, C, N] 287 | xyz2: sampled input points position data, [B, C, S] 288 | points1: input points data, [B, D, N] 289 | points2: input points data, [B, D, S] 290 | Return: 291 | new_points: upsampled points data, [B, D', N] 292 | """ 293 | xyz1 = xyz1.permute(0, 2, 1) 294 | xyz2 = xyz2.permute(0, 2, 1) 295 | 296 | points2 = points2.permute(0, 2, 1) 297 | B, N, C = xyz1.shape 298 | _, S, _ = xyz2.shape 299 | 300 | if S == 1: 301 | interpolated_points = points2.repeat(1, N, 1) 302 | else: 303 | dists = square_distance(xyz1, xyz2) 304 | dists, idx = dists.sort(dim=-1) 305 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 306 | 307 | dist_recip = 1.0 / (dists + 1e-8) 308 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 309 | weight = dist_recip / norm 310 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 311 | 312 | if points1 is not None: 313 | points1 = points1.permute(0, 2, 1) 314 | new_points = torch.cat([points1, interpolated_points], dim=-1) 315 | else: 316 | new_points = interpolated_points 317 | 318 | new_points = new_points.permute(0, 2, 1) 319 | for i, conv in enumerate(self.mlp_convs): 320 | bn = self.mlp_bns[i] 321 | new_points = F.relu(bn(conv(new_points))) 322 | return new_points 323 | 324 | 325 | class PointNet(nn.Module): 326 | 327 | def __init__(self, num_class, normal_channel=True): 328 | super(PointNet, self).__init__() 329 | in_channel = 6 if normal_channel else 3 330 | self.normal_channel = normal_channel 331 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False) 332 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 333 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 334 | self.fc1 = nn.Linear(1024, 512) 335 | self.bn1 = nn.BatchNorm1d(512) 336 | self.drop1 = nn.Dropout(0.4) 337 | self.fc2 = nn.Linear(512, 256) 338 | self.bn2 = nn.BatchNorm1d(256) 339 | self.drop2 = nn.Dropout(0.4) 340 | self.fc3 = nn.Linear(256, num_class) 341 | 342 | def forward(self, xyz, return_feats=False): 343 | B, _, _ = xyz.shape 344 | if self.normal_channel: 345 | norm = xyz[:, 3:, :] 346 | xyz = xyz[:, :3, :] 347 | else: 348 | norm = None 349 | l1_xyz, l1_points = self.sa1(xyz, norm) 350 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 351 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 352 | x = l3_points.view(B, 1024) 353 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 354 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 355 | feats = x 356 | x = self.fc3(x) 357 | x = F.log_softmax(x, -1) 358 | if return_feats: 359 | return x, feats 360 | return x, l3_points 361 | 362 | @torch.no_grad() 363 | def classifier_guided_filter(self, mesh, expected_category): 364 | interesting_categories = { 365 | '03001627': ([8, 30], 0.05), 366 | '04379243': ([33, 3], 0.03), 367 | } 368 | category = interesting_categories[expected_category][0] 369 | threshold = interesting_categories[expected_category][1] 370 | points = get_point_cloud(mesh, nvotes=8).cuda() 371 | points = points.transpose(2, 1) 372 | pred, _ = self(points) 373 | pred = pred.mean(dim=0) 374 | pred_probability = torch.nn.functional.softmax(pred.unsqueeze(0), dim=-1)[0] 375 | pval = pred_probability[category].max().item() 376 | return pval > threshold 377 | 378 | 379 | def get_pointnet_classifier(): 380 | classifier = PointNet(40, normal_channel=False) 381 | checkpoint = torch.load('pretrained/pointnet.pth') 382 | classifier.load_state_dict(checkpoint['model_state_dict']) 383 | classifier = classifier.eval() 384 | return classifier 385 | 386 | 387 | def pc_normalize(pc): 388 | centroid = np.mean(pc, axis=0) 389 | pc = pc - centroid 390 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 391 | pc = pc / m 392 | return pc 393 | 394 | 395 | def get_point_cloud(mesh, nvotes=4): 396 | point_batch = [] 397 | for _ in range(nvotes): 398 | tmesh = mesh 399 | rot_matrix = transformations.rotation_matrix(-math.pi / 2, [1, 0, 0], [0, 0, 0]) 400 | tmesh = tmesh.apply_transform(rot_matrix) 401 | rot_matrix = transformations.rotation_matrix(-math.pi / 2, [0, 1, 0], [0, 0, 0]) 402 | tmesh = tmesh.apply_transform(rot_matrix) 403 | point_set = pc_normalize(tmesh.sample(1024)) 404 | point_batch.append(point_set[np.newaxis, :, :]) 405 | point_batch = torch.from_numpy(np.concatenate(point_batch, axis=0)).float() 406 | return point_batch 407 | -------------------------------------------------------------------------------- /model/softargmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def softargmax(logits): 5 | smax = torch.nn.functional.softmax(logits, dim=-1) 6 | tokens = torch.tensor(list(range(logits.shape[-1])), dtype=torch.float32, device=logits.device).reshape(1, 1, -1).expand_as(smax) 7 | return torch.sum(tokens * smax, dim=-1) 8 | 9 | 10 | if __name__ == '__main__': 11 | logits_ = torch.tensor([ 12 | [[[2, 0.1, 0.2, 0.5]], 13 | [[0.7, 0.1, 0.2, 5]], 14 | [[0, 10, 0.4, 0.5]], 15 | ], 16 | [[[0, 0.1, 0.2, 0.9]], 17 | [[0.7, 0.1, 100, 5]], 18 | [[0, 1, 4, 5]], 19 | ] 20 | ], requires_grad=True) 21 | print(softargmax(logits_), softargmax(logits_).requires_grad) 22 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | from dataset import get_shifted_sequence 8 | from model.transformer_base import TransformerBase 9 | from model.nanogpt import Block, LayerNorm 10 | from util.misc import top_p_sampling 11 | from tqdm import tqdm 12 | 13 | 14 | class QuantSoupTransformer(TransformerBase): 15 | 16 | def __init__(self, config, vq_config): 17 | super().__init__() 18 | assert config.block_size is not None 19 | self.config = config 20 | self.padding_idx = 2 21 | self.tokens_per_face = config.finemb_size 22 | self.finemb_size = 3 + config.finemb_size # 3 for start, stop pad, 3 for fin 23 | self.foutemb_size = 3 + config.foutemb_size 24 | vocab_size = vq_config.n_embed + 1 + 1 + 1 # +1 for start, +1 for stop, +1 for pad 25 | self.vocab_size = vocab_size 26 | print('Model Vocab Size:', vocab_size) 27 | print('Model Padding Index:', self.padding_idx) 28 | print('Model Fin Size:', self.finemb_size) 29 | print('Model Fout Size:', self.foutemb_size) 30 | self.input_layer = nn.Linear(vq_config.embed_dim, config.n_embd) 31 | self.extra_embeds = nn.Embedding(3, config.n_embd, padding_idx=self.padding_idx) 32 | self.transformer = nn.ModuleDict(dict( 33 | wpe=nn.Embedding(config.block_size, config.n_embd), 34 | wfie=nn.Embedding(self.finemb_size, config.n_embd, padding_idx=self.padding_idx), 35 | wfoe=nn.Embedding(self.foutemb_size, config.n_embd, padding_idx=self.padding_idx), 36 | drop=nn.Dropout(config.dropout), 37 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 38 | ln_f=LayerNorm(config.n_embd, bias=config.bias), 39 | )) 40 | self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False) 41 | 42 | # init all weights 43 | self.apply(self._init_weights) 44 | # apply special scaled init to the residual projections, per GPT-2 paper 45 | for pn, p in self.named_parameters(): 46 | if pn.endswith('c_proj.weight'): 47 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 48 | 49 | # report number of parameters 50 | print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) 51 | 52 | def forward(self, idx, fin, fout, tokenizer, targets=None, kv_cache=None, mask_cache=None): 53 | use_kv_cache = kv_cache is not None 54 | device = idx.device 55 | b, t = idx.size() 56 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 57 | 58 | embed = torch.zeros((b * t, self.config.n_embd), dtype=torch.float32, device=device) 59 | idx_in_extra = torch.isin(idx, torch.LongTensor([0, 1, 2]).to(device)).reshape(-1) 60 | idx_flat = idx.reshape(-1) 61 | embed[idx_in_extra, :] = self.extra_embeds(idx_flat[idx_in_extra]) 62 | embed[~idx_in_extra, :] = self.input_layer(tokenizer.embed(idx_flat[~idx_in_extra] - 3)) 63 | tok_emb = embed.reshape(b, t, -1) # token embeddings of shape (b, t, n_embd) 64 | finemb = self.transformer.wfie(fin) # face inner embeddings of shape (t, n_embd) 65 | foutemb = self.transformer.wfoe(fout) # face outer embeddings of shape (t, n_embd) 66 | 67 | # position embedding 68 | 69 | if kv_cache is not None and kv_cache[0].numel(): 70 | pos = kv_cache[0].shape[-2] # kv_cache of shape: num_layers * (2, B, nh, T, hs) 71 | pos_emb = self.transformer.wpe.weight[None, pos] # 1 x n_embd 72 | mask = mask_cache.index_select(2, torch.LongTensor([pos]).to(pos_emb.device))[:, :, :, :pos + 1] 73 | else: 74 | pos = torch.tensor([i for i in range(t)], dtype=torch.long, device=device) 75 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 76 | mask = None 77 | 78 | sum_emb = tok_emb + pos_emb + finemb + foutemb 79 | # print('shapes:', tok_emb.shape, pos_emb.shape, coord_emb.shape) 80 | x = self.transformer.drop(sum_emb) 81 | 82 | # apply multiple transformer blocks 83 | new_kv_cache = [] 84 | kv_cache = kv_cache or [None] * self.config.n_layer 85 | 86 | for block, kv_cache_layer in zip(self.transformer.h, kv_cache): 87 | x, new_kv = block(x, kv_cache_layer, mask) 88 | new_kv_cache.append(new_kv) 89 | 90 | x = self.transformer.ln_f(x) 91 | 92 | if targets is not None: 93 | # if we are given some desired targets also calculate the loss 94 | logits = self.lm_head(x) 95 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=self.padding_idx) 96 | else: 97 | # inference-time mini-optimization: only forward the lm_head on the very last position 98 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 99 | loss = None 100 | 101 | if not use_kv_cache: 102 | return logits, loss 103 | else: 104 | return logits, new_kv_cache 105 | 106 | @torch.no_grad() 107 | def generate(self, idx, fin, fout, tokenizer, max_new_tokens=10000, temperature=1.0, top_k=None, top_p=0.9, use_kv_cache=False): 108 | """ 109 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 110 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 111 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 112 | """ 113 | if use_kv_cache and (max_new_tokens + idx.shape[-1] - 1) > self.config.block_size: 114 | # print(f"Cannot generate more than {self.config.block_size} tokens with kv cache, setting max new tokens to {self.config.block_size - idx.shape[-1]}") 115 | max_new_tokens = self.config.block_size - idx.shape[-1] 116 | 117 | kv_cache = ( 118 | [torch.empty(2, 0, device=idx.device, dtype=idx.dtype) for _ in range(self.config.n_layer)] 119 | if use_kv_cache 120 | else None 121 | ) 122 | mask_cache = None 123 | if use_kv_cache: 124 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) 125 | mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) 126 | 127 | current_fin = fin 128 | current_fout = fout 129 | one_t = torch.LongTensor([1]).to(fin.device) 130 | for iteration in range(max_new_tokens): 131 | 132 | if not use_kv_cache or (iteration == 0 and idx.shape[-1] > 1): 133 | # if the sequence context is growing too long we must crop it at block_size 134 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 135 | fin_cond = current_fin if current_fin.size(1) <= self.config.block_size else current_fin[:, -self.config.block_size:] 136 | fout_cond = current_fout if current_fout.size(1) <= self.config.block_size else current_fout[:, -self.config.block_size:] 137 | fout_cond = torch.from_numpy(get_shifted_sequence(fout_cond[0].cpu().numpy())).to(idx_cond.device).unsqueeze(0) 138 | else: 139 | idx_cond = idx[:, -1:] 140 | fin_cond = current_fin[:, -1:] 141 | fout_cond = current_fout[:, -1:] # note: don't need shifting since we assume block_size is huge enough to not need shifting 142 | # forward the model to get the logits for the index in the sequence 143 | logits, kv_cache = self(idx_cond, fin_cond, fout_cond, tokenizer, kv_cache=kv_cache if use_kv_cache else None, mask_cache=mask_cache) 144 | # pluck the logits at the final step and scale by desired temperature 145 | logits = logits[:, -1, :] / temperature 146 | # optionally crop the logits to only the top k options 147 | if top_k is not None: 148 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 149 | logits[logits < v[:, [-1]]] = -float('Inf') 150 | 151 | # TODO: Introduce hard constraints 152 | 153 | # sample from the distribution 154 | # apply softmax to convert logits to (normalized) probabilities 155 | if top_p is not None: 156 | idx_next = top_p_sampling(logits, top_p) 157 | else: 158 | probs = F.softmax(logits, dim=-1) 159 | idx_next = torch.multinomial(probs, num_samples=1) 160 | # append sampled index to the running sequence and continue 161 | idx = torch.cat((idx, idx_next), dim=1) 162 | last_fin_cond = current_fin[0, -1] 163 | if last_fin_cond == self.finemb_size - 1 or (iteration == 0 and idx.shape[-1] == 2): 164 | current_fin = torch.cat((current_fin, (3 * one_t[0]).unsqueeze(0).unsqueeze(0)), dim=1) 165 | current_fout = torch.cat((current_fout, (current_fout[0, -1] + 1).unsqueeze(0).unsqueeze(0)), dim=1) 166 | else: 167 | current_fin = torch.cat((current_fin, (current_fin[0, -1] + 1).unsqueeze(0).unsqueeze(0)), dim=1) 168 | current_fout = torch.cat((current_fout, (current_fout[0, -1]).unsqueeze(0).unsqueeze(0)), dim=1) 169 | if idx_next == 1: 170 | return idx 171 | return None 172 | 173 | 174 | @torch.no_grad() 175 | def generate_with_beamsearch(self, idx, fin, fout, tokenizer, max_new_tokens=10000, use_kv_cache=False, beam_width=6): 176 | 177 | backup_beams = [] 178 | backup_beam_prob = [] 179 | max_new_tokens = self.config.block_size - idx.shape[-1] 180 | 181 | kv_cache = ( 182 | [torch.empty(2, 0, device=idx.device, dtype=idx.dtype) for _ in range(self.config.n_layer)] 183 | if use_kv_cache 184 | else None 185 | ) 186 | 187 | mask_cache = None 188 | 189 | if use_kv_cache: 190 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) 191 | mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) 192 | 193 | current_fin = fin 194 | current_fout = fout 195 | one_t = torch.LongTensor([1]).to(fin.device) 196 | 197 | idx = idx.repeat((beam_width, 1, 1)).transpose(0, 1).flatten(end_dim=-2) 198 | current_fin = current_fin.repeat((beam_width, 1, 1)).transpose(0, 1).flatten(end_dim=-2) 199 | current_fout = current_fout.repeat((beam_width, 1, 1)).transpose(0, 1).flatten(end_dim=-2) 200 | 201 | logits, kv_cache = self(idx, fin, fout, tokenizer, kv_cache=kv_cache if use_kv_cache else None, mask_cache=mask_cache) 202 | 203 | vocabulary_size = logits.shape[-1] 204 | probabilities, top_k_indices = logits[0, 0, :].squeeze().log_softmax(-1).topk(k=beam_width, axis=-1) 205 | 206 | next_chars = top_k_indices.reshape(-1, 1) 207 | idx = torch.cat((idx, next_chars), axis=-1) 208 | 209 | last_fin_cond = current_fin[0, -1] # same for all beams 210 | if last_fin_cond == self.finemb_size - 1 or (idx.shape[-1] == 2): 211 | current_fin = torch.cat((current_fin, (3 * one_t[0]).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1) 212 | current_fout = torch.cat((current_fout, (current_fout[0, -1] + 1).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1) 213 | else: 214 | current_fin = torch.cat((current_fin, (current_fin[0, -1] + 1).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1) 215 | current_fout = torch.cat((current_fout, current_fout[0, -1].unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1) 216 | 217 | for iteration in tqdm(range(max_new_tokens - 1), desc='beam_search'): 218 | if not use_kv_cache: 219 | idx_cond = idx 220 | fin_cond = current_fin 221 | fout_cond = current_fout 222 | else: 223 | idx_cond = idx[:, -1:] 224 | fin_cond = current_fin[:, -1:] 225 | fout_cond = current_fout[:, -1:] 226 | 227 | # forward the model to get the logits for the index in the sequence 228 | logits, kv_cache = self(idx_cond, fin_cond, fout_cond, tokenizer, kv_cache=kv_cache if use_kv_cache else None, mask_cache=mask_cache) 229 | 230 | next_probabilities = logits.log_softmax(-1) 231 | 232 | next_probabilities = next_probabilities.reshape((-1, beam_width, next_probabilities.shape[-1])) 233 | probabilities = probabilities.unsqueeze(-1) + next_probabilities 234 | probabilities = probabilities.flatten(start_dim=1) 235 | probabilities, top_k_indices = probabilities.topk(k=beam_width, axis=-1) 236 | next_indices = torch.remainder(top_k_indices, vocabulary_size).flatten().unsqueeze(-1) 237 | best_candidates = (top_k_indices / vocabulary_size).long() 238 | best_candidates += torch.arange(idx.shape[0] // beam_width, device=idx.device).unsqueeze(-1) * beam_width 239 | idx = idx[best_candidates].flatten(end_dim=-2) 240 | for block_idx in range(len(kv_cache)): 241 | kv_cache[block_idx] = kv_cache[block_idx][:, best_candidates.flatten(), :, :, :] 242 | idx = torch.cat((idx, next_indices), axis=1) 243 | 244 | # update fin and fout 245 | last_fin_cond = current_fin[0, -1] # same for all beams 246 | if last_fin_cond == self.finemb_size - 1 or (idx.shape[-1] == 2): 247 | current_fin = torch.cat((current_fin, (3 * one_t[0]).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1) 248 | current_fout = torch.cat((current_fout, (current_fout[0, -1] + 1).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1) 249 | else: 250 | current_fin = torch.cat((current_fin, (current_fin[0, -1] + 1).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1) 251 | current_fout = torch.cat((current_fout, current_fout[0, -1].unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1) 252 | 253 | amax = probabilities.flatten().argmax() 254 | if idx[amax, -1] == 1: 255 | return idx[amax: amax + 1, :] 256 | for beam_idx in range(beam_width): 257 | if idx[beam_idx, -1] == 1: 258 | backup_beams.append(idx[beam_idx: beam_idx + 1, :]) 259 | backup_beam_prob.append(probabilities[0, beam_idx].item()) 260 | return None 261 | -------------------------------------------------------------------------------- /model/transformer_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from model.nanogpt import configure_optimizers 5 | 6 | 7 | class TransformerBase(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def get_num_params(self): 13 | """ 14 | Return the number of parameters in the model. 15 | For non-embedding count (default), the position embeddings get subtracted. 16 | The token embeddings would too, except due to the parameter sharing these 17 | params are actually used as weights in the final layer, so we include them. 18 | """ 19 | n_params = sum(p.numel() for n,p in self.named_parameters()) 20 | return n_params 21 | 22 | def _init_weights(self, module): 23 | if isinstance(module, nn.Linear): 24 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 25 | if module.bias is not None: 26 | torch.nn.init.zeros_(module.bias) 27 | elif isinstance(module, nn.Embedding): 28 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 29 | 30 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 31 | return configure_optimizers(self.named_parameters(), weight_decay, learning_rate, betas, device_type) 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch_geometric==2.4.0 2 | torchmetrics==0.11.0 3 | cosine-annealing-warmup @ git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup@12d03c07553aedd3d9e9155e2b3e31ce8c64081a 4 | easydict==1.10 5 | einops==0.7.0 6 | einops-exts==0.0.4 7 | hydra-core==1.3.2 8 | imageio==2.31.0 9 | imageio-ffmpeg==0.4.2 10 | lightning @ git+https://github.com/Lightning-AI/lightning@eb8b314f7cbb5e87e0d149f95f45f1409a7715c2 11 | lightning-utilities==0.8.0 12 | matplotlib==3.7.2 13 | matplotlib-inline==0.1.6 14 | moviepy==1.0.3 15 | numpy==1.24.4 16 | omegaconf==2.3.0 17 | Pillow==10.3.0 18 | pymeshlab==2022.2.post2 19 | pytorch-lightning==2.0.3 20 | scikit-image==0.21.0 21 | scikit-learn==1.2.1 22 | scipy==1.10.1 23 | tqdm==4.64.1 24 | transformers==4.48.0 25 | trimesh==3.18.0 26 | vector-quantize-pytorch==1.8.1 27 | wandb==0.13.4 28 | randomname 29 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import sys 4 | import traceback 5 | from pathlib import Path 6 | from random import randint 7 | import datetime 8 | 9 | import torch 10 | import wandb 11 | import randomname 12 | from pytorch_lightning.strategies.ddp import DDPStrategy 13 | 14 | from pytorch_lightning import seed_everything, Trainer 15 | from pytorch_lightning.callbacks import ModelCheckpoint 16 | from pytorch_lightning.loggers.wandb import WandbLogger 17 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 18 | from vector_quantize_pytorch import ResidualVQ 19 | 20 | from model.encoder import GraphEncoder 21 | from model.decoder import resnet34_decoder 22 | from util.filesystem_logger import FilesystemLogger 23 | from util.misc import get_parameters_from_state_dict 24 | 25 | 26 | def print_traceback_handler(sig, _frame): 27 | print(f'Received signal {sig}') 28 | bt = ''.join(traceback.format_stack()) 29 | print(f'Requested stack trace:\n{bt}') 30 | 31 | 32 | def quit_handler(sig, frame): 33 | print(f'Received signal {sig}, quitting.') 34 | sys.exit(1) 35 | 36 | 37 | def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler): 38 | print(f'Setting signal {sig} handler {handler}') 39 | signal.signal(sig, handler) 40 | 41 | 42 | def register_quit_signal_handlers(sig=signal.SIGUSR2, handler=quit_handler): 43 | print(f'Setting signal {sig} handler {handler}') 44 | signal.signal(sig, handler) 45 | 46 | 47 | def generate_experiment_name(name, config): 48 | if config.resume is not None: 49 | experiment = Path(config.resume).parents[1].name 50 | os.environ['experiment'] = experiment 51 | elif not os.environ.get('experiment'): 52 | experiment = f"{datetime.datetime.now().strftime('%m%d%H%M')}_{name}_{config.experiment}_{randomname.get_name()}" 53 | os.environ['experiment'] = experiment 54 | else: 55 | experiment = os.environ['experiment'] 56 | return experiment 57 | 58 | 59 | def create_trainer(name, config): 60 | if not config.wandb_main and config.suffix == '': 61 | config.suffix = '-dev' 62 | config.experiment = generate_experiment_name(name, config) 63 | if config.val_check_interval > 1: 64 | config.val_check_interval = int(config.val_check_interval) 65 | if config.seed is None: 66 | config.seed = randint(0, 999) 67 | 68 | # config.dataset_root = Path(config.dataset_root) 69 | 70 | seed_everything(1337 + config.seed) 71 | torch.backends.cuda.matmul.allow_tf32 = True # type: ignore # allow tf32 on matmul 72 | torch.backends.cudnn.allow_tf32 = True # type: ignore # allow tf32 on cudnn 73 | torch.multiprocessing.set_sharing_strategy('file_system') # possible fix for the "OSError: too many files" exception 74 | 75 | register_debug_signal_handlers() 76 | register_quit_signal_handlers() 77 | 78 | # noinspection PyUnusedLocal 79 | filesystem_logger = FilesystemLogger(config) 80 | 81 | # use wandb logger instead 82 | if config.logger == 'wandb': 83 | logger = WandbLogger(project=f'{name}{config.suffix}', name=config.experiment, id=config.experiment, settings=wandb.Settings(start_method='thread')) 84 | else: 85 | logger = TensorBoardLogger(name='tb', save_dir=(Path("runs") / config.experiment)) 86 | 87 | checkpoint_callback = ModelCheckpoint( 88 | dirpath=(Path("runs") / config.experiment / "checkpoints"), 89 | save_top_k=-1, 90 | verbose=False, 91 | every_n_epochs=config.save_epoch, 92 | filename='{epoch:02d}-{global_step}', 93 | auto_insert_metric_name=False, 94 | ) 95 | 96 | gpu_count = torch.cuda.device_count() 97 | 98 | precision = 'bf16' if torch.cuda.is_bf16_supported() else 16 99 | precision = 32 100 | 101 | if gpu_count > 1: 102 | trainer = Trainer( 103 | accelerator='gpu', 104 | strategy=DDPStrategy(find_unused_parameters=False), 105 | num_nodes=1, 106 | precision=precision, 107 | devices=gpu_count, 108 | num_sanity_val_steps=config.sanity_steps, 109 | max_epochs=config.max_epoch, 110 | limit_val_batches=config.val_check_percent, 111 | callbacks=[checkpoint_callback], 112 | val_check_interval=float(min(config.val_check_interval, 1)), 113 | check_val_every_n_epoch=max(1, config.val_check_interval), 114 | logger=logger, 115 | deterministic=False, 116 | benchmark=True, 117 | ) 118 | elif gpu_count == 1: 119 | trainer = Trainer( 120 | devices=[0], 121 | accelerator='gpu', 122 | precision=precision, 123 | strategy=DDPStrategy(find_unused_parameters=False), 124 | num_sanity_val_steps=config.sanity_steps, 125 | max_epochs=config.max_epoch, 126 | limit_val_batches=config.val_check_percent, 127 | callbacks=[checkpoint_callback], 128 | val_check_interval=float(min(config.val_check_interval, 1)), 129 | check_val_every_n_epoch=max(1, config.val_check_interval), 130 | logger=logger, 131 | deterministic=False, 132 | benchmark=True, 133 | ) 134 | else: 135 | trainer = Trainer( 136 | accelerator='cpu', 137 | precision=precision, 138 | num_sanity_val_steps=config.sanity_steps, 139 | max_epochs=config.max_epoch, 140 | limit_val_batches=config.val_check_percent, 141 | callbacks=[checkpoint_callback], 142 | val_check_interval=float(min(config.val_check_interval, 1)), 143 | check_val_every_n_epoch=max(1, config.val_check_interval), 144 | logger=logger, 145 | deterministic=False, 146 | benchmark=True, 147 | ) 148 | return trainer 149 | 150 | 151 | def step(opt, modules): 152 | for module in modules: 153 | for param in module.parameters(): 154 | if param.grad is not None: 155 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) 156 | torch.nn.utils.clip_grad_norm_(module.parameters(), 1) # type: ignore 157 | opt.step() 158 | 159 | 160 | def create_conv_batch(encoded_features, batch, batch_size, device): 161 | conv_input, conv_mask = [], [] 162 | max_sequence_length = 0 163 | for k in range(batch_size): 164 | features = encoded_features[batch == k, :].T.contiguous().unsqueeze(0) 165 | max_sequence_length = max(max_sequence_length, features.shape[2]) 166 | conv_input.append(features) 167 | conv_mask.append(torch.ones([features.shape[2]], device=device, dtype=torch.bool)) 168 | for k in range(batch_size): 169 | conv_input[k] = torch.nn.functional.pad(conv_input[k], (0, max_sequence_length - conv_input[k].shape[2]), 'replicate') 170 | conv_mask[k] = torch.nn.functional.pad(conv_mask[k], (0, max_sequence_length - conv_mask[k].shape[0]), 'constant', False) 171 | conv_input = torch.cat(conv_input, dim=0) 172 | conv_mask = torch.cat(conv_mask, dim=0) 173 | return conv_input, conv_mask 174 | 175 | 176 | def get_rvqvae_v0_all(config, resume): 177 | encoder, pre_quant, post_quant, vq = get_rvqvae_v0_encoder_vq(config, resume) 178 | decoder = get_rvqvae_v0_decoder(config, resume) 179 | return encoder, decoder, pre_quant, post_quant, vq 180 | 181 | 182 | def get_rvqvae_v0_encoder_vq(config, resume): 183 | state_dict = torch.load(resume, map_location="cpu")["state_dict"] 184 | encoder = GraphEncoder(no_max_pool=config.g_no_max_pool, aggr=config.g_aggr, graph_conv=config.graph_conv, use_point_features=config.use_point_feats) 185 | pre_quant = torch.nn.Linear(512, config.embed_dim) 186 | post_quant = torch.nn.Linear(config.embed_dim, 512) 187 | 188 | encoder.load_state_dict(get_parameters_from_state_dict(state_dict, "encoder")) 189 | pre_quant.load_state_dict(get_parameters_from_state_dict(state_dict, "pre_quant")) 190 | post_quant.load_state_dict(get_parameters_from_state_dict(state_dict, "post_quant")) 191 | 192 | vq = ResidualVQ( 193 | dim=config.embed_dim, 194 | codebook_size=config.n_embed, # codebook size 195 | num_quantizers=config.embed_levels, 196 | commitment_weight=config.embed_loss_weight, # the weight on the commitment loss 197 | stochastic_sample_codes=True, 198 | sample_codebook_temp=0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic 199 | shared_codebook=config.embed_share, 200 | decay=config.code_decay, 201 | ) 202 | vq.load_state_dict(get_parameters_from_state_dict(state_dict, "vq")) 203 | return encoder, pre_quant, post_quant, vq 204 | 205 | 206 | def get_rvqvae_v0_decoder(config, resume, device=torch.device("cpu")): 207 | state_dict = torch.load(resume, map_location="cpu")["state_dict"] 208 | decoder = resnet34_decoder(512, config.num_tokens - 2, config.ce_output) 209 | decoder.load_state_dict(get_parameters_from_state_dict(state_dict, "decoder")) 210 | decoder = decoder.to(device).eval() 211 | return decoder 212 | 213 | 214 | def get_rvqvae_v1_encoder_vq(config, resume): 215 | state_dict = torch.load(resume, map_location="cpu")["state_dict"] 216 | encoder = GraphEncoder(no_max_pool=config.g_no_max_pool, aggr=config.g_aggr, graph_conv=config.graph_conv, use_point_features=config.use_point_feats, output_dim=576) 217 | pre_quant = torch.nn.Linear(192, config.embed_dim) 218 | post_quant = torch.nn.Linear(config.embed_dim * 3, 512) 219 | 220 | encoder.load_state_dict(get_parameters_from_state_dict(state_dict, "encoder")) 221 | pre_quant.load_state_dict(get_parameters_from_state_dict(state_dict, "pre_quant")) 222 | post_quant.load_state_dict(get_parameters_from_state_dict(state_dict, "post_quant")) 223 | 224 | vq = ResidualVQ( 225 | dim=config.embed_dim, 226 | codebook_size=config.n_embed, # codebook size 227 | num_quantizers=config.embed_levels, 228 | commitment_weight=config.embed_loss_weight, # the weight on the commitment loss 229 | stochastic_sample_codes=True, 230 | sample_codebook_temp=0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic 231 | shared_codebook=config.embed_share, 232 | decay=config.code_decay, 233 | ) 234 | vq.load_state_dict(get_parameters_from_state_dict(state_dict, "vq")) 235 | return encoder, pre_quant, post_quant, vq 236 | 237 | 238 | def get_rvqvae_v1_all(config, resume): 239 | encoder, pre_quant, post_quant, vq = get_rvqvae_v1_encoder_vq(config, resume) 240 | decoder = get_rvqvae_v0_decoder(config, resume) 241 | return encoder, decoder, pre_quant, post_quant, vq 242 | -------------------------------------------------------------------------------- /trainer/train_transformer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import omegaconf 4 | import trimesh 5 | from cosine_annealing_warmup import CosineAnnealingWarmupRestarts 6 | import pytorch_lightning as pl 7 | import hydra 8 | from easydict import EasyDict 9 | from lightning_utilities.core.rank_zero import rank_zero_only 10 | from pathlib import Path 11 | import torch 12 | 13 | from dataset.quantized_soup import QuantizedSoupTripletsCreator 14 | from dataset.triangles import TriangleNodesWithFacesAndSequenceIndices, TriangleNodesWithFacesDataloader 15 | from model.transformer import QuantSoupTransformer 16 | from trainer import create_trainer, step, get_rvqvae_v0_decoder 17 | from util.misc import accuracy 18 | from util.visualization import plot_vertices_and_faces 19 | from util.misc import get_parameters_from_state_dict 20 | 21 | 22 | class QuantSoupModelTrainer(pl.LightningModule): 23 | 24 | def __init__(self, config): 25 | super().__init__() 26 | self.config = config 27 | self.vq_cfg = omegaconf.OmegaConf.load(Path(config.vq_resume).parents[1] / "config.yaml") 28 | self.save_hyperparameters() 29 | self.train_dataset = TriangleNodesWithFacesAndSequenceIndices(config, 'train', config.scale_augment, config.shift_augment, config.ft_category) 30 | self.val_dataset = TriangleNodesWithFacesAndSequenceIndices(config, 'val', config.scale_augment_val, False, config.ft_category) 31 | print("Dataset Lengths:", len(self.train_dataset), len(self.val_dataset)) 32 | print("Batch Size:", self.config.batch_size) 33 | print("Dataloader Lengths:", len(self.train_dataset) // self.config.batch_size, len(self.val_dataset) // self.config.batch_size) 34 | model_cfg = get_qsoup_model_config(config, self.vq_cfg.embed_levels) 35 | self.model = QuantSoupTransformer(model_cfg, self.vq_cfg) 36 | self.sequencer = QuantizedSoupTripletsCreator(self.config, self.vq_cfg) 37 | self.sequencer.freeze_vq() 38 | # print('compiling model...') 39 | # self.model = torch.compile(model) # requires PyTorch 2.0 40 | self.output_dir_image = Path(f'runs/{self.config.experiment}/image') 41 | self.output_dir_image.mkdir(exist_ok=True, parents=True) 42 | self.output_dir_mesh = Path(f'runs/{self.config.experiment}/mesh') 43 | self.output_dir_mesh.mkdir(exist_ok=True, parents=True) 44 | if self.config.ft_resume is not None: 45 | self.model.load_state_dict(get_parameters_from_state_dict(torch.load(self.config.ft_resume, map_location='cpu')['state_dict'], "model")) 46 | self.automatic_optimization = False 47 | 48 | def configure_optimizers(self): 49 | optimizer = self.model.configure_optimizers( 50 | self.config.weight_decay, self.config.lr, 51 | (self.config.beta1, self.config.beta2), 'cuda' 52 | ) 53 | max_steps = int(self.config.max_epoch * len(self.train_dataset) / self.config.batch_size / 2) 54 | print('Max Steps | First cycle:', max_steps) 55 | scheduler = CosineAnnealingWarmupRestarts( 56 | optimizer, first_cycle_steps=max_steps, cycle_mult=1.0, 57 | max_lr=self.config.lr, min_lr=self.config.min_lr, 58 | warmup_steps=self.config.warmup_steps, gamma=1.0 59 | ) 60 | return [optimizer], [scheduler] 61 | 62 | def training_step(self, data, batch_idx): 63 | optimizer = self.optimizers() 64 | scheduler = self.lr_schedulers() 65 | scheduler.step() # type: ignore 66 | if self.config.force_lr is not None: 67 | for param_group in optimizer.param_groups: 68 | param_group['lr'] = self.config.force_lr 69 | sequence_in, sequence_out, pfin, pfout = self.sequencer(data.x, data.edge_index, data.batch, data.faces, data.num_vertices.sum(), data.js) 70 | logits, loss = self.model(sequence_in, pfin, pfout, self.sequencer, targets=sequence_out) 71 | acc = accuracy(logits.detach(), sequence_out, ignore_label=2, device=self.device) 72 | self.log("train/ce_loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) 73 | self.log("train/acc", acc.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) 74 | loss = loss / self.config.gradient_accumulation_steps # scale the loss to account for gradient accumulation 75 | self.manual_backward(loss) 76 | # accumulate gradients of `n` batches 77 | if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0: 78 | step(optimizer, [self.model]) 79 | optimizer.zero_grad(set_to_none=True) # type: ignore 80 | self.log("lr", optimizer.param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) # type: ignore 81 | 82 | def validation_step(self, data, batch_idx): 83 | sequence_in, sequence_out, pfin, pfout = self.sequencer(data.x, data.edge_index, data.batch, data.faces, data.num_vertices.sum(), data.js) 84 | logits, loss = self.model(sequence_in, pfin, pfout, self.sequencer, targets=sequence_out) 85 | acc = accuracy(logits.detach(), sequence_out, ignore_label=2, device=self.device) 86 | if not torch.isnan(loss).any(): 87 | self.log("val/ce_loss", loss.item(), on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 88 | if not torch.isnan(acc).any(): 89 | self.log("val/acc", acc.item(), on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 90 | 91 | @rank_zero_only 92 | def on_validation_epoch_end(self): 93 | decoder = get_rvqvae_v0_decoder(self.vq_cfg, self.config.vq_resume, self.device) 94 | for k in range(self.config.num_val_samples): 95 | data = self.val_dataset.get(random.randint(0, len(self.val_dataset) - 1)) 96 | soup_sequence, face_in_idx, face_out_idx, target = self.sequencer.get_completion_sequence( 97 | data.x.to(self.device), 98 | data.edge_index.to(self.device), 99 | data.faces.to(self.device), 100 | data.num_vertices, 101 | 12 102 | ) 103 | y = self.model.generate_with_beamsearch( 104 | soup_sequence, face_in_idx, face_out_idx, self.sequencer, self.config.max_val_tokens, use_kv_cache=True 105 | ) 106 | if y is None: 107 | continue 108 | gen_vertices, gen_faces = self.sequencer.decode(y[0], decoder) 109 | plot_vertices_and_faces(gen_vertices, gen_faces, self.output_dir_image / f"{self.global_step:06d}_{k}.jpg") 110 | 111 | try: 112 | trimesh.Trimesh(vertices=gen_vertices, faces=gen_faces, process=False).export(self.output_dir_mesh / f"{self.global_step:06d}_{k}.obj") 113 | except Exception as e: 114 | pass # sometimes the mesh is invalid (ngon) and we don't want to crash 115 | 116 | def train_dataloader(self): 117 | return TriangleNodesWithFacesDataloader(self.train_dataset, batch_size=self.config.batch_size, shuffle=True, drop_last=not self.config.overfit, num_workers=self.config.num_workers, pin_memory=True) 118 | 119 | def val_dataloader(self): 120 | return TriangleNodesWithFacesDataloader(self.val_dataset, batch_size=self.config.batch_size, shuffle=True, drop_last=False, num_workers=self.config.num_workers) 121 | 122 | 123 | def get_qsoup_model_config(config, vq_embed_levels): 124 | cfg = EasyDict({ 125 | 'block_size': config.block_size, 126 | 'n_embd': config.model.n_embd, 127 | 'dropout': config.model.dropout, 128 | 'n_layer': config.model.n_layer, 129 | 'n_head': config.model.n_head, 130 | 'bias': config.model.bias, 131 | 'finemb_size': vq_embed_levels * 3, 132 | 'foutemb_size': config.block_size * 3, 133 | }) 134 | return cfg 135 | 136 | 137 | @hydra.main(config_path='../config', config_name='meshgpt', version_base='1.2') 138 | def main(config): 139 | trainer = create_trainer("MeshTriSoup", config) 140 | model = QuantSoupModelTrainer(config) 141 | trainer.fit(model, ckpt_path=config.resume) 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /trainer/train_vocabulary.py: -------------------------------------------------------------------------------- 1 | import omegaconf 2 | import torch_scatter 3 | import trimesh 4 | from cosine_annealing_warmup import CosineAnnealingWarmupRestarts 5 | import pytorch_lightning as pl 6 | import hydra 7 | from lightning_utilities.core.rank_zero import rank_zero_only 8 | from pathlib import Path 9 | import torch 10 | from vector_quantize_pytorch import ResidualVQ 11 | 12 | from dataset import quantize_coordinates 13 | from dataset.triangles import create_feature_stack_from_triangles, TriangleNodesWithFaces, TriangleNodesWithFacesDataloader 14 | from model.encoder import GraphEncoder 15 | from model.decoder import resnet34_decoder 16 | from model.softargmax import softargmax 17 | from trainer import create_trainer, step, create_conv_batch 18 | from util.visualization import plot_vertices_and_faces, triangle_sequence_to_mesh 19 | 20 | 21 | class TriangleTokenizationGraphConv(pl.LightningModule): 22 | 23 | def __init__(self, config): 24 | super().__init__() 25 | self.config = config 26 | self.save_hyperparameters() 27 | if config.only_chairs: 28 | self.train_dataset = TriangleNodesWithFaces(config, 'train', config.scale_augment, config.shift_augment, '03001627') 29 | self.interesting_categories = [('03001627', "")] 30 | self.val_datasets = [TriangleNodesWithFaces(config, 'val', config.scale_augment_val, config.shift_augment_val, '03001627')] 31 | else: 32 | self.train_dataset = TriangleNodesWithFaces(config, 'train', config.scale_augment, config.shift_augment, None) 33 | self.interesting_categories = [('02828884', '_bench'), ('02871439', '_bookshelf'), ('03001627', ""), ('03211117', '_display'), ('04379243', '_table')] 34 | self.val_datasets = [] 35 | for cat, name in self.interesting_categories: 36 | self.val_datasets.append(TriangleNodesWithFaces(config, 'val', config.scale_augment_val, config.shift_augment_val, cat)) 37 | self.encoder = GraphEncoder(no_max_pool=config.g_no_max_pool, aggr=config.g_aggr, graph_conv=config.graph_conv, use_point_features=config.use_point_feats, output_dim=576) 38 | self.decoder = resnet34_decoder(512, config.num_tokens - 2, config.ce_output) 39 | self.pre_quant = torch.nn.Linear(192, config.embed_dim) 40 | self.post_quant = torch.nn.Linear(config.embed_dim * 3, 512) 41 | self.vq = ResidualVQ( 42 | dim=self.config.embed_dim, 43 | codebook_size=self.config.n_embed, # codebook size 44 | num_quantizers=config.embed_levels, 45 | commitment_weight=self.config.embed_loss_weight, # the weight on the commitment loss 46 | stochastic_sample_codes=True, 47 | sample_codebook_temp=config.stochasticity, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic 48 | shared_codebook=self.config.embed_share, 49 | decay=self.config.code_decay, 50 | ) 51 | self.register_buffer('smoothing_weight', torch.tensor([2, 10, 200, 10, 2], dtype=torch.float32).unsqueeze(0).unsqueeze(0)) 52 | # print('compiling model...') 53 | # self.model = torch.compile(model) # requires PyTorch 2.0 54 | self.output_dir_image_val = Path(f'runs/{self.config.experiment}/image_val') 55 | self.output_dir_image_val.mkdir(exist_ok=True, parents=True) 56 | self.output_dir_mesh_val = Path(f'runs/{self.config.experiment}/mesh_val') 57 | self.output_dir_mesh_val.mkdir(exist_ok=True, parents=True) 58 | self.output_dir_image_train = Path(f'runs/{self.config.experiment}/image_train') 59 | self.output_dir_image_train.mkdir(exist_ok=True, parents=True) 60 | self.output_dir_mesh_train = Path(f'runs/{self.config.experiment}/mesh_train') 61 | self.output_dir_mesh_train.mkdir(exist_ok=True, parents=True) 62 | self.automatic_optimization = False 63 | self.distribute_features_fn = distribute_features if self.config.distribute_features else dummy_distribute 64 | self.visualize_groundtruth() 65 | 66 | def configure_optimizers(self): 67 | parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.pre_quant.parameters()) + list(self.post_quant.parameters()) + list(self.vq.parameters()) 68 | optimizer = torch.optim.AdamW(parameters, lr=self.config.lr, amsgrad=True, weight_decay=self.config.weight_decay) 69 | max_steps = int(self.config.max_epoch * len(self.train_dataset) / self.config.batch_size) 70 | print('Max Steps | First cycle:', max_steps) 71 | scheduler = CosineAnnealingWarmupRestarts( 72 | optimizer, first_cycle_steps=max_steps, cycle_mult=1.0, 73 | max_lr=self.config.lr, min_lr=self.config.min_lr, 74 | warmup_steps=self.config.warmup_steps, gamma=1.0 75 | ) 76 | return [optimizer], [scheduler] 77 | 78 | def create_conv_batch(self, encoded_features, batch, batch_size): 79 | return create_conv_batch(encoded_features, batch, batch_size, self.device) 80 | 81 | def training_step(self, data, batch_idx): 82 | optimizer = self.optimizers() 83 | scheduler = self.lr_schedulers() 84 | scheduler.step() # type: ignore 85 | encoded_x = self.encoder(data.x, data.edge_index, data.batch) 86 | encoded_x = encoded_x.reshape(encoded_x.shape[0] * 3, 192) # 3N x 192 87 | encoded_x = self.distribute_features_fn(encoded_x, data.faces, data.num_vertices.sum(), self.device) 88 | encoded_x = self.pre_quant(encoded_x) # 3N x 192 89 | encoded_x, _, commit_loss = self.vq(encoded_x.unsqueeze(0)) 90 | encoded_x = encoded_x.squeeze(0) 91 | commit_loss = commit_loss.mean() 92 | encoded_x = encoded_x.reshape(-1, 3 * encoded_x.shape[-1]) 93 | encoded_x = self.post_quant(encoded_x) 94 | encoded_x_conv, conv_mask = self.create_conv_batch(encoded_x, data.batch, self.config.batch_size) 95 | decoded_x_conv = self.decoder(encoded_x_conv) 96 | if self.config.ce_output: 97 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :] 98 | decoded_tri = softargmax(decoded_x) / (self.config.num_tokens - 3) - 0.5 99 | _, decoded_normals, decoded_areas, decoded_angles = create_feature_stack_from_triangles(decoded_tri.reshape(-1, 3, 3)) 100 | if self.config.use_smoothed_loss: 101 | otarget = torch.nn.functional.one_hot(data.y.reshape(-1), num_classes=self.config.num_tokens - 2).float() 102 | otarget = otarget.unsqueeze(1) 103 | starget = torch.nn.functional.conv1d(otarget, self.smoothing_weight, bias=None, stride=1, padding=2, dilation=1, groups=1) 104 | if self.config.use_multimodal_loss: 105 | starget_a = starget.reshape(-1, decoded_x.shape[-2] * decoded_x_conv.shape[-1]) 106 | starget_a = torch.nn.functional.normalize(starget_a, p=1.0, dim=-1, eps=1e-12).squeeze(1) 107 | starget_b = torch.nn.functional.normalize(starget, p=1.0, dim=-1, eps=1e-12).squeeze(1) 108 | loss = torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-2] * decoded_x_conv.shape[-1]), starget_a).mean() 109 | loss = loss * 0.1 + torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-1]), starget_b).mean() 110 | else: 111 | starget = torch.nn.functional.normalize(starget, p=1.0, dim=-1, eps=1e-12).squeeze(1) 112 | loss = torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-1]), starget).mean() 113 | else: 114 | loss = torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-1]), data.y.reshape(-1), reduction='mean') 115 | y_coords = data.y / (self.config.num_tokens - 3) - 0.5 116 | loss_tri = torch.nn.functional.mse_loss(decoded_tri, y_coords, reduction='mean') 117 | loss_normals = torch.nn.functional.mse_loss(decoded_normals, data.x[:, 9:12], reduction='mean') 118 | loss_areas = torch.nn.functional.mse_loss(decoded_areas, data.x[:, 12:13], reduction='mean') 119 | loss_angles = torch.nn.functional.mse_loss(decoded_angles, data.x[:, 13:16], reduction='mean') 120 | else: 121 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-1])[conv_mask, :] 122 | _, decoded_normals, decoded_areas, decoded_angles = create_feature_stack_from_triangles(decoded_x.reshape(-1, 3, 3)) 123 | loss = torch.nn.functional.mse_loss(decoded_x, data.y, reduction='mean') 124 | loss_tri = torch.nn.functional.mse_loss(decoded_x, data.y, reduction='mean') 125 | loss_normals = torch.nn.functional.mse_loss(decoded_normals, data.x[:, 9:12], reduction='mean') 126 | loss_areas = torch.nn.functional.mse_loss(decoded_areas, data.x[:, 12:13], reduction='mean') 127 | loss_angles = torch.nn.functional.mse_loss(decoded_angles, data.x[:, 13:16], reduction='mean') 128 | 129 | acc = self.get_accuracy(decoded_x, data.y) 130 | acc_triangle = self.get_triangle_accuracy(decoded_x, data.y) 131 | self.log("train/ce_loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) 132 | self.log("train/mse_loss", loss_tri.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) 133 | self.log("train/norm_loss", loss_normals.item(), on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) 134 | self.log("train/area_loss", loss_areas.item(), on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) 135 | self.log("train/angle_loss", loss_angles.item(), on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) 136 | self.log("train/embed_loss", commit_loss.item(), on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) 137 | self.log("train/acc", acc.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) 138 | self.log("train/acc_tri", acc_triangle.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) 139 | loss = loss + loss_tri * self.config.tri_weight + loss_normals * self.config.norm_weight + loss_areas * self.config.area_weight + loss_angles * self.config.angle_weight + commit_loss 140 | # loss = loss + loss_tri * self.config.tri_weight + commit_loss 141 | loss = loss / self.config.gradient_accumulation_steps # scale the loss to account for gradient accumulation 142 | self.manual_backward(loss) 143 | # accumulate gradients of `n` batches 144 | if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0: 145 | step(optimizer, [self.encoder, self.decoder, self.pre_quant, self.post_quant]) 146 | optimizer.zero_grad(set_to_none=True) # type: ignore 147 | self.log("lr", optimizer.param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) # type: ignore 148 | 149 | def validation_step(self, data, batch_idx, dataloader_idx): 150 | encoded_x = self.encoder(data.x, data.edge_index, data.batch) 151 | encoded_x = encoded_x.reshape(encoded_x.shape[0] * 3, 192) 152 | encoded_x = self.distribute_features_fn(encoded_x, data.faces, data.num_vertices.sum(), self.device) 153 | encoded_x = self.pre_quant(encoded_x) 154 | encoded_x, _, commit_loss = self.vq(encoded_x.unsqueeze(0)) 155 | encoded_x = encoded_x.squeeze(0) 156 | commit_loss = commit_loss.mean() 157 | encoded_x = encoded_x.reshape(-1, 3 * encoded_x.shape[-1]) 158 | encoded_x = self.post_quant(encoded_x) 159 | encoded_x_conv, conv_mask = self.create_conv_batch(encoded_x, data.batch, self.config.batch_size) 160 | decoded_x_conv = self.decoder(encoded_x_conv) 161 | if self.config.ce_output: 162 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :] 163 | decoded_c = softargmax(decoded_x) / (self.config.num_tokens - 3) - 0.5 164 | loss = torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-1]), data.y.reshape(-1), reduction='mean') 165 | y_coords = data.y / (self.config.num_tokens - 3) - 0.5 166 | loss_c = torch.nn.functional.mse_loss(decoded_c, y_coords, reduction='mean') 167 | else: 168 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-1])[conv_mask, :] 169 | loss = torch.nn.functional.mse_loss(decoded_x, data.y, reduction='mean') 170 | loss_c = torch.nn.functional.mse_loss(decoded_x, data.y, reduction='mean') 171 | acc = self.get_accuracy(decoded_x, data.y) 172 | acc_triangle = self.get_triangle_accuracy(decoded_x, data.y) 173 | if not torch.isnan(loss).any(): 174 | self.log(f"val/ce_loss{self.interesting_categories[dataloader_idx][1]}", loss.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 175 | if not torch.isnan(loss_c).any(): 176 | self.log(f"val/mse_loss{self.interesting_categories[dataloader_idx][1]}", loss_c.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 177 | if not torch.isnan(commit_loss).any(): 178 | self.log(f"val/embed_loss{self.interesting_categories[dataloader_idx][1]}", commit_loss.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 179 | if not torch.isnan(acc).any(): 180 | self.log(f"val/acc{self.interesting_categories[dataloader_idx][1]}", acc.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 181 | if not torch.isnan(acc_triangle).any(): 182 | self.log(f"val/acc_tri{self.interesting_categories[dataloader_idx][1]}", acc_triangle.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) 183 | 184 | @rank_zero_only 185 | def on_validation_epoch_end(self): 186 | category_names = [""] + [x[0].strip("_") for x in self.interesting_categories] 187 | for didx, dataset in enumerate([self.train_dataset] + self.val_datasets): 188 | output_dir_image = self.output_dir_image_train if didx == 0 else self.output_dir_image_val 189 | output_dir_mesh = self.output_dir_mesh_train if didx == 0 else self.output_dir_mesh_val 190 | for k in range(self.config.num_val_samples): 191 | data = dataset.get(k * (len(dataset) // self.config.num_val_samples) % len(dataset)) 192 | encoded_x = self.encoder(data.x.to(self.device), data.edge_index.to(self.device), torch.zeros([data.x.shape[0]], device=self.device).long()) 193 | encoded_x = encoded_x.reshape(encoded_x.shape[0] * 3, 192) 194 | encoded_x = self.distribute_features_fn(encoded_x, data.faces.to(self.device), data.num_vertices, self.device) 195 | encoded_x = self.pre_quant(encoded_x) 196 | encoded_x, _, _ = self.vq(encoded_x.unsqueeze(0)) 197 | encoded_x = encoded_x.squeeze(0) 198 | encoded_x = encoded_x.reshape(-1, 3 * encoded_x.shape[-1]) 199 | encoded_x = self.post_quant(encoded_x) 200 | encoded_x_conv, conv_mask = self.create_conv_batch(encoded_x, torch.zeros([data.x.shape[0]], device=self.device).long(), 1) 201 | decoded_x_conv = self.decoder(encoded_x_conv) 202 | if self.config.ce_output: 203 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :] 204 | coords = decoded_x.argmax(-1).detach().cpu().numpy() / (self.config.num_tokens - 3) - 0.5 205 | else: 206 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-1])[conv_mask, :] 207 | coords = decoded_x.detach().cpu().numpy() 208 | gen_vertices, gen_faces = triangle_sequence_to_mesh(coords) 209 | plot_vertices_and_faces(gen_vertices, gen_faces, output_dir_image / f"{self.global_step:06d}_{category_names[didx]}_{k}.jpg") 210 | try: 211 | trimesh.Trimesh(vertices=gen_vertices, faces=gen_faces, process=False).export(output_dir_mesh / f"{self.global_step:06d}_{category_names[didx]}_{k}.obj") 212 | except Exception as e: 213 | pass # sometimes the mesh is invalid (ngon) and we don't want to crash 214 | 215 | def visualize_groundtruth(self): 216 | category_names = [""] + [x[0].strip("_") for x in self.interesting_categories] 217 | for didx, dataset in enumerate([self.train_dataset] + self.val_datasets): 218 | output_dir_image = self.output_dir_image_train if didx == 0 else self.output_dir_image_val 219 | for k in range(self.config.num_val_samples): 220 | data = dataset.get(k * (len(dataset) // self.config.num_val_samples) % len(dataset)) 221 | if self.config.ce_output: 222 | coords = data.y / (self.config.num_tokens - 3) - 0.5 223 | else: 224 | coords = data.y 225 | gen_vertices, gen_faces = triangle_sequence_to_mesh(coords) 226 | plot_vertices_and_faces(gen_vertices, gen_faces, output_dir_image / f"GT_{category_names[didx]}_{k}.jpg") 227 | 228 | def train_dataloader(self): 229 | return TriangleNodesWithFacesDataloader(self.train_dataset, batch_size=self.config.batch_size, shuffle=True, drop_last=not self.config.overfit, num_workers=self.config.num_workers, pin_memory=True) 230 | 231 | def val_dataloader(self): 232 | dataloaders = [] 233 | for val_dataset in self.val_datasets: 234 | dataloaders.append(TriangleNodesWithFacesDataloader(val_dataset, batch_size=self.config.batch_size, shuffle=True, drop_last=True, num_workers=self.config.num_workers)) 235 | return dataloaders 236 | 237 | def get_accuracy(self, x, y): 238 | if self.config.ce_output: 239 | return (x.argmax(-1).reshape(-1) == y.reshape(-1)).sum() / (x.shape[0] * x.shape[1]) 240 | return (quantize_coordinates(x, self.config.num_tokens - 2).reshape(-1) == quantize_coordinates(y, self.config.num_tokens - 2).reshape(-1)).sum() / (x.shape[0] * x.shape[1]) 241 | 242 | def get_triangle_accuracy(self, x, y): 243 | if self.config.ce_output: 244 | return torch.all(x.argmax(-1).reshape(-1, 9) == y.reshape(-1, 9), dim=-1).sum() / x.shape[0] 245 | return torch.all((quantize_coordinates(x, self.config.num_tokens - 2).reshape(-1, 9) == quantize_coordinates(y, self.config.num_tokens - 2).reshape(-1, 9)), dim=-1).sum() / (x.shape[0]) 246 | 247 | 248 | def distribute_features(features, face_indices, num_vertices, device): 249 | # N = num triangles 250 | # features is N3 x 192 251 | # face_indices is N x 3 252 | assert features.shape[0] == face_indices.shape[0] * face_indices.shape[1], "Features and face indices must match in size" 253 | vertex_features = torch.zeros([num_vertices, features.shape[1]], device=device) 254 | torch_scatter.scatter_mean(features, face_indices.reshape(-1), out=vertex_features, dim=0) 255 | distributed_features = vertex_features[face_indices.reshape(-1), :] 256 | return distributed_features 257 | 258 | 259 | def dummy_distribute(features, _face_indices, _n, _device): 260 | return features 261 | 262 | 263 | @hydra.main(config_path='../config', config_name='meshgpt', version_base='1.2') 264 | def main(config): 265 | trainer = create_trainer("TriangleTokens", config) 266 | model = TriangleTokenizationGraphConv(config) 267 | trainer.fit(model, ckpt_path=config.resume) 268 | 269 | 270 | if __name__ == '__main__': 271 | main() 272 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/audi/MeshGPT/0c871aa4acad3c316404630601c73aea5338082b/util/__init__.py -------------------------------------------------------------------------------- /util/filesystem_logger.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from pathlib import Path 4 | from typing import Dict, Optional, Union 5 | 6 | from omegaconf import OmegaConf 7 | from pytorch_lightning.loggers.logger import Logger 8 | from lightning_fabric.loggers.logger import rank_zero_experiment 9 | from lightning_fabric.loggers.logger import _DummyExperiment 10 | 11 | 12 | class FilesystemLogger(Logger): 13 | 14 | @property 15 | def version(self) -> Union[int, str]: 16 | return 0 17 | 18 | @property 19 | def name(self) -> str: 20 | return "fslogger" 21 | 22 | # noinspection PyMethodOverriding 23 | def log_hyperparams(self, params: argparse.Namespace): 24 | pass 25 | 26 | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): 27 | pass 28 | 29 | def __init__(self, experiment_config, **_kwargs): 30 | super().__init__() 31 | self.experiment_config = experiment_config 32 | self._experiment = None 33 | # noinspection PyStatementEffect 34 | self.experiment 35 | 36 | @property 37 | @rank_zero_experiment 38 | def experiment(self): 39 | if self._experiment is None: 40 | self._experiment = _DummyExperiment() 41 | experiment_dir = Path("runs", self.experiment_config["experiment"]) 42 | experiment_dir.mkdir(exist_ok=True, parents=True) 43 | 44 | src_folders = ['config', 'model', 'tests', 'trainer', 'util', 'visualizer', 'dataset'] 45 | sources = [] 46 | for src in src_folders: 47 | sources.extend(list(Path(".").glob(f'{src}/**/*'))) 48 | 49 | files_to_copy = [x for x in sources if x.suffix in [".py", ".pyx", ".txt", ".so", ".pyd", ".h", ".cu", ".c", '.cpp', ".html"] and x.parts[0] != "runs" and x.parts[0] != "wandb"] 50 | 51 | for f in files_to_copy: 52 | Path(experiment_dir, "code", f).parents[0].mkdir(parents=True, exist_ok=True) 53 | shutil.copyfile(f, Path(experiment_dir, "code", f)) 54 | 55 | Path(experiment_dir, "config.yaml").write_text(OmegaConf.to_yaml(self.experiment_config)) 56 | 57 | return self._experiment -------------------------------------------------------------------------------- /util/meshlab.py: -------------------------------------------------------------------------------- 1 | import pymeshlab 2 | 3 | 4 | def meshlab_proc(meshpath): 5 | ms = pymeshlab.MeshSet() 6 | ms.load_new_mesh(str(meshpath)) 7 | ms.meshing_merge_close_vertices(threshold=pymeshlab.Percentage(1)) 8 | ms.meshing_remove_duplicate_faces() 9 | ms.meshing_remove_null_faces() 10 | ms.meshing_remove_duplicate_vertices() 11 | ms.meshing_remove_unreferenced_vertices() 12 | ms.save_current_mesh(str(meshpath), save_vertex_color=False, save_vertex_coord=False, save_face_color=False, save_wedge_texcoord=False) 13 | ms.clear() 14 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from collections import OrderedDict 4 | 5 | 6 | def get_parameters_from_state_dict(state_dict, filter_key): 7 | new_state_dict = OrderedDict() 8 | for k in state_dict: 9 | if k.startswith(filter_key): 10 | new_state_dict[k.replace(filter_key + '.', '')] = state_dict[k] 11 | return new_state_dict 12 | 13 | 14 | def accuracy(y_pred, y_true, ignore_label=None, device=None): 15 | y_pred = y_pred.argmax(dim=-1) 16 | 17 | if ignore_label: 18 | normalizer = torch.sum(y_true != ignore_label) # type: ignore 19 | ignore_mask = torch.where( # type: ignore 20 | y_true == ignore_label, 21 | torch.zeros_like(y_true, device=device), 22 | torch.ones_like(y_true, device=device) 23 | ).type(torch.float32) 24 | else: 25 | normalizer = y_true.shape[0] 26 | ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32) 27 | acc = (y_pred.reshape(-1) == y_true.reshape(-1)).type(torch.float32) # type: ignore 28 | acc = torch.sum(acc*ignore_mask.flatten()) 29 | return acc / normalizer 30 | 31 | 32 | def rmse(y_pred, y_true, num_tokens, ignore_labels=(0, 1, 2)): 33 | mask = torch.logical_and(y_true != ignore_labels[0], y_pred != ignore_labels[0]) 34 | for i in range(1, len(ignore_labels)): 35 | mask = torch.logical_and(mask, y_true != ignore_labels[i]) 36 | mask = torch.logical_and(mask, y_pred != ignore_labels[i]) 37 | y_pred = y_pred[mask] 38 | y_true = y_true[mask] 39 | vertices_pred = (y_pred - 3) / num_tokens - 0.5 40 | vertices_true = (y_true - 3) / num_tokens - 0.5 41 | return torch.sqrt(torch.mean((vertices_pred - vertices_true)**2)) 42 | 43 | 44 | def get_create_shapenet_train_val(path): 45 | from collections import Counter 46 | import random 47 | 48 | all_shapes = [x.stem for x in list(path.iterdir())] 49 | all_categories = sorted(list(set(list(x.split("_")[0] for x in all_shapes)))) 50 | counts_all = Counter() 51 | for s in all_shapes: 52 | counts_all[s.split("_")[0]] += 1 53 | validation = random.sample(all_shapes, int(len(all_shapes) * 0.05)) 54 | counts_val = Counter() 55 | for s in validation: 56 | counts_val[s.split("_")[0]] += 1 57 | for c in all_categories: 58 | print(c, f"{counts_all[c] / len(all_shapes) * 100:.2f}", f"{counts_val[c] / len(validation) * 100:.2f}") 59 | train = [x for x in all_shapes if x not in validation] 60 | Path("val.txt").write_text("\n".join(validation)) 61 | Path("train.txt").write_text("\n".join(train)) 62 | 63 | 64 | def scale_vertices(vertices, x_lims=(0.75, 1.25), y_lims=(0.75, 1.25), z_lims=(0.75, 1.25)): 65 | # scale x, y, z 66 | x = np.random.uniform(low=x_lims[0], high=x_lims[1], size=(1,)) 67 | y = np.random.uniform(low=y_lims[0], high=y_lims[1], size=(1,)) 68 | z = np.random.uniform(low=z_lims[0], high=z_lims[1], size=(1,)) 69 | vertices = np.stack([vertices[:, 0] * x, vertices[:, 1] * y, vertices[:, 2] * z], axis=-1) 70 | return vertices 71 | 72 | 73 | def shift_vertices(vertices, x_lims=(-0.1, 0.1), y_lims=(-0.1, 0.1), z_lims=(-0.075, 0.075)): 74 | # shift x, y, z 75 | x = np.random.uniform(low=x_lims[0], high=x_lims[1], size=(1,)) 76 | y = np.random.uniform(low=y_lims[0], high=y_lims[1], size=(1,)) 77 | z = np.random.uniform(low=z_lims[0], high=z_lims[1], size=(1,)) 78 | x = max(min(x, 0.5 - vertices[:, 0].max()), -0.5 - vertices[:, 0].min()) 79 | y = max(min(y, 0.5 - vertices[:, 1].max()), -0.5 - vertices[:, 1].min()) 80 | z = max(min(z, 0.5 - vertices[:, 2].max()), -0.5 - vertices[:, 2].min()) 81 | vertices = np.stack([vertices[:, 0] + x, vertices[:, 1] + y, vertices[:, 2] + z], axis=-1) 82 | return vertices 83 | 84 | 85 | def normalize_vertices(vertices): 86 | bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) # type: ignore 87 | vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 88 | vertices = vertices / (bounds[1] - bounds[0]).max() 89 | return vertices 90 | 91 | 92 | def top_p_sampling(logits, p): 93 | probs = torch.softmax(logits, dim=-1) 94 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 95 | probs_sum = torch.cumsum(probs_sort, dim=-1) 96 | mask = probs_sum - probs_sort > p 97 | probs_sort[mask] = 0.0 98 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 99 | next_token = torch.multinomial(probs_sort, num_samples=1) 100 | next_token = torch.gather(probs_idx, -1, next_token) 101 | return next_token 102 | 103 | 104 | if __name__ == "__main__": 105 | from pathlib import Path 106 | # get_create_shapenet_train_val(Path("/cluster/gimli/ysiddiqui/ShapeNetCore.v2.meshlab/")) 107 | logits_ = torch.FloatTensor([[3, -1, 0.5, 0.1], 108 | [0, 9, 0.5, 0.1], 109 | [0, 0, 5, 0.1]]) 110 | for i in range(5): 111 | print(top_p_sampling(logits_, 0.9)) 112 | -------------------------------------------------------------------------------- /util/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Embedder: 6 | def __init__(self, **kwargs): 7 | self.kwargs = kwargs 8 | self.create_embedding_fn() 9 | 10 | def create_embedding_fn(self): 11 | embed_fns = [] 12 | d = self.kwargs['input_dims'] 13 | out_dim = 0 14 | if self.kwargs['include_input']: 15 | embed_fns.append(lambda x: x) 16 | out_dim += d 17 | 18 | max_freq = self.kwargs['max_freq_log2'] 19 | N_freqs = self.kwargs['num_freqs'] 20 | 21 | if self.kwargs['log_sampling']: 22 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 23 | else: 24 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 25 | 26 | for freq in freq_bands: 27 | for p_fn in self.kwargs['periodic_fns']: 28 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 29 | out_dim += d 30 | 31 | self.embed_fns = embed_fns 32 | self.out_dim = out_dim 33 | 34 | def embed(self, inputs): 35 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 36 | 37 | 38 | def get_embedder(multires, i=0): 39 | if i == -1: 40 | return nn.Identity(), 3 41 | 42 | embed_kwargs = { 43 | 'include_input': True, 44 | 'input_dims': 3, 45 | 'max_freq_log2': multires - 1, 46 | 'num_freqs': multires, 47 | 'log_sampling': True, 48 | 'periodic_fns': [torch.sin, torch.cos], 49 | } 50 | 51 | embedder_obj = Embedder(**embed_kwargs) 52 | embed = lambda x, eo=embedder_obj: eo.embed(x) 53 | return embed, embedder_obj.out_dim 54 | -------------------------------------------------------------------------------- /util/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from pathlib import Path 4 | from PIL import Image 5 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip 6 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 7 | 8 | from dataset import newface_token, stopface_token, padface_token 9 | 10 | 11 | def visualize_points(points, vis_path, colors=None): 12 | if colors is None: 13 | Path(vis_path).write_text("\n".join(f"v {p[0]} {p[1]} {p[2]} 127 127 127" for p in points)) 14 | else: 15 | Path(vis_path).write_text("\n".join(f"v {p[0]} {p[1]} {p[2]} {colors[i, 0]} {colors[i, 1]} {colors[i, 2]}" for i, p in enumerate(points))) 16 | 17 | 18 | def tokens_to_vertices(token_sequence, num_tokens): 19 | try: 20 | end = token_sequence.index(num_tokens + 1) 21 | except ValueError: 22 | end = len(token_sequence) 23 | token_sequence = token_sequence[:end] 24 | token_sequence = token_sequence[:(len(token_sequence) // 3) * 3] 25 | vertices = (np.array(token_sequence).reshape(-1, 3)) / num_tokens - 0.5 26 | # order: Z, Y, X --> X, Y, Z 27 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1) 28 | return vertices 29 | 30 | 31 | def visualize_quantized_mesh_vertices(token_sequence, num_tokens, output_path): 32 | vertices = tokens_to_vertices(token_sequence, num_tokens) 33 | plot_vertices(vertices, output_path) 34 | 35 | 36 | def visualize_quantized_mesh_vertices_and_faces(token_sequence_vertex, token_sequence_face, num_tokens, output_path): 37 | vertices, faces = tokens_to_mesh(token_sequence_vertex, token_sequence_face, num_tokens) 38 | plot_vertices_and_faces(vertices, faces, output_path) 39 | 40 | 41 | def plot_vertices(vertices, output_path): 42 | fig = plt.figure(figsize=(4, 4)) 43 | ax = fig.add_subplot(111, projection="3d") 44 | plt.xlim(-0.35, 0.35) 45 | plt.ylim(-0.35, 0.35) 46 | # Don't mess with the limits! 47 | plt.autoscale(False) 48 | ax.set_axis_off() 49 | ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], c='g', s=10) 50 | ax.set_zlim(-0.35, 0.35) 51 | ax.view_init(25, -120, 0) 52 | plt.tight_layout() 53 | plt.savefig(output_path) 54 | plt.close("all") 55 | 56 | 57 | def plot_vertices_and_faces(vertices, faces, output_path): 58 | ngons = [[vertices[v, :].tolist() for v in f] for f in faces] 59 | fig = plt.figure(figsize=(8, 8)) 60 | ax = fig.add_subplot(111, projection="3d") 61 | plt.xlim(-0.45, 0.45) 62 | plt.ylim(-0.45, 0.45) 63 | # Don't mess with the limits! 64 | plt.autoscale(False) 65 | ax.set_axis_off() 66 | ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], c='black', s=10) 67 | polygon_collection = Poly3DCollection(ngons) 68 | polygon_collection.set_alpha(0.3) 69 | polygon_collection.set_color('b') 70 | ax.add_collection(polygon_collection) 71 | ax.set_zlim(-0.35, 0.35) 72 | ax.view_init(25, -120, 0) 73 | plt.tight_layout() 74 | plt.savefig(output_path) 75 | plt.close("all") 76 | 77 | 78 | def visualize_quantized_mesh_vertices_gif(token_sequence, num_tokens, output_dir): 79 | vertices = tokens_to_vertices(token_sequence, num_tokens) 80 | visualize_mesh_vertices_gif(vertices, output_dir) 81 | 82 | 83 | def visualize_mesh_vertices_gif(vertices, output_dir): 84 | for i in range(1, len(vertices), 1): 85 | fig = plt.figure(figsize=(4, 4)) 86 | ax = fig.add_subplot(111, projection="3d") 87 | plt.xlim(-0.35, 0.35) 88 | plt.ylim(-0.35, 0.35) 89 | # Don't mess with the limits! 90 | plt.autoscale(False) 91 | ax.set_axis_off() 92 | ax.scatter(vertices[:i, 0], vertices[:i, 1], vertices[:i, 2], c='g', s=10) 93 | ax.set_zlim(-0.35, 0.35) 94 | ax.view_init(25, -120, 0) 95 | plt.tight_layout() 96 | plt.savefig(output_dir / f"{i:05d}.png") 97 | plt.close("all") 98 | create_gif(output_dir, 40, output_dir / "vis.gif") 99 | 100 | 101 | def visualize_quantized_mesh_vertices_and_faces_gif(token_sequence_vertex, token_sequence_face, num_tokens, output_dir): 102 | visualize_quantized_mesh_vertices_gif(token_sequence_vertex, num_tokens, output_dir) 103 | vertices, faces = tokens_to_mesh(token_sequence_vertex, token_sequence_face, num_tokens) 104 | visualize_mesh_vertices_and_faces_gif(vertices, faces, output_dir) 105 | 106 | 107 | def visualize_mesh_vertices_and_faces_gif(vertices, faces, output_dir): 108 | ngons = [[vertices[v, :].tolist() for v in f] for f in faces] 109 | for i in range(1, len(ngons) + 1, 1): 110 | fig = plt.figure(figsize=(9, 9)) 111 | ax = fig.add_subplot(111, projection="3d") 112 | plt.xlim(-0.35, 0.35) 113 | plt.ylim(-0.35, 0.35) 114 | # Don't mess with the limits! 115 | plt.autoscale(False) 116 | ax.set_axis_off() 117 | ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], c='black', s=10) 118 | polygon_collection = Poly3DCollection(ngons[:i]) 119 | polygon_collection.set_alpha(0.3) 120 | polygon_collection.set_color('b') 121 | ax.add_collection(polygon_collection) 122 | ax.set_zlim(-0.35, 0.35) 123 | ax.view_init(25, -120, 0) 124 | plt.tight_layout() 125 | plt.savefig(output_dir / f"{len(vertices) + i:05d}.png") 126 | plt.close("all") 127 | create_gif(output_dir, 40, output_dir / "vis.gif") 128 | 129 | 130 | def create_gif(folder, fps, output_path): 131 | collection_rgb = [] 132 | for f in sorted([x for x in folder.iterdir() if x.suffix == ".png" or x.suffix == ".jpg"]): 133 | img_rgb = np.array(Image.open(f).resize((384, 384))) 134 | collection_rgb.append(img_rgb) 135 | clip = ImageSequenceClip(collection_rgb, fps=fps) 136 | clip.write_gif(output_path, verbose=False, logger=None) 137 | 138 | 139 | def tokens_to_mesh(vertices_q, face_sequence, num_tokens): 140 | vertices = (np.array(vertices_q).reshape(-1, 3)) / num_tokens - 0.5 141 | # order: Z, Y, X --> X, Y, Z 142 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1) 143 | try: 144 | end = face_sequence.index(stopface_token) 145 | except ValueError: 146 | end = len(face_sequence) 147 | face_sequence = face_sequence[:end] 148 | face_sequence = [x for x in face_sequence if x != 2] # remove padding 149 | faces = [] 150 | current_face = [] 151 | for i in range(len(face_sequence)): 152 | if face_sequence[i] == newface_token: 153 | if len(current_face) > 2: 154 | faces.append(current_face) 155 | current_face = [] 156 | else: 157 | current_face.append(face_sequence[i] - 3) 158 | if len(current_face) != 0: 159 | faces.append(current_face) 160 | return vertices, faces 161 | 162 | 163 | def ngon_to_obj(vertices, faces): 164 | obj = "" 165 | for i in range(len(vertices)): 166 | obj += f"v {vertices[i, 0]} {vertices[i, 1]} {vertices[i, 2]}\n" 167 | for i in range(len(faces)): 168 | fline = "f" 169 | for j in range(len(faces[i])): 170 | fline += f" {faces[i][j] + 1} " 171 | fline += "\n" 172 | obj += fline 173 | return obj 174 | 175 | 176 | def trisoup_sequence_to_mesh(soup_sequence, num_tokens): 177 | try: 178 | end = soup_sequence.index(stopface_token) 179 | except ValueError: 180 | end = len(soup_sequence) 181 | soup_sequence = soup_sequence[:end] 182 | vertices_q = [] 183 | current_subsequence = [] 184 | for i in range(len(soup_sequence)): 185 | if soup_sequence[i] == newface_token: 186 | if len(current_subsequence) >= 9: 187 | current_subsequence = current_subsequence[:9] 188 | vertices_q.append(np.array(current_subsequence).reshape(3, 3)) 189 | current_subsequence = [] 190 | elif soup_sequence[i] != padface_token: 191 | current_subsequence.append(soup_sequence[i] - 3) 192 | if len(current_subsequence) >= 9: 193 | current_subsequence = current_subsequence[:9] 194 | vertices_q.append(np.array(current_subsequence).reshape(3, 3)) 195 | vertices = (np.array(vertices_q).reshape(-1, 3)) / num_tokens - 0.5 196 | # order: Z, Y, X --> X, Y, Z 197 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1) 198 | faces = np.array(list(range(len(vertices_q) * 3)), dtype=np.int32).reshape(-1, 3) 199 | return vertices, faces 200 | 201 | 202 | def ngonsoup_sequence_to_mesh(soup_sequence, num_tokens): 203 | try: 204 | end = soup_sequence.index(stopface_token) 205 | except ValueError: 206 | end = len(soup_sequence) 207 | soup_sequence = soup_sequence[:end] 208 | vertices_q = [] 209 | face_ctr = 0 210 | faces = [] 211 | current_subsequence = [] 212 | for i in range(len(soup_sequence)): 213 | if soup_sequence[i] == newface_token: 214 | current_subsequence = current_subsequence[:len(current_subsequence) // 3 * 3] 215 | if len(current_subsequence) > 0: 216 | vertices_q.append(np.array(current_subsequence).reshape(-1, 3)) 217 | faces.append([x for x in range(face_ctr, face_ctr + len(current_subsequence) // 3)]) 218 | face_ctr += (len(current_subsequence) // 3) 219 | current_subsequence = [] 220 | elif soup_sequence[i] != padface_token: 221 | current_subsequence.append(soup_sequence[i] - 3) 222 | 223 | current_subsequence = current_subsequence[:len(current_subsequence) // 3 * 3] 224 | if len(current_subsequence) > 0: 225 | vertices_q.append(np.array(current_subsequence).reshape(-1, 3)) 226 | faces.append([x for x in range(face_ctr, face_ctr + len(current_subsequence) // 3)]) 227 | face_ctr += (len(current_subsequence) // 3) 228 | 229 | vertices = np.vstack(vertices_q) / num_tokens - 0.5 230 | # order: Z, Y, X --> X, Y, Z 231 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1) 232 | return vertices, faces 233 | 234 | 235 | def triangle_sequence_to_mesh(triangles): 236 | vertices = triangles.reshape(-1, 3) 237 | faces = np.array(list(range(vertices.shape[0]))).reshape(-1, 3) 238 | return vertices, faces 239 | --------------------------------------------------------------------------------