├── .gitignore ├── 1_preliminaries.ipynb ├── 2_image_classification.ipynb ├── 3_image_autoencoder.ipynb ├── 4_pointcloud_classification.ipynb ├── 5_pointcloud_autoencoder.ipynb ├── 6_transient_flow_cfd.ipynb ├── LICENSE ├── README.md ├── data └── simulation │ └── case_000000 │ ├── 00000000_mesh.th │ ├── 00000001_mesh.th │ ├── 00000002_mesh.th │ ├── 00000003_mesh.th │ ├── 00000004_mesh.th │ ├── 00000005_mesh.th │ ├── 00000006_mesh.th │ ├── 00000007_mesh.th │ ├── 00000008_mesh.th │ ├── 00000009_mesh.th │ ├── 00000010_mesh.th │ ├── 00000011_mesh.th │ ├── 00000012_mesh.th │ ├── 00000013_mesh.th │ ├── 00000014_mesh.th │ ├── 00000015_mesh.th │ ├── 00000016_mesh.th │ ├── 00000017_mesh.th │ ├── 00000018_mesh.th │ ├── 00000019_mesh.th │ ├── x.th │ └── y.th ├── schematics ├── architecture.svg ├── perceiver_decoder.svg ├── perceiver_pooling.svg ├── schematics.drawio ├── upt_dense_autoencoder.svg ├── upt_dense_classifier.svg ├── upt_sparse_autoencoder.svg └── upt_sparse_classifier.svg └── upt ├── __init__.py ├── collators ├── __init__.py ├── simulation_collator.py ├── sparseimage_autoencoder_collator.py └── sparseimage_classifier_collator.py ├── datasets ├── __init__.py ├── simulation_dataset.py ├── sparse_cifar10_autoencoder_dataset.py └── sparse_cifar10_classifier_dataset.py ├── models ├── __init__.py ├── approximator.py ├── conditioner_timestep.py ├── decoder_classifier.py ├── decoder_perceiver.py ├── encoder_image.py ├── encoder_supernodes.py ├── upt.py ├── upt_image_autoencoder.py ├── upt_image_classifier.py ├── upt_sparseimage_autoencoder.py └── upt_sparseimage_classifier.py └── modules ├── __init__.py └── supernode_pooling.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.json 3 | *.png 4 | *.jpg 5 | *.JPEG 6 | *.mp4 7 | *.pdf 8 | *.bkp 9 | *.dtmp 10 | 11 | data/ 12 | yamls_run/ 13 | temp/ 14 | sbatch_config.yaml 15 | sbatch_template_nodes.sh 16 | submit/ 17 | wandb_config.yaml 18 | wandb_configs/ 19 | 20 | hyperparams*.yaml 21 | static_config.yaml 22 | scratchpad*.py 23 | 24 | 25 | # Byte-compiled / optimized / DLL files 26 | __pycache__/ 27 | *.py[cod] 28 | *$py.class 29 | 30 | # C extensions 31 | *.so 32 | 33 | # Distribution / packaging 34 | .Python 35 | build/ 36 | develop-eggs/ 37 | dist/ 38 | downloads/ 39 | eggs/ 40 | .eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | wheels/ 47 | pip-wheel-metadata/ 48 | share/python-wheels/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | MANIFEST 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .nox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | db.sqlite3-journal 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # celery beat schedule file 118 | celerybeat-schedule 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Benedikt Alkin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Resources for UPT Tutorials 2 | 3 | [[`Project Page`](https://ml-jku.github.io/UPT)] [[`Paper (arxiv)`](https://arxiv.org/abs/2402.12365)]] 4 | 5 | We recommend to familiarize yourself with the following papers if you haven't already: 6 | - [Transformer](https://arxiv.org/abs/1706.03762) 7 | - [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929) 8 | - [Perceiver IO](https://arxiv.org/abs/2107.14795) 9 | 10 | 11 | # Tutorials 12 | 13 | 14 | The following notebooks should provide a tutorial to get familiar with universal physics transformers. 15 | 16 | - [Preliminaries](https://github.com/BenediktAlkin/upt-tutorial/blob/main/1_preliminaries.ipynb): provides an introduction into the concepts behind UPT 17 | - Sparse tensors 18 | - Positional encoding 19 | - Architecture overview 20 | - [CIFAR10 image classification](https://github.com/BenediktAlkin/upt-tutorial/blob/main/2_image_classification.ipynb): start from a basic example (regular grid input, scalar output, easy encoder, simple classification decoder) 21 | - [CIFAR10 autoencoder](https://github.com/BenediktAlkin/upt-tutorial/blob/main/3_image_autoencoder.ipynb): introduce the perceiver decoder to query at arbitrary positions 22 | - [SparseCIFAR10 image classification](https://github.com/BenediktAlkin/upt-tutorial/blob/main/4_pointcloud_classification.ipynb): introduce handling point clouds via sparse tensors and supernode message passing 23 | - [SparseCIFAR10 image autoencoder](https://github.com/BenediktAlkin/upt-tutorial/blob/main/5_pointcloud_autoencoder.ipynb): combine the handling of input point clouds with decoding arbitrary many positions 24 | - [Simple Transient Flow](https://github.com/BenediktAlkin/upt-tutorial/blob/main/6_transient_flow_cfd.ipynb): put everything together to train UPT on a single trajectory of our transient flow simulations 25 | 26 | 27 | -------------------------------------------------------------------------------- /data/simulation/case_000000/00000000_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000000_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000001_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000001_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000002_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000002_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000003_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000003_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000004_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000004_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000005_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000005_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000006_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000006_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000007_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000007_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000008_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000008_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000009_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000009_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000010_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000010_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000011_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000011_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000012_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000012_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000013_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000013_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000014_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000014_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000015_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000015_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000016_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000016_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000017_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000017_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000018_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000018_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/00000019_mesh.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/00000019_mesh.th -------------------------------------------------------------------------------- /data/simulation/case_000000/x.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/x.th -------------------------------------------------------------------------------- /data/simulation/case_000000/y.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/data/simulation/case_000000/y.th -------------------------------------------------------------------------------- /schematics/perceiver_decoder.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | MLPMLPCross-attentionBlockCross-attention...qqkvkvOutput PositionOutput PositionText is not SVG - cannot display -------------------------------------------------------------------------------- /schematics/perceiver_pooling.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Learnable QueryLearnable QueryCross-attentionBlockCross-attention...qqkvkvText is not SVG - cannot display -------------------------------------------------------------------------------- /schematics/schematics.drawio: -------------------------------------------------------------------------------- 1 | 7Vzfd6I4FP5rfJw5hPBDHqut287pnPGczuxs9y0jUdlBwomx6v71GyQoklixFNJFfGglJBG+e7+b3Jub9OBwsfmDonj+lfg47JmGv+nB255pun2b/00KtmmBkxXMaOCnReBQ8BT8i0WhIUpXgY+XRxUZISEL4uPCCYkiPGFHZYhSsj6uNiXh8a/GaCZ+0TgUPE1QiKVqPwOfzdPSvp2rfY+D2Tz7ZWCIOwuUVRZdLOfIJ+vcb8G7HhxSQlj6bbEZ4jDBLsPFdTabhyfav49/jKc4/hb+vn/5lPY+uqTJ/hUojtj7dm2mXb+gcCXwEu/KthmAlKwiHyedgB4crOcBw08xmiR311xjeNmcLUJxexqE4ZCEhO7aQn/3ScpJxHLl6YeXLxklv3F2JyIR73YgnglThjcFIZ5BAOzFwtUZkwVmdMvbiV76lhCtUGXTEKq8PigGcESdeV4pXFGIhDLO9n0fAOdfBOYX4A+vCX+ngD/IKHsOf2jbNeFvSfh/xctlYlJMY4yWyyCa9Uwn5M8x+EX5t1nyjRF+92kVYxqRxLa9p8RG3o03GDYuGafADKBghqmQjFUXMWwJVexzuy4uCWVzMiMRCu8OpYMD7ga/OtR5JCQWaP+DGduKQQqtuCCPZIE3Afsraf7ZFlfPuTu3G9Hz7mKbXUT8dXONksvn/L1Ds91V1i59v+Sl3iBIDgxZ0Qk+r9kM0Rl+rT9XrRgUh4gFL8cPp5KyaDomAX/svUK5ZkGh3AKD0+cXrQq6sn+Mt6uP06lPNfVxS6pPOmBV0J9KVsKVzPd3iqLllNAFpm20y25Ju2zWZZc9CfEx5prEJU2TIZPP7HdDZv3I8zvG7sPv+Gg53/felCwsqFsWmYt1HbNHq8iFsrNHYNU1ewSgXgGMRiPD8xoHul8AWqHodqN6Lnupt4ghCWqOAzvGU4lbHmRRhMJgFvHLEE+THhJMgwkKb0TxIvD93RRBJb7jaUNdQrGLvqtCKCrlr08oNbuuegZcSfc93bqv8FDTEN01K79CKs0qv61F+Ruf8xTpAE3ddHAk4L/FLEjctKtmhEowzTLClfDvPO6LPG7Q75VzuT21ZjTjcWdPeTUuN1SMNQ27eV5HrWrUykRzllqpduvilin78y3nlq2bW/u13Y5bb+WWWZJbuTVzHdySYwgt51bZJbz6uAU7blXkVtlFPKiXW3KMol3c6pcJPTfLLdkPbjfk0NIOeefhVjVnZT3cVLu1mbO2u7gStwzt3Opc3IrcgmVdXFOriwvb7uJKUwVHN7dgzYvUHytLABjFJFPX/VwyzbdGGciQd/btMvsGS9o3qDU8DuVV8XbZN2AUDBzwlPxq1sRZHb0q0ssuSS/rhHo0RC953b3l9ILgA9Cry6auSq+y6dSpgmujV9vzqSV6merZYbP06nf0qkgvryy9qu52qSZoOXm+7fSC+umV+YMdvd68lSzbt32OXrbWuK0lBzoegwgjzixjGCabOadBO2kG7Q9AM1NC/12h9jwfT6dNQy3lgzuaE2AtOcDwiBgWuyivNf1VJZZG01+tLgJRdZApHYHQuoBhyRGIhyheJVufHxa7cwPec3jRsv+reEyDq1jBUJGrtrMALDnZIcN8TJZBkv7fAthts7hwoRjSjUaNWpfxUNWolc14sLWGVS054+HP4HtCL8Qmc/7/bvGLi/T/P3MuJkkC7Xu4LT2ZD3u+HLHlQJ4TfMm4CXLMPPBUzc3aOeY0xTHRtHDUiQMLc1Gj2aNO7C6+UVGD7LLxDUunkbbl8EbPHMiyD8MgXp5yz3ISPLbepzeaKhzEusyzC85PgRq1zrYcw+i4dRG3ykbmzarc2jW9oRRtcxXixOguT9vu4jFVtlE4Q7FQv1+xPvBer++9Xp9/Sd/wxNv01U/X2EjUBUGqZkiXncukeUZN06WonvCMOgOjaoMs0HCyAXi9wRnGFF/IAg3P3eTIxphiP5i0JKYhZWNawCydjfkOgzoh68nIvbEXP/7+8v3xMX7+MkCfVPOok0gblyEt8JRmU6OR5+WSYbOTiJ0640mFyZRiry1UhZOAXRfw5nUAD4yCRwgt5fJcXwG+t6fHBejzy8Mp1KlpOhzlDe/+Aw==7Vzdd6I4FP9rfOwcQkDksdo6O2c7u93TOZ2PNypRmSLxxDjV/es3SEBIYkURMov40JKbD+De+7vJvbmkB0eLzUfiLeefsY/Cnmn4mx6865mmM7DZ35iwTQj9lDAjgZ+QwJ7wFPyLONHg1HXgo1WhIcU4pMGySJzgKEITWqB5hOC3YrMpDot3XXozfkdjT3iaeCGSmn0NfDpPqAM71/oPFMzm6Z2BwWsWXtqYD7Gaez5+y90L3vfgiGBMk6vFZoTCmHcpX14n8E/X/2aPHyj9B/6Y/vXySm+S0cendMlegaCIXnZoMxn6lxeuOb/4u9JtykCC15GP4kFADw7f5gFFT0tvEte+MY1htDldhLx6GoThCIeY7PpCf/eL6TiiOXryY/QVJfgVpTURjtiwQ/5MiFC0EYR4hAMgEwtTZ4QXiJIt68dHGVhctFyVTYOr8tteMUCft5nnlcLhRI8r4ywbe89wdsF5fgL/4TXxvy/wH6SQPcZ/aNs18d+S+P8ZrVaxSTGNR2+1CqJZz+yH7DmGL4RdzeIrilnt03qJSIRj23ZJiY3dW3c4alwyfQEZQIEMUyEZqy5g2BJXkc/sOi9iQud4hiMvvN9Th3u+G6y0b/OA8ZJz+yeidMsnKW/NBFmQBdoE9Fvc/YPNS99zNXcbPvKusE0LEXvdXKe4+D1ft++2K6X9kveLX+oMQTLG4DWZoOOaTT0yQ++N56gVg6DQo8Gv4sOppMy7PuKAPXamUI4pKJQjIDh5ft5L0JXsMc5Xn36nPtXUxympPuYBw1JafypZCUcy31+IF62mmCwQaaNddkraZbMuuzyQOP6ImCYxSZN4ymQr+92UWT/nWY2x+7Ea31vNs9GbkoUFdcvCvabFoyVCoeziEVh1LR5TF7cuAYzHY8N1G2f0QGC0Qs/tJtUcAInNdx71JFYzPtAiP5V8yzOZk7wwmEWsGKJpPELM02DihbecvAh8f7dCUImvuGqoSyi26LoqhKJS/vqEUnPkQM98K+m+q1v35QABj9Bds/IrpNKs8sthg1YueUQ4QFM3HGyJ8X8vaRB7aVeNCJVgmkVE53BXdLhBWY97oNaMZhxucG0eN1TMNc16eWDQQasitNyS0AJVg6HVBC378y3Hlq0bW+nAHbbOxVa2p38UW65ObJlyDKHl2Cq7g1cftmSfqMPWadiCJbFlal0TmnKMol3YGpQJPTeLLTkA0W6WQ0s7y7uMhKrmrF/WnFlazVn/2rBlaMeW02GrIrYGZbHV14otOX2g3djKorD6sHVVWQLAEHNMHedDySzf+mQAuzBDRfsGy4YZoNYQHmx7mAEYgoEDrhJfjZo42EUaqsILVkSNOgsYZHqQRXwzbUlHSRBdWyIwbHtwQkIkBL8BIq0OkRURaZed8KpCt5qg5YSMlsPLVC8om4VXl21RFV5lsy0SBdcGr7anW8jwgvrhlbqQdbnMruuj6bRpVku5rX3NyXyW7Cw9eBTxdeC1pvKpxNJoKp/VeVNVP70su29raQ3GWrJr9ClarmP/7dNi9wn0JacXLd+yiF+cO4porApctX3WbMkbtynPH/EqiFOZW8B22xSDsIop3WjUqHW7t1WNWtndW0tropcl794+B19ieHl0Mmf/7xcvTKT//5WzmPAFtH+OaunZxc3wUkDLHjwH8JJiE+SQucepGpu1Y6zsoQ2VMaaO1/ahsBY1mj21wepS3atqUNlUd60pNulTFhzsoSz7MAyWq0PuWU6CRet9+KM5hYNYl3l2wPElUKPWOXXtO2ydi63SKTYX2Uy7JcTb5hosY6O7Omy7xRN3bEM4Dk5oP6jYHrjvt3ffb88ukjc88DYD9dM1NRPZqVJ0aDk3I63sWgZWPUHoLLiI6gmPqDMwqnZIAw0HO4ib5UKHI4gRX8gCza7d7JOOMjRO87L4ZC3N7OOx6+aSzNIDPvt1xjaEiV3xDRtUhTaAfYGZHYc/xz+sryPrmXx8HtHX59nN7U03sZc3VUoGKnLLlO3MioaqkpTlzZL8wVt3aIL9du4IQnsgIay2wMZhuRcPnVjTJGg7Rh5dk0ufAaolaCulzlrp4TNHE2frYjzUatlOCAxllg0U7NrRIFTdlk2xBjtsXHRZNnlLJINXi/ZEgCkuIFV27UKbIqy4P5c8WeHtD3eH9/8B7Vxtk6o2FP41zrQfukMIL/Jxdde2M3undrYzt/djrkSlReKEeFf76xs0CJK4RlnIDuL4gbxCnnOek+RwyACOV9tfKVovv5AQxwPbCrcD+DSwbWDZ9iD7W+HukDPMMxY0CkWlIuM1+g/nLUXuJgpxelKRERKzaH2aOSNJgmfsJA9RSt5Oq81JfHrXNVqIO1pFxusMxViq9jUK2VKMwi3V/g1Hi2V+Z2CJkhXKK4su0iUKyVvpXvB5AMeUEHa4Wm3HOM7Ay3EZRnS+WvhT/H3nLlj8p4+8118OvU+uaXIcAsUJ+9iuhSx/oHgj8BJjZbscQEo2SYizTsAAjt6WEcOvazTLSt+4yvC8JVvFongexfGYxITu28Jw/8vyScJK+Ycfz08ZJf/ivCQhCe92JJ4JU4a3FSFeQAAcxcL1GZMVZnTH24leho4QrVBl23IP6bdCMYAn6izLSuGLTCSUcXHsuwCcXwjMr8Af3hP+XgV/kFP2Ev7QdRvC35Hw/4LTNDMptjVFaRoli4Htxfw5Rt8pv1pkV4zw0tfNGtOEZLbtIyU2CR6D0bh1yXgVZgAFM2yFZJymiOFKqOKQ23WRJJQtyYIkKH4uckcF7hZPFXVeCFkLtP/BjO3EJIU2XJAnssDbiP2dNX9wRepbqeRpK3reJ3Z5IuHDLTXKkt/KZUWzfSpvdxhfNqgbBMmBIRs6w5c1myG6wO/156sVg+IYsejH6cOppCyaTknEH/uoUL5dUSi/wuDD84tWFV05Psbt6uP16lNPfXxN9bHPGBZt/allJXzJfP9FUZLOCV1h2kW77GvaZbspuzyUEJ9irklc0jSbMvnKfj9lNo88L7H2P14SonR57L0tWTjQtCyCe1o8OlUq6C4egdPU4jHf4jYlgMlkYgVB60APK0Ar9NxtU80BkGB+QgxJUHMc2CmeStzKIIssFEeLhCdjPM96yDCNZih+FNmrKAz3KwSV+E5XDU0Jxa1uXRVCUSl/c0Jp2HNgZr6VdD8wrfuyg0C46O5Z+RVSaVf5ZbdBJ5c8VTpA2zQdXAn4P9YsynZpd80IlWDaZUS/4a654Qa6O+6hWjPa2XCDe9txQ8Vc0+4uDwx7atWkVqBJLVDXGVpP0PJ+vuPcck1zK++459at3Dq+07/IrcAkt2zZh9Bxbum+wWuOW/KeqOfWddyCmtyyja4JbdlH0S1uDXVcz+1yS3ZAdBty6BiHvI9IqGvOPF1z5hg1Z969ccsyzi2/51ZNbg11ueUZ5ZYcPtBtbh29sOa4dVdRAsCqxpj6/oNmlG9zMoC9m6GmfYO6bgZo1IUHu+5mAFbFwIFAya9WTRzsPQ116aXraYBGvXiw654GiV4QfAJ6OT29atLL1aUXNEovObqi4/Sy1avDdunVh07UpZdu6MRBwY3Rq+uxEzK94CegVx8+UZdeuuETrm2UXrKj4yVKMOLMssZx9i3nPOomzaBrnma526UpN1MQhHg+bxtqKR7cMxwA68gOhhfEsPiI8l7DX1ViaTX81ek9EDUnGUfXA+EYfYHhyB6I35P1JvvyeYIR29DsOIDqeQE/rXl2yssG9njPh5jMIrY7pB4eHn7+2EnJyFdj1bMdfMV7DxUlGztAwJFDJHJJTUkaZR8NdAB2166+7lAsBKxWTWEfJ1HXFOrGSThGnbGOHCfx5WX6sZz6FKGUwPiH3o6Z+IgjP07YUZDlDD9yLoISEwteqrnYOKd0j0OpzSnRtHIeigcrK1ar3fNQnN4LUleDdL0gRoPXHNkJMrBHsuzjOFqn5zZxJQmeWu/zn6MqtpFNmWcfXF7ytGqdcwdAz61buaUdvFb39di+6SOlaFeqsM6MbnredlfPsnKtykGLlfrDmvVB8H794P36/OIwwjOjGaqfrq2ZyM2VomfLrbGeumsZWPdsrpvoUlVPeEGdgVW3Qe5YONsAvN/gAmOqA3JAu2s315bm9CnFYTTriA9Ditl0gK0ds9nYpH7VER/WdVALQKXl1GQSBKWY2fy8Yq9R8Cs7E+goXyYNFeAHRzFdAT9PFkcmHyhSHDwNn/8H7Vxfc6M2EP80nmkfLoMQGPMYO3GvbdJmLp25S186nJFtWoxcIZ/tfPoKI7CR5JgYg66YvASt/hh297fsrhb14Gix+Yl4y/kj9lHYMw1/04N3PdMEpgPZv4SyTSlu30kJMxL4fNCe8By8Ik40OHUV+CguDKQYhzRYFokTHEVoQgs0jxC8Lg6b4rD4q0tvxn/R2BOeJ16IpGGfA5/OU+rAPhj9EQWzefbLwOA9Cy8bzJeI556P1we/Be97cEQwpunVYjNCYcK8jC+fXu+H/z69jh7sj1/+2tzZX4dfXj6kq4/fMyV/BIIietmlzXTpb1644vziz0q3GQMJXkU+ShYBPThczwOKnpfeJOldM5VhtDldhLx7GoThCIeY7OZCf/eX0HFED+jpH6PHlOB/UNYT4YgtO+T3hAhFG0GIJzgAcrEwfUZ4gSjZsnl8lYHFRctV2TTstL3eKwbo8zHzQ6VwONHjyjjL194znF1wnr+D//Ca+N8X+A8yyJ7iP7TtmvhvSfx/RHGcmBTTePLiOIhmPbMfsvsYfiXsapZcUcx6n1dLRCKc2LZLSmzs3rrDUeOS6QvIAApkmArJWHUBw5a4inxm13kTEzrHMxx54f2eOtzz3WCt/ZgHjJec238jSrf8JeWtmCALskCbgH5Jpt/YvPVy0HO34SvvGtusEbHHPZiUNF8O+/bTdq1sXvp8yUOdIUjGGLwiE3Ras6lHZuit9Ry1YhAUejT4Vrw5lZT51CccsNvOFcoxBYVyBASn989nCbqS38b56tPv1Kea+jgl1cc8YlhK608lK+FI5vsP4kXxFJMFIm20y05Ju2zWZZcHEsefENMkJmmSvDKZZ797ZdbPedZj7P5Yj+/F83z1pmRhQd2ycK/JebREKJR1HoFVl/OYhbh1CWA8Hhuu2zijBwKjFXpuN6nmAEhsvvOoJ7Ga8YEW+ank2yGTOckLg1nEmiGaJiskPA0mXnjLyYvA93cegkp8Ra+hLqHYYuiqEIpK+esTSs2ZAz3vW0n3Xd26LycIeIrumpVfIZVmlV9OG7TS5RHhAE3dcLAlxv++pEESpV01IlSCaRYRXcBdMeAGZSPugVozmgm4wbVF3FDxrmk2ygODDloVoeWWhBaomgytJmg5nm85tmzd2MoW7rB1LrbyPf2T2HJ1YsuUcwgtx1bZHbz6sCXHRB223octWBJbplaf0JRzFO3C1qBM6rlZbMkJiHazHFraWd5VJFQ1Z/2y5szSas7614YtQzu2nA5bFbE1KIutvlZsyeUD7cZWnoXVh62rqhIAhlhj6jg3Jat865MB7NIMFe0bLJtmgFpTeLDtaQZgCAYOuEp8NWriYJdpqAovWBE16ipgkOtBnvHNtSVbJUV0bYXAsO3JCQmREHwHiLQ6RFZEpF32hVcVutUELRdktBxeptqhbBZeXbVFVXiVrbZIFVwbvNpebiHDC+qHVxZC1hUyu66PptOmWS3VtvY1F/NZcrD04FHE/cBrLeVTiaXRUj6ri6aqfnpZdt/W0pqMteTQ6OdouUritzHy6IoknzaL3z7/sGTkmPX1zNEODyGeBHSbtm5ubn687EtJyxcw4nfqjiKHq4JkbR9DW/J2byapJxwHSQF0C9hum2LqVuEIGI2awm7Pt6opLLvna2ktD7PkPd/Hh6fLYuq7KAsD2j9atfTs9eb4KKBjD5Yj+MiwCA6QuMelGou1Y6rs0Q6VMaXO6vah4LEazZ7tYHUF8VU1qGxBvNZCnOwuC2H4UJZ9GAbL+FgQdyDBovU+/mmdIoysyzw74LTL06h1zhIAHbbOxVbpQpyLbLndEuJtDwYsE6MbH7fd4rk8tiEcGieMH1QcD9y3x7tvj2cX6RMeeZqB+u6aehPZmVJ0aDm3bq2sLwOrnjN0FlxE9YQn1BkYVSdkiYWjE8QtdWHCCcSID2SB2ny3h1+no7ULf9v+8pn8OXoxbiF4/NC9XsoDRslARR2UcpzZkOf21k0eOSTqDk2w387dK2gPJP+ttvD6uNyLBySsqJjU/d+nCqUyTys7KOVkkWddjIdaLds70hO5ZQMFu3YyFVK3ZVN4AseNiy7LJific3i1KBMPTNGNUdm1C6XiWXN/hnbqZ+xPIof3/wE=7V1bb6M4FP41kXYfpsIYh/DYJE13pVZTqSPtzNOKCU7CLsGRcdpkfv2aYALBJDEttyVEfcBX4JzzHdsfx+4ATta7R2pvVs/Ewd5A15zdAE4Hug40XR+Ef5qzj3JGccaSuo6olGS8ur9w3FLkbl0HBycVGSEeczenmXPi+3jOTvJsSsn7abUF8U7vurGX4o5akvE6tz0sVfvLddhKvAVK1f4Du8tVfGegiZK1HVcWXQQr2yHvqXvBhwGcUEJYdLXeTbAXCi+WC5ptvzv2oxd8M7yvzk/rC3v+9SXqfVakyfEVKPZZuV0LXb7Z3lbIS7wr28cCpGTrOzjsBAzg+H3lMvy6sedh6Ts3GZ63YmtPFC9cz5sQj9BDW+gcfmE+8VkqP/rx/IBR8i+OS3zi827Hiq8rxPKGKcO7lKbE6z9issaM7nkVUToyhGqFKesaitLviWGAoaizShuFKTJtYYzLY9+JwPmFkHkB+cNbkv8wI38QQ/aa/CFCFcnfkOT/jIMgdCm69mIHgesvB/rQ488x/kn51TK8YoSXvm43mPok9G1lamxm3VvjSe2aGWaQAXKQoedoxqgKGEiSKna4XxdJQtmKLIlvew9J7jiRu8ZTSZ0nQjZC2v9gxvZikLK3XJEnusA7l30Pm98hkfqRKpnuRM+HxD5O+Px1U43C5I90WdLskIrbFVNlQLZ0jq8bMrPpEl/qz4zqhcK8aBgUezZz304H0Twti6YvxOXvcTQoU88YlJlBcPRColXGVo6P8XHzGfbmU8h8TEXzicen8uznU17ClNz3N2r7wYLQNaZd9Mumol/Wq/LLox5YhYBlKQILwVYBy5KA9YL5i/L70nBmxBdwh5lR9QDjJdrhx0scO1gde68LcgZsGnLxSvo2FglG1uWpLhKAUdUiAYBqFTCbzTTLql3Qo4ygcwwd1WrnMhkxtZktiZq/HzuVZ67c0kIWWbbnLn2e9PAi7CGUlTu3vXuRvXYd5zBg5anvdBCrSikoS1HkKCXP+KtTSsUMRTPzKsn2raZtP4eIiKjYWzb+HK3Ua/yoEeOvfc6ThQPUm4bDUBL81w1zw0XDTSMiTzH1IsKU5N+v/y6t/4BYMF9dAIoVV0vWf/Fj3wyzAnPGmpqXeVYPrULQijVxFVqxMbcEW7q8nu84tlDT2Dp+wu+xpYgtXRFbscragi2ZQ+g4tlS/1FaHLdhjqxi2VL/VwpZhS+YouoWtkQr1XC+2+jCIgtgaKmIrnpK0BVsy4dFtbEGjcWz1VEZBbKlSGbExtwVbXecyJGxpjWOr5zKKYQuqchl6u7gM2HUuQ5oTDpvGFqw4GqFd4SBAywaNm+adYth+hTqQRd77t4v+DSr6N9iu7yBQDn/oln8DWsbBASsXX/W6OKOHVzF4IUV4GaBd8JIDLDoOLwhaAK9+d0RBeKluj4DtYpVg1/dHSPDS82eH9cKr3yNREF6qmyRg6bvXPqdoeZdE1+EFm4dXvB7s4aUIr3jOd/2bSLt4W0MmOtKbkKZ4TpxuwgyiFsBMl6Rfqqgty8GLRd2ilgL/hw1HOhsywfBkMyx2Rd9qnHOeWmqNczZ6BqLgIKPMQLTrA4YhMxB/+ptteJTBDNtsS8PzPbIHgPy24dkBLxvok4Ode2Tusn2Uuru7+73cQamR7YHZw1rMnO8eeZCs7EQQQw6RiDX1QgI33B3SAbEjPfu5I2cioNXqCvs4iYKuUDVOArWLjDXkOInnp5dyMdWKmFnQ+JZ+o5n4iCM+TtCRgOUMPmIsghQSE1zmY7FsTCnH9ZWOKdE0c8DREGZmrFq9BxyhngUpZkFIlQUxWuWUkUyCDPSxrHvPczfBuUVcSoOn3vv8vuOcZWRV7tkE16c8tXpnJDMdNWKriJdtC7ZUozv00sF1aHpPqb1PVdiEXjc477xHmUk20jJHp16pD6zL9a3L9flF9MTljgh5J3UY4R+aHLgjc+xFzJI5/ZuJonIZvenDbNb05M6wmg7QQz13VNB9qHJHsPRD0j7kPrLwNuBldwC0zzaImZazDcDlBhV5nLzjUc55HJ7gDzwG/LozvkcKvG/8EAGUc3DKlmX5zP89SyZFBRtAV44Krkz2zRJlH5k2ghOvf5UIKNvvqxJlLYtaRDJRdgRZh6hooGcGlTOfpUtio3ky+V8K0RCV/EcK+PAf5VhLc5swEP41HJ3h4Qc+GsdOD0mbGXfaJjcFFFAts1QIG/fXVxhhIWDcpCH2tPX4wH5ardC3LwnDmW/yG4aS6A4CTA3bDHLDuTZs23Jt2yj+ZrAvkWEFhIwEUkkBK/ITS9CUaEYCnGqKHIBykuigD3GMfa5hiDHY6WrPQPVVExTKFU0FrHxEcUvtKwl4VKLuqKb9AZMwqla2TDmyQZWyNJFGKIBdbS1nYThzBsDLp00+x7Qgr+Llc367p8yNP67Xs+xpnHyxHgeD0vryNVOOW2A45v2aHrul7S2imSRMbpbvKwYZZHGACyuW4Xi7iHC8SpBfjO5EzAgs4hsqh58JpXOgwA5zneDwK3CIeQ0vfwJPOYM1rkZiiIVZ74X7lbxsMeM4r7lK7v8GwwZzthcquR6TMpTHVRDsVGBYQ4lFelDIgJTBGB5NK8LFg+T8FfxPzBbdOBDxK0VgPIIQYkQXCvWUQ0whKZ1bgES64TvmfC+TEWUcdCfhnPBvxfSrkZQeaiPXubR8EPaVEIv9lpPsUSU/1AfVvINUTXydM1PImI9PMVYVEcRCfMrgZFgqFnyeDA6GKeJkq9eL/j1dRavKtLvb+16TbTmdTb35uZPKaSTVuJ1Uo46cst8tpdoV7BwppdJj8o7pcYGol1PvgYgXPHrdbXh92HBn+aJyUr0rNexY5kg35DQMlQWhZegQGsftvCFanL+lAPddRy8TKBP3ZYHSm3+HrbIrjmxpOkCcCyIJxIY9poJQ74mJp7B48ij463+vMk+sjuOO2VGam7ncX2ketZzxo8Wz2BzXyewkrc6whBAlYSxEX/CIBe4VVBFxFZjJgQ0JgkMCd3lPT+pzdctp2yf2WbvluOWS9fY/84nVvBdc3CmTSzYlq9aSVIP63a1AO/NcHY9AZ7oUuC89Hll9d723Obp9/f6U8SQr+uc9pKTsT302ouXSnE7P3YgGrQzruCR0tqI/SDEhqm8y5TFCfdlyFr8A5VfbbtpAEP0aP1L5giE8xkDSB5DSUqlJ3jb2xN6yeMiy5tKv79oe3zBCoBKQWsTDztnZWeacmTE2nOFi+yjZMppiAMKwzWBrOCPDtq072zbSrxnscqQEQskDcqqAGf8NBJqEJjyAVcNRIQrFl03QxzgGXzUwJiVumm7vKJq3LllIN5oVMPOZgJbbTx6oiLJwa95fgYdRcbNl0s6CFc4UYhWxADe1u5yx4QwlospXi+0QREpewYt6ms4TmL5+t7sfsjOZ8dHrj04e/eGcI2UKEmJ12dDuII+9ZiIhwihZtSsYlJjEAaRRLMPxNhFXMFsyP93d6JrRWKQWgrbfuRBDFCizs06QfVIcY1XD84/GV0riHIqdGGMd1jsxX+JlDVLBtiYV5f8IuAAld9qFdjt3pC3VcreomU1VGbZDWFSviqIoGFVjWMauGNcLIv0MAXpWi28IdAGTiVJFGGLMxLhCvUoRU1uVzwRxSTr8AqV21I0sUdhUCbZcPdfWL2moLy5Zoy1FzoxdYcQ63+fM0XYL+6W+WZ3LrOLgeWquMJE+HGOMKlQxGcKxgD03d0z5PFodEgRTfN0cGJdX2m612gSYjNmb1tQ2vyVAl1yq9x4G9wNvePUe6+/1mNPusbKd6j1mf1qLOTdpsapd+p/YLjfoAjr6hFz/wFL3/dG6PzHzH0qH6o+pvTiW6R4PlA+IVqCsNMp0/qJaurccyNUQfqntHB7Il56rtykUt3taoVxMX7c1hvV/uNWqw5TSRHKMDbsnNKHem9SrMF15Av35PzCbnb0e7Z84mvclutxo7rXE+GjxrJNTTTIPklZnmCAmeBhr09c8gsa9lCqu3w3uaWPBgyBr4EPqNZv6Wpq4B/6SXvVp2W9JMl//Z5pY1tVE0Wb1BpnPuOo93Bn/AQ== -------------------------------------------------------------------------------- /schematics/upt_dense_classifier.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Message Passingto SupernodesMessage Passing...TransformerTransformerPerceiver PoolingPerceiver PoolingDataDataModelModelOptionalOptionalTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerLinear ClassifierLinear ClassifierLatentLatentInput ImageInput ImageInput PositionInput PositionViT Patch EmbedViT Patch Embed++PredictionPredictionViewer does not support full SVG 1.1 -------------------------------------------------------------------------------- /schematics/upt_sparse_autoencoder.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Message Passingto SupernodesMessage Passing...TransformerTransformerPerceiver PoolingPerceiver PoolingDataDataModelModelOptionalOptionalTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerLatentLatentInput Features(pressure, velocity, ...)Input Features...Input PositionInput PositionMLPMLP++Perceiver DecoderPerceiver DecoderOutput FeaturesOutput FeaturesOutput PositionOutput PositionViewer does not support full SVG 1.1 -------------------------------------------------------------------------------- /schematics/upt_sparse_classifier.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Message Passingto SupernodesMessage Passing...TransformerTransformerPerceiver PoolingPerceiver PoolingDataDataModelModelOptionalOptionalTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerTransformerLinear ClassifierLinear ClassifierLatentLatentInput Features(pressure, velocity, ...)Input Features...Input PositionInput PositionMLPMLP++PredictionPredictionViewer does not support full SVG 1.1 -------------------------------------------------------------------------------- /upt/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * -------------------------------------------------------------------------------- /upt/collators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/upt/collators/__init__.py -------------------------------------------------------------------------------- /upt/collators/simulation_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import default_collate 3 | 4 | class SimulationCollator: 5 | def __init__(self, num_supernodes, deterministic): 6 | self.num_supernodes = num_supernodes 7 | self.deterministic = deterministic 8 | 9 | def __call__(self, batch): 10 | collated_batch = {} 11 | 12 | # inputs to sparse tensors 13 | # position: batch_size * (num_inputs, ndim) -> (batch_size * num_inputs, ndim) 14 | # features: batch_size * (num_inputs, dim) -> (batch_size * num_inputs, dim) 15 | input_pos = [] 16 | input_feat = [] 17 | input_lens = [] 18 | for i in range(len(batch)): 19 | pos = batch[i]["input_pos"] 20 | feat = batch[i]["input_feat"] 21 | assert len(pos) == len(pos) 22 | input_pos.append(pos) 23 | input_feat.append(feat) 24 | input_lens.append(len(pos)) 25 | collated_batch["input_pos"] = torch.concat(input_pos) 26 | collated_batch["input_feat"] = torch.concat(input_feat) 27 | 28 | # select supernodes 29 | supernodes_offset = 0 30 | supernode_idxs = [] 31 | for i in range(len(input_lens)): 32 | if self.deterministic: 33 | rng = torch.Generator().manual_seed(batch[i]["index"]) 34 | else: 35 | rng = None 36 | perm = torch.randperm(len(input_pos[i]), generator=rng)[:self.num_supernodes] + supernodes_offset 37 | supernode_idxs.append(perm) 38 | supernodes_offset += input_lens[i] 39 | collated_batch["supernode_idxs"] = torch.concat(supernode_idxs) 40 | 41 | # create batch_idx tensor 42 | batch_idx = torch.empty(sum(input_lens), dtype=torch.long) 43 | start = 0 44 | cur_batch_idx = 0 45 | for i in range(len(input_lens)): 46 | end = start + input_lens[i] 47 | batch_idx[start:end] = cur_batch_idx 48 | start = end 49 | cur_batch_idx += 1 50 | collated_batch["batch_idx"] = batch_idx 51 | 52 | # output_feat to sparse tensor 53 | output_feat = [] 54 | for i in range(len(batch)): 55 | feat = batch[i]["output_feat"] 56 | output_feat.append(feat) 57 | # output_feat is either list of tensors (for training) or list of list of tensors (for rollout) 58 | if torch.is_tensor(output_feat[0]): 59 | collated_batch["output_feat"] = torch.concat(output_feat) 60 | else: 61 | collated_batch["output_feat"] = output_feat 62 | 63 | # collate dense tensors 64 | collated_batch["output_pos"] = default_collate([batch[i]["output_pos"] for i in range(len(batch))]) 65 | collated_batch["timestep"] = default_collate([batch[i]["timestep"] for i in range(len(batch))]) 66 | 67 | return collated_batch 68 | -------------------------------------------------------------------------------- /upt/collators/sparseimage_autoencoder_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import default_collate 3 | 4 | class SparseImageAutoencoderCollator: 5 | def __init__(self, num_supernodes, deterministic): 6 | self.num_supernodes = num_supernodes 7 | self.deterministic = deterministic 8 | 9 | def __call__(self, batch): 10 | collated_batch = {} 11 | 12 | # inputs to sparse tensors 13 | # position: batch_size * (num_inputs, ndim) -> (batch_size * num_inputs, ndim) 14 | # features: batch_size * (num_inputs, dim) -> (batch_size * num_inputs, dim) 15 | input_pos = [] 16 | input_feat = [] 17 | input_lens = [] 18 | for i in range(len(batch)): 19 | pos = batch[i]["input_pos"] 20 | feat = batch[i]["input_feat"] 21 | assert len(pos) == len(pos) 22 | input_pos.append(pos) 23 | input_feat.append(feat) 24 | input_lens.append(len(pos)) 25 | collated_batch["input_pos"] = torch.concat(input_pos) 26 | collated_batch["input_feat"] = torch.concat(input_feat) 27 | 28 | # select supernodes 29 | supernodes_offset = 0 30 | supernode_idxs = [] 31 | for i in range(len(input_lens)): 32 | if self.deterministic: 33 | rng = torch.Generator().manual_seed(batch[i]["index"]) 34 | else: 35 | rng = None 36 | perm = torch.randperm(len(input_pos[i]), generator=rng)[:self.num_supernodes] + supernodes_offset 37 | supernode_idxs.append(perm) 38 | supernodes_offset += input_lens[i] 39 | collated_batch["supernode_idxs"] = torch.concat(supernode_idxs) 40 | 41 | # create batch_idx tensor 42 | batch_idx = torch.empty(sum(input_lens), dtype=torch.long) 43 | start = 0 44 | cur_batch_idx = 0 45 | for i in range(len(input_lens)): 46 | end = start + input_lens[i] 47 | batch_idx[start:end] = cur_batch_idx 48 | start = end 49 | cur_batch_idx += 1 50 | collated_batch["batch_idx"] = batch_idx 51 | 52 | # output_pos 53 | collated_batch["output_pos"] = default_collate([batch[i]["output_pos"] for i in range(len(batch))]) 54 | 55 | # target_feat to sparse tensor 56 | # batch_size * (num_outputs, dim) -> (batch_size * num_outputs, dim) 57 | collated_batch["target_feat"] = torch.concat([batch[i]["target_feat"] for i in range(len(batch))]) 58 | 59 | return collated_batch 60 | -------------------------------------------------------------------------------- /upt/collators/sparseimage_classifier_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import default_collate 3 | 4 | class SparseImageClassifierCollator: 5 | def __init__(self, num_supernodes, deterministic): 6 | self.num_supernodes = num_supernodes 7 | self.deterministic = deterministic 8 | 9 | def __call__(self, batch): 10 | collated_batch = {} 11 | 12 | # inputs to sparse tensors 13 | # position: batch_size * (num_inputs, ndim) -> (batch_size * num_inputs, ndim) 14 | # features: batch_size * (num_inputs, dim) -> (batch_size * num_inputs, dim) 15 | input_pos = [] 16 | input_feat = [] 17 | input_lens = [] 18 | for i in range(len(batch)): 19 | pos = batch[i]["input_pos"] 20 | feat = batch[i]["input_feat"] 21 | assert len(pos) == len(pos) 22 | input_pos.append(pos) 23 | input_feat.append(feat) 24 | input_lens.append(len(pos)) 25 | collated_batch["input_pos"] = torch.concat(input_pos) 26 | collated_batch["input_feat"] = torch.concat(input_feat) 27 | 28 | # select supernodes 29 | supernodes_offset = 0 30 | supernode_idxs = [] 31 | for i in range(len(input_lens)): 32 | if self.deterministic: 33 | rng = torch.Generator().manual_seed(batch[i]["index"]) 34 | else: 35 | rng = None 36 | perm = torch.randperm(len(input_pos[i]), generator=rng)[:self.num_supernodes] + supernodes_offset 37 | supernode_idxs.append(perm) 38 | supernodes_offset += input_lens[i] 39 | collated_batch["supernode_idxs"] = torch.concat(supernode_idxs) 40 | 41 | # create batch_idx tensor 42 | batch_idx = torch.empty(sum(input_lens), dtype=torch.long) 43 | start = 0 44 | cur_batch_idx = 0 45 | for i in range(len(input_lens)): 46 | end = start + input_lens[i] 47 | batch_idx[start:end] = cur_batch_idx 48 | start = end 49 | cur_batch_idx += 1 50 | collated_batch["batch_idx"] = batch_idx 51 | 52 | # targets can be collated normally 53 | collated_batch["target_class"] = default_collate([batch[i]["target_class"] for i in range(len(batch))]) 54 | 55 | return collated_batch 56 | -------------------------------------------------------------------------------- /upt/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/upt/datasets/__init__.py -------------------------------------------------------------------------------- /upt/datasets/simulation_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class SimulationDataset(Dataset): 9 | def __init__( 10 | self, 11 | root, 12 | # how many input points to sample 13 | num_inputs, 14 | # how many output points to sample 15 | num_outputs, 16 | # train or rollout mode 17 | # - train: next timestep prediction 18 | # - rollout: return all timesteps for visualization 19 | mode, 20 | ): 21 | super().__init__() 22 | root = Path(root).expanduser() 23 | self.root = root 24 | self.num_inputs = num_inputs 25 | self.num_outputs = num_outputs 26 | self.mode = mode 27 | # discover simulations 28 | self.case_names = list(sorted(os.listdir(root))) 29 | self.num_timesteps = len( 30 | [ 31 | fname for fname in os.listdir(root / self.case_names[0]) 32 | if fname.endswith("_mesh.th") 33 | ], 34 | ) 35 | # these values were mistakenly copied from the "v1-10000sims" dataset version of the original codebase, which was a preliminary dataset that we generated during development. 36 | # the correct values would be from the "v3-10000sims" dataset version which would be the following (we keep the old values as its not too important for the tutorial). 37 | # self.mean = torch.tensor([0.03648518770933151, 1.927249059008318e-06, 0.000112384237581864]) 38 | # self.std = torch.tensor([0.005249467678368092, 0.003499444341287017, 0.0002817418717313558]) 39 | # For more details on normalization, see the UPT paper, Appendix D.5. 40 | self.mean = torch.tensor([0.0152587890625, -1.7881393432617188e-06, 0.0003612041473388672]) 41 | self.std = torch.tensor([0.0233612060546875, 0.0184173583984375, 0.0019378662109375]) 42 | 43 | def __len__(self): 44 | if self.mode == "train": 45 | # first timestep cant be predicted 46 | return len(self.case_names) * (self.num_timesteps - 1) 47 | elif self.mode == "rollout": 48 | return len(self.case_names) 49 | else: 50 | raise NotImplementedError(f"invalid mode: '{self.mode}'") 51 | 52 | @staticmethod 53 | def _load_positions(case_uri): 54 | x = torch.load(case_uri / "x.th", weights_only=True).float() 55 | y = torch.load(case_uri / "y.th", weights_only=True).float() 56 | # x is in [-0.5, 1.0] -> rescale to [0, 300] positional embedding is designed for positive values in the 100s 57 | x = (x + 0.5) * 200 58 | # y is in [-0.5, 0.5] -> rescale to [0, 200] positional embedding is designed for positive values in the 100s 59 | y = (y + 0.5) * 200 60 | return torch.stack([x, y], dim=1) 61 | 62 | def __getitem__(self, idx): 63 | if self.mode == "train": 64 | # return t and t + 1 65 | case_idx = idx // (self.num_timesteps - 1) 66 | timestep = idx % (self.num_timesteps - 1) 67 | case_uri = self.root / self.case_names[case_idx] 68 | pos = self._load_positions(case_uri) 69 | input_pos = pos 70 | output_pos = pos 71 | input_feat = torch.load(case_uri / f"{timestep:08d}_mesh.th", weights_only=True).float().T 72 | output_feat = torch.load(case_uri / f"{timestep + 1:08d}_mesh.th", weights_only=True).float().T 73 | # subsample inputs 74 | if self.num_inputs != float("inf"): 75 | input_perm = torch.randperm(len(input_feat))[:self.num_inputs] 76 | input_feat = input_feat[input_perm] 77 | input_pos = input_pos[input_perm] 78 | # subsample outputs 79 | if self.num_outputs != float("inf"): 80 | output_perm = torch.randperm(len(output_feat))[:self.num_outputs] 81 | output_feat = output_feat[output_perm] 82 | output_pos = output_pos[output_perm] 83 | # create input dependence to make sure that encoder works by flipping the sign of input/output features 84 | # - if the dataset consists of only a single sample the decoder could learn it by heart 85 | # - if the encoder doesnt work it would not get recognized as it is not needed 86 | # - flipping the sign creates an input dependence (if input sign is flipped -> flip output sign) 87 | # - if the encoder does not work, it will learn an average of the two samples 88 | # - this is only relevant because this tutorial overfits on 1 trajectory for simplicity 89 | if torch.rand(size=(1,)) < 0.5: 90 | input_feat *= -1 91 | output_feat *= -1 92 | elif self.mode == "rollout": 93 | # return all timesteps 94 | timestep = 0 95 | case_uri = self.root / self.case_names[idx] 96 | pos = self._load_positions(case_uri) 97 | input_pos = pos 98 | output_pos = pos 99 | data = [ 100 | torch.load(case_uri / f"{i:08d}_mesh.th", weights_only=True).float().T 101 | for i in range(self.num_timesteps) 102 | ] 103 | input_feat = data[0] 104 | output_feat = data[1:] 105 | # deterministically downsample (for fast evaluation during training) 106 | # subsample inputs 107 | if self.num_inputs != float("inf"): 108 | rng = torch.Generator().manual_seed(idx) 109 | input_perm = torch.randperm(len(input_feat), generator=rng)[:self.num_inputs] 110 | input_feat = input_feat[input_perm] 111 | input_pos = input_pos[input_perm] 112 | # subsample outputs 113 | if self.num_outputs != float("inf"): 114 | rng = torch.Generator().manual_seed(idx) 115 | output_perm = torch.randperm(len(output_pos), generator=rng)[:self.num_outputs] 116 | output_pos = output_pos[output_perm] 117 | for i in range(len(output_feat)): 118 | output_feat[i] = output_feat[i][output_perm] 119 | else: 120 | raise NotImplementedError 121 | 122 | # normalize 123 | input_feat -= self.mean.unsqueeze(0) 124 | input_feat /= self.std.unsqueeze(0) 125 | if isinstance(output_feat, list): 126 | for i in range(len(output_feat)): 127 | output_feat[i] -= self.mean.unsqueeze(0) 128 | output_feat[i] /= self.std.unsqueeze(0) 129 | else: 130 | output_feat -= self.mean.unsqueeze(0) 131 | output_feat /= self.std.unsqueeze(0) 132 | 133 | return dict( 134 | index=idx, 135 | input_feat=input_feat, 136 | input_pos=input_pos, 137 | output_feat=output_feat, 138 | output_pos=output_pos, 139 | timestep=timestep, 140 | ) 141 | -------------------------------------------------------------------------------- /upt/datasets/sparse_cifar10_autoencoder_dataset.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from torchvision.datasets import CIFAR10 4 | 5 | 6 | class SparseCIFAR10AutoencoderDataset(CIFAR10): 7 | def __init__( 8 | self, 9 | # how many input pixels to sample (<= 1024) 10 | num_inputs, 11 | # how many output pixels to sample (<= 1024) 12 | num_outputs, 13 | # CIFAR10 properties 14 | root, 15 | train=True, 16 | transform=None, 17 | download=False, 18 | ): 19 | super().__init__( 20 | root=root, 21 | train=train, 22 | transform=transform, 23 | download=download, 24 | ) 25 | assert num_inputs <= 1024, "CIFAR10 only has 1024 pixels, use less or equal 1024 num_inputs" 26 | self.num_inputs = num_inputs 27 | self.num_outputs = num_outputs 28 | # CIFAR has 32x32 pixels 29 | # output_pos will be a tensor of shape (32 * 32, 2) with and will contain x and y indices 30 | # output_pos[0] = [0, 0] 31 | # output_pos[1] = [0, 1] 32 | # output_pos[2] = [0, 2] 33 | # ... 34 | # output_pos[32] = [1, 0] 35 | # output_pos[1024] = [31, 31] 36 | self.output_pos = einops.rearrange( 37 | torch.stack(torch.meshgrid([torch.arange(32), torch.arange(32)], indexing="ij")), 38 | "ndim height width -> (height width) ndim", 39 | ).float() 40 | 41 | def __getitem__(self, idx): 42 | image, _ = super().__getitem__(idx) 43 | assert image.shape == (3, 32, 32) 44 | # reshape image to sparse tensor 45 | x = einops.rearrange(image, "dim height width -> (height width) dim") 46 | pos = self.output_pos.clone() 47 | 48 | # subsample random input pixels (locations of inputs and outputs does not have to be the same) 49 | if self.num_inputs < 1024: 50 | if self.train: 51 | rng = None 52 | else: 53 | rng = torch.Generator().manual_seed(idx) 54 | input_perm = torch.randperm(len(x), generator=rng)[:self.num_inputs] 55 | input_feat = x[input_perm] 56 | input_pos = pos[input_perm].clone() 57 | else: 58 | input_feat = x 59 | input_pos = pos.clone() 60 | 61 | # subsample random output pixels (locations of inputs and outputs does not have to be the same) 62 | if self.num_outputs < 1024: 63 | if self.train: 64 | rng = None 65 | else: 66 | rng = torch.Generator().manual_seed(idx + 1) 67 | output_perm = torch.randperm(len(x), generator=rng)[:self.num_outputs] 68 | target_feat = x[output_perm] 69 | output_pos = pos[output_perm].clone() 70 | else: 71 | target_feat = x 72 | output_pos = pos.clone() 73 | 74 | return dict( 75 | index=idx, 76 | input_feat=input_feat, 77 | input_pos=input_pos, 78 | target_feat=target_feat, 79 | output_pos=output_pos, 80 | ) 81 | -------------------------------------------------------------------------------- /upt/datasets/sparse_cifar10_classifier_dataset.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from torchvision.datasets import CIFAR10 4 | 5 | 6 | class SparseCIFAR10ClassifierDataset(CIFAR10): 7 | def __init__(self, root, num_inputs, train=True, transform=None, download=False): 8 | super().__init__( 9 | root=root, 10 | train=train, 11 | transform=transform, 12 | download=download, 13 | ) 14 | assert num_inputs <= 1024, "CIFAR10 only has 1024 pixels, use less or equal 1024 num_inputs" 15 | self.num_inputs = num_inputs 16 | # CIFAR has 32x32 pixels 17 | # output_pos will be a tensor of shape (32 * 32, 2) with and will contain x and y indices 18 | # output_pos[0] = [0, 0] 19 | # output_pos[1] = [0, 1] 20 | # output_pos[2] = [0, 2] 21 | # ... 22 | # output_pos[32] = [1, 0] 23 | # output_pos[1024] = [31, 31] 24 | self.output_pos = einops.rearrange( 25 | torch.stack(torch.meshgrid([torch.arange(32), torch.arange(32)], indexing="ij")), 26 | "ndim height width -> (height width) ndim", 27 | ).float() 28 | 29 | def __getitem__(self, idx): 30 | image, y = super().__getitem__(idx) 31 | assert image.shape == (3, 32, 32) 32 | # reshape image to sparse tensor 33 | x = einops.rearrange(image, "dim height width -> (height width) dim") 34 | pos = self.output_pos.clone() 35 | 36 | # subsample random input pixels (locations of inputs and outputs does not have to be the same) 37 | if self.num_inputs < 1024: 38 | if self.train: 39 | rng = None 40 | else: 41 | rng = torch.Generator().manual_seed(idx) 42 | input_perm = torch.randperm(len(x), generator=rng)[:self.num_inputs] 43 | input_feat = x[input_perm] 44 | input_pos = pos[input_perm].clone() 45 | else: 46 | input_feat = x 47 | input_pos = pos.clone() 48 | 49 | return dict( 50 | index=idx, 51 | input_feat=input_feat, 52 | input_pos=input_pos, 53 | target_class=y, 54 | ) 55 | -------------------------------------------------------------------------------- /upt/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/upt/models/__init__.py -------------------------------------------------------------------------------- /upt/models/approximator.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from kappamodules.layers import LinearProjection, Sequential 5 | from kappamodules.transformer import DitBlock, PrenormBlock 6 | from torch import nn 7 | 8 | 9 | class Approximator(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim, 13 | depth, 14 | num_heads, 15 | dim=None, 16 | cond_dim=None, 17 | init_weights="truncnormal002", 18 | **kwargs, 19 | ): 20 | super().__init__(**kwargs) 21 | dim = dim or input_dim 22 | self.dim = dim 23 | self.depth = depth 24 | self.num_heads = num_heads 25 | self.cond_dim = cond_dim 26 | self.init_weights = init_weights 27 | 28 | # project 29 | self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True) 30 | 31 | # blocks 32 | if cond_dim is None: 33 | block_ctor = PrenormBlock 34 | else: 35 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 36 | self.blocks = Sequential( 37 | *[ 38 | block_ctor( 39 | dim=dim, 40 | num_heads=num_heads, 41 | init_weights=init_weights, 42 | ) 43 | for _ in range(depth) 44 | ], 45 | ) 46 | 47 | def forward(self, x, condition=None): 48 | # check inputs 49 | assert x.ndim == 3, "expected shape (batch_size, num_latent_tokens, dim)" 50 | if condition is not None: 51 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 52 | 53 | # pass condition to DiT blocks 54 | cond_kwargs = {} 55 | if condition is not None: 56 | cond_kwargs["cond"] = condition 57 | 58 | # project to decoder dim 59 | x = self.input_proj(x) 60 | 61 | # apply blocks 62 | x = self.blocks(x, **cond_kwargs) 63 | 64 | return x 65 | -------------------------------------------------------------------------------- /upt/models/conditioner_timestep.py: -------------------------------------------------------------------------------- 1 | from kappamodules.functional.pos_embed import get_sincos_1d_from_seqlen 2 | from torch import nn 3 | 4 | 5 | class ConditionerTimestep(nn.Module): 6 | def __init__(self, dim, num_timesteps): 7 | super().__init__() 8 | cond_dim = dim * 4 9 | self.num_timesteps = num_timesteps 10 | self.dim = dim 11 | self.cond_dim = cond_dim 12 | self.register_buffer( 13 | "timestep_embed", 14 | get_sincos_1d_from_seqlen(seqlen=num_timesteps, dim=dim), 15 | ) 16 | self.mlp = nn.Sequential( 17 | nn.Linear(dim, cond_dim), 18 | nn.SiLU(), 19 | ) 20 | 21 | def forward(self, timestep): 22 | # checks + preprocess 23 | assert timestep.numel() == len(timestep) 24 | timestep = timestep.flatten() 25 | # embed 26 | embed = self.mlp(self.timestep_embed[timestep]) 27 | return embed 28 | -------------------------------------------------------------------------------- /upt/models/decoder_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from kappamodules.layers import LinearProjection, Sequential 4 | from kappamodules.transformer import DitBlock 5 | from kappamodules.vit import VitBlock 6 | from torch import nn 7 | 8 | 9 | class DecoderClassifier(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim, 13 | num_classes, 14 | dim, 15 | depth, 16 | num_heads, 17 | cond_dim=None, 18 | init_weights="truncnormal002", 19 | **kwargs, 20 | ): 21 | super().__init__(**kwargs) 22 | self.input_dim = input_dim 23 | self.num_classes = num_classes 24 | self.dim = dim 25 | self.depth = depth 26 | self.num_heads = num_heads 27 | self.cond_dim = cond_dim 28 | self.init_weights = init_weights 29 | 30 | # input projection 31 | self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True) 32 | 33 | # blocks 34 | if cond_dim is None: 35 | block_ctor = VitBlock 36 | else: 37 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 38 | self.blocks = Sequential( 39 | *[ 40 | block_ctor( 41 | dim=dim, 42 | num_heads=num_heads, 43 | init_weights=init_weights, 44 | ) 45 | for _ in range(depth) 46 | ], 47 | ) 48 | 49 | # classifier 50 | self.head = nn.Sequential( 51 | nn.LayerNorm(dim, eps=1e-6), 52 | LinearProjection(dim, num_classes), 53 | ) 54 | 55 | def forward(self, x, condition=None): 56 | # check inputs 57 | assert x.ndim == 3, "expected shape (batch_size, num_latent_tokens, dim)" 58 | if condition is not None: 59 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 60 | 61 | # pass condition to DiT blocks 62 | cond_kwargs = {} 63 | if condition is not None: 64 | cond_kwargs["cond"] = condition 65 | 66 | # input projection 67 | x = self.input_proj(x) 68 | 69 | # apply blocks 70 | x = self.blocks(x, **cond_kwargs) 71 | 72 | # pool 73 | x = x.mean(dim=1) 74 | 75 | # classify 76 | x = self.head(x) 77 | 78 | return x 79 | -------------------------------------------------------------------------------- /upt/models/decoder_perceiver.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import einops 4 | import torch 5 | from kappamodules.layers import ContinuousSincosEmbed, LinearProjection, Sequential 6 | from kappamodules.transformer import PerceiverBlock, DitPerceiverBlock, DitBlock 7 | from kappamodules.vit import VitBlock 8 | from torch import nn 9 | import math 10 | 11 | 12 | class DecoderPerceiver(nn.Module): 13 | def __init__( 14 | self, 15 | input_dim, 16 | output_dim, 17 | ndim, 18 | dim, 19 | depth, 20 | num_heads, 21 | unbatch_mode="dense_to_sparse_unpadded", 22 | perc_dim=None, 23 | perc_num_heads=None, 24 | cond_dim=None, 25 | init_weights="truncnormal002", 26 | **kwargs, 27 | ): 28 | super().__init__(**kwargs) 29 | perc_dim = perc_dim or dim 30 | perc_num_heads = perc_num_heads or num_heads 31 | self.input_dim = input_dim 32 | self.output_dim = output_dim 33 | self.ndim = ndim 34 | self.dim = dim 35 | self.depth = depth 36 | self.num_heads = num_heads 37 | self.perc_dim = perc_dim 38 | self.perc_num_heads = perc_num_heads 39 | self.cond_dim = cond_dim 40 | self.init_weights = init_weights 41 | self.unbatch_mode = unbatch_mode 42 | 43 | # input projection 44 | self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True) 45 | 46 | # blocks 47 | if cond_dim is None: 48 | block_ctor = VitBlock 49 | else: 50 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 51 | self.blocks = Sequential( 52 | *[ 53 | block_ctor( 54 | dim=dim, 55 | num_heads=num_heads, 56 | init_weights=init_weights, 57 | ) 58 | for _ in range(depth) 59 | ], 60 | ) 61 | 62 | # prepare perceiver 63 | self.pos_embed = ContinuousSincosEmbed( 64 | dim=perc_dim, 65 | ndim=ndim, 66 | ) 67 | if cond_dim is None: 68 | block_ctor = PerceiverBlock 69 | else: 70 | block_ctor = partial(DitPerceiverBlock, cond_dim=cond_dim) 71 | 72 | # decoder 73 | self.query_proj = nn.Sequential( 74 | LinearProjection(perc_dim, perc_dim, init_weights=init_weights), 75 | nn.GELU(), 76 | LinearProjection(perc_dim, perc_dim, init_weights=init_weights), 77 | ) 78 | self.perc = block_ctor(dim=perc_dim, kv_dim=dim, num_heads=perc_num_heads, init_weights=init_weights) 79 | self.pred = nn.Sequential( 80 | nn.LayerNorm(perc_dim, eps=1e-6), 81 | LinearProjection(perc_dim, output_dim, init_weights=init_weights), 82 | ) 83 | 84 | def forward(self, x, output_pos, condition=None): 85 | # check inputs 86 | assert x.ndim == 3, "expected shape (batch_size, num_latent_tokens, dim)" 87 | assert output_pos.ndim == 3, "expected shape (batch_size, num_outputs, dim) num_outputs might be padded" 88 | if condition is not None: 89 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 90 | 91 | # pass condition to DiT blocks 92 | cond_kwargs = {} 93 | if condition is not None: 94 | cond_kwargs["cond"] = condition 95 | 96 | # input projection 97 | x = self.input_proj(x) 98 | 99 | # apply blocks 100 | x = self.blocks(x, **cond_kwargs) 101 | 102 | # create query 103 | query = self.pos_embed(output_pos) 104 | query = self.query_proj(query) 105 | 106 | x = self.perc(q=query, kv=x, **cond_kwargs) 107 | x = self.pred(x) 108 | if self.unbatch_mode == "dense_to_sparse_unpadded": 109 | # dense to sparse where no padding needs to be considered 110 | x = einops.rearrange( 111 | x, 112 | "batch_size seqlen dim -> (batch_size seqlen) dim", 113 | ) 114 | elif self.unbatch_mode == "image": 115 | # rearrange to square image 116 | height = math.sqrt(x.size(1)) 117 | assert height.is_integer() 118 | x = einops.rearrange( 119 | x, 120 | "batch_size (height width) dim -> batch_size dim height width", 121 | height=int(height), 122 | ) 123 | else: 124 | raise NotImplementedError(f"invalid unbatch_mode '{self.unbatch_mode}'") 125 | 126 | return x 127 | -------------------------------------------------------------------------------- /upt/models/encoder_image.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import einops 4 | from kappamodules.layers import Sequential 5 | from kappamodules.transformer import PerceiverPoolingBlock, PrenormBlock, DitPerceiverPoolingBlock, DitBlock 6 | from kappamodules.utils.param_checking import to_2tuple 7 | from kappamodules.vit import VitPatchEmbed, VitPosEmbed2d 8 | from torch import nn 9 | 10 | 11 | class EncoderImage(nn.Module): 12 | def __init__( 13 | self, 14 | input_dim, 15 | patch_size, 16 | resolution, 17 | enc_dim, 18 | enc_num_heads, 19 | enc_depth, 20 | perc_dim=None, 21 | perc_num_heads=None, 22 | num_latent_tokens=None, 23 | cond_dim=None, 24 | init_weights="truncnormal", 25 | ): 26 | super().__init__() 27 | patch_size = to_2tuple(patch_size) 28 | resolution = to_2tuple(resolution) 29 | self.input_dim = input_dim 30 | self.patch_size = patch_size 31 | self.resolution = resolution 32 | self.enc_dim = enc_dim 33 | self.enc_depth = enc_depth 34 | self.enc_num_heads = enc_num_heads 35 | self.perc_dim = perc_dim 36 | self.perc_num_heads = perc_num_heads 37 | self.num_latent_tokens = num_latent_tokens 38 | self.condition_dim = cond_dim 39 | self.init_weights = init_weights 40 | 41 | # embed 42 | self.patch_embed = VitPatchEmbed( 43 | dim=enc_dim, 44 | num_channels=input_dim, 45 | resolution=resolution, 46 | patch_size=patch_size, 47 | ) 48 | self.pos_embed = VitPosEmbed2d(seqlens=self.patch_embed.seqlens, dim=enc_dim, is_learnable=False) 49 | 50 | # blocks 51 | if cond_dim is None: 52 | block_ctor = PrenormBlock 53 | else: 54 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 55 | self.blocks = Sequential( 56 | *[ 57 | block_ctor(dim=enc_dim, num_heads=enc_num_heads, init_weights=init_weights) 58 | for _ in range(enc_depth) 59 | ], 60 | ) 61 | 62 | # perceiver pooling 63 | if num_latent_tokens is None: 64 | self.perceiver = None 65 | else: 66 | if cond_dim is None: 67 | block_ctor = partial( 68 | PerceiverPoolingBlock, 69 | perceiver_kwargs=dict( 70 | kv_dim=enc_dim, 71 | init_weights=init_weights, 72 | ), 73 | ) 74 | else: 75 | block_ctor = partial( 76 | DitPerceiverPoolingBlock, 77 | perceiver_kwargs=dict( 78 | kv_dim=enc_dim, 79 | cond_dim=cond_dim, 80 | init_weights=init_weights, 81 | ), 82 | ) 83 | self.perceiver = block_ctor( 84 | dim=perc_dim, 85 | num_heads=perc_num_heads, 86 | num_query_tokens=num_latent_tokens, 87 | ) 88 | 89 | def forward(self, input_image, condition=None): 90 | # check inputs 91 | assert input_image.ndim == 4, "expected input image of shape (batch_size, num_channels, height, width)" 92 | if condition is not None: 93 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 94 | 95 | # pass condition to DiT blocks 96 | cond_kwargs = {} 97 | if condition is not None: 98 | cond_kwargs["cond"] = condition 99 | 100 | # patch_embed 101 | x = self.patch_embed(input_image) 102 | # add pos_embed 103 | x = self.pos_embed(x) 104 | # flatten 105 | x = einops.rearrange(x, "b ... d -> b (...) d") 106 | 107 | # transformer 108 | x = self.blocks(x, **cond_kwargs) 109 | 110 | # perceiver 111 | if self.perceiver is not None: 112 | x = self.perceiver(kv=x, **cond_kwargs) 113 | 114 | return x 115 | -------------------------------------------------------------------------------- /upt/models/encoder_supernodes.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from kappamodules.layers import LinearProjection, Sequential 4 | from kappamodules.transformer import PerceiverPoolingBlock, PrenormBlock, DitPerceiverPoolingBlock, DitBlock 5 | from upt.modules.supernode_pooling import SupernodePooling 6 | from torch import nn 7 | 8 | 9 | class EncoderSupernodes(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim, 13 | ndim, 14 | radius, 15 | max_degree, 16 | gnn_dim, 17 | enc_dim, 18 | enc_depth, 19 | enc_num_heads, 20 | perc_dim=None, 21 | perc_num_heads=None, 22 | num_latent_tokens=None, 23 | cond_dim=None, 24 | init_weights="truncnormal", 25 | ): 26 | super().__init__() 27 | self.input_dim = input_dim 28 | self.ndim = ndim 29 | self.radius = radius 30 | self.max_degree = max_degree 31 | self.gnn_dim = gnn_dim 32 | self.enc_dim = enc_dim 33 | self.enc_depth = enc_depth 34 | self.enc_num_heads = enc_num_heads 35 | self.perc_dim = perc_dim 36 | self.perc_num_heads = perc_num_heads 37 | self.num_latent_tokens = num_latent_tokens 38 | self.condition_dim = cond_dim 39 | self.init_weights = init_weights 40 | 41 | # supernode pooling 42 | self.supernode_pooling = SupernodePooling( 43 | radius=radius, 44 | max_degree=max_degree, 45 | input_dim=input_dim, 46 | hidden_dim=gnn_dim, 47 | ndim=ndim, 48 | ) 49 | 50 | # blocks 51 | self.enc_proj = LinearProjection(gnn_dim, enc_dim, init_weights=init_weights, optional=True) 52 | if cond_dim is None: 53 | block_ctor = PrenormBlock 54 | else: 55 | block_ctor = partial(DitBlock, cond_dim=cond_dim) 56 | self.blocks = Sequential( 57 | *[ 58 | block_ctor(dim=enc_dim, num_heads=enc_num_heads, init_weights=init_weights) 59 | for _ in range(enc_depth) 60 | ], 61 | ) 62 | 63 | # perceiver pooling 64 | if num_latent_tokens is None: 65 | self.perceiver = None 66 | else: 67 | if cond_dim is None: 68 | block_ctor = partial( 69 | PerceiverPoolingBlock, 70 | perceiver_kwargs=dict( 71 | kv_dim=enc_dim, 72 | init_weights=init_weights, 73 | ), 74 | ) 75 | else: 76 | block_ctor = partial( 77 | DitPerceiverPoolingBlock, 78 | perceiver_kwargs=dict( 79 | kv_dim=enc_dim, 80 | cond_dim=cond_dim, 81 | init_weights=init_weights, 82 | ), 83 | ) 84 | self.perceiver = block_ctor( 85 | dim=perc_dim, 86 | num_heads=perc_num_heads, 87 | num_query_tokens=num_latent_tokens, 88 | ) 89 | 90 | def forward(self, input_feat, input_pos, supernode_idxs, batch_idx, condition=None): 91 | # check inputs 92 | assert input_feat.ndim == 2, "expected sparse tensor (batch_size * num_inputs, input_dim)" 93 | assert input_pos.ndim == 2, "expected sparse tensor (batch_size * num_inputs, ndim)" 94 | assert len(input_feat) == len(input_pos), "expected input_feat and input_pos to have same length" 95 | assert supernode_idxs.ndim == 1, "supernode_idxs is a 1D tensor of indices that are used as supernodes" 96 | assert batch_idx.ndim == 1, f"batch_idx should be 1D tensor that assigns elements of the input to samples" 97 | if condition is not None: 98 | assert condition.ndim == 2, "expected shape (batch_size, cond_dim)" 99 | 100 | # pass condition to DiT blocks 101 | cond_kwargs = {} 102 | if condition is not None: 103 | cond_kwargs["cond"] = condition 104 | 105 | # supernode pooling 106 | x = self.supernode_pooling( 107 | input_feat=input_feat, 108 | input_pos=input_pos, 109 | supernode_idxs=supernode_idxs, 110 | batch_idx=batch_idx, 111 | ) 112 | 113 | # project to encoder dimension 114 | x = self.enc_proj(x) 115 | 116 | # transformer 117 | x = self.blocks(x, **cond_kwargs) 118 | 119 | # perceiver 120 | if self.perceiver is not None: 121 | x = self.perceiver(kv=x, **cond_kwargs) 122 | 123 | return x 124 | -------------------------------------------------------------------------------- /upt/models/upt.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class UPT(nn.Module): 7 | def __init__(self, conditioner, encoder, approximator, decoder): 8 | super().__init__() 9 | self.conditioner = conditioner 10 | self.encoder = encoder 11 | self.approximator = approximator 12 | self.decoder = decoder 13 | 14 | def forward( 15 | self, 16 | input_feat, 17 | input_pos, 18 | supernode_idxs, 19 | output_pos, 20 | batch_idx, 21 | timestep, 22 | ): 23 | condition = self.conditioner(timestep) 24 | 25 | # encode data 26 | latent = self.encoder( 27 | input_feat=input_feat, 28 | input_pos=input_pos, 29 | supernode_idxs=supernode_idxs, 30 | batch_idx=batch_idx, 31 | condition=condition, 32 | ) 33 | 34 | # propagate forward 35 | latent = self.approximator(latent, condition=condition) 36 | 37 | # decode 38 | pred = self.decoder( 39 | x=latent, 40 | output_pos=output_pos, 41 | condition=condition, 42 | ) 43 | 44 | return pred 45 | 46 | @torch.no_grad() 47 | def rollout(self, input_feat, input_pos, supernode_idxs, batch_idx): 48 | batch_size = batch_idx.max() + 1 49 | timestep = torch.zeros(batch_size).long() 50 | 51 | # we assume that output_pos is simply a rearranged input_pos 52 | # i.e. num_inputs == num_outputs and num_inputs is constant for all samples 53 | output_pos = einops.rearrange( 54 | input_pos, 55 | "(batch_size num_inputs) ndim -> batch_size num_inputs ndim", 56 | batch_size=batch_size, 57 | ) 58 | 59 | predictions = [] 60 | for i in range(self.conditioner.num_timesteps): 61 | condition = self.conditioner(timestep) 62 | # encode data 63 | latent = self.encoder( 64 | input_feat=input_feat, 65 | input_pos=input_pos, 66 | supernode_idxs=supernode_idxs, 67 | batch_idx=batch_idx, 68 | condition=condition, 69 | ) 70 | 71 | # propagate forward 72 | latent = self.approximator(latent, condition=condition) 73 | 74 | # decode 75 | pred = self.decoder( 76 | x=latent, 77 | output_pos=output_pos, 78 | condition=condition, 79 | ) 80 | predictions.append(pred) 81 | 82 | # increase timestep 83 | timestep += 1 84 | 85 | # feed prediction as next input for autoregressive rollout 86 | input_feat = pred 87 | 88 | return predictions 89 | -------------------------------------------------------------------------------- /upt/models/upt_image_autoencoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UPTImageAutoencoder(nn.Module): 5 | def __init__(self, encoder, approximator, decoder): 6 | super().__init__() 7 | self.encoder = encoder 8 | self.approximator = approximator 9 | self.decoder = decoder 10 | 11 | def forward(self, x, output_pos): 12 | # encode data 13 | latent = self.encoder(x) 14 | 15 | # propagate forward 16 | latent = self.approximator(latent) 17 | 18 | # decode 19 | pred = self.decoder(latent, output_pos=output_pos) 20 | 21 | return pred 22 | -------------------------------------------------------------------------------- /upt/models/upt_image_classifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UPTImageClassifier(nn.Module): 5 | def __init__(self, encoder, approximator, decoder): 6 | super().__init__() 7 | self.encoder = encoder 8 | self.approximator = approximator 9 | self.decoder = decoder 10 | 11 | def forward(self, x): 12 | # encode data 13 | latent = self.encoder(x) 14 | 15 | # propagate forward 16 | latent = self.approximator(latent) 17 | 18 | # decode 19 | pred = self.decoder(latent) 20 | 21 | return pred 22 | -------------------------------------------------------------------------------- /upt/models/upt_sparseimage_autoencoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UPTSparseImageAutoencoder(nn.Module): 5 | def __init__(self, encoder, approximator, decoder): 6 | super().__init__() 7 | self.encoder = encoder 8 | self.approximator = approximator 9 | self.decoder = decoder 10 | 11 | def forward(self, input_feat, input_pos, supernode_idxs, batch_idx, output_pos): 12 | # encode data 13 | latent = self.encoder( 14 | input_feat=input_feat, 15 | input_pos=input_pos, 16 | supernode_idxs=supernode_idxs, 17 | batch_idx=batch_idx, 18 | ) 19 | 20 | # propagate forward 21 | latent = self.approximator(latent) 22 | 23 | # decode 24 | pred = self.decoder(latent, output_pos=output_pos) 25 | 26 | return pred 27 | -------------------------------------------------------------------------------- /upt/models/upt_sparseimage_classifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UPTSparseImageClassifier(nn.Module): 5 | def __init__(self, encoder, approximator, decoder): 6 | super().__init__() 7 | self.encoder = encoder 8 | self.approximator = approximator 9 | self.decoder = decoder 10 | 11 | def forward(self, input_feat, input_pos, supernode_idxs, batch_idx): 12 | # encode data 13 | latent = self.encoder( 14 | input_feat=input_feat, 15 | input_pos=input_pos, 16 | supernode_idxs=supernode_idxs, 17 | batch_idx=batch_idx, 18 | ) 19 | 20 | # propagate forward 21 | latent = self.approximator(latent) 22 | 23 | # decode 24 | pred = self.decoder(latent) 25 | 26 | return pred 27 | -------------------------------------------------------------------------------- /upt/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenediktAlkin/upt-tutorial/941b574efc556cae607713c98c9ce424fd75c3ee/upt/modules/__init__.py -------------------------------------------------------------------------------- /upt/modules/supernode_pooling.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from kappamodules.layers import ContinuousSincosEmbed, LinearProjection 4 | from torch import nn 5 | from torch_geometric.nn.pool import radius_graph 6 | from torch_scatter import segment_csr 7 | 8 | 9 | class SupernodePooling(nn.Module): 10 | def __init__( 11 | self, 12 | radius, 13 | max_degree, 14 | input_dim, 15 | hidden_dim, 16 | ndim, 17 | init_weights="torch", 18 | ): 19 | super().__init__() 20 | self.radius = radius 21 | self.max_degree = max_degree 22 | self.input_dim = input_dim 23 | self.hidden_dim = hidden_dim 24 | self.ndim = ndim 25 | self.init_weights = init_weights 26 | 27 | self.input_proj = LinearProjection(input_dim, hidden_dim, init_weights=init_weights) 28 | self.pos_embed = ContinuousSincosEmbed(dim=hidden_dim, ndim=ndim) 29 | self.message = nn.Sequential( 30 | LinearProjection(hidden_dim * 2, hidden_dim, init_weights=init_weights), 31 | nn.GELU(), 32 | LinearProjection(hidden_dim, hidden_dim, init_weights=init_weights), 33 | ) 34 | self.output_dim = hidden_dim 35 | 36 | def forward(self, input_feat, input_pos, supernode_idxs, batch_idx): 37 | assert input_feat.ndim == 2 38 | assert input_pos.ndim == 2 39 | assert supernode_idxs.ndim == 1 40 | 41 | # radius graph 42 | input_edges = radius_graph( 43 | x=input_pos, 44 | r=self.radius, 45 | max_num_neighbors=self.max_degree, 46 | batch=batch_idx, 47 | loop=True, 48 | # inverted flow direction is required to have sorted dst_indices 49 | flow="target_to_source", 50 | ) 51 | is_supernode_edge = torch.isin(input_edges[0], supernode_idxs) 52 | input_edges = input_edges[:, is_supernode_edge] 53 | 54 | # embed mesh 55 | x = self.input_proj(input_feat) + self.pos_embed(input_pos) 56 | 57 | # create message input 58 | dst_idx, src_idx = input_edges.unbind() 59 | x = torch.concat([x[src_idx], x[dst_idx]], dim=1) 60 | x = self.message(x) 61 | # accumulate messages 62 | # indptr is a tensor of indices betweeen which to aggregate 63 | # i.e. a tensor of [0, 2, 5] would result in [src[0] + src[1], src[2] + src[3] + src[4]] 64 | dst_indices, counts = dst_idx.unique(return_counts=True) 65 | # first index has to be 0 66 | # NOTE: padding for target indices that dont occour is not needed as self-loop is always present 67 | padded_counts = torch.zeros(len(counts) + 1, device=counts.device, dtype=counts.dtype) 68 | padded_counts[1:] = counts 69 | indptr = padded_counts.cumsum(dim=0) 70 | x = segment_csr(src=x, indptr=indptr, reduce="mean") 71 | 72 | # sanity check: dst_indices has len of batch_size * num_supernodes and has to be divisible by batch_size 73 | # if num_supernodes is not set in dataset this assertion fails 74 | batch_size = batch_idx.max() + 1 75 | assert dst_indices.numel() % batch_size == 0 76 | 77 | # convert to dense tensor (dim last) 78 | x = einops.rearrange( 79 | x, 80 | "(batch_size num_supernodes) dim -> batch_size num_supernodes dim", 81 | batch_size=batch_size, 82 | ) 83 | 84 | return x 85 | --------------------------------------------------------------------------------