├── .gitignore ├── 1_preliminaries.ipynb ├── 2_image_classification.ipynb ├── 3_image_autoencoder.ipynb ├── 4_pointcloud_classification.ipynb ├── 5_pointcloud_autoencoder.ipynb ├── 6_transient_flow_cfd.ipynb ├── LICENSE ├── README.md ├── data └── simulation │ └── case_000000 │ ├── 00000000_mesh.th │ ├── 00000001_mesh.th │ ├── 00000002_mesh.th │ ├── 00000003_mesh.th │ ├── 00000004_mesh.th │ ├── 00000005_mesh.th │ ├── 00000006_mesh.th │ ├── 00000007_mesh.th │ ├── 00000008_mesh.th │ ├── 00000009_mesh.th │ ├── 00000010_mesh.th │ ├── 00000011_mesh.th │ ├── 00000012_mesh.th │ ├── 00000013_mesh.th │ ├── 00000014_mesh.th │ ├── 00000015_mesh.th │ ├── 00000016_mesh.th │ ├── 00000017_mesh.th │ ├── 00000018_mesh.th │ ├── 00000019_mesh.th │ ├── x.th │ └── y.th ├── schematics ├── architecture.svg ├── perceiver_decoder.svg ├── perceiver_pooling.svg ├── schematics.drawio ├── upt_dense_autoencoder.svg ├── upt_dense_classifier.svg ├── upt_sparse_autoencoder.svg └── upt_sparse_classifier.svg └── upt ├── __init__.py ├── collators ├── __init__.py ├── simulation_collator.py ├── sparseimage_autoencoder_collator.py └── sparseimage_classifier_collator.py ├── datasets ├── __init__.py ├── simulation_dataset.py ├── sparse_cifar10_autoencoder_dataset.py └── sparse_cifar10_classifier_dataset.py ├── models ├── __init__.py ├── approximator.py ├── conditioner_timestep.py ├── decoder_classifier.py ├── decoder_perceiver.py ├── encoder_image.py ├── encoder_supernodes.py ├── upt.py ├── upt_image_autoencoder.py ├── upt_image_classifier.py ├── upt_sparseimage_autoencoder.py └── upt_sparseimage_classifier.py └── modules ├── __init__.py └── supernode_pooling.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.json 3 | *.png 4 | *.jpg 5 | *.JPEG 6 | *.mp4 7 | *.pdf 8 | *.bkp 9 | *.dtmp 10 | 11 | data/ 12 | yamls_run/ 13 | temp/ 14 | sbatch_config.yaml 15 | sbatch_template_nodes.sh 16 | submit/ 17 | wandb_config.yaml 18 | wandb_configs/ 19 | 20 | hyperparams*.yaml 21 | static_config.yaml 22 | scratchpad*.py 23 | 24 | 25 | # Byte-compiled / optimized / DLL files 26 | __pycache__/ 27 | *.py[cod] 28 | *$py.class 29 | 30 | # C extensions 31 | *.so 32 | 33 | # Distribution / packaging 34 | .Python 35 | build/ 36 | develop-eggs/ 37 | dist/ 38 | downloads/ 39 | eggs/ 40 | .eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | wheels/ 47 | pip-wheel-metadata/ 48 | share/python-wheels/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | MANIFEST 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .nox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | db.sqlite3-journal 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # celery beat schedule file 118 | celerybeat-schedule 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Benedikt Alkin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Resources for UPT Tutorials 2 | 3 | [[`Project Page`](https://ml-jku.github.io/UPT)] [[`Paper (arxiv)`](https://arxiv.org/abs/2402.12365)]] 4 | 5 | We recommend to familiarize yourself with the following papers if you haven't already: 6 | - [Transformer](https://arxiv.org/abs/1706.03762) 7 | - [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929) 8 | - [Perceiver IO](https://arxiv.org/abs/2107.14795) 9 | 10 | 11 | # Tutorials 12 | 13 | 14 | The following notebooks should provide a tutorial to get familiar with universal physics transformers. 15 | 16 | - [Preliminaries](https://github.com/BenediktAlkin/upt-tutorial/blob/main/1_preliminaries.ipynb): provides an introduction into the concepts behind UPT 17 | - Sparse tensors 18 | - Positional encoding 19 | - Architecture overview 20 | - [CIFAR10 image classification](https://github.com/BenediktAlkin/upt-tutorial/blob/main/2_image_classification.ipynb): start from a basic example (regular grid input, scalar output, easy encoder, simple classification decoder) 21 | - [CIFAR10 autoencoder](https://github.com/BenediktAlkin/upt-tutorial/blob/main/3_image_autoencoder.ipynb): introduce the perceiver decoder to query at arbitrary positions 22 | - [SparseCIFAR10 image classification](https://github.com/BenediktAlkin/upt-tutorial/blob/main/4_pointcloud_classification.ipynb): introduce handling point clouds via sparse tensors and supernode message passing 23 | - [SparseCIFAR10 image autoencoder](https://github.com/BenediktAlkin/upt-tutorial/blob/main/5_pointcloud_autoencoder.ipynb): combine the handling of input point clouds with decoding arbitrary many positions 24 | - [Simple Transient Flow](https://github.com/BenediktAlkin/upt-tutorial/blob/main/6_transient_flow_cfd.ipynb): put everything together to train UPT on a single trajectory of our transient flow simulations 25 | 26 | 27 | -------------------------------------------------------------------------------- /data/simulation/case_000000/00000000_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000000_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000001_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000001_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000002_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000002_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000003_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000003_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000004_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000004_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000005_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000005_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000006_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000006_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000007_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000007_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000008_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000008_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000009_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000009_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000010_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000010_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000011_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000011_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000012_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000012_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000013_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000013_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000014_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000014_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000015_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000015_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000016_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000016_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000017_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000017_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000018_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000018_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000019_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000019_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/x.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/x.th -------------------------------------------------------------------------------- /data/simulation/case_000000/y.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/y.th -------------------------------------------------------------------------------- /schematics/perceiver_decoder.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
MLP
MLP
Cross-attention
Block
Cross-attention...
q
q
kv
kv
Output Position
Output Position
Text is not SVG - cannot display
-------------------------------------------------------------------------------- /schematics/perceiver_pooling.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
Learnable Query
Learnable Query
Cross-attention
Block
Cross-attention...
q
q
kv
kv
Text is not SVG - cannot display
-------------------------------------------------------------------------------- /schematics/schematics.drawio: -------------------------------------------------------------------------------- 1 | 7Vzfd6I4FP5rfJw5hPBDHqut287pnPGczuxs9y0jUdlBwomx6v71GyQoklixFNJFfGglJBG+e7+b3Jub9OBwsfmDonj+lfg47JmGv+nB255pun2b/00KtmmBkxXMaOCnReBQ8BT8i0WhIUpXgY+XRxUZISEL4uPCCYkiPGFHZYhSsj6uNiXh8a/GaCZ+0TgUPE1QiKVqPwOfzdPSvp2rfY+D2Tz7ZWCIOwuUVRZdLOfIJ+vcb8G7HhxSQlj6bbEZ4jDBLsPFdTabhyfav49/jKc4/hb+vn/5lPY+uqTJ/hUojtj7dm2mXb+gcCXwEu/KthmAlKwiHyedgB4crOcBw08xmiR311xjeNmcLUJxexqE4ZCEhO7aQn/3ScpJxHLl6YeXLxklv3F2JyIR73YgnglThjcFIZ5BAOzFwtUZkwVmdMvbiV76lhCtUGXTEKq8PigGcESdeV4pXFGIhDLO9n0fAOdfBOYX4A+vCX+ngD/IKHsOf2jbNeFvSfh/xctlYlJMY4yWyyCa9Uwn5M8x+EX5t1nyjRF+92kVYxqRxLa9p8RG3o03GDYuGafADKBghqmQjFUXMWwJVexzuy4uCWVzMiMRCu8OpYMD7ga/OtR5JCQWaP+DGduKQQqtuCCPZIE3Afsraf7ZFlfPuTu3G9Hz7mKbXUT8dXONksvn/L1Ds91V1i59v+Sl3iBIDgxZ0Qk+r9kM0Rl+rT9XrRgUh4gFL8cPp5KyaDomAX/svUK5ZkGh3AKD0+cXrQq6sn+Mt6uP06lPNfVxS6pPOmBV0J9KVsKVzPd3iqLllNAFpm20y25Ju2zWZZc9CfEx5prEJU2TIZPP7HdDZv3I8zvG7sPv+Gg53/felCwsqFsWmYt1HbNHq8iFsrNHYNU1ewSgXgGMRiPD8xoHul8AWqHodqN6Lnupt4ghCWqOAzvGU4lbHmRRhMJgFvHLEE+THhJMgwkKb0TxIvD93RRBJb7jaUNdQrGLvqtCKCrlr08oNbuuegZcSfc93bqv8FDTEN01K79CKs0qv61F+Ruf8xTpAE3ddHAk4L/FLEjctKtmhEowzTLClfDvPO6LPG7Q75VzuT21ZjTjcWdPeTUuN1SMNQ27eV5HrWrUykRzllqpduvilin78y3nlq2bW/u13Y5bb+WWWZJbuTVzHdySYwgt51bZJbz6uAU7blXkVtlFPKiXW3KMol3c6pcJPTfLLdkPbjfk0NIOeefhVjVnZT3cVLu1mbO2u7gStwzt3Opc3IrcgmVdXFOriwvb7uJKUwVHN7dgzYvUHytLABjFJFPX/VwyzbdGGciQd/btMvsGS9o3qDU8DuVV8XbZN2AUDBzwlPxq1sRZHb0q0ssuSS/rhHo0RC953b3l9ILgA9Cry6auSq+y6dSpgmujV9vzqSV6merZYbP06nf0qkgvryy9qu52qSZoOXm+7fSC+umV+YMdvd68lSzbt32OXrbWuK0lBzoegwgjzixjGCabOadBO2kG7Q9AM1NC/12h9jwfT6dNQy3lgzuaE2AtOcDwiBgWuyivNf1VJZZG01+tLgJRdZApHYHQuoBhyRGIhyheJVufHxa7cwPec3jRsv+reEyDq1jBUJGrtrMALDnZIcN8TJZBkv7fAthts7hwoRjSjUaNWpfxUNWolc14sLWGVS054+HP4HtCL8Qmc/7/bvGLi/T/P3MuJkkC7Xu4LT2ZD3u+HLHlQJ4TfMm4CXLMPPBUzc3aOeY0xTHRtHDUiQMLc1Gj2aNO7C6+UVGD7LLxDUunkbbl8EbPHMiyD8MgXp5yz3ISPLbepzeaKhzEusyzC85PgRq1zrYcw+i4dRG3ykbmzarc2jW9oRRtcxXixOguT9vu4jFVtlE4Q7FQv1+xPvBer++9Xp9/Sd/wxNv01U/X2EjUBUGqZkiXncukeUZN06WonvCMOgOjaoMs0HCyAXi9wRnGFF/IAg3P3eTIxphiP5i0JKYhZWNawCydjfkOgzoh68nIvbEXP/7+8v3xMX7+MkCfVPOok0gblyEt8JRmU6OR5+WSYbOTiJ0640mFyZRiry1UhZOAXRfw5nUAD4yCRwgt5fJcXwG+t6fHBejzy8Mp1KlpOhzlDe/+Aw==7Vzdd6I4FP9rfOwcQkDksdo6O2c7u93TOZ2PNypRmSLxxDjV/es3SEBIYkURMov40JKbD+De+7vJvbmkB0eLzUfiLeefsY/Cnmn4mx6865mmM7DZ35iwTQj9lDAjgZ+QwJ7wFPyLONHg1HXgo1WhIcU4pMGySJzgKEITWqB5hOC3YrMpDot3XXozfkdjT3iaeCGSmn0NfDpPqAM71/oPFMzm6Z2BwWsWXtqYD7Gaez5+y90L3vfgiGBMk6vFZoTCmHcpX14n8E/X/2aPHyj9B/6Y/vXySm+S0cendMlegaCIXnZoMxn6lxeuOb/4u9JtykCC15GP4kFADw7f5gFFT0tvEte+MY1htDldhLx6GoThCIeY7PpCf/eL6TiiOXryY/QVJfgVpTURjtiwQ/5MiFC0EYR4hAMgEwtTZ4QXiJIt68dHGVhctFyVTYOr8tteMUCft5nnlcLhRI8r4ywbe89wdsF5fgL/4TXxvy/wH6SQPcZ/aNs18d+S+P8ZrVaxSTGNR2+1CqJZz+yH7DmGL4RdzeIrilnt03qJSIRj23ZJiY3dW3c4alwyfQEZQIEMUyEZqy5g2BJXkc/sOi9iQud4hiMvvN9Th3u+G6y0b/OA8ZJz+yeidMsnKW/NBFmQBdoE9Fvc/YPNS99zNXcbPvKusE0LEXvdXKe4+D1ft++2K6X9kveLX+oMQTLG4DWZoOOaTT0yQ++N56gVg6DQo8Gv4sOppMy7PuKAPXamUI4pKJQjIDh5ft5L0JXsMc5Xn36nPtXUxympPuYBw1JafypZCUcy31+IF62mmCwQaaNddkraZbMuuzyQOP6ImCYxSZN4ymQr+92UWT/nWY2x+7Ea31vNs9GbkoUFdcvCvabFoyVCoeziEVh1LR5TF7cuAYzHY8N1G2f0QGC0Qs/tJtUcAInNdx71JFYzPtAiP5V8yzOZk7wwmEWsGKJpPELM02DihbecvAh8f7dCUImvuGqoSyi26LoqhKJS/vqEUnPkQM98K+m+q1v35QABj9Bds/IrpNKs8sthg1YueUQ4QFM3HGyJ8X8vaRB7aVeNCJVgmkVE53BXdLhBWY97oNaMZhxucG0eN1TMNc16eWDQQasitNyS0AJVg6HVBC378y3Hlq0bW+nAHbbOxVa2p38UW65ObJlyDKHl2Cq7g1cftmSfqMPWadiCJbFlal0TmnKMol3YGpQJPTeLLTkA0W6WQ0s7y7uMhKrmrF/WnFlazVn/2rBlaMeW02GrIrYGZbHV14otOX2g3djKorD6sHVVWQLAEHNMHedDySzf+mQAuzBDRfsGy4YZoNYQHmx7mAEYgoEDrhJfjZo42EUaqsILVkSNOgsYZHqQRXwzbUlHSRBdWyIwbHtwQkIkBL8BIq0OkRURaZed8KpCt5qg5YSMlsPLVC8om4VXl21RFV5lsy0SBdcGr7anW8jwgvrhlbqQdbnMruuj6bRpVku5rX3NyXyW7Cw9eBTxdeC1pvKpxNJoKp/VeVNVP70su29raQ3GWrJr9ClarmP/7dNi9wn0JacXLd+yiF+cO4porApctX3WbMkbtynPH/EqiFOZW8B22xSDsIop3WjUqHW7t1WNWtndW0tropcl794+B19ieHl0Mmf/7xcvTKT//5WzmPAFtH+OaunZxc3wUkDLHjwH8JJiE+SQucepGpu1Y6zsoQ2VMaaO1/ahsBY1mj21wepS3atqUNlUd60pNulTFhzsoSz7MAyWq0PuWU6CRet9+KM5hYNYl3l2wPElUKPWOXXtO2ydi63SKTYX2Uy7JcTb5hosY6O7Omy7xRN3bEM4Dk5oP6jYHrjvt3ffb88ukjc88DYD9dM1NRPZqVJ0aDk3I63sWgZWPUHoLLiI6gmPqDMwqnZIAw0HO4ib5UKHI4gRX8gCza7d7JOOMjRO87L4ZC3N7OOx6+aSzNIDPvt1xjaEiV3xDRtUhTaAfYGZHYc/xz+sryPrmXx8HtHX59nN7U03sZc3VUoGKnLLlO3MioaqkpTlzZL8wVt3aIL9du4IQnsgIay2wMZhuRcPnVjTJGg7Rh5dk0ufAaolaCulzlrp4TNHE2frYjzUatlOCAxllg0U7NrRIFTdlk2xBjtsXHRZNnlLJINXi/ZEgCkuIFV27UKbIqy4P5c8WeHtD3eH9/8B7Vxtk6o2FP41zrQfukMIL/Jxdde2M3undrYzt/djrkSlReKEeFf76xs0CJK4RlnIDuL4gbxCnnOek+RwyACOV9tfKVovv5AQxwPbCrcD+DSwbWDZ9iD7W+HukDPMMxY0CkWlIuM1+g/nLUXuJgpxelKRERKzaH2aOSNJgmfsJA9RSt5Oq81JfHrXNVqIO1pFxusMxViq9jUK2VKMwi3V/g1Hi2V+Z2CJkhXKK4su0iUKyVvpXvB5AMeUEHa4Wm3HOM7Ay3EZRnS+WvhT/H3nLlj8p4+8118OvU+uaXIcAsUJ+9iuhSx/oHgj8BJjZbscQEo2SYizTsAAjt6WEcOvazTLSt+4yvC8JVvFongexfGYxITu28Jw/8vyScJK+Ycfz08ZJf/ivCQhCe92JJ4JU4a3FSFeQAAcxcL1GZMVZnTH24leho4QrVBl23IP6bdCMYAn6izLSuGLTCSUcXHsuwCcXwjMr8Af3hP+XgV/kFP2Ev7QdRvC35Hw/4LTNDMptjVFaRoli4Htxfw5Rt8pv1pkV4zw0tfNGtOEZLbtIyU2CR6D0bh1yXgVZgAFM2yFZJymiOFKqOKQ23WRJJQtyYIkKH4uckcF7hZPFXVeCFkLtP/BjO3EJIU2XJAnssDbiP2dNX9wRepbqeRpK3reJ3Z5IuHDLTXKkt/KZUWzfSpvdxhfNqgbBMmBIRs6w5c1myG6wO/156sVg+IYsejH6cOppCyaTknEH/uoUL5dUSi/wuDD84tWFV05Psbt6uP16lNPfXxN9bHPGBZt/allJXzJfP9FUZLOCV1h2kW77GvaZbspuzyUEJ9irklc0jSbMvnKfj9lNo88L7H2P14SonR57L0tWTjQtCyCe1o8OlUq6C4egdPU4jHf4jYlgMlkYgVB60APK0Ar9NxtU80BkGB+QgxJUHMc2CmeStzKIIssFEeLhCdjPM96yDCNZih+FNmrKAz3KwSV+E5XDU0Jxa1uXRVCUSl/c0Jp2HNgZr6VdD8wrfuyg0C46O5Z+RVSaVf5ZbdBJ5c8VTpA2zQdXAn4P9YsynZpd80IlWDaZUS/4a654Qa6O+6hWjPa2XCDe9txQ8Vc0+4uDwx7atWkVqBJLVDXGVpP0PJ+vuPcck1zK++459at3Dq+07/IrcAkt2zZh9Bxbum+wWuOW/KeqOfWddyCmtyyja4JbdlH0S1uDXVcz+1yS3ZAdBty6BiHvI9IqGvOPF1z5hg1Z969ccsyzi2/51ZNbg11ueUZ5ZYcPtBtbh29sOa4dVdRAsCqxpj6/oNmlG9zMoC9m6GmfYO6bgZo1IUHu+5mAFbFwIFAya9WTRzsPQ116aXraYBGvXiw654GiV4QfAJ6OT29atLL1aUXNEovObqi4/Sy1avDdunVh07UpZdu6MRBwY3Rq+uxEzK94CegVx8+UZdeuuETrm2UXrKj4yVKMOLMssZx9i3nPOomzaBrnma526UpN1MQhHg+bxtqKR7cMxwA68gOhhfEsPiI8l7DX1ViaTX81ek9EDUnGUfXA+EYfYHhyB6I35P1JvvyeYIR29DsOIDqeQE/rXl2yssG9njPh5jMIrY7pB4eHn7+2EnJyFdj1bMdfMV7DxUlGztAwJFDJHJJTUkaZR8NdAB2166+7lAsBKxWTWEfJ1HXFOrGSThGnbGOHCfx5WX6sZz6FKGUwPiH3o6Z+IgjP07YUZDlDD9yLoISEwteqrnYOKd0j0OpzSnRtHIeigcrK1ar3fNQnN4LUleDdL0gRoPXHNkJMrBHsuzjOFqn5zZxJQmeWu/zn6MqtpFNmWcfXF7ytGqdcwdAz61buaUdvFb39di+6SOlaFeqsM6MbnredlfPsnKtykGLlfrDmvVB8H794P36/OIwwjOjGaqfrq2ZyM2VomfLrbGeumsZWPdsrpvoUlVPeEGdgVW3Qe5YONsAvN/gAmOqA3JAu2s315bm9CnFYTTriA9Ditl0gK0ds9nYpH7VER/WdVALQKXl1GQSBKWY2fy8Yq9R8Cs7E+goXyYNFeAHRzFdAT9PFkcmHyhSHDwNn/8H7Vxfc6M2EP80nmkfLoMQGPMYO3GvbdJmLp25S186nJFtWoxcIZ/tfPoKI7CR5JgYg66YvASt/hh297fsrhb14Gix+Yl4y/kj9lHYMw1/04N3PdMEpgPZv4SyTSlu30kJMxL4fNCe8By8Ik40OHUV+CguDKQYhzRYFokTHEVoQgs0jxC8Lg6b4rD4q0tvxn/R2BOeJ16IpGGfA5/OU+rAPhj9EQWzefbLwOA9Cy8bzJeI556P1we/Be97cEQwpunVYjNCYcK8jC+fXu+H/z69jh7sj1/+2tzZX4dfXj6kq4/fMyV/BIIietmlzXTpb1644vziz0q3GQMJXkU+ShYBPThczwOKnpfeJOldM5VhtDldhLx7GoThCIeY7OZCf/eX0HFED+jpH6PHlOB/UNYT4YgtO+T3hAhFG0GIJzgAcrEwfUZ4gSjZsnl8lYHFRctV2TTstL3eKwbo8zHzQ6VwONHjyjjL194znF1wnr+D//Ca+N8X+A8yyJ7iP7TtmvhvSfx/RHGcmBTTePLiOIhmPbMfsvsYfiXsapZcUcx6n1dLRCKc2LZLSmzs3rrDUeOS6QvIAApkmArJWHUBw5a4inxm13kTEzrHMxx54f2eOtzz3WCt/ZgHjJec238jSrf8JeWtmCALskCbgH5Jpt/YvPVy0HO34SvvGtusEbHHPZiUNF8O+/bTdq1sXvp8yUOdIUjGGLwiE3Ras6lHZuit9Ry1YhAUejT4Vrw5lZT51CccsNvOFcoxBYVyBASn989nCbqS38b56tPv1Kea+jgl1cc8YlhK608lK+FI5vsP4kXxFJMFIm20y05Ju2zWZZcHEsefENMkJmmSvDKZZ797ZdbPedZj7P5Yj+/F83z1pmRhQd2ycK/JebREKJR1HoFVl/OYhbh1CWA8Hhuu2zijBwKjFXpuN6nmAEhsvvOoJ7Ga8YEW+ank2yGTOckLg1nEmiGaJiskPA0mXnjLyYvA93cegkp8Ra+hLqHYYuiqEIpK+esTSs2ZAz3vW0n3Xd26LycIeIrumpVfIZVmlV9OG7TS5RHhAE3dcLAlxv++pEESpV01IlSCaRYRXcBdMeAGZSPugVozmgm4wbVF3FDxrmk2ygODDloVoeWWhBaomgytJmg5nm85tmzd2MoW7rB1LrbyPf2T2HJ1YsuUcwgtx1bZHbz6sCXHRB223octWBJbplaf0JRzFO3C1qBM6rlZbMkJiHazHFraWd5VJFQ1Z/2y5szSas7614YtQzu2nA5bFbE1KIutvlZsyeUD7cZWnoXVh62rqhIAhlhj6jg3Jat865MB7NIMFe0bLJtmgFpTeLDtaQZgCAYOuEp8NWriYJdpqAovWBE16ipgkOtBnvHNtSVbJUV0bYXAsO3JCQmREHwHiLQ6RFZEpF32hVcVutUELRdktBxeptqhbBZeXbVFVXiVrbZIFVwbvNpebiHDC+qHVxZC1hUyu66PptOmWS3VtvY1F/NZcrD04FHE/cBrLeVTiaXRUj6ri6aqfnpZdt/W0pqMteTQ6OdouUritzHy6IoknzaL3z7/sGTkmPX1zNEODyGeBHSbtm5ubn687EtJyxcw4nfqjiKHq4JkbR9DW/J2byapJxwHSQF0C9hum2LqVuEIGI2awm7Pt6opLLvna2ktD7PkPd/Hh6fLYuq7KAsD2j9atfTs9eb4KKBjD5Yj+MiwCA6QuMelGou1Y6rs0Q6VMaXO6vah4LEazZ7tYHUF8VU1qGxBvNZCnOwuC2H4UJZ9GAbL+FgQdyDBovU+/mmdIoysyzw74LTL06h1zhIAHbbOxVbpQpyLbLndEuJtDwYsE6MbH7fd4rk8tiEcGieMH1QcD9y3x7tvj2cX6RMeeZqB+u6aehPZmVJ0aDm3bq2sLwOrnjN0FlxE9YQn1BkYVSdkiYWjE8QtdWHCCcSID2SB2ny3h1+no7ULf9v+8pn8OXoxbiF4/NC9XsoDRslARR2UcpzZkOf21k0eOSTqDk2w387dK2gPJP+ttvD6uNyLBySsqJjU/d+nCqUyTys7KOVkkWddjIdaLds70hO5ZQMFu3YyFVK3ZVN4AseNiy7LJific3i1KBMPTNGNUdm1C6XiWXN/hnbqZ+xPIof3/wE=7V1bb6M4FP41kXYfpsIYh/DYJE13pVZTqSPtzNOKCU7CLsGRcdpkfv2aYALBJDEttyVEfcBX4JzzHdsfx+4ATta7R2pvVs/Ewd5A15zdAE4Hug40XR+Ef5qzj3JGccaSuo6olGS8ur9w3FLkbl0HBycVGSEeczenmXPi+3jOTvJsSsn7abUF8U7vurGX4o5akvE6tz0sVfvLddhKvAVK1f4Du8tVfGegiZK1HVcWXQQr2yHvqXvBhwGcUEJYdLXeTbAXCi+WC5ptvzv2oxd8M7yvzk/rC3v+9SXqfVakyfEVKPZZuV0LXb7Z3lbIS7wr28cCpGTrOzjsBAzg+H3lMvy6sedh6Ts3GZ63YmtPFC9cz5sQj9BDW+gcfmE+8VkqP/rx/IBR8i+OS3zi827Hiq8rxPKGKcO7lKbE6z9issaM7nkVUToyhGqFKesaitLviWGAoaizShuFKTJtYYzLY9+JwPmFkHkB+cNbkv8wI38QQ/aa/CFCFcnfkOT/jIMgdCm69mIHgesvB/rQ488x/kn51TK8YoSXvm43mPok9G1lamxm3VvjSe2aGWaQAXKQoedoxqgKGEiSKna4XxdJQtmKLIlvew9J7jiRu8ZTSZ0nQjZC2v9gxvZikLK3XJEnusA7l30Pm98hkfqRKpnuRM+HxD5O+Px1U43C5I90WdLskIrbFVNlQLZ0jq8bMrPpEl/qz4zqhcK8aBgUezZz304H0Twti6YvxOXvcTQoU88YlJlBcPRColXGVo6P8XHzGfbmU8h8TEXzicen8uznU17ClNz3N2r7wYLQNaZd9Mumol/Wq/LLox5YhYBlKQILwVYBy5KA9YL5i/L70nBmxBdwh5lR9QDjJdrhx0scO1gde68LcgZsGnLxSvo2FglG1uWpLhKAUdUiAYBqFTCbzTTLql3Qo4ygcwwd1WrnMhkxtZktiZq/HzuVZ67c0kIWWbbnLn2e9PAi7CGUlTu3vXuRvXYd5zBg5anvdBCrSikoS1HkKCXP+KtTSsUMRTPzKsn2raZtP4eIiKjYWzb+HK3Ua/yoEeOvfc6ThQPUm4bDUBL81w1zw0XDTSMiTzH1IsKU5N+v/y6t/4BYMF9dAIoVV0vWf/Fj3wyzAnPGmpqXeVYPrULQijVxFVqxMbcEW7q8nu84tlDT2Dp+wu+xpYgtXRFbscragi2ZQ+g4tlS/1FaHLdhjqxi2VL/VwpZhS+YouoWtkQr1XC+2+jCIgtgaKmIrnpK0BVsy4dFtbEGjcWz1VEZBbKlSGbExtwVbXecyJGxpjWOr5zKKYQuqchl6u7gM2HUuQ5oTDpvGFqw4GqFd4SBAywaNm+adYth+hTqQRd77t4v+DSr6N9iu7yBQDn/oln8DWsbBASsXX/W6OKOHVzF4IUV4GaBd8JIDLDoOLwhaAK9+d0RBeKluj4DtYpVg1/dHSPDS82eH9cKr3yNREF6qmyRg6bvXPqdoeZdE1+EFm4dXvB7s4aUIr3jOd/2bSLt4W0MmOtKbkKZ4TpxuwgyiFsBMl6Rfqqgty8GLRd2ilgL/hw1HOhsywfBkMyx2Rd9qnHOeWmqNczZ6BqLgIKPMQLTrA4YhMxB/+ptteJTBDNtsS8PzPbIHgPy24dkBLxvok4Ode2Tusn2Uuru7+73cQamR7YHZw1rMnO8eeZCs7EQQQw6RiDX1QgI33B3SAbEjPfu5I2cioNXqCvs4iYKuUDVOArWLjDXkOInnp5dyMdWKmFnQ+JZ+o5n4iCM+TtCRgOUMPmIsghQSE1zmY7FsTCnH9ZWOKdE0c8DREGZmrFq9BxyhngUpZkFIlQUxWuWUkUyCDPSxrHvPczfBuUVcSoOn3vv8vuOcZWRV7tkE16c8tXpnJDMdNWKriJdtC7ZUozv00sF1aHpPqb1PVdiEXjc477xHmUk20jJHp16pD6zL9a3L9flF9MTljgh5J3UY4R+aHLgjc+xFzJI5/ZuJonIZvenDbNb05M6wmg7QQz13VNB9qHJHsPRD0j7kPrLwNuBldwC0zzaImZazDcDlBhV5nLzjUc55HJ7gDzwG/LozvkcKvG/8EAGUc3DKlmX5zP89SyZFBRtAV44Krkz2zRJlH5k2ghOvf5UIKNvvqxJlLYtaRDJRdgRZh6hooGcGlTOfpUtio3ky+V8K0RCV/EcK+PAf5VhLc5swEP41HJ3h4Qc+GsdOD0mbGXfaJjcFFFAts1QIG/fXVxhhIWDcpCH2tPX4wH5ardC3LwnDmW/yG4aS6A4CTA3bDHLDuTZs23Jt2yj+ZrAvkWEFhIwEUkkBK/ITS9CUaEYCnGqKHIBykuigD3GMfa5hiDHY6WrPQPVVExTKFU0FrHxEcUvtKwl4VKLuqKb9AZMwqla2TDmyQZWyNJFGKIBdbS1nYThzBsDLp00+x7Qgr+Llc367p8yNP67Xs+xpnHyxHgeD0vryNVOOW2A45v2aHrul7S2imSRMbpbvKwYZZHGACyuW4Xi7iHC8SpBfjO5EzAgs4hsqh58JpXOgwA5zneDwK3CIeQ0vfwJPOYM1rkZiiIVZ74X7lbxsMeM4r7lK7v8GwwZzthcquR6TMpTHVRDsVGBYQ4lFelDIgJTBGB5NK8LFg+T8FfxPzBbdOBDxK0VgPIIQYkQXCvWUQ0whKZ1bgES64TvmfC+TEWUcdCfhnPBvxfSrkZQeaiPXubR8EPaVEIv9lpPsUSU/1AfVvINUTXydM1PImI9PMVYVEcRCfMrgZFgqFnyeDA6GKeJkq9eL/j1dRavKtLvb+16TbTmdTb35uZPKaSTVuJ1Uo46cst8tpdoV7BwppdJj8o7pcYGol1PvgYgXPHrdbXh92HBn+aJyUr0rNexY5kg35DQMlQWhZegQGsftvCFanL+lAPddRy8TKBP3ZYHSm3+HrbIrjmxpOkCcCyIJxIY9poJQ74mJp7B48ij463+vMk+sjuOO2VGam7ncX2ketZzxo8Wz2BzXyewkrc6whBAlYSxEX/CIBe4VVBFxFZjJgQ0JgkMCd3lPT+pzdctp2yf2WbvluOWS9fY/84nVvBdc3CmTSzYlq9aSVIP63a1AO/NcHY9AZ7oUuC89Hll9d723Obp9/f6U8SQr+uc9pKTsT302ouXSnE7P3YgGrQzruCR0tqI/SDEhqm8y5TFCfdlyFr8A5VfbbtpAEP0aP1L5giE8xkDSB5DSUqlJ3jb2xN6yeMiy5tKv79oe3zBCoBKQWsTDztnZWeacmTE2nOFi+yjZMppiAMKwzWBrOCPDtq072zbSrxnscqQEQskDcqqAGf8NBJqEJjyAVcNRIQrFl03QxzgGXzUwJiVumm7vKJq3LllIN5oVMPOZgJbbTx6oiLJwa95fgYdRcbNl0s6CFc4UYhWxADe1u5yx4QwlospXi+0QREpewYt6ms4TmL5+t7sfsjOZ8dHrj04e/eGcI2UKEmJ12dDuII+9ZiIhwihZtSsYlJjEAaRRLMPxNhFXMFsyP93d6JrRWKQWgrbfuRBDFCizs06QfVIcY1XD84/GV0riHIqdGGMd1jsxX+JlDVLBtiYV5f8IuAAld9qFdjt3pC3VcreomU1VGbZDWFSviqIoGFVjWMauGNcLIv0MAXpWi28IdAGTiVJFGGLMxLhCvUoRU1uVzwRxSTr8AqV21I0sUdhUCbZcPdfWL2moLy5Zoy1FzoxdYcQ63+fM0XYL+6W+WZ3LrOLgeWquMJE+HGOMKlQxGcKxgD03d0z5PFodEgRTfN0cGJdX2m612gSYjNmb1tQ2vyVAl1yq9x4G9wNvePUe6+/1mNPusbKd6j1mf1qLOTdpsapd+p/YLjfoAjr6hFz/wFL3/dG6PzHzH0qH6o+pvTiW6R4PlA+IVqCsNMp0/qJaurccyNUQfqntHB7Il56rtykUt3taoVxMX7c1hvV/uNWqw5TSRHKMDbsnNKHem9SrMF15Av35PzCbnb0e7Z84mvclutxo7rXE+GjxrJNTTTIPklZnmCAmeBhr09c8gsa9lCqu3w3uaWPBgyBr4EPqNZv6Wpq4B/6SXvVp2W9JMl//Z5pY1tVE0Wb1BpnPuOo93Bn/AQ== -------------------------------------------------------------------------------- /schematics/upt_dense_classifier.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
Message Passing
to Supernodes
Message Passing...
Transformer
Transformer
Perceiver Pooling
Perceiver Pooling
Data
Data
Model
Model
Optional
Optional
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Linear Classifier
Linear Classifier
Latent
Latent
Input Image
Input Image
Input Position
Input Position
ViT Patch Embed
ViT Patch Embed
+
+
Prediction
Prediction
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /schematics/upt_sparse_autoencoder.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
Message Passing
to Supernodes
Message Passing...
Transformer
Transformer
Perceiver Pooling
Perceiver Pooling
Data
Data
Model
Model
Optional
Optional
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Latent
Latent
Input Features
(pressure, velocity, ...)
Input Features...
Input Position
Input Position
MLP
MLP
+
+
Perceiver Decoder
Perceiver Decoder
Output Features
Output Features
Output Position
Output Position
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /schematics/upt_sparse_classifier.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
Message Passing
to Supernodes
Message Passing...
Transformer
Transformer
Perceiver Pooling
Perceiver Pooling
Data
Data
Model
Model
Optional
Optional
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Transformer
Linear Classifier
Linear Classifier
Latent
Latent
Input Features
(pressure, velocity, ...)
Input Features...
Input Position
Input Position
MLP
MLP
+
+
Prediction
Prediction
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /upt/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * -------------------------------------------------------------------------------- /upt/collators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/upt/collators/__init__.py -------------------------------------------------------------------------------- /upt/collators/simulation_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import default_collate 3 | 4 | class SimulationCollator: 5 | def __init__(self, num_supernodes, deterministic): 6 | self.num_supernodes = num_supernodes 7 | self.deterministic = deterministic 8 | 9 | def __call__(self, batch): 10 | collated_batch = {} 11 | 12 | # inputs to sparse tensors 13 | # position: batch_size * (num_inputs, ndim) -> (batch_size * num_inputs, ndim) 14 | # features: batch_size * (num_inputs, dim) -> (batch_size * num_inputs, dim) 15 | input_pos = [] 16 | input_feat = [] 17 | input_lens = [] 18 | for i in range(len(batch)): 19 | pos = batch[i]["input_pos"] 20 | feat = batch[i]["input_feat"] 21 | assert len(pos) == len(pos) 22 | input_pos.append(pos) 23 | input_feat.append(feat) 24 | input_lens.append(len(pos)) 25 | collated_batch["input_pos"] = torch.concat(input_pos) 26 | collated_batch["input_feat"] = torch.concat(input_feat) 27 | 28 | # select supernodes 29 | supernodes_offset = 0 30 | supernode_idxs = [] 31 | for i in range(len(input_lens)): 32 | if self.deterministic: 33 | rng = torch.Generator().manual_seed(batch[i]["index"]) 34 | else: 35 | rng = None 36 | perm = torch.randperm(len(input_pos[i]), generator=rng)[:self.num_supernodes] + supernodes_offset 37 | supernode_idxs.append(perm) 38 | supernodes_offset += input_lens[i] 39 | collated_batch["supernode_idxs"] = torch.concat(supernode_idxs) 40 | 41 | # create batch_idx tensor 42 | batch_idx = torch.empty(sum(input_lens), dtype=torch.long) 43 | start = 0 44 | cur_batch_idx = 0 45 | for i in range(len(input_lens)): 46 | end = start + input_lens[i] 47 | batch_idx[start:end] = cur_batch_idx 48 | start = end 49 | cur_batch_idx += 1 50 | collated_batch["batch_idx"] = batch_idx 51 | 52 | # output_feat to sparse tensor 53 | output_feat = [] 54 | for i in range(len(batch)): 55 | feat = batch[i]["output_feat"] 56 | output_feat.append(feat) 57 | # output_feat is either list of tensors (for training) or list of list of tensors (for rollout) 58 | if torch.is_tensor(output_feat[0]): 59 | collated_batch["output_feat"] = torch.concat(output_feat) 60 | else: 61 | collated_batch["output_feat"] = output_feat 62 | 63 | # collate dense tensors 64 | collated_batch["output_pos"] = default_collate([batch[i]["output_pos"] for i in range(len(batch))]) 65 | collated_batch["timestep"] = default_collate([batch[i]["timestep"] for i in range(len(batch))]) 66 | 67 | return collated_batch 68 | -------------------------------------------------------------------------------- /upt/collators/sparseimage_autoencoder_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import default_collate 3 | 4 | class SparseImageAutoencoderCollator: 5 | def __init__(self, num_supernodes, deterministic): 6 | self.num_supernodes = num_supernodes 7 | self.deterministic = deterministic 8 | 9 | def __call__(self, batch): 10 | collated_batch = {} 11 | 12 | # inputs to sparse tensors 13 | # position: batch_size * (num_inputs, ndim) -> (batch_size * num_inputs, ndim) 14 | # features: batch_size * (num_inputs, dim) -> (batch_size * num_inputs, dim) 15 | input_pos = [] 16 | input_feat = [] 17 | input_lens = [] 18 | for i in range(len(batch)): 19 | pos = batch[i]["input_pos"] 20 | feat = batch[i]["input_feat"] 21 | assert len(pos) == len(pos) 22 | input_pos.append(pos) 23 | input_feat.append(feat) 24 | input_lens.append(len(pos)) 25 | collated_batch["input_pos"] = torch.concat(input_pos) 26 | collated_batch["input_feat"] = torch.concat(input_feat) 27 | 28 | # select supernodes 29 | supernodes_offset = 0 30 | supernode_idxs = [] 31 | for i in range(len(input_lens)): 32 | if self.deterministic: 33 | rng = torch.Generator().manual_seed(batch[i]["index"]) 34 | else: 35 | rng = None 36 | perm = torch.randperm(len(input_pos[i]), generator=rng)[:self.num_supernodes] + supernodes_offset 37 | supernode_idxs.append(perm) 38 | supernodes_offset += input_lens[i] 39 | collated_batch["supernode_idxs"] = torch.concat(supernode_idxs) 40 | 41 | # create batch_idx tensor 42 | batch_idx = torch.empty(sum(input_lens), dtype=torch.long) 43 | start = 0 44 | cur_batch_idx = 0 45 | for i in range(len(input_lens)): 46 | end = start + input_lens[i] 47 | batch_idx[start:end] = cur_batch_idx 48 | start = end 49 | cur_batch_idx += 1 50 | collated_batch["batch_idx"] = batch_idx 51 | 52 | # output_pos 53 | collated_batch["output_pos"] = default_collate([batch[i]["output_pos"] for i in range(len(batch))]) 54 | 55 | # target_feat to sparse tensor 56 | # batch_size * (num_outputs, dim) -> (batch_size * num_outputs, dim) 57 | collated_batch["target_feat"] = torch.concat([batch[i]["target_feat"] for i in range(len(batch))]) 58 | 59 | return collated_batch 60 | -------------------------------------------------------------------------------- /upt/collators/sparseimage_classifier_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import default_collate 3 | 4 | class SparseImageClassifierCollator: 5 | def __init__(self, num_supernodes, deterministic): 6 | self.num_supernodes = num_supernodes 7 | self.deterministic = deterministic 8 | 9 | def __call__(self, batch): 10 | collated_batch = {} 11 | 12 | # inputs to sparse tensors 13 | # position: batch_size * (num_inputs, ndim) -> (batch_size * num_inputs, ndim) 14 | # features: batch_size * (num_inputs, dim) -> (batch_size * num_inputs, dim) 15 | input_pos = [] 16 | input_feat = [] 17 | input_lens = [] 18 | for i in range(len(batch)): 19 | pos = batch[i]["input_pos"] 20 | feat = batch[i]["input_feat"] 21 | assert len(pos) == len(pos) 22 | input_pos.append(pos) 23 | input_feat.append(feat) 24 | input_lens.append(len(pos)) 25 | collated_batch["input_pos"] = torch.concat(input_pos) 26 | collated_batch["input_feat"] = torch.concat(input_feat) 27 | 28 | # select supernodes 29 | supernodes_offset = 0 30 | supernode_idxs = [] 31 | for i in range(len(input_lens)): 32 | if self.deterministic: 33 | rng = torch.Generator().manual_seed(batch[i]["index"]) 34 | else: 35 | rng = None 36 | perm = torch.randperm(len(input_pos[i]), generator=rng)[:self.num_supernodes] + supernodes_offset 37 | supernode_idxs.append(perm) 38 | supernodes_offset += input_lens[i] 39 | collated_batch["supernode_idxs"] = torch.concat(supernode_idxs) 40 | 41 | # create batch_idx tensor 42 | batch_idx = torch.empty(sum(input_lens), dtype=torch.long) 43 | start = 0 44 | cur_batch_idx = 0 45 | for i in range(len(input_lens)): 46 | end = start + input_lens[i] 47 | batch_idx[start:end] = cur_batch_idx 48 | start = end 49 | cur_batch_idx += 1 50 | collated_batch["batch_idx"] = batch_idx 51 | 52 | # targets can be collated normally 53 | collated_batch["target_class"] = default_collate([batch[i]["target_class"] for i in range(len(batch))]) 54 | 55 | return collated_batch 56 | -------------------------------------------------------------------------------- /upt/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/upt/datasets/__init__.py -------------------------------------------------------------------------------- /upt/datasets/simulation_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class SimulationDataset(Dataset): 9 | def __init__( 10 | self, 11 | root, 12 | # how many input points to sample 13 | num_inputs, 14 | # how many output points to sample 15 | num_outputs, 16 | # train or rollout mode 17 | # - train: next timestep prediction 18 | # - rollout: return all timesteps for visualization 19 | mode, 20 | ): 21 | super().__init__() 22 | root = Path(root).expanduser() 23 | self.root = root 24 | self.num_inputs = num_inputs 25 | self.num_outputs = num_outputs 26 | self.mode = mode 27 | # discover simulations 28 | self.case_names = list(sorted(os.listdir(root))) 29 | self.num_timesteps = len( 30 | [ 31 | fname for fname in os.listdir(root / self.case_names[0]) 32 | if fname.endswith("_mesh.th") 33 | ], 34 | ) 35 | # these values were mistakenly copied from the "v1-10000sims" dataset version of the original codebase, which was a preliminary dataset that we generated during development. 36 | # the correct values would be from the "v3-10000sims" dataset version which would be the following (we keep the old values as its not too important for the tutorial). 37 | # self.mean = torch.tensor([0.03648518770933151, 1.927249059008318e-06, 0.000112384237581864]) 38 | # self.std = torch.tensor([0.005249467678368092, 0.003499444341287017, 0.0002817418717313558]) 39 | # For more details on normalization, see the UPT paper, Appendix D.5. 40 | self.mean = torch.tensor([0.0152587890625, -1.7881393432617188e-06, 0.0003612041473388672]) 41 | self.std = torch.tensor([0.0233612060546875, 0.0184173583984375, 0.0019378662109375]) 42 | 43 | def __len__(self): 44 | if self.mode == "train": 45 | # first timestep cant be predicted 46 | return len(self.case_names) * (self.num_timesteps - 1) 47 | elif self.mode == "rollout": 48 | return len(self.case_names) 49 | else: 50 | raise NotImplementedError(f"invalid mode: '{self.mode}'") 51 | 52 | @staticmethod 53 | def _load_positions(case_uri): 54 | x = torch.load(case_uri / "x.th", weights_only=True).float() 55 | y = torch.load(case_uri / "y.th", weights_only=True).float() 56 | # x is in [-0.5, 1.0] -> rescale to [0, 300] positional embedding is designed for positive values in the 100s 57 | x = (x + 0.5) * 200 58 | # y is in [-0.5, 0.5] -> rescale to [0, 200] positional embedding is designed for positive values in the 100s 59 | y = (y + 0.5) * 200 60 | return torch.stack([x, y], dim=1) 61 | 62 | def __getitem__(self, idx): 63 | if self.mode == "train": 64 | # return t and t + 1 65 | case_idx = idx // (self.num_timesteps - 1) 66 | timestep = idx % (self.num_timesteps - 1) 67 | case_uri = self.root / self.case_names[case_idx] 68 | pos = self._load_positions(case_uri) 69 | input_pos = pos 70 | output_pos = pos 71 | input_feat = torch.load(case_uri / f"{timestep:08d}_mesh.th", weights_only=True).float().T 72 | output_feat = torch.load(case_uri / f"{timestep + 1:08d}_mesh.th", weights_only=True).float().T 73 | # subsample inputs 74 | if self.num_inputs != float("inf"): 75 | input_perm = torch.randperm(len(input_feat))[:self.num_inputs] 76 | input_feat = input_feat[input_perm] 77 | input_pos = input_pos[input_perm] 78 | # subsample outputs 79 | if self.num_outputs != float("inf"): 80 | output_perm = torch.randperm(len(output_feat))[:self.num_outputs] 81 | output_feat = output_feat[output_perm] 82 | output_pos = output_pos[output_perm] 83 | # create input dependence to make sure that encoder works by flipping the sign of input/output features 84 | # - if the dataset consists of only a single sample the decoder could learn it by heart 85 | # - if the encoder doesnt work it would not get recognized as it is not needed 86 | # - flipping the sign creates an input dependence (if input sign is flipped -> flip output sign) 87 | # - if the encoder does not work, it will learn an average of the two samples 88 | # - this is only relevant because this tutorial overfits on 1 trajectory for simplicity 89 | if torch.rand(size=(1,)) < 0.5: 90 | input_feat *= -1 91 | output_feat *= -1 92 | elif self.mode == "rollout": 93 | # return all timesteps 94 | timestep = 0 95 | case_uri = self.root / self.case_names[idx] 96 | pos = self._load_positions(case_uri) 97 | input_pos = pos 98 | output_pos = pos 99 | data = [ 100 | torch.load(case_uri / f"{i:08d}_mesh.th", weights_only=True).float().T 101 | for i in range(self.num_timesteps) 102 | ] 103 | input_feat = data[0] 104 | output_feat = data[1:] 105 | # deterministically downsample (for fast evaluation during training) 106 | # subsample inputs 107 | if self.num_inputs != float("inf"): 108 | rng = torch.Generator().manual_seed(idx) 109 | input_perm = torch.randperm(len(input_feat), generator=rng)[:self.num_inputs] 110 | input_feat = input_feat[input_perm] 111 | input_pos = input_pos[input_perm] 112 | # subsample outputs 113 | if self.num_outputs != float("inf"): 114 | rng = torch.Generator().manual_seed(idx) 115 | output_perm = torch.randperm(len(output_pos), generator=rng)[:self.num_outputs] 116 | output_pos = output_pos[output_perm] 117 | for i in range(len(output_feat)): 118 | output_feat[i] = output_feat[i][output_perm] 119 | else: 120 | raise NotImplementedError 121 | 122 | # normalize 123 | input_feat -= self.mean.unsqueeze(0) 124 | input_feat /= self.std.unsqueeze(0) 125 | if isinstance(output_feat, list): 126 | for i in range(len(output_feat)): 127 | output_feat[i] -= self.mean.unsqueeze(0) 128 | output_feat[i] /= self.std.unsqueeze(0) 129 | else: 130 | output_feat -= self.mean.unsqueeze(0) 131 | output_feat /= self.std.unsqueeze(0) 132 | 133 | return dict( 134 | index=idx, 135 | input_feat=input_feat, 136 | input_pos=input_pos, 137 | output_feat=output_feat, 138 | output_pos=output_pos, 139 | timestep=timestep, 140 | ) 141 | -------------------------------------------------------------------------------- /upt/datasets/sparse_cifar10_autoencoder_dataset.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from torchvision.datasets import CIFAR10 4 | 5 | 6 | class SparseCIFAR10AutoencoderDataset(CIFAR10): 7 | def __init__( 8 | self, 9 | # how many input pixels to sample (<= 1024) 10 | num_inputs, 11 | # how many output pixels to sample (<= 1024) 12 | num_outputs, 13 | # CIFAR10 properties 14 | root, 15 | train=True, 16 | transform=None, 17 | download=False, 18 | ): 19 | super().__init__( 20 | root=root, 21 | train=train, 22 | transform=transform, 23 | download=download, 24 | ) 25 | assert num_inputs <= 1024, "CIFAR10 only has 1024 pixels, use less or equal 1024 num_inputs" 26 | self.num_inputs = num_inputs 27 | self.num_outputs = num_outputs 28 | # CIFAR has 32x32 pixels 29 | # output_pos will be a tensor of shape (32 * 32, 2) with and will contain x and y indices 30 | # output_pos[0] = [0, 0] 31 | # output_pos[1] = [0, 1] 32 | # output_pos[2] = [0, 2] 33 | # ... 34 | # output_pos[32] = [1, 0] 35 | # output_pos[1024] = [31, 31] 36 | self.output_pos = einops.rearrange( 37 | torch.stack(torch.meshgrid([torch.arange(32), torch.arange(32)], indexing="ij")), 38 | "ndim height width -> (height width) ndim", 39 | ).float() 40 | 41 | def __getitem__(self, idx): 42 | image, _ = super().__getitem__(idx) 43 | assert image.shape == (3, 32, 32) 44 | # reshape image to sparse tensor 45 | x = einops.rearrange(image, "dim height width -> (height width) dim") 46 | pos = self.output_pos.clone() 47 | 48 | # subsample random input pixels (locations of inputs and outputs does not have to be the same) 49 | if self.num_inputs < 1024: 50 | if self.train: 51 | rng = None 52 | else: 53 | rng = torch.Generator().manual_seed(idx) 54 | input_perm = torch.randperm(len(x), generator=rng)[:self.num_inputs] 55 | input_feat = x[input_perm] 56 | input_pos = pos[input_perm].clone() 57 | else: 58 | input_feat = x 59 | input_pos = pos.clone() 60 | 61 | # subsample random output pixels (locations of inputs and outputs does not have to be the same) 62 | if self.num_outputs < 1024: 63 | if self.train: 64 | rng = None 65 | else: 66 | rng = torch.Generator().manual_seed(idx + 1) 67 | output_perm = torch.randperm(len(x), generator=rng)[:self.num_outputs] 68 | target_feat = x[output_perm] 69 | output_pos = pos[output_perm].clone() 70 | else: 71 | target_feat = x 72 | output_pos = pos.clone() 73 | 74 | return dict( 75 | index=idx, 76 | input_feat=input_feat, 77 | input_pos=input_pos, 78 | target_feat=target_feat, 79 | output_pos=output_pos, 80 | ) 81 | -------------------------------------------------------------------------------- /upt/datasets/sparse_cifar10_classifier_dataset.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from torchvision.datasets import CIFAR10 4 | 5 | 6 | class SparseCIFAR10ClassifierDataset(CIFAR10): 7 | def __init__(self, root, num_inputs, train=True, transform=None, download=False): 8 | super().__init__( 9 | root=root, 10 | train=train, 11 | transform=transform, 12 | download=download, 13 | ) 14 | assert num_inputs <= 1024, "CIFAR10 only has 1024 pixels, use less or equal 1024 num_inputs" 15 | self.num_inputs = num_inputs 16 | # CIFAR has 32x32 pixels 17 | # output_pos will be a tensor of shape (32 * 32, 2) with and will contain x and y indices 18 | # output_pos[0] = [0, 0] 19 | # output_pos[1] = [0, 1] 20 | # output_pos[2] = [0, 2] 21 | # ... 22 | # output_pos[32] = [1, 0] 23 | # output_pos[1024] = [31, 31] 24 | self.output_pos = einops.rearrange( 25 | torch.stack(torch.meshgrid([torch.arange(32), torch.arange(32)], indexing="ij")), 26 | "ndim height width -> (height width) ndim", 27 | ).float() 28 | 29 | def __getitem__(self, idx): 30 | image, y = super().__getitem__(idx) 31 | assert image.shape == (3, 32, 32) 32 | # reshape image to sparse tensor 33 | x = einops.rearrange(image, "dim height width -> (height width) dim") 34 | pos = self.output_pos.clone() 35 | 36 | # subsample random input pixels (locations of inputs and outputs does not have to be the same) 37 | if self.num_inputs < 1024: 38 | if self.train: 39 | rng = None 40 | else: 41 | rng = torch.Generator().manual_seed(idx) 42 | input_perm = torch.randperm(len(x), generator=rng)[:self.num_inputs] 43 | input_feat = x[input_perm] 44 | input_pos = pos[input_perm].clone() 45 | else: 46 | input_feat = x 47 | input_pos = pos.clone() 48 | 49 | return dict( 50 | index=idx, 51 | input_feat=input_feat, 52 | input_pos=input_pos, 53 | target_class=y, 54 | ) 55 | -------------------------------------------------------------------------------- /upt/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/upt/models/__init__.py -------------------------------------------------------------------------------- /upt/models/approximator.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from kappamodules.layers import LinearProjection, Sequential 5 | from kappamodules.transformer import DitBlock, PrenormBlock 6 | from torch import nn 7 | 8 | 9 | class Approximator(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim, 13 | depth, 14 | num_heads, 15 | dim=None, 16 | cond_dim=None, 17 | init_weights="truncnormal002", 18 | **kwargs, 19 | ): 20 | super().__init__(**kwargs) 21 | dim = dim or input_dim 22 | self.dim = dim 23 | self.depth = depth 24 | self.num_heads = num_heads 25 | self.cond_dim = cond_dim 26 | self.init_weights = init_weights 27 | 28 | # project 29 | self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True) 30 | 31 | # blocks 32 | if cond_dim is None: 33 | block_ctor = PrenormBlock 34 | else: 35 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 36 | self.blocks = Sequential( 37 | *[ 38 | block_ctor( 39 | dim=dim, 40 | num_heads=num_heads, 41 | init_weights=init_weights, 42 | ) 43 | for _ in range(depth) 44 | ], 45 | ) 46 | 47 | def forward(self, x, condition=None): 48 | # check inputs 49 | assert x.ndim == 3, "expected shape (batch_size, num_latent_tokens, dim)" 50 | if condition is not None: 51 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 52 | 53 | # pass condition to DiT blocks 54 | cond_kwargs = {} 55 | if condition is not None: 56 | cond_kwargs["cond"] = condition 57 | 58 | # project to decoder dim 59 | x = self.input_proj(x) 60 | 61 | # apply blocks 62 | x = self.blocks(x, **cond_kwargs) 63 | 64 | return x 65 | -------------------------------------------------------------------------------- /upt/models/conditioner_timestep.py: -------------------------------------------------------------------------------- 1 | from kappamodules.functional.pos_embed import get_sincos_1d_from_seqlen 2 | from torch import nn 3 | 4 | 5 | class ConditionerTimestep(nn.Module): 6 | def __init__(self, dim, num_timesteps): 7 | super().__init__() 8 | cond_dim = dim * 4 9 | self.num_timesteps = num_timesteps 10 | self.dim = dim 11 | self.cond_dim = cond_dim 12 | self.register_buffer( 13 | "timestep_embed", 14 | get_sincos_1d_from_seqlen(seqlen=num_timesteps, dim=dim), 15 | ) 16 | self.mlp = nn.Sequential( 17 | nn.Linear(dim, cond_dim), 18 | nn.SiLU(), 19 | ) 20 | 21 | def forward(self, timestep): 22 | # checks + preprocess 23 | assert timestep.numel() == len(timestep) 24 | timestep = timestep.flatten() 25 | # embed 26 | embed = self.mlp(self.timestep_embed[timestep]) 27 | return embed 28 | -------------------------------------------------------------------------------- /upt/models/decoder_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from kappamodules.layers import LinearProjection, Sequential 4 | from kappamodules.transformer import DitBlock 5 | from kappamodules.vit import VitBlock 6 | from torch import nn 7 | 8 | 9 | class DecoderClassifier(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim, 13 | num_classes, 14 | dim, 15 | depth, 16 | num_heads, 17 | cond_dim=None, 18 | init_weights="truncnormal002", 19 | **kwargs, 20 | ): 21 | super().__init__(**kwargs) 22 | self.input_dim = input_dim 23 | self.num_classes = num_classes 24 | self.dim = dim 25 | self.depth = depth 26 | self.num_heads = num_heads 27 | self.cond_dim = cond_dim 28 | self.init_weights = init_weights 29 | 30 | # input projection 31 | self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True) 32 | 33 | # blocks 34 | if cond_dim is None: 35 | block_ctor = VitBlock 36 | else: 37 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 38 | self.blocks = Sequential( 39 | *[ 40 | block_ctor( 41 | dim=dim, 42 | num_heads=num_heads, 43 | init_weights=init_weights, 44 | ) 45 | for _ in range(depth) 46 | ], 47 | ) 48 | 49 | # classifier 50 | self.head = nn.Sequential( 51 | nn.LayerNorm(dim, eps=1e-6), 52 | LinearProjection(dim, num_classes), 53 | ) 54 | 55 | def forward(self, x, condition=None): 56 | # check inputs 57 | assert x.ndim == 3, "expected shape (batch_size, num_latent_tokens, dim)" 58 | if condition is not None: 59 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 60 | 61 | # pass condition to DiT blocks 62 | cond_kwargs = {} 63 | if condition is not None: 64 | cond_kwargs["cond"] = condition 65 | 66 | # input projection 67 | x = self.input_proj(x) 68 | 69 | # apply blocks 70 | x = self.blocks(x, **cond_kwargs) 71 | 72 | # pool 73 | x = x.mean(dim=1) 74 | 75 | # classify 76 | x = self.head(x) 77 | 78 | return x 79 | -------------------------------------------------------------------------------- /upt/models/decoder_perceiver.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import einops 4 | import torch 5 | from kappamodules.layers import ContinuousSincosEmbed, LinearProjection, Sequential 6 | from kappamodules.transformer import PerceiverBlock, DitPerceiverBlock, DitBlock 7 | from kappamodules.vit import VitBlock 8 | from torch import nn 9 | import math 10 | 11 | 12 | class DecoderPerceiver(nn.Module): 13 | def __init__( 14 | self, 15 | input_dim, 16 | output_dim, 17 | ndim, 18 | dim, 19 | depth, 20 | num_heads, 21 | unbatch_mode="dense_to_sparse_unpadded", 22 | perc_dim=None, 23 | perc_num_heads=None, 24 | cond_dim=None, 25 | init_weights="truncnormal002", 26 | **kwargs, 27 | ): 28 | super().__init__(**kwargs) 29 | perc_dim = perc_dim or dim 30 | perc_num_heads = perc_num_heads or num_heads 31 | self.input_dim = input_dim 32 | self.output_dim = output_dim 33 | self.ndim = ndim 34 | self.dim = dim 35 | self.depth = depth 36 | self.num_heads = num_heads 37 | self.perc_dim = perc_dim 38 | self.perc_num_heads = perc_num_heads 39 | self.cond_dim = cond_dim 40 | self.init_weights = init_weights 41 | self.unbatch_mode = unbatch_mode 42 | 43 | # input projection 44 | self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True) 45 | 46 | # blocks 47 | if cond_dim is None: 48 | block_ctor = VitBlock 49 | else: 50 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 51 | self.blocks = Sequential( 52 | *[ 53 | block_ctor( 54 | dim=dim, 55 | num_heads=num_heads, 56 | init_weights=init_weights, 57 | ) 58 | for _ in range(depth) 59 | ], 60 | ) 61 | 62 | # prepare perceiver 63 | self.pos_embed = ContinuousSincosEmbed( 64 | dim=perc_dim, 65 | ndim=ndim, 66 | ) 67 | if cond_dim is None: 68 | block_ctor = PerceiverBlock 69 | else: 70 | block_ctor = partial(DitPerceiverBlock, cond_dim=cond_dim) 71 | 72 | # decoder 73 | self.query_proj = nn.Sequential( 74 | LinearProjection(perc_dim, perc_dim, init_weights=init_weights), 75 | nn.GELU(), 76 | LinearProjection(perc_dim, perc_dim, init_weights=init_weights), 77 | ) 78 | self.perc = block_ctor(dim=perc_dim, kv_dim=dim, num_heads=perc_num_heads, init_weights=init_weights) 79 | self.pred = nn.Sequential( 80 | nn.LayerNorm(perc_dim, eps=1e-6), 81 | LinearProjection(perc_dim, output_dim, init_weights=init_weights), 82 | ) 83 | 84 | def forward(self, x, output_pos, condition=None): 85 | # check inputs 86 | assert x.ndim == 3, "expected shape (batch_size, num_latent_tokens, dim)" 87 | assert output_pos.ndim == 3, "expected shape (batch_size, num_outputs, dim) num_outputs might be padded" 88 | if condition is not None: 89 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 90 | 91 | # pass condition to DiT blocks 92 | cond_kwargs = {} 93 | if condition is not None: 94 | cond_kwargs["cond"] = condition 95 | 96 | # input projection 97 | x = self.input_proj(x) 98 | 99 | # apply blocks 100 | x = self.blocks(x, **cond_kwargs) 101 | 102 | # create query 103 | query = self.pos_embed(output_pos) 104 | query = self.query_proj(query) 105 | 106 | x = self.perc(q=query, kv=x, **cond_kwargs) 107 | x = self.pred(x) 108 | if self.unbatch_mode == "dense_to_sparse_unpadded": 109 | # dense to sparse where no padding needs to be considered 110 | x = einops.rearrange( 111 | x, 112 | "batch_size seqlen dim -> (batch_size seqlen) dim", 113 | ) 114 | elif self.unbatch_mode == "image": 115 | # rearrange to square image 116 | height = math.sqrt(x.size(1)) 117 | assert height.is_integer() 118 | x = einops.rearrange( 119 | x, 120 | "batch_size (height width) dim -> batch_size dim height width", 121 | height=int(height), 122 | ) 123 | else: 124 | raise NotImplementedError(f"invalid unbatch_mode '{self.unbatch_mode}'") 125 | 126 | return x 127 | -------------------------------------------------------------------------------- /upt/models/encoder_image.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import einops 4 | from kappamodules.layers import Sequential 5 | from kappamodules.transformer import PerceiverPoolingBlock, PrenormBlock, DitPerceiverPoolingBlock, DitBlock 6 | from kappamodules.utils.param_checking import to_2tuple 7 | from kappamodules.vit import VitPatchEmbed, VitPosEmbed2d 8 | from torch import nn 9 | 10 | 11 | class EncoderImage(nn.Module): 12 | def __init__( 13 | self, 14 | input_dim, 15 | patch_size, 16 | resolution, 17 | enc_dim, 18 | enc_num_heads, 19 | enc_depth, 20 | perc_dim=None, 21 | perc_num_heads=None, 22 | num_latent_tokens=None, 23 | cond_dim=None, 24 | init_weights="truncnormal", 25 | ): 26 | super().__init__() 27 | patch_size = to_2tuple(patch_size) 28 | resolution = to_2tuple(resolution) 29 | self.input_dim = input_dim 30 | self.patch_size = patch_size 31 | self.resolution = resolution 32 | self.enc_dim = enc_dim 33 | self.enc_depth = enc_depth 34 | self.enc_num_heads = enc_num_heads 35 | self.perc_dim = perc_dim 36 | self.perc_num_heads = perc_num_heads 37 | self.num_latent_tokens = num_latent_tokens 38 | self.condition_dim = cond_dim 39 | self.init_weights = init_weights 40 | 41 | # embed 42 | self.patch_embed = VitPatchEmbed( 43 | dim=enc_dim, 44 | num_channels=input_dim, 45 | resolution=resolution, 46 | patch_size=patch_size, 47 | ) 48 | self.pos_embed = VitPosEmbed2d(seqlens=self.patch_embed.seqlens, dim=enc_dim, is_learnable=False) 49 | 50 | # blocks 51 | if cond_dim is None: 52 | block_ctor = PrenormBlock 53 | else: 54 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 55 | self.blocks = Sequential( 56 | *[ 57 | block_ctor(dim=enc_dim, num_heads=enc_num_heads, init_weights=init_weights) 58 | for _ in range(enc_depth) 59 | ], 60 | ) 61 | 62 | # perceiver pooling 63 | if num_latent_tokens is None: 64 | self.perceiver = None 65 | else: 66 | if cond_dim is None: 67 | block_ctor = partial( 68 | PerceiverPoolingBlock, 69 | perceiver_kwargs=dict( 70 | kv_dim=enc_dim, 71 | init_weights=init_weights, 72 | ), 73 | ) 74 | else: 75 | block_ctor = partial( 76 | DitPerceiverPoolingBlock, 77 | perceiver_kwargs=dict( 78 | kv_dim=enc_dim, 79 | cond_dim=cond_dim, 80 | init_weights=init_weights, 81 | ), 82 | ) 83 | self.perceiver = block_ctor( 84 | dim=perc_dim, 85 | num_heads=perc_num_heads, 86 | num_query_tokens=num_latent_tokens, 87 | ) 88 | 89 | def forward(self, input_image, condition=None): 90 | # check inputs 91 | assert input_image.ndim == 4, "expected input image of shape (batch_size, num_channels, height, width)" 92 | if condition is not None: 93 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 94 | 95 | # pass condition to DiT blocks 96 | cond_kwargs = {} 97 | if condition is not None: 98 | cond_kwargs["cond"] = condition 99 | 100 | # patch_embed 101 | x = self.patch_embed(input_image) 102 | # add pos_embed 103 | x = self.pos_embed(x) 104 | # flatten 105 | x = einops.rearrange(x, "b ... d -> b (...) d") 106 | 107 | # transformer 108 | x = self.blocks(x, **cond_kwargs) 109 | 110 | # perceiver 111 | if self.perceiver is not None: 112 | x = self.perceiver(kv=x, **cond_kwargs) 113 | 114 | return x 115 | -------------------------------------------------------------------------------- /upt/models/encoder_supernodes.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from kappamodules.layers import LinearProjection, Sequential 4 | from kappamodules.transformer import PerceiverPoolingBlock, PrenormBlock, DitPerceiverPoolingBlock, DitBlock 5 | from upt.modules.supernode_pooling import SupernodePooling 6 | from torch import nn 7 | 8 | 9 | class EncoderSupernodes(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim, 13 | ndim, 14 | radius, 15 | max_degree, 16 | gnn_dim, 17 | enc_dim, 18 | enc_depth, 19 | enc_num_heads, 20 | perc_dim=None, 21 | perc_num_heads=None, 22 | num_latent_tokens=None, 23 | cond_dim=None, 24 | init_weights="truncnormal", 25 | ): 26 | super().__init__() 27 | self.input_dim = input_dim 28 | self.ndim = ndim 29 | self.radius = radius 30 | self.max_degree = max_degree 31 | self.gnn_dim = gnn_dim 32 | self.enc_dim = enc_dim 33 | self.enc_depth = enc_depth 34 | self.enc_num_heads = enc_num_heads 35 | self.perc_dim = perc_dim 36 | self.perc_num_heads = perc_num_heads 37 | self.num_latent_tokens = num_latent_tokens 38 | self.condition_dim = cond_dim 39 | self.init_weights = init_weights 40 | 41 | # supernode pooling 42 | self.supernode_pooling = SupernodePooling( 43 | radius=radius, 44 | max_degree=max_degree, 45 | input_dim=input_dim, 46 | hidden_dim=gnn_dim, 47 | ndim=ndim, 48 | ) 49 | 50 | # blocks 51 | self.enc_proj = LinearProjection(gnn_dim, enc_dim, init_weights=init_weights, optional=True) 52 | if cond_dim is None: 53 | block_ctor = PrenormBlock 54 | else: 55 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 56 | self.blocks = Sequential( 57 | *[ 58 | block_ctor(dim=enc_dim, num_heads=enc_num_heads, init_weights=init_weights) 59 | for _ in range(enc_depth) 60 | ], 61 | ) 62 | 63 | # perceiver pooling 64 | if num_latent_tokens is None: 65 | self.perceiver = None 66 | else: 67 | if cond_dim is None: 68 | block_ctor = partial( 69 | PerceiverPoolingBlock, 70 | perceiver_kwargs=dict( 71 | kv_dim=enc_dim, 72 | init_weights=init_weights, 73 | ), 74 | ) 75 | else: 76 | block_ctor = partial( 77 | DitPerceiverPoolingBlock, 78 | perceiver_kwargs=dict( 79 | kv_dim=enc_dim, 80 | cond_dim=cond_dim, 81 | init_weights=init_weights, 82 | ), 83 | ) 84 | self.perceiver = block_ctor( 85 | dim=perc_dim, 86 | num_heads=perc_num_heads, 87 | num_query_tokens=num_latent_tokens, 88 | ) 89 | 90 | def forward(self, input_feat, input_pos, supernode_idxs, batch_idx, condition=None): 91 | # check inputs 92 | assert input_feat.ndim == 2, "expected sparse tensor (batch_size * num_inputs, input_dim)" 93 | assert input_pos.ndim == 2, "expected sparse tensor (batch_size * num_inputs, ndim)" 94 | assert len(input_feat) == len(input_pos), "expected input_feat and input_pos to have same length" 95 | assert supernode_idxs.ndim == 1, "supernode_idxs is a 1D tensor of indices that are used as supernodes" 96 | assert batch_idx.ndim == 1, f"batch_idx should be 1D tensor that assigns elements of the input to samples" 97 | if condition is not None: 98 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 99 | 100 | # pass condition to DiT blocks 101 | cond_kwargs = {} 102 | if condition is not None: 103 | cond_kwargs["cond"] = condition 104 | 105 | # supernode pooling 106 | x = self.supernode_pooling( 107 | input_feat=input_feat, 108 | input_pos=input_pos, 109 | supernode_idxs=supernode_idxs, 110 | batch_idx=batch_idx, 111 | ) 112 | 113 | # project to encoder dimension 114 | x = self.enc_proj(x) 115 | 116 | # transformer 117 | x = self.blocks(x, **cond_kwargs) 118 | 119 | # perceiver 120 | if self.perceiver is not None: 121 | x = self.perceiver(kv=x, **cond_kwargs) 122 | 123 | return x 124 | -------------------------------------------------------------------------------- /upt/models/upt.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class UPT(nn.Module): 7 | def __init__(self, conditioner, encoder, approximator, decoder): 8 | super().__init__() 9 | self.conditioner = conditioner 10 | self.encoder = encoder 11 | self.approximator = approximator 12 | self.decoder = decoder 13 | 14 | def forward( 15 | self, 16 | input_feat, 17 | input_pos, 18 | supernode_idxs, 19 | output_pos, 20 | batch_idx, 21 | timestep, 22 | ): 23 | condition = self.conditioner(timestep) 24 | 25 | # encode data 26 | latent = self.encoder( 27 | input_feat=input_feat, 28 | input_pos=input_pos, 29 | supernode_idxs=supernode_idxs, 30 | batch_idx=batch_idx, 31 | condition=condition, 32 | ) 33 | 34 | # propagate forward 35 | latent = self.approximator(latent, condition=condition) 36 | 37 | # decode 38 | pred = self.decoder( 39 | x=latent, 40 | output_pos=output_pos, 41 | condition=condition, 42 | ) 43 | 44 | return pred 45 | 46 | @torch.no_grad() 47 | def rollout(self, input_feat, input_pos, supernode_idxs, batch_idx): 48 | batch_size = batch_idx.max() + 1 49 | timestep = torch.zeros(batch_size).long() 50 | 51 | # we assume that output_pos is simply a rearranged input_pos 52 | # i.e. num_inputs == num_outputs and num_inputs is constant for all samples 53 | output_pos = einops.rearrange( 54 | input_pos, 55 | "(batch_size num_inputs) ndim -> batch_size num_inputs ndim", 56 | batch_size=batch_size, 57 | ) 58 | 59 | predictions = [] 60 | for i in range(self.conditioner.num_timesteps): 61 | condition = self.conditioner(timestep) 62 | # encode data 63 | latent = self.encoder( 64 | input_feat=input_feat, 65 | input_pos=input_pos, 66 | supernode_idxs=supernode_idxs, 67 | batch_idx=batch_idx, 68 | condition=condition, 69 | ) 70 | 71 | # propagate forward 72 | latent = self.approximator(latent, condition=condition) 73 | 74 | # decode 75 | pred = self.decoder( 76 | x=latent, 77 | output_pos=output_pos, 78 | condition=condition, 79 | ) 80 | predictions.append(pred) 81 | 82 | # increase timestep 83 | timestep += 1 84 | 85 | # feed prediction as next input for autoregressive rollout 86 | input_feat = pred 87 | 88 | return predictions 89 | -------------------------------------------------------------------------------- /upt/models/upt_image_autoencoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UPTImageAutoencoder(nn.Module): 5 | def __init__(self, encoder, approximator, decoder): 6 | super().__init__() 7 | self.encoder = encoder 8 | self.approximator = approximator 9 | self.decoder = decoder 10 | 11 | def forward(self, x, output_pos): 12 | # encode data 13 | latent = self.encoder(x) 14 | 15 | # propagate forward 16 | latent = self.approximator(latent) 17 | 18 | # decode 19 | pred = self.decoder(latent, output_pos=output_pos) 20 | 21 | return pred 22 | -------------------------------------------------------------------------------- /upt/models/upt_image_classifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UPTImageClassifier(nn.Module): 5 | def __init__(self, encoder, approximator, decoder): 6 | super().__init__() 7 | self.encoder = encoder 8 | self.approximator = approximator 9 | self.decoder = decoder 10 | 11 | def forward(self, x): 12 | # encode data 13 | latent = self.encoder(x) 14 | 15 | # propagate forward 16 | latent = self.approximator(latent) 17 | 18 | # decode 19 | pred = self.decoder(latent) 20 | 21 | return pred 22 | -------------------------------------------------------------------------------- /upt/models/upt_sparseimage_autoencoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UPTSparseImageAutoencoder(nn.Module): 5 | def __init__(self, encoder, approximator, decoder): 6 | super().__init__() 7 | self.encoder = encoder 8 | self.approximator = approximator 9 | self.decoder = decoder 10 | 11 | def forward(self, input_feat, input_pos, supernode_idxs, batch_idx, output_pos): 12 | # encode data 13 | latent = self.encoder( 14 | input_feat=input_feat, 15 | input_pos=input_pos, 16 | supernode_idxs=supernode_idxs, 17 | batch_idx=batch_idx, 18 | ) 19 | 20 | # propagate forward 21 | latent = self.approximator(latent) 22 | 23 | # decode 24 | pred = self.decoder(latent, output_pos=output_pos) 25 | 26 | return pred 27 | -------------------------------------------------------------------------------- /upt/models/upt_sparseimage_classifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UPTSparseImageClassifier(nn.Module): 5 | def __init__(self, encoder, approximator, decoder): 6 | super().__init__() 7 | self.encoder = encoder 8 | self.approximator = approximator 9 | self.decoder = decoder 10 | 11 | def forward(self, input_feat, input_pos, supernode_idxs, batch_idx): 12 | # encode data 13 | latent = self.encoder( 14 | input_feat=input_feat, 15 | input_pos=input_pos, 16 | supernode_idxs=supernode_idxs, 17 | batch_idx=batch_idx, 18 | ) 19 | 20 | # propagate forward 21 | latent = self.approximator(latent) 22 | 23 | # decode 24 | pred = self.decoder(latent) 25 | 26 | return pred 27 | -------------------------------------------------------------------------------- /upt/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/upt/modules/__init__.py -------------------------------------------------------------------------------- /upt/modules/supernode_pooling.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from kappamodules.layers import ContinuousSincosEmbed, LinearProjection 4 | from torch import nn 5 | from torch_geometric.nn.pool import radius_graph 6 | from torch_scatter import segment_csr 7 | 8 | 9 | class SupernodePooling(nn.Module): 10 | def __init__( 11 | self, 12 | radius, 13 | max_degree, 14 | input_dim, 15 | hidden_dim, 16 | ndim, 17 | init_weights="torch", 18 | ): 19 | super().__init__() 20 | self.radius = radius 21 | self.max_degree = max_degree 22 | self.input_dim = input_dim 23 | self.hidden_dim = hidden_dim 24 | self.ndim = ndim 25 | self.init_weights = init_weights 26 | 27 | self.input_proj = LinearProjection(input_dim, hidden_dim, init_weights=init_weights) 28 | self.pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim) 29 | self.message = nn.Sequential( 30 | LinearProjection(hidden_dim * 2, hidden_dim, init_weights=init_weights), 31 | nn.GELU(), 32 | LinearProjection(hidden_dim, hidden_dim, init_weights=init_weights), 33 | ) 34 | self.output_dim = hidden_dim 35 | 36 | def forward(self, input_feat, input_pos, supernode_idxs, batch_idx): 37 | assert input_feat.ndim == 2 38 | assert input_pos.ndim == 2 39 | assert supernode_idxs.ndim == 1 40 | 41 | # radius graph 42 | input_edges = radius_graph( 43 | x=input_pos, 44 | r=self.radius, 45 | max_num_neighbors=self.max_degree, 46 | batch=batch_idx, 47 | loop=True, 48 | # inverted flow direction is required to have sorted dst_indices 49 | flow="target_to_source", 50 | ) 51 | is_supernode_edge = torch.isin(input_edges[0], supernode_idxs) 52 | input_edges = input_edges[:, is_supernode_edge] 53 | 54 | # embed mesh 55 | x = self.input_proj(input_feat) + self.pos_embed(input_pos) 56 | 57 | # create message input 58 | dst_idx, src_idx = input_edges.unbind() 59 | x = torch.concat([x[src_idx], x[dst_idx]], dim=1) 60 | x = self.message(x) 61 | # accumulate messages 62 | # indptr is a tensor of indices betweeen which to aggregate 63 | # i.e. a tensor of [0, 2, 5] would result in [src[0] + src[1], src[2] + src[3] + src[4]] 64 | dst_indices, counts = dst_idx.unique(return_counts=True) 65 | # first index has to be 0 66 | # NOTE: padding for target indices that dont occour is not needed as self-loop is always present 67 | padded_counts = torch.zeros(len(counts) + 1, device=counts.device, dtype=counts.dtype) 68 | padded_counts[1:] = counts 69 | indptr = padded_counts.cumsum(dim=0) 70 | x = segment_csr(src=x, indptr=indptr, reduce="mean") 71 | 72 | # sanity check: dst_indices has len of batch_size * num_supernodes and has to be divisible by batch_size 73 | # if num_supernodes is not set in dataset this assertion fails 74 | batch_size = batch_idx.max() + 1 75 | assert dst_indices.numel() % batch_size == 0 76 | 77 | # convert to dense tensor (dim last) 78 | x = einops.rearrange( 79 | x, 80 | "(batch_size num_supernodes) dim -> batch_size num_supernodes dim", 81 | batch_size=batch_size, 82 | ) 83 | 84 | return x 85 | --------------------------------------------------------------------------------