├── aion ├── fourm │ ├── __init__.py │ ├── modality_transforms.py │ ├── generation_utils.py │ ├── text_utils.py │ └── lora_utils.py ├── codecs │ ├── modules │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── subsampler.py │ │ ├── ema.py │ │ ├── spectrum.py │ │ ├── convnext.py │ │ └── magvit.py │ ├── preprocessing │ │ ├── band_to_index.py │ │ └── image.py │ ├── __init__.py │ ├── config.py │ ├── catalog.py │ ├── base.py │ ├── scalar.py │ ├── manager.py │ ├── utils.py │ ├── image.py │ ├── spectrum.py │ └── scalar_field.py ├── __init__.py └── model.py ├── assets └── aion.png ├── docs ├── _static │ ├── aion.png │ ├── image.png │ ├── embeddings.png │ ├── data_mixture.png │ └── polymathic_logo.png ├── Makefile ├── index.md ├── conf.py └── api.rst ├── tests ├── test_data │ ├── EBV_codec_decoded_batch.pt │ ├── EBV_codec_encoded_batch.pt │ ├── EBV_codec_input_batch.pt │ ├── Z_codec_decoded_batch.pt │ ├── Z_codec_encoded_batch.pt │ ├── Z_codec_input_batch.pt │ ├── a_g_codec_decoded_batch.pt │ ├── a_g_codec_encoded_batch.pt │ ├── a_g_codec_input_batch.pt │ ├── a_i_codec_decoded_batch.pt │ ├── a_i_codec_encoded_batch.pt │ ├── a_i_codec_input_batch.pt │ ├── a_r_codec_decoded_batch.pt │ ├── a_r_codec_encoded_batch.pt │ ├── a_r_codec_input_batch.pt │ ├── a_y_codec_decoded_batch.pt │ ├── a_y_codec_encoded_batch.pt │ ├── a_y_codec_input_batch.pt │ ├── a_z_codec_decoded_batch.pt │ ├── a_z_codec_encoded_batch.pt │ ├── a_z_codec_input_batch.pt │ ├── dec_codec_decoded_batch.pt │ ├── dec_codec_encoded_batch.pt │ ├── dec_codec_input_batch.pt │ ├── ra_codec_decoded_batch.pt │ ├── ra_codec_encoded_batch.pt │ ├── ra_codec_input_batch.pt │ ├── FLUX_G_codec_decoded_batch.pt │ ├── FLUX_G_codec_encoded_batch.pt │ ├── FLUX_G_codec_input_batch.pt │ ├── FLUX_I_codec_decoded_batch.pt │ ├── FLUX_I_codec_encoded_batch.pt │ ├── FLUX_I_codec_input_batch.pt │ ├── FLUX_R_codec_decoded_batch.pt │ ├── FLUX_R_codec_encoded_batch.pt │ ├── FLUX_R_codec_input_batch.pt │ ├── FLUX_W1_codec_decoded_batch.pt │ ├── FLUX_W1_codec_encoded_batch.pt │ ├── FLUX_W1_codec_input_batch.pt │ ├── FLUX_W2_codec_decoded_batch.pt │ ├── FLUX_W2_codec_encoded_batch.pt │ ├── FLUX_W2_codec_input_batch.pt │ ├── FLUX_W3_codec_decoded_batch.pt │ ├── FLUX_W3_codec_encoded_batch.pt │ ├── FLUX_W3_codec_input_batch.pt │ ├── FLUX_W4_codec_decoded_batch.pt │ ├── FLUX_W4_codec_encoded_batch.pt │ ├── FLUX_W4_codec_input_batch.pt │ ├── FLUX_Z_codec_decoded_batch.pt │ ├── FLUX_Z_codec_encoded_batch.pt │ ├── FLUX_Z_codec_input_batch.pt │ ├── SHAPE_E1_codec_decoded_batch.pt │ ├── SHAPE_E1_codec_encoded_batch.pt │ ├── SHAPE_E1_codec_input_batch.pt │ ├── SHAPE_E2_codec_decoded_batch.pt │ ├── SHAPE_E2_codec_encoded_batch.pt │ ├── SHAPE_E2_codec_input_batch.pt │ ├── SHAPE_R_codec_decoded_batch.pt │ ├── SHAPE_R_codec_encoded_batch.pt │ ├── SHAPE_R_codec_input_batch.pt │ ├── SPECTRUM_decoded_batch.pt │ ├── SPECTRUM_encoded_batch.pt │ ├── SPECTRUM_input_batch.pt │ ├── catalog_codec_input_batch.pt │ ├── image_codec_encoded_batch.pt │ ├── image_codec_input_batch.pt │ ├── parallax_codec_decoded_batch.pt │ ├── parallax_codec_encoded_batch.pt │ ├── parallax_codec_input_batch.pt │ ├── catalog_codec_decoded_batch.pt │ ├── catalog_codec_encoded_batch.pt │ ├── g_cmodel_mag_codec_decoded_batch.pt │ ├── g_cmodel_mag_codec_encoded_batch.pt │ ├── g_cmodel_mag_codec_input_batch.pt │ ├── i_cmodel_mag_codec_decoded_batch.pt │ ├── i_cmodel_mag_codec_encoded_batch.pt │ ├── i_cmodel_mag_codec_input_batch.pt │ ├── image_codec_decoded_batch.pt │ ├── r_cmodel_mag_codec_decoded_batch.pt │ ├── r_cmodel_mag_codec_encoded_batch.pt │ ├── r_cmodel_mag_codec_input_batch.pt │ ├── scalar-field_codec_encoded_batch.pt │ ├── y_cmodel_mag_codec_decoded_batch.pt │ ├── y_cmodel_mag_codec_encoded_batch.pt │ ├── y_cmodel_mag_codec_input_batch.pt │ ├── z_cmodel_mag_codec_decoded_batch.pt │ ├── z_cmodel_mag_codec_encoded_batch.pt │ ├── z_cmodel_mag_codec_input_batch.pt │ ├── bp_coefficients_codec_decoded_batch.pt │ ├── bp_coefficients_codec_encoded_batch.pt │ ├── bp_coefficients_codec_input_batch.pt │ ├── i_sdssshape_shape11_codec_input_batch.pt │ ├── i_sdssshape_shape12_codec_input_batch.pt │ ├── i_sdssshape_shape22_codec_input_batch.pt │ ├── phot_bp_mean_flux_codec_decoded_batch.pt │ ├── phot_bp_mean_flux_codec_encoded_batch.pt │ ├── phot_bp_mean_flux_codec_input_batch.pt │ ├── phot_g_mean_flux_codec_decoded_batch.pt │ ├── phot_g_mean_flux_codec_encoded_batch.pt │ ├── phot_g_mean_flux_codec_input_batch.pt │ ├── phot_rp_mean_flux_codec_decoded_batch.pt │ ├── phot_rp_mean_flux_codec_encoded_batch.pt │ ├── phot_rp_mean_flux_codec_input_batch.pt │ ├── rp_coefficients_codec_decoded_batch.pt │ ├── rp_coefficients_codec_encoded_batch.pt │ ├── rp_coefficients_codec_input_batch.pt │ ├── scalar-field_codec_decoded_batch.pt │ ├── scalar-field_codec_input_batch.pt │ ├── i_sdssshape_shape11_codec_decoded_batch.pt │ ├── i_sdssshape_shape11_codec_encoded_batch.pt │ ├── i_sdssshape_shape12_codec_decoded_batch.pt │ ├── i_sdssshape_shape12_codec_encoded_batch.pt │ ├── i_sdssshape_shape22_codec_decoded_batch.pt │ └── i_sdssshape_shape22_codec_encoded_batch.pt ├── conftest.py └── codecs │ ├── test_load_codecs.py │ ├── test_scalar_field_codec.py │ ├── test_catalog_codec.py │ ├── test_spectrum_codec.py │ ├── test_scalar_codec.py │ ├── test_image_codec.py │ └── test_codec_manager.py ├── .readthedocs.yaml ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── test-doc.yml │ ├── test.yaml │ ├── publish-pypi.yml │ └── deploy-doc.yml ├── LICENSE ├── pyproject.toml ├── .gitattributes ├── CLAUDE.md ├── .gitignore └── README.md /aion/fourm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aion/codecs/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aion/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import AION 2 | 3 | __all__ = ["AION"] 4 | -------------------------------------------------------------------------------- /assets/aion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AION/HEAD/assets/aion.png -------------------------------------------------------------------------------- /docs/_static/aion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AION/HEAD/docs/_static/aion.png -------------------------------------------------------------------------------- /docs/_static/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AION/HEAD/docs/_static/image.png -------------------------------------------------------------------------------- /docs/_static/embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AION/HEAD/docs/_static/embeddings.png -------------------------------------------------------------------------------- /docs/_static/data_mixture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AION/HEAD/docs/_static/data_mixture.png -------------------------------------------------------------------------------- /docs/_static/polymathic_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/AION/HEAD/docs/_static/polymathic_logo.png -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXBUILD := sphinx-build 2 | SOURCEDIR := . 3 | BUILDDIR := _build 4 | 5 | .PHONY: html 6 | html: 7 | $(SPHINXBUILD) -M html $(SOURCEDIR) $(BUILDDIR) 8 | -------------------------------------------------------------------------------- /tests/test_data/EBV_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:04f898ae89211cda3f6415706f7ed12f0888ab46edf1bcd81d5c47b8c64a2f54 3 | size 2348 4 | -------------------------------------------------------------------------------- /tests/test_data/EBV_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9cbee933f51244cf071bd8b52fcf6f4bb7caa82b410a77a09178124d69d58d5a 3 | size 3372 4 | -------------------------------------------------------------------------------- /tests/test_data/EBV_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a6ff629d1ec7bab53e420ec1435311a78707db36bd2c6d5efbf41832417997b0 3 | size 2274 4 | -------------------------------------------------------------------------------- /tests/test_data/Z_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:58c44b67b276d65517634634474f28c3eea874954edd630dfdb3510c4bbaf637 3 | size 5346 4 | -------------------------------------------------------------------------------- /tests/test_data/Z_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7b6d417ae5fc52575a825a28a48d5a4bb56bd76a8fb30d33ecfd16c309729a9c 3 | size 5346 4 | -------------------------------------------------------------------------------- /tests/test_data/Z_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:05b1c96641bfb55c6e9b8096a2d4adc3813a609f8c5a788c33f6e2137d196bf0 3 | size 5336 4 | -------------------------------------------------------------------------------- /tests/test_data/a_g_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:44c7e436f4fb148e8caea7f363bf1539c40ec5382d59f448ef503d7e0bcd80f4 3 | size 5292 4 | -------------------------------------------------------------------------------- /tests/test_data/a_g_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3408d4101126dc73315b13fe1187816a214f9560c11df894687f53c6523be4fe 3 | size 9324 4 | -------------------------------------------------------------------------------- /tests/test_data/a_g_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e7273e0b0472e3db58de19156820a81c1d4d796cee9aa6e66a70b1774b23dfad 3 | size 5218 4 | -------------------------------------------------------------------------------- /tests/test_data/a_i_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ab6aa38b3c9740a596eba81e2651c5500461f6de94ed5608766bac5622d45003 3 | size 5292 4 | -------------------------------------------------------------------------------- /tests/test_data/a_i_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ecf7ac3dde28c9083447b2e0612051613c89d3db4129233797272dcb06a9e1df 3 | size 9324 4 | -------------------------------------------------------------------------------- /tests/test_data/a_i_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e3f5e22f15e2d5379b789733fd4959bd8d1afc73b1b724a8a625736b9f2667d9 3 | size 5218 4 | -------------------------------------------------------------------------------- /tests/test_data/a_r_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eb19709935f3d6eb729b850cc59ac35a1d9bac553da347cf099267c3fa304b77 3 | size 5292 4 | -------------------------------------------------------------------------------- /tests/test_data/a_r_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:33edc4149e5095c067f00a2fce8d86ee5b5dcd0caad37f347dc6a1b16a5b6631 3 | size 9324 4 | -------------------------------------------------------------------------------- /tests/test_data/a_r_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:cd71cc586ad9d77c831b3178da004060b8fdb3103dda91741468ee53ac19db26 3 | size 5218 4 | -------------------------------------------------------------------------------- /tests/test_data/a_y_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:94bb238f37d3176bf4015918e570106e33e03bec0963b8df7cfc8377f95d23ec 3 | size 5292 4 | -------------------------------------------------------------------------------- /tests/test_data/a_y_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b9e0a93677556c4546260d95b9c12eef7e9831f0e95c0a81d29383c4f4f5df4b 3 | size 9324 4 | -------------------------------------------------------------------------------- /tests/test_data/a_y_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:487f91e411d2eef5799ebfb2041e502ca83748f57da356ac88f047598b30bff1 3 | size 5218 4 | -------------------------------------------------------------------------------- /tests/test_data/a_z_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6e0a31e9b24f372ee6dce0710fbce4102512054ef9cbc273be1b61519da7a56f 3 | size 5292 4 | -------------------------------------------------------------------------------- /tests/test_data/a_z_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:da0c0cf296c1674395389ca61eeeafcda112eabdeb1eb58940c1e504b9209713 3 | size 9324 4 | -------------------------------------------------------------------------------- /tests/test_data/a_z_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3af8a7012cbf4a5fe03fac6451257f23fb6699a59f33c16df2b843a8517a3763 3 | size 5218 4 | -------------------------------------------------------------------------------- /tests/test_data/dec_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ddc1ae399a3a0f527bec1f78c919353ec2c529be9ef0a0b04814b31f003ca4fb 3 | size 5420 4 | -------------------------------------------------------------------------------- /tests/test_data/dec_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8ffbe131993ccd27e116a88f265d373e0315988ce48bd3f1e4e7df7f0d52460e 3 | size 9516 4 | -------------------------------------------------------------------------------- /tests/test_data/dec_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e1c28304251535d026f8614235f7fc1de9abd8211899c645ed9e5866cd8c339d 3 | size 5346 4 | -------------------------------------------------------------------------------- /tests/test_data/ra_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d8138ca6c8ffcae0d12df103a81f933de8f86612dafdf45d211d8ad1899cd24d 3 | size 5415 4 | -------------------------------------------------------------------------------- /tests/test_data/ra_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:446469f0ba0c362a4e9a2f04e091b0fa0b1d58dfd85044d3a202d8387dc303aa 3 | size 9511 4 | -------------------------------------------------------------------------------- /tests/test_data/ra_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c98db9d770d5a3941752501f5f0f136955b3bfd0c7f070af8662e69a6ab53093 3 | size 5341 4 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(scope="session") 7 | def data_dir(): 8 | return Path(__file__).parent / "test_data" 9 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_G_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b33d4f263123d081dca4329ec7f6f5a3c771b4bf54c9c631f505573656c32a71 3 | size 2363 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_G_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0ef5379e891fa93ca35fd6bb780fec7fa95c16392169a9dc8afa91955bc5ae6f 3 | size 3387 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_G_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:afc2d7f110a5a6eaf6af1d98f15693c2396fcc6d06c0734ccf35eb57eeebf68e 3 | size 2353 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_I_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:74a7ff526e6ca4e2df561e1eca03c5f5fbe165752512363b18502f58cbbe859b 3 | size 2363 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_I_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bcaf0c82868ef1b8659cf425543693abeec525c5f22dc1eafd91862b06dd38cd 3 | size 3387 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_I_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:32e1ad1e4d2a7aca9e72cf6b97ab1c540e65b2637663eaf65fd0ecde6b5bdf5b 3 | size 2353 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_R_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fddb08432e740000cde5221259b3c74d30abe40b70e75f1c1060979b8e0c49e8 3 | size 2363 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_R_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7e07cf090e1665531eb5cffe1c5c0621abd78228d95c5ebdf3027c2b6a80df8f 3 | size 3387 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_R_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c9c631f2c2775c4787abd7256e5abf16da79db8b853bdb904dc28fdf138694cb 3 | size 2353 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W1_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:54c356e62d24deb744036aacad1f9cd24767e64fa901306a6b6dbb57835d4410 3 | size 2368 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W1_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:115b0afdf57c20320841fca9d7c825db78299a61dd894bfa353cce070b79c59a 3 | size 3392 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W1_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e27bca92ebf33e15c3900fba01a9ce2881a23afa0b50ac2a1b77e7d392bfd489 3 | size 2358 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W2_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:07fc4f92c601612566bd92315b4cf8f14be6247e0f6e91e3c61f12e37aca5d13 3 | size 2368 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W2_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d4e5edd6591bc94c6e6d87258bf68c34b37ec0d925913e21c844f4e64169c7b8 3 | size 3392 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W2_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4e770b19ed996f1fa3202d511ea5798b15c4b1d7ce9d5ba0a87a5d7b80c5bce8 3 | size 2358 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W3_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f7544d34d4ed5a3557180d509824e640441d7d8640963b9beecc940540f35ef4 3 | size 2368 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W3_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bc3652ec155f1e1f22745b1e1248f66285a838e6f36945bcc0b2ddb84d535c6d 3 | size 3392 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W3_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:59d70cf5329662b6745f5f581280b32265659a4c4966f551fc862fdd9888bd82 3 | size 2358 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W4_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:82e9c5697882f7486e7fd07ff3c3321d76001f077c939580ba73aa5868580800 3 | size 2368 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W4_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8995dcad6425e93104c3ab6f86ac9856308196792de0538d57483c60e9596e69 3 | size 3392 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_W4_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e19af753588313f0e645d758b0948d7615604a5b56a3b4221dcafe4b426a7a31 3 | size 2358 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_Z_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8e5c9437f9e8727118c606e0bda3214d7b0be9b829955d2c208392b2c9547208 3 | size 2363 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_Z_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e97c5ae6733854dd70c783fad442408c75459b163d58b450725519b10baba5c6 3 | size 3387 4 | -------------------------------------------------------------------------------- /tests/test_data/FLUX_Z_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5a06f5e3b7e04f296ea97d4bde358623929bd4452dbae1ed9780519159edac6f 3 | size 2353 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_E1_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:988f4992151cb226782cd120a4f0aab7768e05ddf4a2a562422febc913b4f35c 3 | size 2373 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_E1_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:32bba5c63884efc8f3d340809d5eb577c93d9b88610aaafa8b8e3d04a3ae1814 3 | size 3397 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_E1_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e293edd1d683145b892c0eb796725143e78ada599668d02fabb0eebcec77646c 3 | size 2363 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_E2_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:68a8e2d3bc39187d17598f994ae86182828b79da0d07739f6173b20769805fea 3 | size 2373 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_E2_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6701df912978f9b5dbc0b15b5da6ddb44d8e6f1e6c5e70fd62c1f47620e68400 3 | size 3397 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_E2_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:031cbf7a92a7122f53bfb01a223dac1d8aed1d057e1b3a8a1b8123cf2e4e927b 3 | size 2363 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_R_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:95db6b09d7d03ae849ff355f3c2730ecf9a23a22167de9ec42bdf2d9747824e2 3 | size 2368 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_R_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3c46b4bc1c01fbaa03b18be80d19a9ac28ebdb7e237d3c043da11a62fb5e9bac 3 | size 3392 4 | -------------------------------------------------------------------------------- /tests/test_data/SHAPE_R_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8b54f21088d454e3e88ef3e13c21b1d400ca96790acc7ed156df78c33a3755eb 3 | size 2358 4 | -------------------------------------------------------------------------------- /tests/test_data/SPECTRUM_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7f01193b6a2e5284e419451c614157955559d4500ae4614da0ab4405e70d85cb 3 | size 13371261 4 | -------------------------------------------------------------------------------- /tests/test_data/SPECTRUM_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:67b58211ca19f193da68a2a281ff01f3ab50bddb950bf7a20b89cd638ad41d5d 3 | size 280871 4 | -------------------------------------------------------------------------------- /tests/test_data/SPECTRUM_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:524953206a9e1f439984d31c56b49d9fe1af40bcdcafed314763f91a88426c85 3 | size 16975169 4 | -------------------------------------------------------------------------------- /tests/test_data/catalog_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:627dc9ab69c3d26f3293029f9106fba20868e5aa4b245d00f1723f947f6e6c35 3 | size 330158 4 | -------------------------------------------------------------------------------- /tests/test_data/image_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6736810b57fa075fde9bafff62caa8ab1c22e0d675d43181ae25b3a4ad05d3db 3 | size 148790 4 | -------------------------------------------------------------------------------- /tests/test_data/image_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8060ce523ec4ca7beed303744cf4affeea7983045779a93c82d63e5e15ff5f16 3 | size 21235896 4 | -------------------------------------------------------------------------------- /tests/test_data/parallax_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9bdb625fbc70864218deadd00d416f94450518c0b52a90ee8187cf6b552ad389 3 | size 5445 4 | -------------------------------------------------------------------------------- /tests/test_data/parallax_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:988d32f6bd6886319cad00551f4e6bc80396cb9044baffaa19130d3a48a349ce 3 | size 9541 4 | -------------------------------------------------------------------------------- /tests/test_data/parallax_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c8c5b88cb4097b8f8ad64bfdf9322f82bee682b367f1a0d8262c396d7d8f4cba 3 | size 5435 4 | -------------------------------------------------------------------------------- /tests/test_data/catalog_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:436f2a56496712c65fa571c91d85e15b58268673ac0ab77facebe510219466d5 3 | size 575936 4 | -------------------------------------------------------------------------------- /tests/test_data/catalog_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6091832e50b9060014065e3ab85130250fe02c2960dd16a10b78561ca20abbb3 3 | size 820544 4 | -------------------------------------------------------------------------------- /tests/test_data/g_cmodel_mag_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b35c8a37beab0f0a873e3937f62acc80302118cebc06f2fd68392519114ba362 3 | size 5337 4 | -------------------------------------------------------------------------------- /tests/test_data/g_cmodel_mag_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:db97d0309f9b50b2b0afee08b42483789dd0294444f7c099af26ebfe10e6a481 3 | size 9369 4 | -------------------------------------------------------------------------------- /tests/test_data/g_cmodel_mag_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0fe9ce80f50d78260c4a72b8884c4f7de971e9169cb8075f3ae62a51e272ff01 3 | size 5327 4 | -------------------------------------------------------------------------------- /tests/test_data/i_cmodel_mag_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2b55e00a4567fc6ea3eeb33f2c47d1d41a3defa26c83d9a548ae5b89d0231060 3 | size 5337 4 | -------------------------------------------------------------------------------- /tests/test_data/i_cmodel_mag_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7bc1f5795d08ecac573a4aef1676591b905c05c3f5e5cea8720851b6b7db6669 3 | size 9369 4 | -------------------------------------------------------------------------------- /tests/test_data/i_cmodel_mag_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9c4e637f979fa9d03aefc516def05ea072a90370d449443495ca463da31a06c7 3 | size 5327 4 | -------------------------------------------------------------------------------- /tests/test_data/image_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:42e8b42e27dbb00314ff56e40470035afaa78ed7d9e39b3189a7eee35803882b 3 | size 21234998 4 | -------------------------------------------------------------------------------- /tests/test_data/r_cmodel_mag_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:50cc2ae3a156f47156302a5b4c5d605cf0afadbd931343f00853ffa47283e9e4 3 | size 5337 4 | -------------------------------------------------------------------------------- /tests/test_data/r_cmodel_mag_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4635aebfe84585fc7ca978b3dd0ca9eacaf05d1f8759a24471b63a96e0d8e19a 3 | size 9369 4 | -------------------------------------------------------------------------------- /tests/test_data/r_cmodel_mag_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c25163f70867884d7f00d6686e593268110792a128bac0ada84da2a560419aca 3 | size 5327 4 | -------------------------------------------------------------------------------- /tests/test_data/scalar-field_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1e3cc973342b955e4d449ea41809ad55ea038b419ccc9dc2fab648537fda5d1f 3 | size 75097 4 | -------------------------------------------------------------------------------- /tests/test_data/y_cmodel_mag_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ae94754215c6a80d59bae3daf8b6e4a261c5592cb001ffc3920af9536339a8da 3 | size 5337 4 | -------------------------------------------------------------------------------- /tests/test_data/y_cmodel_mag_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c9d1c5668060576a6a854de94d492207fe744a399834a618810081ebaf5b27b6 3 | size 9369 4 | -------------------------------------------------------------------------------- /tests/test_data/y_cmodel_mag_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a45c15800d3607d72d07c5ba619e51398b30551cf363bf28417bb4eef4136f04 3 | size 5327 4 | -------------------------------------------------------------------------------- /tests/test_data/z_cmodel_mag_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:48fb6306e0ef2905a875d8173d9b38b7ee703d9422c40ec0b6db297396dcc8b0 3 | size 5337 4 | -------------------------------------------------------------------------------- /tests/test_data/z_cmodel_mag_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8b13785bf64cddf9549600db4091537414d117ce5d3420bf7b726fe38cbccad4 3 | size 9369 4 | -------------------------------------------------------------------------------- /tests/test_data/z_cmodel_mag_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:aa5ae56dce54118343b9d10fcf8908d4c929e26440ec940a08a574bec8a9d8db 3 | size 5327 4 | -------------------------------------------------------------------------------- /tests/test_data/bp_coefficients_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:16a5d3a2fa3aa41f064fb690fc111236e85f0e3d55b9d1e3a98051e00ab281f5 3 | size 226664 4 | -------------------------------------------------------------------------------- /tests/test_data/bp_coefficients_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9bf9770aa3f506854701ed0c34903e343098e92c8f0fcf44e273ec7120973eb1 3 | size 451944 4 | -------------------------------------------------------------------------------- /tests/test_data/bp_coefficients_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:add58e0e8185075540d9aededdca02916e96a0e03bb0284fe3daf478091bed63 3 | size 226654 4 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape11_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f264979f9c17d0dc6e2aba3a59a49797de3bd26ab9dd4c8791d9a8b2f2451a43 3 | size 5362 4 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape12_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:cea800dae8b2fef070edb53fabec8c47624014654b410f5e0b11987119a9711c 3 | size 5362 4 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape22_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6915da9c92aa31508c809a39fa8f5669ef25e81313a37b9cffb67aa470267efe 3 | size 5362 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_bp_mean_flux_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4eb832cd5e130df476a74fe8ad6ec98cca2eb1dacbfd2c502faa2c0ea1ae3ff9 3 | size 5490 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_bp_mean_flux_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2d7cd98f48b67a22bf4e74252a19a8610788f377fb1785362a6af8a2b5bdaa93 3 | size 9586 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_bp_mean_flux_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:344dfff0df2c41c683fb2cca65e82b532d86dc4773d6109948bd379ecdf7b86c 3 | size 5480 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_g_mean_flux_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0eeebba1a3c56b67c4541e407083c9729a279974085310fc384fdf0a8fe2c071 3 | size 5485 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_g_mean_flux_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e30d0fac7c5eb91afcd3d29f3974e17640b193861bee50b0c7f96d2211321034 3 | size 9581 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_g_mean_flux_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3df4b9771b2d10dda268350eed658a761c5a80d0c903d81bec51c82cdd892ac5 3 | size 5475 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_rp_mean_flux_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:708dbc5900c7c578e9892f0b4d5ed6689613f4d3d89e6acc35607f47b7c4cde7 3 | size 5490 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_rp_mean_flux_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2f33f20b49504599de2967687ba50472f359d23c56867801a2f43a76452aa951 3 | size 9586 4 | -------------------------------------------------------------------------------- /tests/test_data/phot_rp_mean_flux_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8c76c40a9e3a150419c4fe6fae1d9fb65c27d60c4136a99cad8648da96594014 3 | size 5480 4 | -------------------------------------------------------------------------------- /tests/test_data/rp_coefficients_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6ed8e367ff1cf86946442c6632974cb0589298223aabaebb883e908c6cb41373 3 | size 226664 4 | -------------------------------------------------------------------------------- /tests/test_data/rp_coefficients_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:69e13590b019bdde0c1afb9d24996412e4e3def82c1ca4e4f27bbf8021650294 3 | size 451944 4 | -------------------------------------------------------------------------------- /tests/test_data/rp_coefficients_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:17154c009a1d9d67db2fbd796c5b174676b2d3d478804a62fac8ad26f26a46d6 3 | size 226654 4 | -------------------------------------------------------------------------------- /tests/test_data/scalar-field_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:15f55b8a2a22c27dee212386c5de5374095711ba395cefdcefb0dbe584276900 3 | size 4719961 4 | -------------------------------------------------------------------------------- /tests/test_data/scalar-field_codec_input_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7bfaf6edd2874da62a3c2fd2a6739ece3b602e0d23a43531a6ab744ef498cc96 3 | size 13108559 4 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | sphinx: 3 | configuration: docs/conf.py 4 | fail_on_warning: false 5 | python: 6 | version: 3.10 7 | install: 8 | - requirements: docs/requirements.txt 9 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape11_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fbb0f72bd880ddf3464e2991471513f9c34dd92f3abf1d9d8f1de49dd66c2ebd 3 | size 5372 4 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape11_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e4d85c4f9f3f7e5ed4f1b6049612fd443fbbda2a70b92e792db45c48424566aa 3 | size 9404 4 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape12_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:02e13aa4d7f37b5d49721fecd7356a4a4b752ba14a2f8cb9f1a20a3305523708 3 | size 5372 4 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape12_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:44e2ac8f07e1008a2e3796c63c23ad360e7b449a743d93c7dfd7577c7804652d 3 | size 9340 4 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape22_codec_decoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d82cf5f47ab2a925038ac3d7c6d6769c1ae7ed29179c22bf32acd118a0f95fcc 3 | size 5372 4 | -------------------------------------------------------------------------------- /tests/test_data/i_sdssshape_shape22_codec_encoded_batch.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6af60d2f45b793ab1aaf3b0dfeb793c59bfe41abb95700df9a3e16d4c22840e6 3 | size 9404 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.9.6 4 | hooks: 5 | - id: ruff 6 | args: [--fix] 7 | - id: ruff-format 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v5.0.0 10 | hooks: 11 | - id: check-merge-conflict 12 | - id: check-toml 13 | - id: check-yaml 14 | args: [--unsafe] 15 | - id: end-of-file-fixer 16 | - id: mixed-line-ending 17 | args: [--fix=lf] 18 | - id: trailing-whitespace 19 | -------------------------------------------------------------------------------- /aion/codecs/preprocessing/band_to_index.py: -------------------------------------------------------------------------------- 1 | # Keeps track of the band indices for HSC and DES bands 2 | BAND_TO_INDEX = { 3 | "HSC-G": 0, 4 | "HSC-R": 1, 5 | "HSC-I": 2, 6 | "HSC-Z": 3, 7 | "HSC-Y": 4, 8 | "DES-G": 5, 9 | "DES-R": 6, 10 | "DES-I": 7, 11 | "DES-Z": 8, 12 | } 13 | 14 | # Maximum band center values for HSC and DES bands 15 | BAND_CENTER_MAX = { 16 | "HSC-G": 80, 17 | "HSC-R": 110, 18 | "HSC-I": 200, 19 | "HSC-Z": 330, 20 | "HSC-Y": 500, 21 | "DES-G": 6, 22 | "DES-R": 15, 23 | "DES-I": 20, 24 | "DES-Z": 25, 25 | } 26 | -------------------------------------------------------------------------------- /aion/codecs/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import ImageCodec 2 | from .scalar import ScalarCodec, LogScalarCodec, MultiScalarCodec, GridScalarCodec 3 | from .spectrum import SpectrumCodec 4 | from .catalog import CatalogCodec 5 | from .scalar_field import ScalarFieldCodec 6 | from .base import Codec 7 | from .manager import CodecManager 8 | 9 | __all__ = [ 10 | "ImageCodec", 11 | "ScalarCodec", 12 | "LogScalarCodec", 13 | "MultiScalarCodec", 14 | "GridScalarCodec", 15 | "SpectrumCodec", 16 | "CatalogCodec", 17 | "ScalarFieldCodec", 18 | "Codec", 19 | "CodecManager", 20 | ] 21 | -------------------------------------------------------------------------------- /.github/workflows/test-doc.yml: -------------------------------------------------------------------------------- 1 | name: Check Documentation Build 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - 'docs/**' 7 | - 'aion/**' 8 | - '.github/workflows/docs-check.yml' 9 | - 'pyproject.toml' 10 | 11 | jobs: 12 | docs: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: astral-sh/setup-uv@v5 18 | with: 19 | enable-cache: true 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.11' 24 | - name: Install dependencies 25 | run: | 26 | uv sync --all-extras --dev 27 | - name: Build documentation 28 | run: | 29 | cd docs 30 | uv run sphinx-build -W -b html . _build/html 31 | - name: Check for broken links 32 | run: | 33 | cd docs 34 | uv run sphinx-build -b linkcheck . _build/linkcheck || true 35 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | pre-commit: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Run pre-commit hooks 15 | uses: pre-commit/action@v3.0.1 16 | with: 17 | extra_args: --all-files 18 | test: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | with: 23 | lfs: true 24 | - uses: astral-sh/setup-uv@v5 25 | with: 26 | enable-cache: true 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: "3.11" 30 | - name: Install AION 31 | run: uv sync --all-extras --dev 32 | - name: Run tests 33 | env: 34 | HF_TOKEN: ${{ secrets.AION_HF_TOKEN }} 35 | PY_COLORS: "1" 36 | run: uv run pytest tests 37 | -------------------------------------------------------------------------------- /tests/codecs/test_load_codecs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from aion.codecs import ImageCodec 5 | from aion.codecs.config import HF_REPO_ID 6 | from aion.modalities import Image, LegacySurveyCatalog, LegacySurveyImage 7 | 8 | 9 | def test_load_invalid_modality(): 10 | """Test that loading a modality raises an error.""" 11 | with pytest.raises(TypeError): 12 | ImageCodec.from_pretrained(HF_REPO_ID, modality=LegacySurveyCatalog) 13 | 14 | 15 | def test_load_image_codec(): 16 | """Test that loading an image codec raises an error.""" 17 | codec_image = ImageCodec.from_pretrained(HF_REPO_ID, modality=Image) 18 | codec_legacy_survey_image = ImageCodec.from_pretrained( 19 | HF_REPO_ID, modality=LegacySurveyImage 20 | ) 21 | for param_image, param_legacy_survey_image in zip( 22 | codec_image.parameters(), codec_legacy_survey_image.parameters() 23 | ): 24 | assert torch.equal(param_image, param_legacy_survey_image) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Polymathic AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/publish-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' # Triggers on version tags like v1.0.0, v0.1.2, etc. 7 | 8 | jobs: 9 | publish: 10 | name: Publish to PyPI 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/polymathic-aion 15 | permissions: 16 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 17 | contents: read 18 | 19 | steps: 20 | - name: Checkout code 21 | uses: actions/checkout@v4 22 | with: 23 | fetch-depth: 0 # Fetch full history for setuptools_scm 24 | 25 | - name: Set up Python 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: '3.11' 29 | 30 | - name: Install build dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | python -m pip install build 34 | 35 | - name: Build package 36 | run: python -m build 37 | 38 | - name: Publish to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | # This action uses OIDC trusted publishing - no API tokens needed! 41 | # Just configure your PyPI project to trust this GitHub repository 42 | -------------------------------------------------------------------------------- /tests/codecs/test_scalar_field_codec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from aion.codecs import ScalarFieldCodec 4 | from aion.codecs.config import HF_REPO_ID 5 | from aion.modalities import LegacySurveySegmentationMap 6 | 7 | 8 | def test_scalar_field_tokenizer(data_dir): 9 | codec = ScalarFieldCodec.from_pretrained( 10 | HF_REPO_ID, modality=LegacySurveySegmentationMap 11 | ) 12 | codec.eval() 13 | input_batch = torch.load( 14 | data_dir / "scalar-field_codec_input_batch.pt", weights_only=False 15 | ) 16 | reference_encoded_batch = torch.load( 17 | data_dir / "scalar-field_codec_encoded_batch.pt", weights_only=False 18 | ) 19 | reference_decoded_batch = torch.load( 20 | data_dir / "scalar-field_codec_decoded_batch.pt", weights_only=False 21 | ) 22 | 23 | # We flatten the reference encoded output to match the encoded output 24 | # as we now make all codecs return flattened outputs 25 | reference_encoded_batch = reference_encoded_batch.reshape( 26 | reference_encoded_batch.shape[0], -1 27 | ) 28 | 29 | with torch.no_grad(): 30 | output = codec.encode(LegacySurveySegmentationMap(field=input_batch)) 31 | decoded_output = codec.decode(output) 32 | 33 | assert torch.allclose(output, reference_encoded_batch) 34 | assert torch.allclose( 35 | decoded_output.field, reference_decoded_batch, atol=1e-4, rtol=1e-4 36 | ) 37 | -------------------------------------------------------------------------------- /tests/codecs/test_catalog_codec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from aion.codecs import CatalogCodec 4 | from aion.modalities import LegacySurveyCatalog 5 | 6 | from aion.codecs.config import HF_REPO_ID 7 | 8 | 9 | def test_catalog_tokenizer(data_dir): 10 | codec = CatalogCodec.from_pretrained(HF_REPO_ID, modality=LegacySurveyCatalog) 11 | codec.eval() 12 | input_batch = torch.load( 13 | data_dir / "catalog_codec_input_batch.pt", weights_only=False 14 | ) 15 | reference_encoded_batch = torch.load( 16 | data_dir / "catalog_codec_encoded_batch.pt", weights_only=False 17 | ) 18 | reference_decoded_batch = torch.load( 19 | data_dir / "catalog_codec_decoded_batch.pt", weights_only=False 20 | ) 21 | 22 | with torch.no_grad(): 23 | output = codec.encode(LegacySurveyCatalog(**input_batch)) 24 | decoded_output = codec.decode(output) 25 | 26 | assert torch.allclose(output, reference_encoded_batch) 27 | assert torch.allclose(decoded_output.X, reference_decoded_batch["X"], atol=1e-5) 28 | assert torch.allclose(decoded_output.Y, reference_decoded_batch["Y"], atol=1e-5) 29 | assert torch.allclose( 30 | decoded_output.SHAPE_E1, reference_decoded_batch["SHAPE_E1"], atol=1e-5 31 | ) 32 | assert torch.allclose( 33 | decoded_output.SHAPE_E2, reference_decoded_batch["SHAPE_E2"], atol=1e-5 34 | ) 35 | assert torch.allclose( 36 | decoded_output.SHAPE_R, reference_decoded_batch["SHAPE_R"], atol=1e-5 37 | ) 38 | -------------------------------------------------------------------------------- /.github/workflows/deploy-doc.yml: -------------------------------------------------------------------------------- 1 | name: Build Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | # Allows you to run this workflow manually from the Actions tab 8 | workflow_dispatch: 9 | 10 | # Allow one concurrent deployment 11 | concurrency: 12 | group: "pages" 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - uses: astral-sh/setup-uv@v5 22 | with: 23 | enable-cache: true 24 | - name: Set up Python 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: '3.11' 28 | - name: Install dependencies 29 | run: | 30 | uv sync --all-extras --dev 31 | - name: Build HTML documentation 32 | run: | 33 | cd docs 34 | uv run sphinx-build -b html . _build/html 35 | 36 | - name: Upload artifact 37 | uses: actions/upload-pages-artifact@v3 38 | with: 39 | path: docs/_build/html 40 | 41 | # Deploy job - only runs on main branch 42 | deploy: 43 | needs: build 44 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 45 | permissions: 46 | contents: read 47 | pages: write 48 | id-token: write 49 | environment: 50 | name: github-pages 51 | url: ${{ steps.deployment.outputs.page_url }} 52 | runs-on: ubuntu-latest 53 | 54 | steps: 55 | - name: Deploy to GitHub Pages 56 | id: deployment 57 | uses: actions/deploy-pages@v4 58 | -------------------------------------------------------------------------------- /aion/codecs/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | 6 | class LayerNorm(torch.nn.Module): 7 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. 8 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 9 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 10 | with shape (batch_size, channels, height, width). 11 | """ 12 | 13 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 14 | super().__init__() 15 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) 16 | self.bias = torch.nn.Parameter(torch.zeros(normalized_shape)) 17 | self.eps = eps 18 | self.data_format = data_format 19 | if self.data_format not in ["channels_last", "channels_first"]: 20 | raise NotImplementedError 21 | self.normalized_shape = (normalized_shape,) 22 | 23 | def forward(self, x): 24 | if self.data_format == "channels_last": 25 | return F.layer_norm( 26 | x, self.normalized_shape, self.weight, self.bias, self.eps 27 | ) 28 | elif self.data_format == "channels_first": 29 | x = rearrange(x, "b c ... -> b ... c") 30 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 31 | return rearrange(x, "b ... c -> b c ...") 32 | 33 | 34 | class GRN(torch.nn.Module): 35 | """GRN (Global Response Normalization) layer""" 36 | 37 | def __init__(self, dim): 38 | super().__init__() 39 | self.gamma = torch.nn.Parameter(torch.zeros(1, 1, dim)) 40 | self.beta = torch.nn.Parameter(torch.zeros(1, 1, dim)) 41 | 42 | def forward(self, x): 43 | Gx = torch.norm(x, p=2, dim=(1,), keepdim=True) 44 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 45 | return self.gamma * (x * Nx) + self.beta + x 46 | -------------------------------------------------------------------------------- /aion/fourm/modality_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 EPFL and Apple Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional, Tuple 15 | from abc import ABC, abstractmethod 16 | 17 | import numpy as np 18 | import torch 19 | 20 | 21 | class AbstractTransform(ABC): 22 | @abstractmethod 23 | def load(self, sample): 24 | pass 25 | 26 | @abstractmethod 27 | def preprocess(self, sample): 28 | pass 29 | 30 | @abstractmethod 31 | def image_augment( 32 | self, 33 | v, 34 | crop_coords: Tuple, 35 | flip: bool, 36 | orig_size: Tuple, 37 | target_size: Tuple, 38 | rand_aug_idx: Optional[int], 39 | resample_mode: str = None, 40 | ): 41 | pass 42 | 43 | @abstractmethod 44 | def postprocess(self, v): 45 | pass 46 | 47 | 48 | class TokTransform(AbstractTransform): 49 | def __init__(self): 50 | pass 51 | 52 | def load(self, path): 53 | sample = np.load(path).astype(int) 54 | return sample 55 | 56 | def preprocess(self, sample): 57 | return sample 58 | 59 | def image_augment( 60 | self, 61 | v, 62 | crop_coords: Tuple, 63 | flip: bool, 64 | orig_size: Tuple, 65 | target_size: Tuple, 66 | rand_aug_idx: Optional[int], 67 | resample_mode: str = None, 68 | ): 69 | if rand_aug_idx is None: 70 | raise ValueError( 71 | "Crop settings / augmentation index are missing but a pre-tokenized modality is being used" 72 | ) 73 | v = torch.tensor(v[rand_aug_idx]) 74 | return v 75 | 76 | def postprocess(self, sample): 77 | return sample 78 | -------------------------------------------------------------------------------- /tests/codecs/test_spectrum_codec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from aion.codecs import SpectrumCodec 4 | from aion.codecs.config import HF_REPO_ID 5 | from aion.modalities import Spectrum 6 | 7 | 8 | def test_hf_previous_predictions(data_dir): 9 | codec = SpectrumCodec.from_pretrained(HF_REPO_ID, modality=Spectrum) 10 | 11 | input_batch = torch.load(data_dir / "SPECTRUM_input_batch.pt", weights_only=False)[ 12 | "spectrum" 13 | ] 14 | reference_encoded_output = torch.load( 15 | data_dir / "SPECTRUM_encoded_batch.pt", weights_only=False 16 | ) 17 | reference_decoded_output = torch.load( 18 | data_dir / "SPECTRUM_decoded_batch.pt", weights_only=False 19 | ) 20 | 21 | with torch.no_grad(): 22 | # Create Spectrum modality instance 23 | spectrum_input = Spectrum( 24 | flux=input_batch["flux"], 25 | ivar=input_batch["ivar"], 26 | mask=input_batch["mask"], 27 | wavelength=input_batch["lambda"], 28 | ) 29 | 30 | encoded_output = codec.encode(spectrum_input) 31 | assert encoded_output.shape == reference_encoded_output.shape 32 | assert torch.allclose(encoded_output, reference_encoded_output) 33 | 34 | decoded_spectrum = codec.decode(encoded_output) 35 | 36 | assert ( 37 | decoded_spectrum.flux.shape 38 | == reference_decoded_output["spectrum"]["flux"].shape 39 | ) 40 | assert torch.allclose( 41 | decoded_spectrum.flux, 42 | reference_decoded_output["spectrum"]["flux"], 43 | rtol=1e-3, 44 | atol=1e-4, 45 | ) 46 | assert ( 47 | decoded_spectrum.wavelength.shape 48 | == reference_decoded_output["spectrum"]["lambda"].shape 49 | ) 50 | assert torch.allclose( 51 | decoded_spectrum.wavelength, 52 | reference_decoded_output["spectrum"]["lambda"], 53 | rtol=1e-3, 54 | atol=1e-4, 55 | ) 56 | assert ( 57 | decoded_spectrum.mask.shape 58 | == reference_decoded_output["spectrum"]["mask"].shape 59 | ) 60 | assert torch.allclose( 61 | decoded_spectrum.mask, reference_decoded_output["spectrum"]["mask"].bool() 62 | ) 63 | -------------------------------------------------------------------------------- /aion/codecs/modules/subsampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from jaxtyping import Bool, Float 5 | 6 | 7 | class SubsampledLinear(torch.nn.Module): 8 | def __init__(self, dim_in: int, dim_out: int, subsample_in: bool = True): 9 | """ 10 | Subsampled linear layer for the encoder. 11 | It takes in a zero-padded tensor and a mask. 12 | It projects the tensor into some shared projection space. 13 | It can also be used to reverse out of the space with the mask. 14 | 15 | Args: 16 | dim_in : Number of total possible bands. 17 | dim_out : Number of embedding dimensions. 18 | subsample_in : Whether to subsample the input. Defaults to True. 19 | """ 20 | super().__init__() 21 | self.subsample_in = subsample_in 22 | self.dim_in = dim_in # Number of total possible bands 23 | self.dim_out = dim_out # Number of embedding dimensions 24 | temp_linear = torch.nn.Linear(dim_in, dim_out) 25 | self.weight = torch.nn.Parameter(temp_linear.weight) 26 | self.bias = torch.nn.Parameter(temp_linear.bias) 27 | 28 | def _subsample_in(self, x, labels: Bool[torch.Tensor, " b c"]): 29 | # Get mask 30 | mask = labels[:, None, None, :].float() 31 | x = x * mask 32 | 33 | # Normalize 34 | label_sizes = labels.sum(dim=1, keepdim=True) 35 | scales = ((self.dim_in / label_sizes) ** 0.5).squeeze(-1) 36 | 37 | # Apply linear layer 38 | return scales[:, None, None, None] * F.linear(x, self.weight, self.bias) 39 | 40 | def _subsample_out(self, x, labels): 41 | # Get mask 42 | mask = labels[:, None, None, :].float() 43 | 44 | # Apply linear layer and mask 45 | return F.linear(x, self.weight, self.bias) * mask 46 | 47 | def forward( 48 | self, x: Float[torch.Tensor, " b c h w"], labels: Bool[torch.Tensor, " b c"] 49 | ) -> Float[torch.Tensor, " b c h w"]: 50 | x = rearrange(x, "b c h w -> b h w c") 51 | 52 | if self.subsample_in: 53 | x = self._subsample_in(x, labels) 54 | 55 | else: 56 | x = self._subsample_out(x, labels) 57 | 58 | x = rearrange(x, "b h w c -> b c h w") 59 | 60 | return x 61 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "setuptools_scm>=8.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "polymathic-aion" 7 | dynamic = ["version"] 8 | description = "AstronomIcal Omnimodal Network - Polymathic's Large Omnimodal Model for Astronomy" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = {text = "MIT"} 12 | authors = [ 13 | {name = "Polymathic AI", email = "info@polymathic-ai.org"}, 14 | ] 15 | maintainers = [ 16 | {name = "Polymathic AI", email = "info@polymathic-ai.org"}, 17 | ] 18 | keywords = [ 19 | "astronomy", 20 | "astrophysics", 21 | "machine learning", 22 | "transformer", 23 | "multimodal", 24 | "deep learning", 25 | "scientific computing", 26 | "AI", 27 | ] 28 | classifiers = [ 29 | "Development Status :: 4 - Beta", 30 | "Intended Audience :: Science/Research", 31 | "License :: OSI Approved :: MIT License", 32 | "Operating System :: OS Independent", 33 | "Programming Language :: Python :: 3", 34 | "Programming Language :: Python :: 3.10", 35 | "Programming Language :: Python :: 3.11", 36 | "Programming Language :: Python :: 3.12", 37 | "Topic :: Scientific/Engineering :: Astronomy", 38 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 39 | "Topic :: Scientific/Engineering :: Physics", 40 | ] 41 | dependencies = [ 42 | "einops>=0.7.0", 43 | "huggingface_hub>=0.24.7", 44 | "jaxtyping>=0.2.28", 45 | "numpy", 46 | "scipy", 47 | "tokenizers>=0.15.2", 48 | ] 49 | 50 | [project.urls] 51 | Homepage = "https://polymathic-ai.org/" 52 | Documentation = "https://polymathic-ai.github.io/AION/" 53 | Repository = "https://github.com/PolymathicAI/AION" 54 | "Bug Tracker" = "https://github.com/PolymathicAI/AION/issues" 55 | "Discussion" = "https://github.com/PolymathicAI/AION/discussions" 56 | "Hugging Face" = "https://huggingface.co/polymathic-ai/aion-base" 57 | 58 | [project.optional-dependencies] 59 | torch = [ 60 | "torch>=2.4.0", 61 | "huggingface_hub[torch]>=0.24.7", 62 | ] 63 | dev = [ 64 | "pre-commit", 65 | "pytest", 66 | "ruff", 67 | ] 68 | docs = [ 69 | "furo", 70 | "myst-parser>=1.0", 71 | "sphinx-copybutton", 72 | "sphinx-design", 73 | "sphinxcontrib-mermaid", 74 | "sphinx>=7.2", 75 | ] 76 | 77 | [tool.ruff.lint] 78 | # Ignore space in shape notation for jaxtyping 79 | # See https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error 80 | ignore = ["F722"] 81 | 82 | [tool.setuptools.packages.find] 83 | include = ["aion*"] 84 | # Only include the 'aion' package and its subpackages 85 | 86 | [tool.setuptools_scm] 87 | # This section configures setuptools_scm 88 | # It will use git tags to determine version numbers 89 | # If no tags exist, it will use 0.1.dev0+g 90 | fallback_version = "0.1.dev0" 91 | -------------------------------------------------------------------------------- /aion/codecs/modules/ema.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on the timm code base 3 | # https://github.com/huggingface/pytorch-image-models 4 | # -------------------------------------------------------- 5 | """Exponential Moving Average (EMA) of model updates 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | 10 | from copy import deepcopy 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | class ModelEmaV2(nn.Module): 17 | """Model Exponential Moving Average V2 18 | 19 | Keep a moving average of everything in the model state_dict (parameters and buffers). 20 | V2 of this module is simpler, it does not match params/buffers based on name but simply 21 | iterates in order. It works with torchscript (JIT of full model). 22 | 23 | This is intended to allow functionality like 24 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 25 | 26 | A smoothed version of the weights is necessary for some training schemes to perform well. 27 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 28 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 29 | smoothing of weights to match results. Pay attention to the decay constant you are using 30 | relative to your update count per epoch. 31 | 32 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 33 | disable validation of the EMA weights. Validation will have to be done manually in a separate 34 | process, or after the training stops converging. 35 | 36 | This class is sensitive where it is initialized in the sequence of model init, 37 | GPU assignment and distributed training wrappers. 38 | """ 39 | 40 | def __init__(self, model, decay=0.9999, device=None): 41 | super().__init__() 42 | # make a copy of the model for accumulating moving average of weights 43 | self.module = deepcopy(model) 44 | self.module.eval() 45 | self.decay = decay 46 | self.device = device # perform ema on different device from model if set 47 | if self.device is not None: 48 | self.module.to(device=device) 49 | 50 | def _update(self, model, update_fn): 51 | with torch.no_grad(): 52 | for ema_v, model_v in zip( 53 | self.module.state_dict().values(), model.state_dict().values() 54 | ): 55 | if self.device is not None: 56 | model_v = model_v.to(device=self.device) 57 | ema_v.copy_(update_fn(ema_v, model_v)) 58 | 59 | def update(self, model): 60 | self._update( 61 | model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m 62 | ) 63 | 64 | def set(self, model): 65 | self._update(model, update_fn=lambda e, m: m) 66 | 67 | def forward(self, *args, **kwargs): 68 | return self.module(*args, **kwargs) 69 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{raw} html 2 |
3 |
4 |

AION-1

5 |

AstronomIcal Omnimodal Network

6 |

Large-Scale Multimodal Foundation Model for Astronomy

7 |
8 | Get Started → 9 | 10 | Run on Colab 11 |
12 |
13 | ``` 14 | 15 | # AION-1 Documentation 16 | 17 | ## 🌟 Why AION-1? 18 | 19 | Trained on over 200 million astronomical objects, AION-1 (AstronomIcal Omnimodal Network) is the first Foundation Model capable of unifying multiband imaging, spectroscopy, and photometry from major ground- and space-based observatories into a single framework. 20 | 21 | Compared to traditional machine learning approaches in Astronomy, AION-1 stands out on several points: 22 | - **Enabling Flexible Data Fusion**: Scientists can use any combination of available observations without redesigning their analysis pipeline 23 | - **Enabling Easy Adaptation to Downstream Tasks**: Scientists can adapt AION-1 to new tasks in a matter of minutes and reach SOTA performance 24 | - **Excelling in Low-Data Regimes**: AION-1 achieves competitive results with orders of magnitude less labeled data than supervised approaches 25 | - **Providing Universal Representations**: The learned embeddings capture physically meaningful structure useful across diverse downstream tasks 26 | 27 | ## 🚀 Quick Start 28 | 29 | Assuming you have PyTorch installed, you can install AION trivially with: 30 | ```bash 31 | pip install polymathic-aion 32 | ``` 33 | 34 | Then you can load the pretrained model and start analyzing astronomical data: 35 | ```python 36 | import torch 37 | from aion import AION 38 | from aion.codecs import CodecManager 39 | from aion.modalities import LegacySurveyImage 40 | 41 | # Load model and codec manager 42 | model = AION.from_pretrained('aion-base').to('cuda') # or 'aion-large', 'aion-xlarge' 43 | codec_manager = CodecManager(device='cuda') 44 | 45 | # Prepare your astronomical data (example: Legacy Survey image) 46 | image = LegacySurveyImage( 47 | flux=your_image_tensor, # Shape: [batch, 4, height, width] for g,r,i,z bands 48 | bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] 49 | ) 50 | 51 | # Encode data to tokens 52 | tokens = codec_manager.encode(image) 53 | 54 | # Option 1: Extract embeddings for downstream tasks 55 | embeddings = model.encode(tokens, num_encoder_tokens=600) 56 | 57 | # Option 2: Generate predictions (e.g., redshift) 58 | from aion.modalities import Z 59 | preds = model( 60 | codec_manager.encode(image), 61 | target_modality=Z, 62 | ) 63 | ``` 64 | 65 | ## 📚 Documentation 66 | 67 | ```{eval-rst} 68 | .. grid:: 1 1 1 2 69 | :gutter: 3 70 | 71 | .. grid-item-card:: API Reference 72 | :link: api.html 73 | :class-card: doc-card 74 | 75 | Complete API documentation with all classes and methods 76 | ``` 77 | 78 | ```{toctree} 79 | :hidden: 80 | :maxdepth: 2 81 | 82 | api 83 | ``` 84 | -------------------------------------------------------------------------------- /aion/codecs/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TypeVar 3 | 4 | from aion.codecs.catalog import CatalogCodec 5 | from aion.codecs.image import ImageCodec 6 | from aion.codecs.scalar import ( 7 | GridScalarCodec, 8 | LogScalarCodec, 9 | MultiScalarCodec, 10 | ScalarCodec, 11 | ) 12 | from aion.codecs.scalar_field import ScalarFieldCodec 13 | from aion.codecs.spectrum import SpectrumCodec 14 | from aion.modalities import ( 15 | HSCAG, 16 | HSCAI, 17 | HSCAR, 18 | HSCAY, 19 | HSCAZ, 20 | Dec, 21 | DESISpectrum, 22 | GaiaFluxBp, 23 | GaiaFluxG, 24 | GaiaFluxRp, 25 | GaiaParallax, 26 | GaiaXpBp, 27 | GaiaXpRp, 28 | HSCImage, 29 | HSCMagG, 30 | HSCMagI, 31 | HSCMagR, 32 | HSCMagY, 33 | HSCMagZ, 34 | HSCShape11, 35 | HSCShape12, 36 | HSCShape22, 37 | Image, 38 | LegacySurveyCatalog, 39 | LegacySurveyEBV, 40 | LegacySurveyFluxG, 41 | LegacySurveyFluxI, 42 | LegacySurveyFluxR, 43 | LegacySurveyFluxW1, 44 | LegacySurveyFluxW2, 45 | LegacySurveyFluxW3, 46 | LegacySurveyFluxW4, 47 | LegacySurveyFluxZ, 48 | LegacySurveyImage, 49 | LegacySurveySegmentationMap, 50 | LegacySurveyShapeE1, 51 | LegacySurveyShapeE2, 52 | LegacySurveyShapeR, 53 | Ra, 54 | SDSSSpectrum, 55 | Spectrum, 56 | Z, 57 | ) 58 | 59 | CodecType = TypeVar( 60 | "CodecModel", 61 | bound=type[ 62 | CatalogCodec 63 | | GridScalarCodec 64 | | ImageCodec 65 | | LogScalarCodec 66 | | MultiScalarCodec 67 | | ScalarCodec 68 | | ScalarFieldCodec 69 | | SpectrumCodec 70 | ], 71 | ) 72 | 73 | 74 | @dataclass 75 | class CodecHFConfig: 76 | """Codec configuration for AION.""" 77 | 78 | codec_class: CodecType 79 | repo_id: str 80 | 81 | 82 | MODALITY_CODEC_MAPPING = { 83 | Dec: ScalarCodec, 84 | DESISpectrum: SpectrumCodec, 85 | GaiaFluxBp: LogScalarCodec, 86 | GaiaFluxG: LogScalarCodec, 87 | GaiaFluxRp: LogScalarCodec, 88 | GaiaParallax: LogScalarCodec, 89 | GaiaXpBp: MultiScalarCodec, 90 | GaiaXpRp: MultiScalarCodec, 91 | HSCAG: ScalarCodec, 92 | HSCAI: ScalarCodec, 93 | HSCAR: ScalarCodec, 94 | HSCAY: ScalarCodec, 95 | HSCAZ: ScalarCodec, 96 | HSCImage: ImageCodec, 97 | HSCMagG: ScalarCodec, 98 | HSCMagI: ScalarCodec, 99 | HSCMagR: ScalarCodec, 100 | HSCMagY: ScalarCodec, 101 | HSCMagZ: ScalarCodec, 102 | HSCShape11: ScalarCodec, 103 | HSCShape12: ScalarCodec, 104 | HSCShape22: ScalarCodec, 105 | Image: ImageCodec, 106 | LegacySurveyCatalog: CatalogCodec, 107 | LegacySurveyEBV: ScalarCodec, 108 | LegacySurveyFluxG: LogScalarCodec, 109 | LegacySurveyFluxI: LogScalarCodec, 110 | LegacySurveyFluxR: LogScalarCodec, 111 | LegacySurveyFluxW1: LogScalarCodec, 112 | LegacySurveyFluxW2: LogScalarCodec, 113 | LegacySurveyFluxW3: LogScalarCodec, 114 | LegacySurveyFluxW4: LogScalarCodec, 115 | LegacySurveyFluxZ: LogScalarCodec, 116 | LegacySurveyImage: ImageCodec, 117 | LegacySurveySegmentationMap: ScalarFieldCodec, 118 | LegacySurveyShapeE1: ScalarCodec, 119 | LegacySurveyShapeE2: ScalarCodec, 120 | LegacySurveyShapeR: LogScalarCodec, 121 | Ra: ScalarCodec, 122 | SDSSSpectrum: SpectrumCodec, 123 | Spectrum: SpectrumCodec, 124 | Z: GridScalarCodec, 125 | } 126 | 127 | HF_REPO_ID = "polymathic-ai/aion-base" 128 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | tests/test_data/image_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 2 | tests/test_data/image_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 3 | tests/test_data/image_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 4 | tests/test_data/FLUX_I_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 5 | tests/test_data/FLUX_W1_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 6 | tests/test_data/FLUX_W3_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 7 | tests/test_data/SHAPE_E1_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 8 | tests/test_data/SHAPE_E1_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 9 | tests/test_data/EBV_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 10 | tests/test_data/FLUX_W2_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 11 | tests/test_data/FLUX_W1_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 12 | tests/test_data/FLUX_Z_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 13 | tests/test_data/FLUX_W3_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 14 | tests/test_data/FLUX_W4_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 15 | tests/test_data/SHAPE_E2_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 16 | tests/test_data/FLUX_G_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 17 | tests/test_data/FLUX_R_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 18 | tests/test_data/FLUX_R_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 19 | tests/test_data/FLUX_W1_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 20 | tests/test_data/SHAPE_E2_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 21 | tests/test_data/SHAPE_R_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 22 | tests/test_data/FLUX_G_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 23 | tests/test_data/FLUX_I_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 24 | tests/test_data/FLUX_I_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 25 | tests/test_data/FLUX_W4_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 26 | tests/test_data/FLUX_Z_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 27 | tests/test_data/FLUX_G_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 28 | tests/test_data/FLUX_W2_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 29 | tests/test_data/FLUX_W2_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 30 | tests/test_data/SHAPE_E2_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 31 | tests/test_data/SHAPE_R_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 32 | tests/test_data/EBV_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text 33 | tests/test_data/FLUX_R_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 34 | tests/test_data/FLUX_W3_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 35 | tests/test_data/FLUX_W4_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 36 | tests/test_data/FLUX_Z_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 37 | tests/test_data/SHAPE_E1_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 38 | tests/test_data/SHAPE_R_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 39 | tests/test_data/EBV_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 40 | tests/test_data/SPECTRUM_input_batch.pt filter=lfs diff=lfs merge=lfs -text 41 | tests/test_data/SPECTRUM_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text 42 | tests/test_data/SPECTRUM_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text 43 | *.pt filter=lfs diff=lfs merge=lfs -text 44 | -------------------------------------------------------------------------------- /tests/codecs/test_scalar_codec.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from aion.codecs import GridScalarCodec, LogScalarCodec, MultiScalarCodec, ScalarCodec 5 | from aion.codecs.config import HF_REPO_ID 6 | from aion.modalities import ( 7 | HSCAG, 8 | HSCAI, 9 | HSCAR, 10 | HSCAY, 11 | HSCAZ, 12 | Dec, 13 | GaiaFluxBp, 14 | # Gaia modalities 15 | GaiaFluxG, 16 | GaiaFluxRp, 17 | GaiaParallax, 18 | GaiaXpBp, 19 | GaiaXpRp, 20 | HSCMagG, 21 | HSCMagI, 22 | HSCMagR, 23 | HSCMagY, 24 | HSCMagZ, 25 | HSCShape11, 26 | HSCShape12, 27 | HSCShape22, 28 | LegacySurveyEBV, 29 | LegacySurveyFluxG, 30 | LegacySurveyFluxI, 31 | LegacySurveyFluxR, 32 | LegacySurveyFluxW1, 33 | LegacySurveyFluxW2, 34 | LegacySurveyFluxW3, 35 | LegacySurveyFluxW4, 36 | LegacySurveyFluxZ, 37 | LegacySurveyShapeE1, 38 | LegacySurveyShapeE2, 39 | LegacySurveyShapeR, 40 | Ra, 41 | Z, 42 | ) 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "codec_class,modality", 47 | [ 48 | # LogScalarCodec tests 49 | (LogScalarCodec, LegacySurveyFluxG), 50 | (LogScalarCodec, LegacySurveyFluxR), 51 | (LogScalarCodec, LegacySurveyFluxI), 52 | (LogScalarCodec, LegacySurveyFluxZ), 53 | (LogScalarCodec, LegacySurveyFluxW1), 54 | (LogScalarCodec, LegacySurveyFluxW2), 55 | (LogScalarCodec, LegacySurveyFluxW3), 56 | (LogScalarCodec, LegacySurveyFluxW4), 57 | (LogScalarCodec, LegacySurveyShapeR), 58 | # Gaia LogScalarCodec tests 59 | (LogScalarCodec, GaiaFluxG), 60 | (LogScalarCodec, GaiaFluxBp), 61 | (LogScalarCodec, GaiaFluxRp), 62 | (LogScalarCodec, GaiaParallax), 63 | # ScalarCodec tests 64 | (ScalarCodec, LegacySurveyShapeE1), 65 | (ScalarCodec, LegacySurveyShapeE2), 66 | (ScalarCodec, LegacySurveyEBV), 67 | (ScalarCodec, HSCMagG), 68 | (ScalarCodec, HSCMagR), 69 | (ScalarCodec, HSCMagI), 70 | (ScalarCodec, HSCMagZ), 71 | (ScalarCodec, HSCMagY), 72 | (ScalarCodec, HSCShape11), 73 | (ScalarCodec, HSCShape22), 74 | (ScalarCodec, HSCShape12), 75 | (ScalarCodec, HSCAG), 76 | (ScalarCodec, HSCAR), 77 | (ScalarCodec, HSCAI), 78 | (ScalarCodec, HSCAZ), 79 | (ScalarCodec, HSCAY), 80 | # Gaia ScalarCodec tests 81 | (ScalarCodec, Ra), 82 | (ScalarCodec, Dec), 83 | # Gaia MultiScalarCodec tests 84 | (MultiScalarCodec, GaiaXpBp), 85 | (MultiScalarCodec, GaiaXpRp), 86 | # Grid tokenizer 87 | (GridScalarCodec, Z), 88 | ], 89 | ) 90 | def test_scalar_tokenizer(data_dir, codec_class, modality): 91 | codec = codec_class.from_pretrained(HF_REPO_ID, modality=modality) 92 | codec.eval() 93 | input_batch = torch.load( 94 | data_dir / f"{modality.name}_codec_input_batch.pt", weights_only=False 95 | ) 96 | reference_encoded_batch = torch.load( 97 | data_dir / f"{modality.name}_codec_encoded_batch.pt", weights_only=False 98 | ) 99 | reference_decoded_batch = torch.load( 100 | data_dir / f"{modality.name}_codec_decoded_batch.pt", weights_only=False 101 | ) 102 | 103 | with torch.no_grad(): 104 | output = codec.encode(modality(value=input_batch)) 105 | decoded_output = codec.decode(output) 106 | 107 | assert torch.allclose(output, reference_encoded_batch) 108 | assert torch.allclose(decoded_output.value, reference_decoded_batch) 109 | -------------------------------------------------------------------------------- /aion/codecs/catalog.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict, Optional, Type 3 | 4 | import torch 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from aion.codecs.base import Codec 9 | from aion.codecs.quantizers import Quantizer 10 | from aion.codecs.quantizers.scalar import ( 11 | ComposedScalarQuantizer, 12 | IdentityQuantizer, 13 | ScalarReservoirQuantizer, 14 | ) 15 | from aion.codecs.utils import CodecPytorchHubMixin 16 | from aion.modalities import LegacySurveyCatalog 17 | 18 | 19 | class CatalogCodec(Codec, CodecPytorchHubMixin): 20 | """Codec for catalog quantities. 21 | 22 | A codec that embeds catalog quantities through an identity mapping. A 23 | quantizer is applied if specified. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | mask_value: int = 9999, 29 | ): 30 | super().__init__() 31 | self._modality = LegacySurveyCatalog 32 | catalog_keys = ["X", "Y", "SHAPE_E1", "SHAPE_E2", "SHAPE_R"] 33 | quantizers = [ 34 | IdentityQuantizer(96), 35 | IdentityQuantizer(96), 36 | ScalarReservoirQuantizer(1024, 100000), 37 | ScalarReservoirQuantizer(1024, 100000), 38 | ScalarReservoirQuantizer(1024, 100000), 39 | ] 40 | self.mask_value = mask_value 41 | self._catalog_keys = catalog_keys 42 | assert len(catalog_keys) == len(quantizers), ( 43 | "Number of catalog keys and quantizers must match" 44 | ) 45 | _quantizer = OrderedDict() 46 | for key, quantizer in zip(catalog_keys, quantizers): 47 | _quantizer[key] = quantizer 48 | self._quantizer = ComposedScalarQuantizer(_quantizer) 49 | 50 | @property 51 | def modality(self) -> Type[LegacySurveyCatalog]: 52 | return self._modality 53 | 54 | @property 55 | def quantizer(self) -> Optional[Quantizer]: 56 | return self._quantizer 57 | 58 | def _encode(self, x: LegacySurveyCatalog) -> Dict[str, Tensor]: 59 | encoded = OrderedDict() 60 | for key in self._catalog_keys: 61 | catalog_value = getattr(x, key) 62 | mask = catalog_value != self.mask_value 63 | catalog_value = catalog_value[mask] 64 | encoded[key] = catalog_value 65 | encoded["mask"] = mask 66 | return encoded 67 | 68 | def encode(self, x: LegacySurveyCatalog) -> Float[Tensor, "b c1 *code_shape"]: 69 | """Encodes a given batch of samples into latent space.""" 70 | embedding = self._encode(x) 71 | _encoded = self.quantizer.encode( 72 | embedding 73 | ) # (b, C), where b is the number of non-masked samples 74 | 75 | mask = embedding["mask"] 76 | # B: batch size, L: sequence length (20) for each catalog key 77 | B, L = mask.shape 78 | C = len(self._catalog_keys) 79 | encoded = self.mask_value * torch.ones( 80 | B, L, C, dtype=_encoded.dtype, device=_encoded.device 81 | ) 82 | encoded[mask] = _encoded 83 | encoded = encoded.reshape(B, -1) 84 | return encoded 85 | 86 | def _decode(self, z: Dict[str, Tensor]) -> LegacySurveyCatalog: 87 | return LegacySurveyCatalog(**z) 88 | 89 | def decode(self, z: Float[Tensor, "b c1 *code_shape"]) -> LegacySurveyCatalog: 90 | B, LC = z.shape 91 | C = len(self._catalog_keys) 92 | L = LC // C 93 | z = z[:, : C * L] # Truncate the z if it is longer than the expected length 94 | z = z.reshape(B * L, C) 95 | if self._quantizer is not None: 96 | z = self.quantizer.decode(z) 97 | for key in self._catalog_keys: 98 | z[key] = z[key].reshape(B, L) 99 | return self._decode(z) 100 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | AION (AstronomIcal Omnimodal Network) is a large omnimodal transformer model for astronomical surveys. It processes 39 distinct astronomical data modalities using a two-stage architecture: 8 | 9 | 1. **Modality-specific tokenizers** transform raw inputs (images, spectra, catalogs, scalars) into discrete tokens 10 | 2. **Unified encoder-decoder transformer** processes all token streams via multimodal masked modeling (4M) 11 | 12 | The model comes in three variants: Base (300M), Large (800M), and XLarge (3B parameters). 13 | 14 | ## Development Commands 15 | 16 | ### Testing 17 | ```bash 18 | pytest # Run all tests 19 | pytest tests/codecs/ # Run codec tests only 20 | pytest tests/test_data/ # Uses pre-computed test data for validation 21 | ``` 22 | 23 | ### Linting and Code Quality 24 | ```bash 25 | ruff check . # Check code style and lint 26 | ruff check . --fix # Auto-fix linting issues 27 | ``` 28 | 29 | ### Installation for Development 30 | ```bash 31 | pip install -e .[torch,dev] # Install in editable mode with dev dependencies 32 | ``` 33 | 34 | ### Documentation 35 | ```bash 36 | cd docs && make html # Build Sphinx documentation 37 | ``` 38 | 39 | ## Architecture Overview 40 | 41 | ### Core Components 42 | 43 | - **`aion/model.py`**: Main AION wrapper class, inherits from FM (4M) transformer 44 | - **`aion/fourm/`**: 4M (Four-Modal) transformer implementation 45 | - `fm.py`: Core transformer architecture with encoder-decoder blocks 46 | - `modality_info.py`: Configuration for all 39 supported modalities 47 | - `encoder_embeddings.py` / `decoder_embeddings.py`: Modality-specific embedding layers 48 | - **`aion/codecs/`**: Modality tokenization system 49 | - `manager.py`: Dynamic codec loading and management 50 | - `base.py`: Abstract base codec class 51 | - Individual codec implementations for images, spectra, scalars, etc. 52 | - **`aion/modalities.py`**: Type definitions for all astronomical data types 53 | 54 | ### Key Design Patterns 55 | 56 | 1. **Modality System**: Each astronomical data type (flux, spectrum, catalog) has: 57 | - A modality class in `modalities.py` defining data structure 58 | - A codec in `codecs/` for tokenization 59 | - Embedding layers in `fourm/` for the transformer 60 | 61 | 2. **Token Keys**: Each modality has a `token_key` (e.g., `tok_image`, `tok_spectrum_sdss`) that maps between modalities and model components 62 | 63 | 3. **HuggingFace Integration**: Models and codecs are distributed via HuggingFace Hub with `from_pretrained()` methods 64 | 65 | ## Code Conventions 66 | 67 | - Type hints are mandatory, using `jaxtyping` for tensor shapes (e.g., `Float[Tensor, "batch height width"]`) 68 | - Modality classes use `@dataclass` and inherit from `BaseModality` 69 | - All tensor operations should handle device placement explicitly 70 | - Test data is pre-computed and stored in `tests/test_data/` as `.pt` files 71 | 72 | ## Testing Strategy 73 | 74 | Tests validate both encoding and decoding for each modality using pre-computed reference data. The test pattern is: 75 | 1. Load input, encoded, and decoded reference tensors 76 | 2. Run codec encode/decode operations 77 | 3. Assert outputs match reference data within tolerance 78 | 79 | Test files follow naming: `test_{modality}_codec.py` 80 | 81 | ## Astronomical Context 82 | 83 | The model processes data from major surveys: 84 | - **Legacy Survey**: Optical images and catalogs (g,r,i,z bands + WISE) 85 | - **HSC (Hyper Suprime-Cam)**: Deep optical imaging (g,r,i,z,y bands) 86 | - **Gaia**: Astrometry, photometry, and BP/RP spectra 87 | - **SDSS/DESI**: Optical spectra 88 | 89 | Each modality represents different physical measurements (flux, shape parameters, coordinates, extinction, etc.) that the model learns to correlate. 90 | -------------------------------------------------------------------------------- /tests/codecs/test_image_codec.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from aion.codecs import ImageCodec 5 | from aion.codecs.config import HF_REPO_ID 6 | from aion.modalities import Image 7 | 8 | 9 | @pytest.mark.parametrize("embedding_dim", [5, 10]) 10 | @pytest.mark.parametrize("multisurvey_projection_dims", [12, 24]) 11 | @pytest.mark.parametrize("hidden_dims", [8, 16]) 12 | def test_magvit_image_tokenizer( 13 | embedding_dim, multisurvey_projection_dims, hidden_dims 14 | ): 15 | tokenizer = ImageCodec( 16 | quantizer_levels=[1] * embedding_dim, 17 | hidden_dims=hidden_dims, 18 | multisurvey_projection_dims=multisurvey_projection_dims, 19 | n_compressions=2, 20 | num_consecutive=4, 21 | embedding_dim=embedding_dim, 22 | range_compression_factor=0.01, 23 | mult_factor=10, 24 | ) 25 | batch_size = 4 26 | flux_tensor = torch.randn(batch_size, 4, 96, 96) 27 | input_image_obj = Image( 28 | flux=flux_tensor, 29 | bands=["DES-G", "DES-R", "DES-I", "DES-Z"], 30 | ) 31 | 32 | encoded = tokenizer.encode(input_image_obj) 33 | assert encoded.shape == (batch_size, 24 * 24) 34 | 35 | decoded_image_obj = tokenizer.decode( 36 | encoded, bands=["DES-G", "DES-R", "DES-I", "DES-Z"] 37 | ) 38 | 39 | assert isinstance(decoded_image_obj, Image) 40 | assert decoded_image_obj.flux.shape == flux_tensor.shape 41 | 42 | 43 | def test_hf_previous_predictions(data_dir): 44 | codec = ImageCodec.from_pretrained(HF_REPO_ID, modality=Image) 45 | 46 | input_batch_dict = torch.load( 47 | data_dir / "image_codec_input_batch.pt", weights_only=False 48 | ) 49 | reference_encoded_output = torch.load( 50 | data_dir / "image_codec_encoded_batch.pt", weights_only=False 51 | ) 52 | reference_decoded_output_tensor = torch.load( 53 | data_dir / "image_codec_decoded_batch.pt", weights_only=False 54 | ) 55 | with torch.no_grad(): 56 | input_image_obj = Image( 57 | flux=input_batch_dict["image"]["array"][:, 5:], 58 | bands=["DES-G", "DES-R", "DES-I", "DES-Z"], 59 | ) 60 | encoded_output = codec.encode(input_image_obj) 61 | decoded_image_obj = codec.decode( 62 | encoded_output, bands=["DES-G", "DES-R", "DES-I", "DES-Z"] 63 | ) 64 | 65 | # We flatten the reference encoded output to match the encoded output 66 | # as we now make all codecs return flattened outputs 67 | reference_encoded_output = reference_encoded_output.reshape( 68 | reference_encoded_output.shape[0], -1 69 | ) 70 | 71 | assert encoded_output.shape == reference_encoded_output.shape 72 | assert torch.allclose( 73 | encoded_output, 74 | reference_encoded_output, 75 | ) 76 | 77 | assert isinstance(decoded_image_obj, Image) 78 | assert torch.allclose( 79 | decoded_image_obj.flux, 80 | reference_decoded_output_tensor[:, 5:], 81 | rtol=1e-3, 82 | atol=1e-4, 83 | ) 84 | 85 | 86 | def test_batch_size_one(): 87 | """Test ImageCodec with batch_size=1 to ensure subsampler works correctly.""" 88 | codec = ImageCodec.from_pretrained(HF_REPO_ID, modality=Image) 89 | 90 | # Test with batch_size=1 91 | batch_size = 1 92 | flux_tensor = torch.randn(batch_size, 4, 96, 96) 93 | input_image_obj = Image( 94 | flux=flux_tensor, 95 | bands=["DES-G", "DES-R", "DES-I", "DES-Z"], 96 | ) 97 | 98 | # This should not raise an error (previously failed due to squeeze() issue) 99 | with torch.no_grad(): 100 | encoded = codec.encode(input_image_obj) 101 | decoded_image_obj = codec.decode( 102 | encoded, bands=["DES-G", "DES-R", "DES-I", "DES-Z"] 103 | ) 104 | 105 | assert isinstance(decoded_image_obj, Image) 106 | assert decoded_image_obj.flux.shape == flux_tensor.shape 107 | -------------------------------------------------------------------------------- /aion/codecs/preprocessing/image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from aion.codecs.preprocessing.band_to_index import BAND_TO_INDEX, BAND_CENTER_MAX 3 | 4 | 5 | class ImagePadder: 6 | """Formatter that pads the images to have a fixed number of bands.""" 7 | 8 | def __init__(self): 9 | self.nbands = max(BAND_TO_INDEX.values()) + 1 10 | 11 | def _check_bands(self, bands: list[str]): 12 | for band in bands: 13 | if band not in BAND_TO_INDEX: 14 | raise ValueError( 15 | f"Invalid band: {band}. Valid bands are: {list(BAND_TO_INDEX.keys())}" 16 | ) 17 | 18 | def forward(self, image, bands): 19 | num_channels = self.nbands 20 | batch, _, height, width = image.shape 21 | 22 | # Check if bands are valid 23 | self._check_bands(bands) 24 | 25 | # Create a new image array with the correct number of channels 26 | padded_image = torch.zeros( 27 | (batch, num_channels, height, width), dtype=image.dtype 28 | ).to(image.device) 29 | 30 | # Create a list of new channel indices based on the order of bands 31 | new_channel_indices = [ 32 | BAND_TO_INDEX[band] for band in bands if band in BAND_TO_INDEX 33 | ] 34 | 35 | # Vectorized assignment of the original channels to the new positions 36 | padded_image[:, new_channel_indices, :, :] = image[ 37 | :, : len(new_channel_indices), :, : 38 | ] 39 | 40 | # Get boolean mask of channels that are present 41 | channel_mask = torch.zeros(num_channels, dtype=torch.bool).to(image.device) 42 | channel_mask[new_channel_indices] = True 43 | channel_mask = channel_mask.unsqueeze(0).expand(batch, -1) 44 | return padded_image, channel_mask 45 | 46 | def backward(self, padded_image, bands): 47 | # Check if bands are valid 48 | self._check_bands(bands) 49 | 50 | # Get the indices for the requested bands 51 | channel_indices = [BAND_TO_INDEX[b] for b in bands] 52 | 53 | # Select those channels along dim=1 54 | selected_image = padded_image[:, channel_indices, :, :] 55 | return selected_image 56 | 57 | 58 | class CenterCrop: 59 | """Formatter that crops the images to have a fixed number of bands.""" 60 | 61 | def __init__(self, crop_size: int = 96): 62 | self.crop_size = crop_size 63 | 64 | def __call__(self, image): 65 | _, _, height, width = image.shape 66 | start_x = (width - self.crop_size) // 2 67 | start_y = (height - self.crop_size) // 2 68 | return image[ 69 | :, :, start_y : start_y + self.crop_size, start_x : start_x + self.crop_size 70 | ] 71 | 72 | 73 | class Clamp: 74 | """Formatter that clamps the images to a given range.""" 75 | 76 | def __init__(self): 77 | self.clamp_dict = BAND_CENTER_MAX 78 | 79 | def __call__(self, image, bands): 80 | for i, band in enumerate(bands): 81 | image[:, i, :, :] = torch.clip( 82 | image[:, i, :, :], -self.clamp_dict[band], self.clamp_dict[band] 83 | ) 84 | return image 85 | 86 | 87 | class RescaleToLegacySurvey: 88 | """Formatter that rescales the images to have a fixed number of bands.""" 89 | 90 | def __init__(self): 91 | pass 92 | 93 | def convert_zeropoint(self, zp: float) -> float: 94 | return 10.0 ** ((zp - 22.5) / 2.5) 95 | 96 | def reverse_zeropoint(self, scale: float) -> float: 97 | return 22.5 - 2.5 * torch.log10(scale) 98 | 99 | def forward(self, image, survey): 100 | zpscale = self.convert_zeropoint(27.0) if survey == "HSC" else 1.0 101 | image /= zpscale 102 | return image 103 | 104 | def backward(self, image, survey): 105 | zpscale = self._reverse_zeropoint(27.0) if survey == "HSC" else 1.0 106 | image *= zpscale 107 | return image 108 | -------------------------------------------------------------------------------- /aion/codecs/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | from typing import Dict, Type, Optional, Any 7 | 8 | from aion.modalities import ModalityType, Modality 9 | from aion.codecs.quantizers import Quantizer 10 | 11 | 12 | class Codec(ABC, torch.nn.Module): 13 | """Abstract definition of the Codec API. 14 | 15 | A Codec is responsible for transforming data of a specific modality into a 16 | continuous latent representation, which is then quantized into discrete tokens. 17 | It also provides the functionality to decode these tokens back into the 18 | original data space. 19 | """ 20 | 21 | @property 22 | @abstractmethod 23 | def modality(self) -> Type[Modality]: 24 | """Returns the modality key that this codec can operate on.""" 25 | raise NotImplementedError 26 | 27 | @abstractmethod 28 | def _encode(self, x: ModalityType) -> Float[Tensor, "b c n_tokens"]: 29 | """Function to be implemented by subclasses which 30 | takes a batch of input samples (as a ModalityType instance) 31 | and embeds it into a latent space, before any quantization. 32 | """ 33 | raise NotImplementedError 34 | 35 | @abstractmethod 36 | def _decode( 37 | self, z: Float[Tensor, "b c n_tokens"], **metadata: Optional[Dict[str, Any]] 38 | ) -> ModalityType: 39 | """Function to be implemented by subclasses which 40 | takes a batch of latent space embeddings (after dequantization) 41 | and decodes it into the original input space as a ModalityType instance. 42 | 43 | Args: 44 | z: The batch of latent space embeddings after dequantization. 45 | **metadata: Optional keyword arguments containing metadata that might be 46 | necessary for the decoding process (e.g., original dimensions, 47 | specific modality parameters). 48 | """ 49 | raise NotImplementedError 50 | 51 | @property 52 | @abstractmethod 53 | def quantizer(self) -> "Quantizer": 54 | """Returns the quantizer.""" 55 | raise NotImplementedError 56 | 57 | def encode(self, x: ModalityType) -> Float[Tensor, "b n_tokens"]: 58 | """Encodes a given batch of samples into latent space. 59 | Encodes a batch of input samples into quantized discrete tokens. 60 | 61 | This involves first embedding the input into a continuous latent space 62 | using `_encode`, and then quantizing this embedding using the 63 | associated `quantizer`. 64 | 65 | Args: 66 | x: A batch of input samples (as a ModalityType instance). 67 | 68 | Returns: 69 | A tensor representing the quantized discrete tokens. 70 | """ 71 | # Verify that the input type corresponds to the modality of the codec 72 | if not isinstance(x, self.modality): 73 | raise ValueError( 74 | f"Input type {type(x).__name__} does not match the modality of the codec {self.modality.__name__}" 75 | ) 76 | embedding = self._encode(x) 77 | return self.quantizer.encode(embedding) 78 | 79 | def decode( 80 | self, z: Float[Tensor, "b n_tokens"], **metadata: Optional[Dict[str, Any]] 81 | ) -> ModalityType: 82 | """Decodes a batch of quantized discrete tokens back into the original data space. 83 | 84 | This involves first dequantizing the tokens using the associated `quantizer`, 85 | and then decoding the resulting continuous latent representation using `_decode`. 86 | 87 | Args: 88 | z: A tensor representing the quantized discrete tokens. 89 | **metadata: Optional keyword arguments containing metadata that might be 90 | necessary for the decoding process, passed to `_decode`. 91 | 92 | Returns: 93 | The decoded batch of samples as a ModalityType instance. 94 | """ 95 | z = self.quantizer.decode(z) 96 | return self._decode(z, **metadata) 97 | -------------------------------------------------------------------------------- /aion/codecs/scalar.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Optional, Dict, Any 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from aion.codecs.quantizers import Quantizer, ScalarLinearQuantizer 7 | from aion.codecs.quantizers.scalar import ( 8 | ScalarLogReservoirQuantizer, 9 | ScalarReservoirQuantizer, 10 | MultiScalarCompressedReservoirQuantizer, 11 | ) 12 | from aion.codecs.base import Codec 13 | from aion.codecs.utils import CodecPytorchHubMixin 14 | from aion.modalities import Scalar, ScalarModalities 15 | 16 | 17 | class BaseScalarIdentityCodec(Codec, CodecPytorchHubMixin): 18 | """Codec for scalar quantities. 19 | 20 | A codec that embeds scalar quantities through an identity mapping. A 21 | quantizer is applied if specified. 22 | 23 | Args: 24 | modality_class: Type[ScalarModality] 25 | The modality class this codec is designed for. 26 | quantizer: Quantizer 27 | Optional quantizer for the scalar values. 28 | """ 29 | 30 | @property 31 | def quantizer(self) -> Quantizer: 32 | return self._quantizer 33 | 34 | @property 35 | def modality(self) -> Type[Scalar]: 36 | return self._modality_class 37 | 38 | def _encode(self, x: Scalar) -> Float[Tensor, " b"]: 39 | return x.value 40 | 41 | def _decode( 42 | self, z: Float[Tensor, " b"], **metadata: Optional[Dict[str, Any]] 43 | ) -> Scalar: 44 | return self._modality_class(value=z) 45 | 46 | def load_state_dict(self, state_dict, strict=True): 47 | # This function is just because the scalar codecs were saved with 'quantizer' instead of '_quantizer' 48 | remapped_state_dict = { 49 | ( 50 | k.replace("quantizer", "_quantizer", 1) 51 | if k.startswith("quantizer") 52 | else k 53 | ): v 54 | for k, v in state_dict.items() 55 | } 56 | return super().load_state_dict(remapped_state_dict, strict=strict) 57 | 58 | 59 | class ScalarCodec(BaseScalarIdentityCodec): 60 | def __init__( 61 | self, 62 | modality: str, 63 | codebook_size: int, 64 | reservoir_size: int, 65 | ): 66 | super().__init__() 67 | self._modality_class = ScalarModalities[modality] 68 | self._quantizer = ScalarReservoirQuantizer( 69 | codebook_size=codebook_size, 70 | reservoir_size=reservoir_size, 71 | ) 72 | 73 | 74 | class LogScalarCodec(BaseScalarIdentityCodec): 75 | def __init__( 76 | self, 77 | modality: str, 78 | codebook_size: int, 79 | reservoir_size: int, 80 | min_log_value: float | None = -3, 81 | ): 82 | super().__init__() 83 | self._modality_class = ScalarModalities[modality] 84 | self._quantizer = ScalarLogReservoirQuantizer( 85 | codebook_size=codebook_size, 86 | reservoir_size=reservoir_size, 87 | min_log_value=min_log_value, 88 | ) 89 | 90 | 91 | class MultiScalarCodec(BaseScalarIdentityCodec): 92 | def __init__( 93 | self, 94 | modality: str, 95 | compression_fns: list[str], 96 | decompression_fns: list[str], 97 | codebook_size: int, 98 | reservoir_size: int, 99 | num_quantizers: int, 100 | ): 101 | super().__init__() 102 | self._modality_class = ScalarModalities[modality] 103 | self._quantizer = MultiScalarCompressedReservoirQuantizer( 104 | compression_fns=compression_fns, 105 | decompression_fns=decompression_fns, 106 | codebook_size=codebook_size, 107 | reservoir_size=reservoir_size, 108 | num_quantizers=num_quantizers, 109 | ) 110 | 111 | 112 | class GridScalarCodec(BaseScalarIdentityCodec): 113 | def __init__(self, modality: str, codebook_size: int): 114 | super().__init__() 115 | self._modality_class = ScalarModalities[modality] 116 | self._quantizer = ScalarLinearQuantizer( 117 | codebook_size=codebook_size, 118 | range=(0.0, 1.0), 119 | ) 120 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath("..")) 5 | 6 | project = "AION-1" 7 | author = "Polymathic AI" 8 | html_title = "AION" 9 | 10 | extensions = [ 11 | "myst_parser", 12 | "sphinx_copybutton", 13 | "sphinx_design", # For cards and grids 14 | "sphinxcontrib.mermaid", 15 | "sphinx.ext.autodoc", 16 | "sphinx.ext.autosummary", 17 | "sphinx.ext.napoleon", 18 | ] 19 | 20 | autosummary_generate = True 21 | 22 | # MyST parser configuration 23 | myst_enable_extensions = [ 24 | "colon_fence", 25 | "deflist", 26 | "html_image", 27 | ] 28 | 29 | myst_heading_anchors = 3 30 | 31 | html_theme = "furo" 32 | html_static_path = ["_static"] 33 | html_css_files = ["style.css"] 34 | 35 | # Theme customizations - separate light and dark themes 36 | html_theme_options = { 37 | "light_css_variables": { 38 | "color-brand-primary": "#CA0E4C", 39 | "color-brand-content": "#CA0E4C", 40 | "color-foreground-primary": "#2c3e50", # Dark text for light mode 41 | "color-foreground-secondary": "#546e7a", 42 | "color-foreground-muted": "#90a4ae", 43 | "color-foreground-border": "#e0e0e0", 44 | "color-background-primary": "#ffffff", # White background for light mode 45 | "color-background-secondary": "#f5f5f5", 46 | "color-background-hover": "#fafafa", 47 | "color-background-border": "#e0e0e0", 48 | "color-sidebar-background": "#fafafa", 49 | "color-sidebar-background-border": "#e0e0e0", 50 | "color-sidebar-brand-text": "#2c3e50", 51 | "color-sidebar-caption-text": "#546e7a", 52 | "color-sidebar-link-text": "#2c3e50", 53 | "color-sidebar-link-text--top-level": "#2c3e50", 54 | "color-sidebar-search-background": "#ffffff", 55 | "color-sidebar-search-border": "#e0e0e0", 56 | "color-sidebar-search-foreground": "#2c3e50", 57 | "color-admonition-background": "#f5f5f5", 58 | "color-api-background": "#f5f5f5", 59 | "color-api-background-hover": "#eeeeee", 60 | "color-highlight-on-target": "rgba(202, 14, 76, 0.1)", 61 | "color-inline-code-background": "rgba(202, 14, 76, 0.08)", 62 | "color-inline-code-text": "#CA0E4C", 63 | }, 64 | "dark_css_variables": { 65 | "color-brand-primary": "#CA0E4C", 66 | "color-brand-content": "#CA0E4C", 67 | "color-foreground-primary": "#e0e0e0", 68 | "color-foreground-secondary": "#b0b0b0", 69 | "color-foreground-muted": "#909090", 70 | "color-foreground-border": "#2a2a2a", 71 | "color-background-primary": "#0a0a0a", 72 | "color-background-secondary": "#171717", 73 | "color-background-hover": "#1a1a1a", 74 | "color-background-border": "#2a2a2a", 75 | "color-sidebar-background": "#0f0f0f", 76 | "color-sidebar-background-border": "#2a2a2a", 77 | "color-sidebar-brand-text": "#e0e0e0", 78 | "color-sidebar-caption-text": "#b0b0b0", 79 | "color-sidebar-link-text": "#cccccc", 80 | "color-sidebar-link-text--top-level": "#e0e0e0", 81 | "color-sidebar-search-background": "#1a1a1a", 82 | "color-sidebar-search-border": "#2a2a2a", 83 | "color-sidebar-search-foreground": "#e0e0e0", 84 | "color-admonition-background": "#1a1a1a", 85 | "color-api-background": "#1a1a1a", 86 | "color-api-background-hover": "#262626", 87 | "color-highlight-on-target": "rgba(202, 14, 76, 0.15)", 88 | "color-inline-code-background": "rgba(202, 14, 76, 0.15)", 89 | "color-inline-code-text": "#ff7a9a", 90 | }, 91 | "sidebar_hide_name": False, 92 | "navigation_with_keys": True, 93 | } 94 | 95 | # Add custom footer 96 | html_context = { 97 | "default_mode": "auto", # Let the user's browser preference decide 98 | } 99 | 100 | # Customize source link text 101 | html_copy_source = True 102 | html_show_sourcelink = True 103 | html_sourcelink_suffix = "" 104 | 105 | # Add custom favicon if available 106 | # html_favicon = "_static/favicon.ico" 107 | 108 | # Set custom logo for the top left 109 | # html_logo = "_static/polymathic_logo.png" 110 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Ruff 155 | .ruff_cache 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | notebooks 176 | data 177 | old_impl 178 | -------------------------------------------------------------------------------- /tests/codecs/test_codec_manager.py: -------------------------------------------------------------------------------- 1 | """Test the CodecManager class.""" 2 | 3 | from pathlib import Path 4 | 5 | import pytest 6 | import torch 7 | 8 | from aion.codecs.manager import CodecManager, ModalityTypeError 9 | from aion.modalities import ( 10 | DESISpectrum, 11 | LegacySurveyFluxG, 12 | LegacySurveyImage, 13 | LegacySurveyShapeE1, 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def manager(): 19 | """Create a CodecManager instance.""" 20 | manager = CodecManager(device="cpu") 21 | yield manager 22 | manager._load_codec.cache_clear() 23 | 24 | 25 | def test_encode_decode_image(manager: CodecManager, data_dir: Path): 26 | """Test encoding and decoding Image modality.""" 27 | # Load test data 28 | input_batch_dict = torch.load( 29 | data_dir / "image_codec_input_batch.pt", weights_only=False 30 | ) 31 | 32 | # Create Image modality 33 | image = LegacySurveyImage( 34 | flux=input_batch_dict["image"]["array"][:, 5:], 35 | bands=["DES-G", "DES-R", "DES-I", "DES-Z"], 36 | ) 37 | 38 | # Encode 39 | tokens = manager.encode(image) 40 | assert "tok_image" in tokens 41 | assert tokens["tok_image"].shape[0] == image.flux.shape[0] 42 | 43 | # Decode using modality type 44 | decoded_image = manager.decode( 45 | tokens, LegacySurveyImage, bands=["DES-G", "DES-R", "DES-I", "DES-Z"] 46 | ) 47 | assert isinstance(decoded_image, LegacySurveyImage) 48 | assert decoded_image.flux.shape == image.flux.shape 49 | 50 | 51 | def test_encode_decode_spectrum(manager: CodecManager, data_dir: Path): 52 | """Test encoding and decoding Spectrum modality.""" 53 | # Load test data 54 | input_batch = torch.load(data_dir / "SPECTRUM_input_batch.pt", weights_only=False)[ 55 | "spectrum" 56 | ] 57 | 58 | # Create Spectrum modality 59 | spectrum = DESISpectrum( 60 | flux=input_batch["flux"], 61 | ivar=input_batch["ivar"], 62 | mask=input_batch["mask"], 63 | wavelength=input_batch["lambda"], 64 | ) 65 | 66 | # Encode 67 | tokens = manager.encode(spectrum) 68 | assert "tok_spectrum_desi" in tokens 69 | 70 | # Decode 71 | decoded_spectrum = manager.decode(tokens, DESISpectrum) 72 | assert isinstance(decoded_spectrum, DESISpectrum) 73 | assert decoded_spectrum.flux.shape[0] == spectrum.flux.shape[0] 74 | # Spectrum are returned with a fixed length 75 | assert decoded_spectrum.flux.shape[1] >= spectrum.flux.shape[1] 76 | 77 | 78 | def test_codec_caching(manager: CodecManager): 79 | """Test that codecs are properly cached and reused.""" 80 | # Create two modalities that use the same codec type 81 | flux_g1 = LegacySurveyFluxG(value=torch.randn(4, 1)) 82 | flux_g2 = LegacySurveyFluxG(value=torch.randn(4, 1)) 83 | 84 | # Encode both 85 | manager.encode(flux_g1) 86 | manager.encode(flux_g2) 87 | 88 | # Check that only one codec was loaded 89 | assert manager._load_codec.cache_info().hits == 1 90 | 91 | # Check that the same codec instance is used 92 | codec1 = manager._load_codec(LegacySurveyFluxG) 93 | codec2 = manager._load_codec(LegacySurveyFluxG) 94 | assert codec1 is codec2 95 | 96 | 97 | def test_error_handling(manager: CodecManager): 98 | """Test error handling in CodecManager.""" 99 | 100 | # Test with invalid modality type 101 | class InvalidModality: 102 | pass 103 | 104 | with pytest.raises(ModalityTypeError): 105 | manager._load_codec(InvalidModality) 106 | 107 | 108 | @pytest.mark.parametrize("batch_size", [1, 4, 16]) 109 | def test_different_batch_sizes(manager: CodecManager, batch_size: int): 110 | """Test that CodecManager handles different batch sizes correctly.""" 111 | # Create modalities with different batch sizes 112 | flux_g = LegacySurveyFluxG(value=torch.randn(batch_size, 1)) 113 | shape_e1 = LegacySurveyShapeE1(value=torch.randn(batch_size, 1)) 114 | 115 | # Encode 116 | tokens = manager.encode(flux_g, shape_e1) 117 | 118 | # Check batch sizes 119 | assert tokens["tok_flux_g"].shape[0] == batch_size 120 | assert tokens["tok_shape_e1"].shape[0] == batch_size 121 | 122 | # Decode and verify 123 | decoded_flux = manager.decode(tokens, LegacySurveyFluxG) 124 | decoded_shape = manager.decode(tokens, LegacySurveyShapeE1) 125 | 126 | assert decoded_flux.value.shape[0] == batch_size 127 | assert decoded_shape.value.shape[0] == batch_size 128 | -------------------------------------------------------------------------------- /aion/fourm/generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 EPFL and Apple Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import math 16 | 17 | 18 | def sample_to_batch(mod_dict, device, domains): 19 | mod_dict = { 20 | modality: { 21 | k: v.unsqueeze(0).to(device, non_blocking=True) for k, v in d.items() 22 | } 23 | for modality, d in mod_dict.items() 24 | if modality in domains 25 | } 26 | 27 | return mod_dict 28 | 29 | 30 | def unbatch(tensor): 31 | return tensor.detach().squeeze(0).cpu() 32 | 33 | 34 | def batch_to_sample(mod_dict, domains): 35 | mod_dict = { 36 | modality: {k: unbatch(v) for k, v in d.items()} 37 | for modality, d in mod_dict.items() 38 | if modality in domains 39 | } 40 | 41 | return mod_dict 42 | 43 | 44 | def batch_to_device(mod_dict, device, domains): 45 | mod_dict = { 46 | modality: {k: v.to(device, non_blocking=True) for k, v in d.items()} 47 | for modality, d in mod_dict.items() 48 | if modality in domains 49 | } 50 | 51 | return mod_dict 52 | 53 | 54 | def cosine_schedule(num_steps, total_tokens): 55 | iters = np.arange(num_steps) 56 | base_value = 1 57 | final_value = 0 58 | schedule = np.array( 59 | [ 60 | final_value 61 | + 0.5 62 | * (base_value - final_value) 63 | * (1 + math.cos(math.pi * i / (len(iters)))) 64 | for i in iters 65 | ] 66 | ) 67 | schedule_tokens = [round(total_tokens * i) for i in (schedule[:-1] - schedule[1:])] 68 | schedule_tokens.append(total_tokens - sum(schedule_tokens)) 69 | return np.array(schedule_tokens) 70 | 71 | 72 | def linear_schedule(num_steps, total_tokens): 73 | schedule = np.linspace(0, total_tokens, num_steps + 1, dtype=int) 74 | schedule_tokens = np.diff(schedule)[::-1] 75 | schedule_tokens.sort() # Sorts the array in ascending order. 76 | schedule_tokens = schedule_tokens[::-1] # Reverses the array to descending order. 77 | return np.trim_zeros(schedule_tokens, "b") # Trims trailing zeros. 78 | 79 | 80 | def continue_schedule(schedule, num_current_tokens): 81 | schedule_cumsum = np.cumsum(schedule) 82 | keep_mask = schedule_cumsum > num_current_tokens 83 | diff = schedule_cumsum[keep_mask][0] - num_current_tokens 84 | new_schedule = schedule[keep_mask] 85 | new_schedule[0] = diff 86 | return new_schedule 87 | 88 | 89 | def decreasing_temp_schedule(max, min, token_schedule): 90 | schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule) 91 | temp_schedule = np.array([min + (max - min) * (1 - s) for s in schedule_cumsum]) 92 | return temp_schedule 93 | 94 | 95 | def onex_temp_schedule( 96 | max_t, min_t, token_schedule, power=0.5, min_linspace=1, max_linspace=100 97 | ): 98 | """Abitrary temperature schedule for one over x""" 99 | x = np.linspace(min_linspace, max_linspace, num=sum(token_schedule)) 100 | y = 1 / (x**power) 101 | y = y - min(y) 102 | y = y / max(y) 103 | unscaled_schedule = y 104 | schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule) 105 | unscaled_schedule = [ 106 | (1 - cs) * us for us, cs in zip(unscaled_schedule, schedule_cumsum) 107 | ] 108 | 109 | temp_schedule = np.array( 110 | [min_t + (max_t - min_t) * s for s in unscaled_schedule] 111 | ).clip(min=1e-9) 112 | return temp_schedule 113 | 114 | 115 | def linear_temp_schedule(temp, token_schedule): 116 | """Temperature that decays the temperature inversely proportional to the token schedule.""" 117 | return np.concatenate( 118 | [ 119 | np.array([temp * 1.0]), 120 | ( 121 | temp 122 | * (token_schedule.sum() - token_schedule.cumsum()) 123 | / token_schedule.sum() 124 | )[:-1], 125 | ] 126 | ).clip(min=1e-9) 127 | -------------------------------------------------------------------------------- /aion/fourm/text_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | from typing import Optional, Union, List 5 | 6 | from tokenizers import AddedToken, decoders, trainers 7 | from tokenizers import Tokenizer 8 | from tokenizers.models import WordPiece 9 | from tokenizers.normalizers import BertNormalizer 10 | from tokenizers.pre_tokenizers import BertPreTokenizer 11 | 12 | 13 | def generate_sentinel_tokens(num=100, start_id=0): 14 | tokens = [ 15 | AddedToken(content=f"[S_{i}]", single_word=True, normalized=False) 16 | for i in range(start_id, num + start_id) 17 | ] 18 | 19 | return tokens 20 | 21 | 22 | def generate_coord_tokens(bins=1000): 23 | tokens = [] 24 | coords_str = ["xmin={}", "ymin={}", "xmax={}", "ymax={}"] 25 | 26 | for s in coords_str: 27 | for i in range(bins): 28 | tokens.append( 29 | AddedToken(content=s.format(i), single_word=True, normalized=False) 30 | ) 31 | 32 | return tokens 33 | 34 | 35 | def generate_object_class_tokens(dataset="coco"): 36 | with open(os.path.join(os.path.dirname(__file__), "object_classes.json")) as f: 37 | object_classes = json.load(f)[dataset] 38 | 39 | tokens = [ 40 | AddedToken(content=class_name, single_word=True, normalized=True) 41 | for class_name in object_classes 42 | ] 43 | 44 | return tokens 45 | 46 | 47 | def train_unified_wordpiece_tokenizer( 48 | files, 49 | vocab_size, 50 | sentinel_tokens: List[Union[str, AddedToken]] = None, 51 | coord_tokens: List[Union[str, AddedToken]] = None, 52 | object_class_tokens: List[Union[str, AddedToken]] = None, 53 | unk_token: Union[str, AddedToken] = "[UNK]", 54 | pad_token: Union[str, AddedToken] = "[PAD]", 55 | sos_token: Union[str, AddedToken] = "[SOS]", 56 | eos_token: Union[str, AddedToken] = "[EOS]", 57 | additional_special_tokens: List[Union[str, AddedToken]] = None, 58 | min_frequency=0, 59 | clean_text: bool = True, 60 | handle_chinese_chars: bool = True, 61 | strip_accents: Optional[bool] = None, 62 | lowercase: bool = True, 63 | wordpieces_prefix: str = "##", 64 | show_progress=True, 65 | ): 66 | tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) 67 | 68 | tokenizer.normalizer = BertNormalizer( 69 | clean_text=clean_text, 70 | handle_chinese_chars=handle_chinese_chars, 71 | strip_accents=strip_accents, 72 | lowercase=lowercase, 73 | ) 74 | tokenizer.pre_tokenizer = BertPreTokenizer() 75 | tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) 76 | 77 | special_tokens = [] 78 | special_tokens.append(pad_token) 79 | special_tokens.append(unk_token) 80 | special_tokens.append(sos_token) 81 | special_tokens.append(eos_token) 82 | 83 | if sentinel_tokens is not None: 84 | special_tokens.extend(sentinel_tokens) 85 | if coord_tokens is not None: 86 | special_tokens.extend(coord_tokens) 87 | if object_class_tokens is not None: 88 | special_tokens.extend(object_class_tokens) 89 | if additional_special_tokens is not None: 90 | special_tokens.extend(additional_special_tokens) 91 | 92 | trainer = trainers.WordPieceTrainer( 93 | vocab_size=vocab_size, 94 | min_frequency=min_frequency, 95 | show_progress=show_progress, 96 | continuing_subword_prefix=wordpieces_prefix, 97 | special_tokens=special_tokens, 98 | ) 99 | 100 | if isinstance(files, str): 101 | files = [files] 102 | 103 | tokenizer.train(files, trainer=trainer) 104 | 105 | return tokenizer 106 | 107 | 108 | def get_sentinel_to_id_mapping(tokenizer, match_str="[S_"): 109 | sentinel_tokens = { 110 | k: v for k, v in tokenizer.get_vocab().items() if k.startswith(match_str) 111 | } 112 | # Extract the sentinel token id, the id is of the form "[S_0]", "[S_1]", etc. 113 | sentinel_to_id = { 114 | int(k.split("_")[1][:-1]): v 115 | for k, v in sorted(sentinel_tokens.items(), key=lambda x: x[1]) 116 | } 117 | return sentinel_to_id 118 | 119 | 120 | def split_by_sentinel(seq_ids, sentinel_ids): 121 | splits = defaultdict(list) 122 | cur_sentinel = None 123 | for token in seq_ids: 124 | if token in sentinel_ids: 125 | cur_sentinel = token 126 | else: 127 | splits[cur_sentinel].append(token) 128 | 129 | return splits 130 | 131 | 132 | def merge_span_masking(input_seq, decoder_seq, sentinel_ids): 133 | decoder_splits = split_by_sentinel(decoder_seq, sentinel_ids) 134 | out_seq = [] 135 | for token in input_seq: 136 | if token in sentinel_ids: 137 | out_seq.extend(decoder_splits[token]) 138 | else: 139 | out_seq.append(token) 140 | return out_seq 141 | -------------------------------------------------------------------------------- /aion/codecs/modules/spectrum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | 4 | 5 | def interp1d( 6 | x: Float[torch.Tensor, " b n"], 7 | y: Float[torch.Tensor, " b n"], 8 | xnew: Float[torch.Tensor, " b m"], 9 | mask_value: float | None = 0.0, 10 | ) -> Float[torch.Tensor, " b m"]: 11 | """Linear interpolation of a 1-D tensor using torch.searchsorted. 12 | Assumes that x and xnew are sorted in increasing order. 13 | 14 | Args: 15 | x: The x-coordinates of the data points, shape [batch, N]. 16 | y: The y-coordinates of the data points, shape [batch, N]. 17 | xnew: The x-coordinates of the interpolated points, shape [batch, M]. 18 | mask_value: The value to use for xnew outside the range of x. 19 | Returns: 20 | The y-coordinates of the interpolated points, shape [batch, M]. 21 | """ 22 | # Find the indices where xnew should be inserted in sorted_x 23 | # Given a point xnew[i] in xnew, return j where x[j] is the nearest point in x such that 24 | # x[j] < xnew[i], except if the nearest point in x has x[j] = xnew[i] then return j - 1. 25 | indices = torch.searchsorted(x, xnew) - 1 26 | 27 | # We can define a local linear approx of the grad in each interval 28 | # between two points in x, and we would like to use this to interpolate 29 | # y at those points in xnew which lie inside the range of x, otherwise 30 | # interpolated_y is masked for points in xnew outside the range of x. 31 | # There are len(x) - 1 such intervals between points in x, having indices 32 | # ranging between 0 and len(x) - 2. Points with xnew < min(x) will be 33 | # assigned indices of -1 and points with xnew > max(x) will be assigned 34 | # indices equal to len(x). These are not valid segment indices, but we can 35 | # clamp them to 0 and len(x) - 2 respectively to avoid breaking the 36 | # calculation of the slope variable. The nonsense values we obtain outside 37 | # the range of x will be discarded when masking. 38 | indices = torch.clamp(indices, 0, x.shape[1] - 1 - 1) 39 | 40 | slopes = (y[:, :-1] - y[:, 1:]) / (x[:, :-1] - x[:, 1:]) 41 | 42 | # Interpolate the y-coordinates 43 | ynew = torch.gather(y, 1, indices) + ( 44 | xnew - torch.gather(x, 1, indices) 45 | ) * torch.gather(slopes, 1, indices) 46 | 47 | # Mask out the values that are outside the valid range 48 | mask = (xnew < x[..., 0].reshape(-1, 1)) | (xnew > x[..., -1].reshape(-1, 1)) 49 | ynew[mask] = mask_value 50 | 51 | return ynew 52 | 53 | 54 | class LatentSpectralGrid(torch.nn.Module): 55 | def __init__(self, lambda_min: float, resolution: float, num_pixels: int): 56 | """ 57 | Initialize a latent grid to represent spectra from multiple resolutions. 58 | 59 | Args: 60 | lambda_min: The minimum wavelength value, in Angstrom. 61 | resolution: The resolution of the spectra, in Angstrom per pixel. 62 | num_pixels: The number of pixels in the spectra. 63 | 64 | """ 65 | super().__init__() 66 | self.register_buffer("lambda_min", torch.tensor(lambda_min)) 67 | self.register_buffer("resolution", torch.tensor(resolution)) 68 | self.register_buffer("length", torch.tensor(num_pixels)) 69 | self.register_buffer( 70 | "_wavelength", 71 | (torch.arange(0, num_pixels) * resolution + lambda_min).reshape( 72 | 1, num_pixels 73 | ), 74 | ) 75 | 76 | @property 77 | def wavelength(self) -> Float[torch.Tensor, " n"]: 78 | return self._wavelength.squeeze() 79 | 80 | def to_observed( 81 | self, 82 | x_latent: Float[torch.Tensor, " b n"], 83 | wavelength: Float[torch.Tensor, " b m"], 84 | ) -> Float[torch.Tensor, " b m"]: 85 | """Transforms the latent representation to the observed wavelength grid. 86 | 87 | Args: 88 | x_latent: The latent representation, [batch, self.num_pixels]. 89 | wavelength: The observed wavelength grid, [batch, M]. 90 | 91 | Returns: 92 | The transformed representation on the observed wavelength grid. 93 | """ 94 | b = x_latent.shape[0] 95 | return interp1d(self._wavelength.repeat([b, 1]), x_latent, wavelength) 96 | 97 | def to_latent( 98 | self, x_obs: Float[torch.Tensor, "b m"], wavelength: Float[torch.Tensor, "b m"] 99 | ) -> Float[torch.Tensor, "b n"]: 100 | """Transforms the observed representation to the latent wavelength grid. 101 | 102 | Args: 103 | x_obs: The observed representation, [batch, N]. 104 | wavelength: The wavelength grid, [batch, N]. 105 | 106 | Returns: 107 | The transformed representation on the latent wavelength grid. 108 | """ 109 | b = x_obs.shape[0] 110 | return interp1d(wavelength, x_obs, self._wavelength.repeat([b, 1])) 111 | -------------------------------------------------------------------------------- /aion/codecs/manager.py: -------------------------------------------------------------------------------- 1 | """Codec Manager for AION. 2 | 3 | Handles dynamic loading and management of codecs for different modalities. 4 | """ 5 | 6 | from dataclasses import asdict 7 | from functools import lru_cache 8 | 9 | import torch 10 | 11 | from aion.codecs.base import Codec 12 | from aion.codecs.config import MODALITY_CODEC_MAPPING, CodecType, HF_REPO_ID 13 | from aion.modalities import Modality 14 | 15 | 16 | class ModalityTypeError(TypeError): 17 | """Error raised when a modality type is not supported.""" 18 | 19 | 20 | class TokenKeyError(ValueError): 21 | """Error raised when a token key is not found in the tokens dictionary.""" 22 | 23 | 24 | class CodecManager: 25 | """Manager for loading and using codecs for different modalities.""" 26 | 27 | def __init__(self, device: str | torch.device = "cpu"): 28 | """Initialize the codec manager. 29 | 30 | Args: 31 | device: Device to load codecs on 32 | cache_dir: Optional cache directory for downloaded models 33 | """ 34 | self.device = device 35 | 36 | @staticmethod 37 | @lru_cache 38 | def _load_codec_from_hf( 39 | codec_class: CodecType, modality_type: type[Modality] 40 | ) -> Codec: 41 | """Load a codec from HuggingFace. 42 | Although HF download is already cached, 43 | the method is cached to avoid reloading the same codec. 44 | 45 | Args: 46 | codec_class: The class of the codec to load 47 | hf_codec_repo_id: The HuggingFace repository ID of the codec 48 | 49 | Returns: 50 | The loaded codec 51 | """ 52 | 53 | codec = codec_class.from_pretrained(HF_REPO_ID, modality=modality_type) 54 | codec = codec.eval() 55 | return codec 56 | 57 | @lru_cache 58 | def _load_codec(self, modality_type: type[Modality]) -> Codec: 59 | """Load a codec for the given modality type.""" 60 | # Look up configuration in CODEC_CONFIG 61 | if modality_type in MODALITY_CODEC_MAPPING: 62 | codec_class = MODALITY_CODEC_MAPPING[modality_type] 63 | else: 64 | raise ModalityTypeError( 65 | f"No codec configuration found for modality type: {modality_type.__name__}" 66 | ) 67 | 68 | codec = self._load_codec_from_hf(codec_class, modality_type) 69 | 70 | return codec 71 | 72 | @torch.no_grad() 73 | def encode(self, *modalities: Modality) -> dict[str, torch.Tensor]: 74 | """Encode multiple modalities. 75 | 76 | Args: 77 | *modalities: Variable number of modality instances to encode 78 | 79 | Returns: 80 | Dictionary mapping token keys to encoded tensors 81 | """ 82 | tokens = {} 83 | 84 | for modality in modalities: 85 | if not isinstance(modality, Modality): 86 | raise ModalityTypeError( 87 | f"Modality {type(modality).__name__} does not have a token_key attribute" 88 | ) 89 | # Get the appropriate codec 90 | codec = self._load_codec(type(modality)) 91 | codec = codec.to(self.device) 92 | 93 | # Tokenize the modality 94 | tokenized = codec.encode(modality) 95 | 96 | tokens[modality.token_key] = tokenized 97 | 98 | return tokens 99 | 100 | @torch.no_grad() 101 | def decode( 102 | self, 103 | tokens: dict[str, torch.Tensor], 104 | modality_type: type[Modality], 105 | **metadata, 106 | ) -> Modality: 107 | """Decode tokens back to a modality. 108 | 109 | Args: 110 | tokens: Dictionary mapping token keys to tokenized tensors 111 | modality_type: The modality type (e.g., DESISpectrum) to decode into 112 | **metadata: Additional metadata required by the specific codec 113 | (e.g., wavelength for spectra, bands for images) 114 | 115 | Returns: 116 | Decoded modality instance 117 | """ 118 | if not issubclass(modality_type, Modality): 119 | raise ModalityTypeError( 120 | f"Modality type {modality_type} does not have a token_key attribute" 121 | ) 122 | 123 | token_key = modality_type.token_key 124 | if token_key not in tokens: 125 | raise TokenKeyError( 126 | f"Token key '{token_key}' for modality {modality_type} not found in tokens dictionary" 127 | ) 128 | 129 | # Get the appropriate codec 130 | codec = self._load_codec(modality_type) 131 | codec = codec.to(self.device) 132 | 133 | # Decode using the codec with any provided metadata 134 | decoded_modality = codec.decode(tokens[token_key], **metadata) 135 | 136 | # Cast decoded modality to the correct type 137 | decoded_modality = modality_type(**asdict(decoded_modality)) 138 | 139 | return decoded_modality 140 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | This page provides comprehensive API documentation for all AION components, automatically generated from the source code. 5 | 6 | .. currentmodule:: aion 7 | 8 | Main Model 9 | ---------- 10 | 11 | .. automodule:: aion.model 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: 15 | 16 | Modalities 17 | ---------- 18 | 19 | The modality system defines data structures for all 39 astronomical data types supported by AION. 20 | 21 | Base Classes 22 | ~~~~~~~~~~~~ 23 | 24 | .. automodule:: aion.modalities 25 | :members: Modality, Image, Spectrum, Scalar 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | Image Modalities 30 | ~~~~~~~~~~~~~~~~ 31 | 32 | .. automodule:: aion.modalities 33 | :members: LegacySurveyImage, HSCImage 34 | :undoc-members: 35 | :show-inheritance: 36 | 37 | Spectrum Modalities 38 | ~~~~~~~~~~~~~~~~~~~ 39 | 40 | .. automodule:: aion.modalities 41 | :members: DESISpectrum, SDSSSpectrum 42 | :undoc-members: 43 | :show-inheritance: 44 | 45 | Catalog Modalities 46 | ~~~~~~~~~~~~~~~~~~ 47 | 48 | .. automodule:: aion.modalities 49 | :members: LegacySurveyCatalog, LegacySurveySegmentationMap 50 | :undoc-members: 51 | :show-inheritance: 52 | 53 | Scalar Modalities 54 | ~~~~~~~~~~~~~~~~~ 55 | 56 | Legacy Survey Scalars 57 | ^^^^^^^^^^^^^^^^^^^^^^ 58 | 59 | .. automodule:: aion.modalities 60 | :members: LegacySurveyFluxG, LegacySurveyFluxR, LegacySurveyFluxI, LegacySurveyFluxZ, LegacySurveyFluxW1, LegacySurveyFluxW2, LegacySurveyFluxW3, LegacySurveyFluxW4, LegacySurveyShapeR, LegacySurveyShapeE1, LegacySurveyShapeE2, LegacySurveyEBV 61 | :undoc-members: 62 | :show-inheritance: 63 | 64 | HSC Scalars 65 | ~~~~~~~~~~~ 66 | 67 | .. automodule:: aion.modalities 68 | :members: HSCAG, HSCAR, HSCAI, HSCAZ, HSCAY, HSCMagG, HSCMagR, HSCMagI, HSCMagZ, HSCMagY, HSCShape11, HSCShape22, HSCShape12 69 | :undoc-members: 70 | :show-inheritance: 71 | 72 | Gaia Scalars 73 | ~~~~~~~~~~~~ 74 | 75 | .. automodule:: aion.modalities 76 | :members: GaiaFluxG, GaiaFluxBp, GaiaFluxRp, GaiaParallax, GaiaXpBp, GaiaXpRp 77 | :undoc-members: 78 | :show-inheritance: 79 | 80 | Coordinate Scalars 81 | ~~~~~~~~~~~~~~~~~~ 82 | 83 | .. automodule:: aion.modalities 84 | :members: Ra, Dec, Z 85 | :undoc-members: 86 | :show-inheritance: 87 | 88 | Utility Types 89 | ~~~~~~~~~~~~~ 90 | 91 | .. automodule:: aion.modalities 92 | :members: ScalarModalities, ModalityType 93 | :undoc-members: 94 | :show-inheritance: 95 | 96 | Codec System 97 | ------------ 98 | 99 | The codec system handles tokenization of different modality types. 100 | 101 | Core Codec Classes 102 | ~~~~~~~~~~~~~~~~~~ 103 | 104 | .. automodule:: aion.codecs.manager 105 | :members: 106 | :undoc-members: 107 | :show-inheritance: 108 | 109 | .. automodule:: aion.codecs.base 110 | :members: 111 | :undoc-members: 112 | :show-inheritance: 113 | 114 | Codec Implementations 115 | ~~~~~~~~~~~~~~~~~~~~~ 116 | 117 | .. automodule:: aion.codecs.image 118 | :members: 119 | :undoc-members: 120 | :show-inheritance: 121 | 122 | .. automodule:: aion.codecs.spectrum 123 | :members: 124 | :undoc-members: 125 | :show-inheritance: 126 | 127 | .. automodule:: aion.codecs.catalog 128 | :members: 129 | :undoc-members: 130 | :show-inheritance: 131 | 132 | .. automodule:: aion.codecs.scalar_field 133 | :members: 134 | :undoc-members: 135 | :show-inheritance: 136 | 137 | .. automodule:: aion.codecs.scalar 138 | :members: 139 | :undoc-members: 140 | :show-inheritance: 141 | 142 | Quantizers 143 | ~~~~~~~~~~ 144 | 145 | .. automodule:: aion.codecs.quantizers 146 | :members: 147 | :undoc-members: 148 | :show-inheritance: 149 | 150 | .. automodule:: aion.codecs.quantizers.scalar 151 | :members: 152 | :undoc-members: 153 | :show-inheritance: 154 | 155 | 4M Transformer 156 | -------------- 157 | 158 | Core transformer architecture and components. 159 | 160 | Main Transformer 161 | ~~~~~~~~~~~~~~~~ 162 | 163 | .. automodule:: aion.fourm.fm 164 | :members: 165 | :undoc-members: 166 | :show-inheritance: 167 | 168 | Embedding Layers 169 | ~~~~~~~~~~~~~~~~ 170 | 171 | .. automodule:: aion.fourm.encoder_embeddings 172 | :members: 173 | :undoc-members: 174 | :show-inheritance: 175 | 176 | .. automodule:: aion.fourm.decoder_embeddings 177 | :members: 178 | :undoc-members: 179 | :show-inheritance: 180 | 181 | Transformer Components 182 | ~~~~~~~~~~~~~~~~~~~~~~ 183 | 184 | .. automodule:: aion.fourm.fm_utils 185 | :members: 186 | :undoc-members: 187 | :show-inheritance: 188 | 189 | Generation 190 | ~~~~~~~~~~ 191 | 192 | .. automodule:: aion.fourm.generate 193 | :members: 194 | :undoc-members: 195 | :show-inheritance: 196 | 197 | LoRA Support 198 | ~~~~~~~~~~~~ 199 | 200 | .. automodule:: aion.fourm.lora_utils 201 | :members: 202 | :undoc-members: 203 | :show-inheritance: 204 | 205 | Modality Configuration 206 | ~~~~~~~~~~~~~~~~~~~~~~ 207 | 208 | .. automodule:: aion.fourm.modality_info 209 | :members: 210 | :undoc-members: 211 | :show-inheritance: 212 | 213 | .. automodule:: aion.fourm.modality_transforms 214 | :members: 215 | :undoc-members: 216 | :show-inheritance: 217 | 218 | Codec Modules 219 | ------------- 220 | 221 | Specialized neural network modules used in codecs. 222 | 223 | Architecture Components 224 | ~~~~~~~~~~~~~~~~~~~~~~~ 225 | 226 | .. automodule:: aion.codecs.modules.magvit 227 | :members: 228 | :undoc-members: 229 | :show-inheritance: 230 | 231 | .. automodule:: aion.codecs.modules.convnext 232 | :members: 233 | :undoc-members: 234 | :show-inheritance: 235 | 236 | .. automodule:: aion.codecs.modules.convblocks 237 | :members: 238 | :undoc-members: 239 | :show-inheritance: 240 | 241 | Specialized Modules 242 | ~~~~~~~~~~~~~~~~~~~ 243 | 244 | .. automodule:: aion.codecs.modules.spectrum 245 | :members: 246 | :undoc-members: 247 | :show-inheritance: 248 | 249 | .. automodule:: aion.codecs.modules.ema 250 | :members: 251 | :undoc-members: 252 | :show-inheritance: 253 | 254 | .. automodule:: aion.codecs.modules.subsampler 255 | :members: 256 | :undoc-members: 257 | :show-inheritance: 258 | 259 | Configuration and Utilities 260 | ---------------------------- 261 | 262 | .. automodule:: aion.codecs.config 263 | :members: 264 | :undoc-members: 265 | :show-inheritance: 266 | -------------------------------------------------------------------------------- /aion/codecs/modules/convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from aion.codecs.modules.utils import LayerNorm, GRN 4 | 5 | 6 | class ConvNextBlock1d(torch.nn.Module): 7 | """ConvNeXtV2 Block. 8 | Modified to 1D from the original 2D implementation from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py 9 | 10 | Args: 11 | dim (int): Number of input channels. 12 | drop_path (float): Stochastic depth rate. Default: 0.0 13 | """ 14 | 15 | def __init__(self, dim: int): 16 | super().__init__() 17 | self.dwconv = torch.nn.Conv1d( 18 | dim, dim, kernel_size=7, padding=3, groups=dim 19 | ) # depthwise conv 20 | self.norm = LayerNorm(dim, eps=1e-6) 21 | self.pwconv1 = torch.nn.Linear( 22 | dim, 4 * dim 23 | ) # pointwise/1x1 convs, implemented with linear layers 24 | self.act = torch.nn.GELU() 25 | self.grn = GRN(4 * dim) 26 | self.pwconv2 = torch.nn.Linear(4 * dim, dim) 27 | 28 | def forward(self, x): 29 | y = self.dwconv(x) 30 | y = y.permute(0, 2, 1) # (B, C, N) -> (B, N, C) 31 | y = self.norm(y) 32 | y = self.pwconv1(y) 33 | y = self.act(y) 34 | y = self.grn(y) 35 | y = self.pwconv2(y) 36 | y = y.permute(0, 2, 1) # (B, N, C) -> (B, C, N) 37 | 38 | y = x + y 39 | return y 40 | 41 | 42 | class ConvNextEncoder1d(torch.nn.Module): 43 | r"""ConvNeXt encoder. 44 | 45 | Modified from https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py 46 | 47 | Args: 48 | in_chans : Number of input image channels. Default: 3 49 | depths : Number of blocks at each stage. Default: [3, 3, 9, 3] 50 | dims : Feature dimension at each stage. Default: [96, 192, 384, 768] 51 | drop_path_rate : Stochastic depth rate. Default: 0. 52 | layer_scale_init_value : Init value for Layer Scale. Default: 1e-6. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | in_chans: int = 2, 58 | depths: tuple[int, ...] = (3, 3, 9, 3), 59 | dims: tuple[int, ...] = (96, 192, 384, 768), 60 | ): 61 | super().__init__() 62 | assert len(depths) == len(dims), "depths and dims should have the same length" 63 | num_layers = len(depths) 64 | 65 | self.downsample_layers = ( 66 | torch.nn.ModuleList() 67 | ) # stem and 3 intermediate downsampling conv layers 68 | stem = torch.nn.Sequential( 69 | torch.nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4), 70 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), 71 | ) 72 | self.downsample_layers.append(stem) 73 | for i in range(num_layers - 1): 74 | downsample_layer = torch.nn.Sequential( 75 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 76 | torch.nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2), 77 | ) 78 | self.downsample_layers.append(downsample_layer) 79 | 80 | self.stages = torch.nn.ModuleList() 81 | for i in range(num_layers): 82 | stage = torch.nn.Sequential( 83 | *[ 84 | ConvNextBlock1d( 85 | dim=dims[i], 86 | ) 87 | for j in range(depths[i]) 88 | ] 89 | ) 90 | self.stages.append(stage) 91 | 92 | self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") 93 | 94 | self.apply(self._init_weights) 95 | 96 | def _init_weights(self, m): 97 | if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): 98 | torch.nn.init.trunc_normal_(m.weight, std=0.02) 99 | torch.nn.init.constant_(m.bias, 0) 100 | 101 | def forward(self, x): 102 | for ds, st in zip(self.downsample_layers, self.stages): 103 | x = ds(x) 104 | x = st(x) 105 | return self.norm(x) 106 | 107 | 108 | class ConvNextDecoder1d(torch.nn.Module): 109 | r"""ConvNeXt decoder. Essentially a mirrored version of the encoder. 110 | 111 | Args: 112 | in_chans (int): Number of input image channels. Default: 3 113 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 114 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 115 | drop_path_rate (float): Stochastic depth rate. Default: 0. 116 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 117 | """ 118 | 119 | def __init__( 120 | self, 121 | in_chans=768, 122 | depths=[3, 3, 9, 3], 123 | dims=[384, 192, 96, 2], 124 | ): 125 | super().__init__() 126 | assert len(depths) == len(dims), "depths and dims should have the same length" 127 | num_layers = len(depths) 128 | 129 | self.upsample_layers = torch.nn.ModuleList() 130 | 131 | stem = torch.nn.Sequential( 132 | torch.nn.ConvTranspose1d(in_chans, dims[0], kernel_size=2, stride=2), 133 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), 134 | ) 135 | self.upsample_layers.append(stem) 136 | 137 | for i in range(num_layers - 1): 138 | upsample_layer = torch.nn.Sequential( 139 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 140 | torch.nn.ConvTranspose1d( 141 | dims[i], 142 | dims[i + 1], 143 | kernel_size=2 if i < (num_layers - 2) else 4, 144 | stride=2 if i < (num_layers - 2) else 4, 145 | ), 146 | ) 147 | self.upsample_layers.append(upsample_layer) 148 | 149 | self.stages = torch.nn.ModuleList() 150 | for i in range(num_layers): 151 | stage = torch.nn.Sequential( 152 | *[ 153 | ConvNextBlock1d( 154 | dim=dims[i], 155 | ) 156 | for j in range(depths[i]) 157 | ] 158 | ) 159 | self.stages.append(stage) 160 | 161 | self.apply(self._init_weights) 162 | 163 | def _init_weights(self, m): 164 | if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): 165 | torch.nn.init.trunc_normal_(m.weight, std=0.02) 166 | torch.nn.init.constant_(m.bias, 0) 167 | 168 | def forward(self, x): 169 | for us, st in zip(self.upsample_layers, self.stages): 170 | x = us(x) 171 | x = st(x) 172 | return x 173 | -------------------------------------------------------------------------------- /aion/codecs/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from threading import local 3 | from typing import Optional 4 | 5 | from huggingface_hub import hub_mixin 6 | 7 | from aion.codecs.base import Codec 8 | from aion.modalities import Modality 9 | 10 | 11 | ORIGINAL_CONFIG_NAME = hub_mixin.constants.CONFIG_NAME 12 | ORIGINAL_PYTORCH_WEIGHTS_NAME = hub_mixin.constants.PYTORCH_WEIGHTS_NAME 13 | ORIGINAL_SAFETENSORS_SINGLE_FILE = hub_mixin.constants.SAFETENSORS_SINGLE_FILE 14 | 15 | # Thread-local storage for codec context 16 | _thread_local = local() 17 | 18 | 19 | @contextmanager 20 | def _codec_path_context(modality: type[Modality]): 21 | """Thread-safe context manager for temporarily overriding HuggingFace constants. 22 | 23 | Args: 24 | modality: The modality type to create paths for 25 | 26 | Yields: 27 | None 28 | """ 29 | # Store original values 30 | original_config = hub_mixin.constants.CONFIG_NAME 31 | original_weights = hub_mixin.constants.PYTORCH_WEIGHTS_NAME 32 | original_safetensors = hub_mixin.constants.SAFETENSORS_SINGLE_FILE 33 | 34 | try: 35 | # Set codec-specific paths 36 | hub_mixin.constants.CONFIG_NAME = ( 37 | f"codecs/{modality.name}/{ORIGINAL_CONFIG_NAME}" 38 | ) 39 | hub_mixin.constants.PYTORCH_WEIGHTS_NAME = ( 40 | f"codecs/{modality.name}/{ORIGINAL_PYTORCH_WEIGHTS_NAME}" 41 | ) 42 | hub_mixin.constants.SAFETENSORS_SINGLE_FILE = ( 43 | f"codecs/{modality.name}/{ORIGINAL_SAFETENSORS_SINGLE_FILE}" 44 | ) 45 | yield 46 | finally: 47 | # Always restore original values 48 | hub_mixin.constants.CONFIG_NAME = original_config 49 | hub_mixin.constants.PYTORCH_WEIGHTS_NAME = original_weights 50 | hub_mixin.constants.SAFETENSORS_SINGLE_FILE = original_safetensors 51 | 52 | 53 | def _validate_modality(modality: type[Modality]) -> None: 54 | """Validate that the modality is properly configured. 55 | 56 | Args: 57 | modality: The modality type to validate 58 | 59 | Raises: 60 | ValueError: If the modality is invalid 61 | """ 62 | if not isinstance(modality, type): 63 | raise ValueError(f"Expected modality to be a type, got {type(modality)}") 64 | 65 | if not issubclass(modality, Modality): 66 | raise ValueError(f"Modality {modality} must be a subclass of Modality") 67 | 68 | if not hasattr(modality, "name") or not isinstance(modality.name, str): 69 | raise ValueError( 70 | f"Modality {modality} must have a 'name' class attribute of type str" 71 | ) 72 | 73 | if not modality.name.strip(): 74 | raise ValueError(f"Modality {modality} name cannot be empty") 75 | 76 | 77 | class CodecPytorchHubMixin(hub_mixin.PyTorchModelHubMixin): 78 | """Mixin for PyTorch models that correspond to codecs. 79 | Codec don't have their own model repo. 80 | Instead they lie in the transformer model repo as subfolders. 81 | """ 82 | 83 | @staticmethod 84 | def _validate_codec_modality(codec: type[Codec], modality: type[Modality]): 85 | """Validate that a codec class is compatible with a modality. 86 | 87 | Args: 88 | codec: The codec class to validate 89 | modality: The modality type to validate against 90 | 91 | Raises: 92 | TypeError: If the codec is not a valid codec class or is incompatible with the modality 93 | ValueError: If the modality has no corresponding codec configuration 94 | """ 95 | # Import MODALITY_CODEC_MAPPING here to avoid circular import 96 | from aion.codecs.config import MODALITY_CODEC_MAPPING 97 | 98 | if not issubclass(codec, Codec): 99 | raise TypeError("Only codecs can be loaded using this method.") 100 | if modality not in MODALITY_CODEC_MAPPING: 101 | raise ValueError(f"Modality {modality} has no corresponding codec.") 102 | elif MODALITY_CODEC_MAPPING[modality] != codec: 103 | raise TypeError( 104 | f"Modality {modality} is associated with {MODALITY_CODEC_MAPPING[modality]} codec but {codec} requested." 105 | ) 106 | 107 | @classmethod 108 | def from_pretrained( 109 | cls, 110 | pretrained_model_name_or_path, 111 | modality: type[Modality], 112 | *model_args, 113 | **kwargs, 114 | ): 115 | """Load a codec model from a pretrained model repository. 116 | 117 | Args: 118 | pretrained_model_name_or_path (str): The name or path of the pretrained 119 | model repository. 120 | modality (type[Modality]): The modality type for this codec. 121 | *model_args: Additional positional arguments to pass to the model 122 | constructor. 123 | **kwargs: Additional keyword arguments to pass to the model 124 | constructor. 125 | 126 | Returns: 127 | The loaded codec model. 128 | 129 | Raises: 130 | ValueError: If the class is not a codec subclass or modality is invalid. 131 | """ 132 | # Validate codec-modality compatibility 133 | cls._validate_codec_modality(cls, modality) 134 | 135 | # Validate modality 136 | _validate_modality(modality) 137 | 138 | # Use thread-safe context manager to override paths 139 | with _codec_path_context(modality): 140 | model = super().from_pretrained( 141 | pretrained_model_name_or_path, *model_args, **kwargs 142 | ) 143 | 144 | # Store modality reference on the model instance for later use 145 | model._modality = modality 146 | return model 147 | 148 | def save_pretrained( 149 | self, save_directory, modality: Optional[type[Modality]] = None, *args, **kwargs 150 | ): 151 | """Save the codec model to a pretrained model repository. 152 | 153 | Args: 154 | save_directory (str): The directory to save the model to. 155 | modality (Optional[type[Modality]]): The modality type for this codec. 156 | If not provided, will use the modality stored during from_pretrained. 157 | *args: Additional positional arguments to pass to the save method. 158 | **kwargs: Additional keyword arguments to pass to the save method. 159 | 160 | Raises: 161 | ValueError: If the instance is not a codec or modality cannot be determined. 162 | """ 163 | if not issubclass(self.__class__, Codec): 164 | raise ValueError("Only codec instances can be saved using this method.") 165 | 166 | # Determine modality to use 167 | if modality is not None: 168 | _validate_modality(modality) 169 | target_modality = modality 170 | elif hasattr(self, "_modality"): 171 | target_modality = self._modality 172 | else: 173 | raise ValueError( 174 | "No modality specified. Either provide modality parameter or " 175 | "load the codec using from_pretrained() which stores the modality." 176 | ) 177 | 178 | # Construct the path to the codec subfolder 179 | codec_path = f"{save_directory}/codecs/{target_modality.name}" 180 | super().save_pretrained(codec_path, *args, **kwargs) 181 | -------------------------------------------------------------------------------- /aion/codecs/modules/magvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | from einops.layers.torch import Rearrange 4 | 5 | 6 | def cast_tuple(t, length=1): 7 | return t if isinstance(t, tuple) else ((t,) * length) 8 | 9 | 10 | class SameConv2d(torch.nn.Module): 11 | def __init__(self, dim_in, dim_out, kernel_size): 12 | super().__init__() 13 | kernel_size = cast_tuple(kernel_size, 2) 14 | padding = [k // 2 for k in kernel_size] 15 | self.conv = torch.nn.Conv2d( 16 | dim_in, dim_out, kernel_size=kernel_size, padding=padding 17 | ) 18 | 19 | def forward(self, x: torch.Tensor): 20 | return self.conv(x) 21 | 22 | 23 | class SqueezeExcite(torch.nn.Module): 24 | # global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375) 25 | 26 | def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10): 27 | super().__init__() 28 | dim_out = dim_out if dim_out is not None else dim 29 | 30 | self.to_k = torch.nn.Conv2d(dim, 1, 1) 31 | dim_hidden = max(dim_hidden_min, dim_out // 2) 32 | 33 | self.net = torch.nn.Sequential( 34 | torch.nn.Conv2d(dim, dim_hidden, 1), 35 | torch.nn.LeakyReLU(0.1), 36 | torch.nn.Conv2d(dim_hidden, dim_out, 1), 37 | torch.nn.Sigmoid(), 38 | ) 39 | 40 | torch.nn.init.zeros_(self.net[-2].weight) 41 | torch.nn.init.constant_(self.net[-2].bias, init_bias) 42 | 43 | def forward(self, x): 44 | context = self.to_k(x) 45 | 46 | context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1) 47 | spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)") 48 | 49 | out = torch.einsum("b i n, b c n -> b c i", context, spatial_flattened_input) 50 | out = rearrange(out, "... -> ... 1") 51 | gates = self.net(out) 52 | 53 | return gates * x 54 | 55 | 56 | class ResidualUnit(torch.nn.Module): 57 | def __init__(self, dim: int, kernel_size: int | tuple[int, int, int]): 58 | super().__init__() 59 | self.net = torch.nn.Sequential( 60 | SameConv2d(dim, dim, kernel_size), 61 | torch.nn.ELU(), 62 | torch.nn.Conv2d(dim, dim, 1), 63 | torch.nn.ELU(), 64 | SqueezeExcite(dim), 65 | ) 66 | 67 | def forward(self, x: torch.Tensor): 68 | return self.net(x) + x 69 | 70 | 71 | class SpatialDownsample2x(torch.nn.Module): 72 | def __init__( 73 | self, 74 | dim: int, 75 | dim_out: int = None, 76 | kernel_size: int = 3, 77 | ): 78 | super().__init__() 79 | dim_out = dim_out if dim_out is not None else dim 80 | self.conv = torch.nn.Conv2d( 81 | dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2 82 | ) 83 | 84 | def forward(self, x: torch.Tensor): 85 | out = self.conv(x) 86 | return out 87 | 88 | 89 | class SpatialUpsample2x(torch.nn.Module): 90 | def __init__(self, dim: int, dim_out: int = None): 91 | super().__init__() 92 | dim_out = dim_out if dim_out is not None else dim 93 | conv = torch.nn.Conv2d(dim, dim_out * 4, 1) 94 | 95 | self.net = torch.nn.Sequential( 96 | conv, 97 | torch.nn.SiLU(), 98 | Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2), 99 | ) 100 | 101 | self.init_conv_(conv) 102 | 103 | def init_conv_(self, conv: torch.nn.Module): 104 | o, i, h, w = conv.weight.shape 105 | conv_weight = torch.empty(o // 4, i, h, w) 106 | torch.nn.init.kaiming_uniform_(conv_weight) 107 | conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") 108 | 109 | conv.weight.data.copy_(conv_weight) 110 | torch.nn.init.zeros_(conv.bias.data) 111 | 112 | def forward(self, x: torch.Tensor): 113 | out = self.net(x) 114 | return out 115 | 116 | 117 | class MagVitAE(torch.nn.Module): 118 | """MagViTAE implementation from Yu, et al. (2024), adapted for Pytorch. 119 | Code borrowed from https://github.com/lucidrains/magvit2-pytorch, and adapted for images. 120 | """ 121 | 122 | def __init__( 123 | self, 124 | n_bands: int = 3, 125 | hidden_dims: int = 512, 126 | residual_conv_kernel_size: int = 3, 127 | n_compressions: int = 2, 128 | num_consecutive: int = 2, 129 | ): 130 | super().__init__() 131 | 132 | self.encoder_layers = torch.nn.ModuleList([]) 133 | self.decoder_layers = torch.nn.ModuleList([]) 134 | init_dim = int(hidden_dims / 2**n_compressions) 135 | dim = init_dim 136 | 137 | self.conv_in = SameConv2d(n_bands, init_dim, 7) 138 | self.conv_out = SameConv2d(init_dim, n_bands, 3) 139 | 140 | # Residual layers 141 | encoder_layer = ResidualUnit(dim, residual_conv_kernel_size) 142 | decoder_layer = ResidualUnit(dim, residual_conv_kernel_size) 143 | self.encoder_layers.append(encoder_layer) 144 | self.decoder_layers.insert(0, decoder_layer) 145 | 146 | # Compressions 147 | for i in range(n_compressions): 148 | dim_out = dim * 2 149 | encoder_layer = SpatialDownsample2x(dim, dim_out) 150 | decoder_layer = SpatialUpsample2x(dim_out, dim) 151 | self.encoder_layers.append(encoder_layer) 152 | self.decoder_layers.insert(0, decoder_layer) 153 | dim = dim_out 154 | 155 | # Consecutive residual layers 156 | encoder_layer = torch.nn.Sequential( 157 | *[ 158 | ResidualUnit(dim, residual_conv_kernel_size) 159 | for _ in range(num_consecutive) 160 | ] 161 | ) 162 | decoder_layer = torch.nn.Sequential( 163 | *[ 164 | ResidualUnit(dim, residual_conv_kernel_size) 165 | for _ in range(num_consecutive) 166 | ] 167 | ) 168 | self.encoder_layers.append(encoder_layer) 169 | self.decoder_layers.insert(0, decoder_layer) 170 | 171 | # Add a final non-compress layer 172 | dim_out = dim 173 | encoder_layer = SameConv2d(dim, dim_out, 7) 174 | decoder_layer = SameConv2d(dim_out, dim, 3) 175 | self.encoder_layers.append(encoder_layer) 176 | self.decoder_layers.insert(0, decoder_layer) 177 | dim = dim_out 178 | 179 | # Consecutive residual layers 180 | encoder_layer = torch.nn.Sequential( 181 | *[ 182 | ResidualUnit(dim, residual_conv_kernel_size) 183 | for _ in range(num_consecutive) 184 | ] 185 | ) 186 | decoder_layer = torch.nn.Sequential( 187 | *[ 188 | ResidualUnit(dim, residual_conv_kernel_size) 189 | for _ in range(num_consecutive) 190 | ] 191 | ) 192 | self.encoder_layers.append(encoder_layer) 193 | self.decoder_layers.insert(0, decoder_layer) 194 | 195 | # add a final norm just before quantization layer 196 | self.encoder_layers.append( 197 | torch.nn.Sequential( 198 | Rearrange("b c ... -> b ... c"), 199 | torch.nn.LayerNorm(dim), 200 | Rearrange("b ... c -> b c ..."), 201 | ) 202 | ) 203 | 204 | def encode(self, x: torch.Tensor): 205 | x = self.conv_in(x) 206 | for layer in self.encoder_layers: 207 | x = layer(x) 208 | return x 209 | 210 | def decode(self, x: torch.Tensor): 211 | for layer in self.decoder_layers: 212 | x = layer(x) 213 | x = self.conv_out(x) 214 | return x 215 | -------------------------------------------------------------------------------- /aion/codecs/image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | from typing import Type, Optional, List 5 | 6 | from aion.modalities import Image 7 | from aion.codecs.modules.magvit import MagVitAE 8 | from aion.codecs.modules.subsampler import SubsampledLinear 9 | from aion.codecs.quantizers import FiniteScalarQuantizer, Quantizer 10 | from aion.codecs.base import Codec 11 | from aion.codecs.preprocessing.image import ( 12 | ImagePadder, 13 | CenterCrop, 14 | RescaleToLegacySurvey, 15 | Clamp, 16 | ) 17 | from aion.codecs.preprocessing.band_to_index import BAND_TO_INDEX 18 | from aion.codecs.utils import CodecPytorchHubMixin 19 | 20 | 21 | class AutoencoderImageCodec(Codec): 22 | """Meta-class for autoencoder codecs for images, does not actually contain a network.""" 23 | 24 | def __init__( 25 | self, 26 | quantizer: Quantizer, 27 | encoder: torch.nn.Module, 28 | decoder: torch.nn.Module, 29 | hidden_dims: int = 64, 30 | embedding_dim: int = 5, 31 | multisurvey_projection_dims: int = 54, 32 | range_compression_factor: float = 0.01, 33 | mult_factor: float = 10.0, 34 | ): 35 | super().__init__() 36 | self._quantizer = quantizer 37 | self.range_compression_factor = range_compression_factor 38 | self.mult_factor = mult_factor 39 | self.encoder = encoder 40 | self.decoder = decoder 41 | 42 | # Preprocessing 43 | self.clamp = Clamp() 44 | self.center_crop = CenterCrop(crop_size=96) 45 | self.rescaler = RescaleToLegacySurvey() 46 | 47 | # Handle multi-survey projection 48 | self.image_padder = ImagePadder() 49 | self.subsample_in = SubsampledLinear( 50 | dim_in=self.image_padder.nbands, 51 | dim_out=multisurvey_projection_dims, 52 | subsample_in=True, 53 | ) 54 | self.subsample_out = SubsampledLinear( 55 | dim_in=multisurvey_projection_dims, 56 | dim_out=self.image_padder.nbands, 57 | subsample_in=False, 58 | ) 59 | # Go down to size of levels 60 | self.pre_quant_proj = torch.nn.Conv2d( 61 | hidden_dims, embedding_dim, kernel_size=1, stride=1, padding=0 62 | ) 63 | 64 | # Go back to the original size 65 | self.post_quant_proj = torch.nn.Conv2d( 66 | embedding_dim, hidden_dims, kernel_size=1, stride=1, padding=0 67 | ) 68 | 69 | @property 70 | def quantizer(self) -> Quantizer: 71 | return self._quantizer 72 | 73 | @property 74 | def modality(self) -> Type[Image]: 75 | return Image 76 | 77 | def _get_survey(self, bands: List[str]) -> str: 78 | survey = bands[0].split("-")[0] 79 | return survey 80 | 81 | def _range_compress(self, x: Tensor) -> Tensor: 82 | x = ( 83 | torch.arcsinh(x / self.range_compression_factor) 84 | * self.range_compression_factor 85 | ) 86 | x = x * self.mult_factor 87 | return x 88 | 89 | def _reverse_range_compress(self, x: Tensor) -> Tensor: 90 | x = x / self.mult_factor 91 | x = ( 92 | torch.sinh(x / self.range_compression_factor) 93 | * self.range_compression_factor 94 | ) 95 | return x 96 | 97 | def _encode(self, x: Image) -> Float[torch.Tensor, "b c w*h"]: 98 | flux_tensor = x.flux 99 | bands_in = x.bands 100 | 101 | processed_flux = self.center_crop(flux_tensor) 102 | processed_flux = self.clamp(processed_flux, bands_in) 103 | processed_flux = self.rescaler.forward( 104 | processed_flux, self._get_survey(bands_in) 105 | ) 106 | processed_flux = self._range_compress(processed_flux) 107 | 108 | processed_flux, channel_mask = self.image_padder.forward( 109 | processed_flux, bands_in 110 | ) 111 | processed_flux = self.subsample_in(processed_flux, channel_mask) 112 | 113 | h = self.encoder(processed_flux) 114 | h = self.pre_quant_proj(h) 115 | 116 | # Flatten the spatial dimensions 117 | h = h.reshape(h.shape[0], h.shape[1], -1) 118 | return h 119 | 120 | def _decode( 121 | self, z: Float[torch.Tensor, "b c w*h"], bands: Optional[List[str]] = None 122 | ) -> Image: 123 | # z is flattened, need to reshape 124 | batch_size, embedding_dim, n_tokens = z.shape 125 | spatial_size = int(n_tokens**0.5) 126 | z = z.reshape(batch_size, embedding_dim, spatial_size, spatial_size) 127 | 128 | h = self.post_quant_proj(z) 129 | decoded_flux_raw = self.decoder(h) 130 | 131 | full_dim_channel_mask = torch.ones( 132 | (z.shape[0], self.image_padder.nbands), device=z.device, dtype=torch.bool 133 | ) 134 | decoded_flux_padded = self.subsample_out( 135 | decoded_flux_raw, full_dim_channel_mask 136 | ) 137 | 138 | decoded_flux_compressed = self._reverse_range_compress(decoded_flux_padded) 139 | 140 | if bands is None: 141 | target_bands = list(BAND_TO_INDEX.keys()) 142 | else: 143 | target_bands = bands 144 | 145 | final_flux = self.image_padder.backward(decoded_flux_compressed, target_bands) 146 | final_flux = self.rescaler.backward(final_flux, self._get_survey(target_bands)) 147 | 148 | return Image(flux=final_flux, bands=target_bands) 149 | 150 | def decode( 151 | self, z: Float[Tensor, "b c"], bands: Optional[List[str]] = None 152 | ) -> Image: 153 | """ 154 | Decodes the given latent tensor `z` back into an Image object. 155 | 156 | Args: 157 | z: The latent tensor to decode. 158 | bands (Optional[List[str]]): A list of band names to decode. 159 | If None or not provided, all default bands ('DES-G', 'DES-R', 'DES-I', 'DES-Z', 160 | 'HSC-G', 'HSC-R', 'HSC-I', 'HSC-Z', 'HSC-Y') 161 | will be decoded. 162 | Returns: 163 | An Image object. 164 | """ 165 | return super().decode(z, bands=bands) 166 | 167 | 168 | class ImageCodec(AutoencoderImageCodec, CodecPytorchHubMixin): 169 | def __init__( 170 | self, 171 | quantizer_levels: List[int], 172 | hidden_dims: int = 512, 173 | multisurvey_projection_dims: int = 54, 174 | n_compressions: int = 2, 175 | num_consecutive: int = 4, 176 | embedding_dim: int = 5, 177 | range_compression_factor: float = 0.01, 178 | mult_factor: float = 10.0, 179 | ): 180 | """ 181 | MagViT Autoencoder for images. 182 | 183 | Args: 184 | quantizer_levels: Levels for the FiniteScalarQuantizer. 185 | hidden_dims: Number of hidden dimensions in the network. 186 | n_compressions: Number of compressions in the network. 187 | num_consecutive: Number of consecutive residual layers per compression. 188 | embedding_dim: Dimension of the latent space. 189 | range_compression_factor: Range compression factor. 190 | mult_factor: Multiplication factor. 191 | """ 192 | model = MagVitAE( 193 | n_bands=multisurvey_projection_dims, 194 | hidden_dims=hidden_dims, 195 | n_compressions=n_compressions, 196 | num_consecutive=num_consecutive, 197 | ) 198 | quantizer = FiniteScalarQuantizer(levels=quantizer_levels) 199 | super().__init__( 200 | quantizer, 201 | model.encode, 202 | model.decode, 203 | hidden_dims, 204 | embedding_dim, 205 | multisurvey_projection_dims, 206 | range_compression_factor, 207 | mult_factor, 208 | ) 209 | self.model = model 210 | -------------------------------------------------------------------------------- /aion/fourm/lora_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 EPFL and Apple Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import List, Set, Optional, Type 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | SELF_ATTENTION_MODULES = {"Attention", "NormAttention"} 21 | CROSS_ATTENTION_MODULES = {"CrossAttention", "NormCrossAttention"} 22 | ATTENTION_MODULES = SELF_ATTENTION_MODULES | CROSS_ATTENTION_MODULES 23 | MLP_MODULES = {"Mlp", "GatedMlp", "SwiGLUFFNFused"} # SwiGLUFFNFused is from DINOv2 24 | TRANSFORMER_MODULES = ATTENTION_MODULES | MLP_MODULES 25 | 26 | 27 | def get_LoRA_module_names(id: str) -> Set[str]: 28 | """Returns a list of module names that are LoRA-adapted for the given id.""" 29 | id = id.lower() 30 | if id in ["selfattn", "selfattention", "self_attn", "self_attention"]: 31 | return SELF_ATTENTION_MODULES 32 | elif id in ["crossattn", "crossattention", "cross_attn", "cross_attention"]: 33 | return CROSS_ATTENTION_MODULES 34 | elif id in ["attn", "attention"]: 35 | return ATTENTION_MODULES 36 | elif id in ["mlp"]: 37 | return MLP_MODULES 38 | elif id in ["all", "transformer"]: 39 | return TRANSFORMER_MODULES 40 | else: 41 | raise ValueError(f"Unknown LoRA module id {id}.") 42 | 43 | 44 | class LoRAWrapper(nn.Module): 45 | """Low-Rank Adaptation Wrapper for linear layers. 46 | See https://arxiv.org/abs/2106.09685 47 | 48 | Args: 49 | linear: nn.Linear layer to wrap 50 | rank: Rank of adaptation matrix B@A 51 | scale: x = W_0@x + scale * B@A@x 52 | num_packed_linear: Set to > 1 when wrapping e.g. packed kv, or qkv attention weights. 53 | Weights will be initialized as if num_packed_linear = 1, but the LoRA bottleneck will 54 | be num_packed_linear times larger. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | linear: nn.Module, 60 | rank: int = 4, 61 | scale: float = 1.0, 62 | num_packed_linear: int = 1, 63 | ): 64 | super().__init__() 65 | self.rank = rank 66 | self.scale = scale 67 | self.in_features, self.out_features = linear.in_features, linear.out_features 68 | assert num_packed_linear * rank <= min(self.in_features, self.out_features), ( 69 | f"LoRA rank {num_packed_linear} * {rank} must be less or equal than {min(self.in_features, self.out_features)}" 70 | ) 71 | 72 | self.linear = linear 73 | self.lora_down = nn.Linear( 74 | self.in_features, num_packed_linear * rank, bias=False 75 | ) 76 | self.lora_up = nn.Linear( 77 | num_packed_linear * rank, self.out_features, bias=False 78 | ) 79 | 80 | nn.init.normal_(self.lora_down.weight, std=1 / rank) 81 | nn.init.zeros_(self.lora_up.weight) 82 | 83 | def fuse_LoRA_into_linear(self) -> nn.Linear: 84 | """Returns a single nn.Linear layer with the LoRA matrix fused into the original one.""" 85 | fused_linear = nn.Linear( 86 | self.in_features, self.out_features, bias=self.linear.bias is not None 87 | ) 88 | fused_linear.weight.data = self.linear.weight + self.scale * ( 89 | self.lora_up.weight @ self.lora_down.weight 90 | ) 91 | if self.linear.bias is not None: 92 | fused_linear.bias.data = self.linear.bias 93 | return fused_linear 94 | 95 | def forward(self, x: torch.Tensor) -> torch.Tensor: 96 | """LoRA adapted linear layer forward pass.""" 97 | return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale 98 | 99 | 100 | def _find_modules( 101 | model, 102 | ancestor_class: Optional[Set[str]] = None, 103 | search_class: List[Type[nn.Module]] = [nn.Linear], 104 | exclude_children_of: Optional[List[Type[nn.Module]]] = [LoRAWrapper], 105 | ): 106 | """ 107 | Find all modules of a certain class (or union of classes) that are direct or 108 | indirect descendants of other modules of a certain class (or union of classes). 109 | 110 | Returns all matching modules, along with the parent of those moduless and the 111 | names they are referenced by. 112 | 113 | Adapted from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py 114 | """ 115 | # Get the targets we should replace all linears under 116 | if ancestor_class is not None: 117 | ancestors = ( 118 | module 119 | for module in model.modules() 120 | if module.__class__.__name__ in ancestor_class 121 | ) 122 | else: 123 | # this, incase you want to naively iterate over all modules. 124 | ancestors = [module for module in model.modules()] 125 | 126 | # For each target find every linear_class module that isn't a child of a LoRA layer 127 | for ancestor in ancestors: 128 | for fullname, module in ancestor.named_modules(): 129 | if any([isinstance(module, _class) for _class in search_class]): 130 | # Find the direct parent if this is a descendant, not a child, of target 131 | *path, name = fullname.split(".") 132 | parent = ancestor 133 | while path: 134 | parent = parent.get_submodule(path.pop(0)) 135 | # Skip this linear if it's a child of a LoRA layer 136 | if exclude_children_of and any( 137 | [isinstance(parent, _class) for _class in exclude_children_of] 138 | ): 139 | continue 140 | # Otherwise, yield it 141 | yield parent, name, module 142 | 143 | 144 | def inject_trainable_LoRA( 145 | model: nn.Module, 146 | rank: int = 4, 147 | scale: float = 1.0, 148 | target_replace_modules: Set[str] = ATTENTION_MODULES, 149 | ) -> None: 150 | """Replaces all linear layers of the specified modules with LoRA-adapted linear layers. 151 | Modifies the model in-place. 152 | 153 | Args: 154 | model: nn.Module to modify 155 | rank: Rank of adaptation matrix B@A 156 | scale: x = W_0@x + scale * B@A@x 157 | target_replace_modules: Set of module names to replace linear layers in. 158 | """ 159 | for _module, name, _child_module in _find_modules( 160 | model, target_replace_modules, search_class=[nn.Linear] 161 | ): 162 | if sorted(name) == sorted("qkv"): 163 | num_packed_linear = 3 164 | elif sorted(name) in [sorted("kv"), sorted("qk"), sorted("qv")]: 165 | num_packed_linear = 2 166 | else: 167 | num_packed_linear = 1 168 | 169 | _module._modules[name] = LoRAWrapper( 170 | _child_module, rank=rank, scale=scale, num_packed_linear=num_packed_linear 171 | ) 172 | 173 | 174 | def fuse_LoRA_into_linear( 175 | model: nn.Module, target_replace_modules: Set[str] = ATTENTION_MODULES 176 | ) -> None: 177 | """Fuses all LoRA-adapted linear layers back into single linear layers. 178 | Modifies the model in-place. 179 | 180 | Args: 181 | model: nn.Module to modify 182 | target_replace_modules: Set of module names to replace linear layers in. 183 | """ 184 | for _module, name, _child_module in _find_modules( 185 | model, target_replace_modules, search_class=[LoRAWrapper] 186 | ): 187 | _module._modules[name] = _module._modules[name].fuse_LoRA_into_linear() 188 | 189 | 190 | def unfreeze_all_LoRA_layers(model: nn.Module) -> None: 191 | """Unfreezes all LoRA-adapted linear layers.""" 192 | for name, param in model.named_parameters(): 193 | if "lora" in name: 194 | param.requires_grad = True 195 | -------------------------------------------------------------------------------- /aion/codecs/spectrum.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | import torch 4 | from jaxtyping import Float, Real 5 | 6 | from aion.codecs.base import Codec 7 | from aion.codecs.modules.convnext import ConvNextDecoder1d, ConvNextEncoder1d 8 | from aion.codecs.modules.spectrum import LatentSpectralGrid 9 | from aion.codecs.quantizers import LucidrainsLFQ, Quantizer, ScalarLinearQuantizer 10 | from aion.codecs.utils import CodecPytorchHubMixin 11 | from aion.modalities import Spectrum 12 | 13 | 14 | class AutoencoderSpectrumCodec(Codec): 15 | """Meta-class for autoencoder codecs for spectra, does not actually contains a network.""" 16 | 17 | def __init__( 18 | self, 19 | quantizer: Quantizer, 20 | encoder: torch.nn.Module, 21 | decoder: torch.nn.Module, 22 | normalization_quantizer: Quantizer, 23 | lambda_min: float = 3500.0, 24 | resolution: float = 0.8, 25 | num_pixels: int = 8704, 26 | latent_channels: int = 512, 27 | embedding_dim: int = 4, 28 | clip_ivar: float = 100, 29 | clip_flux: float | None = None, 30 | input_scaling: float = 0.2, 31 | ): 32 | super().__init__() 33 | self._quantizer = quantizer 34 | self.encoder = encoder 35 | self.decoder = decoder 36 | self.normalization_quantizer = normalization_quantizer 37 | self.latent_grid = LatentSpectralGrid( 38 | lambda_min=lambda_min, resolution=resolution, num_pixels=num_pixels 39 | ) 40 | self.embedding_dim = embedding_dim 41 | self.clip_ivar = clip_ivar 42 | self.clip_flux = clip_flux 43 | self.input_scaling = input_scaling 44 | self.pre_quant_norm = torch.nn.LayerNorm(latent_channels) 45 | self.quant_conv = torch.nn.Conv1d(latent_channels, embedding_dim, 1) 46 | self.post_quant_conv = torch.nn.Conv1d(embedding_dim, latent_channels, 1) 47 | 48 | @property 49 | def modality(self) -> Type[Spectrum]: 50 | return Spectrum 51 | 52 | @property 53 | def quantizer(self) -> Quantizer: 54 | return self._quantizer 55 | 56 | def _encode(self, x: Spectrum) -> Float[torch.Tensor, "b c t"]: 57 | # Extract fields from Spectrum instance 58 | flux = x.flux 59 | ivar = x.ivar 60 | mask = x.mask 61 | wavelength = x.wavelength 62 | 63 | # Robustify the model against NaN values in the input 64 | # And add optional cliping of extreme values 65 | spectrum = torch.nan_to_num(flux) 66 | if self.clip_flux is not None: 67 | spectrum = torch.clamp(spectrum, -self.clip_flux, self.clip_flux) 68 | ivar = torch.nan_to_num(ivar) 69 | if self.clip_ivar is not None: 70 | ivar = torch.clamp(ivar, 0, self.clip_ivar) 71 | istd = torch.sqrt(ivar) 72 | 73 | # Normalize input spectrum 74 | normalization = (spectrum * (1.0 - mask.float())).sum(dim=-1) / ( 75 | torch.count_nonzero(~mask, dim=-1) + 1.0 76 | ) 77 | 78 | normalization = torch.clamp(normalization, 0.1) 79 | 80 | # Compressing the range of this normalization factor 81 | normalization = torch.log10(normalization + 1.0) 82 | 83 | # Apply quantization to normalization factor 84 | normalization = self.normalization_quantizer.quantize(normalization) 85 | 86 | # Normalize the spectrum 87 | n = torch.clamp((10 ** normalization[..., None] - 1.0), 0.1) 88 | spectrum = (spectrum / n - 1.0) * self.input_scaling 89 | istd = (istd / n) * self.input_scaling 90 | 91 | # Project spectra on the latent grid 92 | spectrum = self.latent_grid.to_latent(spectrum, wavelength) 93 | istd = self.latent_grid.to_latent(istd, wavelength) 94 | 95 | # Apply additional range compression for good measure 96 | x = torch.arcsinh(torch.stack([spectrum, istd], dim=1)) 97 | h = self.encoder(x) 98 | h = self.pre_quant_norm(h.moveaxis(1, -1)).moveaxis(-1, 1) 99 | h = self.quant_conv(h) 100 | return h, normalization 101 | 102 | def encode(self, x: Spectrum) -> Real[torch.Tensor, " b code"]: 103 | # Override to handle normalization token 104 | # First verify input type 105 | if not isinstance(x, self.modality): 106 | raise ValueError( 107 | f"Input type {type(x).__name__} does not match the modality of the codec {self.modality.__name__}" 108 | ) 109 | 110 | # Get embedding using _encode 111 | embedding, normalization = self._encode(x) 112 | 113 | # Quantize embedding 114 | embedding = self.quantizer.encode(embedding) 115 | 116 | # Quantize normalization 117 | normalization = self.normalization_quantizer.encode(normalization) 118 | 119 | # Concatenate normalization token with embedding 120 | embedding = torch.cat([normalization[..., None], embedding], dim=-1) 121 | 122 | return embedding 123 | 124 | def decode( 125 | self, 126 | z: Real[torch.Tensor, " b code"], 127 | wavelength: Float[torch.Tensor, " b t"] | None = None, 128 | ) -> Spectrum: 129 | # Override to handle normalization token extraction 130 | # Extract the normalization token from the sequence 131 | norm_token, z = z[..., 0], z[..., 1:] 132 | 133 | normalization = self.normalization_quantizer.decode(norm_token) 134 | 135 | z = self.quantizer.decode(z) 136 | 137 | return self._decode(z, normalization=normalization, wavelength=wavelength) 138 | 139 | def _decode( 140 | self, 141 | z: Float[torch.Tensor, " b c l"], 142 | normalization: Float[torch.Tensor, " b"], 143 | wavelength: Float[torch.Tensor, " b t"] | None = None, 144 | ) -> Spectrum: 145 | h = self.post_quant_conv(z) 146 | spectra = self.decoder(h) 147 | 148 | if spectra.shape[1] == 1: # just flux 149 | spectra = spectra.squeeze(1) 150 | mask = torch.ones_like(spectra) * -torch.inf 151 | elif spectra.shape[1] == 2: # flux and mask 152 | spectra, mask = spectra.chunk(2, dim=1) 153 | spectra, mask = spectra.squeeze(1), mask.squeeze(1) 154 | else: 155 | raise ValueError("Invalid number of output channels, must be 1 or 2") 156 | 157 | # If the wavelength are provided, interpolate the spectrum on the observed grid 158 | if wavelength is not None: 159 | spectra = self.latent_grid.to_observed(spectra, wavelength) 160 | mask = self.latent_grid.to_observed(mask, wavelength) 161 | else: 162 | b = spectra.shape[0] 163 | wavelength = self.latent_grid.wavelength.reshape(1, -1).repeat(b, 1) 164 | 165 | # Decode the spectrum on the latent grid and apply normalization 166 | if normalization is not None: 167 | spectra = (spectra + 1.0) * torch.clamp( 168 | 10 ** normalization[..., None] - 1.0, 0.1 169 | ) 170 | 171 | # Round mask 172 | mask = torch.round(torch.sigmoid(mask)).bool().detach() 173 | 174 | # Return Spectrum instance 175 | return Spectrum( 176 | flux=spectra, 177 | ivar=torch.ones_like(spectra), # We don't decode ivar, so set to ones 178 | mask=mask, 179 | wavelength=wavelength, 180 | ) 181 | 182 | 183 | class SpectrumCodec(AutoencoderSpectrumCodec, CodecPytorchHubMixin): 184 | """Spectrum codec based on convnext blocks.""" 185 | 186 | def __init__( 187 | self, 188 | encoder_depths: tuple[int, ...] = (3, 3, 9, 3), 189 | encoder_dims: tuple[int, ...] = (96, 192, 384, 768), 190 | decoder_depths: tuple[int, ...] = (3, 3, 9, 3), 191 | decoder_dims: tuple[int, ...] = (384, 192, 96, 1), 192 | lambda_min: float = 3500.0, 193 | resolution: float = 0.8, 194 | num_pixels: int = 8704, 195 | latent_channels: int = 512, 196 | embedding_dim: int = 4, 197 | clip_ivar: float = 100, 198 | clip_flux: float | None = None, 199 | input_scaling: float = 0.2, 200 | normalization_range: tuple[float, float] = (-1, 5), 201 | codebook_size: int = 1024, 202 | dim: int = 10, 203 | ): 204 | assert encoder_dims[-1] == latent_channels, ( 205 | "Last encoder dim must match latent_channels" 206 | ) 207 | quantizer = LucidrainsLFQ(dim=dim, codebook_size=codebook_size) 208 | normalization_quantizer = ScalarLinearQuantizer( 209 | codebook_size=codebook_size, range=normalization_range 210 | ) 211 | encoder = ConvNextEncoder1d( 212 | in_chans=2, 213 | depths=encoder_depths, 214 | dims=encoder_dims, 215 | ) 216 | 217 | decoder = ConvNextDecoder1d( 218 | in_chans=latent_channels, 219 | depths=decoder_depths, 220 | dims=decoder_dims, 221 | ) 222 | super().__init__( 223 | quantizer=quantizer, 224 | encoder=encoder, 225 | decoder=decoder, 226 | normalization_quantizer=normalization_quantizer, 227 | lambda_min=lambda_min, 228 | resolution=resolution, 229 | num_pixels=num_pixels, 230 | latent_channels=latent_channels, 231 | embedding_dim=embedding_dim, 232 | clip_ivar=clip_ivar, 233 | clip_flux=clip_flux, 234 | input_scaling=input_scaling, 235 | ) 236 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌌 AION-1: AstronomIcal Omnimodal Network 2 | 3 |
4 | 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | [![PyTorch](https://img.shields.io/badge/PyTorch-≥2.4.0-ee4c2c.svg)](https://pytorch.org/) 7 | [![Tests](https://github.com/PolymathicAI/AION/actions/workflows/test.yaml/badge.svg)](https://github.com/PolymathicAI/AION/actions/workflows/test.yaml) 8 | [![arXiv](https://img.shields.io/badge/arXiv-2510.17960-b31b1b.svg)](https://arxiv.org/abs/2510.17960) 9 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb) 10 | [![Model on HF](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-yellow)](https://huggingface.co/polymathic-ai/aion-base) 11 | 12 | **Polymathic's Large Omnimodal Model for Astronomy** 13 | 14 | [🚀 Quick Start](#-quick-start) • [🎓 Tutorials](#-tutorials) • [🔬 Scientific Overview](#-scientific-overview) • [📦 Advanced Installation](#-advanced-installation) 15 | 16 |
17 | 18 | --- 19 | 20 | ## 🎯 Overview 21 | 22 |
23 | AION Logo 24 |
25 | 26 | AION-1 is a cutting-edge large omnimodal model specifically designed for astronomical surveys. It seamlessly integrates multiple data modalities, and enables simple adaptation to a wide range of astronomical tasks. 27 | 28 | ## 🚀 Quick Start 29 | 30 | Assuming you have PyTorch installed, you can install AION trivially with: 31 | ```bash 32 | pip install polymathic-aion 33 | ``` 34 | 35 | Then you can load the pretrained model and start analyzing astronomical data: 36 | ```python 37 | import torch 38 | from aion import AION 39 | from aion.codecs import CodecManager 40 | from aion.modalities import LegacySurveyImage 41 | 42 | # Load model and codec manager 43 | model = AION.from_pretrained('aion-base').to('cuda') # or 'aion-large', 'aion-xlarge' 44 | codec_manager = CodecManager(device='cuda') 45 | 46 | # Prepare your astronomical data (example: Legacy Survey image) 47 | image = LegacySurveyImage( 48 | flux=your_image_tensor, # Shape: [batch, 4, height, width] for g,r,i,z bands 49 | bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] 50 | ) 51 | 52 | # Encode data to tokens 53 | tokens = codec_manager.encode(image) 54 | 55 | # Option 1: Extract embeddings for downstream tasks 56 | embeddings = model.encode(tokens, num_encoder_tokens=600) 57 | 58 | # Option 2: Generate predictions (e.g., redshift) 59 | from aion.modalities import Z 60 | preds = model( 61 | codec_manager.encode(image), 62 | target_modality=Z, 63 | ) 64 | ``` 65 | 66 | ## 🎓 Tutorials 67 | 68 | Start with our interactive tutorial: 69 | - **[Open in Google Colab](https://colab.research.google.com/github/PolymathicAI/AION/blob/main/notebooks/Tutorial.ipynb)** - Learn AION basics interactively, no local setup required! 70 | 71 | 72 | ## 🔬 Scientific Overview 73 | 74 | ### 🧬 Architecture 75 | AION-1 employs a two-stage, transformer-based design: 76 | 1. **Modality-Specific Tokenizers** transform raw inputs into discrete tokens 77 | 2. **Unified Encoder–Decoder Transformer** ingests all token streams via a multimodal masked modeling (4M) objective 78 | 79 | --- 80 | 81 | ### 🗂️ Supported Modalities 82 | AION-1’s tokenizers cover **39 distinct data types**, grouped by survey and data category 83 | 84 | | **Category** | **Description** | **Token Name(s)** | 85 | |-------------------------|-----------------------------------------|--------------------------| 86 | | **Imaging (2)** | Legacy Survey, HSC Wide | `tok_image_ls`, `tok_image_hsc` | 87 | | **Catalog (1)** | Legacy Survey catalog entries | `catalog` | 88 | | **Spectra (2)** | SDSS, DESI | `tok_spectrum_sdss`, `tok_spectrum_desi` | 89 | | **Gaia (4)** | BP/RP spectra, parallax, sky coords | `tok_xp_bp`, `tok_xp_rp`, `tok_parallax`, `tok_ra`, `tok_dec` | 90 | | **Gaia Photometry (3)** | G/BP/RP flux | `tok_flux_g_gaia`, `tok_flux_bp_gaia`, `tok_flux_rp_gaia` | 91 | | **Legacy Survey (9)** | g,r,i,z bands & WISE W1–W4 flux, E(B–V) | `tok_flux_g`,…,`tok_flux_w4`, `tok_ebv` | 92 | | **Legacy Shape (3)** | Ellipticity components & effective radius | `tok_shape_e1`, `tok_shape_e2`, `tok_shape_r` | 93 | | **HSC Photometry (5)** | g,r,i,z,y magnitudes | `tok_mag_g`,…,`tok_mag_y` | 94 | | **HSC Extinction (5)** | g,r,i,z,y extinctions | `tok_a_g`,…,`tok_a_y` | 95 | | **HSC Shape (3)** | Shape components 11,22,12 | `tok_shape11`, `tok_shape22`, `tok_shape12` | 96 | | **Other (1)** | Spectroscopic redshift | `tok_z` | 97 | 98 | --- 99 | 100 | ### 📈 Model Variants 101 | 102 | | **Variant** | **Encoder Blocks** | **Decoder Blocks** | **Model Dim** | **Heads** | **Total Params** | **Model** | 103 | |------------:|-------------------:|-------------------:|--------------:|----------:|-----------------:|-----------| 104 | | **Base** | 12 | 12 | 768 | 12 | 300 M | [aion-base](https://huggingface.co/polymathic-ai/aion-base) | 105 | | **Large** | 24 | 24 | 1024 | 16 | 800 M | soon | 106 | | **XLarge** | 24 | 24 | 2048 | 32 | 3 B | soon | 107 | 108 | > **Pretraining** 109 | > – Global batch size: 8 192 110 | > – Steps: Base (1.5 days on 64 H100), Large (2.5 days on 100 H100), XLarge (3.5 days on 288 H100) 111 | > – Optimizer: AdamW, peak LR 2 × 10⁻⁴, linear warmup + cosine decay 112 | 113 | ## 🔧 Data Preparation 114 | 115 | AION uses a typed data system to understand the provenance of each astronomical observation. Each modality must be properly formatted: 116 | 117 | ### Modality Types 118 | ```python 119 | from aion.modalities import ( 120 | LegacySurveyImage, HSCImage, # Images 121 | DESISpectrum, SDSSSpectrum, # Spectra 122 | LegacySurveyFluxG, HSCMagG, # Photometry 123 | GaiaParallax, Z, # Scalars 124 | # ... and 30+ more modalities 125 | ) 126 | ``` 127 | 128 | ### Example: Preparing Legacy Survey Data 129 | ```python 130 | import torch 131 | from aion.modalities import LegacySurveyImage, LegacySurveyFluxG 132 | 133 | # Format image data (shape: [batch, 4, height, width]) 134 | image = LegacySurveyImage( 135 | flux=torch.tensor(image_data, dtype=torch.float32), 136 | bands=['DES-G', 'DES-R', 'DES-I', 'DES-Z'] 137 | ) 138 | 139 | # Format scalar photometry 140 | flux_g = LegacySurveyFluxG(value=torch.tensor([flux_values])) 141 | ``` 142 | 143 | ### Supported Data Formats 144 | | Survey | Modality | Required Format | 145 | |--------|----------|-----------------| 146 | | **Legacy Survey** | Images | 4-band (g,r,i,z), any resolution (auto-cropped to 96×96) | 147 | | **HSC** | Images | 5-band (g,r,i,z,y), any resolution | 148 | | **DESI/SDSS** | Spectra | Flux, inverse variance, wavelength arrays | 149 | | **Gaia** | BP/RP | Coefficient arrays (55 coefficients each) | 150 | | **All Surveys** | Scalars | Single values or 1D tensors | 151 | 152 | --- 153 | 154 | ## 💡 Example Use Cases 155 | 156 | ### 🔍 Similarity Search 157 | Find galaxies similar to a query object across different modalities: 158 | ```python 159 | # Extract embeddings for similarity search 160 | query_embedding = model.encode(codec_manager.encode(query_image)) 161 | all_embeddings = model.encode(codec_manager.encode(*dataset_images)) 162 | 163 | # Find most similar objects using cosine similarity 164 | from sklearn.metrics.pairwise import cosine_similarity 165 | similarity_scores = cosine_similarity(query_embedding, all_embeddings) 166 | similar_objects = similarity_scores.argsort()[::-1][:10] # Top 10 similar 167 | ``` 168 | 169 | ### 📊 Property Prediction 170 | Build lightweight models on AION embeddings: 171 | ```python 172 | # Extract embeddings from multiple modalities 173 | embeddings = model.encode(codec_manager.encode( 174 | image, spectrum, flux_g, flux_r, flux_i, flux_z 175 | ), num_encoder_tokens=900) 176 | 177 | # Train simple regressor for stellar mass, redshift, etc. 178 | from sklearn.neighbors import KNeighborsRegressor 179 | regressor = KNeighborsRegressor(n_neighbors=5) 180 | regressor.fit(embeddings.mean(axis=1), target_property) 181 | ``` 182 | 183 | ### 🌌 Generative Modeling 184 | Predict missing astronomical properties: 185 | ```python 186 | # Predict redshift from photometry + morphology 187 | predictions = model( 188 | codec_manager.encode(image, flux_g, flux_r, flux_i, flux_z), 189 | target_mask={'tok_z': torch.zeros(batch_size, 1)}, 190 | num_encoder_tokens=600 191 | ) 192 | redshift_probs = torch.softmax(predictions['tok_z'], dim=-1) 193 | ``` 194 | 195 | ## 📦 Advanced Installation 196 | 197 | AION offers flexible installation options to suit your environment and requirements. 198 | 199 | To install AION with PyTorch included: 200 | 201 | ```bash 202 | pip install polymathic-aion[torch] 203 | ``` 204 | 205 | For contributors and developers: 206 | 207 | ```bash 208 | pip install polymathic-aion[torch,dev] 209 | ``` 210 | 211 | This includes testing frameworks, linting tools, and development dependencies. 212 | 213 | For specific PyTorch versions (e.g., CUDA support): 214 | 215 | ```bash 216 | # Install PyTorch with CUDA 12.4 support 217 | pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124 218 | 219 | # Then install AION 220 | pip install polymathic-aion 221 | ``` 222 | 223 | ## 📄 License 224 | 225 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 226 | 227 | ## 🌟 Acknowledgments 228 | 229 | AION is developed by [Polymathic AI](https://polymathic-ai.org/), advancing the frontier of AI for scientific applications. We would like to acknowledge the support of the Simons Foundation and of Schmidt Sciences. This project was provided with computer and storage resources by GENCI at IDRIS thanks to the grant 2024-GC011015468 on the supercomputer 230 | Jean Zay’s H100 partition. Additionally, some of the computations in this work were run at facilities supported by the Scientific Computing Core at the Flatiron Institute, a division of the Simons Foundation. 231 | 232 | ## 📬 Contact 233 | 234 | - **Issues**: [GitHub Issues](https://github.com/PolymathicAI/AION/issues) 235 | - **Discussions**: [GitHub Discussions](https://github.com/PolymathicAI/AION/discussions) 236 | 237 | --- 238 | 239 |
240 | Built with ❤️ for the astronomical community 241 |
242 | -------------------------------------------------------------------------------- /aion/codecs/scalar_field.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Type 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | from aion.codecs.utils import CodecPytorchHubMixin 10 | from aion.modalities import LegacySurveySegmentationMap 11 | 12 | from .base import Codec 13 | from .modules.convblocks import Decoder2d, Encoder2d 14 | from .modules.ema import ModelEmaV2 15 | from .preprocessing.image import CenterCrop 16 | from .quantizers import FiniteScalarQuantizer, Quantizer 17 | 18 | 19 | class AutoencoderScalarFieldCodec(Codec): 20 | """Abstract class for autoencoding scalar field codecs.""" 21 | 22 | def __init__( 23 | # ------------------------------------------------------------------------------ 24 | self, 25 | # Code dimensions -------------------------------------------------------------- 26 | encoder_output_dim: int, 27 | decoder_input_dim: Optional[int] = None, 28 | embedding_dim: Optional[int] = None, 29 | # Quantisation ----------------------------------------------------------------- 30 | quantizer: Optional[Quantizer] = None, 31 | # VAE operation ---------------------------------------------------------------- 32 | variational: bool = False, 33 | # Model outputs ---------------------------------------------------------------- 34 | output_activation: Optional[Callable] = torch.sigmoid, 35 | output_activation_extension: Optional[float] = None, 36 | # Loss calculation ------------------------------------------------------------- 37 | reconstruction_loss: Callable = F.mse_loss, 38 | quantisation_loss_weight: float = 1.0, 39 | # Loss optimisation ------------------------------------------------------------ 40 | lr: float = 1e-3, 41 | lr_warmup: Optional[int] = None, 42 | begin_cosine_annealing: Optional[int] = None, 43 | lr_cosine_period: Optional[int] = None, 44 | # Model weights EMA ------------------------------------------------------------ 45 | ema_model_weights: bool = False, 46 | ema_decay: float = 0.9999, 47 | ema_update_freq: int = 1, 48 | # ------------------------------------------------------------------------------ 49 | ): 50 | super().__init__() 51 | 52 | # Code dimensions -------------------------------------------------------------- 53 | 54 | self.encoder_output_dim = encoder_output_dim 55 | decoder_input_dim = decoder_input_dim or encoder_output_dim 56 | self.decoder_input_dim = decoder_input_dim 57 | embedding_dim = embedding_dim or encoder_output_dim 58 | self.embedding_dim = embedding_dim 59 | 60 | # VAE operation --------------------------------------------------------------- 61 | 62 | self.variational = variational 63 | 64 | # Preprocessing ---------------------------------------------------------------- 65 | 66 | self.center_crop = CenterCrop(crop_size=96) 67 | 68 | # Quantisation ----------------------------------------------------------------- 69 | 70 | # Pre/post quantisation projections 71 | encode_proj_dim = 2 * embedding_dim if variational else embedding_dim 72 | self.encode_proj = nn.Conv2d(encoder_output_dim, encode_proj_dim, 1) 73 | self.decode_proj = nn.Conv2d(embedding_dim, decoder_input_dim, 1) 74 | 75 | # Quantiser 76 | self._quantizer = quantizer 77 | assert ( 78 | self.quantizer.embedding_dim == embedding_dim 79 | if self.quantizer is not None 80 | else True 81 | ) 82 | 83 | # Model outputs ---------------------------------------------------------------- 84 | 85 | self.output_activation = output_activation or nn.Identity() 86 | self.output_activation_extension = output_activation_extension 87 | 88 | # Loss calculation ------------------------------------------------------------ 89 | 90 | self.reconstruction_loss = reconstruction_loss 91 | self.quantization_loss_weight = quantisation_loss_weight 92 | 93 | # Loss optimisation ------------------------------------------------------------ 94 | 95 | self.lr = lr 96 | self.lr_warmup = lr_warmup 97 | self.begin_cosine_annealing = begin_cosine_annealing 98 | self.lr_cosine_period = lr_cosine_period 99 | 100 | # Model weights EMA ------------------------------------------------------------ 101 | 102 | self.ema_model_weights = ema_model_weights 103 | self.ema_decay = ema_decay 104 | self.ema_update_freq = ema_update_freq 105 | 106 | # ------------------------------------------------------------------------------ 107 | 108 | @property 109 | def modality(self) -> Type[LegacySurveySegmentationMap]: 110 | return LegacySurveySegmentationMap 111 | 112 | @property 113 | def quantizer(self) -> Optional[Quantizer]: 114 | return self._quantizer 115 | 116 | def _encode(self, x: LegacySurveySegmentationMap) -> Float[Tensor, "b c h*w"]: 117 | # Extract the field tensor from the ScalarField modality 118 | field_tensor = x.field 119 | 120 | # Add channel dimension if needed (ScalarField is batch x height x width) 121 | if field_tensor.dim() == 3: 122 | field_tensor = field_tensor.unsqueeze(1) # Add channel dimension 123 | 124 | # Apply center cropping to 96x96 125 | processed_field = self.center_crop(field_tensor) 126 | 127 | h = self.encoder(processed_field) 128 | h = self.encode_proj(h) 129 | h = h.reshape(h.shape[0], h.shape[1], -1) 130 | return h 131 | 132 | def _decode(self, z: Float[Tensor, "b c h*w"]) -> LegacySurveySegmentationMap: 133 | batch_size, embedding_dim, n_tokens = z.shape 134 | spatial_size = int(n_tokens**0.5) 135 | assert spatial_size * spatial_size == n_tokens, ( 136 | f"n_tokens ({n_tokens}) is not a perfect square. " 137 | f"Calculated spatial_size: {spatial_size}." 138 | ) 139 | z = z.reshape(batch_size, embedding_dim, spatial_size, spatial_size) 140 | h = self.decode_proj(z) 141 | x_hat = self.decoder(h) 142 | x_hat = self._output_activation(x_hat).clip(0.0, 1.0) 143 | 144 | # Remove channel dimension for ScalarField (expects batch x height x width) 145 | if x_hat.shape[1] == 1: 146 | x_hat = x_hat.squeeze(1) # Remove channel dimension 147 | 148 | return LegacySurveySegmentationMap(field=x_hat) 149 | 150 | def _output_activation( 151 | self, 152 | x_hat: Float[Tensor, "b c h w"], 153 | ) -> Float[Tensor, "b c h w"]: 154 | x_hat = self.output_activation(x_hat) 155 | 156 | d = self.output_activation_extension 157 | if d is not None: 158 | x_hat = (1 + 2 * d) * x_hat - d 159 | 160 | return x_hat 161 | 162 | 163 | # ====================================================================================== 164 | # Specific subclasses of AutoencoderScalarFieldCodec 165 | # ====================================================================================== 166 | 167 | 168 | class ScalarFieldCodec(AutoencoderScalarFieldCodec, CodecPytorchHubMixin): 169 | """Convolutional autoencoder codec for scalar fields.""" 170 | 171 | def __init__( 172 | # ------------------------------------------------------------------------------ 173 | self, 174 | # Code dimensions -------------------------------------------------------------- 175 | encoder_output_dim: int, 176 | decoder_input_dim: Optional[int] = None, 177 | embedding_dim: Optional[int] = None, 178 | # Encoder / decoder architecture ----------------------------------------------- 179 | res_hidden_dims: int = 64, 180 | num_res_layers: int = 2, 181 | num_downsamples: int = 3, 182 | # VAE operation ---------------------------------------------------------------- 183 | variational: bool = False, 184 | # Model outputs ---------------------------------------------------------------- 185 | output_activation: Optional[Callable] = F.sigmoid, 186 | output_activation_extension: Optional[float] = None, 187 | # Loss calculation ------------------------------------------------------------- 188 | reconstruction_loss: Callable = F.mse_loss, 189 | quantisation_loss_weight: float = 1.0, 190 | # Loss optimisation ------------------------------------------------------------ 191 | lr: float = 1e-3, 192 | lr_warmup: Optional[int] = None, 193 | begin_cosine_annealing: Optional[int] = None, 194 | lr_cosine_period: Optional[int] = None, 195 | # Model weights EMA ------------------------------------------------------------ 196 | ema_model_weights: bool = False, 197 | ema_decay: float = 0.9999, 198 | ema_update_freq: int = 1, 199 | levels=[8, 5, 5, 5], 200 | # ------------------------------------------------------------------------------ 201 | ): 202 | super().__init__( 203 | encoder_output_dim=encoder_output_dim, 204 | decoder_input_dim=decoder_input_dim, 205 | embedding_dim=embedding_dim, 206 | variational=variational, 207 | output_activation=output_activation, 208 | output_activation_extension=output_activation_extension, 209 | reconstruction_loss=reconstruction_loss, 210 | quantisation_loss_weight=quantisation_loss_weight, 211 | lr=lr, 212 | lr_warmup=lr_warmup, 213 | begin_cosine_annealing=begin_cosine_annealing, 214 | lr_cosine_period=lr_cosine_period, 215 | ema_model_weights=ema_model_weights, 216 | ema_decay=ema_decay, 217 | ema_update_freq=ema_update_freq, 218 | ) 219 | 220 | self._quantizer = FiniteScalarQuantizer(levels=levels) 221 | 222 | # Encoder ---------------------------------------------------------------------- 223 | self.encoder = Encoder2d( 224 | in_dims=1, 225 | out_dims=self.encoder_output_dim, 226 | res_hidden_dims=res_hidden_dims, 227 | num_res_layers=num_res_layers, 228 | num_downsamples=num_downsamples, 229 | ) 230 | 231 | # Decoder ---------------------------------------------------------------------- 232 | self.decoder = Decoder2d( 233 | in_dims=self.decoder_input_dim, # = encoder_output_dim unless overridden 234 | out_dims=1, 235 | hidden_dims=self.decoder_input_dim, 236 | res_hidden_dims=res_hidden_dims, 237 | num_res_layers=num_res_layers, 238 | num_upsamples=num_downsamples, 239 | ) 240 | 241 | # Model weights EMA ------------------------------------------------------------ 242 | if ema_model_weights: 243 | self.ema = ModelEmaV2(self, decay=ema_decay, device=None) 244 | else: 245 | self.ema = None 246 | -------------------------------------------------------------------------------- /aion/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, Tuple, Optional 3 | 4 | from .fourm.fm import FM 5 | from .fourm.modality_info import MODALITY_INFO 6 | 7 | 8 | class AION(FM): 9 | """ 10 | Wrapper for 4M model including additional utilities. 11 | """ 12 | 13 | def embed_inputs( 14 | self, 15 | input_dict: Dict[str, torch.Tensor], 16 | mask: Optional[Dict[str, torch.Tensor]] = None, 17 | num_encoder_tokens: int = 256, 18 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 19 | """ 20 | Builds the encoder modality dictionary given some input data. 21 | Optionally, if mask is provided, input token masking can be used. 22 | 23 | Args: 24 | X (Dict[str, torch.Tensor]): Input data dictionary. 25 | mask (Dict[str, torch.Tensor], optional): Mask dictionary. Defaults to {}. 26 | num_encoder_tokens (int, optional): Maximum number of encoder tokens. Defaults to 256. 27 | 28 | Returns: 29 | tuple: 30 | - encoder_tokens (torch.Tensor): Selected encoder tokens from all modalities. Shape (B, N, D) where N is the number of selected encoder tokens. 31 | - encoder_emb (torch.Tensor): Corresponding embeddings for encoder tokens. Shape (B, N, D) 32 | - encoder_mask (torch.Tensor): A boolean mask indicating which encoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N) 33 | - mod_mask (torch.Tensor): An integer mask marking the modality type for each encoder token (with -1 indicating unassigned pad tokens). Shape (B, N) 34 | """ 35 | if mask is None: 36 | mask = {} 37 | assert isinstance(input_dict, dict), "first input must be a dictionary" 38 | assert isinstance(mask, dict), "Mask must be a dictionary if provided" 39 | assert all(key in input_dict for key in mask), ( 40 | "All keys in the input mask must be in X" 41 | ) 42 | assert all(key in self.encoder_embeddings for key in input_dict.keys()), ( 43 | "All keys in X must be in self.encoder_embeddings" 44 | ) 45 | 46 | device = next(self.parameters()).device 47 | 48 | encoder_mod_dict = {} 49 | for mod, tensor in input_dict.items(): 50 | tensor = tensor.to(torch.long).to(device) 51 | if tensor.dim() == 1: 52 | tensor = tensor.unsqueeze(1) 53 | input_mask = mask.get( 54 | mod, 55 | torch.zeros( 56 | tensor.shape[0], tensor.shape[1], dtype=torch.bool, device=device 57 | ), 58 | ) 59 | if MODALITY_INFO[mod]["type"] == "img": 60 | assert tensor.shape[1] == self.encoder_embeddings[mod].num_patches, ( 61 | f"Expected size {self.encoder_embeddings[mod].num_patches} for modality {mod}, but got {tensor.shape[1]}" 62 | ) 63 | 64 | encoder_mod_dict[mod] = self.encoder_embeddings[mod]( 65 | {"tensor": tensor, "input_mask": input_mask} 66 | ) 67 | 68 | encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = ( 69 | self.forward_mask_encoder(encoder_mod_dict, num_encoder_tokens) 70 | ) 71 | 72 | return encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask 73 | 74 | def embed_targets( 75 | self, target_mask: Dict[str, torch.Tensor], num_decoder_tokens: int = 256 76 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 77 | """ 78 | 79 | Returns: 80 | tuple: 81 | - decoder_tokens (torch.Tensor): Selected decoder tokens from all modalities. Shape (B, M, D) where M is the number of selected decoder tokens. 82 | - decoder_emb (torch.Tensor): Corresponding embeddings for decoder tokens. Shape (B, M, D) 83 | - decoder_mask (torch.Tensor): A boolean mask indicating which decoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, M) 84 | - target_ids (torch.Tensor): IDs of the target tokens corresponding to the decoder tokens. Shape (B, M) 85 | - decoder_attention_mask (torch.Tensor): Mask for the decoder self-attention layers. Shape (B, M, M) 86 | - mod_mask (torch.Tensor): An integer mask marking the modality type for each decoder token (with -1 indicating unassigned pad tokens). Shape (B, M) 87 | """ 88 | assert isinstance(target_mask, dict), "Traget mask must be a dictionary" 89 | assert all(key in self.decoder_embeddings for key in target_mask.keys()), ( 90 | "All keys in target mask must be in self.decoder_embeddings" 91 | ) 92 | 93 | device = next(self.parameters()).device 94 | 95 | decoder_mod_dict = {} 96 | for mod, mask in target_mask.items(): 97 | mask = mask.to(torch.bool).to(device) 98 | tensor = torch.zeros_like(mask).to(torch.long).to(device) 99 | decoder_attention_mask = torch.zeros_like(mask).to(torch.bool).to(device) 100 | decoder_mod_dict[mod] = self.decoder_embeddings[mod].forward_embed( 101 | { 102 | "tensor": tensor, 103 | "target_mask": mask, 104 | "decoder_attention_mask": decoder_attention_mask, 105 | } 106 | ) 107 | 108 | ( 109 | decoder_tokens, 110 | decoder_emb, 111 | decoder_mask, 112 | target_ids, 113 | decoder_attention_mask, 114 | decoder_mod_mask, 115 | ) = self.forward_mask_decoder(decoder_mod_dict, num_decoder_tokens) 116 | 117 | return ( 118 | decoder_tokens, 119 | decoder_emb, 120 | decoder_mask, 121 | target_ids, 122 | decoder_attention_mask, 123 | decoder_mod_mask, 124 | ) 125 | 126 | def _encode(self, encoder_tokens, encoder_emb, encoder_mask): 127 | x = encoder_tokens + encoder_emb 128 | x = self.forward_encoder(x, encoder_mask=encoder_mask) 129 | context = self.decoder_proj_context(x) + encoder_emb 130 | return context 131 | 132 | def _decode( 133 | self, 134 | encoder_outputs, 135 | encoder_mask, 136 | decoder_tokens, 137 | decoder_emb, 138 | decoder_attention_mask, 139 | ): 140 | x = decoder_tokens + decoder_emb 141 | x = self.forward_decoder( 142 | x, 143 | encoder_outputs, 144 | encoder_mask=encoder_mask, 145 | decoder_attention_mask=decoder_attention_mask, 146 | ) 147 | return x 148 | 149 | def encode( 150 | self, 151 | input_dict: Dict[str, torch.Tensor], 152 | input_mask: Optional[Dict[str, torch.Tensor]] = None, 153 | num_encoder_tokens: int = 256, 154 | ) -> torch.Tensor: 155 | """ 156 | Encode input data using the mode 157 | 158 | Args: 159 | num_encoder_tokens (int, optional): Maximum number of encoder tokens. Defaults to 256. 160 | """ 161 | encoder_tokens, encoder_emb, encoder_mask, _ = self.embed_inputs( 162 | input_dict, mask=input_mask, num_encoder_tokens=num_encoder_tokens 163 | ) 164 | return self._encode(encoder_tokens, encoder_emb, encoder_mask) 165 | 166 | def forward( 167 | self, 168 | input_dict: Dict[str, torch.Tensor], 169 | target_modality: list[object], 170 | input_mask: Optional[Dict[str, torch.Tensor]] = None, 171 | ) -> torch.Tensor: 172 | """ 173 | Helpful function to compute the logits of the requested target outputs, given the input data. 174 | 175 | Args: 176 | input_dict (Dict[str, torch.Tensor]): Input data dictionary. 177 | target_modality (list[object]): List of target modalities to be predicted. 178 | input_mask (Dict[str, torch.Tensor], optional): Mask dictionary. Defaults to None. 179 | 180 | Returns: 181 | torch.Tensor: Output tensor of the model. 182 | """ 183 | # Get batch size: 184 | B = list(input_dict.values())[0].shape[0] 185 | 186 | # Dynamically compute the number of encoder tokens 187 | num_encoder_tokens = 0 188 | for mod in input_dict.keys(): 189 | num_encoder_tokens += ( 190 | input_dict[mod].shape[1] if input_dict[mod].dim() == 2 else 1 191 | ) 192 | 193 | # Dynamically build the target mask and decoder tokens 194 | target_mask = {} 195 | num_decoder_tokens = 0 196 | target_modality = ( 197 | [target_modality] 198 | if not isinstance(target_modality, list) 199 | else target_modality 200 | ) 201 | for mod in target_modality: 202 | target_mask[mod.token_key] = torch.zeros(B, mod.num_tokens).to(torch.bool) 203 | num_decoder_tokens += mod.num_tokens 204 | 205 | logit_dict = self._forward( 206 | input_dict, 207 | target_mask=target_mask, 208 | input_mask=input_mask, 209 | num_decoder_tokens=num_decoder_tokens, 210 | num_encoder_tokens=num_encoder_tokens, 211 | ) 212 | 213 | for mod in logit_dict.keys(): 214 | logit_dict[mod] = logit_dict[mod].view(B, target_mask[mod].shape[1], -1) 215 | 216 | return logit_dict 217 | 218 | def _forward( 219 | self, 220 | input_dict: Dict[str, torch.Tensor], 221 | target_mask: Dict[str, torch.Tensor], 222 | input_mask: Optional[Dict[str, torch.Tensor]] = None, 223 | num_decoder_tokens: int = 256, 224 | num_encoder_tokens: int = 256, 225 | ) -> torch.Tensor: 226 | """ 227 | The forward function returns the logits of the requested target outputs, given the input data. 228 | """ 229 | # Embedding inputs and targets 230 | encoder_tokens, encoder_emb, encoder_mask, _ = self.embed_inputs( 231 | input_dict, mask=input_mask, num_encoder_tokens=num_encoder_tokens 232 | ) 233 | ( 234 | decoder_tokens, 235 | decoder_emb, 236 | decoder_mask, 237 | target_ids, 238 | decoder_attention_mask, 239 | decoder_mod_mask, 240 | ) = self.embed_targets(target_mask, num_decoder_tokens=num_decoder_tokens) 241 | 242 | # Run the encoder 243 | encoder_output = self._encode(encoder_tokens, encoder_emb, encoder_mask) 244 | decoder_output = self._decode( 245 | encoder_output, 246 | encoder_mask, 247 | decoder_tokens, 248 | decoder_emb, 249 | decoder_attention_mask, 250 | ) 251 | 252 | # Now, we compute the logits for the requested tokens and return them 253 | mod_logits = {} 254 | for mod in target_mask.keys(): 255 | idx = self.modality_info[mod]["id"] 256 | mod_logits[mod] = self.decoder_embeddings[mod].forward_logits( 257 | decoder_output[decoder_mod_mask == idx] 258 | ) 259 | 260 | return mod_logits 261 | --------------------------------------------------------------------------------