├── benchmarks ├── __init__.py ├── README.md ├── data_request_benchmarks.py ├── benchmarks.py └── asv.conf.json ├── src ├── hyrax │ ├── data_sets │ │ ├── random │ │ │ └── __init__.py │ │ ├── hyrax_csv_dataset.py │ │ ├── hyrax_cifar_data_set.py │ │ └── __init__.py │ ├── vector_dbs │ │ ├── __init__.py │ │ ├── vector_db_factory.py │ │ └── vector_db_interface.py │ ├── 3d_viz │ │ ├── .gitattributes │ │ ├── readme.md │ │ └── plotly_3d.py │ ├── downloadCutout │ │ ├── LINCC_README.md │ │ ├── __init__.py │ │ └── README.md │ ├── __init__.py │ ├── prepare.py │ ├── models │ │ ├── __init__.py │ │ ├── hyrax_loopback.py │ │ ├── hsc_autoencoder.py │ │ ├── hsc_dcae.py │ │ ├── simclr.py │ │ └── hyrax_cnn.py │ ├── rebuild_manifest.py │ ├── verbs │ │ ├── model.py │ │ ├── __init__.py │ │ ├── search.py │ │ ├── verb_registry.py │ │ ├── lookup.py │ │ ├── database_connection.py │ │ └── to_onnx.py │ └── gpu_monitor.py └── hyrax_cli │ └── main.py ├── docs ├── notebooks │ ├── README.md │ └── model_input_3.ipynb ├── _static │ ├── hyrax_design.png │ ├── hyrax_header.png │ └── umap_visualization.JPG ├── pre_executed │ ├── mlflow_screenshot.JPG │ ├── umap_visualization.JPG │ ├── tensorboard_screenshot.JPG │ ├── README.md │ └── mpr_demo_plotting.py ├── requirements.txt ├── notebooks.rst ├── architecture_overview.rst ├── Makefile ├── dev_guide.rst ├── model_comparison.rst ├── about.rst ├── data_set_splits.rst ├── verbs.rst ├── conf.py └── configuration.rst ├── .git_archival.txt ├── .github ├── ISSUE_TEMPLATE │ ├── 0-general_issue.md │ ├── README.md │ ├── 2-feature_request.md │ └── 1-bug_report.md ├── dependabot.yml ├── workflows │ ├── README.md │ ├── build-documentation.yml │ ├── publish-to-pypi.yml │ ├── pre-commit-ci.yml │ ├── testing-and-coverage.yml │ ├── smoke-test.yml │ ├── publish-benchmarks-pr.yml │ ├── asv-main.yml │ ├── asv-nightly.yml │ └── asv-pr.yml └── pull_request_template.md ├── tests └── hyrax │ ├── test_data │ ├── small_dataset_hscstars │ │ ├── star_cat_correct.pq │ │ ├── star_cat_correct.fits │ │ ├── images │ │ │ ├── 10-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ ├── 11-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ ├── 2-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ ├── 3-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ ├── 4-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ ├── 5-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ ├── 6-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ ├── 7-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ ├── 8-cutout-HSC-I-8279-pdr2_wide.fits │ │ │ └── 9-cutout-HSC-I-8279-pdr2_wide.fits │ │ └── star_cat_correct.astropy.csv │ ├── test_config_quoted_tables.toml │ ├── csv_test │ │ └── sample_data.csv │ ├── test_user_config.toml │ ├── test_user_config_repeated_keys.toml │ └── test_default_config.toml │ ├── test_packaging.py │ ├── test_umap.py │ ├── test_csv_dataset.py │ ├── test_fits_image_dataset.py │ ├── test_save_to_database.py │ ├── test_e2e.py │ ├── test_infer.py │ ├── test_train.py │ ├── test_patch_model.py │ ├── test_to_onnx.py │ └── test_qdrant_impl.py ├── .readthedocs.yml ├── .copier-answers.yml ├── .gitattributes ├── LICENSE ├── .setup_dev.sh ├── README_RSP.md ├── README.md ├── example_notebooks └── GettingStartedDownloader.ipynb ├── .gitignore ├── .pre-commit-config.yaml └── pyproject.toml /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/hyrax/data_sets/random/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/notebooks/README.md: -------------------------------------------------------------------------------- 1 | Put your Jupyter notebooks here :) 2 | -------------------------------------------------------------------------------- /src/hyrax/vector_dbs/__init__.py: -------------------------------------------------------------------------------- 1 | from .chromadb_impl import ChromaDB 2 | 3 | __all__ = ["ChromaDB"] 4 | -------------------------------------------------------------------------------- /docs/_static/hyrax_design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/docs/_static/hyrax_design.png -------------------------------------------------------------------------------- /docs/_static/hyrax_header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/docs/_static/hyrax_header.png -------------------------------------------------------------------------------- /docs/_static/umap_visualization.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/docs/_static/umap_visualization.JPG -------------------------------------------------------------------------------- /src/hyrax/3d_viz/.gitattributes: -------------------------------------------------------------------------------- 1 | umap_data* filter=lfs diff=lfs merge=lfs -text 2 | *.fits filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /docs/pre_executed/mlflow_screenshot.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/docs/pre_executed/mlflow_screenshot.JPG -------------------------------------------------------------------------------- /docs/pre_executed/umap_visualization.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/docs/pre_executed/umap_visualization.JPG -------------------------------------------------------------------------------- /docs/pre_executed/tensorboard_screenshot.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/docs/pre_executed/tensorboard_screenshot.JPG -------------------------------------------------------------------------------- /.git_archival.txt: -------------------------------------------------------------------------------- 1 | node: bda3e289a68df67221210402ac289ec4c83844e3 2 | node-date: 2025-12-18T11:24:59-08:00 3 | describe-name: v0.6.9 4 | ref-names: HEAD -> main, tag: v0.6.9 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/0-general_issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General issue 3 | about: Quickly create a general issue 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/star_cat_correct.pq: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/star_cat_correct.pq -------------------------------------------------------------------------------- /tests/hyrax/test_packaging.py: -------------------------------------------------------------------------------- 1 | import hyrax 2 | 3 | 4 | def test_version(): 5 | """Check to see that we can get the package version""" 6 | assert hyrax.__version__ is not None 7 | -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/star_cat_correct.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/star_cat_correct.fits -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | ipykernel 3 | ipython 4 | jupytext 5 | nbconvert 6 | nbsphinx 7 | sphinx 8 | sphinx-autoapi 9 | sphinx-copybutton 10 | sphinx-rtd-theme 11 | sphinx-tabs 12 | sphinx-togglebutton -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/10-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/10-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/11-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/11-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/2-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/2-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/3-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/3-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/4-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/4-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/5-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/5-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/6-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/6-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/7-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/7-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/8-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/8-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/images/9-cutout-HSC-I-8279-pdr2_wide.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lincc-frameworks/hyrax/HEAD/tests/hyrax/test_data/small_dataset_hscstars/images/9-cutout-HSC-I-8279-pdr2_wide.fits -------------------------------------------------------------------------------- /tests/hyrax/test_data/test_config_quoted_tables.toml: -------------------------------------------------------------------------------- 1 | # Test config with quoted table names 2 | [general] 3 | dev_mode = true 4 | 5 | ["my.custom.optimizer.Adam"] 6 | lr = 0.01 7 | beta1 = 0.9 8 | 9 | ["my.custom.optimizer.SGD"] 10 | lr = 0.01 11 | momentum = 0.9 12 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | - package-ecosystem: "pip" 8 | directory: "/" 9 | schedule: 10 | interval: "monthly" 11 | -------------------------------------------------------------------------------- /tests/hyrax/test_data/csv_test/sample_data.csv: -------------------------------------------------------------------------------- 1 | object_id,ra,dec,magnitude,flux,classification 2 | 1001,30.5,6.2,19.5,1.234e-27,star 3 | 1002,30.6,6.3,18.8,2.456e-27,galaxy 4 | 1003,30.7,6.4,20.1,0.987e-27,star 5 | 1004,30.8,6.5,19.2,1.567e-27,galaxy 6 | 1005,30.9,6.6,21.3,0.543e-27,star 7 | -------------------------------------------------------------------------------- /tests/hyrax/test_data/test_user_config.toml: -------------------------------------------------------------------------------- 1 | [general] 2 | dev_mode = true 3 | 4 | [train.model] 5 | weights_filename = "final_best.pth" 6 | layers = 3 7 | 8 | [infer] 9 | batch_size = 8 # change batch size 10 | 11 | [bespoke_table] 12 | # this is a bespoke table 13 | key1 = "value1" 14 | key2 = "value2" # unlikely to modify 15 | -------------------------------------------------------------------------------- /src/hyrax/downloadCutout/LINCC_README.md: -------------------------------------------------------------------------------- 1 | # HSC Data download tool. 2 | This tool was downloaded from the [ssp-software/data-access-tools](https://hsc-gitlab.mtk.nao.ac.jp/ssp-software/data-access-tools/) gitlab repository. 3 | 4 | This directory was initialized with a copy of the `pdr3/downloadCutout` directory at rev b628d6089acda041eea1041d1011ea154ebefc28 committed Feb 14 2024. -------------------------------------------------------------------------------- /tests/hyrax/test_data/test_user_config_repeated_keys.toml: -------------------------------------------------------------------------------- 1 | [general] 2 | dev_mode = true 3 | 4 | [model] 5 | name = "resnet" 6 | layers = 3 7 | 8 | [loss] 9 | name = "cross_entropy" 10 | 11 | [optimizer] 12 | name = "adam" 13 | 14 | [infer] 15 | batch_size = 8 # change batch size 16 | 17 | [bespoke_table] 18 | # this is a bespoke table 19 | key1 = "value1" 20 | key2 = "value2" # unlikely to modify 21 | -------------------------------------------------------------------------------- /src/hyrax/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | from .config_utils import log_runtime_config 3 | from .hyrax import Hyrax 4 | from .plugin_utils import get_or_load_class, import_module_from_string, update_registry 5 | 6 | __all__ = [ 7 | "log_runtime_config", 8 | "get_or_load_class", 9 | "import_module_from_string", 10 | "update_registry", 11 | "Hyrax", 12 | "__version__", 13 | ] 14 | -------------------------------------------------------------------------------- /.github/workflows/README.md: -------------------------------------------------------------------------------- 1 | # Workflows 2 | 3 | The .yml files in this directory are used to define the various continuous 4 | integration scripts that will be run on your behalf e.g. nightly as a smoke check, 5 | or when you create a new PR. 6 | 7 | For more information about CI and workflows, look here: https://lincc-ppt.readthedocs.io/en/latest/practices/ci.html 8 | 9 | Or if you still have questions contact us: https://lincc-ppt.readthedocs.io/en/latest/source/contact.html -------------------------------------------------------------------------------- /tests/hyrax/test_data/test_default_config.toml: -------------------------------------------------------------------------------- 1 | # this is the default config file 2 | [general] 3 | # set dev_mode to true when developing 4 | # set to false for production use 5 | dev_mode = true 6 | 7 | [train] 8 | model_name = "example_model" # Use a built-in Hyrax model 9 | model_class = "new_thing.cool_model.CoolModel" # Use a custom model 10 | 11 | [train.model] 12 | weights_filename = "example_model.pth" 13 | layers = 3 14 | 15 | [infer] 16 | batch_size = 32 17 | -------------------------------------------------------------------------------- /src/hyrax/prepare.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from hyrax.pytorch_ignite import setup_dataset 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def run(config): 9 | """Prepare the dataset for a given model and data loader. 10 | 11 | Parameters 12 | ---------- 13 | config : dict 14 | The parsed config file as a nested 15 | dict 16 | """ 17 | 18 | data_set = setup_dataset(config) 19 | 20 | logger.info("Finished Prepare") 21 | return data_set 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/README.md: -------------------------------------------------------------------------------- 1 | # Configurations 2 | 3 | Templates for various different issue types are defined in this directory 4 | and a pull request template is defined as ``../pull_request_template.md``. Adding, 5 | removing, and modifying these templates to suit the needs of your project is encouraged. 6 | 7 | For more information about these templates, look here: https://lincc-ppt.readthedocs.io/en/latest/practices/issue_pr_templating.html 8 | 9 | Or if you still have questions contact us: https://lincc-ppt.readthedocs.io/en/latest/source/contact.html -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | 2 | # .readthedocs.yml 3 | # Read the Docs configuration file 4 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 5 | 6 | # Required 7 | version: 2 8 | 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | # Optionally declare the Python requirements required to build your docs 19 | python: 20 | install: 21 | - requirements: docs/requirements.txt 22 | - method: pip 23 | path: . 24 | -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | This directory contains files that will be run via continuous testing either 4 | nightly or after committing code to a pull request. 5 | 6 | The runtime and/or memory usage of the functions defined in these files will be 7 | tracked and reported to give you a sense of the overall performance of your code. 8 | 9 | You are encouraged to add, update, or remove benchmark functions to suit the needs 10 | of your project. 11 | 12 | For more information, see the documentation here: https://lincc-ppt.readthedocs.io/en/latest/practices/ci_benchmarking.html -------------------------------------------------------------------------------- /docs/notebooks.rst: -------------------------------------------------------------------------------- 1 | Example notebooks 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | Getting started with Hyrax 8 | Writing a custom dataset 9 | Providing data 1 10 | Providing data 2 11 | Providing data 3 12 | Training to Similarity Search 13 | Working with a vector database 14 | Exporting a Model 15 | HyraxQL data requests 16 | -------------------------------------------------------------------------------- /src/hyrax/downloadCutout/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a modified copy of the Hyper Suprime-Cam (HSC) download cutout tool published for pdr3 on the 3 | `HSC Gitlab repository `_ 4 | 5 | It is tightly integrated with the ``download`` verb in hyrax and has been modified so as to allow: 6 | #. Better error reporting and recovery 7 | #. Multithreaded use by the ``download`` hyrax verb 8 | #. Creation and updating of download manifests when run by the download verb. 9 | 10 | """ 11 | 12 | # ruff: noqa: F403 13 | from .downloadCutout import * 14 | 15 | __all__ = [] 16 | -------------------------------------------------------------------------------- /.copier-answers.yml: -------------------------------------------------------------------------------- 1 | # Changes here will be overwritten by Copier 2 | _commit: v2.1.1 3 | _src_path: gh:lincc-frameworks/python-project-template 4 | author_email: mtauraso@uw.edu 5 | author_name: LINCC Frameworks 6 | create_example_module: false 7 | custom_install: true 8 | enforce_style: 9 | - ruff_lint 10 | - ruff_format 11 | failure_notification: [] 12 | include_benchmarks: true 13 | include_docs: true 14 | include_notebooks: true 15 | mypy_type_checking: none 16 | package_name: hyrax 17 | project_license: MIT 18 | project_name: hyrax 19 | project_organization: lincc-frameworks 20 | python_versions: 21 | - '3.10' 22 | - '3.11' 23 | - '3.12' 24 | - '3.13' 25 | test_lowest_version: none 26 | -------------------------------------------------------------------------------- /docs/pre_executed/README.md: -------------------------------------------------------------------------------- 1 | # Pre-executed Jupyter notebooks 2 | 3 | Jupyter notebooks in this directory will NOT be run in the docs workflows, and will be rendered with 4 | the provided output cells as-is. 5 | 6 | This is useful for notebooks that require large datasets, access to third party APIs, large CPU or GPU requirements. 7 | 8 | Where possible, instead write smaller notebooks that can be run as part of a github worker, and within the ReadTheDocs rendering process. 9 | 10 | To ensure that the notebooks are not run by the notebook conversion process, you can add the following metadata block to the notebook: 11 | 12 | ``` 13 | "nbsphinx": { 14 | "execute": "never" 15 | }, 16 | ``` 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: 'Short description' 5 | labels: 'enhancement' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Feature request** 11 | 12 | 13 | **Before submitting** 14 | Please check the following: 15 | 16 | - [ ] I have described the purpose of the suggested change, specifying what I need the enhancement to accomplish, i.e. what problem it solves. 17 | - [ ] I have included any relevant links, screenshots, environment information, and data relevant to implementing the requested feature, as well as pseudocode for how I want to access the new functionality. 18 | - [ ] If I have ideas for how the new feature could be implemented, I have provided explanations and/or pseudocode and/or task lists for the steps. 19 | -------------------------------------------------------------------------------- /src/hyrax/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Remove import sorting, these are imported in the order written so that 2 | # autoapi docs are generated with ordering controlled below. 3 | # ruff: noqa: I001 4 | from .hsc_autoencoder import HSCAutoencoder 5 | from .hsc_dcae import HSCDCAE 6 | from .image_dcae import ImageDCAE 7 | from .hyrax_autoencoder import HyraxAutoencoder 8 | from .hyrax_autoencoderv2 import HyraxAutoencoderV2 9 | from .hyrax_cnn import HyraxCNN 10 | from .hyrax_loopback import HyraxLoopback 11 | from .model_registry import hyrax_model 12 | from .simclr import SimCLR 13 | 14 | __all__ = [ 15 | "hyrax_model", 16 | "HyraxAutoencoder", 17 | "HyraxAutoencoderV2", 18 | "HyraxCNN", 19 | "HyraxLoopback", 20 | "HSCAutoencoder", 21 | "HSCDCAE", 22 | "ImageDCAE", 23 | "SimCLR", 24 | ] 25 | -------------------------------------------------------------------------------- /src/hyrax/vector_dbs/vector_db_factory.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from hyrax.vector_dbs.vector_db_interface import VectorDB 4 | 5 | 6 | def vector_db_factory(config: dict, context: dict) -> Union[VectorDB, None]: 7 | """Factory method to create a database object""" 8 | 9 | # if the vector_db name is `False`, return None 10 | if not config["vector_db"]["name"]: 11 | return None 12 | 13 | vector_db_name = config["vector_db"]["name"] 14 | 15 | if vector_db_name == "chromadb": 16 | from hyrax.vector_dbs.chromadb_impl import ChromaDB 17 | 18 | return ChromaDB(config, context) 19 | elif vector_db_name == "qdrant": 20 | from hyrax.vector_dbs.qdrantdb_impl import QdrantDB 21 | 22 | return QdrantDB(config, context) 23 | else: 24 | return None 25 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # For explanation of this file and uses see 2 | # https://git-scm.com/docs/gitattributes 3 | # https://developer.lsst.io/git/git-lfs.html#using-git-lfs-enabled-repositories 4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/git-lfs.html 5 | # 6 | # Used by https://github.com/lsst/afwdata.git 7 | # *.boost filter=lfs diff=lfs merge=lfs -text 8 | # *.dat filter=lfs diff=lfs merge=lfs -text 9 | # *.fits filter=lfs diff=lfs merge=lfs -text 10 | # *.gz filter=lfs diff=lfs merge=lfs -text 11 | # 12 | # apache parquet files 13 | # *.parq filter=lfs diff=lfs merge=lfs -text 14 | # 15 | # sqlite files 16 | # *.sqlite3 filter=lfs diff=lfs merge=lfs -text 17 | # 18 | # gzip files 19 | # *.gz filter=lfs diff=lfs merge=lfs -text 20 | # 21 | # png image files 22 | # *.png filter=lfs diff=lfs merge=lfs -text 23 | 24 | .git_archival.txt export-subst -------------------------------------------------------------------------------- /src/hyrax/rebuild_manifest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from hyrax.pytorch_ignite import setup_dataset 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def run(config): 9 | """Rebuild a broken download manifest 10 | 11 | Parameters 12 | ---------- 13 | config : dict 14 | The parsed config file as a nested 15 | dict 16 | """ 17 | from .data_sets.hsc_data_set import HSCDataSet 18 | 19 | config["rebuild_manifest"] = True 20 | 21 | data_set = setup_dataset(config) 22 | 23 | if not isinstance(data_set, HSCDataSet): 24 | msg = "Invalid to run rebuild manafest except on an HSCDataSet." 25 | raise RuntimeError(msg) 26 | 27 | logger.info("Starting rebuild of manifest") 28 | 29 | data_set._rebuild_manifest(config) 30 | 31 | logger.info("Finished Rebuild Manifest") 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/1-bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Tell us about a problem to fix 4 | title: 'Short description' 5 | labels: 'bug' 6 | assignees: '' 7 | 8 | --- 9 | **Bug report** 10 | 11 | 12 | **Environment Information** 13 | 14 | 15 | **Before submitting** 16 | Please check the following: 17 | 18 | - [ ] I have described the situation in which the bug arose, including what code was executed, and any applicable data others will need to reproduce the problem. 19 | - [ ] I have included information about my environment, including the version of this package (e.g. `hyrax.__version__`) 20 | - [ ] I have included available evidence of the unexpected behavior (including error messages, screenshots, and/or plots) as well as a description of what I expected instead. 21 | - [ ] If I have a solution in mind, I have provided an explanation and/or pseudocode and/or task list. 22 | -------------------------------------------------------------------------------- /src/hyrax/verbs/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .verb_registry import Verb, hyrax_verb 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | @hyrax_verb 9 | class Model(Verb): 10 | """Resolves the model class that is defined in the config file. 11 | This will return a reference to the model class.""" 12 | 13 | cli_name = "model" 14 | add_parser_kwargs = {} 15 | 16 | @staticmethod 17 | def setup_parser(parser): 18 | """Not implemented""" 19 | pass 20 | 21 | def run_cli(self): 22 | """Not implemented""" 23 | logger.error("Running model from the cli is unimplemented") 24 | 25 | def run(self): 26 | """Fetch and return the model _class_. Does not create an instance of 27 | the model class. 28 | """ 29 | from hyrax.models.model_registry import fetch_model_class 30 | 31 | config = self.config 32 | 33 | model_cls = fetch_model_class(config) 34 | 35 | return model_cls 36 | -------------------------------------------------------------------------------- /docs/architecture_overview.rst: -------------------------------------------------------------------------------- 1 | Architecture overview 2 | ===================== 3 | 4 | Hyrax uses verbs 5 | ---------------- 6 | Hyrax defines a set of commands, called verbs, that are the primary mode of interaction. 7 | Verbs are meant to be intuitive and easy to remember. For instance, to train a model, 8 | you would use the ``train`` verb. 9 | To use a trained model for inference, you would use the ``infer`` verb. 10 | 11 | Notebook, CLI, or Both 12 | -------------------------------- 13 | Hyrax is designed to be used in a Jupyter notebook or from the command line without 14 | modification. This supports exploration and development in a familiar notebook environment 15 | and deployment to an HPC or Slurm system for large scale training. 16 | 17 | .. tabs:: 18 | 19 | .. group-tab:: Notebook 20 | 21 | .. code-block:: python 22 | 23 | from hyrax import Hyrax 24 | 25 | h = Hyrax(config_file = 'my_config.toml') 26 | h.train() 27 | 28 | .. group-tab:: CLI 29 | 30 | .. code-block:: bash 31 | 32 | >> hyrax train -c my_config.toml 33 | -------------------------------------------------------------------------------- /.github/workflows/build-documentation.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflow will install Python dependencies, build the package and then build the documentation. 3 | 4 | name: Build documentation 5 | 6 | 7 | on: 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | build: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v6 24 | - name: Set up Python 3.11 25 | uses: actions/setup-python@v6 26 | with: 27 | python-version: '3.11' 28 | - name: Install dependencies 29 | run: | 30 | sudo apt-get update 31 | python -m pip install --upgrade pip 32 | if [ -f docs/requirements.txt ]; then pip install -r docs/requirements.txt; fi 33 | pip install . 34 | - name: Install notebook requirements 35 | run: | 36 | sudo apt-get install pandoc 37 | - name: Build docs 38 | run: | 39 | sphinx-build -T -E -b html -d docs/build/doctrees ./docs docs/build/html 40 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflow will upload a Python Package using Twine when a release is created 3 | # For more information see: https://github.com/pypa/gh-action-pypi-publish#trusted-publishing 4 | 5 | # This workflow uses actions that are not certified by GitHub. 6 | # They are provided by a third-party and are governed by 7 | # separate terms of service, privacy policy, and support 8 | # documentation. 9 | 10 | name: Upload Python Package 11 | 12 | on: 13 | release: 14 | types: [published] 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | deploy: 21 | 22 | runs-on: ubuntu-latest 23 | permissions: 24 | id-token: write 25 | steps: 26 | - uses: actions/checkout@v6 27 | - name: Set up Python 28 | uses: actions/setup-python@v6 29 | with: 30 | python-version: '3.11' 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install build 35 | - name: Build package 36 | run: python -m build 37 | - name: Publish package 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 LINCC Frameworks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/hyrax/verbs/__init__.py: -------------------------------------------------------------------------------- 1 | # Remove import sorting, these are imported in the order written so that 2 | # autoapi docs are generated with ordering controlled below. 3 | # ruff: noqa: I001 4 | from hyrax.verbs.database_connection import DatabaseConnection 5 | from hyrax.verbs.umap import Umap 6 | from hyrax.verbs.infer import Infer 7 | from hyrax.verbs.train import Train 8 | from hyrax.verbs.visualize import Visualize 9 | from hyrax.verbs.lookup import Lookup 10 | from hyrax.verbs.save_to_database import SaveToDatabase 11 | from hyrax.verbs.model import Model 12 | from hyrax.verbs.to_onnx import ToOnnx 13 | from hyrax.verbs.engine import Engine 14 | from hyrax.verbs.verb_registry import Verb 15 | from hyrax.verbs.verb_registry import all_class_verbs, all_verbs, fetch_verb_class, is_verb_class 16 | 17 | __all__ = [ 18 | "VERB_REGISTRY", 19 | "is_verb_class", 20 | "fetch_verb_class", 21 | "all_class_verbs", 22 | "all_verbs", 23 | "Lookup", 24 | "Umap", 25 | "Visualize", 26 | "Infer", 27 | "Train", 28 | "SaveToDatabase", 29 | "Verb", 30 | "DatabaseConnection", 31 | "Model", 32 | "ToOnnx", 33 | "Engine", 34 | ] 35 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= -T -E -d _build/doctrees -D language=en -W --keep-going -n 7 | EXCLUDENB ?= -D exclude_patterns="notebooks/*","_build","**.ipynb_checkpoints" 8 | SPHINXBUILD ?= sphinx-build 9 | SOURCEDIR = . 10 | BUILDDIR = ../_readthedocs/ 11 | 12 | .PHONY: help clean Makefile no-nb no-notebooks 13 | 14 | # Put it first so that "make" without argument is like "make help". 15 | help: 16 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 17 | 18 | # Build all Sphinx docs locally, except the notebooks 19 | no-nb no-notebooks: 20 | @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(EXCLUDENB) $(O) 21 | 22 | # Cleans up files generated by the build process 23 | clean: 24 | rm -rf "_build/doctrees" 25 | rm -rf "autoapi" 26 | rm -rf "$(BUILDDIR)" 27 | 28 | # Catch-all target: route all unknown targets to Sphinx using the new 29 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 30 | %: Makefile 31 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 32 | 33 | -------------------------------------------------------------------------------- /docs/notebooks/model_input_3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b19e006d", 6 | "metadata": {}, 7 | "source": [ 8 | "# Providing Data - Level 3\n", 9 | "## Directory-based vs. Percentage-based dataset splits\n", 10 | "When we don't provide a `model_inputs.validate` key, then we'll use the dataset defined in `train` and the percentages defined in `data_set.train_size`, `data_set.validate_size` and `data_set.test_size` to define percentage or absolute number splits.\n", 11 | "\n", 12 | "## Defining what data to use for inference\n", 13 | "Introduce `model_inputs.infer` as the way to define what data is provided to the model during inference.\n", 14 | "\n", 15 | "## Requesting multi-modal data\n", 16 | "\n", 17 | "In this notebook we'll see how to request data from multiple data sources using multiple ``HyraxDataset``s.\n", 18 | "\n", 19 | "To do so we'll start with ``model_inputs`` data request. " 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "9b20cc7b", 25 | "metadata": {}, 26 | "source": [] 27 | } 28 | ], 29 | "metadata": { 30 | "language_info": { 31 | "name": "python" 32 | } 33 | }, 34 | "nbformat": 4, 35 | "nbformat_minor": 5 36 | } 37 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-ci.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflow runs pre-commit hooks on pushes and pull requests to main 3 | # to enforce coding style. To ensure correct configuration, please refer to: 4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/ci_precommit.html 5 | name: Run pre-commit hooks 6 | 7 | on: 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | jobs: 14 | pre-commit-ci: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v6 18 | with: 19 | fetch-depth: 0 20 | - name: Set up Python 21 | uses: actions/setup-python@v6 22 | with: 23 | python-version: '3.11' 24 | - name: Install uv 25 | uses: astral-sh/setup-uv@v5 26 | - name: Install dependencies 27 | run: | 28 | sudo apt-get update 29 | uv pip install --system .[dev] 30 | if [ -f requirements.txt ]; then uv pip install --system -r requirements.txt; fi 31 | - uses: pre-commit/action@v3.0.1 32 | with: 33 | extra_args: --all-files --verbose 34 | env: 35 | SKIP: "check-lincc-frameworks-template-version,no-commit-to-branch,check-added-large-files,validate-pyproject,sphinx-build,pytest-check" 36 | - uses: pre-commit-ci/lite-action@v1.1.0 37 | if: failure() && github.event_name == 'pull_request' && github.event.pull_request.draft == false -------------------------------------------------------------------------------- /.github/workflows/testing-and-coverage.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflow will install Python dependencies, run tests and report code coverage with a variety of Python versions 3 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 4 | 5 | name: Unit test and code coverage 6 | 7 | on: 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | jobs: 14 | build: 15 | 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: ['3.10', '3.11', '3.12', '3.13'] 20 | 21 | steps: 22 | - uses: actions/checkout@v6 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v6 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install uv 28 | uses: astral-sh/setup-uv@v5 29 | - name: Install dependencies 30 | run: | 31 | sudo apt-get update 32 | uv pip install --system -e .[dev] 33 | if [ -f requirements.txt ]; then uv pip install --system -r requirements.txt; fi 34 | - name: Run unit tests with pytest 35 | run: | 36 | python -m pytest --cov=hyrax --cov-report=xml -m "not slow" 37 | - name: Upload coverage report to codecov 38 | uses: codecov/codecov-action@v5 39 | with: 40 | token: ${{ secrets.CODECOV_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/smoke-test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run daily at 06:45. 2 | # It will install Python dependencies and run tests with a variety of Python versions. 3 | # See documentation for help debugging smoke test issues: 4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/ci_testing.html#version-culprit 5 | 6 | name: Unit test smoke test 7 | 8 | on: 9 | 10 | # Runs this workflow automatically 11 | schedule: 12 | - cron: 45 6 * * * 13 | 14 | # Allows you to run this workflow manually from the Actions tab 15 | workflow_dispatch: 16 | 17 | jobs: 18 | build: 19 | 20 | runs-on: ubuntu-latest 21 | strategy: 22 | matrix: 23 | python-version: ['3.10', '3.11', '3.12', '3.13'] 24 | 25 | steps: 26 | - uses: actions/checkout@v6 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v6 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Install uv 32 | uses: astral-sh/setup-uv@v5 33 | - name: Install dependencies 34 | run: | 35 | sudo apt-get update 36 | uv pip install --system -e .[dev] 37 | if [ -f requirements.txt ]; then uv pip install --system -r requirements.txt; fi 38 | - name: List dependencies 39 | run: | 40 | pip list 41 | - name: Run unit tests with pytest 42 | run: | 43 | python -m pytest -m "not slow" -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 12 | 13 | ## Change Description 14 | 19 | - [ ] My PR includes a link to the issue that I am addressing 20 | 21 | 22 | 23 | ## Solution Description 24 | 25 | -------------------------------------------------------------------------------- /src/hyrax/models/hyrax_loopback.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | from .model_registry import hyrax_model 6 | 7 | logger = logging.getLogger() 8 | 9 | 10 | @hyrax_model 11 | class HyraxLoopback(nn.Module): 12 | """Simple model for testing which returns its own input""" 13 | 14 | def __init__(self, config, data_sample=None): 15 | from functools import partial 16 | 17 | super().__init__() 18 | # The optimizer needs at least one weight, so we add a dummy module here 19 | self.unused_module = nn.Linear(1, 1) 20 | self.config = config 21 | 22 | def load(self, weight_file): 23 | """Boilerplate function to load weights. However, this model has no 24 | weights so we do nothing.""" 25 | pass 26 | 27 | # We override this way rather than defining a method because 28 | # Torch has some __init__ related cleverness which stomps our 29 | # load definition when performed in the usual fashion. 30 | self.load = partial(load, self) 31 | 32 | def forward(self, x): 33 | """We simply return our input""" 34 | if isinstance(x, (tuple, list)): 35 | # if x is a tuple, extract the first element (it should be a tensor) 36 | x, _ = x 37 | return x 38 | 39 | def train_step(self, batch): 40 | """Training is a noop""" 41 | logger.debug(f"Batch length: {len(batch)}") 42 | return {"loss": 0.0} 43 | -------------------------------------------------------------------------------- /docs/dev_guide.rst: -------------------------------------------------------------------------------- 1 | Developer guide 2 | =============== 3 | 4 | Getting Started 5 | --------------- 6 | 7 | Before installing any dependencies or writing code, it's a great idea to create a 8 | virtual environment. LINCC-Frameworks engineers primarily use `conda` to manage virtual 9 | environments. If you have conda installed locally, you can run the following to 10 | create and activate a new environment. 11 | 12 | .. code-block:: console 13 | 14 | >> conda create -n python=3.10 15 | >> conda activate 16 | 17 | 18 | Build from Source 19 | ----------------- 20 | 21 | Once you have created a new environment, you can install this project for local 22 | development using the following commands: 23 | 24 | .. code-block:: console 25 | 26 | >> git clone https://github.com/lincc-frameworks/hyrax.git 27 | >> pip install -e .'[dev]' 28 | >> pre-commit install 29 | >> conda install pandoc 30 | 31 | 32 | Notes: 33 | 34 | 1) The single quotes around ``'[dev]'`` may not be required for your operating system. 35 | 2) ``pre-commit install`` will initialize pre-commit for this local repository, so 36 | that a set of tests will be run prior to completing a local commit. For more 37 | information, see the Python Project Template documentation on 38 | `pre-commit `_. 39 | 3) Installing ``pandoc`` allows you to verify that automatic rendering of Jupyter notebooks 40 | into documentation for ReadTheDocs works as expected. For more information, see 41 | the Python Project Template documentation on 42 | `Sphinx and Python Notebooks `_. 43 | -------------------------------------------------------------------------------- /src/hyrax/verbs/search.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import ArgumentParser, Namespace 3 | 4 | from .verb_registry import Verb, hyrax_verb 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | @hyrax_verb 10 | class Search(Verb): 11 | """Stub of similarity search""" 12 | 13 | cli_name = "search" 14 | add_parser_kwargs = {} 15 | 16 | @staticmethod 17 | def setup_parser(parser: ArgumentParser): 18 | """Stub of parser setup""" 19 | parser.add_argument("-i", "--image-file", type=str, help="Path to image file", required=True) 20 | 21 | # If both of these move to the verb superclass then a new verb is basically 22 | # 23 | # If you want no args, just make the class, define run(self) 24 | # If you want args 25 | # 1) write setup_parser (which sets up for ArgumentParser and name/type info for cli run) 26 | # 2) write run(self, ) to do what you want 27 | # 28 | 29 | # Should there be a version of this on the base class which uses a dict on the Verb 30 | # superclass to build the call to run based on what the subclass verb defined in setup_parser 31 | def run_cli(self, args: Namespace | None = None): 32 | """Stub CLI implementation""" 33 | logger.info("Search run from cli") 34 | if args is None: 35 | raise RuntimeError("Run CLI called with no arguments.") 36 | # This is where we map from CLI parsed args to a 37 | # self.run (args) call. 38 | return self.run(image_file=args.image_file) 39 | 40 | def run(self, image_file: str): 41 | """Search for... todo 42 | 43 | Parameters 44 | ---------- 45 | image_file : str 46 | _description_ 47 | """ 48 | logger.info(f"Got Image {image_file}") 49 | -------------------------------------------------------------------------------- /benchmarks/data_request_benchmarks.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | from hyrax import Hyrax 7 | 8 | 9 | class DatasetRequestBenchmarks: 10 | """Timing benchmarks for requesting data from the Hyrax random dataset""" 11 | 12 | def setup(self): 13 | """Prepare for benchmark by defining and setting up a random dataset""" 14 | self.tmp_dir = tempfile.TemporaryDirectory() 15 | self.input_dir = Path(self.tmp_dir.name) 16 | 17 | self.h = Hyrax() 18 | self.h.config["general"]["results_dir"] = str(self.input_dir) 19 | self.h.config["model_inputs"] = { 20 | "train": { 21 | "data": { 22 | "dataset_class": "HyraxRandomDataset", 23 | "data_location": str(self.input_dir), 24 | "fields": ["image", "label", "object_id"], 25 | } 26 | }, 27 | "infer": { 28 | "data": { 29 | "dataset_class": "HyraxRandomDataset", 30 | "data_location": str(self.input_dir), 31 | "fields": ["image", "label", "object_id"], 32 | } 33 | }, 34 | } 35 | 36 | num_vectors = 4096 37 | self.h.config["data_set"]["HyraxRandomDataset"]["size"] = num_vectors 38 | self.h.config["data_set"]["HyraxRandomDataset"]["seed"] = 0 39 | self.h.config["data_set"]["HyraxRandomDataset"]["shape"] = [3, 64, 64] 40 | 41 | self.ds = self.h.prepare() 42 | 43 | self.indexes = np.random.randint(0, num_vectors, size=128, dtype=int) 44 | 45 | def time_request_all_data(self): 46 | """Benchmark the amount of time needed to retrieve all the data from 47 | the random dataset 48 | """ 49 | for indx in self.indexes: 50 | self.ds["train"][indx] 51 | -------------------------------------------------------------------------------- /src/hyrax/3d_viz/readme.md: -------------------------------------------------------------------------------- 1 | # Hyrax 3D Latent Space Explorer 2 | The Hyrax 3D Latent Space Explorer is a JavaScript-based tool that lets you visualize and interact with three-dimensional UMAP embeddings of your dataset. You can color your embeddings by different parameters, select different objects to see their catalog properties, as well as the source data (e.g., images for an image-based dataset that ws projected onto a latent space) 3 | 4 | ## Saving UMAP-ed vectors as JSON 5 | * The first step in running the Hyrax 3D Latent Space Explorer is to convert the outputs from Hyrax UMAP module into a `.json` file 6 | * To do this, use save_umap_to_json.py. This can be run using `python save_umap_to_json.py /path/to/results/dir` 7 | * To understand optional arguments, do `python save_umap_to_json.py --help` 8 | 9 | 10 | ## Server Initialization 11 | * To start the Hyrax 3D Latent Space Explorer, type `python start_3d_viz_server.py` 12 | * This will launch the service on the 8181 port. If you are running this on a remote machine, forward this port appropriately using something like `ssh -N -L 8181:server_name:8181 username@loginnode.com` 13 | * Finally, navigate to http://localhost:8181/ where you will find the Hyrax 3D Latent Space Explorer running. 14 | * You can also change the port the server is being displayed on; and also pass a folder containing your cutouts. To see all the command line arguments, do `python start_3d_viz_server.py --help` 15 | * Note that the path passed to `cutouts_dir` is relative to the location of root of the server (i.e., location of the `start_3d_viz_server.py` file) 16 | 17 | ## FAQs 18 | 1. If there are repeated Object IDs in your dataset, you will see the second instance of the object not loaded in the image viewer. Instead, you will keep seeing the image loading spinning wheel symbol. 19 | 2. If images are not being loaded, chances are something is going on wrong in the file loading process. To debug, go to the Developer Console of your browser. On Google Chrome, this is View --> Developer --> Developer Tools --> Console 20 | 21 | 22 | ## Simpler Notebook Version -- Deprecated 23 | For a more straightforward plotly 3d plot, use the function in plotly_3d.py 24 | -------------------------------------------------------------------------------- /.github/workflows/publish-benchmarks-pr.yml: -------------------------------------------------------------------------------- 1 | # This workflow publishes a benchmarks comment on a pull request. It is triggered after the 2 | # benchmarks are computed in the asv-pr workflow. This separation of concerns allows us limit 3 | # access to the target repository private tokens and secrets, increasing the level of security. 4 | # Based on https://securitylab.github.com/research/github-actions-preventing-pwn-requests/. 5 | name: Publish benchmarks comment to PR 6 | 7 | on: 8 | workflow_run: 9 | workflows: ["Run benchmarks for PR"] 10 | types: [completed] 11 | 12 | jobs: 13 | upload-pr-comment: 14 | runs-on: ubuntu-latest 15 | if: > 16 | github.event.workflow_run.event == 'pull_request' && 17 | github.event.workflow_run.conclusion == 'success' 18 | permissions: 19 | issues: write 20 | pull-requests: write 21 | steps: 22 | - name: Display Workflow Run Information 23 | run: | 24 | echo "Workflow Run ID: ${{ github.event.workflow_run.id }}" 25 | echo "Head SHA: ${{ github.event.workflow_run.head_sha }}" 26 | echo "Head Branch: ${{ github.event.workflow_run.head_branch }}" 27 | echo "Conclusion: ${{ github.event.workflow_run.conclusion }}" 28 | echo "Event: ${{ github.event.workflow_run.event }}" 29 | - name: Download artifact 30 | uses: dawidd6/action-download-artifact@v11 31 | with: 32 | name: benchmark-artifacts 33 | run_id: ${{ github.event.workflow_run.id }} 34 | - name: Extract artifacts information 35 | id: pr-info 36 | run: | 37 | printf "PR number: $(cat pr)\n" 38 | printf "Output:\n$(cat output)" 39 | printf "pr=$(cat pr)" >> $GITHUB_OUTPUT 40 | - name: Find benchmarks comment 41 | uses: peter-evans/find-comment@v4 42 | id: find-comment 43 | with: 44 | issue-number: ${{ steps.pr-info.outputs.pr }} 45 | comment-author: 'github-actions[bot]' 46 | body-includes: view all benchmarks 47 | - name: Create or update benchmarks comment 48 | uses: peter-evans/create-or-update-comment@v5 49 | with: 50 | comment-id: ${{ steps.find-comment.outputs.comment-id }} 51 | issue-number: ${{ steps.pr-info.outputs.pr }} 52 | body-path: output 53 | edit-mode: replace -------------------------------------------------------------------------------- /docs/model_comparison.rst: -------------------------------------------------------------------------------- 1 | .. _model_comparison: 2 | 3 | Model comparison 4 | ================ 5 | 6 | One goal of Hyrax is to make model evaluation easier. Many tools exist for visualization 7 | and evaluation of models. Hyrax integrates with TensorBoard and MLFlow to provide 8 | easy access to these tools. 9 | 10 | TensorBoard 11 | ----------- 12 | 13 | Hyrax automatically logs training, validation and gpu metrics (when available) to 14 | TensorBoard while training a model. 15 | This allows for easy visualization of the training process. 16 | 17 | For more information about TensorBoard see the 18 | `TensorBoard documentation `_. 19 | 20 | MLFlow 21 | ------ 22 | 23 | Hyrax supports MLFlow for model tracking and experiment management. 24 | By default the data collected for each run will be nested under the experiment 25 | "notebook" using a run name that is the same as the results directory, 26 | i.e. -train-. 27 | 28 | The MLFlow server can be run from within a notebook or from the command line. 29 | 30 | .. tabs:: 31 | 32 | .. group-tab:: Notebook 33 | 34 | .. code-block:: python 35 | 36 | # Start the MLFlow UI server 37 | backend_store_uri = f"file://{Path(f.config['general']['results_dir']).resolve() / 'mlflow'}" 38 | mlflow_ui_process = subprocess.Popen( 39 | ["mlflow", "ui", "--backend-store-uri", backend_store_uri, "--port", "8080"], 40 | stdout=subprocess.PIPE, 41 | stderr=subprocess.PIPE, 42 | ) 43 | 44 | # Display the MLFlow UI in an IFrame in the notebook 45 | IFrame(src="http://localhost:8080", width="100%", height=1000) 46 | 47 | .. group-tab:: CLI 48 | 49 | .. code-block:: bash 50 | 51 | >> mlflow ui --port 8080 --backend-store-uri /mlflow 52 | 53 | If you are running mlflow on a remote server, you will need to add the `--host` flag to the command: 54 | 55 | .. code-block:: bash 56 | 57 | >> mlflow ui --port 8080 --backend-store-uri /mlflow --host 0.0.0.0 58 | 59 | on the remote server, and the forward the port (8080) to your local machine using SSH. 60 | 61 | 62 | For more information about MLFlow see the 63 | `MLFlow documentation `_. 64 | -------------------------------------------------------------------------------- /.github/workflows/asv-main.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run benchmarks with airspeed velocity (asv), 2 | # store the new results in the "benchmarks" branch and publish them 3 | # to a dashboard on GH Pages. 4 | name: Run ASV benchmarks for main 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | 10 | env: 11 | PYTHON_VERSION: "3.11" 12 | ASV_VERSION: "0.6.5" 13 | WORKING_DIR: ${{github.workspace}}/benchmarks 14 | 15 | concurrency: 16 | group: ${{github.workflow}}-${{github.ref}} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | asv-main: 21 | runs-on: ubuntu-latest 22 | permissions: 23 | contents: write 24 | defaults: 25 | run: 26 | working-directory: ${{env.WORKING_DIR}} 27 | steps: 28 | - name: Set up Python ${{env.PYTHON_VERSION}} 29 | uses: actions/setup-python@v6 30 | with: 31 | python-version: ${{env.PYTHON_VERSION}} 32 | - name: Checkout main branch of the repository 33 | uses: actions/checkout@v6 34 | with: 35 | fetch-depth: 0 36 | - name: Install dependencies 37 | run: pip install asv[virtualenv]==${{env.ASV_VERSION}} 38 | - name: Configure git 39 | run: | 40 | git config user.name "github-actions[bot]" 41 | git config user.email "41898282+github-actions[bot]@users.noreply.github.com" 42 | - name: Create ASV machine config file 43 | run: asv machine --machine gh-runner --yes 44 | - name: Fetch previous results from the "benchmarks" branch 45 | run: | 46 | if git ls-remote --exit-code origin benchmarks > /dev/null 2>&1; then 47 | git merge origin/benchmarks \ 48 | --allow-unrelated-histories \ 49 | --no-commit 50 | mv ../_results . 51 | fi 52 | - name: Run ASV for the main branch 53 | run: asv run ALL --skip-existing --verbose || true 54 | - name: Submit new results to the "benchmarks" branch 55 | uses: JamesIves/github-pages-deploy-action@v4 56 | with: 57 | branch: benchmarks 58 | folder: ${{env.WORKING_DIR}}/_results 59 | target-folder: _results 60 | - name: Generate dashboard HTML 61 | run: | 62 | asv show 63 | asv publish 64 | - name: Deploy to Github pages 65 | uses: JamesIves/github-pages-deploy-action@v4 66 | with: 67 | branch: gh-pages 68 | folder: ${{env.WORKING_DIR}}/_html -------------------------------------------------------------------------------- /docs/about.rst: -------------------------------------------------------------------------------- 1 | About Hyrax 2 | =========== 3 | 4 | What is Hyrax? 5 | -------------- 6 | Hyrax is a powerful and extensible machine learning framework that automates data 7 | acquisition, scales seamlessly from laptops to HPC, and ensures reproducibility 8 | — freeing astronomers to focus on discovery instead of infrastructure. 9 | 10 | Put another way, it's an effort to bring the best practices of software engineering 11 | to the astronomy machine learning community. 12 | 13 | 14 | Why build Hyrax? 15 | ---------------- 16 | Image-based machine learning in astronomy is challenging work. 17 | It requires data collection, pre-processing, model training and evaluation, and 18 | analysis of results of inference - all of which introduce potential bottlenecks 19 | for new science. 20 | 21 | We've found that many bottlenecks require significant effort to overcome, and that 22 | effort doesn't accrue to science, it's just a means to an end. 23 | And worse, it's often duplicated effort by many different people, each solving 24 | the same problems in slightly different ways. 25 | 26 | Hyrax is our effort to make the process easier and more efficient by taking care of 27 | the common tasks so that scientists can focus on the science. 28 | 29 | 30 | Guiding principles 31 | ------------------ 32 | * **Principle 1** Empower the scientists to do science - not software engineering. 33 | Hyrax automatically discovers and uses the most performant hardware available 34 | for training without any changes to the users code. 35 | * **Principle 2** Make the software easy to use. 36 | Hyrax is designed to be used in a Jupyter notebook for exploration or from the 37 | command line within HPC or Slurm environments without modification. 38 | * **Principle 3** Make the software extensible to support many different use cases. 39 | We work closely with scientists to build Hyrax to support their use cases, but 40 | we learned early on that we can't anticipate all the ways that Hyrax will be used. 41 | Hyrax is designed to be easily extended to support new models, data sources, 42 | and functionality. 43 | 44 | 45 | Commitment to open science 46 | --------------------------- 47 | Hyrax is open source software, and makes extensive use of open source libraries. 48 | We envision a Hyrax ecosystem where users can freely share data, trained models, 49 | and other components to accelerate their research and the research of others. 50 | -------------------------------------------------------------------------------- /src/hyrax/models/hsc_autoencoder.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101, D102 2 | 3 | 4 | import torch.nn as nn 5 | 6 | # extra long import here to address a circular import issue 7 | from hyrax.models.model_registry import hyrax_model 8 | 9 | 10 | @hyrax_model 11 | class HSCAutoencoder(nn.Module): # These shapes work with [3,258,258] inputs 12 | """ 13 | This autoencoder is designed to work with datasets that are prepared with Hyrax's HSC Data Set class. 14 | """ 15 | 16 | def __init__(self, config, data_sample=None): 17 | super().__init__() 18 | 19 | # Encoder 20 | self.encoder = nn.Sequential( 21 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 22 | nn.ReLU(), 23 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 24 | nn.ReLU(), 25 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 26 | nn.ReLU(), 27 | ) 28 | 29 | # Decoder 30 | self.decoder = nn.Sequential( 31 | nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), 32 | nn.ReLU(), 33 | nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 34 | nn.ReLU(), 35 | nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=4, output_padding=1), 36 | nn.Sigmoid(), # Output pixel values between 0 and 1 37 | ) 38 | 39 | self.config = config 40 | 41 | def forward(self, x): 42 | encoded = self.encoder(x) 43 | decoded = self.decoder(encoded) 44 | return decoded 45 | 46 | def train_step(self, batch): 47 | """ 48 | This function contains the logic for a single training step. i.e. the 49 | contents of the inner loop of a ML training process. 50 | 51 | Parameters 52 | ---------- 53 | batch : tuple 54 | A tuple containing the two values the loss function 55 | 56 | Returns 57 | ------- 58 | Current loss value : dict 59 | Dictionary containing the loss value for the current batch. 60 | """ 61 | 62 | data = batch[0] 63 | self.optimizer.zero_grad() 64 | 65 | decoded = self.forward(data) 66 | loss = self.criterion(decoded, data) 67 | loss.backward() 68 | self.optimizer.step() 69 | 70 | return {"loss": loss.item()} 71 | -------------------------------------------------------------------------------- /src/hyrax/3d_viz/plotly_3d.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import plotly.graph_objects as go 6 | 7 | 8 | def plot_umap_3d_interactive(results_dir, fig_size=(800, 600)): 9 | """Reads UMAP batch results, extracts object IDs, 10 | and creates an interactive 3D scatter plot with tooltips.""" 11 | 12 | results_dir = Path(results_dir) 13 | 14 | # Find batch files matching 'batch_.npy' 15 | batch_files = sorted([f for f in results_dir.glob("batch_*.npy") if re.match(r"batch_\d+\.npy$", f.name)]) 16 | 17 | if not batch_files: 18 | raise FileNotFoundError(f"No valid batch files found in {results_dir}") 19 | 20 | # Load embeddings and object IDs from all batches 21 | embeddings_list = [] 22 | object_ids_list = [] 23 | 24 | for batch_file in batch_files: 25 | data = np.load(batch_file) 26 | embeddings_list.append(data["tensor"]) 27 | object_ids_list.append(data["id"]) # Corrected field name 28 | 29 | # Concatenate all embeddings and object IDs 30 | embeddings = np.concatenate(embeddings_list, axis=0) 31 | object_ids = np.concatenate(object_ids_list, axis=0) 32 | 33 | if embeddings.shape[1] != 3: 34 | raise ValueError(f"Expected 3D embeddings, but got shape {embeddings.shape}") 35 | 36 | hover_texts = [ 37 | f"Object ID: {obj_id}
X: {x:.3f}
Y: {y:.3f}
Z: {z:.3f}" 38 | for (x, y, z), obj_id in zip(embeddings, object_ids) 39 | ] 40 | 41 | # Create interactive 3D scatter plot with Plotly 42 | fig = go.Figure( 43 | data=[ 44 | go.Scatter3d( 45 | x=embeddings[:, 0], 46 | y=embeddings[:, 1], 47 | z=embeddings[:, 2], 48 | mode="markers", 49 | marker=dict(size=3, color=embeddings[:, 2], colorscale="Viridis", opacity=0.8), 50 | text=hover_texts, 51 | hoverinfo="text", 52 | ) 53 | ] 54 | ) 55 | 56 | # Customize layout 57 | fig.update_layout( 58 | title="Interactive 3D UMAP Embeddings", 59 | width=fig_size[0], 60 | height=fig_size[1], 61 | scene=dict( 62 | xaxis_title="UMAP Dimension 1", 63 | yaxis_title="UMAP Dimension 2", 64 | zaxis_title="UMAP Dimension 3", 65 | bgcolor="black", 66 | ), 67 | template="plotly_dark", 68 | ) 69 | 70 | fig.show() 71 | -------------------------------------------------------------------------------- /src/hyrax_cli/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from importlib.metadata import version 4 | 5 | from hyrax import Hyrax 6 | from hyrax.verbs import all_verbs, fetch_verb_class, is_verb_class 7 | 8 | 9 | def main(): 10 | """Primary entry point for the Hyrax CLI. This handles dispatching to the various 11 | Hyrax actions and returning a result. 12 | """ 13 | 14 | description = "Hyrax CLI" 15 | epilog = "Hyrax is the Framework for Image-Based Anomaly Detection" 16 | 17 | parser = argparse.ArgumentParser(description=description, epilog=epilog) 18 | parser.add_argument("--version", dest="version", action="store_true", help="Show version") 19 | 20 | # cut off "usage: " from beginning and "\n" from end so we get an invocation 21 | # which subcommand parsers can add to appropriately. 22 | subparser_usage_prefix = parser.format_usage()[7:-1] 23 | subparsers = parser.add_subparsers(title="Verbs:", required=False) 24 | 25 | # Add a subparser for every verb, (whether defined by function or class) 26 | for cli_name in all_verbs(): 27 | subparser_kwargs = {} 28 | 29 | if is_verb_class(cli_name): 30 | verb_class = fetch_verb_class(cli_name) 31 | subparser_kwargs = verb_class.add_parser_kwargs 32 | 33 | verb_parser = subparsers.add_parser( 34 | cli_name, prog=subparser_usage_prefix + " " + cli_name, **subparser_kwargs 35 | ) 36 | 37 | if is_verb_class(cli_name): 38 | verb_class.setup_parser(verb_parser) 39 | 40 | verb_parser.set_defaults(verb=cli_name) 41 | _add_major_arguments(verb_parser) 42 | 43 | args = parser.parse_args() 44 | 45 | if args.version: 46 | print(version("hyrax")) 47 | return 48 | 49 | if not args.verb: 50 | parser.print_help() 51 | sys.exit(1) 52 | 53 | hyrax_instance = Hyrax(config_file=args.runtime_config) 54 | retval = 0 55 | if is_verb_class(args.verb): 56 | verb = fetch_verb_class(args.verb)(hyrax_instance.config) 57 | retval = verb.run_cli(args) 58 | else: 59 | getattr(hyrax_instance, args.verb)() 60 | 61 | exit(retval) 62 | 63 | 64 | def _add_major_arguments(parser): 65 | parser.add_argument("--version", dest="version", action="store_true", help="Show version") 66 | parser.add_argument("-c", "--runtime-config", type=str, help="Full path to runtime config file") 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /tests/hyrax/test_umap.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | 3 | import numpy as np 4 | 5 | 6 | class FakeUmap: 7 | """ 8 | A Fake implementation of umap.UMAP which simply returns what is passed to it. 9 | This works with the loopback model and random dataset since they both output 10 | pairs of points, so the umap output is also pairs of points 11 | 12 | Install on a test like 13 | 14 | @mock.patch("umap.UMAP", FakeUmap) 15 | def test_blah(): 16 | pass 17 | """ 18 | 19 | def __init__(self, *args, **kwargs): 20 | print("Called FakeUmap init") 21 | 22 | def fit(self, data): 23 | """We do nothing when fit on data. Prints are purely to help debug tests""" 24 | print("Called FakeUmap fit:") 25 | print(f"shape: {data.shape}") 26 | print(f"dtype: {data.dtype}") 27 | 28 | def transform(self, data): 29 | """We return our input when called to transform. Prints are purely to help debug tests""" 30 | print("Called FakeUmap transform:") 31 | print(f"shape: {data.shape}") 32 | print(f"dtype: {data.dtype}") 33 | return data 34 | 35 | 36 | @mock.patch("umap.UMAP", FakeUmap) 37 | def test_umap_order(loopback_inferred_hyrax): 38 | """Test that the order of data run through infer 39 | is correct in the presence of several splits 40 | """ 41 | h, dataset, _ = loopback_inferred_hyrax 42 | 43 | dataset = dataset["infer"] 44 | 45 | umap_results = h.umap() 46 | umap_result_ids = list(umap_results.ids()) 47 | original_dataset_ids = list(dataset.ids()) 48 | 49 | if dataset.is_iterable(): 50 | dataset = list(dataset) 51 | original_dataset_ids = np.array([str(s["object_id"]) for s in dataset]) 52 | 53 | data_shape = h.config["data_set"]["HyraxRandomDataset"]["shape"] 54 | 55 | for idx, result_id in enumerate(umap_result_ids): 56 | dataset_idx = None 57 | for i, orig_id in enumerate(original_dataset_ids): 58 | if orig_id == result_id: 59 | dataset_idx = i 60 | break 61 | else: 62 | raise AssertionError("Failed to find a corresponding ID") 63 | 64 | umap_result = umap_results[idx].cpu().numpy().reshape(data_shape) 65 | 66 | print(f"orig idx: {dataset_idx}, umap idx: {idx}") 67 | print(f"orig data: {dataset[dataset_idx]}, umap data: {umap_result}") 68 | assert np.all(np.isclose(dataset[dataset_idx]["data"]["image"], umap_result)) 69 | -------------------------------------------------------------------------------- /.setup_dev.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Bash Unofficial strict mode (http://redsymbol.net/articles/unofficial-bash-strict-mode/) 4 | # and (https://disconnected.systems/blog/another-bash-strict-mode/) 5 | set -o nounset # Any uninitialized variable is an error 6 | set -o errexit # Exit the script on the failure of any command to execute without error 7 | set -o pipefail # Fail command pipelines on the failure of any individual step 8 | IFS=$'\n\t' #set internal field separator to avoid iteration errors 9 | # Trap all exits and output something helpful 10 | trap 's=$?; echo "$0: Error on line "$LINENO": $BASH_COMMAND"; exit $s' ERR 11 | 12 | # This script should be run by new developers to install this package in 13 | # editable mode and configure their local environment 14 | 15 | echo "Checking virtual environment" 16 | if [ "${VIRTUAL_ENV:-missing}" = "missing" ] && [ "${CONDA_PREFIX:-missing}" = "missing" ]; then 17 | echo 'No virtual environment detected: none of $VIRTUAL_ENV or $CONDA_PREFIX is set.' 18 | echo 19 | echo "=== This script is going to install the project in the system python environment ===" 20 | echo "Proceed? [y/N]" 21 | read -r RESPONCE 22 | if [ "${RESPONCE}" != "y" ]; then 23 | echo "See https://lincc-ppt.readthedocs.io/ for details." 24 | echo "Exiting." 25 | exit 1 26 | fi 27 | 28 | fi 29 | 30 | echo "Checking pip version" 31 | MINIMUM_PIP_VERSION=22 32 | pipversion=( $(python -m pip --version | awk '{print $2}' | sed 's/\./\n\t/g') ) 33 | if let "${pipversion[0]}<${MINIMUM_PIP_VERSION}"; then 34 | echo "Insufficient version of pip found. Requires at least version ${MINIMUM_PIP_VERSION}." 35 | echo "See https://lincc-ppt.readthedocs.io/ for details." 36 | exit 1 37 | fi 38 | 39 | echo "Installing package and runtime dependencies in local environment" 40 | echo "This might take a few minutes. Only errors will be printed to stdout" 41 | python -m pip install -e . > /dev/null 42 | 43 | echo "Installing developer dependencies in local environment" 44 | echo "This might take a few minutes. Only errors will be printed to stdout" 45 | python -m pip install -e .'[dev]' > /dev/null 46 | if [ -f docs/requirements.txt ]; then python -m pip install -r docs/requirements.txt > /dev/null; fi 47 | 48 | echo "Installing pre-commit" 49 | pre-commit install > /dev/null 50 | 51 | ####################################################### 52 | # Include any additional configurations below this line 53 | ####################################################### 54 | -------------------------------------------------------------------------------- /src/hyrax/verbs/verb_registry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class Verb(ABC): # noqa: B024 8 | """Base class for all hyrax verbs""" 9 | 10 | # Verbs get to define how their parser gets added to the main parser 11 | # This is given in case verbs do not define any keyword args for 12 | # subparser.add_parser() 13 | add_parser_kwargs: dict[str, str] = {} 14 | 15 | def __init__(self, config): 16 | """ 17 | .. py:method:: __init__ 18 | 19 | Overall initialization for all verbs that saves the config 20 | """ 21 | self.config = config 22 | 23 | 24 | # Verbs with no class are assumed to have a function in hyrax.py which 25 | # performs their function. All other verbs should be defined by named classes 26 | # in hyrax.verbs and use the @hyrax_verb decorator 27 | VERB_REGISTRY: dict[str, type[Verb] | None] = { 28 | "train": None, 29 | "infer": None, 30 | "download": None, 31 | "prepare": None, 32 | "rebuild_manifest": None, 33 | } 34 | 35 | 36 | def hyrax_verb(cls: type[Verb]) -> type[Verb]: 37 | """Decorator to register a hyrax verb""" 38 | from hyrax.plugin_utils import update_registry 39 | 40 | update_registry(VERB_REGISTRY, cls.cli_name, cls) # type: ignore[attr-defined] 41 | return cls 42 | 43 | 44 | def all_verbs() -> list[str]: 45 | """Returns all verbs that are currently registered""" 46 | return [verb for verb in VERB_REGISTRY] 47 | 48 | 49 | def all_class_verbs() -> list[str]: 50 | """Returns all verbs that are currently registered with a class-based implementation""" 51 | return [verb for verb in VERB_REGISTRY if VERB_REGISTRY.get(verb) is not None] 52 | 53 | 54 | def is_verb_class(cli_name: str) -> bool: 55 | """Returns true if the verb has a class based implementation 56 | 57 | Parameters 58 | ---------- 59 | cli_name : str 60 | The name of the verb on the command line interface 61 | 62 | Returns 63 | ------- 64 | bool 65 | True if the verb has a class-based implementation 66 | """ 67 | return cli_name in VERB_REGISTRY and VERB_REGISTRY.get(cli_name) is not None 68 | 69 | 70 | def fetch_verb_class(cli_name: str) -> type[Verb] | None: 71 | """Gives the class object for the named verb 72 | 73 | Parameters 74 | ---------- 75 | cli_name : str 76 | The name of the verb on the command line interface 77 | 78 | 79 | Returns 80 | ------- 81 | Optional[type[Verb]] 82 | The verb class or None if no such verb class exists. 83 | """ 84 | return VERB_REGISTRY.get(cli_name) 85 | -------------------------------------------------------------------------------- /README_RSP.md: -------------------------------------------------------------------------------- 1 | # Contribution instructions from an RSP 2 | 3 | This is a set of instructions for how to get a source checkout of hyrax working on an RSP. You will need to do this if you are developing a hyrax feature or modifying one of the in-built dataset classes. 4 | 5 | The instructions are tailored to the usdf RSP but should be modifiable for other RSPs 6 | 7 | ## Setting up your repository 8 | First add your SSH key to github. This is so you will eventually be able to push your changes. 9 | on USDF the public version of you key is in `~/.ssh/s3df/id_ed25519.pub`. You will need to paste the contents of this key into github [here](https://github.com/settings/ssh/new) to add it to your account. 10 | 11 | Next you need to tell git to use this key when it uses ssh. This is not working by default on USDF, so you will need to clone the repository with the command below. It's recommended you run this from your `~/rubin-user` directory such that your checkout will be in `~/rubin-user/hyrax`. This way it will persist between notebook server invocations in USDF. 12 | 13 | ``` 14 | GIT_SSH_COMMAND='ssh -i ~/.ssh/s3df/id_ed25519 -o IdentitiesOnly=yes' git clone git@github.com:lincc-frameworks/hyrax.git 15 | ``` 16 | 17 | After this you an enter the directory and run the following command to make sure all future git commands in this repository will use your private key: 18 | 19 | ``` 20 | git config --global core.sshCommand 'ssh -i ~/.ssh/s3df/id_ed25519 -o IdentitiesOnly=yes' 21 | ``` 22 | 23 | At this point you have an active git repository and you should be able to `git fetch` without error. It is recommended that you set your global username and email with the following commands if you intend to commit with the following commands: 24 | 25 | ``` 26 | git config --global user.email "email@domain.tld" 27 | git config --global user.name "Your Name" 28 | ``` 29 | 30 | ## Setting up your notebook 31 | After this setup you can run the following notebook magic to bootstrap an editable install of hyrax in your notebook environment 32 | 33 | ``` 34 | # In the first cell of your notebook 35 | %pip install -q -e ~/rubin-user/hyrax 2>&1 | grep -vE 'WARNING: Error parsing dependencies of (lsst-|astshim|astro-)' 36 | ``` 37 | 38 | This command may suggest you restart your kernel, and you should do so if it asks. When you edit hyrax you will need to restart your kernel, but will not need to re-run this install command. 39 | 40 | ## Running Hyrax 41 | 42 | You can now run hyrax in a notebook cell. This is a sample that uses the LSSTDataset class, which only functions inside of the RSP 43 | 44 | ``` 45 | import hyrax 46 | h = hyrax.Hyrax() 47 | h.config["data_set"]["name"] = "LSSTDataset" 48 | 49 | d = h.prepare() 50 | 51 | d[0].shape 52 | ``` 53 | 54 | -------------------------------------------------------------------------------- /.github/workflows/asv-nightly.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run daily at 06:45. 2 | # It will run benchmarks with airspeed velocity (asv) 3 | # and compare performance with the previous nightly build. 4 | name: Run benchmarks nightly job 5 | 6 | on: 7 | schedule: 8 | - cron: 45 6 * * * 9 | workflow_dispatch: 10 | 11 | env: 12 | PYTHON_VERSION: "3.11" 13 | ASV_VERSION: "0.6.5" 14 | WORKING_DIR: ${{github.workspace}}/benchmarks 15 | NIGHTLY_HASH_FILE: nightly-hash 16 | 17 | jobs: 18 | asv-nightly: 19 | runs-on: ubuntu-latest 20 | defaults: 21 | run: 22 | working-directory: ${{env.WORKING_DIR}} 23 | steps: 24 | - name: Set up Python ${{env.PYTHON_VERSION}} 25 | uses: actions/setup-python@v6 26 | with: 27 | python-version: ${{env.PYTHON_VERSION}} 28 | - name: Checkout main branch of the repository 29 | uses: actions/checkout@v6 30 | with: 31 | fetch-depth: 0 32 | - name: Install dependencies 33 | run: pip install asv[virtualenv]==${{env.ASV_VERSION}} 34 | - name: Configure git 35 | run: | 36 | git config user.name "github-actions[bot]" 37 | git config user.email "41898282+github-actions[bot]@users.noreply.github.com" 38 | - name: Create ASV machine config file 39 | run: asv machine --machine gh-runner --yes 40 | - name: Fetch previous results from the "benchmarks" branch 41 | run: | 42 | if git ls-remote --exit-code origin benchmarks > /dev/null 2>&1; then 43 | git merge origin/benchmarks \ 44 | --allow-unrelated-histories \ 45 | --no-commit 46 | mv ../_results . 47 | fi 48 | - name: Get nightly dates under comparison 49 | id: nightly-dates 50 | run: | 51 | echo "yesterday=$(date -d yesterday +'%Y-%m-%d')" >> $GITHUB_OUTPUT 52 | echo "today=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT 53 | - name: Use last nightly commit hash from cache 54 | uses: actions/cache@v4 55 | with: 56 | path: ${{env.WORKING_DIR}} 57 | key: nightly-results-${{steps.nightly-dates.outputs.yesterday}} 58 | - name: Run comparison of main against last nightly build 59 | run: | 60 | HASH_FILE=${{env.NIGHTLY_HASH_FILE}} 61 | CURRENT_HASH=${{github.sha}} 62 | if [ -f $HASH_FILE ]; then 63 | PREV_HASH=$(cat $HASH_FILE) 64 | asv continuous $PREV_HASH $CURRENT_HASH --verbose || true 65 | asv compare $PREV_HASH $CURRENT_HASH --sort ratio --verbose 66 | fi 67 | echo $CURRENT_HASH > $HASH_FILE 68 | - name: Update last nightly hash in cache 69 | uses: actions/cache@v4 70 | with: 71 | path: ${{env.WORKING_DIR}} 72 | key: nightly-results-${{steps.nightly-dates.outputs.today}} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hyrax 2 | [![Template](https://img.shields.io/badge/Template-LINCC%20Frameworks%20Python%20Project%20Template-brightgreen)](https://lincc-ppt.readthedocs.io/en/latest/) 3 | [![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/lincc-frameworks/hyrax/smoke-test.yml)](https://github.com/lincc-frameworks/hyrax/actions/workflows/smoke-test.yml) 4 | [![codecov](https://codecov.io/gh/lincc-frameworks/hyrax/branch/main/graph/badge.svg)](https://codecov.io/gh/lincc-frameworks/hyrax) 5 | [![Read the Docs](https://img.shields.io/readthedocs/hyrax)](https://hyrax.readthedocs.io/en/latest) 6 | [![PyPI](https://img.shields.io/pypi/v/hyrax?color=blue&logo=pypi&logoColor=white)](https://pypi.org/project/hyrax/) 7 |
8 | 9 | ## Introduction 10 | Hyrax is an efficient tool 11 | to hunt for rare and anomalous sources in large astronomical imaging surveys 12 | (e.g., Rubin-LSST, HSC, Euclid, NGRST, etc.). 13 | Hyrax is designed to support four primary steps in the anomaly detection workflow: 14 | 15 | * Downloading large numbers of cutouts from public data repositories 16 | * Building lower dimensional representations of downloaded images -- the latent space 17 | * Interactive visualization and algorithmic exploration (e.g., clustering, similarity-search, etc.) of the latent space 18 | * Identification & rank-ordering of potential anomalous objects 19 | 20 | Hyrax is not tied to a specific anomaly detection algorithm/model or a specific 21 | class of rare/anomalous objects; but rather intended to support any algorithm 22 | that the user may want to apply on imaging data. 23 | If the algorithm you want to use takes in tensors, outputs tensors, and can be 24 | implemented in PyTorch; then chances are Hyrax is the right tool for you! 25 | 26 | ## Getting Started 27 | To get started with Hyrax, clone the repository and create a new virtual environment. 28 | If you plan to develop code, run the ``.setup_dev.sh`` script. 29 | 30 | ``` 31 | >> git clone https://github.com/lincc-frameworks/hyrax.git 32 | >> conda create -n hyrax python=3.10 33 | >> bash .setup_dev.sh (Optional, for developers) 34 | ``` 35 | 36 | ## Additional Information 37 | Hyrax is under active development and has limited documentation at the moment. 38 | We aim to have v1 stability and more documentation in the first half of 2025. 39 | If you are an astronomer trying to use Hyrax before then, please get in touch with us! 40 | 41 | This project started as a collaboration between different units within the 42 | [LSST Discovery Alliance](https://lsstdiscoveryalliance.org/) -- 43 | the [LINCC Frameworks Team](https://lsstdiscoveryalliance.org/programs/lincc-frameworks/) 44 | and LSST-DA Catalyst Fellow, [Aritra Ghosh](https://ghosharitra.com/). 45 | 46 | ## Acknowledgements 47 | 48 | This project is supported by Schmidt Sciences. 49 | -------------------------------------------------------------------------------- /.github/workflows/asv-pr.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run benchmarks with airspeed velocity (asv) for pull requests. 2 | # It will compare the performance of the main branch with the performance of the merge 3 | # with the new changes. It then publishes a comment with this assessment by triggering 4 | # the publish-benchmarks-pr workflow. 5 | # Based on https://securitylab.github.com/research/github-actions-preventing-pwn-requests/. 6 | name: Run benchmarks for PR 7 | 8 | on: 9 | pull_request: 10 | branches: [ main ] 11 | workflow_dispatch: 12 | 13 | concurrency: 14 | group: ${{github.workflow}}-${{github.ref}} 15 | cancel-in-progress: true 16 | 17 | env: 18 | PYTHON_VERSION: "3.11" 19 | ASV_VERSION: "0.6.5" 20 | WORKING_DIR: ${{github.workspace}}/benchmarks 21 | ARTIFACTS_DIR: ${{github.workspace}}/artifacts 22 | 23 | jobs: 24 | asv-pr: 25 | runs-on: ubuntu-latest 26 | defaults: 27 | run: 28 | working-directory: ${{env.WORKING_DIR}} 29 | steps: 30 | - name: Set up Python ${{env.PYTHON_VERSION}} 31 | uses: actions/setup-python@v6 32 | with: 33 | python-version: ${{env.PYTHON_VERSION}} 34 | - name: Checkout PR branch of the repository 35 | uses: actions/checkout@v6 36 | with: 37 | fetch-depth: 0 38 | - name: Display Workflow Run Information 39 | run: | 40 | echo "Workflow Run ID: ${{github.run_id}}" 41 | - name: Install dependencies 42 | run: pip install asv[virtualenv]==${{env.ASV_VERSION}} lf-asv-formatter 43 | - name: Make artifacts directory 44 | run: mkdir -p ${{env.ARTIFACTS_DIR}} 45 | - name: Save pull request number 46 | run: echo ${{github.event.pull_request.number}} > ${{env.ARTIFACTS_DIR}}/pr 47 | - name: Get current job logs URL 48 | uses: Tiryoh/gha-jobid-action@v1 49 | id: jobs 50 | with: 51 | github_token: ${{secrets.GITHUB_TOKEN}} 52 | job_name: ${{github.job}} 53 | - name: Create ASV machine config file 54 | run: asv machine --machine gh-runner --yes 55 | - name: Save comparison of PR against main branch 56 | run: | 57 | git remote add upstream https://github.com/${{github.repository}}.git 58 | git fetch upstream 59 | asv continuous upstream/main HEAD --verbose || true 60 | asv compare upstream/main HEAD --sort ratio --verbose | tee output 61 | python -m lf_asv_formatter --asv_version "$(asv --version | awk '{print $2}')" 62 | printf "\n\nClick [here]($STEP_URL) to view all benchmarks." >> output 63 | mv output ${{env.ARTIFACTS_DIR}} 64 | env: 65 | STEP_URL: ${{steps.jobs.outputs.html_url}}#step:10:1 66 | - name: Upload artifacts (PR number and benchmarks output) 67 | uses: actions/upload-artifact@v5 68 | with: 69 | name: benchmark-artifacts 70 | path: ${{env.ARTIFACTS_DIR}} -------------------------------------------------------------------------------- /tests/hyrax/test_csv_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | import hyrax 6 | 7 | 8 | @pytest.fixture(scope="function") 9 | def test_hyrax_csv_dataset(): 10 | """Fixture that gives a hyrax object configured to use a simple CSV dataset""" 11 | this_file_dir = Path(__file__).absolute().parent 12 | csv_file = this_file_dir / "test_data" / "csv_test" / "sample_data.csv" 13 | 14 | h = hyrax.Hyrax() 15 | h.config["model_inputs"] = { 16 | "train": { 17 | "data": { 18 | "dataset_class": "HyraxCSVDataset", 19 | "data_location": str(csv_file), 20 | "primary_id_field": "object_id", 21 | }, 22 | }, 23 | } 24 | 25 | return h 26 | 27 | 28 | def test_csv_dataset_initialization(test_hyrax_csv_dataset): 29 | """Check that the CSV dataset is correctly initialized and can be accessed""" 30 | dataset = test_hyrax_csv_dataset.prepare() 31 | 32 | # Dataset has correct length 33 | assert len(dataset["train"]) == 5 34 | 35 | 36 | def test_csv_dataset_column_getters(test_hyrax_csv_dataset): 37 | """Check that column getter methods are dynamically created""" 38 | dataset = test_hyrax_csv_dataset.prepare() 39 | 40 | # Get the underlying HyraxCSVDataset instance 41 | csv_dataset = dataset["train"]._primary_or_first_dataset() 42 | 43 | # Check that getter methods exist for each column 44 | assert hasattr(csv_dataset, "get_object_id") 45 | assert hasattr(csv_dataset, "get_ra") 46 | assert hasattr(csv_dataset, "get_dec") 47 | assert hasattr(csv_dataset, "get_magnitude") 48 | assert hasattr(csv_dataset, "get_flux") 49 | assert hasattr(csv_dataset, "get_classification") 50 | 51 | # Check that getter methods return correct values 52 | assert csv_dataset.get_object_id(0) == 1001 53 | assert csv_dataset.get_ra(0) == 30.5 54 | assert csv_dataset.get_classification(0) == "star" 55 | 56 | 57 | def test_csv_dataset_sample_data(test_hyrax_csv_dataset): 58 | """Check that sample_data returns the first row correctly""" 59 | dataset = test_hyrax_csv_dataset.prepare() 60 | 61 | # Get the underlying HyraxCSVDataset instance 62 | csv_dataset = dataset["train"]._primary_or_first_dataset() 63 | sample = csv_dataset.sample_data() 64 | 65 | # Check that sample has the expected structure 66 | assert "data" in sample 67 | assert "object_id" in sample["data"] 68 | assert "ra" in sample["data"] 69 | assert "dec" in sample["data"] 70 | assert "magnitude" in sample["data"] 71 | assert "flux" in sample["data"] 72 | assert "classification" in sample["data"] 73 | 74 | # Check that sample values match the first row 75 | assert sample["data"]["object_id"] == 1001 76 | assert sample["data"]["ra"] == 30.5 77 | assert sample["data"]["classification"] == "star" 78 | -------------------------------------------------------------------------------- /tests/hyrax/test_fits_image_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from torch import Size 5 | 6 | import hyrax 7 | 8 | 9 | @pytest.fixture(scope="function", params=[".astropy.csv", ".csv", ".fits", ".pq", ".votable"]) 10 | def test_hyrax_small_dataset_hscstars(request): 11 | """Fixture that gives a hyrax object configured to use the small_dataset_hscstars dataset 12 | 13 | Several common table formats are available for this dataset 14 | """ 15 | this_file_dir = Path(__file__).absolute().parent 16 | catalog_file = this_file_dir / "test_data" / "small_dataset_hscstars" / f"star_cat_correct{request.param}" 17 | 18 | h = hyrax.Hyrax() 19 | h.config["model_inputs"] = { 20 | "train": { 21 | "data": { 22 | "dataset_class": "FitsImageDataSet", 23 | "data_location": str(catalog_file.parent), 24 | "primary_id_field": "object_id", 25 | }, 26 | }, 27 | "infer": { 28 | "data": { 29 | "dataset_class": "FitsImageDataSet", 30 | "data_location": str(catalog_file.parent), 31 | "primary_id_field": "object_id", 32 | }, 33 | }, 34 | } 35 | h.config["data_set"]["filter_catalog"] = str(catalog_file) 36 | h.config["data_set"]["crop_to"] = [20, 20] 37 | 38 | object_id_column_name = "___object_id" if request.param == ".votable" else "# object_id" 39 | 40 | h.config["data_set"]["object_id_column_name"] = object_id_column_name 41 | h.config["data_set"]["filename_column_name"] = "star_filename" 42 | 43 | return h 44 | 45 | 46 | def test_prepare(test_hyrax_small_dataset_hscstars): 47 | """Check that the hsc stars dataset was correctly read, and that basic access 48 | of FitsImageDataSet returns sane values 49 | """ 50 | ds = test_hyrax_small_dataset_hscstars.prepare() 51 | 52 | # Dataset has correct length 53 | for subset in ds: 54 | a = ds[subset] 55 | assert len(a) == 10 56 | 57 | # All tensors are the correct size 58 | for d in a: 59 | assert d["data"]["image"].shape == Size([1, 20, 20]) 60 | 61 | # Selected columns in the original catalog exist 62 | assert "ira" in a.metadata_fields("data") 63 | assert "idec" in a.metadata_fields("data") 64 | assert "SNR" in a.metadata_fields("data") 65 | 66 | # IDs are correct and in the correct order 67 | assert list(a.ids()) == [ 68 | "36411452835238206", 69 | "36411452835248579", 70 | "36411452835249051", 71 | "36411452835250175", 72 | "36411457130203411", 73 | "36411457130204168", 74 | "36411457130206288", 75 | "36411457130214646", 76 | "36411457130215774", 77 | "36411457130216436", 78 | ] 79 | -------------------------------------------------------------------------------- /tests/hyrax/test_save_to_database.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | 6 | def test_save_to_database(loopback_inferred_hyrax): 7 | """Test that the data inserted into the vector database is not corrupted. i.e. 8 | that we can match ids to input vectors for all values.""" 9 | 10 | h, dataset, inference_results = loopback_inferred_hyrax 11 | inference_result_ids = np.array(list(inference_results.ids())) 12 | original_dataset_ids = np.array(list(dataset["infer"].ids())) 13 | 14 | # If the dataset is iterable, convert it to a list for easier indexing 15 | if dataset["infer"].is_iterable(): 16 | dataset = list(dataset["infer"]) 17 | original_dataset_ids = np.array([str(s["object_id"]) for s in dataset]) 18 | else: 19 | dataset = dataset["infer"] 20 | 21 | h.config["vector_db"]["name"] = "chromadb" 22 | original_shape = h.config["data_set"]["HyraxRandomDataset"]["shape"] 23 | 24 | # Populate the vector database with the results of inference 25 | vdb_path = h.config["general"]["results_dir"] 26 | h.save_to_database(output_dir=vdb_path) 27 | 28 | # Get a connection to the database that was just created. 29 | db_connection = h.database_connection(database_dir=vdb_path) 30 | 31 | # Verify that every inserted vector id matches the original vector 32 | for id in inference_result_ids: 33 | # Since the ordering of inference results is not guaranteed to match the 34 | # original dataset, we need to find the index of the original dataset id 35 | # that corresponds to the inference result id. 36 | assert id in original_dataset_ids, f"Inference ID, {id} not found in original dataset IDs." 37 | orig_indx = np.where(original_dataset_ids == id)[0][0] 38 | result = db_connection.get_by_id(id) 39 | saved_value = result[id].reshape(original_shape) 40 | original_value = dataset[orig_indx]["data"]["image"] 41 | assert np.all(np.isclose(saved_value, original_value)) 42 | 43 | 44 | def test_save_to_database_tensorboard_logging(loopback_inferred_hyrax): 45 | """Test that Tensorboard logs are created during vector database insertion.""" 46 | 47 | h, dataset, inference_results = loopback_inferred_hyrax 48 | h.config["vector_db"]["name"] = "chromadb" 49 | 50 | # Populate the vector database with the results of inference 51 | vdb_path = h.config["general"]["results_dir"] 52 | h.save_to_database(output_dir=vdb_path) 53 | 54 | # Check that Tensorboard event files were created in the output directory 55 | tensorboard_files = list(Path(vdb_path).glob("events.out.tfevents.*")) 56 | assert len(tensorboard_files) > 0, "No Tensorboard event files found in output directory" 57 | 58 | # Optionally, we could parse the event files to check for our specific metrics 59 | # but that would require additional dependencies, so we'll just check for file existence 60 | -------------------------------------------------------------------------------- /example_notebooks/GettingStartedDownloader.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import hyrax\n", 10 | "\n", 11 | "# We could be all fancy and use a toml file writer here, but\n", 12 | "# I'm trying to show how simple this can be.\n", 13 | "hyrax_config = \"hyrax_config.toml\"\n", 14 | "with open(hyrax_config, \"w\") as f:\n", 15 | " # This simply contains defaults, but you'll have to change at least\n", 16 | " # username, password, fits_file, and cutout_dir to your liking\n", 17 | " # if you want this to work.\n", 18 | "\n", 19 | " f.write(\n", 20 | " \"\"\"\n", 21 | "\n", 22 | " [download]\n", 23 | " sw = \"22asec\"\n", 24 | " sh = \"22asec\"\n", 25 | " filter = [\"HSC-G\", \"HSC-R\", \"HSC-I\", \"HSC-Z\", \"HSC-Y\"]\n", 26 | " type = \"coadd\"\n", 27 | " rerun = \"pdr3_wide\"\n", 28 | "\n", 29 | " username = \"@local\"\n", 30 | " password = \"\"\n", 31 | "\n", 32 | " fits_file = \"\"\n", 33 | " cutout_dir = \"\"\n", 34 | "\n", 35 | " offset = 0\n", 36 | " num_sources = 10\n", 37 | " \n", 38 | " \"\"\"\n", 39 | " )" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "import hyrax\n", 49 | "import os\n", 50 | "from pathlib import Path\n", 51 | "\n", 52 | "# os.chdir(Path(hyrax.__file__).parent/\"..\"/\"..\")\n", 53 | "hyrax_instance = hyrax.Hyrax(config_file=\"hyrax_config.toml\")\n", 54 | "\n", 55 | "hyrax_instance.download()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "import matplotlib.pyplot as plt\n", 65 | "\n", 66 | "widths, heights = hyrax_instance.raw_data_dimensions()\n", 67 | "\n", 68 | "fig, axs = plt.subplots(1, 2)\n", 69 | "fig.set_figwidth(12)\n", 70 | "\n", 71 | "_, _, _ = axs[0].hist(heights, range=(260, 270), bins=10)\n", 72 | "_, _, _ = axs[1].hist(widths, range=(260, 270), bins=10)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [] 81 | } 82 | ], 83 | "metadata": { 84 | "kernelspec": { 85 | "display_name": "fibad", 86 | "language": "python", 87 | "name": "python3" 88 | }, 89 | "language_info": { 90 | "codemirror_mode": { 91 | "name": "ipython", 92 | "version": 3 93 | }, 94 | "file_extension": ".py", 95 | "mimetype": "text/x-python", 96 | "name": "python", 97 | "nbconvert_exporter": "python", 98 | "pygments_lexer": "ipython3", 99 | "version": "3.10.14" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 2 104 | } 105 | -------------------------------------------------------------------------------- /src/hyrax/data_sets/hyrax_csv_dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | from types import MethodType 4 | 5 | import pandas as pd 6 | 7 | from hyrax.data_sets.data_set_registry import HyraxDataset 8 | 9 | 10 | class HyraxCSVDataset(HyraxDataset): 11 | """A Hyrax Dataset for CSV files. 12 | This class reads a CSV file using pandas with memory mapping enabled. 13 | It dynamically creates getter methods for each column in the CSV file, 14 | allowing users to request data from specific columns. 15 | 16 | Note: Column names found in the CSV file are used to create the getter methods. 17 | If a column name contains characters that are invalid for method names, 18 | those characters are replaced with underscores. 19 | 20 | Example model_inputs configuration: 21 | { 22 | "train": { 23 | "data": { 24 | "dataset_class": "HyraxCSVDataset", 25 | "data_location": , 26 | "fields": ["", "", ...], 27 | "primary_id_field": , 28 | }, 29 | }, 30 | "validate": { }, 31 | "infer": { }, 32 | } 33 | """ 34 | 35 | def __init__(self, config: dict, data_location: Path = None): 36 | self.data_location = data_location 37 | if data_location is None: 38 | raise ValueError("A `data_location` Path to a .csv file must be provided.") 39 | 40 | header_only = pd.read_csv(data_location, nrows=0) 41 | self.column_names = [re.sub(r"\W", "_", col) for col in list(header_only.columns)] 42 | self.mem_mapped_csv = pd.read_csv(data_location, memory_map=True, header=0) 43 | 44 | # Automatically generate all the getter methods based on the column names. 45 | def _make_getter(column): 46 | def getter(self, idx, _col=column): 47 | ret_val = self.mem_mapped_csv[_col][idx] 48 | if isinstance(ret_val, pd.Series): 49 | ret_val = ret_val.to_list() 50 | return ret_val 51 | 52 | return getter 53 | 54 | for col in self.column_names: 55 | method_name = f"get_{col}" 56 | if not hasattr(self, method_name): 57 | setattr(self, method_name, MethodType(_make_getter(col), self)) 58 | 59 | super().__init__(config) 60 | 61 | def __getitem__(self, idx): 62 | """Currently required by Hyrax machinery, but likely to be phased out.""" 63 | return {} 64 | 65 | def __len__(self) -> int: 66 | """Return the number of records in the CSV.""" 67 | return len(self.mem_mapped_csv) 68 | 69 | def sample_data(self): 70 | """Return the first record, in dictionary form, as the sample.""" 71 | sample = {"data": {}} 72 | 73 | for col in self.column_names: 74 | sample["data"][col] = self.mem_mapped_csv.iloc[0][col] 75 | 76 | return sample 77 | 78 | @classmethod 79 | def is_map(cls) -> bool: 80 | """Boilerplate method to indicate this is a map-style dataset.""" 81 | return True 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | _version.py 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | docs/autoapi/ 75 | _readthedocs/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # vscode 135 | .vscode/ 136 | 137 | # dask 138 | dask-worker-space/ 139 | 140 | # tmp directory 141 | tmp/ 142 | 143 | # Mac OS 144 | .DS_Store 145 | 146 | # Airspeed Velocity performance results 147 | _results/ 148 | _html/ 149 | 150 | # Project initialization script 151 | .initialize_new_project.sh 152 | 153 | # Default save locations for hyrax 154 | data/ 155 | results/ 156 | example_model.pth 157 | 158 | # Common location to stash personal or notebook example config 159 | hyrax_config.toml 160 | credentials.ini 161 | 162 | # MLFlow data directory 163 | mlruns/ 164 | 165 | #JSON data files used by the 3d visualizer 166 | *.json 167 | -------------------------------------------------------------------------------- /docs/pre_executed/mpr_demo_plotting.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from astropy.io import fits 7 | from matplotlib.colors import LogNorm 8 | 9 | 10 | def plot_grid(data_list): 11 | """ 12 | Plots an n x 4 grid of matplotlib plots. 13 | 14 | Parameters 15 | ---------- 16 | data_list : list of tuples 17 | Each tuple in the list is (object id, rounded median distance to NN, file name) 18 | """ 19 | 20 | num_cols = 4 21 | 22 | num_plots = len(data_list) 23 | num_rows = (num_plots + num_cols) // num_cols # Calculate the number of rows needed 24 | 25 | fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5 * num_rows)) 26 | 27 | for i, data in enumerate(data_list): 28 | row = i // num_cols 29 | col = i % num_cols 30 | ax = axes[row, col] 31 | plotter(ax, data) 32 | fig.patch.set_facecolor("darkslategrey") 33 | 34 | # Hide any unused subplots 35 | for j in range(num_plots, num_rows * num_cols): 36 | fig.delaxes(axes.flatten()[j]) 37 | 38 | plt.tight_layout() 39 | plt.show() 40 | 41 | 42 | def plotter(ax, data_tuple): 43 | """Plot the R band image for a given object ID. 44 | 45 | Parameters 46 | ---------- 47 | ax : matplotlib.axes.Axes 48 | The axes to plot the image on 49 | data_tuple : (int, float, str) 50 | Each tuple is (object id, rounded median distance to NN, file name) 51 | """ 52 | # Read the FITS files 53 | object_id, dist, file_name = data_tuple 54 | 55 | fits_file = file_name + "_HSC-R.fits" 56 | data = fits.getdata(fits_file) 57 | 58 | # Normalize the data 59 | data = (data - np.min(data)) / (np.max(data) - np.min(data)) 60 | 61 | title = f"Obj ID: {object_id}\nMedian dist: {np.round(dist)}" 62 | 63 | # Display the image 64 | ax.imshow(data, origin="lower", norm=LogNorm(), cmap="Greys") 65 | ax.set_title(title, y=1.0, pad=-30) 66 | ax.axis("off") # Hide the axis 67 | 68 | 69 | def sort_objects_by_median_distance(all_embeddings, median_dist_all_nn, data_directory): 70 | """Order all the objects according to median distance to nearest neighbor. 71 | Return a tuple for easy plotting: (object id, rounded median distance, file name).""" 72 | 73 | # Use the indexes to gather metadata: object ID, rounded median distance, and file name 74 | data_directory = Path(data_directory).expanduser().resolve() 75 | objects = [] 76 | for indx in np.argsort(median_dist_all_nn): 77 | object_id = all_embeddings["ids"][indx] 78 | 79 | found_files = glob.glob(f"{data_directory / object_id}*.fits") 80 | file_name = found_files[0][:-11] 81 | 82 | objects.append((object_id, np.round(median_dist_all_nn[indx]), file_name)) 83 | 84 | return objects 85 | 86 | 87 | def plot_umap(results_dir): 88 | """Reads in the UMAP results and plots them as a scatter plot""" 89 | a = np.load(results_dir / "batch_0.npy") 90 | b = np.load(results_dir / "batch_1.npy") 91 | out = np.concatenate((a["tensor"], b["tensor"]), axis=0) 92 | fig, ax = plt.subplots(figsize=(12, 12)) 93 | fig.patch.set_facecolor("darkslategrey") 94 | ax.set_facecolor("darkslategrey") 95 | ax.scatter(out[:, 0], out[:, 1], s=3, c="yellow") 96 | plt.show() 97 | -------------------------------------------------------------------------------- /src/hyrax/data_sets/hyrax_cifar_data_set.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101, D102 2 | import logging 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset, IterableDataset 7 | 8 | from .data_set_registry import HyraxDataset 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class HyraxCifarBase: 14 | """Base class for Hyrax Cifar datasets""" 15 | 16 | def __init__(self, config: dict, data_location: Path = None): 17 | import torchvision.transforms as transforms 18 | from astropy.table import Table 19 | from torchvision.datasets import CIFAR10 20 | 21 | self.data_location = data_location if data_location else config["general"]["data_dir"] 22 | 23 | transform = transforms.Compose( 24 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 25 | ) 26 | self.cifar = CIFAR10(root=self.data_location, train=True, download=True, transform=transform) 27 | metadata_table = Table( 28 | {"label": np.array([self.cifar[index][1] for index in range(len(self.cifar))])} 29 | ) 30 | super().__init__(config, metadata_table) 31 | 32 | def get_image(self, idx): 33 | """Get the image at the given index as a NumPy array.""" 34 | image, _ = self.cifar[idx] 35 | return image.numpy() 36 | 37 | def get_label(self, idx): 38 | """Get the label at the given index.""" 39 | _, label = self.cifar[idx] 40 | return label 41 | 42 | def get_index(self, idx): 43 | """Get the index of the item.""" 44 | return idx 45 | 46 | def get_object_id(self, idx): 47 | """Get the object ID for the item.""" 48 | return idx 49 | 50 | 51 | class HyraxCifarDataSet(HyraxCifarBase, HyraxDataset, Dataset): 52 | """Map style CIFAR 10 dataset for Hyrax 53 | 54 | This is simply a version of CIFAR10 that is initialized using Hyrax config with a transformation 55 | that works well for example code. 56 | 57 | We only use the training split in the data, because it is larger (50k images). Hyrax will then divide that 58 | into Train/test/Validate according to configuration. 59 | """ 60 | 61 | def __len__(self): 62 | return len(self.cifar) 63 | 64 | def __getitem__(self, idx): 65 | return { 66 | "data": { 67 | "object_id": self.get_object_id(idx), 68 | "image": self.get_image(idx), 69 | "label": self.get_label(idx), 70 | }, 71 | "object_id": self.get_object_id(idx), 72 | } 73 | 74 | 75 | class HyraxCifarIterableDataSet(HyraxCifarBase, HyraxDataset, IterableDataset): 76 | """Iterable style CIFAR 10 dataset for Hyrax 77 | 78 | This is simply a version of CIFAR10 that is initialized using Hyrax config with a transformation 79 | that works well for example code. This version only supports iteration, and not map-style access 80 | 81 | We only use the training split in the data, because it is larger (50k images). Hyrax will then divide that 82 | into Train/test/Validate according to configuration. 83 | """ 84 | 85 | def __iter__(self): 86 | for idx in range(len(self.cifar)): 87 | yield { 88 | "data": { 89 | "object_id": self.get_object_id(idx), 90 | "image": self.get_image(idx), 91 | "label": self.get_label(idx), 92 | }, 93 | "object_id": self.get_object_id(idx), 94 | } 95 | -------------------------------------------------------------------------------- /tests/hyrax/test_e2e.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import hyrax 4 | 5 | 6 | # Fixure to give us every model class which will be used for e2e esting 7 | @pytest.fixture(scope="module", params=["HyraxAutoencoder"]) 8 | def model_class_name(request): 9 | """Fixture to generate the model names we want to test 10 | For this file all models must work with all data sources 11 | """ 12 | return request.param 13 | 14 | 15 | # Fixture to give us every dataset which will be used for testing 16 | @pytest.fixture( 17 | scope="module", 18 | params=[ 19 | ("HSCDataSet", "hsc1k"), 20 | ("HyraxCifarDataSet", None), 21 | ("HyraxCifarIterableDataSet", None), 22 | ("FitsImageDataSet", "hsc1k"), 23 | ], 24 | ) 25 | def dataset_spec(request): 26 | """Fixture to generate all the Dataset class, sample data pairs 27 | Each of these must work with all models. 28 | """ 29 | return request.param 30 | 31 | 32 | # If the dataset requires a pre-download not performed by the class __init__ 33 | # it needs to be coded here. 34 | @pytest.fixture(scope="module") 35 | def tmp_dataset_path(tmp_path_factory, dataset_spec): 36 | """Fixture to download any needed sample data 37 | 38 | This is at module scope so it should only run once 39 | per run. for each data set, sample data pair. 40 | 41 | Additional sample data should use pooch and instance the directory name 42 | based on the sample_data name, because multiple data classes may parse 43 | the same sample dataset and we only want to download things once. 44 | """ 45 | import pooch 46 | 47 | class_name, sample_data = dataset_spec 48 | 49 | if sample_data is None: 50 | return tmp_path_factory.mktemp(class_name) 51 | 52 | if sample_data == "hsc1k": 53 | tmp_path = tmp_path_factory.mktemp(sample_data) 54 | pooch.retrieve( 55 | # DOI for Example HSC dataset 56 | url="doi:10.5281/zenodo.14498536/hsc_demo_data.zip", 57 | known_hash="md5:1be05a6b49505054de441a7262a09671", 58 | fname="example_hsc_new.zip", 59 | path=tmp_path, 60 | processor=pooch.Unzip(extract_dir="."), 61 | ) 62 | tmp_path = tmp_path / "hsc_8asec_1000" 63 | 64 | return tmp_path 65 | 66 | 67 | # This gives a configured hyrax instance 68 | @pytest.fixture(scope="function") 69 | def hyrax_instance(tmp_dataset_path, dataset_spec, model_class_name, tmp_path): 70 | """Fixture to configure and initialize the hyrax instance""" 71 | h = hyrax.Hyrax() 72 | dataset_class_name, sample_data = dataset_spec 73 | h.config["general"]["data_dir"] = str(tmp_dataset_path) 74 | h.config["general"]["results_dir"] = str(tmp_path) 75 | h.config["data_set"]["name"] = dataset_class_name 76 | if dataset_class_name == "FitsImageDataSet" and sample_data == "hsc1k": 77 | h.config["data_set"]["filter_catalog"] = str(tmp_dataset_path / "manifest.fits") 78 | h.config["data_set"]["crop_to"] = [100, 100] 79 | h.config["model"]["name"] = model_class_name 80 | h.config["train"]["epochs"] = 1 81 | h.config["data_loader"]["batch_size"] = 128 82 | 83 | return h 84 | 85 | 86 | @pytest.mark.slow 87 | def test_init(hyrax_instance): 88 | """Test that the initialization fixtures function""" 89 | pass 90 | 91 | 92 | @pytest.mark.slow 93 | def test_getting_started(hyrax_instance): 94 | """Test that the basic flow we expect folks to run when 95 | getting started works 96 | """ 97 | hyrax_instance.train() 98 | hyrax_instance.infer() 99 | hyrax_instance.umap() 100 | hyrax_instance.visualize() 101 | -------------------------------------------------------------------------------- /docs/data_set_splits.rst: -------------------------------------------------------------------------------- 1 | .. _data_set_splits: 2 | 3 | Data set splits (subsets) 4 | ============================= 5 | 6 | Datasets used in machine learning are typically split in order to avoid overfitting a particular dataset of 7 | interest, and to perform various sorts of checking that the model is learning what the researcher intends. 8 | In Hyrax there are default conventions for splitting data, which can be configured to the liking of the 9 | investigator. 10 | 11 | Splits in training 12 | ------------------ 13 | By default input datasets are split into train (60%), test (20%), and validate (20%). The ``train`` verb uses 14 | the train split to train and validate splits to create a validation loss statistic every training epoch. The 15 | test split is explicitly left out of training. 16 | 17 | The size of these splits can be configured in the ``[data_set]`` section of the configuration using the 18 | ``train_size``, ``validate_size``, and ``test_size`` configuration keys. The value is either a number of data points 19 | or a ratio of the dataset, where 1.0 represents the entire dataset. For example: 20 | 21 | .. tabs:: 22 | 23 | .. group-tab:: Notebook 24 | 25 | .. code-block:: python 26 | 27 | from hyrax import Hyrax 28 | h = Hyrax() 29 | h.config["data_set"]["train_size"] = 0.6 30 | h.config["data_set"]["validate_size"] = 0.2 31 | h.config["data_set"]["test_size"] = 0.2 32 | 33 | .. group-tab:: CLI 34 | 35 | .. code-block:: bash 36 | 37 | $ cat hyrax_config.toml 38 | 39 | [data_set] 40 | train_size = 600 41 | validate_size = 200 42 | test_size = 200 43 | 44 | 45 | It is recommended that all three are provided; however, zeroing some out can create different training effects 46 | 47 | * If the size of the validate split is zero, then training won't include a validate step. 48 | 49 | * If the size of the test split is zero, then all data will be used in the training process as either training data or for validation. 50 | 51 | * If the size of the test split is zero and the validate split is zero, training will be run on the entire dataset. 52 | 53 | 54 | Splits in inference 55 | ------------------- 56 | 57 | By default the ``infer`` verb uses the entire dataset for inference; however any of the splits can be used by 58 | specifying the ``[infer]`` ``split`` config value. Valid values are any of the three splits. For example, to 59 | infer on only the test split: 60 | 61 | .. tabs:: 62 | 63 | .. group-tab:: Notebook 64 | 65 | .. code-block:: python 66 | 67 | from hyrax import Hyrax 68 | h = Hyrax() 69 | h.config["infer"]["split"] = "test" 70 | 71 | h.infer() 72 | 73 | .. group-tab:: CLI 74 | 75 | .. code-block:: bash 76 | 77 | $ cat hyrax_config.toml 78 | [infer] 79 | split = test 80 | 81 | $ hyrax infer -c hyrax_config.toml 82 | 83 | 84 | Randomness in splits 85 | -------------------- 86 | 87 | The membership in each split is determined randomly. By default, system entropy is used to seed the random number generator for this purpose. 88 | 89 | You can specify a random seed with the ``[data_set]`` ``seed`` configuration key as follows: 90 | 91 | .. tabs:: 92 | 93 | .. group-tab:: Notebook 94 | 95 | .. code-block:: python 96 | 97 | from hyrax import Hyrax 98 | h = Hyrax() 99 | h.config["data_set"]["seed"] = 1 100 | 101 | .. group-tab:: CLI 102 | 103 | .. code-block:: bash 104 | 105 | $ cat hyrax_config.toml 106 | [data_set] 107 | seed = 1 -------------------------------------------------------------------------------- /benchmarks/benchmarks.py: -------------------------------------------------------------------------------- 1 | """Two sample benchmarks to compute runtime and memory usage. 2 | 3 | For more information on writing benchmarks: 4 | https://asv.readthedocs.io/en/stable/writing_benchmarks.html.""" 5 | 6 | import subprocess 7 | 8 | import hyrax 9 | 10 | 11 | def time_import(): 12 | """ 13 | time how long it takes to import our package. This should stay relatively fast. 14 | 15 | Note, the actual import time will be slightly lower than this on a comparable system 16 | However, high import times do affect this metric proportionally. 17 | """ 18 | result = subprocess.run(["python", "-c", "import hyrax"]) 19 | assert result.returncode == 0 20 | 21 | 22 | def time_help(): 23 | """ 24 | time how long it takes to run --help from the CLI 25 | """ 26 | result = subprocess.run(["hyrax", "--help"]) 27 | assert result.returncode == 0 28 | 29 | 30 | def time_infer_help(): 31 | """ 32 | time how long it takes to do verb-specific help for infer 33 | """ 34 | result = subprocess.run(["hyrax", "infer", "--help"]) 35 | assert result.returncode == 0 36 | 37 | 38 | def time_train_help(): 39 | """ 40 | time how long it takes to do verb-specific help for train 41 | """ 42 | result = subprocess.run(["hyrax", "train", "--help"]) 43 | assert result.returncode == 0 44 | 45 | 46 | def time_lookup_help(): 47 | """ 48 | time how long it takes to do verb-specific help for lookup 49 | """ 50 | result = subprocess.run(["hyrax", "lookup", "--help"]) 51 | assert result.returncode == 0 52 | 53 | 54 | def time_umap_help(): 55 | """ 56 | time how long it takes to do verb-specific help for umap 57 | """ 58 | result = subprocess.run(["hyrax", "umap", "--help"]) 59 | assert result.returncode == 0 60 | 61 | 62 | def time_save_to_database_help(): 63 | """ 64 | time how long it takes to do verb-specific help for save_to_database 65 | """ 66 | result = subprocess.run(["hyrax", "save_to_database", "--help"]) 67 | assert result.returncode == 0 68 | 69 | 70 | def time_database_connection_help(): 71 | """ 72 | time how long it takes to do verb-specific help for database_connection 73 | """ 74 | result = subprocess.run(["hyrax", "database_connection", "--help"]) 75 | assert result.returncode == 0 76 | 77 | 78 | def time_download_help(): 79 | """ 80 | time how long it takes to do verb-specific help for download 81 | """ 82 | result = subprocess.run(["hyrax", "download", "--help"]) 83 | assert result.returncode == 0 84 | 85 | 86 | def time_prepare_help(): 87 | """ 88 | time how long it takes to do verb-specific help for prepare 89 | """ 90 | result = subprocess.run(["hyrax", "prepare", "--help"]) 91 | assert result.returncode == 0 92 | 93 | 94 | def time_rebuild_manifest_help(): 95 | """ 96 | time how long it takes to do verb-specific help for rebuild_manifest 97 | """ 98 | result = subprocess.run(["hyrax", "rebuild_manifest", "--help"]) 99 | assert result.returncode == 0 100 | 101 | 102 | def time_visualize_help(): 103 | """ 104 | time how long it takes to do verb-specific help for visualize 105 | """ 106 | result = subprocess.run(["hyrax", "visualize", "--help"]) 107 | assert result.returncode == 0 108 | 109 | 110 | def time_nb_obj_construct(): 111 | """ 112 | time how long notebook users must wait for our interface object to construct 113 | """ 114 | hyrax.Hyrax() 115 | 116 | 117 | def time_nb_obj_dir(): 118 | """ 119 | Time how long it takes to construct the interface object and load the 120 | dynamcally generated list of verbs using `dir()` 121 | """ 122 | h = hyrax.Hyrax() 123 | dir(h) 124 | -------------------------------------------------------------------------------- /src/hyrax/gpu_monitor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from dataclasses import dataclass 4 | from subprocess import PIPE, Popen 5 | from threading import Thread 6 | 7 | 8 | class GpuMonitor(Thread): 9 | """General GPU monitor that runs in a separate thread and logs GPU metrics 10 | to Tensorboard. 11 | """ 12 | 13 | def __init__(self, tensorboard_logger, interval_seconds=1): 14 | super().__init__() 15 | self.stopped = False 16 | self.delay = interval_seconds # Seconds between calls to GPUtil 17 | self.start_time = time.time() 18 | self.tensorboard_logger = tensorboard_logger 19 | self.start() 20 | 21 | def run(self): 22 | """Run loop that logs GPU metrics every `self.delay` seconds.""" 23 | while not self.stopped: 24 | gpus = get_gpu_info() 25 | step = time.time() - self.start_time 26 | for gpu in gpus: 27 | gpu_name = f"GPU_{gpu.id}" 28 | self.tensorboard_logger.add_scalar(f"{gpu_name}/load", gpu.load * 100, step) 29 | self.tensorboard_logger.add_scalar( 30 | f"{gpu_name}/memory_utilization", gpu.memory_util * 100, step 31 | ) 32 | time.sleep(self.delay) 33 | 34 | def stop(self): 35 | """Stop the monitoring thread.""" 36 | self.stopped = True 37 | 38 | 39 | """The following is based on GPUtil. It has been striped down to only the 40 | functionality needed for Hyrax. The original code can be found here: 41 | https://github.com/anderskm/gputil 42 | """ 43 | 44 | 45 | @dataclass 46 | class GPU: 47 | """Holds the GPU metrics retrieved from nvidia-smi.""" 48 | 49 | id: int 50 | load: float 51 | memory_total: float 52 | memory_used: float 53 | 54 | @property 55 | def memory_util(self): 56 | """Return the memory utilization of the GPU.""" 57 | return self.memory_used / self.memory_total 58 | 59 | 60 | def safe_float_cast(str_number): 61 | """Convert a string into a float handling the case of `nan`. 62 | 63 | Parameters 64 | ---------- 65 | str_number : str 66 | The string to convert to a float. 67 | 68 | Returns 69 | ------- 70 | float 71 | The converted float. 72 | """ 73 | try: 74 | number = float(str_number) 75 | except ValueError: 76 | number = float("nan") 77 | return number 78 | 79 | 80 | def get_gpu_info(): 81 | """Get the GPU utilization and memory usage for all GPUs on the system using 82 | nvidia-smi. Returns a list of GPU objects.""" 83 | 84 | try: 85 | p = Popen( 86 | [ 87 | "nvidia-smi", 88 | "--query-gpu=index,utilization.gpu,memory.total,memory.used", 89 | "--format=csv,noheader,nounits", 90 | ], 91 | stdout=PIPE, 92 | ) 93 | stdout, stderror = p.communicate() 94 | except: # noqa: E722 95 | return [] 96 | output = stdout.decode("UTF-8") 97 | 98 | # Parse output 99 | lines = output.split(os.linesep) 100 | 101 | num_devices = len(lines) - 1 102 | gpus = [] 103 | for g in range(num_devices): 104 | line = lines[g] 105 | vals = line.split(", ") 106 | for i in range(4): 107 | if i == 0: 108 | device_ids = int(vals[i]) 109 | elif i == 1: 110 | gpu_util = safe_float_cast(vals[i]) / 100 111 | elif i == 2: 112 | mem_total = safe_float_cast(vals[i]) 113 | elif i == 3: 114 | mem_used = safe_float_cast(vals[i]) 115 | 116 | gpus.append(GPU(device_ids, gpu_util, mem_total, mem_used)) 117 | return gpus 118 | -------------------------------------------------------------------------------- /benchmarks/asv.conf.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | // The version of the config file format. Do not change, unless 4 | // you know what you are doing. 5 | "version": 1, 6 | // The name of the project being benchmarked. 7 | "project": "hyrax", 8 | // The project's homepage. 9 | "project_url": "https://github.com/lincc-frameworks/hyrax", 10 | // The URL or local path of the source code repository for the 11 | // project being benchmarked. 12 | "repo": "..", 13 | // List of branches to benchmark. If not provided, defaults to "master" 14 | // (for git) or "tip" (for mercurial). 15 | "branches": [ 16 | "HEAD" 17 | ], 18 | "install_command": [ 19 | "python -m pip install {wheel_file}" 20 | ], 21 | "build_command": [ 22 | "python -m build --wheel -o {build_cache_dir} {build_dir}" 23 | ], 24 | // The DVCS being used. If not set, it will be automatically 25 | // determined from "repo" by looking at the protocol in the URL 26 | // (if remote), or by looking for special directories, such as 27 | // ".git" (if local). 28 | "dvcs": "git", 29 | // The tool to use to create environments. May be "conda", 30 | // "virtualenv" or other value depending on the plugins in use. 31 | // If missing or the empty string, the tool will be automatically 32 | // determined by looking for tools on the PATH environment 33 | // variable. 34 | "environment_type": "virtualenv", 35 | // the base URL to show a commit for the project. 36 | "show_commit_url": "https://github.com/lincc-frameworks/hyrax/commit/", 37 | // The Pythons you'd like to test against. If not provided, defaults 38 | // to the current version of Python used to run `asv`. 39 | "pythons": [ 40 | "3.11" 41 | ], 42 | // The matrix of dependencies to test. Each key is the name of a 43 | // package (in PyPI) and the values are version numbers. An empty 44 | // list indicates to just test against the default (latest) 45 | // version. 46 | "matrix": { 47 | "Cython": [], 48 | "build": [], 49 | "packaging": [] 50 | }, 51 | // The directory (relative to the current directory) that benchmarks are 52 | // stored in. If not provided, defaults to "benchmarks". 53 | "benchmark_dir": ".", 54 | // The directory (relative to the current directory) to cache the Python 55 | // environments in. If not provided, defaults to "env". 56 | "env_dir": "env", 57 | // The directory (relative to the current directory) that raw benchmark 58 | // results are stored in. If not provided, defaults to "results". 59 | "results_dir": "_results", 60 | // The directory (relative to the current directory) that the html tree 61 | // should be written to. If not provided, defaults to "html". 62 | "html_dir": "_html", 63 | // The number of characters to retain in the commit hashes. 64 | // "hash_length": 8, 65 | // `asv` will cache wheels of the recent builds in each 66 | // environment, making them faster to install next time. This is 67 | // number of builds to keep, per environment. 68 | "build_cache_size": 8 69 | // The commits after which the regression search in `asv publish` 70 | // should start looking for regressions. Dictionary whose keys are 71 | // regexps matching to benchmark names, and values corresponding to 72 | // the commit (exclusive) after which to start looking for 73 | // regressions. The default is to start from the first commit 74 | // with results. If the commit is `null`, regression detection is 75 | // skipped for the matching benchmark. 76 | // 77 | // "regressions_first_commits": { 78 | // "some_benchmark": "352cdf", // Consider regressions only after this commit 79 | // "another_benchmark": null, // Skip regression detection altogether 80 | // } 81 | } -------------------------------------------------------------------------------- /src/hyrax/vector_dbs/vector_db_interface.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Union 3 | 4 | import numpy as np 5 | 6 | 7 | class VectorDB(ABC): 8 | """Interface for a vector database""" 9 | 10 | def __init__(self, config: dict | None = None, context: dict | None = None): 11 | """ 12 | .. py:method:: __init__ 13 | 14 | Create a new instance of a `VectorDB` object. 15 | 16 | Parameters 17 | ---------- 18 | config : dict, optional 19 | An instance of the runtime configuration, by default None 20 | context : dict, optional 21 | An instance of the context object, by default None 22 | """ 23 | self.config = config if config else {} 24 | self.context = context if context else {} 25 | 26 | @abstractmethod 27 | def connect(self): 28 | """Connect to an existing database""" 29 | pass 30 | 31 | @abstractmethod 32 | def create(self): 33 | """Create a new database""" 34 | pass 35 | 36 | @abstractmethod 37 | def insert(self, ids: list[Union[str, int]], vectors: list[np.ndarray]): 38 | """Insert a batch of vectors into the database. 39 | 40 | Parameters 41 | ---------- 42 | ids : list[Union[str, int]] 43 | The ids to associate with the vectors 44 | vectors : list[np.ndarray] 45 | The vectors to insert into the database 46 | """ 47 | pass 48 | 49 | @abstractmethod 50 | def search_by_id(self, id: Union[str, int], k: int = 1) -> dict[int, list[Union[str, int]]]: 51 | """Get the ids of the k nearest neighbors for a given id in the database. 52 | Should use the provided id to look up the vector, then call search_by_vector. 53 | 54 | Parameters 55 | ---------- 56 | id : Union[str, int] 57 | The id of the vector in the database for which we want to find the 58 | k nearest neighbors 59 | k : int, optional 60 | The number of nearest neighbors to return, by default 1, return only 61 | the closest neighbor 62 | 63 | Returns 64 | ------- 65 | dict[int, list[Union[str, int]]] 66 | Dictionary with input vector index as the key and the ids of the k 67 | nearest neighbors as the value. 68 | """ 69 | pass 70 | 71 | @abstractmethod 72 | def search_by_vector( 73 | self, vectors: Union[np.ndarray, list[np.ndarray]], k: int = 1 74 | ) -> dict[int, list[Union[str, int]]]: 75 | """Get the ids of the k nearest neighbors for a given vector. 76 | 77 | Parameters 78 | ---------- 79 | vectors : Union[np.array, list[np.ndarray]] 80 | The one or more vectors to use when searching for nearest neighbors 81 | k : int, optional 82 | The number of nearest neighbors to return, by default 1, return only 83 | the closest neighbor 84 | 85 | Returns 86 | ------- 87 | dict[int, list[Union[str, int]]] 88 | Dictionary with input vector index as the key and the ids of the 89 | k nearest neighbors as the value. 90 | """ 91 | pass 92 | 93 | @abstractmethod 94 | def get_by_id(self, ids: list[Union[str, int]]) -> dict[Union[str, int], list[float]]: 95 | """Retrieve the vectors associated with a list of ids. 96 | 97 | Parameters 98 | ---------- 99 | ids : list[Union[str, int]] 100 | The ids of the vectors to retrieve. 101 | 102 | Returns 103 | ------- 104 | dict[Union[str, int], list[float]] 105 | Dictionary with the ids as the keys and the vectors as the values. 106 | """ 107 | pass 108 | -------------------------------------------------------------------------------- /src/hyrax/verbs/lookup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import ArgumentParser, Namespace 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import numpy as np 7 | 8 | from .verb_registry import Verb, hyrax_verb 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @hyrax_verb 14 | class Lookup(Verb): 15 | """Look up an inference result using the ID of a data member""" 16 | 17 | cli_name = "lookup" 18 | add_parser_kwargs = {} 19 | 20 | @staticmethod 21 | def setup_parser(parser: ArgumentParser): 22 | """Set up our arguments by configuring a subparser 23 | 24 | Parameters 25 | ---------- 26 | parser : ArgumentParser 27 | The sub-parser to configure 28 | """ 29 | parser.add_argument("-i", "--id", type=str, required=True, help="ID of image") 30 | parser.add_argument( 31 | "-r", "--results-dir", type=str, required=False, help="Directory containing inference results." 32 | ) 33 | 34 | def run_cli(self, args: Namespace | None = None): 35 | """Entrypoint to Lookup from the CLI. 36 | 37 | Parameters 38 | ---------- 39 | args : Optional[Namespace], optional 40 | The parsed command line arguments 41 | 42 | """ 43 | logger.info("Lookup run from cli") 44 | if args is None: 45 | raise RuntimeError("Run CLI called with no arguments.") 46 | # This is where we map from CLI parsed args to a 47 | # self.run (args) call. 48 | vector = self.run(id=args.id, results_dir=args.results_dir) 49 | if vector is None: 50 | logger.info("No inference result found") 51 | else: 52 | logger.info("Inference result found") 53 | print(vector) 54 | 55 | def run(self, id: str, results_dir: Union[Path, str] | None = None) -> np.ndarray | None: 56 | """Lookup the latent-space representation of a particular ID 57 | 58 | Requires the relevant dataset to be configured, and for inference to have been run. 59 | 60 | Parameters 61 | ---------- 62 | id : str 63 | The ID of the input data to look up the inference result 64 | 65 | results_dir : str, Optional 66 | The directory containing the inference results. 67 | 68 | Returns 69 | ------- 70 | Optional[np.ndarray] 71 | The output tensor of the model for the given input. 72 | """ 73 | from hyrax.config_utils import find_most_recent_results_dir 74 | from hyrax.data_sets.inference_dataset import InferenceDataSet 75 | 76 | if results_dir is None: 77 | if self.config["results"]["inference_dir"]: 78 | results_dir = self.config["results"]["inference_dir"] 79 | else: 80 | results_dir = find_most_recent_results_dir(self.config, verb="infer") 81 | msg = f"Using most recent results dir {results_dir} for lookup." 82 | msg += "Use the [results] inference_dir config to set a directory or pass it to this verb." 83 | logger.info(msg) 84 | 85 | if results_dir is None: 86 | msg = "Could not find a results directory. Run infer or use " 87 | msg += "[results] inference_dir config to specify a directory" 88 | logger.error(msg) 89 | return None 90 | 91 | if isinstance(results_dir, str): 92 | results_dir = Path(results_dir) 93 | 94 | inference_dataset = InferenceDataSet(self.config, results_dir=results_dir) 95 | 96 | all_ids = np.array(list(inference_dataset.ids())) 97 | lookup_index = np.argwhere(all_ids == id) 98 | 99 | if len(lookup_index) == 1: 100 | return np.array(inference_dataset[lookup_index[0]].numpy()) 101 | elif len(lookup_index) > 1: 102 | raise RuntimeError(f"Inference result directory {results_dir} has duplicate ID numbers") 103 | 104 | return None 105 | -------------------------------------------------------------------------------- /tests/hyrax/test_infer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest.mock import MagicMock, patch 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | 8 | @pytest.mark.parametrize("shuffle", [True, False]) 9 | def test_infer_order(loopback_hyrax, shuffle): 10 | """Test that the order of data run through infer 11 | is correct in the presence of several splits 12 | """ 13 | h, dataset = loopback_hyrax 14 | h.config["data_loader"]["shuffle"] = shuffle 15 | 16 | dataset = dataset["infer"] 17 | inference_results = h.infer() 18 | inference_result_ids = list(inference_results.ids()) 19 | original_dataset_ids = list(dataset.ids()) 20 | 21 | if dataset.is_iterable(): 22 | dataset = list(dataset) 23 | original_dataset_ids = np.array([str(s["object_id"]) for s in dataset]) 24 | 25 | for idx, result_id in enumerate(inference_result_ids): 26 | dataset_idx = None 27 | for i, orig_id in enumerate(original_dataset_ids): 28 | if orig_id == result_id: 29 | dataset_idx = i 30 | break 31 | else: 32 | raise AssertionError("Failed to find a corresponding ID") 33 | 34 | print(f"orig idx: {dataset_idx}, infer idx: {idx}") 35 | print(f"orig data: {dataset[dataset_idx]}, infer data: {inference_results[idx]}") 36 | assert np.all(np.isclose(dataset[dataset_idx]["data"]["image"], inference_results[idx])) 37 | 38 | 39 | def test_load_model_weights_updates_config_when_auto_detected(tmp_path): 40 | """Test that config is updated when model_weights_file is auto-detected from train directory""" 41 | from hyrax.verbs.infer import Infer 42 | 43 | # Create a mock config with no model_weights_file specified 44 | config = {} 45 | config["infer"] = {"model_weights_file": None} 46 | config["train"] = {"weights_filename": "model_weights.pth"} 47 | config["general"] = {"results_dir": str(tmp_path)} 48 | 49 | # Create a fake train results directory 50 | train_dir = tmp_path / "20240101-120000-train-abcd" 51 | train_dir.mkdir(parents=True) 52 | weights_file = train_dir / "model_weights.pth" 53 | weights_file.write_text("fake weights content") 54 | 55 | # Create a mock model 56 | mock_model = MagicMock() 57 | 58 | # Mock find_most_recent_results_dir to return our fake train directory 59 | with patch("hyrax.config_utils.find_most_recent_results_dir", return_value=train_dir): 60 | # Call load_model_weights 61 | Infer.load_model_weights(config, mock_model) 62 | 63 | # Verify that config was updated with the actual weights file path 64 | assert config["infer"]["model_weights_file"] == str(weights_file) 65 | # Verify that model.load was called with the correct path 66 | mock_model.load.assert_called_once_with(weights_file) 67 | 68 | 69 | def test_load_model_weights_preserves_explicit_config(): 70 | """Test that config is still updated when model_weights_file is explicitly provided""" 71 | from tempfile import NamedTemporaryFile 72 | 73 | from hyrax.verbs.infer import Infer 74 | 75 | # Create a temporary weights file 76 | with NamedTemporaryFile(suffix=".pth", delete=False) as tmp_file: 77 | tmp_file.write(b"fake weights content") 78 | weights_path = Path(tmp_file.name) 79 | 80 | try: 81 | # Create a mock config with explicit model_weights_file 82 | config = {} 83 | config["infer"] = {"model_weights_file": str(weights_path)} 84 | config["train"] = {"weights_filename": "model_weights.pth"} 85 | 86 | # Create a mock model 87 | mock_model = MagicMock() 88 | 89 | # Call load_model_weights 90 | Infer.load_model_weights(config, mock_model) 91 | 92 | # Verify that config still contains the weights file path (converted to string) 93 | assert config["infer"]["model_weights_file"] == str(weights_path) 94 | # Verify that model.load was called with the correct path 95 | mock_model.load.assert_called_once_with(weights_path) 96 | finally: 97 | # Clean up 98 | weights_path.unlink() 99 | -------------------------------------------------------------------------------- /tests/hyrax/test_train.py: -------------------------------------------------------------------------------- 1 | from hyrax.config_utils import find_most_recent_results_dir 2 | 3 | 4 | def test_train(loopback_hyrax): 5 | """ 6 | Simple test that training succeeds with the loopback 7 | model in use. 8 | """ 9 | h, _ = loopback_hyrax 10 | h.train() 11 | 12 | 13 | def test_train_resume(loopback_hyrax, tmp_path): 14 | """ 15 | Ensure that training can be resumed from a checkpoint 16 | when using the loopback model. 17 | """ 18 | checkpoint_filename = "checkpoint_epoch_1.pt" 19 | 20 | h, _ = loopback_hyrax 21 | # set results directory to a temporary path 22 | h.config["general"]["results_dir"] = str(tmp_path) 23 | 24 | # First, run initial training to create a saved model file 25 | _ = h.train() 26 | 27 | # find the model file in the most recent results directory 28 | results_dir = find_most_recent_results_dir(h.config, "train") 29 | checkpoint_path = results_dir / checkpoint_filename 30 | 31 | # Now, set the resume config to point to this checkpoint 32 | h.config["train"]["resume"] = str(checkpoint_path) 33 | 34 | # Resume training 35 | h.train() 36 | 37 | 38 | def test_train_percent_split(tmp_path): 39 | """ 40 | Ensure backward compatibility with percent-based splits when the 41 | configuration provides only a `train` and `infer` model_inputs section 42 | (no explicit `validate` table). This should exercise the code path 43 | that creates train/validate splits from a single dataset location. 44 | """ 45 | import hyrax 46 | 47 | h = hyrax.Hyrax() 48 | h.config["model"]["name"] = "HyraxLoopback" 49 | h.config["train"]["epochs"] = 1 50 | h.config["data_loader"]["batch_size"] = 4 51 | h.config["general"]["results_dir"] = str(tmp_path) 52 | h.config["general"]["dev_mode"] = True 53 | 54 | # Only provide `train` and `infer` model_inputs (no `validate` key). 55 | h.config["model_inputs"] = { 56 | "train": { 57 | "data": { 58 | "dataset_class": "HyraxRandomDataset", 59 | "data_location": str(tmp_path / "data_train"), 60 | "primary_id_field": "object_id", 61 | } 62 | }, 63 | "infer": { 64 | "data": { 65 | "dataset_class": "HyraxRandomDataset", 66 | "data_location": str(tmp_path / "data_infer"), 67 | "primary_id_field": "object_id", 68 | } 69 | }, 70 | } 71 | 72 | # Configure the underlying random dataset used by tests 73 | h.config["data_set"]["HyraxRandomDataset"]["size"] = 20 74 | h.config["data_set"]["HyraxRandomDataset"]["seed"] = 0 75 | h.config["data_set"]["HyraxRandomDataset"]["shape"] = [2, 3] 76 | 77 | # Percent-based split parameters - these should be applied to the single 78 | # location `train` dataset and produce a validate split implicitly. 79 | h.config["data_set"]["train_size"] = 0.6 80 | h.config["data_set"]["validate_size"] = 0.2 81 | h.config["data_set"]["test_size"] = 0.2 82 | 83 | # Instead of running full training, validate that the legacy percent-based 84 | # split path creates both train and validate dataloaders with expected sizes. 85 | from hyrax.pytorch_ignite import dist_data_loader, setup_dataset 86 | 87 | # Create dataset dict using the same logic as training 88 | dataset = setup_dataset(h.config) 89 | 90 | assert "train" in dataset 91 | 92 | data_loaders = dist_data_loader(dataset["train"], h.config, ["train", "validate"]) 93 | 94 | # Should have created both train and validate loaders 95 | assert "train" in data_loaders and "validate" in data_loaders 96 | 97 | train_loader, train_indexes = data_loaders["train"] 98 | validate_loader, validate_indexes = data_loaders["validate"] 99 | 100 | # Assert expected sizes: train 12 (60% of 20), validate 4 (20% of 20) 101 | assert len(train_indexes) == 12 102 | assert len(validate_indexes) == 4 103 | 104 | # Finally, run full training to exercise `train.py` end-to-end and ensure 105 | # the training verb functions correctly with percent-based splits. 106 | h.train() 107 | -------------------------------------------------------------------------------- /src/hyrax/models/hsc_dcae.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101, D102 2 | 3 | # This autoencoder is designed to work with datasets 4 | # that are prepared with Hyrax's HSC Data Set class. 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | # extra long import here to address a circular import issue 11 | from hyrax.models.model_registry import hyrax_model 12 | 13 | 14 | class ArcsinhActivation(nn.Module): 15 | """Helper module for HSCDAE to use the arcsinh function""" 16 | 17 | def forward(self, x): 18 | return torch.arcsinh(x) 19 | 20 | 21 | @hyrax_model 22 | class HSCDCAE(nn.Module): 23 | """ 24 | This autoencoder is designed to work with datasets that are prepared with Hyrax's HSC Data Set class. 25 | """ 26 | 27 | def __init__(self, config, data_sample=None): 28 | super().__init__() 29 | 30 | # The current network works with images of size [3,150,150] 31 | # You will need to updat padding, stride, etc. for imags 32 | # of other sizes 33 | 34 | # Encoder 35 | self.encoder1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1) 36 | self.encoder2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 37 | self.encoder3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 38 | self.encoder4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 39 | 40 | self.pool = nn.MaxPool2d(2, 2) 41 | 42 | # Decoder 43 | self.decoder4 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=0, output_padding=0) 44 | self.decoder3 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, output_padding=0) 45 | self.decoder2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, output_padding=0) 46 | self.decoder1 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=1, padding=1, output_padding=0) 47 | 48 | self.activation = nn.ReLU() 49 | 50 | final_layer = config["model"]["HSCDCAE_final_layer"] 51 | if final_layer == "sigmoid": 52 | self.final_activation = nn.Sigmoid() 53 | elif final_layer == "tanh": 54 | self.final_activation = nn.Tanh() 55 | elif final_layer == "arcsinh": 56 | self.final_activation = ArcsinhActivation() 57 | else: 58 | self.final_activation = nn.Identity() 59 | 60 | self.config = config 61 | 62 | def forward(self, x): 63 | # Dropping labels if present 64 | x = x[0] if isinstance(x, tuple) else x 65 | 66 | # Encoder with skip connections 67 | x1 = self.activation(self.encoder1(x)) 68 | x2 = self.activation(self.encoder2(self.pool(x1))) 69 | x3 = self.activation(self.encoder3(self.pool(x2))) 70 | x4 = self.activation(self.encoder4(self.pool(x3))) 71 | 72 | return x4 73 | 74 | def train_step(self, batch): 75 | """This function contains the logic for a single training step. i.e. the 76 | contents of the inner loop of a ML training process. 77 | 78 | Parameters 79 | ---------- 80 | batch : tuple 81 | A tuple containing the two values the loss function 82 | 83 | Returns 84 | ------- 85 | Current loss value : dict 86 | Dictionary containing the loss value for the current batch. 87 | """ 88 | 89 | # Dropping labels if present 90 | data = batch[0] if isinstance(batch, tuple) else batch 91 | self.optimizer.zero_grad() 92 | 93 | # Encoder with skip connections 94 | x1 = self.activation(self.encoder1(data)) 95 | x2 = self.activation(self.encoder2(self.pool(x1))) 96 | x3 = self.activation(self.encoder3(self.pool(x2))) 97 | x4 = self.activation(self.encoder4(self.pool(x3))) 98 | 99 | # Decoder with skip connections 100 | x = self.activation(self.decoder4(x4) + x3) 101 | x = self.activation(self.decoder3(x) + x2) 102 | x = self.activation(self.decoder2(x) + x1) 103 | decoded = self.final_activation(self.decoder1(x)) 104 | 105 | loss = self.criterion(decoded, data) 106 | loss.backward() 107 | self.optimizer.step() 108 | 109 | return {"loss": loss.item()} 110 | -------------------------------------------------------------------------------- /src/hyrax/downloadCutout/README.md: -------------------------------------------------------------------------------- 1 | downloadCutout.py 2 | ============================================================================== 3 | 4 | Download FITS cutouts from the website of HSC data release. 5 | 6 | Requirements 7 | ------------------------------------------------------------------------------ 8 | 9 | python >= 3.7 10 | 11 | Usage 12 | ------------------------------------------------------------------------------ 13 | 14 | ### Download images of all bands at a location 15 | 16 | ``` 17 | python3 downloadCutout.py --ra=222.222 --dec=44.444 --sw=0.5arcmin --sh=0.5arcmin --name="cutout-{filter}" 18 | ``` 19 | 20 | Note that `{filter}` must appear in `--name`. 21 | Otherwise, the five images of the five bands will be written 22 | to a single file over and over. 23 | 24 | ### Use coordinate list 25 | 26 | You can feed a coordinate list that is in nearly the same format as 27 | https://hsc-release.mtk.nao.ac.jp/das_cutout/pdr3/manual.html#list-to-upload 28 | 29 | There are a few differences: 30 | 31 | - There must not appear comments 32 | except for the mandatory one at the first line. 33 | 34 | - You can use "all" as a value of "filter" field. 35 | 36 | - There may be columns with unrecognised names, 37 | which are silently ignored. 38 | 39 | It is permissible for the coordinate list to contain only coordinates. 40 | For example: 41 | 42 | ``` 43 | #? ra dec 44 | 222.222 44.444 45 | 222.223 44.445 46 | 222.224 44.446 47 | ``` 48 | 49 | In this case, you have to specify other fields via the command line: 50 | 51 | ``` 52 | python3 downloadCutout.py \ 53 | --sw=5arcsec --sh=5arcsec \ 54 | --image=true --variance=true --mask=true \ 55 | --name="cutout_{tract}_{ra}_{dec}_{filter}" \ 56 | --list=coordlist.txt # <- the name of the above list 57 | ``` 58 | 59 | It is more efficient to use a list like the example above 60 | than to use a for-loop to call the script iteratively. 61 | 62 | ### Stop asking a password 63 | 64 | To stop the script asking your password, put the password 65 | into an environment variable. (Default: `HSC_SSP_CAS_PASSWORD`) 66 | 67 | ``` 68 | read -s HSC_SSP_CAS_PASSWORD 69 | export HSC_SSP_CAS_PASSWORD 70 | ``` 71 | 72 | Then, run the script with `--user` option: 73 | 74 | ``` 75 | python3 downloadCutout.py \ 76 | --ra=222.222 --dec=44.444 --sw=0.5arcmin --sh=0.5arcmin \ 77 | --name="cutout-{filter}" \ 78 | --user=USERNAME 79 | ``` 80 | 81 | If you are using your own personal laptop or desktop, 82 | you may pass your password through `--password` option. 83 | But you must never do so 84 | if there are other persons using the same computer. 85 | Remember that other persons can see your command lines 86 | with, for example, `top` command. 87 | (If it is GNU's `top`, press `C` key to see others' command lines). 88 | 89 | ### Synchronize processes 90 | 91 | If you run a program in parallel which calls `downloadCutout.py` sporadically 92 | but frequently, the program needs synchronizing---the server refuses 93 | `downloadCutout.py` if many instances of which are run at the same time. 94 | 95 | If your program does not have a synchronization mechanism, 96 | you can run `downloadCutout.py` with synchronization options: 97 | 98 | ``` 99 | python3 downloadCutout.py .... \ 100 | --semaphore=/home/yourname/semaphore --max-connections=4 101 | ``` 102 | 103 | Because the processes synchronize with each other via the specified semaphore 104 | (this is not a posix semaphore but a hand-made semaphore-like object), 105 | the semaphore must be seen to all the processes. 106 | If the processes are distributed over a network, 107 | the semaphore must be placed in an NFS or any other shared filesystem. 108 | 109 | Usage as a python module 110 | ------------------------------------------------------------------------------ 111 | 112 | Here is an example: 113 | 114 | ``` 115 | import downloadCutout 116 | 117 | rect = downloadCutout.Rect.create( 118 | ra="11h11m11.111s", 119 | dec="-1d11m11.111s", 120 | sw="1arcmin", 121 | sh="1arcmin", 122 | ) 123 | 124 | images = downloadCutout.download(rect) 125 | 126 | # Multiple images (of various filters) are returned. 127 | # We look into the first one of them. 128 | metadata, data = images[0] 129 | print(metadata) 130 | 131 | # `data` is just the binary data of a FITS file. 132 | # You can use, for example, `astropy` to decode it. 133 | import io 134 | import astropy.io.fits 135 | hdus = astropy.io.fits.open(io.BytesIO(data)) 136 | print(hdus) 137 | ``` 138 | -------------------------------------------------------------------------------- /src/hyrax/verbs/database_connection.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import ArgumentParser, Namespace 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | from .verb_registry import Verb, hyrax_verb 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @hyrax_verb 12 | class DatabaseConnection(Verb): 13 | """Verb to insert inference results into a vector database index for fast 14 | similarity search.""" 15 | 16 | cli_name = "database_connection" 17 | add_parser_kwargs = {} 18 | 19 | @staticmethod 20 | def setup_parser(parser: ArgumentParser): 21 | """Stub of parser setup""" 22 | 23 | parser.add_argument( 24 | "-d", 25 | "--database-dir", 26 | type=str, 27 | required=False, 28 | help="Directory of existing vector database.", 29 | ) 30 | 31 | def run_cli(self, args: Namespace | None = None): 32 | """Stub CLI implementation""" 33 | logger.error("Database connection is not supported from the command line.") 34 | 35 | def run(self, database_dir: Union[Path, str] | None = None): 36 | """Create a connection to the vector database for interactive queries. 37 | 38 | Parameters 39 | ---------- 40 | database_dir : str or Path, Optional 41 | The directory containing the database that will be connected to. 42 | If None, attempt to connect to the most recently created `...-vector-db-...` 43 | directory. If specified, it can point to either an empty directory 44 | or a directory containing an existing vector database. If the latter, the 45 | database will be updated with the new vectors. 46 | """ 47 | from hyrax.config_utils import find_most_recent_results_dir 48 | from hyrax.vector_dbs.vector_db_factory import vector_db_factory 49 | 50 | config = self.config 51 | 52 | # Attempt to find the directory containing the vector database. Check for 53 | # the database_dir argument first, then check the config file for 54 | # vector_db.vector_db_dir, and finally check for the most recently 55 | # created vector-db directory. 56 | vector_db_dir = None 57 | if database_dir is not None: 58 | vector_db_dir = database_dir 59 | elif config["vector_db"]["vector_db_dir"]: 60 | vector_db_dir = config["vector_db"]["vector_db_dir"] 61 | else: 62 | vector_db_dir = find_most_recent_results_dir(config, "vector-db") 63 | 64 | vector_db_path = Path(vector_db_dir).resolve() 65 | if not vector_db_path.is_dir(): 66 | raise RuntimeError( 67 | f"Database directory {str(vector_db_path)} does not exist. \ 68 | Have you run `hyrax.save_to_database(output_dir={vector_db_path})`?" 69 | ) 70 | 71 | # Get the flavor of database (i.e. Chroma, Qdrant, etc) from the config 72 | # file saved in `vector_db_path`. This ensures that we will use the correct 73 | # database class when creating the connection. 74 | db_type = self._get_database_type_from_config(vector_db_path) 75 | config["vector_db"]["name"] = db_type 76 | 77 | # Create an instance of the vector database class for the connection 78 | self.vector_db = vector_db_factory(config, context={"results_dir": vector_db_path}) 79 | if self.vector_db is None: 80 | raise RuntimeError(f"Unable to conenct to the {db_type} database in directory {vector_db_path}") 81 | 82 | return self.vector_db 83 | 84 | def _get_database_type_from_config(self, database_dir: Path): 85 | """Internal function that will read a config file from a directory and 86 | return the name of the vector database from it. i.e. "chromadb", "qdrant". 87 | 88 | Parameters 89 | ---------- 90 | database_dir : Path 91 | The directory containing the vector database and the config file that 92 | be used as reference. 93 | 94 | Returns 95 | ------- 96 | str 97 | The config value for ["vector_db"]["name"] in the reference config. 98 | """ 99 | from hyrax.config_utils import ConfigManager 100 | 101 | config_file = database_dir / "runtime_config.toml" 102 | reference_config = ConfigManager.read_runtime_config(config_filepath=config_file) 103 | return reference_config["vector_db"]["name"] 104 | -------------------------------------------------------------------------------- /src/hyrax/models/simclr.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101, D102 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F # noqa N812 7 | import torchvision.models as models 8 | import torchvision.transforms as T # noqa N812 9 | 10 | from hyrax.models.model_registry import hyrax_model 11 | 12 | 13 | class NTXentLoss(nn.Module): 14 | """Normalized Temperature-scaled Cross Entropy Loss. Based on Chen, 2020""" 15 | 16 | def __init__(self, temperature=0.1): 17 | super().__init__() 18 | self.temperature = temperature 19 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 20 | 21 | def forward(self, z_i, z_j): 22 | """Forward function of NTXentLoss. Based on Chen, 2020. 23 | Loss is calculated from representations from two augmented views of the same batch. 24 | """ 25 | batch_size = z_i.shape[0] 26 | device = z_i.device 27 | 28 | # Normalize the matrix and concat 29 | z_i = F.normalize(z_i, dim=1) # Shape: (N, D) 30 | z_j = F.normalize(z_j, dim=1) # Shape: (N, D) 31 | z = torch.cat([z_i, z_j], dim=0) # Shape: (2N, D) 32 | 33 | # Cosine similarity 34 | sim_matrix = torch.matmul(z, z.T) # Shape: (2N, 2N) 35 | 36 | # Remove self-similarity by masking the diagonal 37 | mask = torch.eye(2 * batch_size, dtype=torch.bool).to(device) 38 | sim_matrix = sim_matrix.masked_fill(mask, -float("inf")) 39 | 40 | # Apply temperature scaling 41 | sim_matrix /= self.temperature 42 | 43 | # Construct positive pair indices: Each example i has its positive pair at index i + N or i - N 44 | positive_indices = (torch.arange(0, 2 * batch_size, device=device) + batch_size) % (2 * batch_size) 45 | 46 | # Compute cross-entropy loss (it's mathematically equivalent) 47 | loss = self.criterion(sim_matrix, positive_indices) 48 | loss /= 2 * batch_size 49 | 50 | return loss 51 | 52 | 53 | class PositiveRescale: 54 | """Transformation Class specifically for ColorJitter to prevent wrong domain during the augmentation""" 55 | 56 | def __init__(self, transform): 57 | self.transform = transform 58 | 59 | def __call__(self, x): 60 | x = (x + 1) / 2 # to [0, 1] 61 | x = self.transform(x) 62 | return x * 2 - 1 # back to (-1, 1) 63 | 64 | 65 | @hyrax_model 66 | class SimCLR(nn.Module): 67 | """SimCLR model. Implementation based on Chen, 2020""" 68 | 69 | def __init__(self, config, shape): 70 | super().__init__() 71 | self.config = config 72 | self.shape = shape 73 | proj_dim = config["model"]["SimCLR"]["projection_dimension"] 74 | temperature = config["model"]["SimCLR"]["temperature"] 75 | 76 | backbone = models.resnet18(pretrained=False) 77 | backbone.fc = nn.Identity() 78 | self.backbone = backbone 79 | 80 | self.projection_head = nn.Sequential( 81 | nn.Linear(512, 512), 82 | nn.ReLU(inplace=True), 83 | nn.Linear(512, proj_dim), 84 | ) 85 | self.criterion = NTXentLoss(temperature) 86 | 87 | def forward(self, x): 88 | feats = self.backbone(x) 89 | return self.projection_head(feats) 90 | 91 | def train_step(self, x): 92 | aug = T.Compose( 93 | [ 94 | T.RandomResizedCrop(size=x.shape[-1]), 95 | T.RandomHorizontalFlip(self.config["model"]["SimCLR"]["horizontal_flip_probability"]), 96 | T.RandomApply( 97 | [PositiveRescale(T.ColorJitter(*self.config["model"]["SimCLR"]["color_jitter_params"]))], 98 | p=self.config["model"]["SimCLR"]["color_jitter_probability"], 99 | ), 100 | T.RandomGrayscale(p=self.config["model"]["SimCLR"]["grayscale_probability"]), 101 | T.GaussianBlur( 102 | kernel_size=self.config["model"]["SimCLR"]["gaussian_blur_kernel_size"], 103 | sigma=self.config["model"]["SimCLR"]["gaussian_blur_sigma_range"], 104 | ), 105 | ] 106 | ) 107 | 108 | x1 = torch.stack([aug(img) for img in x]) 109 | x2 = torch.stack([aug(img) for img in x]) 110 | 111 | z1 = self.forward(x1) 112 | z2 = self.forward(x2) 113 | 114 | loss = self.criterion(z1, z2) 115 | self.optimizer.zero_grad() 116 | loss.backward() 117 | self.optimizer.step() 118 | return {"loss": loss.item()} 119 | -------------------------------------------------------------------------------- /src/hyrax/verbs/to_onnx.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .verb_registry import Verb, hyrax_verb 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | @hyrax_verb 9 | class ToOnnx(Verb): 10 | """Export the model to ONNX format""" 11 | 12 | cli_name = "to_onnx" 13 | add_parser_kwargs = {} 14 | 15 | @staticmethod 16 | def setup_parser(parser): 17 | """Setup parser for ONNX export verb""" 18 | parser.add_argument( 19 | "--input-model-directory", 20 | type=str, 21 | required=False, 22 | help="Directory containing the trained model to export.", 23 | ) 24 | 25 | def run_cli(self, args=None): 26 | """Run the ONNX export verb from the CLI""" 27 | logger.info("Exporting model to ONNX format.") 28 | self.run(args.input_model_directory) 29 | 30 | def run(self, input_model_directory: str = None): 31 | """Export the model to ONNX format and save it to the specified path.""" 32 | import shutil 33 | from pathlib import Path 34 | 35 | from hyrax.config_utils import ( 36 | ConfigManager, 37 | create_results_dir, 38 | find_most_recent_results_dir, 39 | log_runtime_config, 40 | ) 41 | from hyrax.model_exporters import export_to_onnx 42 | from hyrax.pytorch_ignite import dist_data_loader, setup_dataset, setup_model 43 | 44 | config = self.config 45 | 46 | # Resolve the input directory in this order; 1) input_model_directory arg, 47 | # 2) config value, 3) most recent train results 48 | if input_model_directory: 49 | input_directory = Path(input_model_directory) 50 | if not input_directory.exists(): 51 | logger.error(f"Input model directory {input_directory} does not exist.") 52 | return 53 | elif config["onnx"]["input_model_directory"]: 54 | input_directory = Path(config["onnx"]["input_model_directory"]) 55 | if not input_directory.exists(): 56 | logger.error(f"Input model directory in the config file {input_directory} does not exist.") 57 | return 58 | else: 59 | input_directory = find_most_recent_results_dir(config, "train") 60 | if not input_directory: 61 | logger.error("No previous training results directory found for ONNX export.") 62 | return 63 | 64 | output_dir = create_results_dir(config, "onnx") 65 | log_runtime_config(config, output_dir) 66 | 67 | # grab the config file from the input directory, and render it. 68 | config_file = input_directory / "runtime_config.toml" 69 | config_manager = ConfigManager(runtime_config_filepath=config_file) 70 | config_from_training = config_manager.config 71 | 72 | # copy the to_tensor.py file from the input directory to the output directory 73 | to_tensor_src = input_directory / "to_tensor.py" 74 | to_tensor_dst = output_dir / "to_tensor.py" 75 | if to_tensor_src.exists(): 76 | shutil.copy(to_tensor_src, to_tensor_dst) 77 | 78 | # Use the config file to locate and assemble the trained weight file path 79 | weights_file_path = input_directory / config_from_training["train"]["weights_filename"] 80 | 81 | if not weights_file_path.exists(): 82 | raise RuntimeError(f"Could not find trained model weights: {weights_file_path}") 83 | 84 | # Use the config in the model directory to load the dataset(s) and create 85 | # The data loader instance to provide a data sample to the ONNX exporter. 86 | dataset = setup_dataset(config_from_training) 87 | model = setup_model(config_from_training, dataset["infer"]) 88 | # Load the trained weights and send the model to the CPU for ONNX export. 89 | model.load(weights_file_path) 90 | model.train(False) 91 | 92 | # Create an instance of the dataloader so that we can request a sample batch. 93 | infer_data_loader, _ = dist_data_loader(dataset["infer"], config_from_training, False) 94 | 95 | # Generate the `context` dictionary that will be provided to the ONNX exporter. 96 | context = { 97 | "results_dir": output_dir, 98 | "ml_framework": "pytorch", 99 | } 100 | 101 | # Get a sample of input data. 102 | batch_sample = next(iter(infer_data_loader)) 103 | batch_sample = model.to_tensor(batch_sample) 104 | 105 | export_to_onnx(model, batch_sample, config, context) 106 | -------------------------------------------------------------------------------- /docs/verbs.rst: -------------------------------------------------------------------------------- 1 | Hyrax Verbs 2 | =========== 3 | The term "verb" is used to describe the functions that Hyrax supports. 4 | For instance, the ``train`` verb is used to train a model. 5 | Each of the builtin verbs are detailed here. 6 | 7 | 8 | ``train`` 9 | --------- 10 | Train a model. The specific model to train and the data used for training is 11 | specified in the configuration file or by updating the default configurations 12 | after creating an instance of the Hyrax object. 13 | 14 | When called from a notebook or python, ``train()`` returns a trained pytorch 15 | model which you can :doc:`immediately evaluate, inspect, or export`. Batch evaluations of datasets 16 | are enabled using the ``infer`` verb, see below. 17 | 18 | .. tabs:: 19 | 20 | .. group-tab:: Notebook 21 | 22 | .. code-block:: python 23 | 24 | from hyrax import Hyrax 25 | 26 | # Create an instance of the Hyrax object 27 | h = Hyrax() 28 | 29 | # Train the model specified in the configuration file 30 | model = h.train() 31 | 32 | .. group-tab:: CLI 33 | 34 | .. code-block:: bash 35 | 36 | >> hyrax train 37 | 38 | 39 | ``infer`` 40 | --------- 41 | Run inference using a trained model. The specific model to use for inference can 42 | be specified in the configuration file. If no model is specified, Hyrax will find 43 | the most recently trained model in the results directory and use that for inference. 44 | The data used for inference is also specified in the configuration file. 45 | 46 | .. tabs:: 47 | 48 | .. group-tab:: Notebook 49 | 50 | .. code-block:: python 51 | 52 | from hyrax import Hyrax 53 | 54 | # Create an instance of the Hyrax object 55 | h = Hyrax() 56 | 57 | # Pass data through a trained model to produce embeddings or predictions. 58 | h.infer() 59 | 60 | .. group-tab:: CLI 61 | 62 | .. code-block:: bash 63 | 64 | >> hyrax infer 65 | 66 | When running infer in a notebook context, the infer verb returns an 67 | :doc:`InferenceDataSet` object which can be accessed using 68 | the ``[]`` operators in python. 69 | 70 | ``umap`` 71 | -------- 72 | Run UMAP on the output of inference or a dataset. By default, Hyrax will use the 73 | most recently generated output from the ``infer`` verb. 74 | 75 | .. tabs:: 76 | 77 | .. group-tab:: Notebook 78 | 79 | .. code-block:: python 80 | 81 | from hyrax import Hyrax 82 | 83 | # Create an instance of the Hyrax object 84 | h = Hyrax() 85 | 86 | # Train a UMAP and process the entire dataset. 87 | h.umap() 88 | 89 | .. group-tab:: CLI 90 | 91 | .. code-block:: bash 92 | 93 | >> hyrax umap 94 | 95 | 96 | ``visualize`` 97 | ------------- 98 | Interactively visualize embedded space produced by UMAP. 99 | Due to the fact that the visualization is interactive, it is not available in the CLI. 100 | 101 | .. code-block:: python 102 | 103 | from hyrax import Hyrax 104 | 105 | # Create an instance of the Hyrax object 106 | h = Hyrax() 107 | 108 | # Train the model specified in the configuration file 109 | h.visualize() 110 | 111 | 112 | ``prepare`` 113 | ----------- 114 | Create and return an instance of a Hyrax dataset object. This allows for convenient 115 | investigation of the dataset. While this can be run from the CLI, it is primarily 116 | intended for use in a notebook environment for exploration and debugging. 117 | 118 | .. code-block:: python 119 | 120 | from hyrax import Hyrax 121 | 122 | # Create an instance of the Hyrax object 123 | h = Hyrax() 124 | 125 | # Prepare the dataset for exploration 126 | dataset = h.prepare() 127 | 128 | 129 | ``index`` 130 | --------- 131 | Builds a vector database index from the output of inference. By default, Hyrax 132 | will use the most recently generated output from the ``infer`` verb, and will 133 | write the resulting database to a new timestamped directory under the default 134 | ``./results/`` directory with the form -index-. 135 | 136 | An existing database directory can be specified in order to add more vectors to 137 | an existing index. 138 | 139 | .. tabs:: 140 | 141 | .. group-tab:: Notebook 142 | 143 | .. code-block:: python 144 | 145 | from hyrax import Hyrax 146 | 147 | # Create an instance of the Hyrax object 148 | h = Hyrax() 149 | 150 | # Build a vector database index from the output of inference 151 | h.index() 152 | 153 | .. group-tab:: CLI 154 | 155 | .. code-block:: bash 156 | 157 | >> hyrax index [-i -o ] -------------------------------------------------------------------------------- /tests/hyrax/test_patch_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | 4 | from hyrax.models.model_registry import hyrax_model 5 | 6 | 7 | @hyrax_model 8 | class DummyModelOne(nn.Module): 9 | """A dummy model used to test patching of static methods like to_tensor""" 10 | 11 | def __init__(self, config, data_sample=None): 12 | super().__init__() 13 | # The optimizer needs at least one weight, so we add a dummy module here 14 | self.unused_module = nn.Linear(1, 1) 15 | self.config = config 16 | 17 | @staticmethod 18 | def to_tensor(x): 19 | """Default to_tensor method which just returns the input""" 20 | return x 21 | 22 | 23 | @hyrax_model 24 | class DummyModelTwo(nn.Module): 25 | """A dummy model used to test patching, that uses the default to_tensor method 26 | by default.""" 27 | 28 | def __init__(self, config, data_sample=None): 29 | super().__init__() 30 | # The optimizer needs at least one weight, so we add a dummy module here 31 | self.unused_module = nn.Linear(1, 1) 32 | self.config = config 33 | 34 | 35 | @staticmethod 36 | def to_tensor(x): 37 | """A simple to_tensor method that will patch the default one on DummyModel""" 38 | return x * 2 39 | 40 | 41 | def test_patch_to_tensor(tmp_path): 42 | """Test to ensure we can save and restore the to_tensor static method on a 43 | model instance correctly.""" 44 | 45 | # Minimal config dict to define crit and optimizer for the dummy model. 46 | config = { 47 | "criterion": {"name": "torch.nn.MSELoss"}, 48 | "optimizer": {"name": "torch.optim.SGD"}, 49 | "torch.optim.SGD": {"lr": 0.01}, 50 | } 51 | 52 | # create an instance of the dummy model 53 | model = DummyModelOne(config=config, data_sample=None) 54 | 55 | # manually update the to_tensor static method to be something simple 56 | # don't wrap this with staticmethod(...) because that would be a double wrapping. 57 | model.to_tensor = to_tensor 58 | 59 | # call model.save() to persist the model weights and to_tensor function. 60 | model.save(tmp_path / "model_weights.pth") 61 | 62 | # verify that the to_tensor file was written 63 | assert (tmp_path / "to_tensor.py").exists() 64 | 65 | # create a new instance of the dummy model and call .load() with the correct path 66 | new_model = DummyModelOne(config=config, data_sample=None) 67 | 68 | # verify that the new model's to_tensor method is the default one 69 | input_data = 3.0 70 | output_data = new_model.to_tensor(input_data) 71 | assert output_data == input_data 72 | 73 | # now load the saved weights and to_tensor method into the new model 74 | new_model.load(tmp_path / "model_weights.pth") 75 | 76 | # verify that the to_tensor method was restored correctly by passing some data to it. 77 | output_data = new_model.to_tensor(input_data) 78 | assert output_data == to_tensor(input_data) 79 | 80 | 81 | def test_patch_to_tensor_over_default(tmp_path): 82 | """Test to ensure we can save and restore the to_tensor static method on a 83 | model instance where the model class makes use of the default to_tensor method.""" 84 | 85 | # Minimal config dict to define crit and optimizer for the dummy model. 86 | config = { 87 | "criterion": {"name": "torch.nn.MSELoss"}, 88 | "optimizer": {"name": "torch.optim.SGD"}, 89 | "torch.optim.SGD": {"lr": 0.01}, 90 | } 91 | 92 | # create an instance of the dummy model 93 | model = DummyModelTwo(config=config, data_sample=None) 94 | 95 | # manually update the to_tensor static method to be something simple 96 | # don't wrap this with staticmethod(...) because that would be a double wrapping. 97 | model.to_tensor = to_tensor 98 | 99 | # call model.save() to persist the model weights and to_tensor function. 100 | model.save(tmp_path / "model_weights.pth") 101 | 102 | # verify that the to_tensor file was written 103 | assert (tmp_path / "to_tensor.py").exists() 104 | 105 | # create a new instance of the dummy model and call .load() with the correct path 106 | new_model = DummyModelTwo(config=config, data_sample=None) 107 | 108 | # verify that the new model's to_tensor method is the default one 109 | input_data = {"data": {"image": 3}} 110 | output_data = new_model.to_tensor(input_data) 111 | assert output_data[0] == 3 112 | assert isinstance(output_data[1], np.ndarray) 113 | 114 | # now load the saved weights and to_tensor method into the new model 115 | new_model.load(tmp_path / "model_weights.pth") 116 | 117 | # verify that the to_tensor method was restored correctly by passing some data to it. 118 | input_data = 3 119 | output_data = new_model.to_tensor(input_data) 120 | assert output_data == to_tensor(input_data) 121 | -------------------------------------------------------------------------------- /src/hyrax/data_sets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hyrax has several built-in datasets that you can use for astronomical data. For many uses, these datasets 3 | can be configured out-of-the box for a given project. 4 | 5 | :doc:`FitsImageDataSet ` is a generic container for fits image cutout data 6 | indexed by a user-provided catalog file. It attempts to cover common usage paradigms such as multiple images 7 | of the same object differentiated by telescope filter; however, extending the class as a custom dataset 8 | may be more well fit to advanced usage. 9 | 10 | :doc:`LSSTDataset ` Is a alpha-quality container for LSST cutout images, currently 11 | limited to ``deep_coadd`` type images, and restricted to run only on a Rubin observatory RSP environment 12 | where `LSST Pipeline `_ tools and a 13 | `data butler `_ with the appropriate images 14 | are available. 15 | 16 | :doc:`DownloadedLSSTDataset ` is a subclass of LSSTDataset that generates 17 | cutouts from the butler and saves them as ``.pt`` files on first access. On subsequent access, 18 | it loads the cutouts directly from these files, which can significantly speed up data loading times. 19 | It inherits from LSSTDataset to access the data butler and catalog functionality. 20 | 21 | :doc:`HSCDataSet ` Works similarly to FitsImageDataSet, but is specialized to 22 | `Hyper Suprime-Cam (HSC) `_ cutout images downloaded 23 | with the hyrax ``download`` verb. It contains additional integrity checks and is tightly integrated with 24 | the ``download`` and ``rebuild_manifest`` verbs. In future this class and the downloader may become a 25 | separate package. 26 | 27 | :doc:`HyraxCifarDataSet ` and 28 | :doc:`HyraxCifarIterableDataSet ` give access to the standard 29 | `CIFAR10 `_ labeled image dataset, automatically downloading the 30 | dataset if it is not present. These datasets are useful for testing hyrax and occasionally individual models, 31 | but they are not astronomical datasets. 32 | 33 | :doc:`HyraxRandomDataset ` and 34 | :doc:`HyraxRandomIterableDataset ` are utility datasets that 35 | generate random data with a specific shape. 36 | These datasets make it easy to test new models with simple random data. 37 | They are highly configurable such that it's possible to simulate input data for models that 38 | are under development. 39 | 40 | Each of these datasets can be used a starting point for a Custom Dataset by inheriting your custom dataset 41 | from e.g. `FitsImageDataSet`, or you can make an entirely custom dataset following the 42 | :ref:`custom dataset instructions ` and/or 43 | :doc:`custom dataset example notebook `. 44 | 45 | The remaining classes in this module exist primarily for Hyrax interface purposes: 46 | 47 | :doc:`InferenceDataset ` is a dataset class that represents an ``infer`` or ``umap`` 48 | result, and may be returned from those verbs to provide data access 49 | 50 | :doc:`HyraxDataset ` is a base class for all datasets in Hyrax and must be within 51 | the inheretence hierarchy of all custom datasets. It is not usable on it's own, but provides various fall-back 52 | functionality to make custom datasets easier to write. See the 53 | :ref:`custom dataset instructions ` and 54 | :doc:`example notebook
` for more information. 55 | 56 | """ 57 | 58 | # Remove import sorting, these are imported in the order written so that 59 | # autoapi docs are generated with ordering controlled below. 60 | # ruff: noqa: I001 61 | from .fits_image_dataset import FitsImageDataSet 62 | from .lsst_dataset import LSSTDataset 63 | from .downloaded_lsst_dataset import DownloadedLSSTDataset 64 | from .hsc_data_set import HSCDataSet 65 | from .hyrax_cifar_data_set import HyraxCifarDataSet, HyraxCifarIterableDataSet 66 | from .random.hyrax_random_dataset import ( 67 | HyraxRandomDataset, 68 | HyraxRandomIterableDataset, 69 | HyraxRandomDatasetBase, 70 | ) 71 | from .inference_dataset import InferenceDataSet 72 | from .data_set_registry import HyraxDataset 73 | from .hyrax_cifar_data_set import HyraxCifarBase 74 | from .hyrax_csv_dataset import HyraxCSVDataset 75 | 76 | __all__ = [ 77 | "HyraxCifarDataSet", 78 | "FitsImageDataSet", 79 | "HyraxCifarIterableDataSet", 80 | "HSCDataSet", 81 | "InferenceDataSet", 82 | "HyraxDataset", 83 | "LSSTDataset", 84 | "DownloadedLSSTDataset", 85 | "HyraxCifarBase", 86 | "HyraxRandomDataset", 87 | "HyraxRandomIterableDataset", 88 | "HyraxRandomDatasetBase", 89 | "HyraxCSVDataset", 90 | ] 91 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | 7 | import os 8 | import sys 9 | from importlib.metadata import version 10 | 11 | # Define path to the code to be documented **relative to where conf.py (this file) is kept** 12 | sys.path.insert(0, os.path.abspath("../src/")) 13 | 14 | # -- Project information ----------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 16 | 17 | project = "hyrax" 18 | copyright = "2025, LINCC Frameworks" 19 | author = "LINCC Frameworks" 20 | release = version("hyrax") 21 | # for example take major/minor 22 | version = ".".join(release.split(".")[:2]) 23 | 24 | # -- General configuration --------------------------------------------------- 25 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 26 | 27 | extensions = ["sphinx.ext.mathjax", "sphinx.ext.napoleon", "sphinx.ext.viewcode"] 28 | 29 | extensions.append("autoapi.extension") 30 | extensions.append("nbsphinx") 31 | extensions.append("sphinx_tabs.tabs") 32 | 33 | # -- sphinx-copybutton configuration ---------------------------------------- 34 | extensions.append("sphinx_copybutton") 35 | ## sets up the expected prompt text from console blocks, and excludes it from 36 | ## the text that goes into the clipboard. 37 | copybutton_exclude = ".linenos, .gp" 38 | copybutton_prompt_text = ">> " 39 | 40 | ## lets us suppress the copy button on select code blocks. 41 | copybutton_selector = "div:not(.no-copybutton) > div.highlight > pre" 42 | 43 | extensions.append("sphinx_togglebutton") 44 | 45 | templates_path: list[str] = [] 46 | exclude_patterns = ["_build", "**.ipynb_checkpoints"] 47 | 48 | # This assumes that sphinx-build is called from the root directory 49 | master_doc = "index" 50 | # Remove 'view source code' from top of page (for html, not python) 51 | html_show_sourcelink = False 52 | # Remove namespaces from class/method signatures 53 | add_module_names = False 54 | 55 | autoapi_type = "python" 56 | autoapi_dirs = ["../src/hyrax", "../src/hyrax_cli"] 57 | autoapi_ignore = ["*/__main__.py", "*/_version.py", "*3d_viz*"] # , "*downloadCutout*"] 58 | autoapi_add_toc_tree_entry = False 59 | autoapi_member_order = "bysource" 60 | # Useful for tracking down sphinx errors in autodoc's generated files from a sphinx warning 61 | autoapi_keep_files = True 62 | autoapi_python_class_content = "both" # Render __init__ and class docstring concatenated. 63 | autoapi_options = [ 64 | "members", 65 | "undoc-members", 66 | "private-members", 67 | "show-inheritance", 68 | "show-module-summary", 69 | "special-members", 70 | "imported-members", 71 | ] 72 | 73 | 74 | nitpick_ignore_regex = [ 75 | # Packages that have their own docs 76 | (r"^py:.*", r"^abc\..*"), 77 | (r"^py:.*", r"^astropy\..*"), 78 | (r"^py:.*", r"^tomlkit\..*"), 79 | (r"^py:.*", r"^pathlib\..*"), 80 | (r"^py:.*", r"^Path.*"), 81 | (r"^py:.*", r"^Table.*"), 82 | (r"^py:.*", r"^torch\..*"), 83 | (r"^py:.*", r"^concurrent\..*"), 84 | (r"^py:.*", r"^numpy\..*"), 85 | (r"^py:.*", r"^npt\..*"), 86 | (r"^py:.*", r"^np\..*"), 87 | (r"^py:.*", r"^datetime\..*"), 88 | (r"^py:.*", r"^urllib\..*"), 89 | (r"^py:.*", r"^torchvision\..*"), 90 | (r"^py:.*", r"^collections\..*"), 91 | (r"^py:.*", r"^_collections_abc\..*"), 92 | (r"^py:.*", r"^nn\..*"), 93 | (r"^py:.*", r"^tensorboardX\..*"), 94 | (r"^py:.*", r"^ignite\..*"), 95 | (r"^py:.*", r"^pytorch-ignite\..*"), 96 | (r"^py:.*", r"^argparse\..*"), 97 | (r"^py:.*", r"^holoviews\..*"), 98 | (r"^py:.*", r"^hv\..*"), 99 | (r"^py:.*", r"^pd\..*"), 100 | (r"^py:.*", r"^threading\..*"), 101 | (r"^py:.*", r"^enum\..*"), 102 | # Types and idiomatic ways we document types 103 | (r"^py:.*", r"^T$"), 104 | (r"^py:class", r"^[oO]ptional[:]?$"), 105 | (r"^py:class", r"^tuple$"), 106 | (r"^py:class", r"^string$"), 107 | (r"^py:.*", r"^TOML.*"), 108 | (r"^py:.*", r"^Ellipsis.*"), 109 | (r"^py:.*", r"^ML [fF]ramework [mM]odel"), 110 | (r"^py:.*", r"^Tensor"), 111 | (r"^py:.*", r"^SummaryWriter"), 112 | (r"^py:.*", r"^Dataset"), 113 | (r"^py:.*", r"^Engine"), 114 | (r"^py:.*", r"^DataLoader"), 115 | (r"^py:.*", r"^DistributedDataParallel"), 116 | (r"^py:.*", r"^DataParallel"), 117 | (r"^py:.*", r"^ArgumentParser"), 118 | (r"^py:.*", r"^Namespace"), 119 | # Types defined by our package that autodoc misidentifies in annotations 120 | (r"^py:.*", r"^hyrax.data_sets.fits_image_dataset.files_dict$"), 121 | (r"^py:.*", r"^dim_dict$"), 122 | (r"^py:.*", r"^dC.Rect$"), 123 | (r"^py:.*", r"^hyrax.downloadCutout.downloadCutout.Rect$"), 124 | ] 125 | 126 | html_theme = "sphinx_rtd_theme" 127 | -------------------------------------------------------------------------------- /src/hyrax/models/hyrax_cnn.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101, D102 2 | 3 | # This example model is taken from the PyTorch CIFAR10 tutorial: 4 | # https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network 5 | import logging 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F # noqa N812 11 | 12 | from .model_registry import hyrax_model 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @hyrax_model 18 | class HyraxCNN(nn.Module): 19 | """ 20 | This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class. 21 | """ 22 | 23 | def __init__(self, config, data_sample=None): 24 | super().__init__() 25 | self.config = config 26 | 27 | if data_sample is None: 28 | raise ValueError("A `data_sample` must be provided to HyraxCNN for dynamic sizing.") 29 | 30 | image_sample = data_sample[0] 31 | self.num_input_channels, self.image_width, self.image_height = image_sample.shape 32 | hidden_channels_1 = 6 33 | hidden_channels_2 = 16 34 | 35 | # Calculate how much our convolutional layers and pooling will affect 36 | # the size of final convolution. 37 | # 38 | # If the number of layers are changed this will need to be rewritten. 39 | conv1_end_w = self.conv2d_output_size(self.image_width, kernel_size=5) 40 | conv1_end_h = self.conv2d_output_size(self.image_height, kernel_size=5) 41 | 42 | pool1_end_w = self.pool2d_output_size(conv1_end_w, kernel_size=2, stride=2) 43 | pool1_end_h = self.pool2d_output_size(conv1_end_h, kernel_size=2, stride=2) 44 | 45 | conv2_end_w = self.conv2d_output_size(pool1_end_w, kernel_size=5) 46 | conv2_end_h = self.conv2d_output_size(pool1_end_h, kernel_size=5) 47 | 48 | pool2_end_w = self.pool2d_output_size(conv2_end_w, kernel_size=2, stride=2) 49 | pool2_end_h = self.pool2d_output_size(conv2_end_h, kernel_size=2, stride=2) 50 | 51 | self.conv1 = nn.Conv2d(self.num_input_channels, hidden_channels_1, 5) 52 | self.pool = nn.MaxPool2d(2, 2) 53 | self.conv2 = nn.Conv2d(hidden_channels_1, hidden_channels_2, 5) 54 | self.fc1 = nn.Linear(hidden_channels_2 * pool2_end_h * pool2_end_w, 120) 55 | self.fc2 = nn.Linear(120, 84) 56 | self.fc3 = nn.Linear(84, self.config["model"]["HyraxCNN"]["output_classes"]) 57 | 58 | def conv2d_output_size(self, input_size, kernel_size, padding=0, stride=1, dilation=1) -> int: 59 | # From https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html 60 | numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1 61 | return int((numerator / stride) + 1) 62 | 63 | def pool2d_output_size(self, input_size, kernel_size, stride, padding=0, dilation=1) -> int: 64 | # From https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html 65 | numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1 66 | return int((numerator / stride) + 1) 67 | 68 | def forward(self, x): 69 | # This check is inefficient - we assume that the example CNN will be primarily 70 | # used with the CIFAR10 dataset. During training, the `train_step` method 71 | # will unpack the tuple and only pass the first element to the `forward` method. 72 | # But for inference the entire tuple is passed in, so we need to handle 73 | # both cases. 74 | if isinstance(x, (tuple, list)): 75 | x, _ = x 76 | 77 | x = self.pool(F.relu(self.conv1(x))) 78 | x = self.pool(F.relu(self.conv2(x))) 79 | x = torch.flatten(x, 1) 80 | x = F.relu(self.fc1(x)) 81 | x = F.relu(self.fc2(x)) 82 | x = self.fc3(x) 83 | return x 84 | 85 | def train_step(self, batch): 86 | """This function contains the logic for a single training step. i.e. the 87 | contents of the inner loop of a ML training process. 88 | 89 | Parameters 90 | ---------- 91 | batch : tuple 92 | A tuple containing the inputs and labels for the current batch. 93 | 94 | Returns 95 | ------- 96 | Current loss value : dict 97 | Dictionary containing the loss value for the current batch. 98 | """ 99 | inputs, labels = batch 100 | 101 | self.optimizer.zero_grad() 102 | outputs = self(batch) 103 | loss = self.criterion(outputs, labels) 104 | loss.backward() 105 | self.optimizer.step() 106 | return {"loss": loss.item()} 107 | 108 | @staticmethod 109 | def to_tensor(data_dict) -> tuple: 110 | """Does NOT convert to PyTorch Tensors. 111 | This works exclusively with numpy data types and returns 112 | a tuple of numpy data types.""" 113 | 114 | if "data" not in data_dict: 115 | raise RuntimeError("Unable to find `data` key in data_dict") 116 | 117 | data = data_dict["data"] 118 | image = data.get("image", np.ndarray([])) 119 | label = data.get("label", np.ndarray([])) 120 | 121 | return (image, label) 122 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | 2 | repos: 3 | # Compare the local template version to the latest remote template version 4 | # This hook should always pass. It will print a message if the local version 5 | # is out of date. 6 | - repo: https://github.com/lincc-frameworks/pre-commit-hooks 7 | rev: v0.1.2 8 | hooks: 9 | - id: check-lincc-frameworks-template-version 10 | name: Check template version 11 | description: Compare current template version against latest 12 | verbose: true 13 | # Clear output from jupyter notebooks so that only the input cells are committed. 14 | - repo: local 15 | hooks: 16 | - id: jupyter-nb-clear-output 17 | name: Clear output from Jupyter notebooks 18 | description: Clear output from Jupyter notebooks. 19 | files: \.ipynb$ 20 | exclude: ^docs/pre_executed 21 | stages: [pre-commit] 22 | language: system 23 | entry: jupyter nbconvert --clear-output 24 | # Prevents committing directly branches named 'main' and 'master'. 25 | - repo: https://github.com/pre-commit/pre-commit-hooks 26 | rev: v4.4.0 27 | hooks: 28 | - id: no-commit-to-branch 29 | name: Prevent main branch commits 30 | description: Prevent the user from committing directly to the primary branch. 31 | - id: check-added-large-files 32 | name: Check for large files 33 | description: Prevent the user from committing very large files. 34 | args: ['--maxkb=500'] 35 | # Verify that pyproject.toml is well formed 36 | - repo: https://github.com/abravalheri/validate-pyproject 37 | rev: v0.24.1 38 | hooks: 39 | - id: validate-pyproject 40 | name: Validate pyproject.toml 41 | description: Verify that pyproject.toml adheres to the established schema. 42 | # Verify that GitHub workflows are well formed 43 | - repo: https://github.com/python-jsonschema/check-jsonschema 44 | rev: 0.28.0 45 | hooks: 46 | - id: check-github-workflows 47 | args: ["--verbose"] 48 | - repo: local 49 | hooks: 50 | - id: xcxc-check 51 | name: Check for note-to-self comments (xcxc) 52 | description: Grep all source files for xcxc which signifies a comment that shouldn't be checked in. 53 | entry: bash -c "[[ $(grep -rniI xcxc --exclude .pre-commit-config.yaml --exclude-dir _readthedocs --exclude-dir htmlcov --exclude-dir _results --exclude-dir env --exclude-dir autoapi ./* >&2 ; echo $?) == 1 ]]" 54 | language: system 55 | pass_filenames: false 56 | always_run: true 57 | - repo: https://github.com/astral-sh/ruff-pre-commit 58 | # Ruff version. 59 | rev: v0.14.7 60 | hooks: 61 | - id: ruff 62 | name: Lint code using ruff; sort and organize imports 63 | types_or: [ python, pyi ] 64 | args: ["--fix"] 65 | - repo: https://github.com/astral-sh/ruff-pre-commit 66 | # Ruff version. 67 | rev: v0.14.7 68 | hooks: 69 | - id: ruff-format 70 | name: Format code using ruff 71 | types_or: [ python, pyi, jupyter ] 72 | # Make sure Sphinx can build the documentation while explicitly omitting 73 | # notebooks from the docs, so users don't have to wait through the execution 74 | # of each notebook or each commit. By default, these will be checked in the 75 | # GitHub workflows. 76 | - repo: local 77 | hooks: 78 | - id: sphinx-build 79 | name: Build documentation with Sphinx 80 | entry: sphinx-build 81 | language: system 82 | always_run: true 83 | exclude_types: [file, symlink] 84 | args: 85 | [ 86 | "-M", # Run sphinx in make mode, so we can use -D flag later 87 | # Note: -M requires next 3 args to be builder, source, output 88 | "html", # Specify builder 89 | "./docs", # Source directory of documents 90 | "./_readthedocs", # Output directory for rendered documents 91 | "-T", # Show full trace back on exception 92 | "-E", # Don't use saved env; always read all files 93 | "-d", # Flag for cached environment and doctrees 94 | "./docs/_build/doctrees", # Directory 95 | "-D", # Flag to override settings in conf.py 96 | "exclude_patterns=notebooks/*,_build", # Exclude our notebooks from pre-commit 97 | "-W", # Warnings are errors 98 | "--keep-going", # Finish generating docs even if errors occur 99 | "-n", # Nitpick mode among other things checks for broken links 100 | ] 101 | # Run unit tests, verify that they pass. Note that coverage is run against 102 | # the ./src directory here because that is what will be committed. In the 103 | # github workflow script, the coverage is run against the installed package 104 | # and uploaded to Codecov by calling pytest like so: 105 | # `python -m pytest --cov= --cov-report=xml` 106 | - repo: local 107 | hooks: 108 | - id: pytest-check 109 | name: Run unit tests 110 | description: Run unit tests with pytest. 111 | entry: bash -c "if python -m pytest -n auto --co -qq -m 'not slow'; then python -m pytest -n auto --cov=./src --cov-report=html -m 'not slow'; fi" 112 | language: system 113 | pass_filenames: false 114 | always_run: true 115 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "hyrax" 3 | license = "MIT" 4 | license-files = ["LICENSE"] 5 | readme = "README.md" 6 | authors = [ 7 | { name = "LINCC Frameworks", email = "mtauraso@uw.edu" } 8 | ] 9 | classifiers = [ 10 | "Development Status :: 4 - Beta", 11 | "Intended Audience :: Developers", 12 | "Intended Audience :: Science/Research", 13 | "Operating System :: OS Independent", 14 | "Programming Language :: Python", 15 | ] 16 | dynamic = ["version"] 17 | requires-python = ">=3.10" 18 | dependencies = [ 19 | "astropy", # Used to load fits files of sources to query HSC cutout server 20 | # Pin to the current version of pytorch ignite so workarounds to 21 | # https://github.com/pytorch/ignite/issues/3372 function correctly 22 | # while allowing us to release packages that don't depend on dev versions 23 | # of pytorch-ignite. 24 | "pytorch-ignite <= 0.5.3", # Used for distributed training, logging, etc. 25 | "more-itertools", # Used to work around the issue in pytorch-ignite above 26 | "toml", # Used to load configuration files as dictionaries 27 | "tomlkit", # Used to load configuration files as dictionaries and retain comments 28 | "torch", # Used for CNN model and in train.py 29 | "torchvision", # Used in hsc data loader, example autoencoder, and CNN model data set 30 | "tensorboardX", # Used to log training metrics 31 | "tensorboard", # Used to log training metrics 32 | "schwimmbad", # Used to speedup hsc data loader file scans 33 | "chromadb > 1.0", # Used for similarity search 34 | "holoviews", # Used in Holoviews visualization prototype 35 | "bokeh", # Used in Holoviews visualization prototype 36 | "jupyter_bokeh", # Used in Holoviews visualization prototype 37 | "datashader", # Used in Holoviews visualization prototype 38 | "pandas", # Used in Holoviews visualization prototype 39 | "numpy", # Used in Holoviews visualization prototype 40 | "scipy", # Used in Holoviews visualization prototype 41 | "cython", # Used in Holoviews visualization prototype 42 | "mlflow", # Used to log training metrics and compare models 43 | "pynvml", # Used to gather GPU usage information 44 | "umap-learn", # Used to map latent spaces down to 2d 45 | "pooch", # Used to download data files 46 | "onnx", # Used to export models to ONNX format 47 | "onnxruntime", # Used to run ONNX models 48 | "onnxscript", # Used to convert PyTorch models to ONNX format 49 | "plotly", # Used in 3d visualization 50 | "psutil", # Used for memory monitoring 51 | "tqdm", # Used to show progress bars 52 | "qdrant-client", # Vector database for similarity search 53 | "colorama", # Used to color terminal output 54 | ] 55 | 56 | [project.scripts] 57 | hyrax = "hyrax_cli.main:main" 58 | 59 | [project.urls] 60 | "Source Code" = "https://github.com/lincc-frameworks/hyrax" 61 | 62 | # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes) 63 | [project.optional-dependencies] 64 | examples = [ 65 | ] 66 | 67 | dev = [ 68 | "asv[virtualenv]==0.6.5", # Used to compute performance benchmarks 69 | "jupyter", # Clears output from Jupyter notebooks 70 | "matplotlib", # For example notebooks 71 | "pre-commit", # Used to run checks before finalizing a git commit 72 | "pytest", 73 | "pytest-cov", # Used to report total code coverage 74 | "pytest-env", # Used to set environment variables in testing 75 | "pytest-xdist", # Used to parallelize unit tests 76 | "ruff", # Used for static linting of files 77 | "sphinx", 78 | "sphinx-autoapi", 79 | "nbsphinx", 80 | "sphinx-tabs", 81 | "sphinx-copybutton", 82 | "sphinx-togglebutton", 83 | "sphinx-rtd-theme", 84 | ] 85 | 86 | [build-system] 87 | requires = [ 88 | "setuptools>=62", # Used to build and package the Python project 89 | "setuptools_scm>=6.2", # Gets release version from git. Makes it available programmatically 90 | ] 91 | build-backend = "setuptools.build_meta" 92 | 93 | [tool.setuptools_scm] 94 | write_to = "src/hyrax/_version.py" 95 | 96 | [tool.pytest.ini_options] 97 | testpaths = [ 98 | "tests", 99 | "src", 100 | "docs", 101 | ] 102 | addopts = "--doctest-modules --doctest-glob=*.rst" 103 | env = [ 104 | "TQDM_DISABLE=1", 105 | ] 106 | markers = [ 107 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 108 | ] 109 | 110 | [tool.ruff] 111 | line-length = 110 112 | target-version = "py310" 113 | [tool.ruff.lint] 114 | select = [ 115 | # pycodestyle 116 | "E", 117 | "W", 118 | # Pyflakes 119 | "F", 120 | # pep8-naming 121 | "N", 122 | # pyupgrade 123 | "UP", 124 | # flake8-bugbear 125 | "B", 126 | # flake8-simplify 127 | "SIM", 128 | # isort 129 | "I", 130 | # docstrings 131 | "D101", 132 | "D102", 133 | "D103", 134 | "D106", 135 | "D206", 136 | "D207", 137 | "D208", 138 | "D300", 139 | "D417", 140 | "D419", 141 | # Numpy v2.0 compatibility 142 | "NPY201", 143 | ] 144 | ignore = [ 145 | "UP006", # Allow non standard library generics in type hints 146 | "UP007", # Allow Union in type hints 147 | "SIM114", # Allow if with same arms 148 | "B028", # Allow default warning level 149 | "SIM117", # Allow nested with 150 | "UP015", # Allow redundant open parameters 151 | "UP028", # Allow yield in for loop 152 | "B905", # Allow zip without `strict` 153 | ] 154 | 155 | 156 | [tool.coverage.run] 157 | omit=["src/hyrax/_version.py"] 158 | -------------------------------------------------------------------------------- /tests/hyrax/test_to_onnx.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | import hyrax 6 | from hyrax.config_utils import find_most_recent_results_dir 7 | from hyrax.verbs.to_onnx import ToOnnx 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @pytest.fixture 13 | def trained_hyrax(tmp_path): 14 | """Fixture that creates a trained Hyrax instance for ONNX export tests""" 15 | # Create a Hyrax instance with loopback model configuration 16 | h = hyrax.Hyrax() 17 | h.config["model"]["name"] = "HyraxLoopback" 18 | h.config["train"]["epochs"] = 1 19 | h.config["data_loader"]["batch_size"] = 4 20 | h.config["general"]["results_dir"] = str(tmp_path) 21 | h.config["general"]["dev_mode"] = True 22 | 23 | # Configure dataset 24 | h.config["model_inputs"] = { 25 | "train": { 26 | "data": { 27 | "dataset_class": "HyraxRandomDataset", 28 | "data_location": str(tmp_path / "data_train"), 29 | "fields": ["image", "label"], 30 | "primary_id_field": "object_id", 31 | } 32 | }, 33 | "infer": { 34 | "data": { 35 | "dataset_class": "HyraxRandomDataset", 36 | "data_location": str(tmp_path / "data_train"), 37 | "fields": ["image"], 38 | "primary_id_field": "object_id", 39 | } 40 | }, 41 | } 42 | h.config["data_set"]["HyraxRandomDataset"]["size"] = 20 43 | h.config["data_set"]["HyraxRandomDataset"]["seed"] = 0 44 | h.config["data_set"]["HyraxRandomDataset"]["shape"] = [2, 3] 45 | 46 | # Train the model 47 | h.train() 48 | 49 | return h 50 | 51 | 52 | def test_to_onnx_successful_export(trained_hyrax): 53 | """Test successful ONNX export from a trained model""" 54 | h = trained_hyrax 55 | 56 | # Find the training results directory 57 | train_dir = find_most_recent_results_dir(h.config, "train") 58 | assert train_dir is not None, "Training results directory should exist" 59 | 60 | # Export to ONNX using the verb 61 | to_onnx_verb = ToOnnx(h.config) 62 | to_onnx_verb.run(str(train_dir)) 63 | 64 | onnx_dir = find_most_recent_results_dir(h.config, "onnx") 65 | 66 | # Verify ONNX model was created with timestamp-based filename 67 | onnx_files = list(onnx_dir.glob("*.onnx")) 68 | assert len(onnx_files) == 1, "Exactly one ONNX file should be created" 69 | 70 | onnx_file = onnx_files[0] 71 | # Check filename pattern: _opset_.onnx (opset version only) 72 | assert "model" in onnx_file.name 73 | assert onnx_file.suffix == ".onnx" 74 | 75 | 76 | def test_to_onnx_missing_input_directory(tmp_path): 77 | """Test handling of missing input directories""" 78 | h = hyrax.Hyrax() 79 | h.config["model"]["name"] = "HyraxLoopback" 80 | h.config["general"]["results_dir"] = str(tmp_path) 81 | 82 | to_onnx_verb = ToOnnx(h.config) 83 | 84 | # Test with non-existent directory 85 | non_existent_dir = tmp_path / "does_not_exist" 86 | to_onnx_verb.run(str(non_existent_dir)) 87 | 88 | # The verb should log an error and return without creating ONNX files 89 | # Verify no ONNX files were created 90 | onnx_files = list(tmp_path.glob("**/*.onnx")) 91 | assert len(onnx_files) == 0, "No ONNX files should be created for missing directory" 92 | 93 | 94 | def test_to_onnx_missing_input_directory_from_config(tmp_path): 95 | """Test handling of missing input directories specified in config""" 96 | h = hyrax.Hyrax() 97 | h.config["model"]["name"] = "HyraxLoopback" 98 | h.config["general"]["results_dir"] = str(tmp_path) 99 | h.config["onnx"]["input_model_directory"] = str(tmp_path / "does_not_exist") 100 | 101 | to_onnx_verb = ToOnnx(h.config) 102 | 103 | # Test with directory from config that doesn't exist 104 | to_onnx_verb.run() 105 | 106 | # The verb should log an error and return without creating ONNX files 107 | onnx_files = list(tmp_path.glob("**/*.onnx")) 108 | assert len(onnx_files) == 0, "No ONNX files should be created for missing directory" 109 | 110 | 111 | def test_to_onnx_no_previous_training(tmp_path): 112 | """Test handling when no previous training results exist""" 113 | h = hyrax.Hyrax() 114 | h.config["model"]["name"] = "HyraxLoopback" 115 | h.config["general"]["results_dir"] = str(tmp_path) 116 | 117 | to_onnx_verb = ToOnnx(h.config) 118 | 119 | # Try to export without any prior training 120 | to_onnx_verb.run() 121 | 122 | # The verb should log an error and return without creating ONNX files 123 | onnx_files = list(tmp_path.glob("**/*.onnx")) 124 | assert len(onnx_files) == 0, "No ONNX files should be created without prior training" 125 | 126 | 127 | def test_to_onnx_cli_argument_parsing(tmp_path): 128 | """Test that CLI arguments are properly parsed""" 129 | h = hyrax.Hyrax() 130 | h.config["general"]["results_dir"] = str(tmp_path) 131 | 132 | to_onnx_verb = ToOnnx(h.config) 133 | 134 | # Mock the args object 135 | class MockArgs: 136 | def __init__(self): 137 | self.input_model_directory = str(tmp_path / "test_dir") 138 | 139 | args = MockArgs() 140 | 141 | # This should use the input_model_directory from args 142 | # We expect it to fail because the directory doesn't exist 143 | to_onnx_verb.run_cli(args) 144 | 145 | # Verify no ONNX files were created (directory doesn't exist) 146 | onnx_files = list(tmp_path.glob("**/*.onnx")) 147 | assert len(onnx_files) == 0 148 | -------------------------------------------------------------------------------- /docs/configuration.rst: -------------------------------------------------------------------------------- 1 | Configuration 2 | ============= 3 | 4 | Hyrax ships with a complete default configuration file that can be used immediately 5 | to run the software, however, to make the most of Hyrax you'll need to modify 6 | the configuration to suit your specific needs. 7 | 8 | 9 | Using the configuration system 10 | ------------------------------ 11 | When creating an instance of a ``Hyrax`` object in a notebook or running ``hyrax`` 12 | from the command line, the configuration is the primary method for specifying the parameters. 13 | 14 | If no configuration file is specified, :ref:`the default` 15 | will be used. To specify a different configuration file, use the 16 | ``-c | --runtime-config`` flag from the CLI 17 | or pass the path to the configuration file when creating a ``Hyrax`` object. 18 | 19 | .. tabs:: 20 | 21 | .. group-tab:: Notebook 22 | 23 | .. code-block:: python 24 | 25 | from hyrax import Hyrax 26 | 27 | # Create an instance of the Hyrax object 28 | f = Hyrax(config_file=) 29 | 30 | # Train the model specified in the configuration file 31 | f.train() 32 | 33 | .. group-tab:: CLI 34 | 35 | .. code-block:: bash 36 | 37 | >> hyrax train -c 38 | 39 | 40 | Your first custom configuration 41 | ............................... 42 | 43 | You could create a copy of the entire default configuration file and modify it to suit 44 | your needs, however that's typically not required because often there are only 45 | a few parameters that must be updated for any given Hyrax action. 46 | 47 | If a specific configuration file is provided, Hyrax will combine it with the default 48 | configuration and overwrite the default values with the specific ones. 49 | 50 | For example, if a file called ``my_config.toml`` had the following contents: 51 | 52 | .. code-block:: bash 53 | :linenos: 54 | 55 | [general] 56 | log_level = "debug" 57 | 58 | It could be used to override the default ``log_level`` configuration, while leaving 59 | the rest of the configuration unchanged. 60 | 61 | .. tabs:: 62 | 63 | .. group-tab:: Notebook 64 | 65 | .. code-block:: python 66 | 67 | from hyrax import Hyrax 68 | 69 | # Create an instance of the Hyrax object 70 | f = Hyrax(config_file=my_config.toml) 71 | 72 | # Train the model specified in the configuration file 73 | f.train() 74 | 75 | .. group-tab:: CLI 76 | 77 | .. code-block:: bash 78 | 79 | >> hyrax train -c my_config.toml 80 | 81 | 82 | Updating settings in a notebook 83 | ............................... 84 | Additionally, Hyrax supports modification of the configuration interactively in a notebook. 85 | 86 | .. code-block:: python 87 | 88 | from hyrax import Hyrax 89 | 90 | # Create a Hyrax instance, implicitly using the default configuration 91 | f = Hyrax() 92 | 93 | # Set the log level for the Hyrax instance config 94 | f.config['general']['log_level'] = 'debug' 95 | 96 | # Train the model specified in the configuration file 97 | f.train() 98 | 99 | 100 | Immutable configurations 101 | ........................ 102 | Once Hyrax begins running an action, the configuration becomes immutable. 103 | This means that the configuration cannot be changed during the execution of an action, 104 | and attempting to do so in code will raise an exception. 105 | 106 | By making the configuration immutable during execution, we ensure that the state 107 | of all parameters can be accurately saved with the results of the action. 108 | 109 | 110 | About the default configuration 111 | ------------------------------- 112 | 113 | The default configuration for Hyrax contains safe default values for all of the 114 | settings that Hyrax uses. A portion of the default configuration file is shown below. 115 | 116 | .. note:: 117 | Only the first portion of the default configuration file is shown below. 118 | The entire file can be found at the bottom of the page here: :ref:`complete_default_config`. 119 | 120 | .. literalinclude:: ../src/hyrax/hyrax_default_config.toml 121 | :language: text 122 | :linenos: 123 | :lines: 1-25 124 | 125 | There is a lot of information there, but don't worry, we'll break it down for you. 126 | 127 | First, the file formatted using TOML for its easy readability and because it is 128 | one of the few markdown languages that natively support comments. 129 | TOML files are organized into "tables", and each table contains one or more 130 | key/value pairs. 131 | 132 | For instance the ``[general]`` table (the first table in the config) 133 | contains several keys including ``log_level`` and ``results_dir``. 134 | Each of those keys has a value associated with it. 135 | e.g. ``log_level = "info"``. 136 | 137 | Second, every key has an associated comment describing what the key does. 138 | We attempt to keep the comments as concise as possible. 139 | 140 | Finally, the configuration file is organized into tables that roughly correspond 141 | to the different actions that Hyrax can take. 142 | For instance, the ``[train]`` table contains parameters needed when training a 143 | model such as ``epochs`` and ``weights_filename``. 144 | While the ``[infer]`` table contains keys such as ``model_weights_file``. 145 | 146 | 147 | .. _complete_default_config: 148 | 149 | Complete default configuration file 150 | ----------------------------------- 151 | 152 | .. literalinclude:: ../src/hyrax/hyrax_default_config.toml 153 | :language: bash 154 | :linenos: 155 | -------------------------------------------------------------------------------- /tests/hyrax/test_qdrant_impl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from hyrax import Hyrax 5 | from hyrax.vector_dbs.qdrantdb_impl import QdrantDB 6 | 7 | 8 | @pytest.fixture() 9 | def random_vector_generator(batch_size=1, vector_size=3): 10 | """Create random vectors""" 11 | 12 | def _generator(batch_size=1, vector_size=3): 13 | while True: 14 | batch = [np.random.rand(vector_size) for _ in range(batch_size)] 15 | yield batch 16 | 17 | return _generator 18 | 19 | 20 | @pytest.fixture() 21 | def qdrant_instance(tmp_path): 22 | """Create a QdrantDB instance for testing""" 23 | h = Hyrax() 24 | h.config["vector_db"]["qdrant"]["vector_size"] = 3 25 | qdrant_instance = QdrantDB(h.config, {"results_dir": tmp_path}) 26 | qdrant_instance.connect() 27 | qdrant_instance.create() 28 | return qdrant_instance 29 | 30 | 31 | def test_connect(tmp_path): 32 | """Test that we can create a connections to the database""" 33 | h = Hyrax() 34 | qdrant_instance = QdrantDB(h.config, {"results_dir": tmp_path}) 35 | qdrant_instance.connect() 36 | 37 | assert qdrant_instance.collection_name is None 38 | assert qdrant_instance.client is not None 39 | 40 | 41 | def test_create(qdrant_instance): 42 | """Test creation of a single collection (shard) in the database""" 43 | collections = qdrant_instance.client.get_collections().collections 44 | 45 | assert collections is not None 46 | assert len(collections) == 1 47 | assert collections[0].name == "shard_0" 48 | 49 | 50 | def test_insert(qdrant_instance, random_vector_generator): 51 | """Ensure that we can insert IDs and vectors into the database""" 52 | 53 | batch_size = 20 54 | num_batches = 10 55 | vector_generator = random_vector_generator(batch_size * num_batches) 56 | ids = list(range(batch_size * num_batches)) 57 | vectors = [t.flatten() for t in next(vector_generator)] 58 | 59 | qdrant_instance.insert(ids, vectors) 60 | total_count = qdrant_instance.client.count(collection_name="shard_0", exact=True) 61 | assert total_count.count == batch_size * num_batches 62 | 63 | 64 | def test_insert_wrong_size_raises(qdrant_instance): 65 | """Ensure that inserting vectors of the wrong size raises an error. Expected 66 | size for this test is 3, as set in the fixture.""" 67 | 68 | ids = ["id1", "id2"] 69 | vectors = [np.array([1, 2]), np.array([3, 4, 5])] # Different sizes 70 | 71 | with pytest.raises(ValueError): 72 | qdrant_instance.insert(ids, vectors) 73 | 74 | 75 | def test_search_by_id(qdrant_instance): 76 | """Test search_by_id retrieves nearest neighbor ids""" 77 | 78 | ids = ["id1", "id2"] 79 | vectors = [np.array([1, 2, 3]), np.array([4, 5, 6])] 80 | qdrant_instance.insert(ids, vectors) 81 | 82 | # Search by single vector should return the id1 and id2 in that order 83 | result = qdrant_instance.search_by_id("id1", k=2) 84 | assert len(result["id1"]) == 2 85 | assert np.all(result["id1"] == ["id1", "id2"]) 86 | 87 | # Search should return all ids when k is larger than the number of ids 88 | result = qdrant_instance.search_by_id("id1", k=5) 89 | assert len(result["id1"]) == 2 90 | assert np.all(result["id1"] == ["id1", "id2"]) 91 | 92 | # Search should return 1 id when k is 1 93 | result = qdrant_instance.search_by_id("id1", k=1) 94 | assert len(result["id1"]) == 1 95 | assert np.all(result["id1"] == ["id1"]) 96 | 97 | # Search by another vector should return the id2 and id1 in that order 98 | result = qdrant_instance.search_by_id("id2", k=2) 99 | assert len(result["id2"]) == 2 100 | assert np.all(result["id2"] == ["id2", "id1"]) 101 | 102 | 103 | def test_search_by_vector(qdrant_instance): 104 | """Test search_by_vector retrieves nearest neighbor ids""" 105 | 106 | ids = ["id1", "id2"] 107 | vectors = [np.array([1, 2, 3]), np.array([4, 5, 6])] 108 | qdrant_instance.insert(ids, vectors) 109 | 110 | # Search by single vector should return the id1 and id2 in that order 111 | result = qdrant_instance.search_by_vector([np.array([1, 2, 3])], k=2) 112 | assert len(result[0]) == 2 113 | assert np.all(result[0] == ["id1", "id2"]) 114 | 115 | # Search should return all ids when k is larger than the number of ids 116 | result = qdrant_instance.search_by_vector([np.array([1, 2, 3])], k=5) 117 | assert len(result[0]) == 2 118 | assert np.all(result[0] == ["id1", "id2"]) 119 | 120 | # Search should return 1 id when k is 1 121 | result = qdrant_instance.search_by_vector([np.array([1, 2, 3])], k=1) 122 | assert len(result[0]) == 1 123 | assert np.all(result[0] == ["id1"]) 124 | 125 | # Search by multiple vectors should return the ids in the order of the vectors 126 | result = qdrant_instance.search_by_vector([np.array([4, 5, 6]), np.array([1, 2, 3])], k=2) 127 | assert len(result) == 2 128 | assert len(result[0]) == 2 129 | assert len(result[1]) == 2 130 | assert np.all(result[0] == ["id2", "id1"]) 131 | assert np.all(result[1] == ["id1", "id2"]) 132 | 133 | 134 | def test_get_by_id(qdrant_instance): 135 | """Test get_by_id retrieves embeddings""" 136 | 137 | ids = ["id1", "id2"] 138 | vectors = [np.array([1, 2, 3]), np.array([4, 5, 6])] 139 | qdrant_instance.insert(ids, vectors) 140 | 141 | result = qdrant_instance.get_by_id("id1") 142 | assert np.all(result["id1"] == [1, 2, 3]) 143 | 144 | result = qdrant_instance.get_by_id(["id1", "id2"]) 145 | assert len(result) == 2 146 | assert np.all(result["id1"] == [1, 2, 3]) 147 | assert np.all(result["id2"] == [4, 5, 6]) 148 | -------------------------------------------------------------------------------- /tests/hyrax/test_data/small_dataset_hscstars/star_cat_correct.astropy.csv: -------------------------------------------------------------------------------- 1 | # object_id,parent_id,ira,idec,imag_psf,imag_psf_err,iflux_psf,iflux_psf_err,iflux_psf_flags,ishape_sdss_ixx,ishape_sdss_iyy,ishape_sdss_ixy,ishape_sdss_ixx_var,ishape_sdss_iyy_var,ishape_sdss_ixy_var,ishape_sdss_psf_ixx,ishape_sdss_psf_iyy,ishape_sdss_psf_ixy,tract,icalib_psf_used,merge_peak_g,merge_peak_r,merge_peak_i,merge_peak_z,merge_peak_y,icountinputs,ideblend_has_stray_flux,iflags_pixel_bright_object_center,iflags_pixel_bright_object_any,iblendedness_abs_flux,iflags_negative,ideblend_too_many_peaks,ideblend_parent_too_big,icentroid_naive_flags,iflags_pixel_interpolated_any,iflags_pixel_saturated_any,iflags_pixel_cr_any,iflags_pixel_suspect_any,SNR,logSNR,psf_filename,star_filename,bad 2 | 36411452835238206,0,30.59133800094649,-6.2823738478524005,19.3652439,0.00102574087,6.514818103163041e-28,6.154824544248429e-31,False,0.0899620131,0.0836636125999999,0.00509395171,7.82101353e-08,3.64844581e-08,6.75899514e-08,0.091832608,0.0852637811999999,0.00516478205,8279,True,False,False,True,False,False,6,False,True,True,0.0,False,False,False,False,False,False,False,False,1058.4896541447342,3.0246866175091287,"images/2-psf-calexp-pdr2_wide-HSC-I-8279-0,6-30.59134--6.28237.fits",images/2-cutout-HSC-I-8279-pdr2_wide.fits,0 3 | 36411452835248579,36411452835236743,30.5692242770036,-6.324803474292955,18.980484,0.000686287647,9.285533303788271e-28,5.8693329060638185e-31,False,0.0943031609,0.0920326859,0.00591547042,3.73848543e-08,1.83116544e-08,3.55775569e-08,0.0948586017,0.0930508599,0.00592835806,8279,True,False,False,True,False,False,8,False,False,False,0.0,False,False,False,False,False,False,False,False,1582.0423636551702,3.199218108779444,"images/3-psf-calexp-pdr2_wide-HSC-I-8279-0,6-30.56922--6.32480.fits",images/3-cutout-HSC-I-8279-pdr2_wide.fits,0 4 | 36411452835249051,36411452835237478,30.610496856398964,-6.303467718200048,19.0033455,0.000693935435,9.092041624858649e-28,5.811070748387771e-31,False,0.0951416269,0.088964045,0.0053282817799999,3.57204151e-08,1.67529617e-08,3.12074562e-08,0.0952783972,0.0913838968,0.00595270935,8279,True,False,False,True,False,False,8,False,False,False,3.92249967e-06,False,False,False,False,False,False,False,False,1564.6069405334906,3.1944052524271425,"images/4-psf-calexp-pdr2_wide-HSC-I-8279-0,6-30.61050--6.30347.fits",images/4-cutout-HSC-I-8279-pdr2_wide.fits,0 5 | 36411452835250175,36411452835239397,30.55462203388881,-6.248103518070611,17.9776039,0.000378740806,2.3386108049255644e-27,8.157850401586379e-31,False,0.0934380889,0.0902590527999999,0.00563794002,1.17908705e-08,5.714955179999999e-09,1.09940395e-08,0.0937158689,0.0912092477,0.00556621002,8279,True,False,False,True,False,False,8,False,False,False,1.08765084e-06,False,False,False,False,False,False,False,False,2866.699791982943,3.457382214916563,"images/5-psf-calexp-pdr2_wide-HSC-I-8279-0,6-30.55462--6.24810.fits",images/5-cutout-HSC-I-8279-pdr2_wide.fits,0 6 | 36411457130203411,0,30.591034889439047,-6.161659484538039,18.7893028,0.000716047012,1.1073334293186749e-27,7.302904531684917e-31,False,0.0896800682,0.0792114958,0.00347501505,3.91120416e-08,1.72997705e-08,3.04973895e-08,0.0900582746,0.0795295388,0.00325915101,8279,True,False,False,True,False,False,5,False,False,False,0.0,False,False,False,False,False,False,False,False,1516.291804876152,3.180782787742229,"images/6-psf-calexp-pdr2_wide-HSC-I-8279-0,7-30.59104--6.16166.fits",images/6-cutout-HSC-I-8279-pdr2_wide.fits,0 7 | 36411457130204168,0,30.540473010694857,-6.138881294477764,19.2185364,0.000772552274,7.457365118492627e-28,5.306265471222657e-31,False,0.0944255218,0.0900048241,0.00367699075,4.47393269e-08,2.13529159e-08,4.06295797e-08,0.0946616456,0.0900018737,0.00382546382,8279,True,False,False,True,False,False,9,False,False,False,0.0,False,False,False,False,False,False,False,False,1405.388621985835,3.1477964331710524,"images/7-psf-calexp-pdr2_wide-HSC-I-8279-0,7-30.54047--6.13888.fits",images/7-cutout-HSC-I-8279-pdr2_wide.fits,0 8 | 36411457130206288,0,30.568843641926343,-6.074669887054997,19.3912907,0.00101390155,6.360394103670531e-28,5.939577039039547e-31,False,0.0950927138,0.0953002647,0.00355544221,8.39126812e-08,4.20998916e-08,8.42423518e-08,0.0981467888,0.0987515673,0.0041084745899999,8279,True,False,False,True,False,False,6,False,False,False,0.0,False,False,False,False,False,False,False,False,1070.8496685647892,3.0297285065972184,"images/8-psf-calexp-pdr2_wide-HSC-I-8279-0,7-30.56884--6.07467.fits",images/8-cutout-HSC-I-8279-pdr2_wide.fits,0 9 | 36411457130214646,36411457130201663,30.597375059234032,-6.212819318020295,19.3335953,0.00097498059,6.7075223453620015e-28,6.0232903473177975e-31,False,0.0890230313,0.0802944675,0.00446446147,6.60116157e-08,2.98463618e-08,5.36644436e-08,0.0923568383,0.0819297209,0.00433210609,8279,True,False,False,True,False,False,7,False,False,False,0.0018301678399999,False,False,False,False,False,False,False,False,1113.5977113155266,3.0467283296981877,"images/9-psf-calexp-pdr2_wide-HSC-I-8279-0,7-30.59737--6.21282.fits",images/9-cutout-HSC-I-8279-pdr2_wide.fits,0 10 | 36411457130215774,36411457130203752,30.61880218968407,-6.1514135742461225,18.9717007,0.000628222188,9.360949512275002e-28,5.4163772655292775e-31,False,0.0953872128999999,0.0871068536999999,0.00334656471,3.2788293e-08,1.49892259e-08,2.73288556e-08,0.0959000885,0.0877150893,0.0031442695799999,8279,True,False,False,True,False,False,9,False,True,True,5.41444667e-12,False,False,False,False,False,False,False,False,1728.2676322880304,3.237610996375436,"images/10-psf-calexp-pdr2_wide-HSC-I-8279-0,7-30.61880--6.15141.fits",images/10-cutout-HSC-I-8279-pdr2_wide.fits,0 11 | 36411457130216436,36411457130204918,30.611377446886397,-6.116095804709692,19.5215073,0.000976910582,5.6415335590508455e-28,5.076070851409527e-31,False,0.0941936075999999,0.0859252809999999,0.0027294436,7.56572192e-08,3.45360931e-08,6.29315693e-08,0.0961371437,0.088414304,0.00285776914,8279,True,False,False,True,False,False,9,False,False,False,0.0,False,False,False,False,False,False,False,False,1111.3977176824274,3.0458695006021896,"images/11-psf-calexp-pdr2_wide-HSC-I-8279-0,7-30.61138--6.11610.fits",images/11-cutout-HSC-I-8279-pdr2_wide.fits,0 12 | --------------------------------------------------------------------------------