├── LICENSE
├── README.md
├── config
└── meshgpt.yaml
├── dataset
├── __init__.py
├── quantized_soup.py
└── triangles.py
├── inference
└── infer_meshgpt.py
├── model
├── LICENSE
├── README.md
├── decoder.py
├── encoder.py
├── nanogpt.py
├── pointnet.py
├── softargmax.py
├── transformer.py
└── transformer_base.py
├── requirements.txt
├── trainer
├── __init__.py
├── train_transformer.py
└── train_vocabulary.py
└── util
├── __init__.py
├── filesystem_logger.py
├── meshlab.py
├── misc.py
├── positional_encoding.py
└── visualization.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Automotive Development Public Non-Commercial License Version 1.0
2 | (Based on Mozilla Public License Version 2.0 with additions regarding modifications)
3 |
4 | 1. Definitions
5 |
6 | 1.1 "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software.
7 |
8 | 1.2 "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution.
9 |
10 | 1.3 "Contribution" means Covered Software of a particular Contributor.
11 |
12 | 1.4 "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof.
13 |
14 | 1.5 "Executable Form" means any form of the work other than Source Code Form.
15 |
16 | 1.6 "Larger Work" means a work that combines Covered Software statically or dynamically with code not governed by the terms of this license.
17 |
18 | 1.7 "License" means this document.
19 |
20 | 1.8 "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License.
21 |
22 | 1.9 "Modifications" means any of the following:
23 |
24 | (a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or
25 |
26 | (b) any new file in Source Code Form that contains any Covered Software.
27 | 1.10 "Non-Commercial" means not primarily intended for or directed towards commercial advantage or monetary compensation.
28 |
29 | 1.11 "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version.
30 |
31 | 1.12 "Source Code Form" means the form of the work preferred for making modifications.
32 |
33 | 1.13 "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity.
34 |
35 | 2. License Grants and Conditions
36 |
37 | 2.1 Grants
38 | Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license:
39 |
40 | (a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions for Non-Commercial purposes only, either on an unmodified basis, with Modifications, or as part of a Larger Work; and
41 |
42 | (b) under Patent Claims of such Contributor to make, use, have made, import, and otherwise transfer either its Contributions or its Contributor Version for Non-Commercial purposes only.
43 |
44 | 2.2 Effective Date
45 | The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution.
46 |
47 | 2.3 Limitations on Grant Scope
48 | The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor:
49 |
50 | (a) for any code that a Contributor has removed from Covered Software; or
51 |
52 | (b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or
53 |
54 | (c) under Patent Claims infringed by Covered Software in the absence of its Contributions.
55 |
56 | This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.5).
57 |
58 | 2.4 Subsequent Licenses
59 | No Contributor makes additional grants as result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of any other license of Your choice (if permitted under the terms of Section 3.3).
60 |
61 | 2.5 Representation
62 | Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License.
63 |
64 | 2.6 Fair Use
65 | This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents.
66 |
67 | 2.7 Conditions
68 | Sections 3.1, 3.2, 3.3, 3.4 and 3.5 are conditions of the licenses granted in Section 2.1.
69 |
70 | 3. Responsibilities
71 |
72 | 3.1 Distribution of Source Form
73 | All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License and include a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form.
74 |
75 | 3.2 Distribution of Executable Form
76 | If You distribute Covered Software in Executable Form then:
77 |
78 | (a) such Covered Software must also be made available as described in Section 3.1 in Source Code Form including any Modifications that You create or to which You contribute and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and
79 |
80 | (b) You may distribute such Executable Form under the terms of this License and include this license, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. In any case, you must inform the recipients about the Initial Contributor providing the Covered Software under the terms of this License (we recommend the use of a notice like that in Exhibit A).
81 |
82 | 3.3 Distribution of a Larger Work
83 | You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software.
84 | 3.4 Modifications
85 | (a) You may modify Covered Software, if you cause all Modifications to be supplemented by a file documenting the differences between the Covered Software and the Modifications and the date of the creation of any Modification. You must include a prominent statement that any Modification is derived, directly or indirectly, from the Covered Software provided by the Contributor and including the name of the Contributor in (i) the Source Code Form, and (ii) in a notice according to Exhibit A in an Executable Form or related documentation in which You describe the origin or ownership of the Covered Software.
86 |
87 | (b) You are obliged to provide this documentation supplementing your Modifications and the Covered Software including your Modifications to the Contributor of the Covered Software as specified in the notices according to Exhibit A irrespective of any distribution of the Covered Software. You may provide this information and code directly to the Contributor via email or another accepted electronic distribution mechanism.
88 |
89 | 3.5 Notices
90 | You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies.
91 |
92 | 4. Inability to Comply Due to Statute or Regulation
93 |
94 | If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it.
95 |
96 | 5. Termination
97 |
98 | 5.1 The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated, unless and until such Contributor explicitly and finally terminates Your grants. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice and inform the Contributor of Your compliance.
99 |
100 | 5.2 If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate.
101 |
102 | 6. Disclaimer of Warranty
103 |
104 | Covered Software is provided under this License basically free of charge in a development status and on an "as is" basis, without warranty of any kind, either expressed, implied, or statutory, including, without limitation, warranties that the Covered Software is free of defects, merchantable, fit for a particular purpose or non-infringing. The entire risk as to the interaction and performance of the Covered Software is with You. This disclaimer of warranty constitutes an essential part of this License. No use of any Covered Software is authorized under this License except under this disclaimer.
105 |
106 | 7. Limitation of Liability
107 |
108 | Any Contributor shall only be liable for damages other than those resulting from the detriment to life, body and health to the extent such damages arise from willful misconduct, gross negligence or the culpable violation of a fundamental contractual obligation on the part of the Contributor or its vicarious agents. Any further liability for damages shall be excluded, especially liability for the loss of data and the recovery of this data if this loss could have been avoided by You through appropriate precautionary measures, in particular by creating daily backups of all data. The provisions of the German Product Liability Act and other mandatory legal statutes shall remain unaffected.
109 |
110 | 8. Litigation
111 |
112 | Any litigation relating to this License may be brought only in courts of the German jurisdiction and such litigation shall be governed by German law. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims.
113 |
114 | 9. Miscellaneous
115 |
116 | This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor.
117 |
118 | 10. Versions of the License
119 |
120 | 10.1 New Versions
121 | Audi AG is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number.
122 |
123 | 10.2 Effect of New Versions
124 | You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward.
125 |
126 | 10.3 Modified Versions
127 | If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License).
128 |
129 |
130 | Exhibit A - Source Code Form License Notice
131 |
132 | Copyright (c) 2024, Audi AG, ai-licenses@audi.de. All rights reserved.
133 |
134 | This Source Code Form is subject to the terms of the Automotive Development Public Non-Commercial License, v. 1.0. If a copy of the ADPNCL was not distributed with this file, You can obtain one from the distributor.
135 |
136 | If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice.
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers
2 |
3 |
4 |
5 | [**arXiv**](http://arxiv.org/abs/2311.15475) | [**Video**](https://www.youtube.com/watch?v=UV90O1_69_o) | [**Project Page**](https://nihalsid.github.io/mesh-gpt/)
6 |
7 |
8 | This repository contains the implementation for the paper:
9 |
10 | [**MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers**](http://arxiv.org/abs/2311.15475) by Yawar Siddiqui, Antonio Alliegro, Alexey Artemov, Tatiana Tommasi, Daniele Sirigatti, Vladislav Rosov, Angela Dai, Matthias Nießner.
11 |
12 |
13 |
14 |

15 |
16 |
17 | MeshGPT creates triangle meshes by autoregressively sampling from a transformer model that has been trained to produce tokens from a learned geometric vocabulary. These tokens can then be decoded into the faces of a triangle mesh. Our method generates clean, coherent, and compact meshes, characterized by sharp edges and high fidelity.
18 |
19 |
20 |
21 | ## Dependencies
22 |
23 | Install requirements from the project root directory:
24 |
25 | ```bash
26 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu118.html
27 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
28 | pip install packaging
29 | pip install -r requirements.txt
30 | ```
31 | In case errors show up for missing packages, install them manually.
32 |
33 | ## Structure
34 |
35 | Overall code structure is as follows:
36 |
37 | | Folder | Description |
38 | |------------------------|----------------------------------------------------------------------------------------------|
39 | | `config/` | hydra configs |
40 | | `data/` | processed dataset
41 | | `dataset/` | pytorch datasets and dataloaders |
42 | | `docs/` | project webpage files |
43 | | `inference/` | scripts for inferencing trained model |
44 | | `model/` | pytorch modules for encoder, decoder and the transformer |
45 | | `pretrained/` | pretrained models on shapenet chairs and tables |
46 | | `runs/` | model training logs and checkpoints go here in addition to wandb |
47 | | `trainer/` | pytorch-lightning module for training |
48 | | `util/` | misc utilities for positional encoding, visualization, logging etc. |
49 |
50 | ## Pre-trained Models and Data
51 |
52 | Download the pretrained models and the data from [here](https://drive.google.com/drive/folders/1Gzuxn6c1pguvRWrsedmCa9xKtest8aC2?usp=drive_link). Place them in the project, such that trained models are in `pretrained/` directory and data is in `data/shapenet` directory.
53 |
54 | ### Running inference
55 |
56 | To run inference use the following command
57 |
58 | ```bash
59 | python inference/infer_meshgpt.py
60 | ```
61 |
62 | Examples:
63 |
64 | ```bash
65 | # for chairs
66 | python inference/infer_meshgpt.py pretrained/transformer_ft_03001627/checkpoints/2287-0.ckpt beam 25
67 |
68 | # for tables
69 | python inference/infer_meshgpt.py pretrained/transformer_ft_04379243/checkpoints/1607-0.ckpt beam 25
70 | ```
71 |
72 | ## Training
73 |
74 | For launching training, use the following command from project root
75 |
76 | ```
77 | # vocabulary
78 | python trainer/train_vocabulary.py vq_resume=
79 |
80 | # transformer
81 | python trainer/train_transformer.py vq_resume= ft_category= ft_resume=
82 | ```
83 |
84 | Some example trainings:
85 |
86 | #### Vocabulary training
87 | ```bash
88 | python trainer/train_vocabulary.py batch_size=32 shift_augment=True scale_augment=True wandb_main=True experiment=vq128 val_check_percent=1.0 val_check_interval=5 overfit=False max_epoch=2000 only_chairs=False use_smoothed_loss=True graph_conv=sage use_point_feats=False num_workers=24 n_embed=16384 num_tokens=131 embed_levels=2 num_val_samples=16 use_multimodal_loss=True weight_decay=0.1 embed_dim=192 code_decay=0.99 embed_share=True distribute_features=True
89 | ```
90 | #### Base transformer training
91 | ```bash
92 |
93 | # run over multiple GPUs (recommended GPUs >= 8), if you have a good budget, can use higher gradient_accumulation_steps
94 |
95 | python trainer/train_transformer.py wandb_main=True batch_size=8 gradient_accumulation_steps=8 max_val_tokens=5000 max_epoch=2000 sanity_steps=0 val_check_interval=1 val_check_percent=1 block_size=4608 model.n_layer=24 model.n_head=16 model.n_embd=768 model.dropout=0 scale_augment=True shift_augment=True num_workers=24 experiment=bl4608-GPT2_m24-16-768-0_b8x8x8_lr1e-4 use_smoothed_loss=True num_tokens=131 vq_resume= padding=0
96 | ```
97 | #### Transformer finetuning
98 | ```bash
99 |
100 | # run over multiple GPUs (recommended GPUs >= 8), if you have a good budget, can use higher gradient_accumulation_steps
101 |
102 | python trainer/train_transformer.py wandb_main=True batch_size=8 gradient_accumulation_steps=8 max_val_tokens=5000 max_epoch=2400 sanity_steps=0 val_check_interval=8 val_check_percent=1 block_size=4608 model.n_layer=24 model.n_head=16 model.n_embd=768 model.dropout=0 scale_augment=True shift_augment=True num_workers=24 experiment=bl4608-GPT2_m24-16-768-0_b8x8x8_FT04379243 use_smoothed_loss=True num_tokens=131 vq_resume= padding=0 num_val_samples=4 ft_category=04379243 ft_resume= warmup_steps=100
103 |
104 | ```
105 |
106 | ## License
107 | MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers by Mohd Yawar Nihal Siddiqui is licensed under [Automotive Development Public Non-Commercial License Version 1.0](LICENSE), however portions of the project are available under separate license terms: e.g. NanoGPT code is under MIT license.
108 |
109 | ## Citation
110 |
111 | If you wish to cite us, please use the following BibTeX entry:
112 |
113 | ```BibTeX
114 | @InProceedings{siddiqui_meshgpt_2024,
115 | title={MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers},
116 | author={Siddiqui, Yawar and Alliegro, Antonio and Artemov, Alexey and Tommasi, Tatiana and Sirigatti, Daniele and Rosov, Vladislav and Dai, Angela and Nie{\ss}ner, Matthias},
117 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
118 | year = {2024},
119 | }
120 |
121 | ```
--------------------------------------------------------------------------------
/config/meshgpt.yaml:
--------------------------------------------------------------------------------
1 | # data
2 | dataset: 'shapenet'
3 | dataset_root: 'data/shapenet/processed_data.pkl'
4 | # not actually number of tokens, but the quantization resolution + 3
5 | num_tokens: 131
6 |
7 | gradient_accumulation_steps: 2 # used to simulate larger batch sizes
8 | batch_size: 32 # if gradient_accumulation_steps > 1, this is the micro-batch size
9 | block_size: 4608 # context_size of the transformer, for good performance block_size > max sequence length
10 | padding: 0.0 # fraction of padding allowed, when sequences are beyond the transformer context size
11 |
12 | scale_augment: True # augment shapes by random scaling
13 | scale_augment_val: False # augment val shapes by random scaling
14 |
15 | shift_augment: True # augment shapes by shifting them in space
16 | shift_augment_val: False
17 |
18 | wandb_main: False # set to true to log to main board rather than debug board
19 | suffix: '' # suffix for project name in wandb, if wandb_main is false, auto populated to dump debug experiments to a different project
20 | experiment: snet # experiment name
21 | seed: null
22 | save_epoch: 1 # save every n epoch
23 | sanity_steps: 1 # sanity steps before the run
24 | val_check_percent: 1.0 # check this proportion of val set when evaluation runs
25 | val_check_interval: 1 # run evaluation every x% of the train set steps
26 | resume: null # resume from a checkpoint
27 | num_workers: 24
28 | logger: wandb
29 | overfit: False # overfitting dataloaders
30 |
31 | num_val_samples: 16 # number of meshes to visualize in evaluation
32 | max_val_tokens: 5000
33 | top_k_tokens: 200 # sampling top-k tokens
34 | top_p: 0.9 # p val for nucleus sampling
35 | temperature: 0.8 # temprature for sampling
36 | sequence_stride: 32 # use when sequences are larger than context length
37 | use_smoothed_loss: True # smoothing over neighboring tokens in the quantized space
38 |
39 | use_point_feats: False # point net like point features in graph network
40 | graph_conv: sage # flavor of graph convs
41 | g_no_max_pool: True # no max pooling in graph conv
42 | g_aggr: mean # aggregation op in graph conv
43 | ce_output: True # use quantized predictions
44 | embed_dim: 192 # vq embed dim
45 | n_embed: 16384 # vq num embeddings
46 | embed_loss_weight: 1.0
47 | embed_levels: 2 # rvq levels
48 | tri_weight: 0.00 # weight on geometric predictions
49 | norm_weight: 0.00
50 | area_weight: 0.00
51 | angle_weight: 0.00
52 | code_decay: 0.99 # code decay for vq
53 | embed_share: True # share embeddings across rvq levels
54 | use_multimodal_loss: True # multiple modes in ce loss when training vocabulary
55 |
56 | vq_resume: null # path to trained vocab when training the transformer
57 | ft_resume: null # path to transformer trained on all categories when finetuning
58 | ft_category: null # shapenet category to finetune
59 | distribute_features: True # distribute face features across vertices
60 | low_augment: False # lower the scale of augmentation
61 |
62 | # model
63 | model:
64 | in_emb: 3
65 | n_layer: 24
66 | n_head: 16
67 | n_embd: 768
68 | dropout: 0.0 # for pretraining 0 is good, for finetuning try 0.1+
69 | bias: False # do we use bias inside LayerNorm and Linear layers?
70 |
71 | # adamw optimizer
72 | lr: 1e-4 # max learning rate
73 | force_lr: null
74 | max_epoch: 2000 # total number of training iterations
75 | weight_decay: 1e-1
76 | beta1: 0.9
77 | beta2: 0.95
78 | grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
79 |
80 | # learning rate decay settings
81 | warmup_steps: 2000 # how many steps to warm up for
82 | min_lr: 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
83 |
84 | only_chairs: False
85 | stochasticity: 0.1
86 |
87 | hydra:
88 | output_subdir: null # Disable saving of config files. We'll do that ourselves.
89 | run:
90 | dir: .
91 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import networkx as nx
4 |
5 |
6 | newface_token = 0
7 | stopface_token = 1
8 | padface_token = 2
9 |
10 |
11 | def get_shifted_sequence(sequence):
12 | non_special = np.flatnonzero(np.isin(sequence, [0, 1, 2], invert=True))
13 | if non_special.shape[0] > 0:
14 | idx = non_special[0]
15 | val = sequence[idx]
16 | sequence[non_special] -= (val - 3)
17 | return sequence
18 |
19 |
20 | def read_faces(text):
21 | all_lines = text.splitlines()
22 | all_face_lines = [x for x in all_lines if x.startswith('f ')]
23 | all_faces = [[int(y.split('/')[0]) - 1 for y in x.strip().split(' ')[1:]] for x in all_face_lines]
24 | return all_faces
25 |
26 |
27 | def read_vertices(text):
28 | all_lines = text.splitlines()
29 | all_vertex_lines = [x for x in all_lines if x.startswith('v ')]
30 | all_vertices = np.array([[float(y) for y in x.strip().split(' ')[1:]] for x in all_vertex_lines])
31 | assert all_vertices.shape[1] == 3, 'vertices should have 3 coordinates'
32 | return all_vertices
33 |
34 |
35 | def quantize_coordinates(coords, num_tokens=256):
36 | if torch.is_tensor(coords):
37 | coords = torch.clip((coords + 0.5), 0, 1) * num_tokens # type: ignore
38 | coords_quantized = coords.round().long()
39 | else:
40 | coords = np.clip((coords + 0.5), 0, 1) * num_tokens # type: ignore
41 | coords_quantized = coords.round().astype(int)
42 | return coords_quantized
43 |
44 |
45 | def face_to_cycles(face):
46 | """Find cycles in face."""
47 | g = nx.Graph()
48 | for v in range(len(face) - 1):
49 | g.add_edge(face[v], face[v + 1])
50 | g.add_edge(face[-1], face[0])
51 | return list(nx.cycle_basis(g))
52 |
53 |
54 | def sort_vertices_and_faces(vertices_, faces_, num_tokens=256):
55 | vertices = np.clip((vertices_ + 0.5), 0, 1) * num_tokens # type: ignore
56 | vertices_quantized_ = vertices.round().astype(int)
57 |
58 | vertices_quantized_ = vertices_quantized_[:, [2, 0, 1]]
59 | vertices_quantized, unique_inverse = np.unique(vertices_quantized_, axis=0, return_inverse=True)
60 |
61 | sort_inds = np.lexsort(vertices_quantized.T)
62 |
63 | vertices_quantized = vertices_quantized[sort_inds]
64 | vertices_quantized = np.stack([vertices_quantized[:, 2], vertices_quantized[:, 1], vertices_quantized[:, 0]], axis=-1)
65 |
66 | # Re-index faces and tris to re-ordered vertices.
67 | faces = [np.argsort(sort_inds)[unique_inverse[f]] for f in faces_]
68 | # Merging duplicate vertices and re-indexing the faces causes some faces to
69 | # contain loops (e.g [2, 3, 5, 2, 4]). Split these faces into distinct
70 | # sub-faces.
71 | sub_faces = []
72 | for f in faces:
73 | cliques = face_to_cycles(f)
74 | for c in cliques:
75 | c_length = len(c)
76 | # Only append faces with more than two verts.
77 | if c_length > 2:
78 | d = np.argmin(c)
79 | # Cyclically permute faces just that first index is the smallest.
80 | sub_faces.append([c[(d + i) % c_length] for i in range(c_length)])
81 | faces = sub_faces
82 | # Sort faces by lowest vertex indices. If two faces have the same lowest
83 | # index then sort by next lowest and so on.
84 | faces.sort(key=lambda f: tuple(sorted(f)))
85 |
86 | # After removing degenerate faces some vertices are now unreferenced.
87 | # Remove these.
88 | num_verts = vertices_quantized.shape[0]
89 | vert_connected = np.equal(
90 | np.arange(num_verts)[:, None], np.hstack(faces)[None]).any(axis=-1)
91 | vertices_quantized = vertices_quantized[vert_connected]
92 | # Re-index faces and tris to re-ordered vertices.
93 | vert_indices = (
94 | np.arange(num_verts) - np.cumsum(1 - vert_connected.astype('int')))
95 | faces = [vert_indices[f].tolist() for f in faces]
96 | vertices = vertices_quantized / num_tokens - 0.5
97 | # order: Z, Y, X --> X, Y, Z
98 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1)
99 | return vertices, faces
--------------------------------------------------------------------------------
/dataset/quantized_soup.py:
--------------------------------------------------------------------------------
1 | import random
2 | import omegaconf
3 | import torch
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | from pathlib import Path
7 | import torch.utils.data
8 | import pickle
9 | from numpy import random
10 | import trimesh
11 | import torch_scatter
12 |
13 | from dataset import get_shifted_sequence
14 | from dataset.triangles import create_feature_stack
15 | from trainer import create_conv_batch, get_rvqvae_v0_encoder_vq, get_rvqvae_v1_encoder_vq
16 | from util.misc import normalize_vertices, scale_vertices
17 | from util.visualization import triangle_sequence_to_mesh
18 |
19 |
20 | class QuantizedSoup(Dataset):
21 |
22 | def __init__(self, config, split, scale_augment):
23 | super().__init__()
24 | data_path = Path(config.dataset_root)
25 | self.block_size = config.block_size
26 | self.vq_depth = config.embed_levels
27 | self.vq_is_shared = config.embed_share
28 | self.vq_num_codes_per_level = config.n_embed
29 | self.cached_vertices = []
30 | self.cached_faces = []
31 | self.names = []
32 | self.num_tokens = config.num_tokens - 3
33 | self.scale_augment = scale_augment
34 | with open(data_path, 'rb') as fptr:
35 | data = pickle.load(fptr)
36 | if config.only_chairs:
37 | for s in ['train', 'val']:
38 | data[f'vertices_{s}'] = [data[f'vertices_{s}'][i] for i in range(len(data[f'vertices_{s}'])) if data[f'name_{s}'][i].split('_')[0] == '03001627']
39 | data[f'faces_{s}'] = [data[f'faces_{s}'][i] for i in range(len(data[f'faces_{s}'])) if data[f'name_{s}'][i].split('_')[0] == '03001627']
40 | data[f'name_{s}'] = [data[f'name_{s}'][i] for i in range(len(data[f'name_{s}'])) if data[f'name_{s}'][i].split('_')[0] == '03001627']
41 | if not config.overfit:
42 | self.names = data[f'name_{split}']
43 | self.cached_vertices = data[f'vertices_{split}']
44 | self.cached_faces = data[f'faces_{split}']
45 | else:
46 | multiplier = 16 if split == 'val' else 512
47 | self.names = data[f'name_train'][:1] * multiplier
48 | self.cached_vertices = data[f'vertices_train'][:1] * multiplier
49 | self.cached_faces = data[f'faces_train'][:1] * multiplier
50 |
51 | print(len(self.cached_vertices), "meshes loaded")
52 |
53 | max_inner_face_len = 0
54 | for i in range(len(self.cached_vertices)):
55 | self.cached_vertices[i] = np.array(self.cached_vertices[i])
56 | for j in range(len(self.cached_faces[i])):
57 | max_inner_face_len = max(max_inner_face_len, len(self.cached_faces[i][j]))
58 | print('Longest inner face sequence', max_inner_face_len)
59 | assert max_inner_face_len == 3, "Only triangles are supported"
60 |
61 | self.start_sequence_token = 0
62 | self.end_sequence_token = 1
63 | self.pad_face_token = 2
64 | self.padding = int(config.padding * self.block_size)
65 | self.indices = []
66 |
67 | max_face_sequence_len = 0
68 | min_face_sequence_len = 1e7
69 | for i in range(len(self.cached_faces)):
70 | sequence_len = len(self.cached_faces[i]) * self.vq_depth + 1 + 1
71 | max_face_sequence_len = max(max_face_sequence_len, sequence_len)
72 | min_face_sequence_len = min(min_face_sequence_len, sequence_len)
73 | for j in range(0, max(1, sequence_len - self.block_size + self.padding + 1), config.sequence_stride): # todo: possible bug? +1 added recently
74 | self.indices.append((i, j))
75 | print('Length of', split, len(self.indices))
76 | print('Shortest face sequence', min_face_sequence_len)
77 | print('Longest face sequence', max_face_sequence_len)
78 | self.encoder = None
79 | self.pre_quant = None
80 | self.vq = None
81 | self.post_quant = None
82 | self.decoder = None
83 | self.device = None
84 |
85 | def set_quantizer(self, encoder, pre_quant, vq, post_quant, decoder, device):
86 | self.encoder = encoder.eval()
87 | self.pre_quant = pre_quant.eval()
88 | self.decoder = decoder.eval()
89 | self.vq = vq.eval()
90 | self.post_quant = post_quant.eval()
91 | self.device = device
92 |
93 | @torch.no_grad()
94 | def get_codes(self, vertices, faces):
95 | triangles, normals, areas, angles, vertices, faces = create_feature_stack(vertices, faces, self.num_tokens)
96 | features = np.hstack([triangles, normals, areas, angles])
97 | face_neighborhood = np.array(trimesh.Trimesh(vertices=vertices, faces=faces, process=False).face_neighborhood) # type: ignore
98 |
99 | encoded_x = self.encoder(
100 | torch.from_numpy(features).float().to(self.device),
101 | torch.from_numpy(face_neighborhood.T).long().to(self.device),
102 | torch.zeros([features.shape[0]], device=self.device).long()
103 | )
104 |
105 | encoded_x = self.pre_quant(encoded_x)
106 | _, all_indices, _ = self.vq(encoded_x.unsqueeze(0))
107 | all_indices = all_indices.squeeze(0)
108 | if not self.vq_is_shared:
109 | correction = (torch.arange(0, self.vq_depth, device=self.device) * self.vq_num_codes_per_level).reshape(1, -1)
110 | all_indices = all_indices + correction
111 | inner_face_id = torch.arange(0, self.vq_depth, device=self.device).reshape(1, -1).expand(all_indices.shape[0], -1)
112 | outer_face_id = torch.arange(0, all_indices.shape[0], device=self.device).reshape(-1, 1).expand(-1, self.vq_depth)
113 |
114 | # adjust for start, end and padding tokens
115 | all_indices = all_indices.reshape(-1) + 3
116 | inner_face_id = inner_face_id.reshape(-1) + 3
117 | outer_face_id = outer_face_id.reshape(-1) + 3
118 |
119 | # add start token and end token
120 | all_indices = torch.cat((
121 | torch.tensor([self.start_sequence_token], device=self.device),
122 | all_indices,
123 | torch.tensor([self.end_sequence_token], device=self.device)
124 | )).long().cpu()
125 |
126 | inner_face_id = torch.cat((
127 | torch.tensor([self.start_sequence_token], device=self.device),
128 | inner_face_id,
129 | torch.tensor([self.end_sequence_token], device=self.device)
130 | )).long().cpu()
131 |
132 | outer_face_id = torch.cat((
133 | torch.tensor([self.start_sequence_token], device=self.device),
134 | outer_face_id,
135 | torch.tensor([self.end_sequence_token], device=self.device)
136 | )).long().cpu()
137 |
138 | return all_indices, inner_face_id, outer_face_id
139 |
140 | def __getitem__(self, index: int):
141 | i, j = self.indices[index]
142 | vertices = self.cached_vertices[i]
143 | faces = self.cached_faces[i]
144 | if self.scale_augment:
145 | vertices = normalize_vertices(scale_vertices(vertices))
146 | else:
147 | vertices = normalize_vertices(vertices)
148 |
149 | soup_sequence, face_in_idx, face_out_idx = self.get_codes(vertices, faces)
150 |
151 | # face sequence in block format
152 | end_index = min(j + self.block_size, len(soup_sequence))
153 | x_in = soup_sequence[j:end_index]
154 | y_in = soup_sequence[j + 1:end_index + 1]
155 | fpi_in = face_in_idx[j:end_index]
156 | fpo_in = face_out_idx[j:end_index]
157 |
158 | x_pad = torch.tensor([self.pad_face_token for _ in range(0, self.block_size - len(x_in))])
159 | y_pad = torch.tensor([self.pad_face_token for _ in range(0, len(x_in) + len(x_pad) - len(y_in))])
160 | fpi_in_pad = torch.tensor([self.pad_face_token for _ in range(0, self.block_size - len(fpi_in))])
161 | fpo_in_pad = torch.tensor([self.pad_face_token for _ in range(0, self.block_size - len(fpo_in))])
162 |
163 | x = torch.cat((x_in, x_pad)).long()
164 | y = torch.cat((y_in, y_pad)).long()
165 | fpi = torch.cat((fpi_in, fpi_in_pad)).long()
166 | fpo = torch.from_numpy(get_shifted_sequence(torch.cat((fpo_in, fpo_in_pad)).numpy())).long()
167 |
168 | return {
169 | 'name': self.names[i],
170 | 'seq_in': x,
171 | 'seq_out': y,
172 | 'seq_pos_inner': fpi,
173 | 'seq_pos_outer': fpo,
174 | }
175 |
176 | def get_completion_sequence(self, i, tokens, device=torch.device("cpu")):
177 | vertices = normalize_vertices(self.cached_vertices[i])
178 | faces = self.cached_faces[i]
179 | soup_sequence, face_in_idx, face_out_idx = self.get_codes(vertices, faces)
180 | face_out_idx = torch.from_numpy(get_shifted_sequence(face_out_idx.numpy()))
181 | original_fseq = soup_sequence.to(device)
182 | if isinstance(tokens, int):
183 | num_pre_tokens = tokens
184 | else:
185 | num_pre_tokens = int(len(original_fseq) * tokens)
186 | x = (
187 | soup_sequence[:num_pre_tokens].to(device)[None, ...],
188 | face_in_idx[:num_pre_tokens].to(device)[None, ...],
189 | face_out_idx[:num_pre_tokens].to(device)[None, ...],
190 | original_fseq[None, ...],
191 | )
192 | return x
193 |
194 | def get_start(self, device=torch.device("cpu")):
195 | i = random.choice(list(range(len(self.cached_vertices))))
196 | x = self.get_completion_sequence(i, 11, device)
197 | return x
198 |
199 | def __len__(self) -> int:
200 | return len(self.indices)
201 |
202 | def decode(self, sequence):
203 | mask = torch.isin(sequence, torch.tensor([self.start_sequence_token, self.end_sequence_token, self.pad_face_token], device=sequence.device)).logical_not()
204 | sequence = sequence[mask]
205 | sequence = sequence - 3
206 |
207 | sequence_len = (sequence.shape[0] // self.vq_depth) * self.vq_depth
208 | sequence = sequence[:sequence_len].reshape(-1, self.vq_depth)
209 | N = sequence.shape[0]
210 | E, D = self.vq_num_codes_per_level, self.vq_depth
211 | all_codes = self.vq.get_codes_from_indices(sequence).permute(1, 2, 0)
212 | encoded_x = all_codes.reshape(N, E, D).sum(-1)
213 | encoded_x = self.post_quant(encoded_x)
214 | encoded_x_conv, conv_mask = create_conv_batch(encoded_x, torch.zeros([sequence.shape[0]], device=self.device).long(), 1, self.device)
215 | decoded_x_conv = self.decoder(encoded_x_conv)
216 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :]
217 | coords = decoded_x.argmax(-1).detach().cpu().numpy() / self.num_tokens - 0.5
218 | gen_vertices, gen_faces = triangle_sequence_to_mesh(coords)
219 | return gen_vertices, gen_faces
220 |
221 | def plot_sequence_lenght_stats(self):
222 | sequence_lengths = []
223 | for i in range(len(self.cached_faces)):
224 | sequence_len = len(self.cached_faces[i]) * self.vq_depth + 1 + 1
225 | sequence_lengths.append(sequence_len)
226 | import matplotlib.pyplot as plt
227 | plt.hist(sequence_lengths, bins=32)
228 | plt.ylim(0, 100)
229 | plt.show()
230 |
231 |
232 | class QuantizedSoupCreator(torch.nn.Module):
233 | vq_depth_factor = 1
234 | def __init__(self, config, vq_cfg):
235 | super().__init__()
236 | self.vq_cfg = vq_cfg
237 | self.vq_depth = self.vq_cfg.embed_levels
238 | self.vq_is_shared = self.vq_cfg.embed_share
239 | self.vq_num_codes_per_level = self.vq_cfg.n_embed
240 | self.vq_dim = self.vq_cfg.embed_dim
241 | assert config.num_tokens == self.vq_cfg.num_tokens, "Number of tokens must match"
242 | self.block_size = config.block_size
243 | self.start_sequence_token = 0
244 | self.end_sequence_token = 1
245 | self.pad_face_token = 2
246 | self.num_tokens = config.num_tokens - 3
247 | self.padding = int(config.padding * self.block_size)
248 | self.rq_transformer_input = False
249 | self.encoder, self.pre_quant, self.post_quant, self.vq = self.get_rvq_encoders(config.vq_resume)
250 |
251 | def get_rvq_encoders(self, resume):
252 | return get_rvqvae_v0_encoder_vq(self.vq_cfg, resume)
253 |
254 | def freeze_vq(self):
255 | for model in [self.encoder, self.pre_quant, self.post_quant, self.vq]:
256 | for param in model.parameters():
257 | param.requires_grad = False
258 |
259 | @torch.no_grad()
260 | def embed(self, idx):
261 | assert self.vq_is_shared, "Only shared embedding is supported"
262 | all_codes = self.vq.codebooks[0][idx].reshape(-1, self.vq_dim)
263 | return all_codes
264 |
265 | @torch.no_grad()
266 | def get_indices(self, x, edge_index, batch, _faces, _num_vertices):
267 | encoded_x = self.encoder(x, edge_index, batch)
268 | encoded_x = self.pre_quant(encoded_x)
269 | _, all_indices, _ = self.vq(encoded_x.unsqueeze(0))
270 | all_indices = all_indices.squeeze(0)
271 | if not self.vq_is_shared:
272 | correction = (torch.arange(0, self.vq_depth * self.vq_depth_factor, device=x.device) * self.vq_num_codes_per_level).reshape(1, -1)
273 | all_indices = all_indices + correction
274 | return all_indices
275 |
276 | @torch.no_grad()
277 | def forward(self, x, edge_index, batch, faces, num_vertices, js, force_full_sequence=False):
278 | for model in [self.encoder, self.pre_quant, self.post_quant, self.vq]:
279 | model.eval()
280 | if force_full_sequence:
281 | assert js.shape[0] == 1, "Only single mesh supported"
282 | all_indices = self.get_indices(x, edge_index, batch, faces, num_vertices)
283 | batch_size = js.shape[0]
284 | sequences = []
285 | targets = []
286 | position_inners = []
287 | position_outers = []
288 | max_sequence_length_x = 0
289 | for k in range(batch_size):
290 | sequence_k = all_indices[batch == k, :]
291 | inner_face_id_k = torch.arange(0, self.vq_depth * self.vq_depth_factor, device=x.device).reshape(1, -1).expand(sequence_k.shape[0], -1)
292 | outer_face_id_k = torch.arange(0, sequence_k.shape[0], device=x.device).reshape(-1, 1).expand(-1, self.vq_depth * self.vq_depth_factor)
293 | sequence_k = sequence_k.reshape(-1) + 3
294 | inner_face_id_k = inner_face_id_k.reshape(-1) + 3
295 | outer_face_id_k = outer_face_id_k.reshape(-1) + 3
296 | # add start token and end token
297 |
298 | prefix = [torch.tensor([self.start_sequence_token], device=x.device)]
299 | postfix = [torch.tensor([self.end_sequence_token], device=x.device)]
300 |
301 | if self.rq_transformer_input:
302 | prefix = prefix * self.vq_depth
303 | postfix = postfix * self.vq_depth
304 |
305 | sequence_k = torch.cat(prefix + [sequence_k] + postfix).long()
306 |
307 | inner_face_id_k = torch.cat(prefix + [inner_face_id_k] + postfix).long()
308 |
309 | outer_face_id_k = torch.cat(prefix + [outer_face_id_k] + postfix).long()
310 |
311 | j = js[k]
312 | if force_full_sequence:
313 | end_index = len(sequence_k)
314 | else:
315 | end_index = min(j + self.block_size, len(sequence_k))
316 | x_in = sequence_k[j:end_index]
317 | if self.rq_transformer_input:
318 | y_in = sequence_k[j + self.vq_depth:end_index + self.vq_depth]
319 | else:
320 | y_in = sequence_k[j + 1:end_index + 1]
321 | fpi_in = inner_face_id_k[j:end_index]
322 | fpo_in = outer_face_id_k[j:end_index].cpu()
323 |
324 | max_sequence_length_x = max(max_sequence_length_x, len(x_in))
325 | pad_len_x = self.block_size - len(x_in)
326 | if self.rq_transformer_input:
327 | pad_len_x = pad_len_x + (self.vq_depth - (len(x_in) + pad_len_x) % self.vq_depth)
328 | x_pad = torch.tensor([self.pad_face_token for _ in range(0, pad_len_x)], device=x.device)
329 |
330 | pad_len_y = len(x_in) + len(x_pad) - len(y_in)
331 | pad_len_fpi = self.block_size - len(fpi_in)
332 | pad_len_fpo = self.block_size - len(fpo_in)
333 |
334 | if self.rq_transformer_input:
335 | pad_len_fpi = pad_len_fpi + (self.vq_depth - (len(fpi_in) + pad_len_fpi) % self.vq_depth)
336 | pad_len_fpo = pad_len_fpo + (self.vq_depth - (len(fpo_in) + pad_len_fpo) % self.vq_depth)
337 |
338 | y_pad = torch.tensor([self.pad_face_token for _ in range(0, pad_len_y)], device=x.device)
339 | fpi_in_pad = torch.tensor([self.pad_face_token for _ in range(0, pad_len_fpi)], device=x.device)
340 | fpo_in_pad = torch.tensor([self.pad_face_token for _ in range(0, pad_len_fpo)])
341 |
342 | x = torch.cat((x_in, x_pad)).long()
343 | y = torch.cat((y_in, y_pad)).long()
344 | fpi = torch.cat((fpi_in, fpi_in_pad)).long()
345 | fpo = torch.from_numpy(get_shifted_sequence(torch.cat((fpo_in, fpo_in_pad)).numpy())).long().to(x.device)
346 |
347 | sequences.append(x)
348 | targets.append(y)
349 | position_inners.append(fpi)
350 | position_outers.append(fpo)
351 |
352 | sequences = torch.stack(sequences, dim=0)[:, :max_sequence_length_x].contiguous()
353 | targets = torch.stack(targets, dim=0)[:, :max_sequence_length_x].contiguous()
354 | position_inners = torch.stack(position_inners, dim=0)[:, :max_sequence_length_x].contiguous()
355 | position_outers = torch.stack(position_outers, dim=0)[:, :max_sequence_length_x].contiguous()
356 | return sequences, targets, position_inners, position_outers
357 |
358 | @torch.no_grad()
359 | def get_completion_sequence(self, x, edge_index, face, num_vertices, tokens):
360 | soup_sequence, target, face_in_idx, face_out_idx = self.forward(
361 | x, edge_index,
362 | torch.zeros([x.shape[0]], device=x.device).long(), face,
363 | num_vertices,
364 | torch.zeros([1], device=x.device).long(),
365 | force_full_sequence=True
366 | )
367 | soup_sequence = soup_sequence[0]
368 | face_in_idx = face_in_idx[0]
369 | face_out_idx = face_out_idx[0]
370 | target = target[0]
371 | if isinstance(tokens, int):
372 | num_pre_tokens = tokens
373 | else:
374 | num_pre_tokens = int(len(target) * tokens)
375 | x = (
376 | soup_sequence[:num_pre_tokens][None, ...],
377 | face_in_idx[:num_pre_tokens][None, ...],
378 | face_out_idx[:num_pre_tokens][None, ...],
379 | target[None, ...],
380 | )
381 | return x
382 |
383 | def encode_sequence(self, sequence):
384 | N = sequence.shape[0]
385 | E, D = self.vq_dim, self.vq_depth
386 | all_codes = self.vq.get_codes_from_indices(sequence).permute(1, 2, 0)
387 | encoded_x = all_codes.reshape(N, E, D).sum(-1)
388 | return encoded_x
389 |
390 | @torch.no_grad()
391 | def decode(self, sequence, decoder):
392 | mask = torch.isin(sequence, torch.tensor([self.start_sequence_token, self.end_sequence_token, self.pad_face_token], device=sequence.device)).logical_not()
393 | sequence = sequence[mask]
394 | sequence = sequence - 3
395 | sequence_len = (sequence.shape[0] // (self.vq_depth * self.vq_depth_factor)) * (self.vq_depth * self.vq_depth_factor)
396 | sequence = sequence[:sequence_len].reshape(-1, self.vq_depth)
397 | encoded_x = self.encode_sequence(sequence)
398 | encoded_x = self.post_quant(encoded_x)
399 | encoded_x_conv, conv_mask = create_conv_batch(encoded_x, torch.zeros([encoded_x.shape[0]], device=sequence.device).long(), 1, sequence.device)
400 | decoded_x_conv = decoder(encoded_x_conv)
401 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :]
402 | coords = decoded_x.argmax(-1).detach().cpu().numpy() / self.num_tokens - 0.5
403 | gen_vertices, gen_faces = triangle_sequence_to_mesh(coords)
404 | return gen_vertices, gen_faces
405 |
406 |
407 | class QuantizedSoupTripletsCreator(QuantizedSoupCreator):
408 | vq_depth_factor = 3
409 | def __init__(self, config, vq_cfg):
410 | super().__init__(config, vq_cfg)
411 |
412 | def get_rvq_encoders(self, resume):
413 | return get_rvqvae_v1_encoder_vq(self.vq_cfg, resume)
414 |
415 | @torch.no_grad()
416 | def get_indices(self, x, edge_index, batch, faces, num_vertices):
417 | encoded_x = self.encoder(x, edge_index, batch)
418 | encoded_x = encoded_x.reshape(encoded_x.shape[0] * 3, 192)
419 | encoded_x = distribute_features(encoded_x, faces, num_vertices, x.device)
420 | encoded_x = self.pre_quant(encoded_x)
421 | _, all_indices, _ = self.vq(encoded_x.unsqueeze(0))
422 | all_indices = all_indices.squeeze(0).reshape(-1, self.vq_depth * 3)
423 | if not self.vq_is_shared:
424 | correction = (torch.arange(0, self.vq_depth, device=x.device) * self.vq_num_codes_per_level).reshape(1, -1)
425 | all_indices = all_indices + correction
426 | return all_indices
427 |
428 | def encode_sequence(self, sequence):
429 | N = sequence.shape[0]
430 | E, D = self.vq_dim, self.vq_depth
431 | all_codes = self.vq.get_codes_from_indices(sequence).permute(1, 2, 0)
432 | encoded_x = all_codes.reshape(-1, 3 * E, D).sum(-1)
433 | return encoded_x
434 |
435 |
436 | def distribute_features(features, face_indices, num_vertices, device):
437 | assert features.shape[0] == face_indices.shape[0] * face_indices.shape[1], "Features and face indices must match in size"
438 | vertex_features = torch.zeros([num_vertices, features.shape[1]], device=device)
439 | torch_scatter.scatter_mean(features, face_indices.reshape(-1), out=vertex_features, dim=0)
440 | distributed_features = vertex_features[face_indices.reshape(-1), :]
441 | return distributed_features
442 |
--------------------------------------------------------------------------------
/dataset/triangles.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping, Sequence
2 |
3 | import omegaconf
4 | import torch
5 | import numpy as np
6 | import trimesh
7 | from torch.utils.data import Dataset, default_collate
8 | from pathlib import Path
9 | import torch.utils.data
10 | import pickle
11 |
12 | from torch_geometric.data.data import BaseData
13 | from tqdm import tqdm
14 |
15 | from dataset import sort_vertices_and_faces, quantize_coordinates
16 | from util.misc import normalize_vertices, scale_vertices, shift_vertices
17 | from torch_geometric.data import Dataset as GeometricDataset, Batch
18 | from torch_geometric.data import Data as GeometricData
19 | from torch_geometric.loader.dataloader import Collater as GeometricCollator
20 |
21 |
22 | class TriangleNodes(GeometricDataset):
23 |
24 | def __init__(self, config, split, scale_augment, shift_augment, force_category, use_start_stop=False, only_backward_edges=False):
25 | super().__init__()
26 | data_path = Path(config.dataset_root)
27 | self.cached_vertices = []
28 | self.cached_faces = []
29 | self.names = []
30 | self.scale_augment = scale_augment
31 | self.shift_augment = shift_augment
32 | self.low_augment = config.low_augment
33 | self.use_start_stop = use_start_stop
34 | self.ce_output = config.ce_output
35 | self.only_backward_edges = only_backward_edges
36 | self.num_tokens = config.num_tokens - 3
37 | with open(data_path, 'rb') as fptr:
38 | data = pickle.load(fptr)
39 | if force_category is not None:
40 | for s in ['train', 'val']:
41 | data[f'vertices_{s}'] = [data[f'vertices_{s}'][i] for i in range(len(data[f'vertices_{s}'])) if data[f'name_{s}'][i].split('_')[0] == force_category]
42 | data[f'faces_{s}'] = [data[f'faces_{s}'][i] for i in range(len(data[f'faces_{s}'])) if data[f'name_{s}'][i].split('_')[0] == force_category]
43 | data[f'name_{s}'] = [data[f'name_{s}'][i] for i in range(len(data[f'name_{s}'])) if data[f'name_{s}'][i].split('_')[0] == force_category]
44 | if len(data[f'vertices_val']) == 0:
45 | data[f'vertices_val'] = data[f'vertices_train']
46 | data[f'faces_val'] = data[f'faces_train']
47 | data[f'name_val'] = data[f'name_train']
48 | if not config.overfit:
49 | self.names = data[f'name_{split}']
50 | self.cached_vertices = data[f'vertices_{split}']
51 | self.cached_faces = data[f'faces_{split}']
52 | else:
53 | multiplier = 16 if split == 'val' else 512
54 | self.names = data[f'name_train'][:1] * multiplier
55 | self.cached_vertices = data[f'vertices_train'][:1] * multiplier
56 | self.cached_faces = data[f'faces_train'][:1] * multiplier
57 |
58 | print(len(self.cached_vertices), "meshes loaded")
59 |
60 | def len(self):
61 | return len(self.cached_vertices)
62 |
63 | def get_all_features_for_shape(self, idx):
64 | vertices = self.cached_vertices[idx]
65 | faces = self.cached_faces[idx]
66 | if self.scale_augment:
67 | if self.low_augment:
68 | x_lims = (0.9, 1.1)
69 | y_lims = (0.9, 1.1)
70 | z_lims = (0.9, 1.1)
71 | else:
72 | x_lims = (0.75, 1.25)
73 | y_lims = (0.75, 1.25)
74 | z_lims = (0.75, 1.25)
75 | vertices = scale_vertices(vertices, x_lims=x_lims, y_lims=y_lims, z_lims=z_lims)
76 | vertices = normalize_vertices(vertices)
77 | if self.shift_augment:
78 | vertices = shift_vertices(vertices)
79 | triangles, normals, areas, angles, vertices, faces = create_feature_stack(vertices, faces, self.num_tokens)
80 | features = np.hstack([triangles, normals, areas, angles])
81 | face_neighborhood = np.array(trimesh.Trimesh(vertices=vertices, faces=faces, process=False).face_neighborhood) # type: ignore
82 | target = torch.from_numpy(features[:, :9]).float()
83 | if self.use_start_stop:
84 | features = np.concatenate([np.zeros((1, features.shape[1])), features], axis=0)
85 | target = torch.cat([target, torch.ones(1, 9) * 0.5], dim=0)
86 | face_neighborhood = face_neighborhood + 1
87 | if self.only_backward_edges:
88 | face_neighborhood = face_neighborhood[face_neighborhood[:, 1] > face_neighborhood[:, 0], :]
89 | # face_neighborhood = modify so that only edges in backward direction are present
90 | if self.ce_output:
91 | target = quantize_coordinates(target, self.num_tokens)
92 | return features, target, vertices, faces, face_neighborhood
93 |
94 | def get(self, idx):
95 | features, target, _, _, face_neighborhood = self.get_all_features_for_shape(idx)
96 | return GeometricData(x=torch.from_numpy(features).float(), y=target, edge_index=torch.from_numpy(face_neighborhood.T).long())
97 |
98 |
99 | class TriangleNodesWithFaces(TriangleNodes):
100 |
101 | def __init__(self, config, split, scale_augment, shift_augment, force_category):
102 | super().__init__(config, split, scale_augment, shift_augment, force_category)
103 |
104 | def get(self, idx):
105 | features, target, vertices, faces, face_neighborhood = self.get_all_features_for_shape(idx)
106 | return GeometricData(x=torch.from_numpy(features).float(), y=target,
107 | edge_index=torch.from_numpy(face_neighborhood.T).long(),
108 | num_vertices=vertices.shape[0], faces=torch.from_numpy(np.array(faces)).long())
109 |
110 |
111 | class TriangleNodesWithFacesDataloader(torch.utils.data.DataLoader):
112 |
113 | def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=None, exclude_keys=None, **kwargs):
114 | # Remove for PyTorch Lightning:
115 | kwargs.pop('collate_fn', None)
116 | # Save for PyTorch Lightning < 1.6:
117 | self.follow_batch = follow_batch
118 | self.exclude_keys = exclude_keys
119 | super().__init__(
120 | dataset,
121 | batch_size,
122 | shuffle,
123 | collate_fn=FaceCollator(follow_batch, exclude_keys),
124 | **kwargs,
125 | )
126 |
127 |
128 | class FaceCollator(GeometricCollator):
129 |
130 | def __init__(self, follow_batch, exclude_keys):
131 | super().__init__(follow_batch, exclude_keys)
132 |
133 | def __call__(self, batch):
134 | elem = batch[0]
135 |
136 | num_vertices = 0
137 | for b_idx in range(len(batch)):
138 | batch[b_idx].faces += num_vertices
139 | num_vertices += batch[b_idx].num_vertices
140 |
141 | if isinstance(elem, BaseData):
142 | return Batch.from_data_list(batch, self.follow_batch, self.exclude_keys)
143 | elif isinstance(elem, torch.Tensor):
144 | return default_collate(batch)
145 | elif isinstance(elem, float):
146 | return torch.tensor(batch, dtype=torch.float)
147 | elif isinstance(elem, int):
148 | return torch.tensor(batch)
149 | elif isinstance(elem, str):
150 | return batch
151 | elif isinstance(elem, Mapping):
152 | return {key: self([data[key] for data in batch]) for key in elem}
153 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
154 | return type(elem)(*(self(s) for s in zip(*batch)))
155 | elif isinstance(elem, Sequence) and not isinstance(elem, str):
156 | return [self(s) for s in zip(*batch)]
157 |
158 | raise TypeError(f'DataLoader found invalid type: {type(elem)}')
159 |
160 | def collate(self, batch): # pragma: no cover
161 | raise NotImplementedError
162 |
163 | class TriangleNodesWithSequenceIndices(TriangleNodes):
164 |
165 | vq_depth_factor = 1
166 |
167 | def __init__(self, config, split, scale_augment, shift_augment, force_category):
168 | super().__init__(config, split, scale_augment=scale_augment, shift_augment=shift_augment, force_category=force_category)
169 | vq_cfg = omegaconf.OmegaConf.load(Path(config.vq_resume).parents[1] / "config.yaml")
170 | self.vq_depth = vq_cfg.embed_levels
171 | self.block_size = config.block_size
172 | max_inner_face_len = 0
173 | self.padding = int(config.padding * self.block_size)
174 | self.sequence_stride = config.sequence_stride
175 | for i in range(len(self.cached_vertices)):
176 | self.cached_vertices[i] = np.array(self.cached_vertices[i])
177 | for j in range(len(self.cached_faces[i])):
178 | max_inner_face_len = max(max_inner_face_len, len(self.cached_faces[i][j]))
179 | print('Longest inner face sequence', max_inner_face_len)
180 | assert max_inner_face_len == 3, f"Only triangles are supported, but found a face with {max_inner_face_len}."
181 | self.sequence_indices = []
182 | max_face_sequence_len = 0
183 | min_face_sequence_len = 1e7
184 | for i in range(len(self.cached_faces)):
185 | sequence_len = len(self.cached_faces[i]) * self.vq_depth * self.vq_depth_factor + 1 + 1
186 | max_face_sequence_len = max(max_face_sequence_len, sequence_len)
187 | min_face_sequence_len = min(min_face_sequence_len, sequence_len)
188 | self.sequence_indices.append((i, 0, False))
189 | for j in range(config.sequence_stride, max(1, sequence_len - self.block_size + self.padding + 1), config.sequence_stride): # todo: possible bug? +1 added recently
190 | self.sequence_indices.append((i, j, True if split == 'train' else False))
191 | if sequence_len > self.block_size:
192 | self.sequence_indices.append((i, sequence_len - self.block_size, False))
193 | print('Length of', split, len(self.sequence_indices))
194 | print('Shortest face sequence', min_face_sequence_len)
195 | print('Longest face sequence', max_face_sequence_len)
196 |
197 | def len(self):
198 | return len(self.sequence_indices)
199 |
200 | def get(self, idx):
201 | i, j, randomness = self.sequence_indices[idx]
202 | if randomness:
203 | sequence_len = len(self.cached_faces[i]) * self.vq_depth * self.vq_depth_factor + 1 + 1
204 | j = min(max(0, j + np.random.randint(-self.sequence_stride // 2, self.sequence_stride // 2)), sequence_len - self.block_size + self.padding)
205 | features, target, _, _, face_neighborhood = self.get_all_features_for_shape(i)
206 | return GeometricData(x=torch.from_numpy(features).float(), y=target, edge_index=torch.from_numpy(face_neighborhood.T).long(), js=torch.tensor(j).long())
207 |
208 | def plot_sequence_lenght_stats(self):
209 | sequence_lengths = []
210 | for i in range(len(self.cached_faces)):
211 | sequence_len = len(self.cached_faces[i]) * self.vq_depth * self.vq_depth_factor + 1 + 1
212 | sequence_lengths.append(sequence_len)
213 | import matplotlib.pyplot as plt
214 | plt.hist(sequence_lengths, bins=32)
215 | plt.ylim(0, 100)
216 | plt.show()
217 | return sequence_lengths
218 |
219 |
220 | class TriangleNodesWithFacesAndSequenceIndices(TriangleNodesWithSequenceIndices):
221 | vq_depth_factor = 3
222 | def __init__(self, config, split, scale_augment, shift_augment, force_category):
223 | super().__init__(config, split, scale_augment, shift_augment, force_category)
224 |
225 | def get(self, idx):
226 | i, j, randomness = self.sequence_indices[idx]
227 | if randomness:
228 | sequence_len = len(self.cached_faces[i]) * self.vq_depth * self.vq_depth_factor + 1 + 1
229 | j = min(max(0, j + np.random.randint(-self.sequence_stride // 2, self.sequence_stride // 2)), sequence_len - self.block_size + self.padding)
230 | features, target, vertices, faces, face_neighborhood = self.get_all_features_for_shape(i)
231 | return GeometricData(x=torch.from_numpy(features).float(),
232 | y=target, mesh_name=self.names[i], edge_index=torch.from_numpy(face_neighborhood.T).long(),
233 | js=torch.tensor(j).long(), num_vertices=vertices.shape[0],
234 | faces=torch.from_numpy(np.array(faces)).long())
235 |
236 |
237 | class Triangles(Dataset):
238 |
239 | def __init__(self, config, split, scale_augment, shift_augment):
240 | super().__init__()
241 | data_path = Path(config.dataset_root)
242 | self.cached_vertices = []
243 | self.cached_faces = []
244 | self.names = []
245 | self.scale_augment = scale_augment
246 | self.shift_augment = shift_augment
247 | with open(data_path, 'rb') as fptr:
248 | data = pickle.load(fptr)
249 | if not config.overfit:
250 | self.names = data[f'name_{split}']
251 | self.cached_vertices = data[f'vertices_{split}']
252 | self.cached_faces = data[f'faces_{split}']
253 | else:
254 | multiplier = 1 if split == 'val' else 500
255 | self.names = data[f'name_train'][:1] * multiplier
256 | self.cached_vertices = data[f'vertices_train'][:1] * multiplier
257 | self.cached_faces = data[f'faces_train'][:1] * multiplier
258 |
259 | print(len(self.cached_vertices), "meshes loaded")
260 | self.features = None
261 | self.setup_triangles_for_epoch()
262 |
263 | def __len__(self):
264 | return self.features.shape[0]
265 |
266 | def setup_triangles_for_epoch(self):
267 | all_features = []
268 | for idx in tqdm(range(len(self.cached_vertices)), desc="refresh augs"):
269 | vertices = self.cached_vertices[idx]
270 | faces = self.cached_faces[idx]
271 | if self.scale_augment:
272 | vertices = scale_vertices(vertices)
273 | vertices = normalize_vertices(vertices)
274 | if self.shift_augment:
275 | vertices = shift_vertices(vertices)
276 | all_features.append(create_feature_stack(vertices, faces)[0])
277 | self.features = np.vstack(all_features)
278 |
279 | def __getitem__(self, idx):
280 | return {
281 | 'features': self.features[idx],
282 | 'target': self.features[idx, :9]
283 | }
284 |
285 | def get_all_features_for_shape(self, idx):
286 | vertices = self.cached_vertices[idx]
287 | faces = self.cached_faces[idx]
288 | feature_stack = create_feature_stack(vertices, faces)[0]
289 | return torch.from_numpy(feature_stack).float(), torch.from_numpy(feature_stack[:, :9]).float()
290 |
291 |
292 | def normal(triangles):
293 | # The cross product of two sides is a normal vector
294 | if torch.is_tensor(triangles):
295 | return torch.cross(triangles[:, 1] - triangles[:, 0], triangles[:, 2] - triangles[:, 0], dim=1)
296 | else:
297 | return np.cross(triangles[:, 1] - triangles[:, 0], triangles[:, 2] - triangles[:, 0], axis=1)
298 |
299 |
300 | def area(triangles):
301 | # The norm of the cross product of two sides is twice the area
302 | if torch.is_tensor(triangles):
303 | return torch.norm(normal(triangles), dim=1) / 2
304 | else:
305 | return np.linalg.norm(normal(triangles), axis=1) / 2
306 |
307 |
308 | def angle(triangles):
309 | v_01 = triangles[:, 1] - triangles[:, 0]
310 | v_02 = triangles[:, 2] - triangles[:, 0]
311 | v_10 = -v_01
312 | v_12 = triangles[:, 2] - triangles[:, 1]
313 | v_20 = -v_02
314 | v_21 = -v_12
315 | if torch.is_tensor(triangles):
316 | return torch.stack([angle_between(v_01, v_02), angle_between(v_10, v_12), angle_between(v_20, v_21)], dim=1)
317 | else:
318 | return np.stack([angle_between(v_01, v_02), angle_between(v_10, v_12), angle_between(v_20, v_21)], axis=1)
319 |
320 |
321 | def angle_between(v0, v1):
322 | v0_u = unit_vector(v0)
323 | v1_u = unit_vector(v1)
324 | if torch.is_tensor(v0):
325 | return torch.arccos(torch.clip(torch.einsum('ij,ij->i', v0_u, v1_u), -1.0, 1.0))
326 | else:
327 | return np.arccos(np.clip(np.einsum('ij,ij->i', v0_u, v1_u), -1.0, 1.0))
328 |
329 |
330 | def unit_vector(vector):
331 | if torch.is_tensor(vector):
332 | return vector / (torch.norm(vector, dim=-1)[:, None] + 1e-8)
333 | else:
334 | return vector / (np.linalg.norm(vector, axis=-1)[:, None] + 1e-8)
335 |
336 |
337 | def create_feature_stack(vertices, faces, num_tokens):
338 | vertices, faces = sort_vertices_and_faces(vertices, faces, num_tokens)
339 | # need more features: positions, angles, area, cross_product
340 | triangles = vertices[faces, :]
341 | triangles, normals, areas, angles = create_feature_stack_from_triangles(triangles)
342 | return triangles, normals, areas, angles, vertices, faces
343 |
344 |
345 | def create_feature_stack_from_triangles(triangles):
346 | t_areas = area(triangles) * 1e3
347 | t_angles = angle(triangles) / float(np.pi)
348 | t_normals = unit_vector(normal(triangles))
349 | return triangles.reshape(-1, 9), t_normals.reshape(-1, 3), t_areas.reshape(-1, 1), t_angles.reshape(-1, 3)
350 |
--------------------------------------------------------------------------------
/inference/infer_meshgpt.py:
--------------------------------------------------------------------------------
1 | import random
2 | import sys
3 | from model.pointnet import get_pointnet_classifier
4 | import omegaconf
5 | import torch
6 | from pathlib import Path
7 |
8 | import trimesh
9 |
10 | from dataset.quantized_soup import QuantizedSoupTripletsCreator
11 | from dataset.triangles import TriangleNodesWithFacesAndSequenceIndices
12 | from trainer import get_rvqvae_v0_decoder
13 | from trainer.train_transformer import get_qsoup_model_config
14 | from util.meshlab import meshlab_proc
15 | from util.misc import get_parameters_from_state_dict
16 | from util.visualization import plot_vertices_and_faces
17 | from tqdm import tqdm
18 | from model.transformer import QuantSoupTransformer
19 | from pytorch_lightning import seed_everything
20 |
21 |
22 | @torch.no_grad()
23 | def main(config, mode):
24 | seed_everything(42)
25 | device = torch.device('cuda:0')
26 | vq_cfg = omegaconf.OmegaConf.load(Path(config.vq_resume).parents[1] / "config.yaml")
27 | dataset = TriangleNodesWithFacesAndSequenceIndices(config, 'train', True, True , config.ft_category)
28 | prompt_num_faces = 4
29 | output_dir_image = Path(f'runs/{config.experiment}/inf_image_{mode}')
30 | output_dir_image.mkdir(exist_ok=True, parents=True)
31 | output_dir_mesh = Path(f'runs/{config.experiment}/inf_mesh_{mode}')
32 | output_dir_mesh.mkdir(exist_ok=True, parents=True)
33 | model_cfg = get_qsoup_model_config(config, vq_cfg.embed_levels)
34 | model = QuantSoupTransformer(model_cfg, vq_cfg)
35 | state_dict = torch.load(config.resume, map_location="cpu")["state_dict"]
36 | sequencer = QuantizedSoupTripletsCreator(config, vq_cfg)
37 | model.load_state_dict(get_parameters_from_state_dict(state_dict, "model"))
38 | model = model.to(device)
39 | model = model.eval()
40 | sequencer = sequencer.to(device)
41 | sequencer = sequencer.eval()
42 | decoder = get_rvqvae_v0_decoder(vq_cfg, config.vq_resume, device)
43 | pnet = get_pointnet_classifier().to(device)
44 |
45 | k = 0
46 | while k < config.num_val_samples:
47 |
48 | data = dataset.get(random.randint(0, len(dataset) - 1))
49 | soup_sequence, face_in_idx, face_out_idx, target = sequencer.get_completion_sequence(
50 | data.x.to(device),
51 | data.edge_index.to(device),
52 | data.faces.to(device),
53 | data.num_vertices,
54 | 1 + 6 * prompt_num_faces
55 | )
56 |
57 | y = None
58 |
59 | if mode == "topp":
60 | # more diversity but more change of bad sequences
61 | y = model.generate(
62 | soup_sequence, face_in_idx, face_out_idx, sequencer, config.max_val_tokens,
63 | temperature=config.temperature, top_k=config.top_k_tokens, top_p=config.top_p, use_kv_cache=True
64 | )
65 | elif mode == "beam":
66 | # less diversity but more robust
67 | y = model.generate_with_beamsearch(
68 | soup_sequence, face_in_idx, face_out_idx, sequencer, config.max_val_tokens, use_kv_cache=True, beam_width=6
69 | )
70 |
71 | if y is None:
72 | continue
73 |
74 | gen_vertices, gen_faces = sequencer.decode(y[0], decoder)
75 |
76 | try:
77 | mesh = trimesh.Trimesh(vertices=gen_vertices, faces=gen_faces, process=False)
78 | if pnet.classifier_guided_filter(mesh, config.ft_category):
79 | mesh.export(output_dir_mesh / f"{k:06d}.obj")
80 | meshlab_proc(output_dir_mesh / f"{k:06d}.obj")
81 | plot_vertices_and_faces(gen_vertices, gen_faces, output_dir_image / f"{k:06d}.jpg")
82 | print('Generated mesh', k + 1)
83 | k = k + 1
84 | except Exception as e:
85 | print('Exception occured: ', e)
86 | pass # sometimes the mesh is invalid (ngon) and we don't want to crash
87 |
88 |
89 | if __name__ == "__main__":
90 | cfg = omegaconf.OmegaConf.load(Path(sys.argv[1]).parents[1] / "config.yaml")
91 | cfg.resume = sys.argv[1]
92 | cfg.padding = 0.0
93 | cfg.num_val_samples = int(sys.argv[3])
94 | cfg.sequence_stride = cfg.block_size
95 | cfg.top_p = 0.95
96 | cfg.temperature = 1.0
97 | cfg.low_augment = True
98 | main(cfg, sys.argv[2])
99 |
--------------------------------------------------------------------------------
/model/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Andrej Karpathy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/model/README.md:
--------------------------------------------------------------------------------
1 |
2 | # nanoGPT
3 |
4 | 
5 |
6 | The simplest, fastest repository for training/finetuning medium-sized GPTs. It is a rewrite of [minGPT](https://github.com/karpathy/minGPT) that prioritizes teeth over education. Still under active development, but currently the file `train.py` reproduces GPT-2 (124M) on OpenWebText, running on a single 8XA100 40GB node in about 4 days of training. The code itself is plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it.
7 |
8 | 
9 |
10 | Because the code is so simple, it is very easy to hack to your needs, train new models from scratch, or finetune pretrained checkpoints (e.g. biggest one currently available as a starting point would be the GPT-2 1.3B model from OpenAI).
11 |
12 | ## install
13 |
14 | ```
15 | pip install torch numpy transformers datasets tiktoken wandb tqdm
16 | ```
17 |
18 | Dependencies:
19 |
20 | - [pytorch](https://pytorch.org) <3
21 | - [numpy](https://numpy.org/install/) <3
22 | - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints)
23 | - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
24 | - `tiktoken` for OpenAI's fast BPE code <3
25 | - `wandb` for optional logging <3
26 | - `tqdm` for progress bars <3
27 |
28 | ## quick start
29 |
30 | If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers:
31 |
32 | ```
33 | $ python data/shakespeare_char/prepare.py
34 | ```
35 |
36 | This creates a `train.bin` and `val.bin` in that data directory. Now it is time to train your GPT. The size of it very much depends on the computational resources of your system:
37 |
38 | **I have a GPU**. Great, we can quickly train a baby GPT with the settings provided in the [config/train_shakespeare_char.py](config/train_shakespeare_char.py) config file:
39 |
40 | ```
41 | $ python train.py config/train_shakespeare_char.py
42 | ```
43 |
44 | If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
45 |
46 | ```
47 | $ python sample.py --out_dir=out-shakespeare-char
48 | ```
49 |
50 | This generates a few samples, for example:
51 |
52 | ```
53 | ANGELO:
54 | And cowards it be strawn to my bed,
55 | And thrust the gates of my threats,
56 | Because he that ale away, and hang'd
57 | An one with him.
58 |
59 | DUKE VINCENTIO:
60 | I thank your eyes against it.
61 |
62 | DUKE VINCENTIO:
63 | Then will answer him to save the malm:
64 | And what have you tyrannous shall do this?
65 |
66 | DUKE VINCENTIO:
67 | If you have done evils of all disposition
68 | To end his power, the day of thrust for a common men
69 | That I leave, to fight with over-liking
70 | Hasting in a roseman.
71 | ```
72 |
73 | lol `¯\_(ツ)_/¯`. Not bad for a character-level model after 3 minutes of training on a GPU. Better results are quite likely obtainable by instead finetuning a pretrained GPT-2 model on this dataset (see finetuning section later).
74 |
75 | **I only have a macbook** (or other cheap computer). No worries, we can still train a GPT but we want to dial things down a notch. I recommend getting the bleeding edge PyTorch nightly ([select it here](https://pytorch.org/get-started/locally/) when installing) as it is currently quite likely to make your code more efficient. But even without it, a simple train run could look as follows:
76 |
77 | ```
78 | $ python train.py config/train_shakespeare_char.py --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0
79 | ```
80 |
81 | Here, since we are running on CPU instead of GPU we must set both `--device=cpu` and also turn off PyTorch 2.0 compile with `--compile=False`. Then when we evaluate we get a bit more noisy but faster estimate (`--eval_iters=20`, down from 200), our context size is only 64 characters instead of 256, and the batch size only 12 examples per iteration, not 64. We'll also use a much smaller Transformer (4 layers, 4 heads, 128 embedding size), and decrease the number of iterations to 2000 (and correspondingly usually decay the learning rate to around max_iters with `--lr_decay_iters`). Because our network is so small we also ease down on regularization (`--dropout=0.0`). This still runs in about ~3 minutes, but gets us a loss of only 1.88 and therefore also worse samples, but it's still good fun:
82 |
83 | ```
84 | $ python sample.py --out_dir=out-shakespeare-char --device=cpu
85 | ```
86 | Generates samples like this:
87 |
88 | ```
89 | GLEORKEN VINGHARD III:
90 | Whell's the couse, the came light gacks,
91 | And the for mought you in Aut fries the not high shee
92 | bot thou the sought bechive in that to doth groan you,
93 | No relving thee post mose the wear
94 | ```
95 |
96 | Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
97 |
98 | Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device=mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more.
99 |
100 | ## reproducing GPT-2
101 |
102 | A more serious deep learning professional may be more interested in reproducing GPT-2 results. So here we go - we first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText:
103 |
104 | ```
105 | $ python data/openwebtext/prepare.py
106 | ```
107 |
108 | This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run:
109 |
110 | ```
111 | $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
112 | ```
113 |
114 | This will run for about 4 days using PyTorch Distributed Data Parallel (DDP) and go down to loss of ~2.85. Now, a GPT-2 model just evaluated on OWT gets a val loss of about 3.11, but if you finetune it it will come down to ~2.85 territory (due to an apparent domain gap), making the two models ~match.
115 |
116 | If you're in a cluster environment and you are blessed with multiple GPU nodes you can make GPU go brrrr e.g. across 2 nodes like:
117 |
118 | ```
119 | Run on the first (master) node with example IP 123.456.123.456:
120 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
121 | Run on the worker node:
122 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
123 | ```
124 |
125 | It is a good idea to benchmark your interconnect (e.g. iperf3). In particular, if you don't have Infiniband then also prepend `NCCL_IB_DISABLE=1` to the above launches. Your multinode training will work, but most likely _crawl_. By default checkpoints are periodically written to the `--out_dir`. We can sample from the model by simply `$ python sample.py`.
126 |
127 | Finally, to train on a single GPU simply run the `$ python train.py` script. Have a look at all of its args, the script tries to be very readable, hackable and transparent. You'll most likely want to tune a number of those variables depending on your needs.
128 |
129 | ## baselines
130 |
131 | OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:
132 |
133 | ```
134 | $ python train.py eval_gpt2
135 | $ python train.py eval_gpt2_medium
136 | $ python train.py eval_gpt2_large
137 | $ python train.py eval_gpt2_xl
138 | ```
139 |
140 | and observe the following losses on train and val:
141 |
142 | | model | params | train loss | val loss |
143 | | ------| ------ | ---------- | -------- |
144 | | gpt2 | 124M | 3.11 | 3.12 |
145 | | gpt2-medium | 350M | 2.85 | 2.84 |
146 | | gpt2-large | 774M | 2.66 | 2.67 |
147 | | gpt2-xl | 1558M | 2.56 | 2.54 |
148 |
149 | However, we have to note that GPT-2 was trained on (closed, never released) WebText, while OpenWebText is just a best-effort open reproduction of this dataset. This means there is a dataset domain gap. Indeed, taking the GPT-2 (124M) checkpoint and finetuning on OWT directly for a while reaches loss down to ~2.85. This then becomes the more appropriate baseline w.r.t. reproduction.
150 |
151 | ## finetuning
152 |
153 | Finetuning is no different than training, we just make sure to initialize from a pretrained model and train with a smaller learning rate. For an example of how to finetune a GPT on new text go to `data/shakespeare` and run `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`, using the OpenAI BPE tokenizer from GPT-2. Unlike OpenWebText this will run in seconds. Finetuning can take very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like:
154 |
155 | ```
156 | $ python train.py config/finetune_shakespeare.py
157 | ```
158 |
159 | This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. If you're running out of memory try decreasing the model size (they are `{'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}`) or possibly decreasing the `block_size` (context length). The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py --out_dir=out-shakespeare`:
160 |
161 | ```
162 | THEODORE:
163 | Thou shalt sell me to the highest bidder: if I die,
164 | I sell thee to the first; if I go mad,
165 | I sell thee to the second; if I
166 | lie, I sell thee to the third; if I slay,
167 | I sell thee to the fourth: so buy or sell,
168 | I tell thee again, thou shalt not sell my
169 | possession.
170 |
171 | JULIET:
172 | And if thou steal, thou shalt not sell thyself.
173 |
174 | THEODORE:
175 | I do not steal; I sell the stolen goods.
176 |
177 | THEODORE:
178 | Thou know'st not what thou sell'st; thou, a woman,
179 | Thou art ever a victim, a thing of no worth:
180 | Thou hast no right, no right, but to be sold.
181 | ```
182 |
183 | Whoa there, GPT, entering some dark place over there. I didn't really tune the hyperparameters in the config too much, feel free to try!
184 |
185 | ## sampling / inference
186 |
187 | Use the script `sample.py` to sample either from pre-trained GPT-2 models released by OpenAI, or from a model you trained yourself. For example, here is a way to sample from the largest available `gpt2-xl` model:
188 |
189 | ```
190 | $ python sample.py \
191 | --init_from=gpt2-xl \
192 | --start="What is the answer to life, the universe, and everything?" \
193 | --num_samples=5 --max_new_tokens=100
194 | ```
195 |
196 | If you'd like to sample from a model you trained, use the `--out_dir` to point the code appropriately. You can also prompt the model with some text from a file, e.g. `$ python sample.py --start=FILE:prompt.txt`.
197 |
198 | ## efficiency notes
199 |
200 | For simple model benchmarking and profiling, `bench.py` might be useful. It's identical to what happens in the meat of the training loop of `train.py`, but omits much of the other complexities.
201 |
202 | Note that the code by default uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team!
203 |
204 | ## todos
205 |
206 | - Investigate and add FSDP instead of DDP
207 | - Eval zero-shot perplexities on standard evals (e.g. LAMBADA? HELM? etc.)
208 | - Finetune the finetuning script, I think the hyperparams are not great
209 | - Schedule for linear batch size increase during training
210 | - Incorporate other embeddings (rotary, alibi)
211 | - Separate out the optim buffers from model params in checkpoints I think
212 | - Additional logging around network health (e.g. gradient clip events, magnitudes)
213 | - Few more investigations around better init etc.
214 |
215 | ## troubleshooting
216 |
217 | Note that by default this repo uses PyTorch 2.0 (i.e. `torch.compile`). This is fairly new and experimental, and not yet available on all platforms (e.g. Windows). If you're running into related error messages try to disable this by adding `--compile=False` flag. This will slow down the code but at least it will run.
218 |
219 | For some context on this repository, GPT, and language modeling it might be helpful to watch my [Zero To Hero series](https://karpathy.ai/zero-to-hero.html). Specifically, the [GPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY) is popular if you have some prior language modeling context.
220 |
221 | For more questions/discussions feel free to stop by **#nanoGPT** on Discord:
222 |
223 | [](https://discord.gg/3zy8kqD9Cp)
224 |
225 | ## acknowledgements
226 |
227 | All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT!
228 |
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from util.positional_encoding import get_embedder
4 |
5 |
6 | class BasicBlock(nn.Module):
7 | expansion: int = 1
8 |
9 | def __init__(self, inplanes, planes, base_width=64):
10 | super().__init__()
11 | norm_layer = nn.BatchNorm1d
12 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
13 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=3, padding=1)
14 | self.bn1 = norm_layer(planes)
15 | self.relu = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, padding=1)
17 | self.bn2 = norm_layer(planes)
18 |
19 | if inplanes != planes:
20 | self.identity_fn = nn.Sequential(
21 | nn.Conv1d(inplanes, planes * self.expansion, 1),
22 | norm_layer(planes * self.expansion),
23 | )
24 | else:
25 | self.identity_fn = nn.Identity()
26 |
27 | def forward(self, x):
28 | identity = x
29 |
30 | out = self.conv1(x)
31 | out = self.bn1(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv2(out)
35 | out = self.bn2(out)
36 |
37 | out += self.identity_fn(identity)
38 | out = self.relu(out)
39 |
40 | return out
41 |
42 |
43 | class Bottleneck(nn.Module):
44 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
45 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
46 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
47 | # This variant is also known as ResNet V1.5 and improves accuracy according to
48 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
49 |
50 | expansion: int = 4
51 |
52 | def __init__(self, inplanes, planes, base_width=64):
53 | super().__init__()
54 | norm_layer = nn.BatchNorm1d
55 | width = int(planes * (base_width / 64.0))
56 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
57 | self.conv1 = nn.Conv1d(inplanes, width, 1)
58 | self.bn1 = norm_layer(width)
59 | self.conv2 = nn.Conv1d(width, width, 3, padding=1)
60 | self.bn2 = norm_layer(width)
61 | self.conv3 = nn.Conv1d(width, planes * self.expansion, 1)
62 | self.bn3 = norm_layer(planes * self.expansion)
63 | self.relu = nn.ReLU(inplace=True)
64 | if inplanes != planes:
65 | self.identity_fn = nn.Sequential(
66 | nn.Conv1d(inplanes, planes * self.expansion, 1),
67 | norm_layer(planes * self.expansion),
68 | )
69 | else:
70 | self.identity_fn = nn.Identity()
71 |
72 | def forward(self, x):
73 | identity = x
74 |
75 | out = self.conv1(x)
76 | out = self.bn1(out)
77 | out = self.relu(out)
78 |
79 | out = self.conv2(out)
80 | out = self.bn2(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | out += self.identity_fn(identity)
87 | out = self.relu(out)
88 |
89 | return out
90 |
91 |
92 | class ResNetDecoder(nn.Module):
93 |
94 | def __init__(self, in_feats, num_tokens, block, layers, zero_init_residual=False, width_per_group=64, ce_output=True):
95 | super().__init__()
96 | norm_layer = nn.BatchNorm1d
97 | self._norm_layer = norm_layer
98 | self.num_tokens = num_tokens
99 | self.inplanes = 512
100 | self.base_width = width_per_group
101 | self.conv1 = nn.Conv1d(in_feats, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)
102 | self.bn1 = norm_layer(self.inplanes)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.layer1 = self._make_layer(block, 512, layers[0])
105 | self.layer2 = self._make_layer(block, 384, layers[1])
106 | self.layer3 = self._make_layer(block, 384, layers[2])
107 | self.layer4 = self._make_layer(block, 320, layers[3])
108 | self.ce_output = ce_output
109 |
110 | if ce_output:
111 | self.fc = nn.Conv1d(320 * block.expansion, self.num_tokens * 9, 1)
112 | else:
113 | self.fc = nn.Conv1d(320 * block.expansion, 9, 1)
114 |
115 | for m in self.modules():
116 | if isinstance(m, nn.Conv1d):
117 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
118 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
119 | nn.init.constant_(m.weight, 1)
120 | nn.init.constant_(m.bias, 0)
121 |
122 | # Zero-initialize the last BN in each residual branch,
123 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
124 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
125 | if zero_init_residual:
126 | for m in self.modules():
127 | if isinstance(m, Bottleneck) and m.bn3.weight is not None:
128 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
129 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
130 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
131 |
132 | def _make_layer(self, block, planes: int, blocks: int):
133 | layers = [block(self.inplanes, planes, self.base_width)]
134 | self.inplanes = planes * block.expansion
135 | for _ in range(1, blocks):
136 | layers.append(
137 | block(
138 | self.inplanes,
139 | planes,
140 | base_width=self.base_width,
141 | )
142 | )
143 | return nn.Sequential(*layers)
144 |
145 | def _forward_impl(self, x):
146 | B, _, N = x.shape
147 | x = self.conv1(x)
148 | x = self.bn1(x)
149 | x = self.relu(x)
150 |
151 | x = self.layer1(x)
152 | x = self.layer2(x)
153 | x = self.layer3(x)
154 | x = self.layer4(x)
155 | x = self.fc(x)
156 | if self.ce_output:
157 | x = x.permute((0, 2, 1)).reshape(B, N, 9, self.num_tokens)
158 | else:
159 | x = x.permute((0, 2, 1)).reshape(B, N, 9)
160 | return x
161 |
162 | def forward(self, x):
163 | return self._forward_impl(x)
164 |
165 |
166 | class ResNetEncoder(nn.Module):
167 |
168 | def __init__(self, in_feats, block, layers, zero_init_residual=False, width_per_group=64):
169 | super().__init__()
170 | norm_layer = nn.BatchNorm1d
171 | self._norm_layer = norm_layer
172 | self.inplanes = 128
173 | self.base_width = width_per_group
174 | self.conv1 = nn.Conv1d(in_feats, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)
175 | self.bn1 = norm_layer(self.inplanes)
176 | self.relu = nn.ReLU(inplace=True)
177 | self.layer1 = self._make_layer(block, 128, layers[0])
178 | self.layer2 = self._make_layer(block, 192, layers[1])
179 | self.layer3 = self._make_layer(block, 256, layers[2])
180 | self.layer4 = self._make_layer(block, 384, layers[3])
181 |
182 | self.fc = nn.Conv1d(384 * block.expansion, 512, 1)
183 |
184 | for m in self.modules():
185 | if isinstance(m, nn.Conv1d):
186 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
187 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
188 | nn.init.constant_(m.weight, 1)
189 | nn.init.constant_(m.bias, 0)
190 |
191 | # Zero-initialize the last BN in each residual branch,
192 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
193 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
194 | if zero_init_residual:
195 | for m in self.modules():
196 | if isinstance(m, Bottleneck) and m.bn3.weight is not None:
197 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
198 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
199 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
200 |
201 | def _make_layer(self, block, planes: int, blocks: int):
202 | layers = [block(self.inplanes, planes, self.base_width)]
203 | self.inplanes = planes * block.expansion
204 | for _ in range(1, blocks):
205 | layers.append(
206 | block(
207 | self.inplanes,
208 | planes,
209 | base_width=self.base_width,
210 | )
211 | )
212 | return nn.Sequential(*layers)
213 |
214 | def _forward_impl(self, x):
215 | x = self.conv1(x)
216 | x = self.bn1(x)
217 | x = self.relu(x)
218 |
219 | x = self.layer1(x)
220 | x = self.layer2(x)
221 | x = self.layer3(x)
222 | x = self.layer4(x)
223 | x = self.fc(x)
224 | B, C, N = x.shape
225 | x = x.permute((0, 2, 1)).reshape(-1, C)
226 | return x
227 |
228 | def forward(self, x):
229 | return self._forward_impl(x)
230 |
231 |
232 | def resnet18_decoder(in_feats, num_tokens, ce_output=True):
233 | return ResNetDecoder(in_feats, num_tokens, BasicBlock, [2, 2, 2, 2], zero_init_residual=True, ce_output=ce_output)
234 |
235 |
236 | def resnet34_decoder(in_feats, num_tokens, ce_output=True):
237 | return ResNetDecoder(in_feats, num_tokens, BasicBlock, [3, 4, 6, 3], zero_init_residual=True, ce_output=ce_output)
238 |
239 |
240 | def resnet18_encoder(in_feats):
241 | return ResNetEncoder(in_feats, BasicBlock, [2, 2, 2, 2], zero_init_residual=True)
242 |
243 |
244 | def resnet34_encoder(in_feats):
245 | return ResNetEncoder(in_feats, BasicBlock, [3, 4, 6, 3], zero_init_residual=True)
246 |
247 |
248 | def test_resnet_decoder():
249 | import torch
250 | model = resnet18_decoder(512, 65)
251 | print(model)
252 | print(sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)
253 |
254 | model = model.cuda()
255 | x_ = torch.rand(1, 512, 2048).cuda()
256 | y_, c_ = model(x_)
257 | print(y_.shape, c_.shape)
258 |
259 |
260 | def test_resnet_encoder():
261 | import torch
262 | model = resnet18_encoder(70)
263 | print(model)
264 | print(sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)
265 |
266 | model = model.cuda()
267 | x_ = torch.rand(1, 70, 2048).cuda()
268 | y_ = model(x_)
269 | print(y_.shape)
270 |
271 |
272 | if __name__ == "__main__":
273 | test_resnet_encoder()
274 |
--------------------------------------------------------------------------------
/model/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch_geometric
4 | import torch_scatter
5 |
6 | from util.positional_encoding import get_embedder
7 |
8 |
9 | class GraphEncoder(nn.Module):
10 |
11 | def __init__(self, no_max_pool=True, aggr='mean', graph_conv="edge", use_point_features=False, output_dim=512):
12 | super().__init__()
13 | self.no_max_pool = no_max_pool
14 | self.use_point_features = use_point_features
15 | self.embedder, self.embed_dim = get_embedder(10)
16 | self.conv = graph_conv
17 | self.gc1 = get_conv(self.conv, self.embed_dim * 3 + 7, 64, aggr=aggr)
18 | self.gc2 = get_conv(self.conv, 64, 128, aggr=aggr)
19 | self.gc3 = get_conv(self.conv, 128, 256, aggr=aggr)
20 | self.gc4 = get_conv(self.conv, 256, 256, aggr=aggr)
21 | self.gc5 = get_conv(self.conv, 256, output_dim, aggr=aggr)
22 |
23 | self.norm1 = torch_geometric.nn.BatchNorm(64)
24 | self.norm2 = torch_geometric.nn.BatchNorm(128)
25 | self.norm3 = torch_geometric.nn.BatchNorm(256)
26 | self.norm4 = torch_geometric.nn.BatchNorm(256)
27 |
28 | self.relu = nn.ReLU()
29 |
30 | def forward(self, x, edge_index, batch):
31 | x_0 = self.embedder(x[:, :3])
32 | x_1 = self.embedder(x[:, 3:6])
33 | x_2 = self.embedder(x[:, 6:9])
34 | x_n = x[:, 9:12]
35 | x_ar = x[:, 12:13]
36 | x_an_0 = x[:, 13:14]
37 | x_an_1 = x[:, 14:15]
38 | x_an_2 = x[:, 15:]
39 | x = torch.cat([x_0, x_1, x_2, x_n, x_ar, x_an_0, x_an_1, x_an_2], dim=-1)
40 | x = self.relu(self.norm1(self.gc1(x, edge_index)))
41 | x = self.norm2(self.gc2(x, edge_index))
42 | point_features = x
43 | x = self.relu(x)
44 | x = self.relu(self.norm3(self.gc3(x, edge_index)))
45 | x = self.relu(self.norm4(self.gc4(x, edge_index)))
46 | x = self.gc5(x, edge_index)
47 | if not self.no_max_pool:
48 | x = torch_scatter.scatter_max(x, batch, dim=0)[0]
49 | x = x[batch, :]
50 | if self.use_point_features:
51 | return torch.cat([x, point_features], dim=-1)
52 | return x
53 |
54 |
55 | class GraphEncoderTriangleSoup(nn.Module):
56 |
57 | def __init__(self, aggr='mean', graph_conv="edge"):
58 | super().__init__()
59 | self.embedder, self.embed_dim = get_embedder(10)
60 | self.conv = graph_conv
61 | self.gc1 = get_conv(self.conv, self.embed_dim * 3 + 7, 96, aggr=aggr)
62 | self.gc2 = get_conv(self.conv, 96, 192, aggr=aggr)
63 | self.gc3 = get_conv(self.conv, 192, 384, aggr=aggr)
64 | self.gc4 = get_conv(self.conv, 384, 384, aggr=aggr)
65 | self.gc5 = get_conv(self.conv, 384, 576, aggr=aggr)
66 |
67 | self.norm1 = torch_geometric.nn.BatchNorm(96)
68 | self.norm2 = torch_geometric.nn.BatchNorm(192)
69 | self.norm3 = torch_geometric.nn.BatchNorm(384)
70 | self.norm4 = torch_geometric.nn.BatchNorm(384)
71 |
72 | self.relu = nn.ReLU()
73 |
74 | @staticmethod
75 | def distribute_features(features, face_indices, num_vertices):
76 | N, F = features.shape
77 | features = features.reshape(N * 3, F // 3)
78 | assert features.shape[0] == face_indices.shape[0] * face_indices.shape[1], "Features and face indices must match in size"
79 | vertex_features = torch.zeros([num_vertices, features.shape[1]], device=features.device)
80 | torch_scatter.scatter_mean(features, face_indices.reshape(-1), out=vertex_features, dim=0)
81 | distributed_features = vertex_features[face_indices.reshape(-1), :]
82 | distributed_features = distributed_features.reshape(N, F)
83 | return distributed_features
84 |
85 | def forward(self, x, edge_index, faces, num_vertices):
86 | x_0 = self.embedder(x[:, :3])
87 | x_1 = self.embedder(x[:, 3:6])
88 | x_2 = self.embedder(x[:, 6:9])
89 | x = torch.cat([x_0, x_1, x_2, x[:, 9:]], dim=-1)
90 | x = self.relu(self.norm1(self.gc1(x, edge_index)))
91 | x = self.distribute_features(x, faces, num_vertices)
92 | x = self.relu(self.norm2(self.gc2(x, edge_index)))
93 | x = self.distribute_features(x, faces, num_vertices)
94 | x = self.relu(self.norm3(self.gc3(x, edge_index)))
95 | x = self.distribute_features(x, faces, num_vertices)
96 | x = self.relu(self.norm4(self.gc4(x, edge_index)))
97 | x = self.distribute_features(x, faces, num_vertices)
98 | x = self.gc5(x, edge_index)
99 | x = self.distribute_features(x, faces, num_vertices)
100 | return x
101 |
102 |
103 | def get_conv(conv, in_dim, out_dim, aggr):
104 | if conv == 'sage':
105 | return torch_geometric.nn.SAGEConv(in_dim, out_dim, aggr=aggr)
106 | elif conv == 'gat':
107 | return torch_geometric.nn.GATv2Conv(in_dim, out_dim, fill_value=aggr)
108 | elif conv == 'edge':
109 | return torch_geometric.nn.EdgeConv(
110 | torch.nn.Sequential(
111 | torch.nn.Linear(in_dim * 2, 2 * out_dim),
112 | torch.nn.ReLU(),
113 | torch.nn.Linear(2 * out_dim, out_dim),
114 | ),
115 | aggr=aggr,
116 | )
117 |
--------------------------------------------------------------------------------
/model/nanogpt.py:
--------------------------------------------------------------------------------
1 | """
2 | Full definition of a GPT Language Model, all of it in this single file.
3 | References:
4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5 | https://github.com/openai/gpt-2/blob/master/src/model.py
6 | 2) huggingface/transformers PyTorch implementation:
7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8 | """
9 |
10 | import math
11 | import inspect
12 | from dataclasses import dataclass
13 |
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 |
18 | from util.misc import top_p_sampling
19 |
20 |
21 | class LayerNorm(nn.Module):
22 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
23 |
24 | def __init__(self, ndim, bias):
25 | super().__init__()
26 | self.weight = nn.Parameter(torch.ones(ndim))
27 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
28 |
29 | def forward(self, input):
30 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
31 |
32 |
33 | class AttentionBase(nn.Module):
34 |
35 | def __init__(self, config):
36 | super().__init__()
37 | assert config.n_embd % config.n_head == 0
38 | # output projection
39 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
40 | # regularization
41 | self.attn_dropout = nn.Dropout(config.dropout)
42 | self.resid_dropout = nn.Dropout(config.dropout)
43 | self.n_head = config.n_head
44 | self.n_embd = config.n_embd
45 | self.dropout = config.dropout
46 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
47 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
48 | if not self.flash:
49 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
50 | # causal mask to ensure that attention is only applied to the left in the input sequence
51 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
52 | .view(1, 1, config.block_size, config.block_size))
53 |
54 |
55 | class CausalSelfAttention(AttentionBase):
56 |
57 | def __init__(self, config):
58 | super().__init__(config)
59 | # key, query, value projections for all heads, but in a batch
60 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
61 |
62 | def forward(self, x, kv_cache=None, mask=None):
63 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
64 |
65 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
66 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
67 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
69 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
70 |
71 | if kv_cache is not None:
72 | key_cached, value_cached = kv_cache.unbind(dim=0) # (2, B, T, head_size) -> 2 * (B, T, head_size)
73 | k = torch.cat((key_cached, k), dim=-2) # (B, cache + T, head_size)
74 | v = torch.cat((value_cached, v), dim=-2) # (B, cache + T, head_size)
75 |
76 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
77 | if self.flash:
78 | # efficient attention using Flash Attention CUDA kernels
79 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0, is_causal=mask is None)
80 | else:
81 | # manual implementation of attention
82 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
83 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
84 | att = F.softmax(att, dim=-1)
85 | att = self.attn_dropout(att)
86 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
87 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
88 |
89 | # output projection
90 | y = self.resid_dropout(self.c_proj(y))
91 | return y, None if kv_cache is None else torch.stack((k, v))
92 |
93 |
94 | class CrossAttention(AttentionBase):
95 |
96 | def __init__(self, config):
97 | super().__init__(config)
98 | self.ckv_attn = nn.Linear(config.n_embd, 2 * config.n_embd, bias=config.bias)
99 | self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
100 |
101 | def forward(self, x, encoding):
102 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)4
103 |
104 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
105 | k, v = self.ckv_attn(encoding).split(self.n_embd, dim=2)
106 | q = self.cq_attn(x)
107 |
108 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
109 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
110 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
111 |
112 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
113 | if self.flash:
114 | # efficient attention using Flash Attention CUDA kernels
115 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
116 | else:
117 | # manual implementation of attention
118 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
119 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
120 | att = F.softmax(att, dim=-1)
121 | att = self.attn_dropout(att)
122 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
123 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
124 |
125 | # output projection
126 | y = self.resid_dropout(self.c_proj(y))
127 | return y
128 |
129 |
130 | class MLP(nn.Module):
131 |
132 | def __init__(self, config):
133 | super().__init__()
134 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
135 | self.gelu = nn.GELU()
136 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
137 | self.dropout = nn.Dropout(config.dropout)
138 |
139 | def forward(self, x):
140 | x = self.c_fc(x)
141 | x = self.gelu(x)
142 | x = self.c_proj(x)
143 | x = self.dropout(x)
144 | return x
145 |
146 |
147 | class Block(nn.Module):
148 |
149 | def __init__(self, config):
150 | super().__init__()
151 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
152 | self.attn = CausalSelfAttention(config)
153 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
154 | self.mlp = MLP(config)
155 |
156 | def forward(self, x, kv_cache=None, mask=None):
157 | out, kv_cache = self.attn(self.ln_1(x), kv_cache, mask)
158 | x = x + out
159 | x = x + self.mlp(self.ln_2(x))
160 | return x, kv_cache
161 |
162 |
163 | class BlockWithCrossAttention(nn.Module):
164 |
165 | def __init__(self, config):
166 | super().__init__()
167 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
168 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
169 | self.ln_3 = LayerNorm(config.n_embd, bias=config.bias)
170 | self.attn = CausalSelfAttention(config)
171 | self.xattn = CrossAttention(config)
172 | self.mlp = MLP(config)
173 |
174 | def forward(self, x, encoding):
175 | x = x + self.attn(self.ln_1(x))
176 | x = x + self.xattn(self.ln_2(x, encoding))
177 | x = x + self.mlp(self.ln_3(x))
178 | return x
179 |
180 |
181 | @dataclass
182 | class GPTConfig:
183 | block_size: int = 1024
184 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
185 | n_layer: int = 12
186 | n_head: int = 12
187 | n_embd: int = 768
188 | dropout: float = 0.0
189 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
190 |
191 |
192 | class GPT(nn.Module):
193 |
194 | def __init__(self, config):
195 | super().__init__()
196 | assert config.vocab_size is not None
197 | assert config.block_size is not None
198 | self.config = config
199 | self.padding_idx = config.vocab_size - 1
200 | self.vocab_size = config.vocab_size
201 | print('Model Padding Index:', self.padding_idx)
202 | self.transformer = nn.ModuleDict(dict(
203 | wte=nn.Embedding(config.vocab_size, config.n_embd, padding_idx=self.padding_idx),
204 | wpe=nn.Embedding(config.block_size, config.n_embd),
205 | wce=nn.Embedding(3, config.n_embd),
206 | drop=nn.Dropout(config.dropout),
207 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
208 | ln_f=LayerNorm(config.n_embd, bias=config.bias),
209 | ))
210 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
211 | # with weight tying when using torch.compile() some warnings get generated:
212 | # "UserWarning: functional_call was passed multiple values for tied weights.
213 | # This behavior is deprecated and will be an error in future versions"
214 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
215 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
216 |
217 | # init all weights
218 | self.apply(self._init_weights)
219 | # apply special scaled init to the residual projections, per GPT-2 paper
220 | for pn, p in self.named_parameters():
221 | if pn.endswith('c_proj.weight'):
222 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
223 |
224 | # report number of parameters
225 | print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
226 |
227 | def get_num_params(self, non_embedding=True):
228 | """
229 | Return the number of parameters in the model.
230 | For non-embedding count (default), the position embeddings get subtracted.
231 | The token embeddings would too, except due to the parameter sharing these
232 | params are actually used as weights in the final layer, so we include them.
233 | """
234 | n_params = sum(p.numel() for p in self.parameters())
235 | if non_embedding:
236 | n_params -= self.transformer.wpe.weight.numel()
237 | return n_params
238 |
239 | def _init_weights(self, module):
240 | if isinstance(module, nn.Linear):
241 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
242 | if module.bias is not None:
243 | torch.nn.init.zeros_(module.bias)
244 | elif isinstance(module, nn.Embedding):
245 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
246 |
247 | def forward(self, idx, coord, targets=None, kv_cache=None, mask_cache=None):
248 | use_kv_cache = kv_cache is not None
249 | device = idx.device
250 | b, t = idx.size()
251 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
252 |
253 | if kv_cache is not None and kv_cache[0].numel():
254 | pos = kv_cache[0].shape[-2] # kv_cache of shape: num_layers * (2, B, nh, T, hs)
255 | pos_emb = self.transformer.wpe.weight[None, pos // 3] # 1 x n_embd
256 | mask = mask_cache.index_select(2, torch.LongTensor([pos]).to(pos_emb.device))[:, :, :, :pos + 1]
257 | else:
258 | pos = torch.tensor([i // 3 for i in range(t)], dtype=torch.long, device=device)
259 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
260 | mask = None
261 |
262 | # print('embs:', idx.min(), idx.max(), pos.min(), pos.max(), coord.min(), coord.max())
263 | # forward the GPT model itself
264 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
265 | coord_emb = self.transformer.wce(coord) # position embeddings of shape (t, n_embd)
266 | # print('shapes:', tok_emb.shape, pos_emb.shape, coord_emb.shape)
267 | x = self.transformer.drop(tok_emb + pos_emb + coord_emb)
268 |
269 | # apply multiple transformer blocks
270 | new_kv_cache = []
271 | kv_cache = kv_cache or [None] * self.config.n_layer
272 |
273 | for block, kv_cache_layer in zip(self.transformer.h, kv_cache):
274 | x, new_kv = block(x, kv_cache_layer, mask)
275 | new_kv_cache.append(new_kv)
276 |
277 | x = self.transformer.ln_f(x)
278 |
279 | if targets is not None:
280 | # if we are given some desired targets also calculate the loss
281 | logits = self.lm_head(x)
282 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=self.padding_idx)
283 | else:
284 | # inference-time mini-optimization: only forward the lm_head on the very last position
285 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
286 | loss = None
287 |
288 | if not use_kv_cache:
289 | return logits, loss
290 | else:
291 | return logits, new_kv_cache
292 |
293 | def crop_block_size(self, block_size):
294 | # model surgery to decrease the block size if necessary
295 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
296 | # but want to use a smaller block size for some smaller, simpler model
297 | assert block_size <= self.config.block_size
298 | self.config.block_size = block_size
299 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
300 | for block in self.transformer.h:
301 | if hasattr(block.attn, 'bias'):
302 | block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
303 |
304 | @classmethod
305 | def from_pretrained(cls, model_type, override_args=None):
306 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
307 | override_args = override_args or {} # default to empty dict
308 | # only dropout can be overridden see more notes below
309 | assert all(k == 'dropout' for k in override_args)
310 | from transformers import GPT2LMHeadModel
311 | print("loading weights from pretrained gpt: %s" % model_type)
312 |
313 | # n_layer, n_head and n_embd are determined from model_type
314 | config_args = {
315 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
316 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
317 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
318 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
319 | }[model_type]
320 | print("forcing vocab_size=50257, block_size=1024, bias=True")
321 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
322 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
323 | config_args['bias'] = True # always True for GPT model checkpoints
324 | # we can override the dropout rate, if desired
325 | if 'dropout' in override_args:
326 | print(f"overriding dropout rate to {override_args['dropout']}")
327 | config_args['dropout'] = override_args['dropout']
328 | # create a from-scratch initialized minGPT model
329 | config = GPTConfig(**config_args)
330 | model = GPT(config)
331 | sd = model.state_dict()
332 | sd_keys = sd.keys()
333 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
334 |
335 | # init a huggingface/transformers model
336 | model_hf = GPT2LMHeadModel.from_pretrained(model_type)
337 | sd_hf = model_hf.state_dict()
338 |
339 | # copy while ensuring all of the parameters are aligned and match in names and shapes
340 | sd_keys_hf = sd_hf.keys()
341 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
342 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
343 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
344 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
345 | # this means that we have to transpose these weights when we import them
346 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
347 | for k in sd_keys_hf:
348 | if any(k.endswith(w) for w in transposed):
349 | # special treatment for the Conv1D weights we need to transpose
350 | assert sd_hf[k].shape[::-1] == sd[k].shape
351 | with torch.no_grad():
352 | sd[k].copy_(sd_hf[k].t())
353 | else:
354 | # vanilla copy over the other parameters
355 | assert sd_hf[k].shape == sd[k].shape
356 | with torch.no_grad():
357 | sd[k].copy_(sd_hf[k])
358 |
359 | return model
360 |
361 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
362 | return configure_optimizers(self.named_parameters(), weight_decay, learning_rate, betas, device_type)
363 |
364 | def estimate_mfu(self, fwdbwd_per_iter, dt):
365 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
366 | # first estimate the number of flops we do per iteration.
367 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
368 | N = self.get_num_params()
369 | cfg = self.config
370 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size
371 | flops_per_token = 6 * N + 12 * L * H * Q * T
372 | flops_per_fwdbwd = flops_per_token * T
373 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
374 | # express our flops throughput as ratio of A100 bfloat16 peak flops
375 | flops_achieved = flops_per_iter * (1.0 / dt) # per second
376 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
377 | mfu = flops_achieved / flops_promised
378 | return mfu
379 |
380 | @torch.no_grad()
381 | def generate(self, idx, coord, max_new_tokens, temperature=1.0, top_k=None, top_p=0.9, use_kv_cache=False):
382 |
383 | if use_kv_cache and (max_new_tokens + idx.shape[-1] - 1) > self.config.block_size:
384 | # print(f"Cannot generate more than {self.config.block_size} tokens with kv cache, setting max new tokens to {self.config.block_size - idx.shape[-1]}")
385 | max_new_tokens = self.config.block_size - idx.shape[-1]
386 |
387 | kv_cache = (
388 | [torch.empty(2, 0, device=idx.device, dtype=idx.dtype) for _ in range(self.config.n_layer)]
389 | if use_kv_cache
390 | else None
391 | )
392 | mask_cache = None
393 | if use_kv_cache:
394 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
395 | mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
396 |
397 | current_coord = coord
398 | one_t = torch.LongTensor([1]).to(coord.device)
399 | for iteration in range(max_new_tokens):
400 |
401 | if not use_kv_cache or (iteration == 0 and idx.shape[-1] > 1):
402 | # if the sequence context is growing too long we must crop it at block_size
403 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
404 | coord_cond = current_coord if current_coord.size(0) <= self.config.block_size else current_coord[-self.config.block_size:]
405 | else:
406 | idx_cond = idx[:, -1:]
407 | coord_cond = current_coord[-1:]
408 | # forward the model to get the logits for the index in the sequence
409 | logits, kv_cache = self(idx_cond, coord_cond, kv_cache=kv_cache if use_kv_cache else None, mask_cache=mask_cache)
410 | # pluck the logits at the final step and scale by desired temperature
411 | logits = logits[:, -1, :] / temperature
412 | # optionally crop the logits to only the top k options
413 | if top_k is not None:
414 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
415 | logits[logits < v[:, [-1]]] = -float('Inf')
416 | # apply softmax to convert logits to (normalized) probabilities
417 |
418 | use_hard_constraints = False
419 | if use_hard_constraints:
420 | # Use hard constraints
421 | # stop can only occur when 3 coords are generated
422 | if idx.shape[1] % 3 != 0:
423 | logits[:, self.vocab_size - 2] = -float('Inf')
424 |
425 | # all z less than last z are not allowed
426 | if idx.shape[1] % 3 == 0:
427 | if idx.shape[1] > 2:
428 | last_z = idx[0, -3]
429 | if last_z > 0:
430 | logits[:, :last_z - 1] = -float('Inf')
431 | if idx.shape[1] % 3 == 1:
432 | if idx.shape[1] > 3:
433 | last_z = idx[0, -1]
434 | last_to_last_z = idx[0, -4]
435 | last_y = idx[0, -3]
436 | if last_z == last_to_last_z:
437 | if last_y > 0:
438 | logits[:, :last_y - 1] = -float('Inf')
439 | if idx.shape[1] % 3 == 2:
440 | if idx.shape[1] > 4:
441 | last_z = idx[0, -2]
442 | last_to_last_z = idx[0, -5]
443 | last_y = idx[0, -1]
444 | last_to_last_y = idx[0, -4]
445 | last_x = idx[0, -3]
446 | if last_z == last_to_last_z and last_to_last_y == last_y:
447 | if last_x > 0:
448 | logits[:, :last_x - 1] = -float('Inf')
449 |
450 | # sample from the distribution
451 | if top_p is not None:
452 | idx_next = top_p_sampling(logits, top_p)
453 | else:
454 | # apply softmax to convert logits to (normalized) probabilities
455 | probs = F.softmax(logits, dim=-1)
456 | idx_next = torch.multinomial(probs, num_samples=1)
457 |
458 | # append sampled index to the running sequence and continue
459 | idx = torch.cat((idx, idx_next), dim=1)
460 | current_coord = torch.cat((current_coord, one_t * (current_coord[-1] + 1) % 3), dim=0)
461 | if idx[0, -1] == (self.vocab_size - 2):
462 | break
463 | return idx
464 |
465 |
466 | def configure_optimizers(named_parameters, weight_decay, learning_rate, betas, device_type, additional_params=None):
467 | # start with all of the candidate parameters
468 | param_dict = {pn: p for pn, p in named_parameters}
469 | # filter out those that do not require grad
470 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
471 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
472 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
473 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
474 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
475 | if additional_params is not None:
476 | if isinstance(additional_params, dict):
477 | for n, p in additional_params:
478 | nodecay_params.append(p)
479 | else:
480 | for additional_param in additional_params:
481 | for n, p in additional_param:
482 | nodecay_params.append(p)
483 | optim_groups = [
484 | {'params': decay_params, 'weight_decay': weight_decay},
485 | {'params': nodecay_params, 'weight_decay': 0.0}
486 | ]
487 | num_decay_params = sum(p.numel() for p in decay_params)
488 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
489 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
490 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
491 | # Create AdamW optimizer and use the fused version if it is available
492 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
493 | use_fused = fused_available and device_type == 'cuda'
494 | extra_args = dict(fused=True) if use_fused else dict()
495 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
496 | print(f"using fused AdamW: {use_fused}")
497 |
498 | return optimizer
499 |
--------------------------------------------------------------------------------
/model/pointnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from trimesh import transformations
3 | import math
4 | import numpy as np
5 | from torch import nn
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from time import time
11 | import numpy as np
12 |
13 |
14 | def timeit(tag, t):
15 | print("{}: {}s".format(tag, time() - t))
16 | return time()
17 |
18 | def pc_normalize(pc):
19 | l = pc.shape[0]
20 | centroid = np.mean(pc, axis=0)
21 | pc = pc - centroid
22 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
23 | pc = pc / m
24 | return pc
25 |
26 | def square_distance(src, dst):
27 | """
28 | Calculate Euclid distance between each two points.
29 |
30 | src^T * dst = xn * xm + yn * ym + zn * zm;
31 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
32 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
33 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
34 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
35 |
36 | Input:
37 | src: source points, [B, N, C]
38 | dst: target points, [B, M, C]
39 | Output:
40 | dist: per-point square distance, [B, N, M]
41 | """
42 | B, N, _ = src.shape
43 | _, M, _ = dst.shape
44 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
45 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
46 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
47 | return dist
48 |
49 |
50 | def index_points(points, idx):
51 | """
52 |
53 | Input:
54 | points: input points data, [B, N, C]
55 | idx: sample index data, [B, S]
56 | Return:
57 | new_points:, indexed points data, [B, S, C]
58 | """
59 | device = points.device
60 | B = points.shape[0]
61 | view_shape = list(idx.shape)
62 | view_shape[1:] = [1] * (len(view_shape) - 1)
63 | repeat_shape = list(idx.shape)
64 | repeat_shape[0] = 1
65 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
66 | new_points = points[batch_indices, idx, :]
67 | return new_points
68 |
69 |
70 | def farthest_point_sample(xyz, npoint):
71 | """
72 | Input:
73 | xyz: pointcloud data, [B, N, 3]
74 | npoint: number of samples
75 | Return:
76 | centroids: sampled pointcloud index, [B, npoint]
77 | """
78 | device = xyz.device
79 | B, N, C = xyz.shape
80 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
81 | distance = torch.ones(B, N).to(device) * 1e10
82 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
83 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
84 | for i in range(npoint):
85 | centroids[:, i] = farthest
86 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
87 | dist = torch.sum((xyz - centroid) ** 2, -1)
88 | mask = dist < distance
89 | distance[mask] = dist[mask]
90 | farthest = torch.max(distance, -1)[1]
91 | return centroids
92 |
93 |
94 | def query_ball_point(radius, nsample, xyz, new_xyz):
95 | """
96 | Input:
97 | radius: local region radius
98 | nsample: max sample number in local region
99 | xyz: all points, [B, N, 3]
100 | new_xyz: query points, [B, S, 3]
101 | Return:
102 | group_idx: grouped points index, [B, S, nsample]
103 | """
104 | device = xyz.device
105 | B, N, C = xyz.shape
106 | _, S, _ = new_xyz.shape
107 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
108 | sqrdists = square_distance(new_xyz, xyz)
109 | group_idx[sqrdists > radius ** 2] = N
110 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
111 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
112 | mask = group_idx == N
113 | group_idx[mask] = group_first[mask]
114 | return group_idx
115 |
116 |
117 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
118 | """
119 | Input:
120 | npoint:
121 | radius:
122 | nsample:
123 | xyz: input points position data, [B, N, 3]
124 | points: input points data, [B, N, D]
125 | Return:
126 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
127 | new_points: sampled points data, [B, npoint, nsample, 3+D]
128 | """
129 | B, N, C = xyz.shape
130 | S = npoint
131 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
132 | new_xyz = index_points(xyz, fps_idx)
133 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
134 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
135 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
136 |
137 | if points is not None:
138 | grouped_points = index_points(points, idx)
139 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
140 | else:
141 | new_points = grouped_xyz_norm
142 | if returnfps:
143 | return new_xyz, new_points, grouped_xyz, fps_idx
144 | else:
145 | return new_xyz, new_points
146 |
147 |
148 | def sample_and_group_all(xyz, points):
149 | """
150 | Input:
151 | xyz: input points position data, [B, N, 3]
152 | points: input points data, [B, N, D]
153 | Return:
154 | new_xyz: sampled points position data, [B, 1, 3]
155 | new_points: sampled points data, [B, 1, N, 3+D]
156 | """
157 | device = xyz.device
158 | B, N, C = xyz.shape
159 | new_xyz = torch.zeros(B, 1, C).to(device)
160 | grouped_xyz = xyz.view(B, 1, N, C)
161 | if points is not None:
162 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
163 | else:
164 | new_points = grouped_xyz
165 | return new_xyz, new_points
166 |
167 |
168 | class PointNetSetAbstraction(nn.Module):
169 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
170 | super(PointNetSetAbstraction, self).__init__()
171 | self.npoint = npoint
172 | self.radius = radius
173 | self.nsample = nsample
174 | self.mlp_convs = nn.ModuleList()
175 | self.mlp_bns = nn.ModuleList()
176 | last_channel = in_channel
177 | for out_channel in mlp:
178 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
179 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
180 | last_channel = out_channel
181 | self.group_all = group_all
182 |
183 | def forward(self, xyz, points):
184 | """
185 | Input:
186 | xyz: input points position data, [B, C, N]
187 | points: input points data, [B, D, N]
188 | Return:
189 | new_xyz: sampled points position data, [B, C, S]
190 | new_points_concat: sample points feature data, [B, D', S]
191 | """
192 | xyz = xyz.permute(0, 2, 1)
193 | if points is not None:
194 | points = points.permute(0, 2, 1)
195 |
196 | if self.group_all:
197 | new_xyz, new_points = sample_and_group_all(xyz, points)
198 | else:
199 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
200 | # new_xyz: sampled points position data, [B, npoint, C]
201 | # new_points: sampled points data, [B, npoint, nsample, C+D]
202 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
203 | for i, conv in enumerate(self.mlp_convs):
204 | bn = self.mlp_bns[i]
205 | new_points = F.relu(bn(conv(new_points)))
206 |
207 | new_points = torch.max(new_points, 2)[0]
208 | new_xyz = new_xyz.permute(0, 2, 1)
209 | return new_xyz, new_points
210 |
211 |
212 | class PointNetSetAbstractionMsg(nn.Module):
213 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
214 | super(PointNetSetAbstractionMsg, self).__init__()
215 | self.npoint = npoint
216 | self.radius_list = radius_list
217 | self.nsample_list = nsample_list
218 | self.conv_blocks = nn.ModuleList()
219 | self.bn_blocks = nn.ModuleList()
220 | for i in range(len(mlp_list)):
221 | convs = nn.ModuleList()
222 | bns = nn.ModuleList()
223 | last_channel = in_channel + 3
224 | for out_channel in mlp_list[i]:
225 | convs.append(nn.Conv2d(last_channel, out_channel, 1))
226 | bns.append(nn.BatchNorm2d(out_channel))
227 | last_channel = out_channel
228 | self.conv_blocks.append(convs)
229 | self.bn_blocks.append(bns)
230 |
231 | def forward(self, xyz, points):
232 | """
233 | Input:
234 | xyz: input points position data, [B, C, N]
235 | points: input points data, [B, D, N]
236 | Return:
237 | new_xyz: sampled points position data, [B, C, S]
238 | new_points_concat: sample points feature data, [B, D', S]
239 | """
240 | xyz = xyz.permute(0, 2, 1)
241 | if points is not None:
242 | points = points.permute(0, 2, 1)
243 |
244 | B, N, C = xyz.shape
245 | S = self.npoint
246 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
247 | new_points_list = []
248 | for i, radius in enumerate(self.radius_list):
249 | K = self.nsample_list[i]
250 | group_idx = query_ball_point(radius, K, xyz, new_xyz)
251 | grouped_xyz = index_points(xyz, group_idx)
252 | grouped_xyz -= new_xyz.view(B, S, 1, C)
253 | if points is not None:
254 | grouped_points = index_points(points, group_idx)
255 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
256 | else:
257 | grouped_points = grouped_xyz
258 |
259 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
260 | for j in range(len(self.conv_blocks[i])):
261 | conv = self.conv_blocks[i][j]
262 | bn = self.bn_blocks[i][j]
263 | grouped_points = F.relu(bn(conv(grouped_points)))
264 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
265 | new_points_list.append(new_points)
266 |
267 | new_xyz = new_xyz.permute(0, 2, 1)
268 | new_points_concat = torch.cat(new_points_list, dim=1)
269 | return new_xyz, new_points_concat
270 |
271 |
272 | class PointNetFeaturePropagation(nn.Module):
273 | def __init__(self, in_channel, mlp):
274 | super(PointNetFeaturePropagation, self).__init__()
275 | self.mlp_convs = nn.ModuleList()
276 | self.mlp_bns = nn.ModuleList()
277 | last_channel = in_channel
278 | for out_channel in mlp:
279 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
280 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
281 | last_channel = out_channel
282 |
283 | def forward(self, xyz1, xyz2, points1, points2):
284 | """
285 | Input:
286 | xyz1: input points position data, [B, C, N]
287 | xyz2: sampled input points position data, [B, C, S]
288 | points1: input points data, [B, D, N]
289 | points2: input points data, [B, D, S]
290 | Return:
291 | new_points: upsampled points data, [B, D', N]
292 | """
293 | xyz1 = xyz1.permute(0, 2, 1)
294 | xyz2 = xyz2.permute(0, 2, 1)
295 |
296 | points2 = points2.permute(0, 2, 1)
297 | B, N, C = xyz1.shape
298 | _, S, _ = xyz2.shape
299 |
300 | if S == 1:
301 | interpolated_points = points2.repeat(1, N, 1)
302 | else:
303 | dists = square_distance(xyz1, xyz2)
304 | dists, idx = dists.sort(dim=-1)
305 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
306 |
307 | dist_recip = 1.0 / (dists + 1e-8)
308 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
309 | weight = dist_recip / norm
310 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
311 |
312 | if points1 is not None:
313 | points1 = points1.permute(0, 2, 1)
314 | new_points = torch.cat([points1, interpolated_points], dim=-1)
315 | else:
316 | new_points = interpolated_points
317 |
318 | new_points = new_points.permute(0, 2, 1)
319 | for i, conv in enumerate(self.mlp_convs):
320 | bn = self.mlp_bns[i]
321 | new_points = F.relu(bn(conv(new_points)))
322 | return new_points
323 |
324 |
325 | class PointNet(nn.Module):
326 |
327 | def __init__(self, num_class, normal_channel=True):
328 | super(PointNet, self).__init__()
329 | in_channel = 6 if normal_channel else 3
330 | self.normal_channel = normal_channel
331 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
332 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
333 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
334 | self.fc1 = nn.Linear(1024, 512)
335 | self.bn1 = nn.BatchNorm1d(512)
336 | self.drop1 = nn.Dropout(0.4)
337 | self.fc2 = nn.Linear(512, 256)
338 | self.bn2 = nn.BatchNorm1d(256)
339 | self.drop2 = nn.Dropout(0.4)
340 | self.fc3 = nn.Linear(256, num_class)
341 |
342 | def forward(self, xyz, return_feats=False):
343 | B, _, _ = xyz.shape
344 | if self.normal_channel:
345 | norm = xyz[:, 3:, :]
346 | xyz = xyz[:, :3, :]
347 | else:
348 | norm = None
349 | l1_xyz, l1_points = self.sa1(xyz, norm)
350 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
351 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
352 | x = l3_points.view(B, 1024)
353 | x = self.drop1(F.relu(self.bn1(self.fc1(x))))
354 | x = self.drop2(F.relu(self.bn2(self.fc2(x))))
355 | feats = x
356 | x = self.fc3(x)
357 | x = F.log_softmax(x, -1)
358 | if return_feats:
359 | return x, feats
360 | return x, l3_points
361 |
362 | @torch.no_grad()
363 | def classifier_guided_filter(self, mesh, expected_category):
364 | interesting_categories = {
365 | '03001627': ([8, 30], 0.05),
366 | '04379243': ([33, 3], 0.03),
367 | }
368 | category = interesting_categories[expected_category][0]
369 | threshold = interesting_categories[expected_category][1]
370 | points = get_point_cloud(mesh, nvotes=8).cuda()
371 | points = points.transpose(2, 1)
372 | pred, _ = self(points)
373 | pred = pred.mean(dim=0)
374 | pred_probability = torch.nn.functional.softmax(pred.unsqueeze(0), dim=-1)[0]
375 | pval = pred_probability[category].max().item()
376 | return pval > threshold
377 |
378 |
379 | def get_pointnet_classifier():
380 | classifier = PointNet(40, normal_channel=False)
381 | checkpoint = torch.load('pretrained/pointnet.pth')
382 | classifier.load_state_dict(checkpoint['model_state_dict'])
383 | classifier = classifier.eval()
384 | return classifier
385 |
386 |
387 | def pc_normalize(pc):
388 | centroid = np.mean(pc, axis=0)
389 | pc = pc - centroid
390 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
391 | pc = pc / m
392 | return pc
393 |
394 |
395 | def get_point_cloud(mesh, nvotes=4):
396 | point_batch = []
397 | for _ in range(nvotes):
398 | tmesh = mesh
399 | rot_matrix = transformations.rotation_matrix(-math.pi / 2, [1, 0, 0], [0, 0, 0])
400 | tmesh = tmesh.apply_transform(rot_matrix)
401 | rot_matrix = transformations.rotation_matrix(-math.pi / 2, [0, 1, 0], [0, 0, 0])
402 | tmesh = tmesh.apply_transform(rot_matrix)
403 | point_set = pc_normalize(tmesh.sample(1024))
404 | point_batch.append(point_set[np.newaxis, :, :])
405 | point_batch = torch.from_numpy(np.concatenate(point_batch, axis=0)).float()
406 | return point_batch
407 |
--------------------------------------------------------------------------------
/model/softargmax.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def softargmax(logits):
5 | smax = torch.nn.functional.softmax(logits, dim=-1)
6 | tokens = torch.tensor(list(range(logits.shape[-1])), dtype=torch.float32, device=logits.device).reshape(1, 1, -1).expand_as(smax)
7 | return torch.sum(tokens * smax, dim=-1)
8 |
9 |
10 | if __name__ == '__main__':
11 | logits_ = torch.tensor([
12 | [[[2, 0.1, 0.2, 0.5]],
13 | [[0.7, 0.1, 0.2, 5]],
14 | [[0, 10, 0.4, 0.5]],
15 | ],
16 | [[[0, 0.1, 0.2, 0.9]],
17 | [[0.7, 0.1, 100, 5]],
18 | [[0, 1, 4, 5]],
19 | ]
20 | ], requires_grad=True)
21 | print(softargmax(logits_), softargmax(logits_).requires_grad)
22 |
--------------------------------------------------------------------------------
/model/transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | from dataset import get_shifted_sequence
8 | from model.transformer_base import TransformerBase
9 | from model.nanogpt import Block, LayerNorm
10 | from util.misc import top_p_sampling
11 | from tqdm import tqdm
12 |
13 |
14 | class QuantSoupTransformer(TransformerBase):
15 |
16 | def __init__(self, config, vq_config):
17 | super().__init__()
18 | assert config.block_size is not None
19 | self.config = config
20 | self.padding_idx = 2
21 | self.tokens_per_face = config.finemb_size
22 | self.finemb_size = 3 + config.finemb_size # 3 for start, stop pad, 3 for fin
23 | self.foutemb_size = 3 + config.foutemb_size
24 | vocab_size = vq_config.n_embed + 1 + 1 + 1 # +1 for start, +1 for stop, +1 for pad
25 | self.vocab_size = vocab_size
26 | print('Model Vocab Size:', vocab_size)
27 | print('Model Padding Index:', self.padding_idx)
28 | print('Model Fin Size:', self.finemb_size)
29 | print('Model Fout Size:', self.foutemb_size)
30 | self.input_layer = nn.Linear(vq_config.embed_dim, config.n_embd)
31 | self.extra_embeds = nn.Embedding(3, config.n_embd, padding_idx=self.padding_idx)
32 | self.transformer = nn.ModuleDict(dict(
33 | wpe=nn.Embedding(config.block_size, config.n_embd),
34 | wfie=nn.Embedding(self.finemb_size, config.n_embd, padding_idx=self.padding_idx),
35 | wfoe=nn.Embedding(self.foutemb_size, config.n_embd, padding_idx=self.padding_idx),
36 | drop=nn.Dropout(config.dropout),
37 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
38 | ln_f=LayerNorm(config.n_embd, bias=config.bias),
39 | ))
40 | self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False)
41 |
42 | # init all weights
43 | self.apply(self._init_weights)
44 | # apply special scaled init to the residual projections, per GPT-2 paper
45 | for pn, p in self.named_parameters():
46 | if pn.endswith('c_proj.weight'):
47 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
48 |
49 | # report number of parameters
50 | print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
51 |
52 | def forward(self, idx, fin, fout, tokenizer, targets=None, kv_cache=None, mask_cache=None):
53 | use_kv_cache = kv_cache is not None
54 | device = idx.device
55 | b, t = idx.size()
56 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
57 |
58 | embed = torch.zeros((b * t, self.config.n_embd), dtype=torch.float32, device=device)
59 | idx_in_extra = torch.isin(idx, torch.LongTensor([0, 1, 2]).to(device)).reshape(-1)
60 | idx_flat = idx.reshape(-1)
61 | embed[idx_in_extra, :] = self.extra_embeds(idx_flat[idx_in_extra])
62 | embed[~idx_in_extra, :] = self.input_layer(tokenizer.embed(idx_flat[~idx_in_extra] - 3))
63 | tok_emb = embed.reshape(b, t, -1) # token embeddings of shape (b, t, n_embd)
64 | finemb = self.transformer.wfie(fin) # face inner embeddings of shape (t, n_embd)
65 | foutemb = self.transformer.wfoe(fout) # face outer embeddings of shape (t, n_embd)
66 |
67 | # position embedding
68 |
69 | if kv_cache is not None and kv_cache[0].numel():
70 | pos = kv_cache[0].shape[-2] # kv_cache of shape: num_layers * (2, B, nh, T, hs)
71 | pos_emb = self.transformer.wpe.weight[None, pos] # 1 x n_embd
72 | mask = mask_cache.index_select(2, torch.LongTensor([pos]).to(pos_emb.device))[:, :, :, :pos + 1]
73 | else:
74 | pos = torch.tensor([i for i in range(t)], dtype=torch.long, device=device)
75 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
76 | mask = None
77 |
78 | sum_emb = tok_emb + pos_emb + finemb + foutemb
79 | # print('shapes:', tok_emb.shape, pos_emb.shape, coord_emb.shape)
80 | x = self.transformer.drop(sum_emb)
81 |
82 | # apply multiple transformer blocks
83 | new_kv_cache = []
84 | kv_cache = kv_cache or [None] * self.config.n_layer
85 |
86 | for block, kv_cache_layer in zip(self.transformer.h, kv_cache):
87 | x, new_kv = block(x, kv_cache_layer, mask)
88 | new_kv_cache.append(new_kv)
89 |
90 | x = self.transformer.ln_f(x)
91 |
92 | if targets is not None:
93 | # if we are given some desired targets also calculate the loss
94 | logits = self.lm_head(x)
95 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=self.padding_idx)
96 | else:
97 | # inference-time mini-optimization: only forward the lm_head on the very last position
98 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
99 | loss = None
100 |
101 | if not use_kv_cache:
102 | return logits, loss
103 | else:
104 | return logits, new_kv_cache
105 |
106 | @torch.no_grad()
107 | def generate(self, idx, fin, fout, tokenizer, max_new_tokens=10000, temperature=1.0, top_k=None, top_p=0.9, use_kv_cache=False):
108 | """
109 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
110 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
111 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
112 | """
113 | if use_kv_cache and (max_new_tokens + idx.shape[-1] - 1) > self.config.block_size:
114 | # print(f"Cannot generate more than {self.config.block_size} tokens with kv cache, setting max new tokens to {self.config.block_size - idx.shape[-1]}")
115 | max_new_tokens = self.config.block_size - idx.shape[-1]
116 |
117 | kv_cache = (
118 | [torch.empty(2, 0, device=idx.device, dtype=idx.dtype) for _ in range(self.config.n_layer)]
119 | if use_kv_cache
120 | else None
121 | )
122 | mask_cache = None
123 | if use_kv_cache:
124 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
125 | mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
126 |
127 | current_fin = fin
128 | current_fout = fout
129 | one_t = torch.LongTensor([1]).to(fin.device)
130 | for iteration in range(max_new_tokens):
131 |
132 | if not use_kv_cache or (iteration == 0 and idx.shape[-1] > 1):
133 | # if the sequence context is growing too long we must crop it at block_size
134 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
135 | fin_cond = current_fin if current_fin.size(1) <= self.config.block_size else current_fin[:, -self.config.block_size:]
136 | fout_cond = current_fout if current_fout.size(1) <= self.config.block_size else current_fout[:, -self.config.block_size:]
137 | fout_cond = torch.from_numpy(get_shifted_sequence(fout_cond[0].cpu().numpy())).to(idx_cond.device).unsqueeze(0)
138 | else:
139 | idx_cond = idx[:, -1:]
140 | fin_cond = current_fin[:, -1:]
141 | fout_cond = current_fout[:, -1:] # note: don't need shifting since we assume block_size is huge enough to not need shifting
142 | # forward the model to get the logits for the index in the sequence
143 | logits, kv_cache = self(idx_cond, fin_cond, fout_cond, tokenizer, kv_cache=kv_cache if use_kv_cache else None, mask_cache=mask_cache)
144 | # pluck the logits at the final step and scale by desired temperature
145 | logits = logits[:, -1, :] / temperature
146 | # optionally crop the logits to only the top k options
147 | if top_k is not None:
148 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
149 | logits[logits < v[:, [-1]]] = -float('Inf')
150 |
151 | # TODO: Introduce hard constraints
152 |
153 | # sample from the distribution
154 | # apply softmax to convert logits to (normalized) probabilities
155 | if top_p is not None:
156 | idx_next = top_p_sampling(logits, top_p)
157 | else:
158 | probs = F.softmax(logits, dim=-1)
159 | idx_next = torch.multinomial(probs, num_samples=1)
160 | # append sampled index to the running sequence and continue
161 | idx = torch.cat((idx, idx_next), dim=1)
162 | last_fin_cond = current_fin[0, -1]
163 | if last_fin_cond == self.finemb_size - 1 or (iteration == 0 and idx.shape[-1] == 2):
164 | current_fin = torch.cat((current_fin, (3 * one_t[0]).unsqueeze(0).unsqueeze(0)), dim=1)
165 | current_fout = torch.cat((current_fout, (current_fout[0, -1] + 1).unsqueeze(0).unsqueeze(0)), dim=1)
166 | else:
167 | current_fin = torch.cat((current_fin, (current_fin[0, -1] + 1).unsqueeze(0).unsqueeze(0)), dim=1)
168 | current_fout = torch.cat((current_fout, (current_fout[0, -1]).unsqueeze(0).unsqueeze(0)), dim=1)
169 | if idx_next == 1:
170 | return idx
171 | return None
172 |
173 |
174 | @torch.no_grad()
175 | def generate_with_beamsearch(self, idx, fin, fout, tokenizer, max_new_tokens=10000, use_kv_cache=False, beam_width=6):
176 |
177 | backup_beams = []
178 | backup_beam_prob = []
179 | max_new_tokens = self.config.block_size - idx.shape[-1]
180 |
181 | kv_cache = (
182 | [torch.empty(2, 0, device=idx.device, dtype=idx.dtype) for _ in range(self.config.n_layer)]
183 | if use_kv_cache
184 | else None
185 | )
186 |
187 | mask_cache = None
188 |
189 | if use_kv_cache:
190 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
191 | mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
192 |
193 | current_fin = fin
194 | current_fout = fout
195 | one_t = torch.LongTensor([1]).to(fin.device)
196 |
197 | idx = idx.repeat((beam_width, 1, 1)).transpose(0, 1).flatten(end_dim=-2)
198 | current_fin = current_fin.repeat((beam_width, 1, 1)).transpose(0, 1).flatten(end_dim=-2)
199 | current_fout = current_fout.repeat((beam_width, 1, 1)).transpose(0, 1).flatten(end_dim=-2)
200 |
201 | logits, kv_cache = self(idx, fin, fout, tokenizer, kv_cache=kv_cache if use_kv_cache else None, mask_cache=mask_cache)
202 |
203 | vocabulary_size = logits.shape[-1]
204 | probabilities, top_k_indices = logits[0, 0, :].squeeze().log_softmax(-1).topk(k=beam_width, axis=-1)
205 |
206 | next_chars = top_k_indices.reshape(-1, 1)
207 | idx = torch.cat((idx, next_chars), axis=-1)
208 |
209 | last_fin_cond = current_fin[0, -1] # same for all beams
210 | if last_fin_cond == self.finemb_size - 1 or (idx.shape[-1] == 2):
211 | current_fin = torch.cat((current_fin, (3 * one_t[0]).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1)
212 | current_fout = torch.cat((current_fout, (current_fout[0, -1] + 1).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1)
213 | else:
214 | current_fin = torch.cat((current_fin, (current_fin[0, -1] + 1).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1)
215 | current_fout = torch.cat((current_fout, current_fout[0, -1].unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1)
216 |
217 | for iteration in tqdm(range(max_new_tokens - 1), desc='beam_search'):
218 | if not use_kv_cache:
219 | idx_cond = idx
220 | fin_cond = current_fin
221 | fout_cond = current_fout
222 | else:
223 | idx_cond = idx[:, -1:]
224 | fin_cond = current_fin[:, -1:]
225 | fout_cond = current_fout[:, -1:]
226 |
227 | # forward the model to get the logits for the index in the sequence
228 | logits, kv_cache = self(idx_cond, fin_cond, fout_cond, tokenizer, kv_cache=kv_cache if use_kv_cache else None, mask_cache=mask_cache)
229 |
230 | next_probabilities = logits.log_softmax(-1)
231 |
232 | next_probabilities = next_probabilities.reshape((-1, beam_width, next_probabilities.shape[-1]))
233 | probabilities = probabilities.unsqueeze(-1) + next_probabilities
234 | probabilities = probabilities.flatten(start_dim=1)
235 | probabilities, top_k_indices = probabilities.topk(k=beam_width, axis=-1)
236 | next_indices = torch.remainder(top_k_indices, vocabulary_size).flatten().unsqueeze(-1)
237 | best_candidates = (top_k_indices / vocabulary_size).long()
238 | best_candidates += torch.arange(idx.shape[0] // beam_width, device=idx.device).unsqueeze(-1) * beam_width
239 | idx = idx[best_candidates].flatten(end_dim=-2)
240 | for block_idx in range(len(kv_cache)):
241 | kv_cache[block_idx] = kv_cache[block_idx][:, best_candidates.flatten(), :, :, :]
242 | idx = torch.cat((idx, next_indices), axis=1)
243 |
244 | # update fin and fout
245 | last_fin_cond = current_fin[0, -1] # same for all beams
246 | if last_fin_cond == self.finemb_size - 1 or (idx.shape[-1] == 2):
247 | current_fin = torch.cat((current_fin, (3 * one_t[0]).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1)
248 | current_fout = torch.cat((current_fout, (current_fout[0, -1] + 1).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1)
249 | else:
250 | current_fin = torch.cat((current_fin, (current_fin[0, -1] + 1).unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1)
251 | current_fout = torch.cat((current_fout, current_fout[0, -1].unsqueeze(0).unsqueeze(0).expand((beam_width, -1))), dim=1)
252 |
253 | amax = probabilities.flatten().argmax()
254 | if idx[amax, -1] == 1:
255 | return idx[amax: amax + 1, :]
256 | for beam_idx in range(beam_width):
257 | if idx[beam_idx, -1] == 1:
258 | backup_beams.append(idx[beam_idx: beam_idx + 1, :])
259 | backup_beam_prob.append(probabilities[0, beam_idx].item())
260 | return None
261 |
--------------------------------------------------------------------------------
/model/transformer_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from model.nanogpt import configure_optimizers
5 |
6 |
7 | class TransformerBase(nn.Module):
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def get_num_params(self):
13 | """
14 | Return the number of parameters in the model.
15 | For non-embedding count (default), the position embeddings get subtracted.
16 | The token embeddings would too, except due to the parameter sharing these
17 | params are actually used as weights in the final layer, so we include them.
18 | """
19 | n_params = sum(p.numel() for n,p in self.named_parameters())
20 | return n_params
21 |
22 | def _init_weights(self, module):
23 | if isinstance(module, nn.Linear):
24 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
25 | if module.bias is not None:
26 | torch.nn.init.zeros_(module.bias)
27 | elif isinstance(module, nn.Embedding):
28 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
29 |
30 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
31 | return configure_optimizers(self.named_parameters(), weight_decay, learning_rate, betas, device_type)
32 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch_geometric==2.4.0
2 | torchmetrics==0.11.0
3 | cosine-annealing-warmup @ git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup@12d03c07553aedd3d9e9155e2b3e31ce8c64081a
4 | easydict==1.10
5 | einops==0.7.0
6 | einops-exts==0.0.4
7 | hydra-core==1.3.2
8 | imageio==2.31.0
9 | imageio-ffmpeg==0.4.2
10 | lightning @ git+https://github.com/Lightning-AI/lightning@eb8b314f7cbb5e87e0d149f95f45f1409a7715c2
11 | lightning-utilities==0.8.0
12 | matplotlib==3.7.2
13 | matplotlib-inline==0.1.6
14 | moviepy==1.0.3
15 | numpy==1.24.4
16 | omegaconf==2.3.0
17 | Pillow==10.3.0
18 | pymeshlab==2022.2.post2
19 | pytorch-lightning==2.0.3
20 | scikit-image==0.21.0
21 | scikit-learn==1.2.1
22 | scipy==1.10.1
23 | tqdm==4.64.1
24 | transformers==4.48.0
25 | trimesh==3.18.0
26 | vector-quantize-pytorch==1.8.1
27 | wandb==0.13.4
28 | randomname
29 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import signal
3 | import sys
4 | import traceback
5 | from pathlib import Path
6 | from random import randint
7 | import datetime
8 |
9 | import torch
10 | import wandb
11 | import randomname
12 | from pytorch_lightning.strategies.ddp import DDPStrategy
13 |
14 | from pytorch_lightning import seed_everything, Trainer
15 | from pytorch_lightning.callbacks import ModelCheckpoint
16 | from pytorch_lightning.loggers.wandb import WandbLogger
17 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
18 | from vector_quantize_pytorch import ResidualVQ
19 |
20 | from model.encoder import GraphEncoder
21 | from model.decoder import resnet34_decoder
22 | from util.filesystem_logger import FilesystemLogger
23 | from util.misc import get_parameters_from_state_dict
24 |
25 |
26 | def print_traceback_handler(sig, _frame):
27 | print(f'Received signal {sig}')
28 | bt = ''.join(traceback.format_stack())
29 | print(f'Requested stack trace:\n{bt}')
30 |
31 |
32 | def quit_handler(sig, frame):
33 | print(f'Received signal {sig}, quitting.')
34 | sys.exit(1)
35 |
36 |
37 | def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler):
38 | print(f'Setting signal {sig} handler {handler}')
39 | signal.signal(sig, handler)
40 |
41 |
42 | def register_quit_signal_handlers(sig=signal.SIGUSR2, handler=quit_handler):
43 | print(f'Setting signal {sig} handler {handler}')
44 | signal.signal(sig, handler)
45 |
46 |
47 | def generate_experiment_name(name, config):
48 | if config.resume is not None:
49 | experiment = Path(config.resume).parents[1].name
50 | os.environ['experiment'] = experiment
51 | elif not os.environ.get('experiment'):
52 | experiment = f"{datetime.datetime.now().strftime('%m%d%H%M')}_{name}_{config.experiment}_{randomname.get_name()}"
53 | os.environ['experiment'] = experiment
54 | else:
55 | experiment = os.environ['experiment']
56 | return experiment
57 |
58 |
59 | def create_trainer(name, config):
60 | if not config.wandb_main and config.suffix == '':
61 | config.suffix = '-dev'
62 | config.experiment = generate_experiment_name(name, config)
63 | if config.val_check_interval > 1:
64 | config.val_check_interval = int(config.val_check_interval)
65 | if config.seed is None:
66 | config.seed = randint(0, 999)
67 |
68 | # config.dataset_root = Path(config.dataset_root)
69 |
70 | seed_everything(1337 + config.seed)
71 | torch.backends.cuda.matmul.allow_tf32 = True # type: ignore # allow tf32 on matmul
72 | torch.backends.cudnn.allow_tf32 = True # type: ignore # allow tf32 on cudnn
73 | torch.multiprocessing.set_sharing_strategy('file_system') # possible fix for the "OSError: too many files" exception
74 |
75 | register_debug_signal_handlers()
76 | register_quit_signal_handlers()
77 |
78 | # noinspection PyUnusedLocal
79 | filesystem_logger = FilesystemLogger(config)
80 |
81 | # use wandb logger instead
82 | if config.logger == 'wandb':
83 | logger = WandbLogger(project=f'{name}{config.suffix}', name=config.experiment, id=config.experiment, settings=wandb.Settings(start_method='thread'))
84 | else:
85 | logger = TensorBoardLogger(name='tb', save_dir=(Path("runs") / config.experiment))
86 |
87 | checkpoint_callback = ModelCheckpoint(
88 | dirpath=(Path("runs") / config.experiment / "checkpoints"),
89 | save_top_k=-1,
90 | verbose=False,
91 | every_n_epochs=config.save_epoch,
92 | filename='{epoch:02d}-{global_step}',
93 | auto_insert_metric_name=False,
94 | )
95 |
96 | gpu_count = torch.cuda.device_count()
97 |
98 | precision = 'bf16' if torch.cuda.is_bf16_supported() else 16
99 | precision = 32
100 |
101 | if gpu_count > 1:
102 | trainer = Trainer(
103 | accelerator='gpu',
104 | strategy=DDPStrategy(find_unused_parameters=False),
105 | num_nodes=1,
106 | precision=precision,
107 | devices=gpu_count,
108 | num_sanity_val_steps=config.sanity_steps,
109 | max_epochs=config.max_epoch,
110 | limit_val_batches=config.val_check_percent,
111 | callbacks=[checkpoint_callback],
112 | val_check_interval=float(min(config.val_check_interval, 1)),
113 | check_val_every_n_epoch=max(1, config.val_check_interval),
114 | logger=logger,
115 | deterministic=False,
116 | benchmark=True,
117 | )
118 | elif gpu_count == 1:
119 | trainer = Trainer(
120 | devices=[0],
121 | accelerator='gpu',
122 | precision=precision,
123 | strategy=DDPStrategy(find_unused_parameters=False),
124 | num_sanity_val_steps=config.sanity_steps,
125 | max_epochs=config.max_epoch,
126 | limit_val_batches=config.val_check_percent,
127 | callbacks=[checkpoint_callback],
128 | val_check_interval=float(min(config.val_check_interval, 1)),
129 | check_val_every_n_epoch=max(1, config.val_check_interval),
130 | logger=logger,
131 | deterministic=False,
132 | benchmark=True,
133 | )
134 | else:
135 | trainer = Trainer(
136 | accelerator='cpu',
137 | precision=precision,
138 | num_sanity_val_steps=config.sanity_steps,
139 | max_epochs=config.max_epoch,
140 | limit_val_batches=config.val_check_percent,
141 | callbacks=[checkpoint_callback],
142 | val_check_interval=float(min(config.val_check_interval, 1)),
143 | check_val_every_n_epoch=max(1, config.val_check_interval),
144 | logger=logger,
145 | deterministic=False,
146 | benchmark=True,
147 | )
148 | return trainer
149 |
150 |
151 | def step(opt, modules):
152 | for module in modules:
153 | for param in module.parameters():
154 | if param.grad is not None:
155 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
156 | torch.nn.utils.clip_grad_norm_(module.parameters(), 1) # type: ignore
157 | opt.step()
158 |
159 |
160 | def create_conv_batch(encoded_features, batch, batch_size, device):
161 | conv_input, conv_mask = [], []
162 | max_sequence_length = 0
163 | for k in range(batch_size):
164 | features = encoded_features[batch == k, :].T.contiguous().unsqueeze(0)
165 | max_sequence_length = max(max_sequence_length, features.shape[2])
166 | conv_input.append(features)
167 | conv_mask.append(torch.ones([features.shape[2]], device=device, dtype=torch.bool))
168 | for k in range(batch_size):
169 | conv_input[k] = torch.nn.functional.pad(conv_input[k], (0, max_sequence_length - conv_input[k].shape[2]), 'replicate')
170 | conv_mask[k] = torch.nn.functional.pad(conv_mask[k], (0, max_sequence_length - conv_mask[k].shape[0]), 'constant', False)
171 | conv_input = torch.cat(conv_input, dim=0)
172 | conv_mask = torch.cat(conv_mask, dim=0)
173 | return conv_input, conv_mask
174 |
175 |
176 | def get_rvqvae_v0_all(config, resume):
177 | encoder, pre_quant, post_quant, vq = get_rvqvae_v0_encoder_vq(config, resume)
178 | decoder = get_rvqvae_v0_decoder(config, resume)
179 | return encoder, decoder, pre_quant, post_quant, vq
180 |
181 |
182 | def get_rvqvae_v0_encoder_vq(config, resume):
183 | state_dict = torch.load(resume, map_location="cpu")["state_dict"]
184 | encoder = GraphEncoder(no_max_pool=config.g_no_max_pool, aggr=config.g_aggr, graph_conv=config.graph_conv, use_point_features=config.use_point_feats)
185 | pre_quant = torch.nn.Linear(512, config.embed_dim)
186 | post_quant = torch.nn.Linear(config.embed_dim, 512)
187 |
188 | encoder.load_state_dict(get_parameters_from_state_dict(state_dict, "encoder"))
189 | pre_quant.load_state_dict(get_parameters_from_state_dict(state_dict, "pre_quant"))
190 | post_quant.load_state_dict(get_parameters_from_state_dict(state_dict, "post_quant"))
191 |
192 | vq = ResidualVQ(
193 | dim=config.embed_dim,
194 | codebook_size=config.n_embed, # codebook size
195 | num_quantizers=config.embed_levels,
196 | commitment_weight=config.embed_loss_weight, # the weight on the commitment loss
197 | stochastic_sample_codes=True,
198 | sample_codebook_temp=0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
199 | shared_codebook=config.embed_share,
200 | decay=config.code_decay,
201 | )
202 | vq.load_state_dict(get_parameters_from_state_dict(state_dict, "vq"))
203 | return encoder, pre_quant, post_quant, vq
204 |
205 |
206 | def get_rvqvae_v0_decoder(config, resume, device=torch.device("cpu")):
207 | state_dict = torch.load(resume, map_location="cpu")["state_dict"]
208 | decoder = resnet34_decoder(512, config.num_tokens - 2, config.ce_output)
209 | decoder.load_state_dict(get_parameters_from_state_dict(state_dict, "decoder"))
210 | decoder = decoder.to(device).eval()
211 | return decoder
212 |
213 |
214 | def get_rvqvae_v1_encoder_vq(config, resume):
215 | state_dict = torch.load(resume, map_location="cpu")["state_dict"]
216 | encoder = GraphEncoder(no_max_pool=config.g_no_max_pool, aggr=config.g_aggr, graph_conv=config.graph_conv, use_point_features=config.use_point_feats, output_dim=576)
217 | pre_quant = torch.nn.Linear(192, config.embed_dim)
218 | post_quant = torch.nn.Linear(config.embed_dim * 3, 512)
219 |
220 | encoder.load_state_dict(get_parameters_from_state_dict(state_dict, "encoder"))
221 | pre_quant.load_state_dict(get_parameters_from_state_dict(state_dict, "pre_quant"))
222 | post_quant.load_state_dict(get_parameters_from_state_dict(state_dict, "post_quant"))
223 |
224 | vq = ResidualVQ(
225 | dim=config.embed_dim,
226 | codebook_size=config.n_embed, # codebook size
227 | num_quantizers=config.embed_levels,
228 | commitment_weight=config.embed_loss_weight, # the weight on the commitment loss
229 | stochastic_sample_codes=True,
230 | sample_codebook_temp=0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
231 | shared_codebook=config.embed_share,
232 | decay=config.code_decay,
233 | )
234 | vq.load_state_dict(get_parameters_from_state_dict(state_dict, "vq"))
235 | return encoder, pre_quant, post_quant, vq
236 |
237 |
238 | def get_rvqvae_v1_all(config, resume):
239 | encoder, pre_quant, post_quant, vq = get_rvqvae_v1_encoder_vq(config, resume)
240 | decoder = get_rvqvae_v0_decoder(config, resume)
241 | return encoder, decoder, pre_quant, post_quant, vq
242 |
--------------------------------------------------------------------------------
/trainer/train_transformer.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import omegaconf
4 | import trimesh
5 | from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
6 | import pytorch_lightning as pl
7 | import hydra
8 | from easydict import EasyDict
9 | from lightning_utilities.core.rank_zero import rank_zero_only
10 | from pathlib import Path
11 | import torch
12 |
13 | from dataset.quantized_soup import QuantizedSoupTripletsCreator
14 | from dataset.triangles import TriangleNodesWithFacesAndSequenceIndices, TriangleNodesWithFacesDataloader
15 | from model.transformer import QuantSoupTransformer
16 | from trainer import create_trainer, step, get_rvqvae_v0_decoder
17 | from util.misc import accuracy
18 | from util.visualization import plot_vertices_and_faces
19 | from util.misc import get_parameters_from_state_dict
20 |
21 |
22 | class QuantSoupModelTrainer(pl.LightningModule):
23 |
24 | def __init__(self, config):
25 | super().__init__()
26 | self.config = config
27 | self.vq_cfg = omegaconf.OmegaConf.load(Path(config.vq_resume).parents[1] / "config.yaml")
28 | self.save_hyperparameters()
29 | self.train_dataset = TriangleNodesWithFacesAndSequenceIndices(config, 'train', config.scale_augment, config.shift_augment, config.ft_category)
30 | self.val_dataset = TriangleNodesWithFacesAndSequenceIndices(config, 'val', config.scale_augment_val, False, config.ft_category)
31 | print("Dataset Lengths:", len(self.train_dataset), len(self.val_dataset))
32 | print("Batch Size:", self.config.batch_size)
33 | print("Dataloader Lengths:", len(self.train_dataset) // self.config.batch_size, len(self.val_dataset) // self.config.batch_size)
34 | model_cfg = get_qsoup_model_config(config, self.vq_cfg.embed_levels)
35 | self.model = QuantSoupTransformer(model_cfg, self.vq_cfg)
36 | self.sequencer = QuantizedSoupTripletsCreator(self.config, self.vq_cfg)
37 | self.sequencer.freeze_vq()
38 | # print('compiling model...')
39 | # self.model = torch.compile(model) # requires PyTorch 2.0
40 | self.output_dir_image = Path(f'runs/{self.config.experiment}/image')
41 | self.output_dir_image.mkdir(exist_ok=True, parents=True)
42 | self.output_dir_mesh = Path(f'runs/{self.config.experiment}/mesh')
43 | self.output_dir_mesh.mkdir(exist_ok=True, parents=True)
44 | if self.config.ft_resume is not None:
45 | self.model.load_state_dict(get_parameters_from_state_dict(torch.load(self.config.ft_resume, map_location='cpu')['state_dict'], "model"))
46 | self.automatic_optimization = False
47 |
48 | def configure_optimizers(self):
49 | optimizer = self.model.configure_optimizers(
50 | self.config.weight_decay, self.config.lr,
51 | (self.config.beta1, self.config.beta2), 'cuda'
52 | )
53 | max_steps = int(self.config.max_epoch * len(self.train_dataset) / self.config.batch_size / 2)
54 | print('Max Steps | First cycle:', max_steps)
55 | scheduler = CosineAnnealingWarmupRestarts(
56 | optimizer, first_cycle_steps=max_steps, cycle_mult=1.0,
57 | max_lr=self.config.lr, min_lr=self.config.min_lr,
58 | warmup_steps=self.config.warmup_steps, gamma=1.0
59 | )
60 | return [optimizer], [scheduler]
61 |
62 | def training_step(self, data, batch_idx):
63 | optimizer = self.optimizers()
64 | scheduler = self.lr_schedulers()
65 | scheduler.step() # type: ignore
66 | if self.config.force_lr is not None:
67 | for param_group in optimizer.param_groups:
68 | param_group['lr'] = self.config.force_lr
69 | sequence_in, sequence_out, pfin, pfout = self.sequencer(data.x, data.edge_index, data.batch, data.faces, data.num_vertices.sum(), data.js)
70 | logits, loss = self.model(sequence_in, pfin, pfout, self.sequencer, targets=sequence_out)
71 | acc = accuracy(logits.detach(), sequence_out, ignore_label=2, device=self.device)
72 | self.log("train/ce_loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
73 | self.log("train/acc", acc.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
74 | loss = loss / self.config.gradient_accumulation_steps # scale the loss to account for gradient accumulation
75 | self.manual_backward(loss)
76 | # accumulate gradients of `n` batches
77 | if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
78 | step(optimizer, [self.model])
79 | optimizer.zero_grad(set_to_none=True) # type: ignore
80 | self.log("lr", optimizer.param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) # type: ignore
81 |
82 | def validation_step(self, data, batch_idx):
83 | sequence_in, sequence_out, pfin, pfout = self.sequencer(data.x, data.edge_index, data.batch, data.faces, data.num_vertices.sum(), data.js)
84 | logits, loss = self.model(sequence_in, pfin, pfout, self.sequencer, targets=sequence_out)
85 | acc = accuracy(logits.detach(), sequence_out, ignore_label=2, device=self.device)
86 | if not torch.isnan(loss).any():
87 | self.log("val/ce_loss", loss.item(), on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
88 | if not torch.isnan(acc).any():
89 | self.log("val/acc", acc.item(), on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
90 |
91 | @rank_zero_only
92 | def on_validation_epoch_end(self):
93 | decoder = get_rvqvae_v0_decoder(self.vq_cfg, self.config.vq_resume, self.device)
94 | for k in range(self.config.num_val_samples):
95 | data = self.val_dataset.get(random.randint(0, len(self.val_dataset) - 1))
96 | soup_sequence, face_in_idx, face_out_idx, target = self.sequencer.get_completion_sequence(
97 | data.x.to(self.device),
98 | data.edge_index.to(self.device),
99 | data.faces.to(self.device),
100 | data.num_vertices,
101 | 12
102 | )
103 | y = self.model.generate_with_beamsearch(
104 | soup_sequence, face_in_idx, face_out_idx, self.sequencer, self.config.max_val_tokens, use_kv_cache=True
105 | )
106 | if y is None:
107 | continue
108 | gen_vertices, gen_faces = self.sequencer.decode(y[0], decoder)
109 | plot_vertices_and_faces(gen_vertices, gen_faces, self.output_dir_image / f"{self.global_step:06d}_{k}.jpg")
110 |
111 | try:
112 | trimesh.Trimesh(vertices=gen_vertices, faces=gen_faces, process=False).export(self.output_dir_mesh / f"{self.global_step:06d}_{k}.obj")
113 | except Exception as e:
114 | pass # sometimes the mesh is invalid (ngon) and we don't want to crash
115 |
116 | def train_dataloader(self):
117 | return TriangleNodesWithFacesDataloader(self.train_dataset, batch_size=self.config.batch_size, shuffle=True, drop_last=not self.config.overfit, num_workers=self.config.num_workers, pin_memory=True)
118 |
119 | def val_dataloader(self):
120 | return TriangleNodesWithFacesDataloader(self.val_dataset, batch_size=self.config.batch_size, shuffle=True, drop_last=False, num_workers=self.config.num_workers)
121 |
122 |
123 | def get_qsoup_model_config(config, vq_embed_levels):
124 | cfg = EasyDict({
125 | 'block_size': config.block_size,
126 | 'n_embd': config.model.n_embd,
127 | 'dropout': config.model.dropout,
128 | 'n_layer': config.model.n_layer,
129 | 'n_head': config.model.n_head,
130 | 'bias': config.model.bias,
131 | 'finemb_size': vq_embed_levels * 3,
132 | 'foutemb_size': config.block_size * 3,
133 | })
134 | return cfg
135 |
136 |
137 | @hydra.main(config_path='../config', config_name='meshgpt', version_base='1.2')
138 | def main(config):
139 | trainer = create_trainer("MeshTriSoup", config)
140 | model = QuantSoupModelTrainer(config)
141 | trainer.fit(model, ckpt_path=config.resume)
142 |
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
--------------------------------------------------------------------------------
/trainer/train_vocabulary.py:
--------------------------------------------------------------------------------
1 | import omegaconf
2 | import torch_scatter
3 | import trimesh
4 | from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
5 | import pytorch_lightning as pl
6 | import hydra
7 | from lightning_utilities.core.rank_zero import rank_zero_only
8 | from pathlib import Path
9 | import torch
10 | from vector_quantize_pytorch import ResidualVQ
11 |
12 | from dataset import quantize_coordinates
13 | from dataset.triangles import create_feature_stack_from_triangles, TriangleNodesWithFaces, TriangleNodesWithFacesDataloader
14 | from model.encoder import GraphEncoder
15 | from model.decoder import resnet34_decoder
16 | from model.softargmax import softargmax
17 | from trainer import create_trainer, step, create_conv_batch
18 | from util.visualization import plot_vertices_and_faces, triangle_sequence_to_mesh
19 |
20 |
21 | class TriangleTokenizationGraphConv(pl.LightningModule):
22 |
23 | def __init__(self, config):
24 | super().__init__()
25 | self.config = config
26 | self.save_hyperparameters()
27 | if config.only_chairs:
28 | self.train_dataset = TriangleNodesWithFaces(config, 'train', config.scale_augment, config.shift_augment, '03001627')
29 | self.interesting_categories = [('03001627', "")]
30 | self.val_datasets = [TriangleNodesWithFaces(config, 'val', config.scale_augment_val, config.shift_augment_val, '03001627')]
31 | else:
32 | self.train_dataset = TriangleNodesWithFaces(config, 'train', config.scale_augment, config.shift_augment, None)
33 | self.interesting_categories = [('02828884', '_bench'), ('02871439', '_bookshelf'), ('03001627', ""), ('03211117', '_display'), ('04379243', '_table')]
34 | self.val_datasets = []
35 | for cat, name in self.interesting_categories:
36 | self.val_datasets.append(TriangleNodesWithFaces(config, 'val', config.scale_augment_val, config.shift_augment_val, cat))
37 | self.encoder = GraphEncoder(no_max_pool=config.g_no_max_pool, aggr=config.g_aggr, graph_conv=config.graph_conv, use_point_features=config.use_point_feats, output_dim=576)
38 | self.decoder = resnet34_decoder(512, config.num_tokens - 2, config.ce_output)
39 | self.pre_quant = torch.nn.Linear(192, config.embed_dim)
40 | self.post_quant = torch.nn.Linear(config.embed_dim * 3, 512)
41 | self.vq = ResidualVQ(
42 | dim=self.config.embed_dim,
43 | codebook_size=self.config.n_embed, # codebook size
44 | num_quantizers=config.embed_levels,
45 | commitment_weight=self.config.embed_loss_weight, # the weight on the commitment loss
46 | stochastic_sample_codes=True,
47 | sample_codebook_temp=config.stochasticity, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
48 | shared_codebook=self.config.embed_share,
49 | decay=self.config.code_decay,
50 | )
51 | self.register_buffer('smoothing_weight', torch.tensor([2, 10, 200, 10, 2], dtype=torch.float32).unsqueeze(0).unsqueeze(0))
52 | # print('compiling model...')
53 | # self.model = torch.compile(model) # requires PyTorch 2.0
54 | self.output_dir_image_val = Path(f'runs/{self.config.experiment}/image_val')
55 | self.output_dir_image_val.mkdir(exist_ok=True, parents=True)
56 | self.output_dir_mesh_val = Path(f'runs/{self.config.experiment}/mesh_val')
57 | self.output_dir_mesh_val.mkdir(exist_ok=True, parents=True)
58 | self.output_dir_image_train = Path(f'runs/{self.config.experiment}/image_train')
59 | self.output_dir_image_train.mkdir(exist_ok=True, parents=True)
60 | self.output_dir_mesh_train = Path(f'runs/{self.config.experiment}/mesh_train')
61 | self.output_dir_mesh_train.mkdir(exist_ok=True, parents=True)
62 | self.automatic_optimization = False
63 | self.distribute_features_fn = distribute_features if self.config.distribute_features else dummy_distribute
64 | self.visualize_groundtruth()
65 |
66 | def configure_optimizers(self):
67 | parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.pre_quant.parameters()) + list(self.post_quant.parameters()) + list(self.vq.parameters())
68 | optimizer = torch.optim.AdamW(parameters, lr=self.config.lr, amsgrad=True, weight_decay=self.config.weight_decay)
69 | max_steps = int(self.config.max_epoch * len(self.train_dataset) / self.config.batch_size)
70 | print('Max Steps | First cycle:', max_steps)
71 | scheduler = CosineAnnealingWarmupRestarts(
72 | optimizer, first_cycle_steps=max_steps, cycle_mult=1.0,
73 | max_lr=self.config.lr, min_lr=self.config.min_lr,
74 | warmup_steps=self.config.warmup_steps, gamma=1.0
75 | )
76 | return [optimizer], [scheduler]
77 |
78 | def create_conv_batch(self, encoded_features, batch, batch_size):
79 | return create_conv_batch(encoded_features, batch, batch_size, self.device)
80 |
81 | def training_step(self, data, batch_idx):
82 | optimizer = self.optimizers()
83 | scheduler = self.lr_schedulers()
84 | scheduler.step() # type: ignore
85 | encoded_x = self.encoder(data.x, data.edge_index, data.batch)
86 | encoded_x = encoded_x.reshape(encoded_x.shape[0] * 3, 192) # 3N x 192
87 | encoded_x = self.distribute_features_fn(encoded_x, data.faces, data.num_vertices.sum(), self.device)
88 | encoded_x = self.pre_quant(encoded_x) # 3N x 192
89 | encoded_x, _, commit_loss = self.vq(encoded_x.unsqueeze(0))
90 | encoded_x = encoded_x.squeeze(0)
91 | commit_loss = commit_loss.mean()
92 | encoded_x = encoded_x.reshape(-1, 3 * encoded_x.shape[-1])
93 | encoded_x = self.post_quant(encoded_x)
94 | encoded_x_conv, conv_mask = self.create_conv_batch(encoded_x, data.batch, self.config.batch_size)
95 | decoded_x_conv = self.decoder(encoded_x_conv)
96 | if self.config.ce_output:
97 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :]
98 | decoded_tri = softargmax(decoded_x) / (self.config.num_tokens - 3) - 0.5
99 | _, decoded_normals, decoded_areas, decoded_angles = create_feature_stack_from_triangles(decoded_tri.reshape(-1, 3, 3))
100 | if self.config.use_smoothed_loss:
101 | otarget = torch.nn.functional.one_hot(data.y.reshape(-1), num_classes=self.config.num_tokens - 2).float()
102 | otarget = otarget.unsqueeze(1)
103 | starget = torch.nn.functional.conv1d(otarget, self.smoothing_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
104 | if self.config.use_multimodal_loss:
105 | starget_a = starget.reshape(-1, decoded_x.shape[-2] * decoded_x_conv.shape[-1])
106 | starget_a = torch.nn.functional.normalize(starget_a, p=1.0, dim=-1, eps=1e-12).squeeze(1)
107 | starget_b = torch.nn.functional.normalize(starget, p=1.0, dim=-1, eps=1e-12).squeeze(1)
108 | loss = torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-2] * decoded_x_conv.shape[-1]), starget_a).mean()
109 | loss = loss * 0.1 + torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-1]), starget_b).mean()
110 | else:
111 | starget = torch.nn.functional.normalize(starget, p=1.0, dim=-1, eps=1e-12).squeeze(1)
112 | loss = torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-1]), starget).mean()
113 | else:
114 | loss = torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-1]), data.y.reshape(-1), reduction='mean')
115 | y_coords = data.y / (self.config.num_tokens - 3) - 0.5
116 | loss_tri = torch.nn.functional.mse_loss(decoded_tri, y_coords, reduction='mean')
117 | loss_normals = torch.nn.functional.mse_loss(decoded_normals, data.x[:, 9:12], reduction='mean')
118 | loss_areas = torch.nn.functional.mse_loss(decoded_areas, data.x[:, 12:13], reduction='mean')
119 | loss_angles = torch.nn.functional.mse_loss(decoded_angles, data.x[:, 13:16], reduction='mean')
120 | else:
121 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-1])[conv_mask, :]
122 | _, decoded_normals, decoded_areas, decoded_angles = create_feature_stack_from_triangles(decoded_x.reshape(-1, 3, 3))
123 | loss = torch.nn.functional.mse_loss(decoded_x, data.y, reduction='mean')
124 | loss_tri = torch.nn.functional.mse_loss(decoded_x, data.y, reduction='mean')
125 | loss_normals = torch.nn.functional.mse_loss(decoded_normals, data.x[:, 9:12], reduction='mean')
126 | loss_areas = torch.nn.functional.mse_loss(decoded_areas, data.x[:, 12:13], reduction='mean')
127 | loss_angles = torch.nn.functional.mse_loss(decoded_angles, data.x[:, 13:16], reduction='mean')
128 |
129 | acc = self.get_accuracy(decoded_x, data.y)
130 | acc_triangle = self.get_triangle_accuracy(decoded_x, data.y)
131 | self.log("train/ce_loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
132 | self.log("train/mse_loss", loss_tri.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
133 | self.log("train/norm_loss", loss_normals.item(), on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True)
134 | self.log("train/area_loss", loss_areas.item(), on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True)
135 | self.log("train/angle_loss", loss_angles.item(), on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True)
136 | self.log("train/embed_loss", commit_loss.item(), on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True)
137 | self.log("train/acc", acc.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
138 | self.log("train/acc_tri", acc_triangle.item(), on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
139 | loss = loss + loss_tri * self.config.tri_weight + loss_normals * self.config.norm_weight + loss_areas * self.config.area_weight + loss_angles * self.config.angle_weight + commit_loss
140 | # loss = loss + loss_tri * self.config.tri_weight + commit_loss
141 | loss = loss / self.config.gradient_accumulation_steps # scale the loss to account for gradient accumulation
142 | self.manual_backward(loss)
143 | # accumulate gradients of `n` batches
144 | if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
145 | step(optimizer, [self.encoder, self.decoder, self.pre_quant, self.post_quant])
146 | optimizer.zero_grad(set_to_none=True) # type: ignore
147 | self.log("lr", optimizer.param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True) # type: ignore
148 |
149 | def validation_step(self, data, batch_idx, dataloader_idx):
150 | encoded_x = self.encoder(data.x, data.edge_index, data.batch)
151 | encoded_x = encoded_x.reshape(encoded_x.shape[0] * 3, 192)
152 | encoded_x = self.distribute_features_fn(encoded_x, data.faces, data.num_vertices.sum(), self.device)
153 | encoded_x = self.pre_quant(encoded_x)
154 | encoded_x, _, commit_loss = self.vq(encoded_x.unsqueeze(0))
155 | encoded_x = encoded_x.squeeze(0)
156 | commit_loss = commit_loss.mean()
157 | encoded_x = encoded_x.reshape(-1, 3 * encoded_x.shape[-1])
158 | encoded_x = self.post_quant(encoded_x)
159 | encoded_x_conv, conv_mask = self.create_conv_batch(encoded_x, data.batch, self.config.batch_size)
160 | decoded_x_conv = self.decoder(encoded_x_conv)
161 | if self.config.ce_output:
162 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :]
163 | decoded_c = softargmax(decoded_x) / (self.config.num_tokens - 3) - 0.5
164 | loss = torch.nn.functional.cross_entropy(decoded_x.reshape(-1, decoded_x.shape[-1]), data.y.reshape(-1), reduction='mean')
165 | y_coords = data.y / (self.config.num_tokens - 3) - 0.5
166 | loss_c = torch.nn.functional.mse_loss(decoded_c, y_coords, reduction='mean')
167 | else:
168 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-1])[conv_mask, :]
169 | loss = torch.nn.functional.mse_loss(decoded_x, data.y, reduction='mean')
170 | loss_c = torch.nn.functional.mse_loss(decoded_x, data.y, reduction='mean')
171 | acc = self.get_accuracy(decoded_x, data.y)
172 | acc_triangle = self.get_triangle_accuracy(decoded_x, data.y)
173 | if not torch.isnan(loss).any():
174 | self.log(f"val/ce_loss{self.interesting_categories[dataloader_idx][1]}", loss.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
175 | if not torch.isnan(loss_c).any():
176 | self.log(f"val/mse_loss{self.interesting_categories[dataloader_idx][1]}", loss_c.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
177 | if not torch.isnan(commit_loss).any():
178 | self.log(f"val/embed_loss{self.interesting_categories[dataloader_idx][1]}", commit_loss.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
179 | if not torch.isnan(acc).any():
180 | self.log(f"val/acc{self.interesting_categories[dataloader_idx][1]}", acc.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
181 | if not torch.isnan(acc_triangle).any():
182 | self.log(f"val/acc_tri{self.interesting_categories[dataloader_idx][1]}", acc_triangle.item(), add_dataloader_idx=False, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
183 |
184 | @rank_zero_only
185 | def on_validation_epoch_end(self):
186 | category_names = [""] + [x[0].strip("_") for x in self.interesting_categories]
187 | for didx, dataset in enumerate([self.train_dataset] + self.val_datasets):
188 | output_dir_image = self.output_dir_image_train if didx == 0 else self.output_dir_image_val
189 | output_dir_mesh = self.output_dir_mesh_train if didx == 0 else self.output_dir_mesh_val
190 | for k in range(self.config.num_val_samples):
191 | data = dataset.get(k * (len(dataset) // self.config.num_val_samples) % len(dataset))
192 | encoded_x = self.encoder(data.x.to(self.device), data.edge_index.to(self.device), torch.zeros([data.x.shape[0]], device=self.device).long())
193 | encoded_x = encoded_x.reshape(encoded_x.shape[0] * 3, 192)
194 | encoded_x = self.distribute_features_fn(encoded_x, data.faces.to(self.device), data.num_vertices, self.device)
195 | encoded_x = self.pre_quant(encoded_x)
196 | encoded_x, _, _ = self.vq(encoded_x.unsqueeze(0))
197 | encoded_x = encoded_x.squeeze(0)
198 | encoded_x = encoded_x.reshape(-1, 3 * encoded_x.shape[-1])
199 | encoded_x = self.post_quant(encoded_x)
200 | encoded_x_conv, conv_mask = self.create_conv_batch(encoded_x, torch.zeros([data.x.shape[0]], device=self.device).long(), 1)
201 | decoded_x_conv = self.decoder(encoded_x_conv)
202 | if self.config.ce_output:
203 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-2], decoded_x_conv.shape[-1])[conv_mask, :, :]
204 | coords = decoded_x.argmax(-1).detach().cpu().numpy() / (self.config.num_tokens - 3) - 0.5
205 | else:
206 | decoded_x = decoded_x_conv.reshape(-1, decoded_x_conv.shape[-1])[conv_mask, :]
207 | coords = decoded_x.detach().cpu().numpy()
208 | gen_vertices, gen_faces = triangle_sequence_to_mesh(coords)
209 | plot_vertices_and_faces(gen_vertices, gen_faces, output_dir_image / f"{self.global_step:06d}_{category_names[didx]}_{k}.jpg")
210 | try:
211 | trimesh.Trimesh(vertices=gen_vertices, faces=gen_faces, process=False).export(output_dir_mesh / f"{self.global_step:06d}_{category_names[didx]}_{k}.obj")
212 | except Exception as e:
213 | pass # sometimes the mesh is invalid (ngon) and we don't want to crash
214 |
215 | def visualize_groundtruth(self):
216 | category_names = [""] + [x[0].strip("_") for x in self.interesting_categories]
217 | for didx, dataset in enumerate([self.train_dataset] + self.val_datasets):
218 | output_dir_image = self.output_dir_image_train if didx == 0 else self.output_dir_image_val
219 | for k in range(self.config.num_val_samples):
220 | data = dataset.get(k * (len(dataset) // self.config.num_val_samples) % len(dataset))
221 | if self.config.ce_output:
222 | coords = data.y / (self.config.num_tokens - 3) - 0.5
223 | else:
224 | coords = data.y
225 | gen_vertices, gen_faces = triangle_sequence_to_mesh(coords)
226 | plot_vertices_and_faces(gen_vertices, gen_faces, output_dir_image / f"GT_{category_names[didx]}_{k}.jpg")
227 |
228 | def train_dataloader(self):
229 | return TriangleNodesWithFacesDataloader(self.train_dataset, batch_size=self.config.batch_size, shuffle=True, drop_last=not self.config.overfit, num_workers=self.config.num_workers, pin_memory=True)
230 |
231 | def val_dataloader(self):
232 | dataloaders = []
233 | for val_dataset in self.val_datasets:
234 | dataloaders.append(TriangleNodesWithFacesDataloader(val_dataset, batch_size=self.config.batch_size, shuffle=True, drop_last=True, num_workers=self.config.num_workers))
235 | return dataloaders
236 |
237 | def get_accuracy(self, x, y):
238 | if self.config.ce_output:
239 | return (x.argmax(-1).reshape(-1) == y.reshape(-1)).sum() / (x.shape[0] * x.shape[1])
240 | return (quantize_coordinates(x, self.config.num_tokens - 2).reshape(-1) == quantize_coordinates(y, self.config.num_tokens - 2).reshape(-1)).sum() / (x.shape[0] * x.shape[1])
241 |
242 | def get_triangle_accuracy(self, x, y):
243 | if self.config.ce_output:
244 | return torch.all(x.argmax(-1).reshape(-1, 9) == y.reshape(-1, 9), dim=-1).sum() / x.shape[0]
245 | return torch.all((quantize_coordinates(x, self.config.num_tokens - 2).reshape(-1, 9) == quantize_coordinates(y, self.config.num_tokens - 2).reshape(-1, 9)), dim=-1).sum() / (x.shape[0])
246 |
247 |
248 | def distribute_features(features, face_indices, num_vertices, device):
249 | # N = num triangles
250 | # features is N3 x 192
251 | # face_indices is N x 3
252 | assert features.shape[0] == face_indices.shape[0] * face_indices.shape[1], "Features and face indices must match in size"
253 | vertex_features = torch.zeros([num_vertices, features.shape[1]], device=device)
254 | torch_scatter.scatter_mean(features, face_indices.reshape(-1), out=vertex_features, dim=0)
255 | distributed_features = vertex_features[face_indices.reshape(-1), :]
256 | return distributed_features
257 |
258 |
259 | def dummy_distribute(features, _face_indices, _n, _device):
260 | return features
261 |
262 |
263 | @hydra.main(config_path='../config', config_name='meshgpt', version_base='1.2')
264 | def main(config):
265 | trainer = create_trainer("TriangleTokens", config)
266 | model = TriangleTokenizationGraphConv(config)
267 | trainer.fit(model, ckpt_path=config.resume)
268 |
269 |
270 | if __name__ == '__main__':
271 | main()
272 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/audi/MeshGPT/0c871aa4acad3c316404630601c73aea5338082b/util/__init__.py
--------------------------------------------------------------------------------
/util/filesystem_logger.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import shutil
3 | from pathlib import Path
4 | from typing import Dict, Optional, Union
5 |
6 | from omegaconf import OmegaConf
7 | from pytorch_lightning.loggers.logger import Logger
8 | from lightning_fabric.loggers.logger import rank_zero_experiment
9 | from lightning_fabric.loggers.logger import _DummyExperiment
10 |
11 |
12 | class FilesystemLogger(Logger):
13 |
14 | @property
15 | def version(self) -> Union[int, str]:
16 | return 0
17 |
18 | @property
19 | def name(self) -> str:
20 | return "fslogger"
21 |
22 | # noinspection PyMethodOverriding
23 | def log_hyperparams(self, params: argparse.Namespace):
24 | pass
25 |
26 | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
27 | pass
28 |
29 | def __init__(self, experiment_config, **_kwargs):
30 | super().__init__()
31 | self.experiment_config = experiment_config
32 | self._experiment = None
33 | # noinspection PyStatementEffect
34 | self.experiment
35 |
36 | @property
37 | @rank_zero_experiment
38 | def experiment(self):
39 | if self._experiment is None:
40 | self._experiment = _DummyExperiment()
41 | experiment_dir = Path("runs", self.experiment_config["experiment"])
42 | experiment_dir.mkdir(exist_ok=True, parents=True)
43 |
44 | src_folders = ['config', 'model', 'tests', 'trainer', 'util', 'visualizer', 'dataset']
45 | sources = []
46 | for src in src_folders:
47 | sources.extend(list(Path(".").glob(f'{src}/**/*')))
48 |
49 | files_to_copy = [x for x in sources if x.suffix in [".py", ".pyx", ".txt", ".so", ".pyd", ".h", ".cu", ".c", '.cpp', ".html"] and x.parts[0] != "runs" and x.parts[0] != "wandb"]
50 |
51 | for f in files_to_copy:
52 | Path(experiment_dir, "code", f).parents[0].mkdir(parents=True, exist_ok=True)
53 | shutil.copyfile(f, Path(experiment_dir, "code", f))
54 |
55 | Path(experiment_dir, "config.yaml").write_text(OmegaConf.to_yaml(self.experiment_config))
56 |
57 | return self._experiment
--------------------------------------------------------------------------------
/util/meshlab.py:
--------------------------------------------------------------------------------
1 | import pymeshlab
2 |
3 |
4 | def meshlab_proc(meshpath):
5 | ms = pymeshlab.MeshSet()
6 | ms.load_new_mesh(str(meshpath))
7 | ms.meshing_merge_close_vertices(threshold=pymeshlab.Percentage(1))
8 | ms.meshing_remove_duplicate_faces()
9 | ms.meshing_remove_null_faces()
10 | ms.meshing_remove_duplicate_vertices()
11 | ms.meshing_remove_unreferenced_vertices()
12 | ms.save_current_mesh(str(meshpath), save_vertex_color=False, save_vertex_coord=False, save_face_color=False, save_wedge_texcoord=False)
13 | ms.clear()
14 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from collections import OrderedDict
4 |
5 |
6 | def get_parameters_from_state_dict(state_dict, filter_key):
7 | new_state_dict = OrderedDict()
8 | for k in state_dict:
9 | if k.startswith(filter_key):
10 | new_state_dict[k.replace(filter_key + '.', '')] = state_dict[k]
11 | return new_state_dict
12 |
13 |
14 | def accuracy(y_pred, y_true, ignore_label=None, device=None):
15 | y_pred = y_pred.argmax(dim=-1)
16 |
17 | if ignore_label:
18 | normalizer = torch.sum(y_true != ignore_label) # type: ignore
19 | ignore_mask = torch.where( # type: ignore
20 | y_true == ignore_label,
21 | torch.zeros_like(y_true, device=device),
22 | torch.ones_like(y_true, device=device)
23 | ).type(torch.float32)
24 | else:
25 | normalizer = y_true.shape[0]
26 | ignore_mask = torch.ones_like(y_true, device=device).type(torch.float32)
27 | acc = (y_pred.reshape(-1) == y_true.reshape(-1)).type(torch.float32) # type: ignore
28 | acc = torch.sum(acc*ignore_mask.flatten())
29 | return acc / normalizer
30 |
31 |
32 | def rmse(y_pred, y_true, num_tokens, ignore_labels=(0, 1, 2)):
33 | mask = torch.logical_and(y_true != ignore_labels[0], y_pred != ignore_labels[0])
34 | for i in range(1, len(ignore_labels)):
35 | mask = torch.logical_and(mask, y_true != ignore_labels[i])
36 | mask = torch.logical_and(mask, y_pred != ignore_labels[i])
37 | y_pred = y_pred[mask]
38 | y_true = y_true[mask]
39 | vertices_pred = (y_pred - 3) / num_tokens - 0.5
40 | vertices_true = (y_true - 3) / num_tokens - 0.5
41 | return torch.sqrt(torch.mean((vertices_pred - vertices_true)**2))
42 |
43 |
44 | def get_create_shapenet_train_val(path):
45 | from collections import Counter
46 | import random
47 |
48 | all_shapes = [x.stem for x in list(path.iterdir())]
49 | all_categories = sorted(list(set(list(x.split("_")[0] for x in all_shapes))))
50 | counts_all = Counter()
51 | for s in all_shapes:
52 | counts_all[s.split("_")[0]] += 1
53 | validation = random.sample(all_shapes, int(len(all_shapes) * 0.05))
54 | counts_val = Counter()
55 | for s in validation:
56 | counts_val[s.split("_")[0]] += 1
57 | for c in all_categories:
58 | print(c, f"{counts_all[c] / len(all_shapes) * 100:.2f}", f"{counts_val[c] / len(validation) * 100:.2f}")
59 | train = [x for x in all_shapes if x not in validation]
60 | Path("val.txt").write_text("\n".join(validation))
61 | Path("train.txt").write_text("\n".join(train))
62 |
63 |
64 | def scale_vertices(vertices, x_lims=(0.75, 1.25), y_lims=(0.75, 1.25), z_lims=(0.75, 1.25)):
65 | # scale x, y, z
66 | x = np.random.uniform(low=x_lims[0], high=x_lims[1], size=(1,))
67 | y = np.random.uniform(low=y_lims[0], high=y_lims[1], size=(1,))
68 | z = np.random.uniform(low=z_lims[0], high=z_lims[1], size=(1,))
69 | vertices = np.stack([vertices[:, 0] * x, vertices[:, 1] * y, vertices[:, 2] * z], axis=-1)
70 | return vertices
71 |
72 |
73 | def shift_vertices(vertices, x_lims=(-0.1, 0.1), y_lims=(-0.1, 0.1), z_lims=(-0.075, 0.075)):
74 | # shift x, y, z
75 | x = np.random.uniform(low=x_lims[0], high=x_lims[1], size=(1,))
76 | y = np.random.uniform(low=y_lims[0], high=y_lims[1], size=(1,))
77 | z = np.random.uniform(low=z_lims[0], high=z_lims[1], size=(1,))
78 | x = max(min(x, 0.5 - vertices[:, 0].max()), -0.5 - vertices[:, 0].min())
79 | y = max(min(y, 0.5 - vertices[:, 1].max()), -0.5 - vertices[:, 1].min())
80 | z = max(min(z, 0.5 - vertices[:, 2].max()), -0.5 - vertices[:, 2].min())
81 | vertices = np.stack([vertices[:, 0] + x, vertices[:, 1] + y, vertices[:, 2] + z], axis=-1)
82 | return vertices
83 |
84 |
85 | def normalize_vertices(vertices):
86 | bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) # type: ignore
87 | vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
88 | vertices = vertices / (bounds[1] - bounds[0]).max()
89 | return vertices
90 |
91 |
92 | def top_p_sampling(logits, p):
93 | probs = torch.softmax(logits, dim=-1)
94 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
95 | probs_sum = torch.cumsum(probs_sort, dim=-1)
96 | mask = probs_sum - probs_sort > p
97 | probs_sort[mask] = 0.0
98 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
99 | next_token = torch.multinomial(probs_sort, num_samples=1)
100 | next_token = torch.gather(probs_idx, -1, next_token)
101 | return next_token
102 |
103 |
104 | if __name__ == "__main__":
105 | from pathlib import Path
106 | # get_create_shapenet_train_val(Path("/cluster/gimli/ysiddiqui/ShapeNetCore.v2.meshlab/"))
107 | logits_ = torch.FloatTensor([[3, -1, 0.5, 0.1],
108 | [0, 9, 0.5, 0.1],
109 | [0, 0, 5, 0.1]])
110 | for i in range(5):
111 | print(top_p_sampling(logits_, 0.9))
112 |
--------------------------------------------------------------------------------
/util/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Embedder:
6 | def __init__(self, **kwargs):
7 | self.kwargs = kwargs
8 | self.create_embedding_fn()
9 |
10 | def create_embedding_fn(self):
11 | embed_fns = []
12 | d = self.kwargs['input_dims']
13 | out_dim = 0
14 | if self.kwargs['include_input']:
15 | embed_fns.append(lambda x: x)
16 | out_dim += d
17 |
18 | max_freq = self.kwargs['max_freq_log2']
19 | N_freqs = self.kwargs['num_freqs']
20 |
21 | if self.kwargs['log_sampling']:
22 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
23 | else:
24 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
25 |
26 | for freq in freq_bands:
27 | for p_fn in self.kwargs['periodic_fns']:
28 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
29 | out_dim += d
30 |
31 | self.embed_fns = embed_fns
32 | self.out_dim = out_dim
33 |
34 | def embed(self, inputs):
35 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
36 |
37 |
38 | def get_embedder(multires, i=0):
39 | if i == -1:
40 | return nn.Identity(), 3
41 |
42 | embed_kwargs = {
43 | 'include_input': True,
44 | 'input_dims': 3,
45 | 'max_freq_log2': multires - 1,
46 | 'num_freqs': multires,
47 | 'log_sampling': True,
48 | 'periodic_fns': [torch.sin, torch.cos],
49 | }
50 |
51 | embedder_obj = Embedder(**embed_kwargs)
52 | embed = lambda x, eo=embedder_obj: eo.embed(x)
53 | return embed, embedder_obj.out_dim
54 |
--------------------------------------------------------------------------------
/util/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from pathlib import Path
4 | from PIL import Image
5 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
6 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
7 |
8 | from dataset import newface_token, stopface_token, padface_token
9 |
10 |
11 | def visualize_points(points, vis_path, colors=None):
12 | if colors is None:
13 | Path(vis_path).write_text("\n".join(f"v {p[0]} {p[1]} {p[2]} 127 127 127" for p in points))
14 | else:
15 | Path(vis_path).write_text("\n".join(f"v {p[0]} {p[1]} {p[2]} {colors[i, 0]} {colors[i, 1]} {colors[i, 2]}" for i, p in enumerate(points)))
16 |
17 |
18 | def tokens_to_vertices(token_sequence, num_tokens):
19 | try:
20 | end = token_sequence.index(num_tokens + 1)
21 | except ValueError:
22 | end = len(token_sequence)
23 | token_sequence = token_sequence[:end]
24 | token_sequence = token_sequence[:(len(token_sequence) // 3) * 3]
25 | vertices = (np.array(token_sequence).reshape(-1, 3)) / num_tokens - 0.5
26 | # order: Z, Y, X --> X, Y, Z
27 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1)
28 | return vertices
29 |
30 |
31 | def visualize_quantized_mesh_vertices(token_sequence, num_tokens, output_path):
32 | vertices = tokens_to_vertices(token_sequence, num_tokens)
33 | plot_vertices(vertices, output_path)
34 |
35 |
36 | def visualize_quantized_mesh_vertices_and_faces(token_sequence_vertex, token_sequence_face, num_tokens, output_path):
37 | vertices, faces = tokens_to_mesh(token_sequence_vertex, token_sequence_face, num_tokens)
38 | plot_vertices_and_faces(vertices, faces, output_path)
39 |
40 |
41 | def plot_vertices(vertices, output_path):
42 | fig = plt.figure(figsize=(4, 4))
43 | ax = fig.add_subplot(111, projection="3d")
44 | plt.xlim(-0.35, 0.35)
45 | plt.ylim(-0.35, 0.35)
46 | # Don't mess with the limits!
47 | plt.autoscale(False)
48 | ax.set_axis_off()
49 | ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], c='g', s=10)
50 | ax.set_zlim(-0.35, 0.35)
51 | ax.view_init(25, -120, 0)
52 | plt.tight_layout()
53 | plt.savefig(output_path)
54 | plt.close("all")
55 |
56 |
57 | def plot_vertices_and_faces(vertices, faces, output_path):
58 | ngons = [[vertices[v, :].tolist() for v in f] for f in faces]
59 | fig = plt.figure(figsize=(8, 8))
60 | ax = fig.add_subplot(111, projection="3d")
61 | plt.xlim(-0.45, 0.45)
62 | plt.ylim(-0.45, 0.45)
63 | # Don't mess with the limits!
64 | plt.autoscale(False)
65 | ax.set_axis_off()
66 | ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], c='black', s=10)
67 | polygon_collection = Poly3DCollection(ngons)
68 | polygon_collection.set_alpha(0.3)
69 | polygon_collection.set_color('b')
70 | ax.add_collection(polygon_collection)
71 | ax.set_zlim(-0.35, 0.35)
72 | ax.view_init(25, -120, 0)
73 | plt.tight_layout()
74 | plt.savefig(output_path)
75 | plt.close("all")
76 |
77 |
78 | def visualize_quantized_mesh_vertices_gif(token_sequence, num_tokens, output_dir):
79 | vertices = tokens_to_vertices(token_sequence, num_tokens)
80 | visualize_mesh_vertices_gif(vertices, output_dir)
81 |
82 |
83 | def visualize_mesh_vertices_gif(vertices, output_dir):
84 | for i in range(1, len(vertices), 1):
85 | fig = plt.figure(figsize=(4, 4))
86 | ax = fig.add_subplot(111, projection="3d")
87 | plt.xlim(-0.35, 0.35)
88 | plt.ylim(-0.35, 0.35)
89 | # Don't mess with the limits!
90 | plt.autoscale(False)
91 | ax.set_axis_off()
92 | ax.scatter(vertices[:i, 0], vertices[:i, 1], vertices[:i, 2], c='g', s=10)
93 | ax.set_zlim(-0.35, 0.35)
94 | ax.view_init(25, -120, 0)
95 | plt.tight_layout()
96 | plt.savefig(output_dir / f"{i:05d}.png")
97 | plt.close("all")
98 | create_gif(output_dir, 40, output_dir / "vis.gif")
99 |
100 |
101 | def visualize_quantized_mesh_vertices_and_faces_gif(token_sequence_vertex, token_sequence_face, num_tokens, output_dir):
102 | visualize_quantized_mesh_vertices_gif(token_sequence_vertex, num_tokens, output_dir)
103 | vertices, faces = tokens_to_mesh(token_sequence_vertex, token_sequence_face, num_tokens)
104 | visualize_mesh_vertices_and_faces_gif(vertices, faces, output_dir)
105 |
106 |
107 | def visualize_mesh_vertices_and_faces_gif(vertices, faces, output_dir):
108 | ngons = [[vertices[v, :].tolist() for v in f] for f in faces]
109 | for i in range(1, len(ngons) + 1, 1):
110 | fig = plt.figure(figsize=(9, 9))
111 | ax = fig.add_subplot(111, projection="3d")
112 | plt.xlim(-0.35, 0.35)
113 | plt.ylim(-0.35, 0.35)
114 | # Don't mess with the limits!
115 | plt.autoscale(False)
116 | ax.set_axis_off()
117 | ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], c='black', s=10)
118 | polygon_collection = Poly3DCollection(ngons[:i])
119 | polygon_collection.set_alpha(0.3)
120 | polygon_collection.set_color('b')
121 | ax.add_collection(polygon_collection)
122 | ax.set_zlim(-0.35, 0.35)
123 | ax.view_init(25, -120, 0)
124 | plt.tight_layout()
125 | plt.savefig(output_dir / f"{len(vertices) + i:05d}.png")
126 | plt.close("all")
127 | create_gif(output_dir, 40, output_dir / "vis.gif")
128 |
129 |
130 | def create_gif(folder, fps, output_path):
131 | collection_rgb = []
132 | for f in sorted([x for x in folder.iterdir() if x.suffix == ".png" or x.suffix == ".jpg"]):
133 | img_rgb = np.array(Image.open(f).resize((384, 384)))
134 | collection_rgb.append(img_rgb)
135 | clip = ImageSequenceClip(collection_rgb, fps=fps)
136 | clip.write_gif(output_path, verbose=False, logger=None)
137 |
138 |
139 | def tokens_to_mesh(vertices_q, face_sequence, num_tokens):
140 | vertices = (np.array(vertices_q).reshape(-1, 3)) / num_tokens - 0.5
141 | # order: Z, Y, X --> X, Y, Z
142 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1)
143 | try:
144 | end = face_sequence.index(stopface_token)
145 | except ValueError:
146 | end = len(face_sequence)
147 | face_sequence = face_sequence[:end]
148 | face_sequence = [x for x in face_sequence if x != 2] # remove padding
149 | faces = []
150 | current_face = []
151 | for i in range(len(face_sequence)):
152 | if face_sequence[i] == newface_token:
153 | if len(current_face) > 2:
154 | faces.append(current_face)
155 | current_face = []
156 | else:
157 | current_face.append(face_sequence[i] - 3)
158 | if len(current_face) != 0:
159 | faces.append(current_face)
160 | return vertices, faces
161 |
162 |
163 | def ngon_to_obj(vertices, faces):
164 | obj = ""
165 | for i in range(len(vertices)):
166 | obj += f"v {vertices[i, 0]} {vertices[i, 1]} {vertices[i, 2]}\n"
167 | for i in range(len(faces)):
168 | fline = "f"
169 | for j in range(len(faces[i])):
170 | fline += f" {faces[i][j] + 1} "
171 | fline += "\n"
172 | obj += fline
173 | return obj
174 |
175 |
176 | def trisoup_sequence_to_mesh(soup_sequence, num_tokens):
177 | try:
178 | end = soup_sequence.index(stopface_token)
179 | except ValueError:
180 | end = len(soup_sequence)
181 | soup_sequence = soup_sequence[:end]
182 | vertices_q = []
183 | current_subsequence = []
184 | for i in range(len(soup_sequence)):
185 | if soup_sequence[i] == newface_token:
186 | if len(current_subsequence) >= 9:
187 | current_subsequence = current_subsequence[:9]
188 | vertices_q.append(np.array(current_subsequence).reshape(3, 3))
189 | current_subsequence = []
190 | elif soup_sequence[i] != padface_token:
191 | current_subsequence.append(soup_sequence[i] - 3)
192 | if len(current_subsequence) >= 9:
193 | current_subsequence = current_subsequence[:9]
194 | vertices_q.append(np.array(current_subsequence).reshape(3, 3))
195 | vertices = (np.array(vertices_q).reshape(-1, 3)) / num_tokens - 0.5
196 | # order: Z, Y, X --> X, Y, Z
197 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1)
198 | faces = np.array(list(range(len(vertices_q) * 3)), dtype=np.int32).reshape(-1, 3)
199 | return vertices, faces
200 |
201 |
202 | def ngonsoup_sequence_to_mesh(soup_sequence, num_tokens):
203 | try:
204 | end = soup_sequence.index(stopface_token)
205 | except ValueError:
206 | end = len(soup_sequence)
207 | soup_sequence = soup_sequence[:end]
208 | vertices_q = []
209 | face_ctr = 0
210 | faces = []
211 | current_subsequence = []
212 | for i in range(len(soup_sequence)):
213 | if soup_sequence[i] == newface_token:
214 | current_subsequence = current_subsequence[:len(current_subsequence) // 3 * 3]
215 | if len(current_subsequence) > 0:
216 | vertices_q.append(np.array(current_subsequence).reshape(-1, 3))
217 | faces.append([x for x in range(face_ctr, face_ctr + len(current_subsequence) // 3)])
218 | face_ctr += (len(current_subsequence) // 3)
219 | current_subsequence = []
220 | elif soup_sequence[i] != padface_token:
221 | current_subsequence.append(soup_sequence[i] - 3)
222 |
223 | current_subsequence = current_subsequence[:len(current_subsequence) // 3 * 3]
224 | if len(current_subsequence) > 0:
225 | vertices_q.append(np.array(current_subsequence).reshape(-1, 3))
226 | faces.append([x for x in range(face_ctr, face_ctr + len(current_subsequence) // 3)])
227 | face_ctr += (len(current_subsequence) // 3)
228 |
229 | vertices = np.vstack(vertices_q) / num_tokens - 0.5
230 | # order: Z, Y, X --> X, Y, Z
231 | vertices = np.stack([vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1)
232 | return vertices, faces
233 |
234 |
235 | def triangle_sequence_to_mesh(triangles):
236 | vertices = triangles.reshape(-1, 3)
237 | faces = np.array(list(range(vertices.shape[0]))).reshape(-1, 3)
238 | return vertices, faces
239 |
--------------------------------------------------------------------------------