├── polaris ├── hub │ ├── __init__.py │ ├── settings.py │ └── oauth.py ├── utils │ ├── __init__.py │ ├── constants.py │ ├── zarr │ │ ├── __init__.py │ │ ├── _memmap.py │ │ ├── _utils.py │ │ ├── _manifest.py │ │ └── codecs.py │ ├── misc.py │ ├── context.py │ ├── errors.py │ ├── dict2html.py │ └── types.py ├── experimental │ └── __init__.py ├── prediction │ ├── __init__.py │ └── _predictions_v2.py ├── loader │ ├── __init__.py │ └── load.py ├── mixins │ ├── __init__.py │ ├── _format_text.py │ └── _checksum.py ├── _version.py ├── evaluate │ ├── metrics │ │ ├── __init__.py │ │ ├── generic_metrics.py │ │ └── docking_metrics.py │ ├── __init__.py │ ├── _metadata.py │ └── utils.py ├── dataset │ ├── converters │ │ ├── __init__.py │ │ ├── _base.py │ │ └── _zarr.py │ ├── __init__.py │ ├── _adapters.py │ └── _column.py ├── benchmark │ ├── __init__.py │ ├── _definitions.py │ ├── _task.py │ ├── _split.py │ └── _split_v2.py ├── __init__.py ├── cli.py ├── model │ └── __init__.py └── _artifact.py ├── docs ├── community │ └── community.md ├── api │ ├── model.md │ ├── competition.evaluation.md │ ├── competition.md │ ├── subset.md │ ├── adapters.md │ ├── load.md │ ├── base.md │ ├── utils.types.md │ ├── hub.external_client.md │ ├── hub.storage.md │ ├── benchmark.md │ ├── dataset.md │ ├── hub.client.md │ ├── factory.md │ ├── converters.md │ └── evaluation.md ├── images │ └── zarr.png ├── resources.md ├── assets │ └── css │ │ └── custom-polaris.css ├── index.md ├── tutorials │ ├── create_a_model.ipynb │ └── submit_to_competition.ipynb └── quickstart.md ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature-request.md │ ├── default-template.md │ └── bug-report.yml ├── workflows │ ├── code-check.yml │ ├── doc.yml │ ├── test.yml │ └── release.yml ├── PULL_REQUEST_TEMPLATE.md ├── changelog_config.json └── CODE_OF_CONDUCT.md ├── tests ├── test_import.py ├── test_codecs.py ├── test_hub_integration.py ├── test_oauth.py ├── test_type_checks.py ├── test_metrics.py ├── test_subset.py ├── test_storage.py ├── test_benchmark_predictions_v2.py ├── test_integration.py ├── test_competition.py ├── test_zarr_checksum.py ├── test_factory.py └── test_dataset_v2.py ├── NOTICE ├── env.yml ├── .gitignore ├── README.md ├── mkdocs.yml └── pyproject.toml /polaris/hub/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/community/community.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /polaris/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @cwognum 2 | -------------------------------------------------------------------------------- /polaris/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/api/model.md: -------------------------------------------------------------------------------- 1 | ::: polaris.model.Model 2 | 3 | --- -------------------------------------------------------------------------------- /docs/api/competition.evaluation.md: -------------------------------------------------------------------------------- 1 | ::: polaris.evaluate.CompetitionPredictions 2 | -------------------------------------------------------------------------------- /docs/api/competition.md: -------------------------------------------------------------------------------- 1 | ::: polaris.competition.CompetitionSpecification 2 | 3 | --- -------------------------------------------------------------------------------- /docs/images/zarr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/polaris-hub/polaris/HEAD/docs/images/zarr.png -------------------------------------------------------------------------------- /docs/api/subset.md: -------------------------------------------------------------------------------- 1 | ::: polaris.dataset.Subset 2 | options: 3 | members: no 4 | 5 | --- 6 | -------------------------------------------------------------------------------- /docs/api/adapters.md: -------------------------------------------------------------------------------- 1 | 2 | ::: polaris.dataset._adapters 3 | options: 4 | filters: ["!^_"] 5 | -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | def test_import(): 2 | """Sanity check.""" 3 | import polaris # noqa: F401 4 | -------------------------------------------------------------------------------- /docs/api/load.md: -------------------------------------------------------------------------------- 1 | 2 | ::: polaris.load_dataset 3 | 4 | --- 5 | 6 | ::: polaris.load_benchmark 7 | 8 | --- 9 | -------------------------------------------------------------------------------- /docs/api/base.md: -------------------------------------------------------------------------------- 1 | ::: polaris._artifact.BaseArtifactModel 2 | options: 3 | filters: ["!^_"] 4 | 5 | --- -------------------------------------------------------------------------------- /polaris/prediction/__init__.py: -------------------------------------------------------------------------------- 1 | from ._predictions_v2 import BenchmarkPredictionsV2 2 | 3 | __all__ = ["BenchmarkPredictionsV2"] 4 | -------------------------------------------------------------------------------- /docs/api/utils.types.md: -------------------------------------------------------------------------------- 1 | ::: polaris.utils.types 2 | options: 3 | show_root_heading: false 4 | show_root_toc_entry: false 5 | 6 | --- -------------------------------------------------------------------------------- /polaris/utils/constants.py: -------------------------------------------------------------------------------- 1 | import platformdirs 2 | 3 | # Default base dir to cache any data 4 | DEFAULT_CACHE_DIR = platformdirs.user_cache_dir("polaris") 5 | -------------------------------------------------------------------------------- /polaris/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import load_benchmark, load_dataset, load_competition, load_model 2 | 3 | _all__ = ["load_benchmark", "load_dataset", "load_competition", "load_model"] 4 | -------------------------------------------------------------------------------- /polaris/mixins/__init__.py: -------------------------------------------------------------------------------- 1 | from polaris.mixins._checksum import ChecksumMixin 2 | from polaris.mixins._format_text import FormattingMixin 3 | 4 | __all__ = ["ChecksumMixin", "FormattingMixin"] 5 | -------------------------------------------------------------------------------- /docs/api/hub.external_client.md: -------------------------------------------------------------------------------- 1 | ::: polaris.hub.external_client.ExternalAuthClient 2 | options: 3 | merge_init_into_class: true 4 | filters: ["!create_authorization_url", "!fetch_token"] 5 | --- 6 | -------------------------------------------------------------------------------- /polaris/_version.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import PackageNotFoundError, version 2 | 3 | try: 4 | __version__ = version("polaris-lib") 5 | except PackageNotFoundError: 6 | # package is not installed 7 | __version__ = "dev" 8 | -------------------------------------------------------------------------------- /docs/api/hub.storage.md: -------------------------------------------------------------------------------- 1 | ::: polaris.hub.storage.StorageSession 2 | options: 3 | merge_init_into_class: true 4 | 5 | --- 6 | 7 | ::: polaris.hub.storage.S3Store 8 | options: 9 | merge_init_into_class: true 10 | --- 11 | -------------------------------------------------------------------------------- /docs/api/benchmark.md: -------------------------------------------------------------------------------- 1 | ::: polaris.benchmark.BenchmarkV2Specification 2 | options: 3 | filters: ["!^_", "!md5sum", "!get_cache_path"] 4 | 5 | 6 | ::: polaris.benchmark.BenchmarkV1Specification 7 | options: 8 | filters: ["!^_", "!md5sum", "!get_cache_path"] 9 | 10 | --- -------------------------------------------------------------------------------- /polaris/evaluate/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from polaris.evaluate.metrics.docking_metrics import rmsd_coverage 2 | from polaris.evaluate.metrics.generic_metrics import ( 3 | absolute_average_fold_error, 4 | average_precision_score, 5 | cohen_kappa_score, 6 | pearsonr, 7 | spearman, 8 | ) 9 | -------------------------------------------------------------------------------- /docs/api/dataset.md: -------------------------------------------------------------------------------- 1 | ::: polaris.dataset.DatasetV2 2 | options: 3 | filters: ["!^_"] 4 | 5 | --- 6 | 7 | ::: polaris.dataset._base.BaseDataset 8 | options: 9 | filters: ["!^_"] 10 | 11 | --- 12 | 13 | ::: polaris.dataset.ColumnAnnotation 14 | options: 15 | filters: ["!^_"] 16 | 17 | --- 18 | 19 | -------------------------------------------------------------------------------- /docs/api/hub.client.md: -------------------------------------------------------------------------------- 1 | ::: polaris.hub.settings.PolarisHubSettings 2 | options: 3 | filters: ["!^_"] 4 | 5 | --- 6 | 7 | 8 | ::: polaris.hub.client.PolarisHubClient 9 | options: 10 | merge_init_into_class: true 11 | filters: ["!^_", "!create_authorization_url", "!fetch_token", "!request", "!token"] 12 | --- 13 | -------------------------------------------------------------------------------- /docs/api/factory.md: -------------------------------------------------------------------------------- 1 | ::: polaris.dataset.DatasetFactory 2 | options: 3 | filters: ["!^_"] 4 | 5 | --- 6 | 7 | ::: polaris.dataset.create_dataset_from_file 8 | options: 9 | filters: ["!^_"] 10 | 11 | --- 12 | 13 | ::: polaris.dataset.create_dataset_from_files 14 | options: 15 | filters: ["!^_"] 16 | 17 | --- 18 | -------------------------------------------------------------------------------- /polaris/dataset/converters/__init__.py: -------------------------------------------------------------------------------- 1 | from polaris.dataset.converters._base import Converter 2 | from polaris.dataset.converters._sdf import SDFConverter 3 | from polaris.dataset.converters._zarr import ZarrConverter 4 | from polaris.dataset.converters._pdb import PDBConverter 5 | 6 | 7 | __all__ = ["Converter", "SDFConverter", "ZarrConverter", "PDBConverter"] 8 | -------------------------------------------------------------------------------- /polaris/utils/zarr/__init__.py: -------------------------------------------------------------------------------- 1 | from ._checksum import ZarrFileChecksum, compute_zarr_checksum 2 | from ._manifest import generate_zarr_manifest 3 | from ._memmap import MemoryMappedDirectoryStore 4 | 5 | __all__ = [ 6 | "MemoryMappedDirectoryStore", 7 | "compute_zarr_checksum", 8 | "ZarrFileChecksum", 9 | "generate_zarr_manifest", 10 | ] 11 | -------------------------------------------------------------------------------- /polaris/mixins/_format_text.py: -------------------------------------------------------------------------------- 1 | class FormattingMixin: 2 | """Mixin class for formatting strings to be output in the console""" 3 | 4 | BOLD = "\033[1m" 5 | YELLOW = "\033[93m" 6 | _END_CODE = "\033[0m" 7 | 8 | def format(self, text: str, codes: str | list[str]): 9 | if not isinstance(codes, list): 10 | codes = [codes] 11 | 12 | return "".join(codes) + text + self._END_CODE 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: ❓ Discuss something on GitHub Discussions 4 | url: https://github.com/polaris-hub/polaris/discussions 5 | about: For questions like "How do I do X with Polaris?", you can move to GitHub Discussions. 6 | - name: ❓ Discuss something on Discord 7 | url: https://discord.gg/vBFd8p6H7u 8 | about: For more interactive discussions, you can join our Discord server. -------------------------------------------------------------------------------- /docs/api/converters.md: -------------------------------------------------------------------------------- 1 | ::: polaris.dataset.converters.Converter 2 | options: 3 | filters: ["!^_"] 4 | 5 | --- 6 | 7 | 8 | ::: polaris.dataset.converters.SDFConverter 9 | options: 10 | filters: ["!^_"] 11 | 12 | --- 13 | 14 | ::: polaris.dataset.converters.ZarrConverter 15 | options: 16 | filters: ["!^_"] 17 | 18 | --- 19 | 20 | ::: polaris.dataset.converters.PDBConverter 21 | options: 22 | filters: ["!^_"] 23 | 24 | --- 25 | -------------------------------------------------------------------------------- /tests/test_codecs.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | import zarr 3 | 4 | from polaris.utils.zarr.codecs import RDKitMolCodec 5 | 6 | 7 | def test_rdkit_mol_codec(): 8 | mol = dm.to_mol("C1=CC=CC=C1") 9 | 10 | arr = zarr.empty(shape=10, chunks=2, dtype=object, object_codec=RDKitMolCodec()) 11 | 12 | arr[0] = mol 13 | arr[1] = mol 14 | arr[2] = mol 15 | 16 | assert dm.same_mol(arr[0], mol) 17 | assert dm.same_mol(arr[1], mol) 18 | assert dm.same_mol(arr[2], mol) 19 | -------------------------------------------------------------------------------- /polaris/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from polaris.benchmark._base import ( 2 | BenchmarkV1Specification, 3 | BenchmarkV1Specification as BenchmarkSpecification, 4 | ) 5 | from polaris.benchmark._benchmark_v2 import BenchmarkV2Specification 6 | from polaris.benchmark._definitions import MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification 7 | 8 | __all__ = [ 9 | "BenchmarkSpecification", 10 | "BenchmarkV1Specification", 11 | "BenchmarkV2Specification", 12 | "SingleTaskBenchmarkSpecification", 13 | "MultiTaskBenchmarkSpecification", 14 | ] 15 | -------------------------------------------------------------------------------- /docs/resources.md: -------------------------------------------------------------------------------- 1 | # Resources 2 | 3 | ## Publications 4 | 5 | - Correspondence in Nature Biotechnology: [10.1038/s42256-024-00911-w](https://doi.org/10.1038/s42256-024-00911-w). 6 | - Preprint on Method Comparison Protocols: [10.26434/chemrxiv-2024-6dbwv-v2](https://doi.org/10.26434/chemrxiv-2024-6dbwv-v2). 7 | 8 | ## Talks 9 | 10 | - PyData London (June, 2024): [https://www.youtube.com/watch?v=YZDfD9D7mtE](https://www.youtube.com/watch?v=YZDfD9D7mtE) 11 | - MoML (June, 2024): [https://www.youtube.com/watch?v=Tsz_T1WyufI](https://www.youtube.com/watch?v=Tsz_T1WyufI) 12 | 13 | --- -------------------------------------------------------------------------------- /tests/test_hub_integration.py: -------------------------------------------------------------------------------- 1 | import polaris as po 2 | from polaris.benchmark._base import BenchmarkV1Specification 3 | from polaris.dataset._base import BaseDataset 4 | from polaris.hub.settings import PolarisHubSettings 5 | 6 | settings = PolarisHubSettings() 7 | 8 | 9 | def test_load_dataset_flow(): 10 | dataset = po.load_dataset("polaris/hello-world") 11 | assert isinstance(dataset, BaseDataset) 12 | 13 | 14 | def test_load_benchmark_flow(): 15 | benchmark = po.load_benchmark("polaris/hello-world-benchmark") 16 | assert isinstance(benchmark, BenchmarkV1Specification) 17 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2025 Valence Labs 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /docs/api/evaluation.md: -------------------------------------------------------------------------------- 1 | ::: polaris.evaluate.BenchmarkPredictions 2 | 3 | --- 4 | 5 | ::: polaris.evaluate.ResultsMetadata 6 | options: 7 | filters: ["!^_"] 8 | 9 | --- 10 | 11 | ::: polaris.evaluate.EvaluationResult 12 | 13 | --- 14 | 15 | ::: polaris.evaluate.BenchmarkResults 16 | 17 | --- 18 | 19 | ::: polaris.evaluate.MetricInfo 20 | 21 | --- 22 | 23 | ::: polaris.evaluate.Metric 24 | options: 25 | filters: ["!^_", "!fn", "!is_multitask", "!y_type"] 26 | 27 | --- 28 | 29 | ::: polaris.evaluate.metrics.generic_metrics 30 | ::: polaris.evaluate.metrics.docking_metrics 31 | 32 | --- 33 | -------------------------------------------------------------------------------- /polaris/utils/zarr/_memmap.py: -------------------------------------------------------------------------------- 1 | import mmap 2 | 3 | import zarr 4 | 5 | 6 | class MemoryMappedDirectoryStore(zarr.DirectoryStore): 7 | """ 8 | A Zarr Store to open chunks as memory-mapped files. 9 | See also [this Github issue](https://github.com/zarr-developers/zarr-python/issues/1245). 10 | 11 | Memory mapping leverages low-level OS functionality to reduce the time it takes 12 | to read the content of a file by directly mapping to memory. 13 | """ 14 | 15 | def _fromfile(self, fn): 16 | with open(fn, "rb") as fh: 17 | return memoryview(mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ)) 18 | -------------------------------------------------------------------------------- /.github/workflows/code-check.yml: -------------------------------------------------------------------------------- 1 | name: code-check 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | tags: [ "*" ] 7 | pull_request: 8 | branches: 9 | - "*" 10 | - "!gh-pages" 11 | 12 | jobs: 13 | 14 | python-lint-ruff: 15 | name: Python lint [ruff] 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout the code 19 | uses: actions/checkout@v4 20 | 21 | - name: Install uv 22 | uses: astral-sh/setup-uv@v5 23 | 24 | - name: Install the project 25 | run: uv sync --group dev 26 | 27 | - name: Lint 28 | run: uv run ruff check 29 | 30 | - name: Format 31 | run: uv run ruff format --check 32 | -------------------------------------------------------------------------------- /polaris/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from rich.logging import RichHandler 4 | 5 | from ._version import __version__ 6 | from .loader import load_benchmark, load_competition, load_dataset, load_model 7 | 8 | __all__ = ["load_dataset", "load_benchmark", "load_competition", "load_model", "__version__"] 9 | 10 | # Polaris specific logger 11 | logger = logging.getLogger(__name__) 12 | 13 | # Only add handler if the logger has not already been configured externally 14 | if not logger.handlers: 15 | handler = RichHandler(rich_tracebacks=True) 16 | handler.setFormatter(logging.Formatter("%(message)s", datefmt="[%Y-%m-%d %X]")) 17 | logger.addHandler(handler) 18 | logger.setLevel(logging.INFO) 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature request 3 | about: Suggest an idea for a new Polaris feature 4 | title: '' 5 | labels: feature 6 | assignees: '' 7 | --- 8 | 9 | ### Is your feature request related to a problem? Please describe. 10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 11 | 12 | ### Describe the solution you'd like 13 | A clear and concise description of what you want to happen. 14 | 15 | ### Describe alternatives you've considered 16 | A clear and concise description of any alternative solutions or features you've considered. 17 | 18 | ### Additional context 19 | Add any other context or screenshots about the feature request here. 20 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Changelogs 2 | 3 | - _enumerate the changes of that PR._ 4 | 5 | --- 6 | 7 | _Checklist:_ 8 | 9 | - [ ] _Was this PR discussed in an issue? It is recommended to first discuss a new feature into a GitHub issue before opening a PR._ 10 | - [ ] _Add tests to cover the fixed bug(s) or the newly introduced feature(s) (if appropriate)._ 11 | - [ ] _Update the API documentation if a new function is added, or an existing one is deleted._ 12 | - [ ] _Write concise and explanatory changelogs above._ 13 | - [ ] _If possible, assign one of the following labels to the PR: `feature`, `fix`, `chore`, `documentation` or `test` (or ask a maintainer to do it for you)._ 14 | 15 | --- 16 | 17 | _discussion related to that PR_ 18 | -------------------------------------------------------------------------------- /polaris/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality 2 | from polaris.dataset._dataset import DatasetV1 3 | from polaris.dataset._dataset import DatasetV1 as Dataset 4 | from polaris.dataset._dataset_v2 import DatasetV2 5 | from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files 6 | from polaris.dataset._subset import Subset 7 | from polaris.utils.zarr import codecs 8 | 9 | __all__ = [ 10 | "create_dataset_from_file", 11 | "create_dataset_from_files", 12 | "ColumnAnnotation", 13 | "Dataset", 14 | "DatasetFactory", 15 | "DatasetV1", 16 | "DatasetV2", 17 | "KnownContentType", 18 | "Modality", 19 | "Subset", 20 | ] 21 | -------------------------------------------------------------------------------- /.github/changelog_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "categories": [ 3 | { 4 | "title": "## 🚀 Features", 5 | "labels": ["feature"] 6 | }, 7 | { 8 | "title": "## 🐛 Fixes", 9 | "labels": ["fix"] 10 | }, 11 | { 12 | "key": "tests", 13 | "title": "## 🧪 Tests", 14 | "labels": ["test"] 15 | }, 16 | { 17 | "key": "docs", 18 | "title": "## 📚 Documentation", 19 | "labels": ["documentation"] 20 | }, 21 | { 22 | "key": "chore", 23 | "title": "## 🧹 Chores", 24 | "labels": ["chore"] 25 | }, 26 | { 27 | "title": "## 📦 Other", 28 | "labels": [] 29 | } 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/default-template.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Default Template 3 | about: Default, generic issue template 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | --- 8 | 9 | # Context 10 | 11 | _Provide some context for this issue: why is this change being requested, what constraint are there on the solution, are there any relevant artifacts(design documents, discussions, others) to this issue, etc._ 12 | 13 | # Description 14 | 15 | _Describe the expected work that will be needed to address the issue, leading into the following acceptance criteria. Add any relevant element that could impact the solution: limits, performance, security, compatibility, etc._ 16 | 17 | # Acceptance Criteria 18 | 19 | - List what needs to be checked and valid to determine that this issue can be closed 20 | 21 | # Links 22 | 23 | - [Link to other issues/PRs/external tasks](www.example.com) -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | 4 | dependencies: 5 | - python >=3.10,<3.13 6 | - pip 7 | - typer 8 | - pyyaml 9 | - pydantic >=2 10 | - pydantic-settings >=2 11 | - fsspec 12 | - typing-extensions >=4.12.0 13 | - boto3 <1.36.0 14 | - pyroaring 15 | - rich >=13.9.4 16 | 17 | # Hub client 18 | - authlib 19 | - httpx 20 | - requests 21 | - aiohttp 22 | 23 | # Scientific 24 | - numpy < 3 25 | - pandas 26 | - scipy 27 | - scikit-learn 28 | - seaborn 29 | 30 | # Chemistry 31 | - datamol >=0.12.1 32 | - fastpdb 33 | 34 | # Storage 35 | - zarr >=2,<3 36 | - pyarrow <18 37 | - numcodecs >=0.13.1,<0.16.0 38 | 39 | # Dev 40 | - pytest 41 | - pytest-xdist 42 | - pytest-cov 43 | - ruff 44 | - jupyterlab 45 | - ipywidgets 46 | - moto >=5.0.0 47 | 48 | # Doc 49 | - mkdocs 50 | - mkdocs-material >=9.4.7 51 | - mkdocstrings 52 | - mkdocstrings-python 53 | - mkdocs-jupyter >=0.24.8 54 | - markdown-include 55 | - mdx_truly_sane_lists 56 | - nbconvert 57 | - mike >=1.0.0 58 | -------------------------------------------------------------------------------- /.github/workflows/doc.yml: -------------------------------------------------------------------------------- 1 | name: doc 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | 7 | # Prevent doc action on `main` to conflict with each others. 8 | concurrency: 9 | group: doc-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | doc: 14 | runs-on: "ubuntu-latest" 15 | timeout-minutes: 30 16 | 17 | defaults: 18 | run: 19 | shell: bash -l {0} 20 | 21 | steps: 22 | - name: Checkout the code 23 | uses: actions/checkout@v4 24 | 25 | - name: Install uv 26 | uses: astral-sh/setup-uv@v5 27 | 28 | - name: Install the project 29 | run: uv sync --group doc 30 | 31 | - name: Configure git 32 | run: | 33 | git config --global user.name "${GITHUB_ACTOR}" 34 | git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com" 35 | 36 | - name: Deploy the doc 37 | run: | 38 | echo "Get the gh-pages branch" 39 | git fetch origin gh-pages 40 | 41 | echo "Build and deploy the doc on main" 42 | uv run mike deploy --push main 43 | -------------------------------------------------------------------------------- /polaris/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from polaris.evaluate._metadata import ResultsMetadataV1, ResultsMetadataV2 2 | from polaris.evaluate._metadata import ResultsMetadataV1 as ResultsMetadata 3 | from polaris.evaluate._metric import Metric, MetricInfo 4 | from polaris.evaluate._predictions import BenchmarkPredictions, CompetitionPredictions 5 | from polaris.evaluate._results import ( 6 | BenchmarkResultsV1 as BenchmarkResults, 7 | BenchmarkResultsV1, 8 | BenchmarkResultsV2, 9 | CompetitionResults, 10 | EvaluationResultV1 as EvaluationResult, 11 | EvaluationResultV1, 12 | EvaluationResultV2, 13 | ) 14 | from polaris.evaluate.utils import evaluate_benchmark 15 | 16 | __all__ = [ 17 | "ResultsMetadata", 18 | "ResultsMetadataV1", 19 | "ResultsMetadataV2", 20 | "Metric", 21 | "MetricInfo", 22 | "EvaluationResult", 23 | "EvaluationResultV1", 24 | "EvaluationResultV2", 25 | "BenchmarkResults", 26 | "BenchmarkResultsV1", 27 | "BenchmarkResultsV2", 28 | "CompetitionResults", 29 | "evaluate_benchmark", 30 | "CompetitionPredictions", 31 | "BenchmarkPredictions", 32 | ] 33 | -------------------------------------------------------------------------------- /polaris/dataset/converters/_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import TypeAlias 3 | 4 | import pandas as pd 5 | 6 | from polaris.dataset import ColumnAnnotation 7 | from polaris.dataset._adapters import Adapter 8 | from polaris.dataset._dataset import _INDEX_SEP 9 | 10 | FactoryProduct: TypeAlias = tuple[pd.DataFrame, dict[str, ColumnAnnotation], dict[str, Adapter]] 11 | 12 | 13 | class Converter(abc.ABC): 14 | @abc.abstractmethod 15 | def convert(self, path: str, append: bool = False) -> FactoryProduct: 16 | """This converts a file into a table and possibly annotations""" 17 | raise NotImplementedError 18 | 19 | @staticmethod 20 | def get_pointer(column: str, index: int | slice) -> str: 21 | """ 22 | Creates a pointer. 23 | 24 | Args: 25 | column: The name of the column. Each column has its own group in the root. 26 | index: The index or slice of the pointer. 27 | """ 28 | if isinstance(index, slice): 29 | index_substr = f"{_INDEX_SEP}{index.start}:{index.stop}" 30 | else: 31 | index_substr = f"{_INDEX_SEP}{index}" 32 | return f"{column}{index_substr}" 33 | -------------------------------------------------------------------------------- /docs/assets/css/custom-polaris.css: -------------------------------------------------------------------------------- 1 | :root { 2 | 3 | /* 4 | For a list of all available variables, see 5 | https://github.com/squidfunk/mkdocs-material/blob/master/src/assets/stylesheets/main/_colors.scss 6 | */ 7 | --polaris-primary: hsla(236, 100%, 19%, 1.0); 8 | --polaris-secondary: hsla(290, 61%, 43%, 1.0); 9 | --polaris-ternary: hsla(236, 100%, 9%, 1.0); 10 | } 11 | 12 | /* Change the header background to use a gradient */ 13 | .md-header { 14 | background-image: linear-gradient(to right, var(--polaris-secondary), var(--polaris-primary)); 15 | } 16 | 17 | /* Change the footer background to use a gradient */ 18 | .md-footer { 19 | background-image: linear-gradient(to right, var(--polaris-primary), var(--polaris-ternary)); 20 | } 21 | 22 | /* Change the tabs background to use a gradient */ 23 | .md-tabs { 24 | background-image: linear-gradient(to right, #F4F6F9, #dfc3e2); 25 | color: var(--polaris-ternary); 26 | } 27 | 28 | /* Remove the `In` and `Out` block in rendered Jupyter notebooks */ 29 | .md-container .jp-Cell-outputWrapper .jp-OutputPrompt.jp-OutputArea-prompt, 30 | .md-container .jp-Cell-inputWrapper .jp-InputPrompt.jp-InputArea-prompt { 31 | display: none !important; 32 | } 33 | -------------------------------------------------------------------------------- /polaris/dataset/_adapters.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto, unique 2 | 3 | import datamol as dm 4 | 5 | 6 | @unique 7 | class Adapter(Enum): 8 | """ 9 | Adapters are predefined callables that change the format of the data. 10 | Adapters are serializable and can thus be saved alongside datasets. 11 | 12 | Attributes: 13 | SMILES_TO_MOL: Convert a SMILES string to a RDKit molecule. 14 | BYTES_TO_MOL: Convert a RDKit binary string to a RDKit molecule. 15 | ARRAY_TO_PDB: Convert a Zarr arrays to PDB arrays. 16 | """ 17 | 18 | SMILES_TO_MOL = auto() 19 | BYTES_TO_MOL = auto() 20 | ARRAY_TO_PDB = auto() 21 | 22 | def __call__(self, data): 23 | # Import here to prevent a cyclic import 24 | # Given the close coupling between `zarr_to_pdb` and the PDB converter, 25 | # we wanted to keep those functions in one file which was leading to a cyclic import. 26 | from polaris.dataset.converters._pdb import zarr_to_pdb 27 | 28 | conversion_map = {"SMILES_TO_MOL": dm.to_mol, "BYTES_TO_MOL": dm.Mol, "ARRAY_TO_PDB": zarr_to_pdb} 29 | 30 | if isinstance(data, tuple): 31 | return tuple(conversion_map[self.name](d) for d in data) 32 | return conversion_map[self.name](data) 33 | -------------------------------------------------------------------------------- /polaris/utils/misc.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | 5 | from polaris.utils.types import ( 6 | ListOrArrayType, 7 | SlugCompatibleStringType, 8 | SlugStringType, 9 | ArtifactUrn, 10 | HubOwner, 11 | ) 12 | 13 | 14 | def listit(t: Any): 15 | """ 16 | Converts all tuples in a possibly nested object to lists 17 | https://stackoverflow.com/questions/1014352/how-do-i-convert-a-nested-tuple-of-tuples-and-lists-to-lists-of-lists-in-python 18 | """ 19 | return list(map(listit, t)) if isinstance(t, (list, tuple)) else t 20 | 21 | 22 | def slugify(sluggable: SlugCompatibleStringType) -> SlugStringType: 23 | """ 24 | Converts a slug-compatible string to a slug. 25 | """ 26 | return sluggable.lower().replace("_", "-").strip("-") 27 | 28 | 29 | def convert_lists_to_arrays(predictions: ListOrArrayType | dict) -> np.ndarray | dict: 30 | """ 31 | Recursively converts all plain Python lists in the predictions object to numpy arrays 32 | """ 33 | 34 | def convert_to_array(v): 35 | if isinstance(v, np.ndarray): 36 | return v 37 | elif isinstance(v, list): 38 | return np.array(v) 39 | elif isinstance(v, dict): 40 | return {k: convert_to_array(v) for k, v in v.items()} 41 | 42 | return convert_to_array(predictions) 43 | 44 | 45 | def build_urn(artifact_type: str, owner: str | HubOwner, slug: str) -> ArtifactUrn: 46 | return f"urn:polaris:{artifact_type}:{owner}:{slug}" 47 | -------------------------------------------------------------------------------- /polaris/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | import typer 4 | 5 | from polaris.hub.client import PolarisHubClient 6 | from polaris.hub.settings import PolarisHubSettings 7 | 8 | app = typer.Typer( 9 | add_completion=False, 10 | help="Polaris is a framework for benchmarking methods in drug discovery.", 11 | ) 12 | 13 | 14 | @app.command("login") 15 | def login( 16 | client_env_file: Annotated[ 17 | str, typer.Option(help="Environment file to overwrite the default environment variables") 18 | ] = ".env", 19 | auto_open_browser: Annotated[ 20 | bool, typer.Option(help="Whether to automatically open the link in a browser to retrieve the token") 21 | ] = True, 22 | overwrite: Annotated[ 23 | bool, typer.Option(help="Whether to overwrite the access token if you are already logged in") 24 | ] = False, 25 | ): 26 | """Authenticate to the Polaris Hub. 27 | 28 | This CLI will use the OAuth2 protocol to gain token-based access to the Polaris Hub API. 29 | """ 30 | client = PolarisHubClient(settings=PolarisHubSettings(_env_file=client_env_file)) 31 | client.login(auto_open_browser=auto_open_browser, overwrite=overwrite) 32 | 33 | 34 | @app.command(hidden=True) 35 | def secret(): 36 | # NOTE (cwognum): Empty, hidden command to force Typer to not collapse the subcommand. 37 | # Added because I anticipate we will want to add more subcommands later on. This will keep 38 | # the API consistent in the meantime. Once there are other subcommands, it can be removed. 39 | # See also: https://github.com/tiangolo/typer/issues/315 40 | raise NotImplementedError() 41 | 42 | 43 | if __name__ == "__main__": 44 | app() 45 | -------------------------------------------------------------------------------- /polaris/utils/zarr/_utils.py: -------------------------------------------------------------------------------- 1 | import zarr 2 | import zarr.storage 3 | 4 | from polaris.utils.errors import InvalidZarrCodec 5 | 6 | try: 7 | # Register imagecodecs if they are available. 8 | from imagecodecs.numcodecs import register_codecs 9 | 10 | register_codecs() 11 | except ImportError: 12 | pass 13 | 14 | 15 | def load_zarr_group_to_memory(group: zarr.Group) -> dict: 16 | """Loads an entire Zarr group into memory.""" 17 | 18 | if isinstance(group, dict): 19 | # If a Zarr group is already loaded to memory (e.g. with dataset.load_to_memory()), 20 | # the adapter would receive a dictionary instead of a Zarr group. 21 | return group 22 | 23 | data = {} 24 | for key, item in group.items(): 25 | if isinstance(item, zarr.Array): 26 | data[key] = item[:] 27 | elif isinstance(item, zarr.Group): 28 | data[key] = load_zarr_group_to_memory(item) 29 | return data 30 | 31 | 32 | def check_zarr_codecs(group: zarr.Group): 33 | """Check if all codecs in the Zarr group are registered.""" 34 | try: 35 | for key, item in group.items(): 36 | if isinstance(item, zarr.Group): 37 | check_zarr_codecs(item) 38 | 39 | except ValueError as error: 40 | # Zarr raises a generic ValueError if a codec is not registered. 41 | # See also: https://github.com/zarr-developers/zarr-python/issues/2508 42 | prefix = "codec not available: " 43 | error_message = str(error) 44 | 45 | if not error_message.startswith(prefix): 46 | raise error 47 | 48 | # Remove prefix and apostrophes 49 | codec_id = error_message.removeprefix(prefix).strip("'") 50 | raise InvalidZarrCodec(codec_id) 51 | -------------------------------------------------------------------------------- /polaris/utils/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from contextvars import ContextVar 3 | from itertools import cycle 4 | 5 | from rich.progress import ( 6 | BarColumn, 7 | MofNCompleteColumn, 8 | Progress, 9 | SpinnerColumn, 10 | TextColumn, 11 | TimeElapsedColumn, 12 | ) 13 | 14 | # Singleton Progress instance to be used for all calls to `track_progress` 15 | progress_instance = ContextVar( 16 | "progress", 17 | default=Progress( 18 | SpinnerColumn(), 19 | TextColumn("[progress.description]{task.description}"), 20 | BarColumn(), 21 | MofNCompleteColumn(), 22 | TimeElapsedColumn(), 23 | ), 24 | ) 25 | 26 | colors = cycle( 27 | { 28 | "green", 29 | "cyan", 30 | "magenta", 31 | } 32 | ) 33 | 34 | 35 | @contextmanager 36 | def track_progress(description: str, total: float | None = 1.0): 37 | """ 38 | Use the Progress instance to track a task's progress 39 | """ 40 | progress = progress_instance.get() 41 | 42 | # Make sure the Progress is started 43 | progress.start() 44 | 45 | task = progress.add_task(f"[{next(colors)}]{description}", total=total) 46 | 47 | try: 48 | # Yield the task and Progress instance, for more granular control 49 | yield progress, task 50 | 51 | # Mark the task as completed 52 | progress.update(task, completed=total, refresh=True) 53 | progress.log(f"[green] Success: {description}") 54 | except Exception: 55 | progress.log(f"[red] Error: {description}") 56 | raise 57 | finally: 58 | # Remove the task from the UI, and stop the progress bar if all tasks are completed 59 | progress.remove_task(task) 60 | if progress.finished: 61 | progress.stop() 62 | -------------------------------------------------------------------------------- /tests/test_oauth.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from polaris.hub.oauth import CachedTokenAuth 4 | 5 | 6 | def test_cached_token_auth_empty_on_no_cache(tmp_path): 7 | filename = "test_token.json" 8 | auth = CachedTokenAuth(cache_dir=tmp_path, filename=filename) 9 | assert not auth.token 10 | 11 | 12 | def test_cached_token_auth_reads_from_cache(tmp_path): 13 | filename = "test_token.json" 14 | cache_file = tmp_path / filename 15 | cache_file.write_text( 16 | json.dumps( 17 | { 18 | "access_token": "test_token", 19 | "issued_token_type": "urn:ietf:params:oauth:token-type:jwt", 20 | "token_type": "Bearer", 21 | "expires_in": 576618, 22 | "expires_at": 1720122005, 23 | } 24 | ) 25 | ) 26 | 27 | auth = CachedTokenAuth(cache_dir=tmp_path, filename=filename) 28 | 29 | assert auth.token is not None 30 | assert auth.token["access_token"] == "test_token" 31 | assert auth.token["expires_at"] == 1720122005 32 | assert auth.token["expires_in"] == 576618 33 | assert auth.token["token_type"] == "Bearer" 34 | assert auth.token["issued_token_type"] == "urn:ietf:params:oauth:token-type:jwt" 35 | 36 | 37 | def test_cached_token_auth_writes_to_cache(tmp_path): 38 | filename = "test_token.json" 39 | cache_file = tmp_path / filename 40 | 41 | auth = CachedTokenAuth(cache_dir=tmp_path, filename=filename) 42 | auth.set_token( 43 | { 44 | "access_token": "test_token", 45 | "issued_token_type": "urn:ietf:params:oauth:token-type:jwt", 46 | "token_type": "Bearer", 47 | "expires_in": 576618, 48 | "expires_at": 1720122005, 49 | } 50 | ) 51 | 52 | assert cache_file.exists() 53 | assert json.loads(cache_file.read_text()) == auth.token 54 | -------------------------------------------------------------------------------- /polaris/benchmark/_definitions.py: -------------------------------------------------------------------------------- 1 | from typing import Collection 2 | 3 | from pydantic import computed_field, field_validator 4 | 5 | from polaris.benchmark._base import BenchmarkV1Specification 6 | from polaris.utils.types import TaskType 7 | 8 | 9 | class SingleTaskMixin: 10 | """ 11 | Mixin for single-task benchmarks. 12 | """ 13 | 14 | @field_validator("target_cols", check_fields=False) 15 | @classmethod 16 | def validate_target_cols(cls, v: Collection[str]) -> Collection[str]: 17 | if len(v) != 1: 18 | raise ValueError("A single-task benchmark should specify exactly one target column.") 19 | return v 20 | 21 | @computed_field 22 | @property 23 | def task_type(self) -> str: 24 | """Return SINGLE_TASK for single-task benchmarks.""" 25 | return TaskType.SINGLE_TASK.value 26 | 27 | 28 | class MultiTaskMixin: 29 | """ 30 | Mixin for multi-task benchmarks. 31 | """ 32 | 33 | @field_validator("target_cols", check_fields=False) 34 | @classmethod 35 | def validate_target_cols(cls, v: Collection[str]) -> Collection[str]: 36 | if len(v) <= 1: 37 | raise ValueError("A multi-task benchmark should specify at least two target columns.") 38 | return v 39 | 40 | @computed_field 41 | @property 42 | def task_type(self) -> str: 43 | """ 44 | Return MULTI_TASK for multi-task benchmarks. 45 | """ 46 | return TaskType.MULTI_TASK.value 47 | 48 | 49 | class SingleTaskBenchmarkSpecification(SingleTaskMixin, BenchmarkV1Specification): 50 | """ 51 | Single-task benchmark for the base specification. 52 | """ 53 | 54 | pass 55 | 56 | 57 | class MultiTaskBenchmarkSpecification(MultiTaskMixin, BenchmarkV1Specification): 58 | """ 59 | Multitask benchmark for the base specification. 60 | """ 61 | 62 | pass 63 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 File a bug report 2 | description: X's behavior is deviating from its documented behavior. 3 | labels: ["bug"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Please provide the following information. 9 | - type: input 10 | id: Polaris-version 11 | attributes: 12 | label: Polaris version 13 | description: Value of ``polaris.__version__`` 14 | placeholder: 0.2.5, 0.3.0, 0.3.1, etc. 15 | validations: 16 | required: true 17 | - type: input 18 | id: Python-version 19 | attributes: 20 | label: Python Version 21 | description: Version of Python interpreter 22 | placeholder: 3.9, 3.10, 3.11, etc. 23 | validations: 24 | required: true 25 | - type: input 26 | id: OS 27 | attributes: 28 | label: Operating System 29 | description: Operating System 30 | placeholder: (Linux/Windows/Mac) 31 | validations: 32 | required: true 33 | - type: input 34 | id: installation 35 | attributes: 36 | label: Installation 37 | description: How was Polaris installed? 38 | placeholder: e.g., "using pip into virtual environment", or "using conda" 39 | validations: 40 | required: true 41 | - type: textarea 42 | id: description 43 | attributes: 44 | label: Description 45 | description: Explain why the current behavior is a problem, what the expected output/behaviour is, and why the expected output/behaviour is a better solution. 46 | validations: 47 | required: true 48 | - type: textarea 49 | id: reproduce 50 | attributes: 51 | label: Steps to reproduce 52 | description: Minimal, reproducible code sample, a copy-pastable example if possible. 53 | validations: 54 | required: true 55 | - type: textarea 56 | id: additional-output 57 | attributes: 58 | label: Additional output 59 | description: If you think it might be relevant, please provide the output from ``pip freeze`` or ``conda env export`` depending on which was used to install Polaris. -------------------------------------------------------------------------------- /polaris/evaluate/metrics/generic_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | from sklearn.metrics import average_precision_score as sk_average_precision_score 4 | from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa_score 5 | 6 | 7 | def pearsonr(y_true: np.ndarray, y_pred: np.ndarray): 8 | """Calculate a pearson r correlation""" 9 | return stats.pearsonr(y_true, y_pred).statistic 10 | 11 | 12 | def spearman(y_true: np.ndarray, y_pred: np.ndarray): 13 | """Calculate a Spearman correlation""" 14 | return stats.spearmanr(y_true, y_pred).statistic 15 | 16 | 17 | def absolute_average_fold_error(y_true: np.ndarray, y_pred: np.ndarray) -> float: 18 | """ 19 | Calculate the Absolute Average Fold Error (AAFE) metric. 20 | It measures the fold change between predicted values and observed values. 21 | The implementation is based on [this paper](https://pubs.acs.org/doi/10.1021/acs.chemrestox.3c00305). 22 | 23 | Args: 24 | y_true: The true target values of shape (n_samples,) 25 | y_pred: The predicted target values of shape (n_samples,). 26 | 27 | Returns: 28 | aafe: The Absolute Average Fold Error. 29 | """ 30 | if len(y_true) != len(y_pred): 31 | raise ValueError("Length of y_true and y_pred must be the same.") 32 | 33 | if np.any(y_true == 0): 34 | raise ValueError("`y_true` contains zero which will result `Inf` value.") 35 | 36 | aafe = np.mean(np.abs(y_pred) / np.abs(y_true)) 37 | 38 | return aafe 39 | 40 | 41 | def cohen_kappa_score(y_true, y_pred, **kwargs): 42 | """Scikit learn cohen_kappa_score wraper with renamed arguments""" 43 | return sk_cohen_kappa_score(y1=y_true, y2=y_pred, **kwargs) 44 | 45 | 46 | def average_precision_score(y_true, y_score, **kwargs): 47 | """Scikit learn average_precision_score wrapper that throws an error if y_true has no positive class""" 48 | if len(y_true) == 0 or not np.any(y_true): 49 | raise ValueError("Average precision requires at least a single positive class") 50 | return sk_average_precision_score(y_true=y_true, y_score=y_score, **kwargs) 51 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | Welcome to the Polaris documentation! 4 | 5 | ## What is Polaris? 6 | 7 | !!! info "Our mission" 8 | 9 | Polaris is on a mission to bring innovators and practitioners closer together to develop methods that matter. 10 | 11 | Polaris is an optimistic community that fundamentally believes in the ability of Machine Learning to radically improve lives by disrupting the drug discovery process. However, we recognize that the absence of standardized, domain-appropriate datasets, guidelines, and tools for method evaluation is limiting its current impact. 12 | 13 | Polaris is a Python library designed to interact with the [Polaris Hub](https://www.polarishub.io). Our aim is to build the leading benchmarking platform for drug discovery, promoting the use of high-quality resources and domain-appropriate evaluation protocols. Learn more through our [blog posts](https://polarishub.io/blog). 14 | 15 | ## Where to next? 16 | 17 | --- 18 | 19 | **:fontawesome-solid-rocket: Quickstart** 20 | 21 | If you are entirely new to Polaris, this is the place to start! Learn about the essential concepts and partake in your first benchmark. 22 | 23 | [:material-arrow-right: Let's get started](./quickstart.md) 24 | 25 | 26 | --- 27 | 28 | **:fontawesome-solid-graduation-cap: Tutorials** 29 | 30 | Dive deeper into the Polaris code and learn about advanced concepts to create your own benchmarks and datasets. 31 | 32 | [:material-arrow-right: Let's get started](./tutorials/submit_to_benchmark.ipynb) 33 | 34 | --- 35 | 36 | **:fontawesome-solid-code: API Reference** 37 | 38 | This is where you will find the technical documentation of the code itself. Learn the intricate details of how the various methods and classes work. 39 | 40 | [:material-arrow-right: Let's get started](./api/dataset.md) 41 | 42 | --- 43 | 44 | **:fontawesome-solid-comments: Community** 45 | 46 | Whether you are a first-time contributor or open-source veteran, we welcome any contribution to Polaris. Learn more about our community initiatives. 47 | 48 | [:material-arrow-right: Let's get started](https://discord.gg/vBFd8p6H7u) 49 | 50 | --- 51 | 52 | -------------------------------------------------------------------------------- /polaris/evaluate/_metadata.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from pydantic import Field, PrivateAttr, computed_field 4 | 5 | from polaris._artifact import BaseArtifactModel 6 | from polaris.utils.dict2html import dict2html 7 | from polaris.utils.types import HttpUrlString, HubUser 8 | from polaris.model import Model 9 | 10 | 11 | class ResultsMetadataV1(BaseArtifactModel): 12 | """V1 implementation of evaluation results without model field support 13 | 14 | Attributes: 15 | github_url: The URL to the code repository that was used to generate these results. 16 | paper_url: The URL to the paper describing the methodology used to generate these results. 17 | contributors: The users that are credited for these results. 18 | 19 | For additional metadata attributes, see the base classes. 20 | """ 21 | 22 | # Additional metadata 23 | github_url: HttpUrlString | None = Field(None, alias="code_url") 24 | paper_url: HttpUrlString | None = Field(None, alias="report_url") 25 | contributors: list[HubUser] = Field(default_factory=list) 26 | 27 | # Private attributes 28 | _created_at: datetime = PrivateAttr(default_factory=datetime.now) 29 | 30 | def _repr_html_(self) -> str: 31 | """For pretty-printing in Jupyter Notebooks""" 32 | return dict2html(self.model_dump()) 33 | 34 | def __repr__(self): 35 | return self.model_dump_json(indent=2) 36 | 37 | 38 | class ResultsMetadataV2(BaseArtifactModel): 39 | """V2 implementation of evaluation results with model field replacing URLs 40 | 41 | Attributes: 42 | model: The model that was used to generate these results. 43 | contributors: The users that are credited for these results. 44 | 45 | For additional metadata attributes, see the base classes. 46 | """ 47 | 48 | # Additional metadata 49 | model: Model | None = Field(None, exclude=True) 50 | contributors: list[HubUser] = Field(default_factory=list) 51 | 52 | # Private attributes 53 | _created_at: datetime = PrivateAttr(default_factory=datetime.now) 54 | 55 | @computed_field 56 | @property 57 | def model_artifact_id(self) -> str: 58 | return self.model.artifact_id if self.model else None 59 | 60 | def _repr_html_(self) -> str: 61 | return dict2html(self.model_dump()) 62 | 63 | def __repr__(self): 64 | return self.model_dump_json(indent=2) 65 | -------------------------------------------------------------------------------- /polaris/model/__init__.py: -------------------------------------------------------------------------------- 1 | from polaris._artifact import BaseArtifactModel 2 | from polaris.utils.types import HttpUrlString 3 | from polaris.utils.types import HubOwner 4 | from pydantic import Field 5 | 6 | 7 | class Model(BaseArtifactModel): 8 | """ 9 | Represents a Model artifact in the Polaris ecosystem. 10 | 11 | A Model artifact serves as a centralized representation of a method or model, encapsulating its metadata. 12 | It can be associated with multiple result artifacts but is immutable after upload, except for the README field. 13 | 14 | Examples: 15 | Basic API usage: 16 | ```python 17 | from polaris.model import Model 18 | 19 | # Create a new Model Card 20 | model = Model( 21 | name="MolGPS", 22 | description="Graph transformer foundation model for molecular modeling", 23 | code_url="https://github.com/datamol-io/graphium" 24 | ) 25 | 26 | # Upload the model card to the Hub 27 | model.upload_to_hub(owner="recursion") 28 | ``` 29 | 30 | Attributes: 31 | readme (str): A detailed README describing the model. 32 | code_url (HttpUrlString | None): Optional URL pointing to the model's code repository. 33 | report_url (HttpUrlString | None): Optional URL linking to a report or publication related to the model. 34 | artifact_version: The version of the model. 35 | artifact_changelog: A description of the changes made in this model version. 36 | 37 | Methods: 38 | upload_to_hub(owner: HubOwner | str | None = None): 39 | Uploads the model artifact to the Polaris Hub, associating it with a specified owner. 40 | 41 | For additional metadata attributes, see the base class. 42 | """ 43 | 44 | _artifact_type = "model" 45 | 46 | readme: str = "" 47 | code_url: HttpUrlString | None = None 48 | report_url: HttpUrlString | None = None 49 | 50 | # Version-related fields 51 | artifact_version: int = Field(default=1, frozen=True) 52 | artifact_changelog: str | None = None 53 | 54 | def upload_to_hub( 55 | self, 56 | owner: HubOwner | str | None = None, 57 | parent_artifact_id: str | None = None, 58 | ): 59 | """ 60 | Uploads the model to the Polaris Hub. 61 | """ 62 | from polaris.hub.client import PolarisHubClient 63 | 64 | with PolarisHubClient() as client: 65 | client.upload_model(self, owner=owner, parent_artifact_id=parent_artifact_id) 66 | -------------------------------------------------------------------------------- /polaris/utils/zarr/_manifest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from hashlib import md5 3 | from pathlib import Path 4 | 5 | from pyarrow import Table, schema, string 6 | from pyarrow.parquet import write_table 7 | 8 | # PyArrow table schema for the V2 Zarr manifest file 9 | ZARR_MANIFEST_SCHEMA = schema([("path", string()), ("md5_checksum", string())]) 10 | 11 | ROW_GROUP_SIZE = 128 * 1024 * 1024 # 128 MB 12 | 13 | 14 | def generate_zarr_manifest(zarr_root_path: str, output_dir: str) -> str: 15 | """ 16 | Entry point function which triggers the creation of a Zarr manifest for a V2 dataset. 17 | 18 | Parameters: 19 | zarr_root_path: The path to the root of a Zarr archive 20 | output_dir: The path to the directory which will hold the generated manifest 21 | """ 22 | zarr_manifest_path = f"{output_dir}/zarr_manifest.parquet" 23 | 24 | entries = manifest_entries(zarr_root_path, zarr_root_path) 25 | manifest = Table.from_pylist(mapping=entries, schema=ZARR_MANIFEST_SCHEMA) 26 | write_table(manifest, zarr_manifest_path, row_group_size=ROW_GROUP_SIZE) 27 | 28 | return zarr_manifest_path 29 | 30 | 31 | def manifest_entries(dir_path: str, root_path: str) -> list[dict[str, str]]: 32 | """ 33 | Recursive function that traverses a directory, returning entries consisting of every file's path and MD5 hash 34 | 35 | Parameters: 36 | dir_path: The path to the current directory being traversed 37 | root_path: The root path from which to compute a relative path 38 | """ 39 | entries = [] 40 | with os.scandir(dir_path) as it: 41 | for entry in it: 42 | if entry.is_file(): 43 | entries.append( 44 | { 45 | "path": str(Path(entry.path).relative_to(root_path)), 46 | "md5_checksum": calculate_file_md5(entry.path), 47 | } 48 | ) 49 | elif entry.is_dir(): 50 | entries.extend(manifest_entries(entry.path, root_path)) 51 | 52 | return entries 53 | 54 | 55 | def calculate_file_md5(file_path: str) -> str: 56 | """Calculates the md5 hash for a file at a given path""" 57 | 58 | md5_hash = md5() 59 | with open(file_path, "rb") as file: 60 | # 61 | # Read the file in chunks to avoid using too much memory for large files 62 | for chunk in iter(lambda: file.read(4096), b""): 63 | md5_hash.update(chunk) 64 | 65 | # Return the hex representation of the digest 66 | return md5_hash.hexdigest() 67 | -------------------------------------------------------------------------------- /polaris/evaluate/metrics/docking_metrics.py: -------------------------------------------------------------------------------- 1 | # This script includes docking related evaluation metrics. 2 | 3 | 4 | import numpy as np 5 | from rdkit.Chem.rdMolAlign import CalcRMS 6 | 7 | import datamol as dm 8 | 9 | 10 | def _rmsd(mol_probe: dm.Mol, mol_ref: dm.Mol) -> float: 11 | """Calculate RMSD between predicted molecule and closest ground truth molecule. 12 | The RMSD is calculated with first conformer of predicted molecule and only consider heavy atoms for RMSD calculation. 13 | It is assumed that the predicted binding conformers are extracted from the docking output, where the receptor (protein) coordinates have been aligned with the original crystal structure. 14 | 15 | Args: 16 | mol_probe: Predicted molecule (docked ligand) with exactly one conformer. 17 | mol_ref: Ground truth molecule (crystal ligand) with at least one conformer. If multiple conformers are 18 | present, the lowest RMSD will be reported. 19 | 20 | Returns: 21 | Returns the RMS between two molecules, taking symmetry into account. 22 | """ 23 | 24 | # copy the molecule for modification. 25 | mol_probe = dm.copy_mol(mol_probe) 26 | mol_ref = dm.copy_mol(mol_ref) 27 | 28 | # remove hydrogen from molecule 29 | mol_probe = dm.remove_hs(mol_probe) 30 | mol_ref = dm.remove_hs(mol_ref) 31 | 32 | # calculate RMSD 33 | return CalcRMS( 34 | prbMol=mol_probe, refMol=mol_ref, symmetrizeConjugatedTerminalGroups=True, prbId=-1, refId=-1 35 | ) 36 | 37 | 38 | def rmsd_coverage(y_pred: str | list[dm.Mol], y_true: str | list[dm.Mol], max_rsmd: float = 2): 39 | """ 40 | Calculate the coverage of molecules with an RMSD less than a threshold (2 Å by default) compared to the reference molecule conformer. 41 | 42 | It is assumed that the predicted binding conformers are extracted from the docking output, where the receptor (protein) coordinates have been aligned with the original crystal structure. 43 | 44 | Attributes: 45 | y_pred: List of predicted binding conformers. 46 | y_true: List of ground truth binding confoermers. 47 | max_rsmd: The threshold for determining acceptable rsmd. 48 | """ 49 | 50 | if len(y_pred) != len(y_true): 51 | assert ValueError( 52 | f"The list of probing molecules and the list of reference molecules are different sizes. {len(y_pred)} != {len(y_true)} " 53 | ) 54 | 55 | rmsds = np.array( 56 | [_rmsd(mol_probe=mol_probe, mol_ref=mol_ref) for mol_probe, mol_ref in zip(y_pred, y_true)] 57 | ) 58 | 59 | return np.sum(rmsds <= max_rsmd) / len(rmsds) 60 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | tags: [ "*" ] 7 | pull_request: 8 | branches: 9 | - "*" 10 | - "!gh-pages" 11 | schedule: 12 | - cron: "0 4 * * MON" 13 | 14 | concurrency: 15 | group: "test-${{ github.ref }}" 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | test-uv: 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | python-version: [ "3.10", "3.11", "3.12" ] 24 | os: [ "ubuntu-latest", "macos-latest", "windows-latest" ] 25 | 26 | runs-on: ${{ matrix.os }} 27 | timeout-minutes: 30 28 | 29 | defaults: 30 | run: 31 | shell: bash -l {0} 32 | 33 | name: PyPi os=${{ matrix.os }} - python=${{ matrix.python-version }} 34 | 35 | steps: 36 | - name: Checkout the code 37 | uses: actions/checkout@v4 38 | 39 | - name: Install uv 40 | uses: astral-sh/setup-uv@v5 41 | 42 | - name: Install the project 43 | run: uv sync --all-groups --python ${{ matrix.python-version }} 44 | 45 | - name: Run tests 46 | run: uv run pytest 47 | env: 48 | POLARIS_USERNAME: ${{ secrets.POLARIS_USERNAME }} 49 | POLARIS_PASSWORD: ${{ secrets.POLARIS_PASSWORD }} 50 | 51 | - name: Test CLI 52 | run: uv run polaris --help 53 | 54 | - name: Test building the doc 55 | run: uv run mkdocs build 56 | 57 | test-conda: 58 | strategy: 59 | fail-fast: false 60 | matrix: 61 | python-version: [ "3.10", "3.11", "3.12" ] 62 | os: [ "ubuntu-latest", "macos-latest", "windows-latest" ] 63 | 64 | runs-on: ${{ matrix.os }} 65 | timeout-minutes: 30 66 | 67 | defaults: 68 | run: 69 | shell: bash -l {0} 70 | 71 | name: Conda os=${{ matrix.os }} - python=${{ matrix.python-version }} 72 | 73 | steps: 74 | - name: Checkout the code 75 | uses: actions/checkout@v4 76 | 77 | - name: Setup mamba 78 | uses: mamba-org/setup-micromamba@v2 79 | with: 80 | environment-file: env.yml 81 | environment-name: polaris_testing_env 82 | cache-environment: true 83 | cache-downloads: true 84 | create-args: >- 85 | python=${{ matrix.python-version }} 86 | 87 | - name: Install library 88 | run: python -m pip install --no-deps . 89 | 90 | - name: Run pytest 91 | run: pytest 92 | env: 93 | POLARIS_USERNAME: ${{ secrets.POLARIS_USERNAME }} 94 | POLARIS_PASSWORD: ${{ secrets.POLARIS_PASSWORD }} 95 | 96 | - name: Test CLI 97 | run: polaris --help 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm files 132 | .idea/ 133 | # Rever 134 | rever/ 135 | 136 | # VS Code 137 | .vscode/ 138 | 139 | # Generated legacy requirements.txt 140 | requirements.txt 141 | 142 | # OS-specific files 143 | .DS_store 144 | -------------------------------------------------------------------------------- /polaris/dataset/converters/_zarr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from typing import TYPE_CHECKING 4 | 5 | import pandas as pd 6 | import zarr 7 | from typing_extensions import deprecated 8 | 9 | from polaris.dataset import ColumnAnnotation 10 | from polaris.dataset.converters._base import Converter, FactoryProduct 11 | 12 | if TYPE_CHECKING: 13 | from polaris.dataset import DatasetFactory 14 | 15 | 16 | @deprecated("Please use the custom codecs in `polaris.utils.zarr.codecs` instead.") 17 | class ZarrConverter(Converter): 18 | """Parse a [.zarr](https://zarr.readthedocs.io/en/stable/index.html) archive into a Polaris `Dataset`. 19 | 20 | Warning: Loading from `.zarr` 21 | Loading and saving datasets from and to `.zarr` is still experimental and currently not 22 | fully supported by the Hub. 23 | 24 | A `.zarr` file can contain groups and arrays, where each group can again contain groups and arrays. 25 | Within Polaris, the Zarr archive is expected to have a flat hierarchy where each array corresponds 26 | to a single column and each array contains the values for all datapoints in that column. 27 | """ 28 | 29 | def convert(self, path: str, factory: "DatasetFactory", append: bool = False) -> FactoryProduct: 30 | src = zarr.open(path, "r") 31 | 32 | v = next(src.group_keys(), None) 33 | if v is not None: 34 | raise ValueError("The root of the zarr hierarchy should only contain arrays.") 35 | 36 | # Copy to the source zarr, so everything is in one place 37 | pointer_start_dict = {col: 0 for col, _ in src.arrays()} 38 | if append: 39 | if not os.path.exists(factory.zarr_root.store.path): 40 | raise RuntimeError( 41 | f"Zarr store {factory.zarr_root.store.path} doesn't exist. \ 42 | Please make sure the zarr store {factory.zarr_root.store.path} is created. Or set `append` to `False`." 43 | ) 44 | else: 45 | for col, arr in src.arrays(): 46 | pointer_start_dict[col] += factory.zarr_root[col].shape[0] 47 | factory.zarr_root[col].append(arr) 48 | else: 49 | zarr.copy_store(source=src.store, dest=factory.zarr_root.store, if_exists="skip") 50 | 51 | # Construct the table 52 | # Parse any group into a column 53 | data = defaultdict(dict) 54 | for col, arr in src.arrays(): 55 | for i in range(len(arr)): 56 | data[col][i] = self.get_pointer(arr.name.removeprefix("/"), i) 57 | 58 | # Construct the dataset 59 | table = pd.DataFrame(data) 60 | return table, {k: ColumnAnnotation(is_pointer=True) for k in table.columns}, {} 61 | -------------------------------------------------------------------------------- /tests/test_type_checks.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | 6 | import polaris as po 7 | from polaris._artifact import BaseArtifactModel 8 | from polaris.utils.types import HttpUrlString, HubOwner 9 | 10 | 11 | def test_slug_string_type(): 12 | """ 13 | Verifies that the slug is validated correctly. 14 | Fails if: 15 | - Is too short (<4 characters) 16 | - Is too long (>64 characters) 17 | - Contains something other than lowercase letters, numbers, and hyphens. 18 | """ 19 | for name in [ 20 | "", 21 | "x", 22 | "xx", 23 | "xxx", 24 | "x" * 65, 25 | "invalid@", 26 | "invalid!", 27 | "InvalidName1", 28 | "invalid_name", 29 | ]: 30 | with pytest.raises(ValidationError): 31 | HubOwner(slug=name) 32 | 33 | for name in ["valid", "valid-name-1", "x" * 64, "x" * 4]: 34 | HubOwner(slug=name) 35 | 36 | 37 | def test_slug_compatible_string_type(): 38 | """Verifies that the artifact name is validated correctly.""" 39 | 40 | # Fails if: 41 | # - Is too short (<4 characters) 42 | # - Is too long (>64 characters) 43 | # - Contains non-alphanumeric characters 44 | for name in ["", "x", "xx", "xxx", "x" * 65, "invalid@", "invalid!"]: 45 | with pytest.raises(ValidationError): 46 | BaseArtifactModel(name=name) 47 | 48 | # Does not fail 49 | for name in [ 50 | "valid", 51 | "valid-name", 52 | "valid_name", 53 | "ValidName1", 54 | "Valid_", 55 | "Valid-", 56 | "x" * 64, 57 | "x" * 4, 58 | ]: 59 | BaseArtifactModel(name=name) 60 | 61 | 62 | def test_version(): 63 | with pytest.raises(ValidationError): 64 | BaseArtifactModel(polaris_version="invalid") 65 | assert BaseArtifactModel().polaris_version == po.__version__ 66 | assert BaseArtifactModel(polaris_version="0.1.2") 67 | 68 | 69 | def test_http_url_string(): 70 | """Verifies that a string validated correctly as a URL.""" 71 | 72 | class _TestModel(BaseModel): 73 | url: HttpUrlString 74 | 75 | m = _TestModel(url="https://example.com") 76 | assert isinstance(m.url, str) 77 | 78 | m = _TestModel(url="http://example.com") 79 | assert isinstance(m.url, str) 80 | 81 | m = _TestModel(url="http://example.io") 82 | assert isinstance(m.url, str) 83 | 84 | with warnings.catch_warnings(): 85 | # Crash if any warnings are raised 86 | warnings.simplefilter("error") 87 | m.model_dump() 88 | 89 | with pytest.raises(ValidationError): 90 | _TestModel(url="invalid") 91 | with pytest.raises(ValidationError): 92 | _TestModel(url="ftp://invalid.com") 93 | -------------------------------------------------------------------------------- /polaris/mixins/_checksum.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | import re 4 | 5 | from pydantic import BaseModel, PrivateAttr, computed_field 6 | 7 | from polaris.utils.errors import PolarisChecksumError 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class ChecksumMixin(BaseModel, abc.ABC): 13 | """ 14 | Mixin class to add checksum functionality to a class. 15 | """ 16 | 17 | _md5sum: str | None = PrivateAttr(None) 18 | 19 | @abc.abstractmethod 20 | def _compute_checksum(self) -> str: 21 | """Compute the checksum of the dataset.""" 22 | raise NotImplementedError 23 | 24 | @computed_field 25 | @property 26 | def md5sum(self) -> str: 27 | """Lazily compute the checksum once needed.""" 28 | if not self.has_md5sum: 29 | logger.info("Computing the checksum. This can be slow for large datasets.") 30 | self.md5sum = self._compute_checksum() 31 | return self._md5sum 32 | 33 | @md5sum.setter 34 | def md5sum(self, value: str): 35 | """Set the checksum.""" 36 | if not re.fullmatch(r"^[a-f0-9]{32}$", value): 37 | raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") 38 | self._md5sum = value 39 | 40 | @property 41 | def has_md5sum(self) -> bool: 42 | """Whether the md5sum for this class has been computed and stored.""" 43 | return self._md5sum is not None 44 | 45 | def verify_checksum(self, md5sum: str | None = None): 46 | """ 47 | Recomputes the checksum and verifies whether it matches the stored checksum. 48 | 49 | Warning: Slow operation 50 | This operation can be slow for large datasets. 51 | 52 | Info: Only works for locally stored datasets 53 | The checksum verification only works for datasets that are stored locally in its entirety. 54 | We don't have to verify the checksum for datasets stored on the Hub, as the Hub will do this on upload. 55 | And if you're streaming the data from the Hub, we will check the checksum of each chunk on download. 56 | """ 57 | if md5sum is None: 58 | md5sum = self._md5sum 59 | if md5sum is None: 60 | logger.warning( 61 | "No checksum to verify against. Specify either the md5sum parameter or " 62 | "store the checksum in the dataset.md5sum attribute." 63 | ) 64 | return 65 | 66 | # Recompute the checksum 67 | logger.info("To verify the checksum, we need to recompute it. This can be slow for large datasets.") 68 | self.md5sum = self._compute_checksum() 69 | 70 | if self.md5sum != md5sum: 71 | raise PolarisChecksumError( 72 | f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" 73 | ) 74 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from polaris.benchmark import BenchmarkV1Specification 6 | from polaris.dataset import Dataset 7 | from polaris.evaluate._metric import Metric 8 | 9 | 10 | def test_absolute_average_fold_error(): 11 | y_true = np.random.uniform(low=50, high=100, size=200) 12 | y_pred_1 = y_true + np.random.uniform(low=0, high=5, size=200) 13 | y_pred_2 = y_true + np.random.uniform(low=5, high=20, size=200) 14 | y_pred_3 = y_true - 10 15 | y_zero = np.zeros(shape=200) 16 | 17 | metric = Metric(label="absolute_average_fold_error") 18 | # Optimal value 19 | aafe_0 = metric.fn(y_true=y_true, y_pred=y_true) 20 | assert aafe_0 == 1 21 | 22 | # small fold change 23 | aafe_1 = metric.fn(y_true=y_true, y_pred=y_pred_1) 24 | assert aafe_1 > 1 25 | 26 | # larger fold change 27 | aafe_2 = metric.fn(y_true=y_true, y_pred=y_pred_2) 28 | assert aafe_2 > aafe_1 29 | 30 | # undershoot 31 | aafe_3 = metric.fn(y_true=y_true, y_pred=y_pred_3) 32 | assert aafe_3 < 1 33 | 34 | # y_true contains zeros 35 | with pytest.raises(ValueError): 36 | metric.fn(y_true=y_zero, y_pred=y_pred_3) 37 | 38 | 39 | def test_grouped_metric(): 40 | metric = Metric(label="accuracy", config={"group_by": "group"}) 41 | 42 | table = pd.DataFrame({"group": ["a", "b", "b"], "y_true": [1, 1, 1]}) 43 | dataset = Dataset(table=table) 44 | benchmark = BenchmarkV1Specification( 45 | dataset=dataset, 46 | metrics=[metric], 47 | main_metric=metric, 48 | target_cols=["y_true"], 49 | input_cols=["group"], 50 | split=([], [0, 1, 2]), 51 | ) 52 | 53 | result = benchmark.evaluate([1, 0, 0]) 54 | 55 | # The global accuracy is only 33%, but because we compute it per group and then average, it's 50%. 56 | assert result.results.Score.values[0] == 0.5 57 | 58 | 59 | def test_metric_hash(): 60 | metric_1 = Metric(label="accuracy") 61 | metric_2 = Metric(label="accuracy") 62 | metric_3 = Metric(label="mean_absolute_error") 63 | assert hash(metric_1) == hash(metric_2) 64 | assert hash(metric_1) != hash(metric_3) 65 | 66 | metric_4 = Metric(label="accuracy", config={"group_by": "group1"}) 67 | assert hash(metric_4) != hash(metric_1) 68 | 69 | metric_5 = Metric(label="accuracy", config={"group_by": "group2"}) 70 | assert hash(metric_4) != hash(metric_5) 71 | 72 | metric_6 = Metric(label="accuracy", config={"group_by": "group1"}) 73 | assert hash(metric_4) == hash(metric_6) 74 | 75 | 76 | def test_metric_name(): 77 | metric = Metric(label="accuracy") 78 | assert metric.name == "accuracy" 79 | 80 | metric = Metric(label="accuracy", config={"group_by": "group"}) 81 | assert metric.name == "accuracy_grouped" 82 | 83 | metric.custom_name = "custom_name" 84 | assert metric.name == "custom_name" 85 | -------------------------------------------------------------------------------- /polaris/dataset/_column.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import Literal, TypeAlias 3 | 4 | import numpy as np 5 | from numpy.typing import DTypeLike 6 | from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator 7 | from pydantic.alias_generators import to_camel 8 | 9 | 10 | class Modality(enum.Enum): 11 | """Used to Annotate columns in a dataset.""" 12 | 13 | UNKNOWN = "unknown" 14 | MOLECULE = "molecule" 15 | MOLECULE_3D = "molecule_3D" 16 | PROTEIN = "protein" 17 | PROTEIN_3D = "protein_3D" 18 | IMAGE = "image" 19 | 20 | 21 | KnownContentType: TypeAlias = Literal["chemical/x-smiles", "chemical/x-pdb"] 22 | 23 | 24 | class ColumnAnnotation(BaseModel): 25 | """ 26 | The `ColumnAnnotation` class is used to annotate the columns of the object. 27 | This mostly just stores metadata and does not affect the logic. The exception is the `is_pointer` attribute. 28 | 29 | Attributes: 30 | is_pointer: Annotates whether a column is a pointer column. If so, it does not contain data, 31 | but rather contains references to blobs of data from which the data is loaded. 32 | modality: The data modality describes the data type and is used to categorize datasets on the Hub 33 | and while it does not affect logic in this library, it does affect the logic of the Hub. 34 | description: Describes how the data was generated. 35 | user_attributes: Any additional metadata can be stored in the user attributes. 36 | content_type: Specify column's IANA content type. If the the content type matches with a known type for 37 | molecules (e.g. "chemical/x-smiles"), visualization for its content will be activated on the Hub side 38 | """ 39 | 40 | is_pointer: bool = Field(False, deprecated=True) 41 | modality: Modality = Modality.UNKNOWN 42 | description: str | None = None 43 | user_attributes: dict[str, str] = Field(default_factory=dict) 44 | dtype: np.dtype | None = None 45 | content_type: KnownContentType | str | None = None 46 | 47 | model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel, populate_by_name=True) 48 | 49 | @field_validator("modality", mode="before") 50 | def _validate_modality(cls, v): 51 | """Tries to convert a string to the Enum""" 52 | if isinstance(v, str): 53 | v = Modality[v.upper()] 54 | return v 55 | 56 | @field_validator("dtype", mode="before") 57 | def _validate_dtype(cls, v): 58 | """Tries to convert a string to the Enum""" 59 | if isinstance(v, str): 60 | v = np.dtype(v) 61 | return v 62 | 63 | @field_serializer("modality") 64 | def _serialize_modality(self, v: Modality): 65 | """Return the modality as a string, keeping it serializable""" 66 | return v.name 67 | 68 | @field_serializer("dtype") 69 | def _serialize_dtype(self, v: DTypeLike | None): 70 | """Return the dtype as a string, keeping it serializable""" 71 | if v is not None: 72 | v = v.name 73 | return v 74 | -------------------------------------------------------------------------------- /polaris/benchmark/_task.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from pydantic import ( 4 | BaseModel, 5 | Field, 6 | ValidationInfo, 7 | computed_field, 8 | field_serializer, 9 | field_validator, 10 | model_validator, 11 | ) 12 | from typing_extensions import Self 13 | 14 | from polaris.utils.errors import InvalidBenchmarkError 15 | from polaris.utils.types import ColumnName, TargetType, TaskType 16 | 17 | 18 | class PredictiveTaskSpecificationMixin(BaseModel): 19 | """A mixin for task benchmarks without metrics. 20 | 21 | Attributes: 22 | target_cols: The column(s) of the original dataset that should be used as the target. 23 | input_cols: The column(s) of the original dataset that should be used as input. 24 | target_types: A dictionary that maps target columns to their type. If not specified, this is automatically inferred. 25 | """ 26 | 27 | target_cols: set[ColumnName] = Field(min_length=1) 28 | input_cols: set[ColumnName] = Field(min_length=1) 29 | target_types: dict[ColumnName, TargetType] = Field(default_factory=dict, validate_default=True) 30 | 31 | @field_validator("target_cols", "input_cols", mode="before") 32 | @classmethod 33 | def _parse_cols(cls, v: str | Sequence[str], info: ValidationInfo) -> set[str]: 34 | """ 35 | Normalize columns input values to a set. 36 | """ 37 | if isinstance(v, str): 38 | v = {v} 39 | else: 40 | v = set(v) 41 | return v 42 | 43 | @field_validator("target_types", mode="before") 44 | @classmethod 45 | def _parse_target_types( 46 | cls, v: dict[ColumnName, TargetType | str | None] 47 | ) -> dict[ColumnName, TargetType]: 48 | """ 49 | Converts the target types to TargetType enums if they are strings. 50 | """ 51 | return { 52 | target: TargetType(val) if isinstance(val, str) else val 53 | for target, val in v.items() 54 | if val is not None 55 | } 56 | 57 | @model_validator(mode="after") 58 | def _validate_target_types(self) -> Self: 59 | """ 60 | Verifies that all target types are for benchmark targets. 61 | """ 62 | columns = set(self.target_types.keys()) 63 | if not columns.issubset(self.target_cols): 64 | raise InvalidBenchmarkError( 65 | f"Not all specified target types were found in the target columns. {columns} - {self.target_cols}" 66 | ) 67 | return self 68 | 69 | @field_serializer("target_types") 70 | def _serialize_target_types(self, target_types): 71 | """ 72 | Convert from enum to string to make sure it's serializable 73 | """ 74 | return {k: v.value for k, v in target_types.items()} 75 | 76 | @field_serializer("target_cols", "input_cols") 77 | def _serialize_columns(self, v: set[str]) -> list[str]: 78 | return list(v) 79 | 80 | @computed_field 81 | @property 82 | def task_type(self) -> str: 83 | """The high-level task type of the benchmark.""" 84 | v = TaskType.MULTI_TASK if len(self.target_cols) > 1 else TaskType.SINGLE_TASK 85 | return v.value 86 | -------------------------------------------------------------------------------- /tests/test_subset.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | import numpy as np 3 | import pandas as pd 4 | import pytest 5 | 6 | from polaris.dataset import Subset 7 | from polaris.utils.errors import TestAccessError 8 | 9 | 10 | def test_consistency_across_access_methods(test_dataset): 11 | """Using the various endpoints of the Subset API should not lead to the same data.""" 12 | indices = list(range(5)) 13 | task = Subset(test_dataset, indices, "smiles", "expt") 14 | 15 | # Ground truth 16 | expected_smiles = test_dataset.table.loc[indices, "smiles"] 17 | expected_targets = test_dataset.table.loc[indices, "expt"] 18 | 19 | # Indexing 20 | assert ([task[i][0] for i in range(5)] == expected_smiles).all() 21 | assert ([task[i][1] for i in range(5)] == expected_targets).all() 22 | 23 | # Iterator 24 | assert (list(smi for smi, y in task) == expected_smiles).all() 25 | assert (list(y for smi, y in task) == expected_targets).all() 26 | 27 | # Property 28 | assert (task.inputs == expected_smiles).all() 29 | assert (task.targets == expected_targets).all() 30 | assert (task.X == expected_smiles).all() 31 | assert (task.y == expected_targets).all() 32 | 33 | 34 | def test_access_to_test_set(test_single_task_benchmark): 35 | """A user should not have access to the test set targets.""" 36 | 37 | train, test = test_single_task_benchmark.get_train_test_split() 38 | assert test._hide_targets 39 | assert not train._hide_targets 40 | 41 | with pytest.raises(TestAccessError): 42 | test.as_array("y") 43 | with pytest.raises(TestAccessError): 44 | test.targets 45 | 46 | # Check if iterable style access returns just the SMILES 47 | for x in test: 48 | assert isinstance(x, str) 49 | for i in range(len(test)): 50 | assert isinstance(test[i], str) 51 | 52 | # For the train set it should work 53 | assert all(isinstance(y, float) for x, y in train) 54 | assert all(isinstance(train[i][1], float) for i in range(len(train))) 55 | 56 | # as_dataframe should work for both, but contain no targets for test 57 | train_df = train.as_dataframe() 58 | assert isinstance(train_df, pd.DataFrame) 59 | assert "expt" in train_df.columns 60 | test_df = test.as_dataframe() 61 | assert isinstance(test_df, pd.DataFrame) 62 | assert "expt" not in test_df.columns 63 | 64 | 65 | def test_input_featurization(test_single_task_benchmark): 66 | # Without a transformation, we expect a SMILES string 67 | train, test = test_single_task_benchmark.get_train_test_split() 68 | test_single_task_benchmark._n_splits_since_evaluate = 0 # Manually reset for sake of test 69 | 70 | x, y = train[0] 71 | assert isinstance(x, str) 72 | 73 | x = test[0] 74 | assert isinstance(x, str) 75 | 76 | train, test = test_single_task_benchmark.get_train_test_split(featurization_fn=dm.to_fp) 77 | 78 | # For all different flavours of accessing the data 79 | # Make sure the input is now featurized 80 | x, y = train[0] 81 | assert isinstance(x, np.ndarray) 82 | 83 | x = test[0] 84 | assert isinstance(x, np.ndarray) 85 | 86 | x, y = next(train) 87 | assert isinstance(x, np.ndarray) 88 | 89 | x = next(test) 90 | assert isinstance(x, np.ndarray) 91 | 92 | x = train.X[0] 93 | assert isinstance(x, np.ndarray) 94 | 95 | x = test.X[0] 96 | assert isinstance(x, np.ndarray) 97 | -------------------------------------------------------------------------------- /polaris/evaluate/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from polaris.dataset._subset import Subset 4 | from polaris.evaluate import BenchmarkPredictions, BenchmarkResults, Metric 5 | from polaris.utils.types import IncomingPredictionsType 6 | 7 | 8 | def _optionally_subset( 9 | preds: BenchmarkPredictions | None, 10 | test_set_labels: list[str] | str, 11 | target_labels: list[str] | str, 12 | ) -> BenchmarkPredictions | None: 13 | """ 14 | Returns the value in a nested dictionary associated with a sequence of keys 15 | if it exists, otherwise return None 16 | """ 17 | if preds is None: 18 | return None 19 | 20 | if not isinstance(test_set_labels, list): 21 | test_set_labels = [test_set_labels] 22 | 23 | if not isinstance(target_labels, list): 24 | target_labels = [target_labels] 25 | 26 | return preds.get_subset( 27 | test_set_subset=test_set_labels, 28 | target_subset=target_labels, 29 | ) 30 | 31 | 32 | def evaluate_benchmark( 33 | target_cols: list[str], 34 | test_set_labels: list[str], 35 | test_set_sizes: dict[str, int], 36 | metrics: set[Metric], 37 | y_true: dict[str, Subset], 38 | y_pred: IncomingPredictionsType | None = None, 39 | y_prob: IncomingPredictionsType | None = None, 40 | ): 41 | """ 42 | Utility function that contains the evaluation logic for a benchmark 43 | """ 44 | 45 | # Normalize the and predictions to a consistent, internal representation. 46 | # Format is a two-level dictionary: {test_set_label: {target_label: np.ndarray}} 47 | if y_pred is not None: 48 | y_pred = BenchmarkPredictions( 49 | predictions=y_pred, 50 | target_labels=target_cols, 51 | test_set_labels=test_set_labels, 52 | test_set_sizes=test_set_sizes, 53 | ) 54 | if y_prob is not None: 55 | y_prob = BenchmarkPredictions( 56 | predictions=y_prob, 57 | target_labels=target_cols, 58 | test_set_labels=test_set_labels, 59 | test_set_sizes=test_set_sizes, 60 | ) 61 | 62 | # Compute the results 63 | # Results are saved in a tabular format. For more info, see the BenchmarkResults docs. 64 | scores = pd.DataFrame(columns=BenchmarkResults.RESULTS_COLUMNS) 65 | 66 | # For every test set... 67 | for test_label in test_set_labels: 68 | # For every metric... 69 | for metric in metrics: 70 | if metric.is_multitask: 71 | # Multi-task but with a metric across targets 72 | score = metric( 73 | y_true=y_true[test_label], 74 | y_pred=_optionally_subset(y_pred, test_set_labels=test_label), 75 | y_prob=_optionally_subset(y_prob, test_set_labels=test_label), 76 | ) 77 | 78 | scores.loc[len(scores)] = (test_label, "aggregated", metric, score) 79 | continue 80 | 81 | # Otherwise, for every target... 82 | for target_label in target_cols: 83 | score = metric( 84 | y_true=y_true[test_label].filter_targets(target_label), 85 | y_pred=_optionally_subset(y_pred, test_set_labels=test_label, target_labels=target_label), 86 | y_prob=_optionally_subset(y_prob, test_set_labels=test_label, target_labels=target_label), 87 | ) 88 | 89 | scores.loc[len(scores)] = (test_label, target_label, metric.name, score) 90 | 91 | return scores 92 | -------------------------------------------------------------------------------- /polaris/hub/settings.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urljoin 2 | 3 | from pydantic import ValidationInfo, field_validator 4 | from pydantic_settings import BaseSettings, SettingsConfigDict 5 | 6 | from polaris.utils.types import HttpUrlString, TimeoutTypes 7 | 8 | 9 | class PolarisHubSettings(BaseSettings): 10 | """Settings for the OAuth2 Polaris Hub API Client. 11 | 12 | Info: Secrecy of these settings 13 | Since the Polaris Hub uses PCKE (Proof Key for Code Exchange) for OAuth2, 14 | these values thus do not have to be kept secret. 15 | See [RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636) for more info. 16 | 17 | Attributes: 18 | hub_url: The URL to the main page of the Polaris Hub. 19 | api_url: The URL to the main entrypoint of the Polaris API. 20 | authorize_url: The URL of the OAuth2 authorization endpoint. 21 | callback_url: The URL to which the user is redirected after authorization. 22 | token_fetch_url: The URL of the OAuth2 token endpoint. 23 | user_info_url: The URL of the OAuth2 user info endpoint. 24 | scopes: The OAuth2 scopes that are requested. 25 | client_id: The OAuth2 client ID. 26 | ca_bundle: The path to a CA bundle file for requests. 27 | Allows for custom SSL certificates to be used. 28 | default_timeout: The default timeout for requests. 29 | hub_token_url: The URL of the Polaris Hub token endpoint. 30 | A default value is generated based on the Hub URL, and this should not need to be overridden. 31 | username: The username for the Polaris Hub, for the optional password-based authentication. 32 | password: The password for the specified username. 33 | """ 34 | 35 | # Configuration of the pydantic model 36 | model_config = SettingsConfigDict( 37 | env_file=".env", env_prefix="POLARIS_", extra="ignore", env_ignore_empty=True 38 | ) 39 | 40 | # Hub settings 41 | hub_url: HttpUrlString = "https://polarishub.io/" 42 | api_url: HttpUrlString | None = None 43 | custom_metadata_prefix: str = "X-Amz-Meta-" 44 | 45 | # Hub authentication settings 46 | hub_token_url: HttpUrlString | None = None 47 | username: str | None = None 48 | password: str | None = None 49 | 50 | # External authentication settings 51 | authorize_url: HttpUrlString = "https://clerk.polarishub.io/oauth/authorize" 52 | callback_url: HttpUrlString | None = None 53 | token_fetch_url: HttpUrlString = "https://clerk.polarishub.io/oauth/token" 54 | user_info_url: HttpUrlString = "https://clerk.polarishub.io/oauth/userinfo" 55 | scopes: str = "profile email" 56 | client_id: str = "agQP2xVM6JqMHvGc" 57 | 58 | # Networking settings 59 | ca_bundle: str | bool | None = None 60 | default_timeout: TimeoutTypes = (10, 200) 61 | 62 | @field_validator("api_url", mode="before") 63 | def validate_api_url(cls, v, info: ValidationInfo): 64 | if v is None: 65 | v = urljoin(str(info.data["hub_url"]), "/api") 66 | return v 67 | 68 | @field_validator("callback_url", mode="before") 69 | def validate_callback_url(cls, v, info: ValidationInfo): 70 | if v is None: 71 | v = urljoin(str(info.data["hub_url"]), "/oauth2/callback") 72 | return v 73 | 74 | @field_validator("hub_token_url", mode="before") 75 | def populate_hub_token_url(cls, v, info: ValidationInfo): 76 | if v is None: 77 | v = urljoin(str(info.data["hub_url"]), "/api/auth/token") 78 | return v 79 | -------------------------------------------------------------------------------- /docs/tutorials/create_a_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "A model in Polaris centralizes all data about a method and can be attached to different results.\n", 8 | "\n", 9 | "## Create a Model\n", 10 | "\n", 11 | "To create a model, you need to instantiate the `Model` class. " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from polaris.model import Model\n", 21 | "\n", 22 | "# Create a new Model Card\n", 23 | "model = Model(\n", 24 | " name=\"MolGPS\",\n", 25 | " description=\"Graph transformer foundation model for molecular modeling\",\n", 26 | " code_url=\"https://github.com/datamol-io/graphium\"\n", 27 | ")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Share your model\n", 35 | "Want to share your model with the community? Upload it to the Polaris Hub!" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "model.upload_to_hub(owner=\"your-username\")" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "If you want to upload a new version of your model, you can specify its previous version with the `parent_artifact_id` parameter. Don't forget to add a changelog describing your updates!" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "model.artifact_changelog = \"In this version, I added...\"\n", 61 | "\n", 62 | "model.upload_to_hub(\n", 63 | " owner=\"your-username\",\n", 64 | " parent_artifact_id=\"your-username/tutorial-example\"\n", 65 | ")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Attach a model with a result\n", 73 | "\n", 74 | "The model card can then be attached to a newly created result on upload." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "from polaris import load_benchmark, load_model\n", 84 | "\n", 85 | "# Load a benchmark\n", 86 | "benchmark = load_benchmark(\"polaris/hello-world-benchmark\")\n", 87 | "\n", 88 | "# Get the results\n", 89 | "results = benchmark.evaluate(...)\n", 90 | "\n", 91 | "# Attach it to the result\n", 92 | "results.model = load_model(\"recursion/MolGPS\")\n", 93 | "\n", 94 | "# Upload the results\n", 95 | "results.upload_to_hub(owner=\"your-username\")" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "---\n", 103 | "\n", 104 | "The End. " 105 | ] 106 | } 107 | ], 108 | "metadata": { 109 | "kernelspec": { 110 | "display_name": ".venv", 111 | "language": "python", 112 | "name": "python3" 113 | }, 114 | "language_info": { 115 | "codemirror_mode": { 116 | "name": "ipython", 117 | "version": 3 118 | }, 119 | "file_extension": ".py", 120 | "mimetype": "text/x-python", 121 | "name": "python", 122 | "nbconvert_exporter": "python", 123 | "pygments_lexer": "ipython3", 124 | "version": "3.12.0" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 2 129 | } 130 | -------------------------------------------------------------------------------- /tests/test_storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import boto3 4 | import pytest 5 | from moto import mock_aws 6 | 7 | from polaris.hub.storage import S3Store 8 | 9 | 10 | @pytest.fixture(scope="function") 11 | def aws_credentials(): 12 | """ 13 | Mocked AWS Credentials for moto. 14 | """ 15 | os.environ["AWS_ACCESS_KEY_ID"] = "testing" 16 | os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" 17 | os.environ["AWS_SECURITY_TOKEN"] = "testing" 18 | os.environ["AWS_SESSION_TOKEN"] = "testing" 19 | 20 | 21 | @pytest.fixture(scope="function") 22 | def mocked_aws(aws_credentials): 23 | """ 24 | Mock all AWS interactions 25 | Requires you to create your own boto3 clients 26 | """ 27 | with mock_aws(): 28 | yield 29 | 30 | 31 | @pytest.fixture 32 | def s3_store(mocked_aws): 33 | # Setup mock S3 environment 34 | s3 = boto3.client("s3", region_name="us-east-1") 35 | bucket_name = "test-bucket" 36 | s3.create_bucket(Bucket=bucket_name) 37 | 38 | # Create an instance of your S3Store 39 | store = S3Store( 40 | path=f"{bucket_name}/prefix", 41 | access_key="fake-access-key", 42 | secret_key="fake-secret-key", 43 | token="fake-token", 44 | endpoint_url="https://s3.amazonaws.com", 45 | ) 46 | 47 | yield store 48 | 49 | 50 | def test_set_and_get_item(s3_store): 51 | key = "test-key" 52 | value = b"test-value" 53 | s3_store[key] = value 54 | 55 | retrieved_value = s3_store[key] 56 | assert retrieved_value == value 57 | 58 | 59 | def test_get_nonexistent_item(s3_store): 60 | with pytest.raises(KeyError): 61 | _ = s3_store["nonexistent-key"] 62 | 63 | 64 | def test_contains_item(s3_store): 65 | key = "test-key" 66 | value = b"test-value" 67 | s3_store[key] = value 68 | 69 | assert key in s3_store 70 | assert "nonexistent-key" not in s3_store 71 | 72 | 73 | def test_store_iterator_empty(s3_store): 74 | stored_keys = list(s3_store) 75 | assert stored_keys == [] 76 | 77 | 78 | def test_store_iterator(s3_store): 79 | keys = ["dir1/subdir1", "dir1/subdir2", "dir1/file1.ext", "dir2/file2.ext"] 80 | for key in keys: 81 | s3_store[key] = b"test" 82 | 83 | stored_keys = list(s3_store) 84 | assert sorted(stored_keys) == sorted(keys) 85 | 86 | 87 | def test_store_length(s3_store): 88 | keys = ["dir1/subdir1", "dir1/subdir2", "dir1/file1.ext", "dir2/file2.ext"] 89 | for key in keys: 90 | s3_store[key] = b"test" 91 | 92 | assert len(s3_store) == len(keys) 93 | 94 | 95 | def test_listdir(s3_store): 96 | keys = ["dir1/subdir1", "dir1/subdir2", "dir1/file1.ext", "dir2/file2.ext"] 97 | for key in keys: 98 | s3_store[key] = b"test" 99 | 100 | dir1_contents = list(s3_store.listdir("dir1")) 101 | assert set(dir1_contents) == {"file1.ext", "subdir1", "subdir2"} 102 | 103 | dir1_contents = list(s3_store.listdir()) 104 | assert set(dir1_contents) == {"dir1", "dir2"} 105 | 106 | 107 | def test_getsize(s3_store): 108 | key = "test-key" 109 | value = b"test-value" 110 | s3_store[key] = value 111 | 112 | size = s3_store.getsize(key) 113 | assert size == len(value) 114 | 115 | 116 | def test_getitems(s3_store): 117 | keys = ["dir1/subdir1", "dir1/subdir2", "dir1/file1.ext", "dir2/file2.ext"] 118 | for key in keys: 119 | s3_store[key] = b"test" 120 | 121 | items = s3_store.getitems(keys, contexts={}) 122 | assert len(items) == len(keys) 123 | assert all(key in items for key in keys) 124 | 125 | 126 | def test_delete_item_not_supported(s3_store): 127 | with pytest.raises(NotImplementedError): 128 | del s3_store["some-key"] 129 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | release-version: 7 | description: "A valid Semver version string" 8 | required: true 9 | 10 | permissions: 11 | contents: write 12 | pull-requests: write 13 | id-token: write 14 | 15 | concurrency: 16 | group: "release-${{ github.ref }}" 17 | cancel-in-progress: false 18 | 19 | defaults: 20 | run: 21 | shell: bash -l {0} 22 | 23 | jobs: 24 | check-semver: 25 | # Do not release if not triggered from the default branch 26 | if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch) 27 | 28 | runs-on: ubuntu-latest 29 | timeout-minutes: 30 30 | 31 | steps: 32 | - name: Checkout the code 33 | uses: actions/checkout@v4 34 | with: 35 | fetch-depth: 0 36 | 37 | - name: Get version 38 | id: version 39 | run: | 40 | version=$(git describe --abbrev=0 --tags) 41 | echo $version 42 | echo "version=${version}" >> $GITHUB_OUTPUT 43 | 44 | - name: Semver check 45 | id: semver_check 46 | uses: madhead/semver-utils@v4 47 | with: 48 | lenient: false 49 | version: ${{ inputs.release-version }} 50 | compare-to: ${{ steps.version.outputs.version }} 51 | 52 | - name: Semver ok 53 | if: steps.semver_check.outputs.comparison-result != '>' 54 | run: | 55 | echo "The release version is not valid Semver (${{ inputs.release-version }}) that is greater than the current version ${{ steps.version.outputs.version }}." 56 | exit 1 57 | 58 | release: 59 | needs: check-semver 60 | 61 | runs-on: ubuntu-latest 62 | timeout-minutes: 30 63 | 64 | env: 65 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 66 | 67 | steps: 68 | - name: Checkout the code 69 | uses: actions/checkout@v4 70 | with: 71 | fetch-depth: 0 72 | 73 | - name: Install uv 74 | uses: astral-sh/setup-uv@v5 75 | 76 | - name: Install the project 77 | run: uv sync --frozen --all-groups --python 3.12 78 | 79 | - name: Build Changelog 80 | id: github_release 81 | uses: mikepenz/release-changelog-builder-action@v5 82 | with: 83 | toTag: "main" 84 | configuration: ".github/changelog_config.json" 85 | 86 | - name: Create and push git tag 87 | run: | 88 | # Configure git 89 | git config --global user.name "${GITHUB_ACTOR}" 90 | git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com" 91 | 92 | # Tag the release 93 | git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}" 94 | 95 | # Checkout the git tag 96 | git checkout "${{ inputs.release-version }}" 97 | 98 | # Push the modified changelogs 99 | git push origin main 100 | 101 | # Push the tags 102 | git push origin "${{ inputs.release-version }}" 103 | 104 | - name: Build the wheel and sdist 105 | run: uv build 106 | 107 | - name: Publish package to PyPI 108 | uses: pypa/gh-action-pypi-publish@release/v1 109 | with: 110 | password: ${{ secrets.PYPI_API_TOKEN }} 111 | packages-dir: dist/ 112 | 113 | - name: Create GitHub Release 114 | uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 115 | with: 116 | tag_name: ${{ inputs.release-version }} 117 | body: ${{steps.github_release.outputs.changelog}} 118 | 119 | - name: Deploy the doc 120 | run: | 121 | echo "Get the gh-pages branch" 122 | git fetch origin gh-pages 123 | 124 | echo "Build and deploy the doc on ${{ inputs.release-version }}" 125 | uv run mike deploy --push stable 126 | uv run mike deploy --push ${{ inputs.release-version }} 127 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | Welcome to the Polaris Quickstart guide! This page will introduce you to core concepts and you'll submit a first result to a benchmark on the [Polaris Hub](https://www.polarishub.io). 3 | 4 | ## Installation 5 | !!! warning "`polaris-lib` vs `polaris`" 6 | Be aware that the package name differs between _pip_ and _conda_. 7 | 8 | Polaris can be installed via _pip_: 9 | 10 | ```bash 11 | pip install polaris-lib 12 | ``` 13 | 14 | or _conda_: 15 | ```bash 16 | conda install -c conda-forge polaris 17 | ``` 18 | 19 | ## Core concepts 20 | Polaris explicitly distinguished **datasets** and **benchmarks**. 21 | 22 | - A _dataset_ is simply a tabular collection of data, storing datapoints in a row-wise manner. 23 | - A _benchmark_ defines the ML task and evaluation logic (e.g. split and metrics) for a dataset. 24 | 25 | One dataset can therefore be associated with multiple benchmarks. 26 | 27 | ## Login 28 | To submit or upload artifacts to the [Polaris Hub](https://polarishub.io/) from the client, you must first authenticate yourself. If you don't have an account yet, you can create one [here](https://polarishub.io/sign-up). 29 | 30 | You can do this via the following command in your terminal: 31 | 32 | ```bash 33 | polaris login 34 | ``` 35 | 36 | or in Python: 37 | ```py 38 | from polaris.hub.client import PolarisHubClient 39 | 40 | with PolarisHubClient() as client: 41 | client.login() 42 | ``` 43 | 44 | ## Benchmark API 45 | To get started, we will submit a result to the [`polaris/hello-world-benchmark`](https://polarishub.io/benchmarks/polaris/hello-world-benchmark). 46 | 47 | ```python 48 | import polaris as po 49 | 50 | # Load the benchmark from the Hub 51 | benchmark = po.load_benchmark("polaris/hello-world-benchmark") 52 | 53 | # Get the train and test data-loaders 54 | train, test = benchmark.get_train_test_split() 55 | 56 | # Use the training data to train your model 57 | # Get the input as an array with 'train.inputs' and 'train.targets' 58 | # Or simply iterate over the train object. 59 | for x, y in train: 60 | ... 61 | 62 | # Work your magic to accurately predict the test set 63 | predictions = [0.0 for x in test] 64 | 65 | # Evaluate your predictions 66 | results = benchmark.evaluate(predictions) 67 | 68 | # Submit your results 69 | results.upload_to_hub(owner="dummy-user") 70 | ``` 71 | 72 | Through immutable datasets and standardized benchmarks, Polaris aims to serve as a source of truth for machine learning in drug discovery. The limited flexibility might differ from your typical experience, but this is by design to improve reproducibility. Learn more [here](https://polarishub.io/blog/reproducible-machine-learning-in-drug-discovery-how-polaris-serves-as-a-single-source-of-truth). 73 | 74 | ## Dataset API 75 | Loading a benchmark will automatically load the underlying dataset. We can also directly access the [`polaris/hello-world`](https://polarishub.io/datasets/polaris/hello-world) dataset. 76 | 77 | ```python 78 | import polaris as po 79 | 80 | # Load the dataset from the Hub 81 | dataset = po.load_dataset("polaris/hello-world") 82 | 83 | # Get information on the dataset size 84 | dataset.size() 85 | 86 | # Load a datapoint in memory 87 | dataset.get_data( 88 | row=dataset.rows[0], 89 | col=dataset.columns[0], 90 | ) 91 | 92 | # Or, similarly: 93 | dataset[dataset.rows[0], dataset.columns[0]] 94 | 95 | # Get an entire data point 96 | dataset[0] 97 | ``` 98 | 99 | Drug discovery research involves a maze of file formats (e.g. PDB for 3D structures, SDF for small molecules, and so on). Each format requires specialized knowledge to parse and interpret properly. At Polaris, we wanted to remove that barrier. We use a universal data format based on [Zarr](https://zarr.dev/). Learn more [here](https://polarishub.io/blog/dataset-v2-built-to-scale). 100 | 101 | ## Where to next? 102 | 103 | Now that you've seen how easy it is to use Polaris, let's dive into the details through [a set of tutorials](./tutorials/submit_to_benchmark.ipynb)! 104 | 105 | --- 106 | -------------------------------------------------------------------------------- /polaris/utils/errors.py: -------------------------------------------------------------------------------- 1 | import certifi 2 | 3 | 4 | class InvalidDatasetError(ValueError): 5 | pass 6 | 7 | 8 | class InvalidBenchmarkError(ValueError): 9 | pass 10 | 11 | 12 | class InvalidCompetitionError(ValueError): 13 | pass 14 | 15 | 16 | class InvalidResultError(ValueError): 17 | pass 18 | 19 | 20 | class TestAccessError(Exception): 21 | # Prevent pytest to collect this as a test 22 | __test__ = False 23 | 24 | pass 25 | 26 | 27 | class PolarisChecksumError(ValueError): 28 | pass 29 | 30 | 31 | class InvalidZarrChecksum(Exception): 32 | pass 33 | 34 | 35 | class InvalidZarrCodec(Exception): 36 | """Raised when an expected codec is not registered.""" 37 | 38 | def __init__(self, codec_id: str): 39 | self.codec_id = codec_id 40 | super().__init__( 41 | f"This Zarr archive requires the {self.codec_id} codec. " 42 | "Install all optional codecs with 'pip install polaris-lib[codecs]'." 43 | ) 44 | 45 | 46 | class PolarisHubError(Exception): 47 | BOLD = "\033[1m" 48 | YELLOW = "\033[93m" 49 | _END_CODE = "\033[0m" 50 | 51 | def __init__(self, message: str = "", response_text: str = ""): 52 | parts = filter( 53 | bool, 54 | [ 55 | f"{self.BOLD}The request to the Polaris Hub has failed.{self._END_CODE}", 56 | f"{self.YELLOW}{message}{self._END_CODE}" if message else "", 57 | f"----------------------\nError reported was:\n{response_text}" if response_text else "", 58 | ], 59 | ) 60 | 61 | super().__init__("\n".join(parts)) 62 | 63 | 64 | class PolarisUnauthorizedError(PolarisHubError): 65 | def __init__(self, response_text: str = ""): 66 | message = ( 67 | "You are not logged in to Polaris or your login has expired. " 68 | "You can use the Polaris CLI to easily authenticate yourself again with `polaris login --overwrite`." 69 | ) 70 | super().__init__(message, response_text) 71 | 72 | 73 | class PolarisCreateArtifactError(PolarisHubError): 74 | def __init__(self, response_text: str = ""): 75 | message = ( 76 | "Note: If you can confirm that you are authorized to perform this action, " 77 | "please call 'polaris login --overwrite' and try again. If the issue persists, please reach out to the Polaris team for support." 78 | ) 79 | super().__init__(message, response_text) 80 | 81 | 82 | class PolarisRetrieveArtifactError(PolarisHubError): 83 | def __init__(self, response_text: str = ""): 84 | message = ( 85 | "Note: If this artifact exists and you can confirm that you are authorized to retrieve it, " 86 | "please call 'polaris login --overwrite' and try again. If the issue persists, please reach out to the Polaris team for support." 87 | ) 88 | super().__init__(message, response_text) 89 | 90 | 91 | class PolarisSSLError(PolarisHubError): 92 | def __init__(self, response_text: str = ""): 93 | message = ( 94 | "We could not verify the SSL certificate. " 95 | f"Please ensure the installed version ({certifi.__version__}) of the `certifi` package is the latest. " 96 | "If you require the usage of a custom CA bundle, you can set the POLARIS_CA_BUNDLE " 97 | "environment variable to the path of your CA bundle. For debugging, you can temporarily disable " 98 | "SSL verification by setting the POLARIS_CA_BUNDLE environment variable to `false`." 99 | ) 100 | super().__init__(message, response_text) 101 | 102 | 103 | class PolarisDeprecatedError(PolarisHubError): 104 | def __init__(self, feature: str, response_text: str = ""): 105 | message = ( 106 | f"The '{feature}' feature has been deprecated and is no longer supported. " 107 | "Please contact the Polaris team for more information about alternative approaches." 108 | ) 109 | super().__init__(message, response_text) 110 | -------------------------------------------------------------------------------- /tests/test_benchmark_predictions_v2.py: -------------------------------------------------------------------------------- 1 | from polaris.prediction._predictions_v2 import BenchmarkPredictionsV2 2 | from polaris.utils.zarr.codecs import RDKitMolCodec, AtomArrayCodec 3 | from rdkit import Chem 4 | import numpy as np 5 | import pytest 6 | import datamol as dm 7 | import zarr 8 | from fastpdb import struc 9 | 10 | 11 | def assert_deep_equal(result, expected): 12 | assert isinstance(result, type(expected)), f"Types differ: {type(result)} != {type(expected)}" 13 | if isinstance(expected, dict): 14 | assert result.keys() == expected.keys() 15 | for key in expected: 16 | assert_deep_equal(result[key], expected[key]) 17 | elif isinstance(expected, np.ndarray): 18 | assert np.array_equal(result, expected) 19 | else: 20 | assert result == expected 21 | 22 | 23 | def test_v2_rdkit_object_codec(v2_benchmark_with_rdkit_object_dtype): 24 | mols = [dm.to_mol("CCO"), dm.to_mol("CCN")] 25 | preds = {"test": {"expt": mols}} 26 | bp = BenchmarkPredictionsV2( 27 | predictions=preds, 28 | dataset_zarr_root=v2_benchmark_with_rdkit_object_dtype.dataset.zarr_root, 29 | benchmark_artifact_id=v2_benchmark_with_rdkit_object_dtype.artifact_id, 30 | target_labels=["expt"], 31 | test_set_labels=["test"], 32 | test_set_sizes={"test": 2}, 33 | ) 34 | assert isinstance(bp.predictions["test"]["expt"], np.ndarray) 35 | assert bp.predictions["test"]["expt"].dtype == object 36 | assert_deep_equal(bp.predictions, {"test": {"expt": np.array(mols, dtype=object)}}) 37 | 38 | # Check Zarr archive 39 | zarr_path = bp.to_zarr() 40 | assert zarr_path.exists() 41 | root = zarr.open(str(zarr_path), mode="r") 42 | arr = root["test"]["expt"][:] 43 | arr_smiles = [Chem.MolToSmiles(m) for m in arr] 44 | mols_smiles = [Chem.MolToSmiles(m) for m in mols] 45 | assert arr_smiles == mols_smiles 46 | 47 | # Check that object_codec is correctly set as a filter (Zarr stores object_codec as filters) 48 | zarr_array = root["test"]["expt"] 49 | assert zarr_array.filters is not None 50 | assert len(zarr_array.filters) > 0 51 | assert any(isinstance(f, RDKitMolCodec) for f in zarr_array.filters) 52 | 53 | 54 | def test_v2_atomarray_object_codec(v2_benchmark_with_atomarray_object_dtype, pdbs_structs): 55 | # Use fastpdb.AtomArray objects 56 | preds = {"test": {"expt": np.array(pdbs_structs[:2], dtype=object)}} 57 | bp = BenchmarkPredictionsV2( 58 | predictions=preds, 59 | dataset_zarr_root=v2_benchmark_with_atomarray_object_dtype.dataset.zarr_root, 60 | benchmark_artifact_id=v2_benchmark_with_atomarray_object_dtype.artifact_id, 61 | target_labels=["expt"], 62 | test_set_labels=["test"], 63 | test_set_sizes={"test": 2}, 64 | ) 65 | assert isinstance(bp.predictions["test"]["expt"], np.ndarray) 66 | assert bp.predictions["test"]["expt"].dtype == object 67 | assert_deep_equal(bp.predictions, {"test": {"expt": np.array(pdbs_structs[:2], dtype=object)}}) 68 | 69 | # Check Zarr archive (dtype and shape only) 70 | zarr_path = bp.to_zarr() 71 | assert zarr_path.exists() 72 | root = zarr.open(str(zarr_path), mode="r") 73 | arr = root["test"]["expt"][:] 74 | assert arr.dtype == object 75 | assert arr.shape == (2,) 76 | assert all(isinstance(x, struc.AtomArray) for x in arr) 77 | 78 | # Check that object_codec is correctly set as a filter (Zarr stores object_codec as filters) 79 | zarr_array = root["test"]["expt"] 80 | assert zarr_array.filters is not None 81 | assert len(zarr_array.filters) > 0 82 | assert any(isinstance(f, AtomArrayCodec) for f in zarr_array.filters) 83 | 84 | 85 | def test_v2_dtype_mismatch_raises(test_benchmark_v2): 86 | # Create a list of rdkit.Chem.Mol objects (object dtype) to test against float dtype dataset 87 | mols = [dm.to_mol("CCO"), dm.to_mol("CCN")] 88 | preds = {"test": {"A": mols}} # Using column "A" which has float dtype in test_dataset_v2 89 | with pytest.raises(ValueError, match="Dtype mismatch"): 90 | BenchmarkPredictionsV2( 91 | predictions=preds, 92 | dataset_zarr_root=test_benchmark_v2.dataset.zarr_root, 93 | benchmark_artifact_id=test_benchmark_v2.artifact_id, 94 | target_labels=["A"], 95 | test_set_labels=["test"], 96 | test_set_sizes={"test": 2}, 97 | ) 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 |
5 | 6 |

7 | 8 | ✨ Polaris Hub 9 | | 10 | 11 | 📚 Client Doc 12 | 13 |

14 | 15 | --- 16 | 17 | | | | 18 | | --- | --- | 19 | | Latest Release | [![PyPI](https://img.shields.io/pypi/v/polaris-lib)](https://pypi.org/project/polaris-lib/) | 20 | | | [![Conda](https://img.shields.io/conda/v/conda-forge/polaris?label=conda&color=success)](https://anaconda.org/conda-forge/polaris) | 21 | | Python Version | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/polaris-lib)](https://pypi.org/project/polaris-lib/) | 22 | | License | [![Code license](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/polaris-hub/polaris/blob/main/LICENSE) | 23 | | Downloads | [![PyPI - Downloads](https://img.shields.io/pypi/dm/polaris-lib)](https://pypi.org/project/polaris-lib/) | 24 | | | [![Conda](https://img.shields.io/conda/dn/conda-forge/polaris)](https://anaconda.org/conda-forge/polaris) | 25 | | Citation | [![DOI](https://img.shields.io/badge/DOI-10.1038%2Fs42256--024--00911--w-blue)](https://doi.org/10.1038/s42256-024-00911-w) | 26 | 27 | Polaris establishes a novel, industry‑certified standard to foster the development of impactful methods in AI-based drug discovery. 28 | 29 | This library is a Python client to interact with the [Polaris Hub](https://polarishub.io/). It allows you to: 30 | 31 | - Download Polaris datasets and benchmarks. 32 | - Evaluate a custom method against a Polaris benchmark. 33 | 34 | ## Quick API Tour 35 | 36 | ```python 37 | import polaris as po 38 | 39 | # Load the benchmark from the Hub 40 | benchmark = po.load_benchmark("polaris/hello-world-benchmark") 41 | 42 | # Get the train and test data-loaders 43 | train, test = benchmark.get_train_test_split() 44 | 45 | # Use the training data to train your model 46 | # Get the input as an array with 'train.inputs' and 'train.targets' 47 | # Or simply iterate over the train object. 48 | for x, y in train: 49 | ... 50 | 51 | # Work your magic to accurately predict the test set 52 | predictions = [0.0 for x in test] 53 | 54 | # Evaluate your predictions 55 | results = benchmark.evaluate(predictions) 56 | 57 | # Submit your results 58 | results.upload_to_hub(owner="dummy-user") 59 | ``` 60 | 61 | ## Documentation 62 | 63 | Please refer to the [documentation](https://polaris-hub.github.io/polaris/), which contains tutorials for getting started with `polaris` and detailed descriptions of the functions provided. 64 | 65 | ## How to cite 66 | 67 | Please cite Polaris if you use it in your research. A list of relevant publications: 68 | 69 | - [![DOI](https://img.shields.io/badge/DOI-10.26434%2Fchemrxiv--2024--6dbwv--v2-blue)](https://doi.org/10.26434/chemrxiv-2024-6dbwv-v2) - Preprint, Method Comparison Guidelines. 70 | - [![DOI](https://img.shields.io/badge/DOI-10.1038%2Fs42256--024--00911--w-blue)](https://doi.org/10.1038/s42256-024-00911-w) - Nature Correspondence, Call to Action. 71 | - [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.13652587.svg)](https://doi.org/10.5281/zenodo.13652587) - Zenodo, Code Repository. 72 | 73 | ## Installation 74 | 75 | You can install `polaris` using conda/mamba/micromamba: 76 | 77 | ```bash 78 | conda install -c conda-forge polaris 79 | ``` 80 | 81 | You can also use pip: 82 | 83 | ```bash 84 | pip install polaris-lib 85 | ``` 86 | 87 | ## Development lifecycle 88 | 89 | ### Setup dev environment 90 | 91 | ```shell 92 | conda env create -n polaris -f env.yml 93 | conda activate polaris 94 | 95 | pip install --no-deps -e . 96 | ``` 97 | 98 |
99 | Other installation options 100 | 101 | Alternatively, using [uv](https://github.com/astral-sh/uv): 102 | ```shell 103 | uv venv -p 3.12 polaris 104 | source .venv/polaris/bin/activate 105 | uv pip compile pyproject.toml -o requirements.txt --all-extras 106 | uv pip install -r requirements.txt 107 | ``` 108 |
109 | 110 | 111 | ### Tests 112 | 113 | You can run tests locally with: 114 | 115 | ```shell 116 | pytest 117 | ``` 118 | 119 | ## License 120 | 121 | Under the Apache-2.0 license. See [LICENSE](LICENSE). 122 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | import numpy as np 3 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 4 | 5 | from polaris.evaluate import BenchmarkResults 6 | 7 | 8 | def test_single_task_benchmark_loop(test_single_task_benchmark): 9 | """Tests the integrated API for a single-task benchmark.""" 10 | train, test = test_single_task_benchmark.get_train_test_split() 11 | 12 | model = RandomForestRegressor() 13 | smiles, y = train.as_array("xy") 14 | x = np.array([dm.to_fp(dm.to_mol(smi)) for smi in smiles]) 15 | model.fit(X=x, y=y) 16 | 17 | x = np.array([dm.to_fp(dm.to_mol(smi)) for smi in test.inputs]) 18 | y_pred = model.predict(x) 19 | 20 | scores = test_single_task_benchmark.evaluate(y_pred) 21 | assert isinstance(scores, BenchmarkResults) 22 | 23 | 24 | def test_single_task_benchmark_loop_with_multiple_test_sets(test_single_task_benchmark_multiple_test_sets): 25 | """Tests the integrated API for a single-task benchmark with multiple test sets.""" 26 | train, test = test_single_task_benchmark_multiple_test_sets.get_train_test_split() 27 | 28 | smiles, y = train.as_array("xy") 29 | 30 | x_train = np.array([dm.to_fp(dm.to_mol(smi)) for smi in smiles]) 31 | 32 | model = RandomForestRegressor() 33 | model.fit(X=x_train, y=y) 34 | 35 | y_pred = {} 36 | task_name, *_ = test_single_task_benchmark_multiple_test_sets.target_cols 37 | for k, test_subset in test.items(): 38 | x_test = np.array([dm.to_fp(dm.to_mol(smi)) for smi in test_subset.inputs]) 39 | y_pred[k] = {task_name: model.predict(x_test)} 40 | 41 | scores = test_single_task_benchmark_multiple_test_sets.evaluate(y_pred) 42 | assert isinstance(scores, BenchmarkResults) 43 | 44 | 45 | def test_single_task_benchmark_clf_loop_with_multiple_test_sets( 46 | test_single_task_benchmark_clf_multiple_test_sets, 47 | ): 48 | """Tests the integrated API for a single-task benchmark for classification probabilities with multiple test sets.""" 49 | train, test = test_single_task_benchmark_clf_multiple_test_sets.get_train_test_split() 50 | 51 | smiles, y = train.as_array("xy") 52 | 53 | x_train = np.array([dm.to_fp(dm.to_mol(smi)) for smi in smiles]) 54 | 55 | model = RandomForestClassifier() 56 | model.fit(X=x_train, y=y) 57 | 58 | y_prob = {} 59 | y_pred = {} 60 | task_name, *_ = test_single_task_benchmark_clf_multiple_test_sets.target_cols 61 | for k, test_subset in test.items(): 62 | x_test = np.array([dm.to_fp(dm.to_mol(smi)) for smi in test_subset.inputs]) 63 | y_prob[k] = {task_name: model.predict_proba(x_test)[:, :1]} # for binary classification 64 | y_pred[k] = {task_name: model.predict(x_test)} 65 | 66 | scores = test_single_task_benchmark_clf_multiple_test_sets.evaluate(y_prob=y_prob, y_pred=y_pred) 67 | assert isinstance(scores, BenchmarkResults) 68 | 69 | 70 | def test_multi_task_benchmark_loop(test_multi_task_benchmark): 71 | """Tests the integrated API for a multi-task benchmark.""" 72 | train, test = test_multi_task_benchmark.get_train_test_split() 73 | 74 | smiles, multi_y = train.as_array("xy") 75 | x_train = np.array([dm.to_fp(dm.to_mol(smi)) for smi in smiles]) 76 | x_test = np.array([dm.to_fp(dm.to_mol(smi)) for smi in test.inputs]) 77 | 78 | y_pred = {} 79 | for k, y in multi_y.items(): 80 | model = RandomForestRegressor() 81 | 82 | mask = ~np.isnan(y) 83 | model.fit(X=x_train[mask], y=y[mask]) 84 | y_pred[k] = model.predict(x_test) 85 | 86 | scores = test_multi_task_benchmark.evaluate(y_pred) 87 | assert isinstance(scores, BenchmarkResults) 88 | 89 | 90 | def test_multi_task_benchmark_loop_with_multiple_test_sets(test_multi_task_benchmark_multiple_test_sets): 91 | """Tests the integrated API for a multi-task benchmark with multiple test sets.""" 92 | train, test = test_multi_task_benchmark_multiple_test_sets.get_train_test_split() 93 | smiles, multi_y = train.as_array("xy") 94 | 95 | x_train = np.array([dm.to_fp(dm.to_mol(smi)) for smi in smiles]) 96 | 97 | y_pred = {} 98 | for test_set_name, test_subset in test.items(): 99 | y_pred[test_set_name] = {} 100 | x_test = np.array([dm.to_fp(dm.to_mol(smi)) for smi in test_subset.inputs]) 101 | 102 | for task_name, y in multi_y.items(): 103 | model = RandomForestRegressor() 104 | 105 | mask = ~np.isnan(y) 106 | model.fit(X=x_train[mask], y=y[mask]) 107 | y_pred[test_set_name][task_name] = model.predict(x_test) 108 | 109 | scores = test_multi_task_benchmark_multiple_test_sets.evaluate(y_pred) 110 | assert isinstance(scores, BenchmarkResults) 111 | -------------------------------------------------------------------------------- /polaris/benchmark/_split.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from itertools import chain 3 | 4 | from pydantic import BaseModel, computed_field, field_serializer, model_validator 5 | from typing_extensions import Self 6 | 7 | from polaris.utils.errors import InvalidBenchmarkError 8 | from polaris.utils.misc import listit 9 | from polaris.utils.types import SplitType 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class SplitSpecificationV1Mixin(BaseModel): 15 | """ 16 | Mixin class to add a split field to a benchmark. This is the V1 implementation. 17 | 18 | The split is defined as a (train, test) tuple, where train is a list of indices and 19 | test is a dictionary that maps test set names to lists of indices. 20 | 21 | Warning: Scalability 22 | The simple list-based representation we use for the split in this first implementation doesn't scale well. 23 | We therefore worked on a V2 implementation that uses roaring bitmaps. 24 | See [`SplitSpecificationV2Mixin`][`polaris.experimental._split_v2.SplitSpecificationV2Mixin`] for more details. 25 | 26 | Attributes: 27 | split: The predefined train-test split to use for evaluation. 28 | """ 29 | 30 | split: SplitType 31 | 32 | @model_validator(mode="after") 33 | def _validate_split(self) -> Self: 34 | """ 35 | Verifies that: 36 | 1) There are no empty test partitions 37 | 2) There is no overlap between the train and test set 38 | 3) There is no duplicate indices in any of the sets 39 | """ 40 | 41 | if not isinstance(self.split[1], dict): 42 | self.split = self.split[0], {"test": self.split[1]} 43 | split = self.split 44 | 45 | # Train partition can be empty (zero-shot) 46 | # Test partitions cannot be empty 47 | if any(len(v) == 0 for v in split[1].values()): 48 | raise InvalidBenchmarkError("The predefined split contains empty test partitions") 49 | 50 | train_idx_list = split[0] 51 | full_test_idx_list = list(chain.from_iterable(split[1].values())) 52 | 53 | if len(train_idx_list) == 0: 54 | logger.info( 55 | "This benchmark only specifies a test set. It will return an empty train set in `get_train_test_split()`" 56 | ) 57 | 58 | train_idx_set = set(train_idx_list) 59 | full_test_idx_set = set(full_test_idx_list) 60 | 61 | # The train and test indices do not overlap 62 | if len(train_idx_set & full_test_idx_set) > 0: 63 | raise InvalidBenchmarkError("The predefined split specifies overlapping train and test sets") 64 | 65 | # Check for duplicate indices within the train set 66 | if len(train_idx_set) != len(train_idx_list): 67 | raise InvalidBenchmarkError("The training set contains duplicate indices") 68 | 69 | # Check for duplicate indices within a given test set. Because a user can specify 70 | # multiple test sets for a given benchmark and it is acceptable for indices to be shared 71 | # across test sets, we check for duplicates in each test set independently. 72 | for test_set_name, test_set_idx_list in split[1].items(): 73 | if len(test_set_idx_list) != len(set(test_set_idx_list)): 74 | raise InvalidBenchmarkError( 75 | f'Test set with name "{test_set_name}" contains duplicate indices' 76 | ) 77 | 78 | return self 79 | 80 | @field_serializer("split") 81 | def _serialize_split(self, v: SplitType): 82 | """Convert any tuple to list to make sure it's serializable""" 83 | return listit(v) 84 | 85 | @computed_field 86 | @property 87 | def test_set_sizes(self) -> dict[str, int]: 88 | """The sizes of the test sets.""" 89 | return {k: len(v) for k, v in self.split[1].items()} 90 | 91 | @computed_field 92 | @property 93 | def n_test_sets(self) -> int: 94 | """The number of test sets""" 95 | return len(self.split[1]) 96 | 97 | @computed_field 98 | @property 99 | def n_train_datapoints(self) -> int: 100 | """The size of the train set.""" 101 | return len(self.split[0]) 102 | 103 | @computed_field 104 | @property 105 | def test_set_labels(self) -> list[str]: 106 | """The labels of the test sets.""" 107 | return sorted(list(self.split[1].keys())) 108 | 109 | @computed_field 110 | @property 111 | def n_test_datapoints(self) -> dict[str, int]: 112 | """The size of (each of) the test set(s).""" 113 | if self.n_test_sets == 1: 114 | return {"test": len(self.split[1]["test"])} 115 | else: 116 | return {k: len(v) for k, v in self.split[1].items()} 117 | -------------------------------------------------------------------------------- /polaris/loader/load.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import fsspec 4 | from datamol.utils import fs 5 | 6 | from polaris.benchmark import MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification 7 | from polaris.benchmark._benchmark_v2 import BenchmarkV2Specification 8 | from polaris.dataset import DatasetV1, create_dataset_from_file 9 | from polaris.hub.client import PolarisHubClient 10 | from polaris.utils.types import ChecksumStrategy 11 | from polaris.model import Model 12 | 13 | 14 | def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr") -> DatasetV1: 15 | """ 16 | Loads a Polaris dataset. 17 | 18 | In Polaris, a dataset is a tabular data structure that stores data-points in a row-wise manner. 19 | A dataset can have multiple modalities or targets, can be sparse 20 | and can be part of _one or multiple benchmarks_. 21 | 22 | The Polaris dataset can be loaded from the Hub or from a local or remote directory. 23 | 24 | - **Hub** (recommended): When loading the dataset from the Hub, you can simply 25 | provide the `owner/name` slug. This can be easily copied from the relevant dataset 26 | page on the Hub. 27 | - **Directory**: When loading the dataset from a directory, you should provide the path 28 | as returned by `dataset.to_json()`. The path can be local or remote. 29 | """ 30 | 31 | extension = fs.get_extension(path) 32 | is_file = fs.is_file(path) or extension == "zarr" 33 | 34 | if not is_file: 35 | # Load from the Hub 36 | with PolarisHubClient() as client: 37 | return client.get_dataset(*path.split("/"), verify_checksum=verify_checksum) 38 | 39 | # Load from local file 40 | if extension == "json": 41 | dataset = DatasetV1.from_json(path) 42 | else: 43 | dataset = create_dataset_from_file(path) 44 | 45 | # Verify checksum if requested 46 | if dataset.should_verify_checksum(verify_checksum): 47 | dataset.verify_checksum() 48 | 49 | return dataset 50 | 51 | 52 | def load_benchmark(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr"): 53 | """ 54 | Loads a Polaris benchmark. 55 | 56 | In Polaris, a benchmark wraps a dataset with additional metadata to specify the evaluation logic. 57 | 58 | The Polaris benchmark can be loaded from the Hub or from a local or remote directory. 59 | 60 | Note: Dataset is automatically loaded 61 | The dataset underlying the benchmark is automatically loaded when loading the benchmark. 62 | 63 | - **Hub** (recommended): When loading the benchmark from the Hub, you can simply 64 | provide the `owner/name` slug. This can be easily copied from the relevant benchmark 65 | page on the Hub. 66 | - **Directory**: When loading the benchmark from a directory, you should provide the path 67 | as returned by `benchmmark.to_json()`. The path can be local or remote. 68 | """ 69 | is_file = fs.is_file(path) or fs.get_extension(path) == "zarr" 70 | 71 | if not is_file: 72 | # Load from the Hub 73 | with PolarisHubClient() as client: 74 | return client.get_benchmark(*path.split("/"), verify_checksum=verify_checksum) 75 | 76 | with fsspec.open(path, "r") as fd: 77 | data = json.load(fd) 78 | 79 | is_single_task = isinstance(data["target_cols"], str) or len(data["target_cols"]) == 1 80 | 81 | match data["version"]: 82 | case 1 if is_single_task: 83 | cls = SingleTaskBenchmarkSpecification 84 | case 1: 85 | cls = MultiTaskBenchmarkSpecification 86 | case 2: 87 | cls = BenchmarkV2Specification 88 | case _: 89 | raise ValueError(f"Unsupported benchmark version: {data['version']}") 90 | 91 | benchmark = cls.from_json(path) 92 | 93 | # Verify checksum if requested 94 | if benchmark.dataset.should_verify_checksum(verify_checksum): 95 | benchmark.verify_checksum() 96 | 97 | return benchmark 98 | 99 | 100 | def load_competition(artifact_id: str): 101 | """ 102 | Loads a Polaris competition. 103 | 104 | On Polaris, a competition represents a secure and fair benchmark. The target labels never exist 105 | on the client and all results are evaluated through Polaris' servers. 106 | 107 | Note: Dataset is automatically loaded 108 | The dataset underlying the competition is automatically loaded when loading the competition. 109 | 110 | """ 111 | with PolarisHubClient() as client: 112 | return client.get_competition(artifact_id) 113 | 114 | 115 | def load_model(artifact_id: str) -> Model: 116 | """ 117 | Loads a Polaris model. 118 | 119 | On Polaris, a model centralizes all data about a method and can be attached to different results. 120 | """ 121 | with PolarisHubClient() as client: 122 | return client.get_model(artifact_id) 123 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: "Polaris" 2 | site_description: "Polaris establishes a novel, industry‑certified standard to foster the development of impactful methods in AI-based drug discovery." 3 | site_url: "https://github.com/polaris-hub/polaris" 4 | repo_url: "https://github.com/polaris-hub/polaris" 5 | repo_name: "polaris-hub/polaris" 6 | copyright: Copyright 2023 - 2025 Polaris 7 | 8 | remote_branch: "gh-pages" 9 | use_directory_urls: false 10 | docs_dir: "docs" 11 | 12 | # Fail on warnings to detect issues with types and docstring 13 | strict: true 14 | 15 | nav: 16 | - Getting started: 17 | - Polaris: index.md 18 | - Quickstart: quickstart.md 19 | - Resources: resources.md 20 | - Tutorials: 21 | - Submit: 22 | - Submit to a Benchmark: tutorials/submit_to_benchmark.ipynb 23 | - Submit to a Competition: tutorials/submit_to_competition.ipynb 24 | - API Reference: 25 | - Load: api/load.md 26 | - Core: 27 | - Dataset: api/dataset.md 28 | - Benchmark: api/benchmark.md 29 | - Model: api/model.md 30 | - Competition: api/competition.md 31 | - Subset: api/subset.md 32 | - Evaluation: api/evaluation.md 33 | - Hub: 34 | - Client: api/hub.client.md 35 | - External Auth Client: api/hub.external_client.md 36 | - Additional: 37 | - Base classes: api/base.md 38 | - Types: api/utils.types.md 39 | - Community: https://discord.gg/vBFd8p6H7u 40 | - Polaris Hub: https://polarishub.io/ 41 | 42 | theme: 43 | name: material 44 | # NOTE(hadim): to customize the material primary and secondary color, 45 | # see check `docs/assets/css/custom-polaris.css`. 46 | palette: 47 | primary: deep purple 48 | accent: indigo 49 | 50 | features: 51 | - navigation.tabs 52 | - navigation.sections 53 | - navigation.path 54 | - navigation.top 55 | - navigation.footer 56 | - toc.follow 57 | - content.code.copy 58 | - content.code.annotate 59 | favicon: images/logo-black.svg 60 | logo: images/logo-white.svg 61 | 62 | extra_css: 63 | - assets/css/custom-polaris.css 64 | 65 | extra_javascript: 66 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 67 | 68 | markdown_extensions: 69 | - admonition 70 | - attr_list 71 | - md_in_html 72 | - tables 73 | - pymdownx.details 74 | - pymdownx.superfences 75 | - pymdownx.superfences 76 | - pymdownx.inlinehilite 77 | - pymdownx.snippets 78 | - pymdownx.superfences 79 | - pymdownx.tabbed: 80 | alternate_style: true 81 | - pymdownx.emoji: 82 | emoji_index: !!python/name:material.extensions.emoji.twemoji 83 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 84 | - pymdownx.highlight: 85 | anchor_linenums: true 86 | line_spans: __span 87 | pygments_lang_class: true 88 | 89 | watch: 90 | - polaris/ 91 | 92 | plugins: 93 | - search 94 | 95 | - mkdocstrings: 96 | handlers: 97 | python: 98 | setup_commands: 99 | - import sys 100 | - sys.path.append("docs") 101 | - sys.path.append("polaris") 102 | options: 103 | show_root_heading: yes 104 | heading_level: 3 105 | show_source: false 106 | group_by_category: true 107 | members_order: source 108 | separate_signature: true 109 | show_signature_annotations: true 110 | line_length: 80 111 | - mkdocs-jupyter: 112 | execute: False 113 | remove_tag_config: 114 | remove_cell_tags: [ remove_cell ] 115 | remove_all_outputs_tags: [ remove_output ] 116 | remove_input_tags: [ remove_input ] 117 | 118 | - mike: 119 | version_selector: true 120 | 121 | extra: 122 | version: 123 | # Multi versioning provider for mkdocs-material (used for the JS selector) 124 | provider: mike 125 | analytics: 126 | provider: google 127 | property: G-V4RP8SG194 128 | # Widget at the bottom of every page to collect information about the user experience 129 | # The data is collected in Google Analytics 130 | feedback: 131 | title: Was this page helpful? 132 | ratings: 133 | - icon: material/emoticon-happy-outline 134 | name: This page was helpful 135 | data: 1 136 | note: >- 137 | Thanks for your feedback! 138 | - icon: material/emoticon-sad-outline 139 | name: This page could be improved 140 | data: 0 141 | # NOTE (cwognum): It could be useful to have a link to a feedback form here 142 | note: >- 143 | Thanks for your feedback! 144 | consent: 145 | title: Cookie consent 146 | description: >- 147 | We use cookies to recognize your repeated visits and preferences, as well 148 | as to measure the effectiveness of our documentation and whether users 149 | find what they're searching for. With your consent, you're helping us to 150 | make our documentation better. 151 | -------------------------------------------------------------------------------- /tests/test_competition.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import pytest 4 | from pydantic import ValidationError 5 | 6 | from polaris.competition import CompetitionSpecification 7 | from polaris.utils.types import TaskType 8 | 9 | 10 | def test_competition_split_verification(test_competition): 11 | """Verifies that the split validation works as expected.""" 12 | 13 | obj = test_competition 14 | cls = CompetitionSpecification 15 | 16 | # By using the fixture as a default, we know it doesn't always fail 17 | default_kwargs = { 18 | "target_cols": obj.target_cols, 19 | "input_cols": obj.input_cols, 20 | "name": obj.name, 21 | "zarr_root_path": obj.zarr_root_path, 22 | "readme": obj.readme, 23 | "start_time": obj.start_time, 24 | "end_time": obj.end_time, 25 | "n_test_sets": obj.n_test_sets, 26 | "n_test_datapoints": obj.n_test_datapoints, 27 | "n_classes": obj.n_classes, 28 | } 29 | 30 | train_split = obj.split[0] 31 | test_split = obj.split[1] 32 | 33 | # One or more empty test partitions 34 | with pytest.raises(ValidationError): 35 | cls(split=(train_split,), **default_kwargs) 36 | with pytest.raises(ValidationError): 37 | cls(split=(train_split, []), **default_kwargs) 38 | with pytest.raises(ValidationError): 39 | cls(split=(train_split, {"test": []}), **default_kwargs) 40 | # Non-exclusive partitions 41 | with pytest.raises(ValidationError): 42 | cls(split=(train_split, test_split["test"] + train_split[:1]), **default_kwargs) 43 | with pytest.raises(ValidationError): 44 | cls(split=(train_split, {"test1": test_split, "test2": train_split[:1]}), **default_kwargs) 45 | # Invalid indices 46 | with pytest.raises(ValidationError): 47 | cls(split=(train_split + [len(obj)], test_split), **default_kwargs) 48 | with pytest.raises(ValidationError): 49 | cls(split=(train_split + [-1], test_split), **default_kwargs) 50 | with pytest.raises(ValidationError): 51 | cls(split=(train_split, test_split["test"] + [len(obj)]), **default_kwargs) 52 | with pytest.raises(ValidationError): 53 | cls(split=(train_split, test_split["test"] + [-1]), **default_kwargs) 54 | # Duplicate indices 55 | with pytest.raises(ValidationError): 56 | cls(split=(train_split + train_split[:1], test_split), **default_kwargs) 57 | with pytest.raises(ValidationError): 58 | cls(split=(train_split, test_split["test"] + test_split["test"][:1]), **default_kwargs) 59 | with pytest.raises(ValidationError): 60 | cls( 61 | split=(train_split, {"test1": test_split, "test2": test_split["test"] + test_split["test"][:1]}), 62 | **default_kwargs, 63 | ) 64 | 65 | # It should _not_ fail with duplicate indices across test partitions 66 | cls(split=(train_split, {"test1": test_split["test"], "test2": test_split["test"]}), **default_kwargs) 67 | # It should _not_ fail with missing indices 68 | cls(split=(train_split[:-1], test_split), **default_kwargs) 69 | # It should _not_ fail with an empty train set 70 | competition = cls(split=([], test_split), **default_kwargs) 71 | train, _ = competition.get_train_test_split() 72 | assert len(train) == 0 73 | 74 | 75 | def test_competition_metric_deserialization(test_competition): 76 | """Tests that passing metrics as a list of strings or dictionaries works as expected""" 77 | m = test_competition.model_dump() 78 | 79 | # Should work with strings 80 | m["metrics"] = ["mean_absolute_error", "accuracy"] 81 | m["main_metric"] = "accuracy" 82 | CompetitionSpecification(**m) 83 | 84 | # Should work with dictionaries 85 | m["metrics"] = [ 86 | {"label": "mean_absolute_error", "config": {"group_by": "CLASS_expt"}}, 87 | {"label": "accuracy"}, 88 | ] 89 | CompetitionSpecification(**m) 90 | 91 | 92 | def test_competition_train_test_split(test_competition): 93 | """Tests that the competition's train/test split can be retrieved through a CompetitionSpecification instance""" 94 | 95 | train, test = test_competition.get_train_test_split() 96 | 97 | train_split = test_competition.split[0] 98 | test_sets = test_competition.split[1] 99 | test_split = set(chain.from_iterable(test_sets.values())) 100 | 101 | assert len(train) == len(train_split) 102 | assert len(test) == len(test_split) 103 | 104 | 105 | def test_competition_computed_fields(test_competition): 106 | default_test_set_name = "test" 107 | assert test_competition.task_type == TaskType.SINGLE_TASK.value 108 | assert test_competition.test_set_labels == [default_test_set_name] 109 | assert test_competition.test_set_sizes == {default_test_set_name: 10} 110 | 111 | 112 | def test_competition_interface(test_competition): 113 | """Tests that the CompetitionSpecification class doesn't accidentally inherit the evaluate method from the benchmark class""" 114 | with pytest.raises(AttributeError): 115 | test_competition.evaluate() 116 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | # NOTE(hadim): unfortunately, we cannot use `polaris` on pypi. 7 | # See https://github.com/pypi/support/issues/2908 8 | name = "polaris-lib" 9 | description = "Client for the Polaris Hub." 10 | dynamic = ["version"] 11 | authors = [ 12 | { name = "Cas Wognum", email = "cas@valencelabs.com" }, 13 | { name = "Lu Zhu", email = "lu@valencelabs.com" }, 14 | { name = "Andrew Quirke", email = "andrew@valencelabs.com" }, 15 | { name = "Julien St-Laurent", email = "julien.stl@valencelabs.com" }, 16 | ] 17 | readme = "README.md" 18 | requires-python = ">=3.10,<3.13" 19 | classifiers = [ 20 | "Development Status :: 5 - Production/Stable", 21 | "Intended Audience :: Developers", 22 | "Intended Audience :: Healthcare Industry", 23 | "Intended Audience :: Science/Research", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | "Topic :: Scientific/Engineering :: Bio-Informatics", 26 | "Topic :: Scientific/Engineering :: Information Analysis", 27 | "Topic :: Scientific/Engineering :: Medical Science Apps.", 28 | "Natural Language :: English", 29 | "Operating System :: OS Independent", 30 | "Programming Language :: Python", 31 | "Programming Language :: Python :: 3", 32 | "Programming Language :: Python :: 3.10", 33 | "Programming Language :: Python :: 3.11", 34 | "Programming Language :: Python :: 3.12", 35 | ] 36 | dependencies = [ 37 | "authlib", 38 | "boto3 <1.36.0", 39 | "datamol >=0.12.1", 40 | "fastpdb", 41 | "fsspec[http]", 42 | "httpx", 43 | "numcodecs[msgpack] >=0.13.1, <0.16.0", 44 | "numpy <3", 45 | "pandas", 46 | "pyarrow < 18", 47 | "pydantic >=2", 48 | "pydantic-settings >=2", 49 | "pyroaring", 50 | "pyyaml", 51 | "rich>=13.9.4", 52 | "scikit-learn", 53 | "scipy", 54 | "seaborn", 55 | "typer", 56 | "typing-extensions>=4.12.0", 57 | "zarr >=2,<3", 58 | ] 59 | 60 | [dependency-groups] 61 | dev = [ 62 | "ipywidgets", 63 | "jupyterlab", 64 | "moto[s3]>=5.0.14", 65 | "pytest >=7", 66 | "pytest-xdist", 67 | "pytest-cov", 68 | "ruff", 69 | ] 70 | doc = [ 71 | "mkdocs", 72 | "mkdocs-material >=9.4.7", 73 | "mkdocstrings", 74 | "mkdocstrings-python", 75 | "mkdocs-jupyter", 76 | "markdown-include", 77 | "mdx_truly_sane_lists", 78 | "mike >=1.0.0", 79 | "nbconvert", 80 | ] 81 | codecs = [ 82 | "imagecodecs", 83 | ] 84 | 85 | # PEP 735 Dependency Groups are not well-supported by pip. 86 | # Duplicate them here with the older syntax. 87 | [project.optional-dependencies] 88 | dev = [ 89 | "ipywidgets", 90 | "jupyterlab", 91 | "moto[s3]>=5.0.14", 92 | "pytest >=7", 93 | "pytest-xdist", 94 | "pytest-cov", 95 | "ruff", 96 | ] 97 | doc = [ 98 | "mkdocs", 99 | "mkdocs-material >=9.4.7", 100 | "mkdocstrings", 101 | "mkdocstrings-python", 102 | "mkdocs-jupyter", 103 | "markdown-include", 104 | "mdx_truly_sane_lists", 105 | "mike >=1.0.0", 106 | "nbconvert", 107 | ] 108 | codecs = [ 109 | "imagecodecs", 110 | ] 111 | 112 | [project.scripts] 113 | polaris = "polaris.cli:app" 114 | 115 | [project.urls] 116 | Website = "https://polarishub.io/" 117 | "Source Code" = "https://github.com/polaris-hub/polaris" 118 | "Bug Tracker" = "https://github.com/polaris-hub/polaris/issues" 119 | Documentation = "https://polaris-hub.github.io/polaris/" 120 | 121 | [tool.setuptools] 122 | include-package-data = true 123 | 124 | [tool.setuptools_scm] 125 | fallback_version = "0.0.0.dev1" 126 | 127 | [tool.setuptools.packages.find] 128 | where = ["."] 129 | include = ["polaris", "polaris.*"] 130 | exclude = [] 131 | namespaces = false 132 | 133 | [tool.pytest.ini_options] 134 | minversion = "7.0" 135 | addopts = "--verbose --durations=10 -n auto --cov=polaris --cov-fail-under=75 --cov-report xml --cov-report term-missing" 136 | testpaths = ["tests"] 137 | pythonpath = "." 138 | filterwarnings = ["ignore::DeprecationWarning:jupyter_client.connect.*:"] 139 | 140 | [tool.coverage.run] 141 | source = ["polaris/"] 142 | disable_warnings = ["no-data-collected"] 143 | data_file = ".coverage/coverage" 144 | 145 | [tool.coverage.report] 146 | omit = [ 147 | "polaris/__init__.py", 148 | "polaris/_version.py", 149 | # We cannot yet test the interaction with the Hub. 150 | # See e.g. https://github.com/polaris-hub/polaris/issues/30 151 | "polaris/hub/client.py", 152 | "polaris/hub/external_client.py", 153 | "polaris/hub/settings.py", 154 | "polaris/hub/oauth.py", 155 | "polaris/hub/storage.py", 156 | "polaris/hub/__init__.py", 157 | ] 158 | 159 | [tool.coverage.xml] 160 | output = "coverage.xml" 161 | 162 | [tool.ruff] 163 | lint.ignore = [ 164 | "E501", # Never enforce `E501` (line length violations). 165 | ] 166 | 167 | lint.per-file-ignores."__init__.py" = [ 168 | "F401", # imported but unused 169 | "E402", # Module level import not at top of file 170 | ] 171 | line-length = 110 172 | target-version = "py310" 173 | extend-exclude = ["*.ipynb"] 174 | -------------------------------------------------------------------------------- /polaris/_artifact.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Annotated, ClassVar, Literal, Optional 4 | 5 | import fsspec 6 | from packaging.version import Version 7 | from pydantic import ( 8 | BaseModel, 9 | ConfigDict, 10 | Field, 11 | computed_field, 12 | field_serializer, 13 | field_validator, 14 | ) 15 | from pydantic.alias_generators import to_camel 16 | from typing_extensions import Self 17 | 18 | import polaris 19 | from polaris.utils.misc import build_urn, slugify 20 | from polaris.utils.types import ArtifactUrn, HubOwner, SlugCompatibleStringType, SlugStringType 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class BaseArtifactModel(BaseModel): 26 | """ 27 | Base class for all artifacts on the Hub. Specifies metadata that is used by the Hub. 28 | 29 | Info: Optional 30 | Despite all artifacts basing this class, note that all attributes are optional. 31 | This ensures the library can be used without the Polaris Hub. 32 | Only when uploading to the Hub, some of the attributes are required. 33 | 34 | Attributes: 35 | name: A slug-compatible name for the artifact. 36 | Together with the owner, this is used by the Hub to uniquely identify the artifact. 37 | description: A beginner-friendly, short description of the artifact. 38 | tags: A list of tags to categorize the artifact by. This is used by the Hub to search over artifacts. 39 | user_attributes: A dict with additional, textual user attributes. 40 | owner: A slug-compatible name for the owner of the artifact. 41 | If the artifact comes from the Polaris Hub, this is the associated owner (organization or user). 42 | Together with the name, this is used by the Hub to uniquely identify the artifact. 43 | polaris_version: The version of the Polaris library that was used to create the artifact. 44 | """ 45 | 46 | _version: ClassVar[Literal[1]] = 1 47 | _artifact_type: ClassVar[str] 48 | 49 | model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, arbitrary_types_allowed=True) 50 | 51 | # Model attributes 52 | name: SlugCompatibleStringType | None = None 53 | description: str = "" 54 | tags: list[str] = Field(default_factory=list) 55 | user_attributes: dict[str, str] = Field(default_factory=dict) 56 | owner: HubOwner | None = None 57 | polaris_version: str = polaris.__version__ 58 | slug: Annotated[Optional[SlugStringType], Field(validate_default=True)] = None 59 | 60 | @field_validator("slug") 61 | def _validate_slug(cls, val: Optional[str], info) -> SlugStringType | None: 62 | # A slug may be None when an artifact is created locally 63 | if val is None: 64 | if info.data.get("name") is not None: 65 | return slugify(info.data.get("name")) 66 | return val 67 | 68 | @computed_field 69 | @property 70 | def artifact_id(self) -> str | None: 71 | if self.owner and self.slug: 72 | return f"{self.owner}/{self.slug}" 73 | return None 74 | 75 | @computed_field 76 | @property 77 | def urn(self) -> ArtifactUrn | None: 78 | if self.owner and self.slug: 79 | return self.urn_for(self.owner, self.slug) 80 | return None 81 | 82 | @computed_field 83 | @property 84 | def version(self) -> int: 85 | return self._version 86 | 87 | @field_validator("polaris_version") 88 | @classmethod 89 | def _validate_version(cls, value: str) -> str: 90 | if value != "dev": 91 | # Make sure it is a valid semantic version 92 | Version(value) 93 | 94 | current_version = polaris.__version__ 95 | if value != current_version: 96 | logger.info( 97 | f"The version of Polaris that was used to create the artifact ({value}) is different " 98 | f"from the currently installed version of Polaris ({current_version})." 99 | ) 100 | return value 101 | 102 | @field_validator("owner", mode="before") 103 | @classmethod 104 | def _validate_owner(cls, value: str | HubOwner | None): 105 | if isinstance(value, str): 106 | return HubOwner(slug=value) 107 | return value 108 | 109 | @field_serializer("owner") 110 | def _serialize_owner(self, value: HubOwner) -> str | None: 111 | return value.slug if value else None 112 | 113 | @classmethod 114 | def from_json(cls, path: str) -> Self: 115 | """Loads an artifact from a JSON file. 116 | 117 | Args: 118 | path: Path to a JSON file containing the artifact definition. 119 | """ 120 | with fsspec.open(path, "r") as f: 121 | data = json.load(f) 122 | return cls.model_validate(data) 123 | 124 | def to_json(self, path: str) -> None: 125 | """Saves an artifact to a JSON file. 126 | 127 | Args: 128 | path: Path to save the artifact definition as JSON. 129 | """ 130 | with fsspec.open(path, "w") as f: 131 | f.write(self.model_dump_json()) 132 | 133 | @classmethod 134 | def urn_for(cls, owner: str | HubOwner, slug: str) -> ArtifactUrn: 135 | return build_urn(cls._artifact_type, owner, slug) 136 | -------------------------------------------------------------------------------- /docs/tutorials/submit_to_competition.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "40f99374-b47e-4f84-bdb9-148a11f9c07d", 6 | "metadata": { 7 | "editable": true, 8 | "slideshow": { 9 | "slide_type": "" 10 | }, 11 | "tags": [] 12 | }, 13 | "source": [ 14 | "On Polaris, submitting to a competition is very similar to submitting to a benchmark. \n", 15 | "\n", 16 | "The main difference lies in how predictions are prepared and how they are evaluated" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 8, 22 | "id": "3d66f466", 23 | "metadata": { 24 | "editable": true, 25 | "slideshow": { 26 | "slide_type": "" 27 | }, 28 | "tags": [ 29 | "remove_cell" 30 | ] 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "# Note: Cell is tagged to not show up in the mkdocs build\n", 35 | "%load_ext autoreload\n", 36 | "%autoreload 2" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "66cd175c-1f8d-4209-ad78-8d959ea31d9f", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "import polaris as po" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "84b6d1b9-3ee8-4ff4-9d92-8ed91ffa2f51", 52 | "metadata": {}, 53 | "source": [ 54 | "## Login\n", 55 | "As before, we first need to authenticate ourselves using our Polaris account. If you don't have an account yet, you can create one [here](https://polarishub.io/sign-up)." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "9b465ea4-7c71-443b-9908-3f9e567ee4c4", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from polaris.hub.client import PolarisHubClient\n", 66 | "\n", 67 | "with PolarisHubClient() as client:\n", 68 | " client.login()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "5edee39f-ce29-4ae6-91ce-453d9190541b", 74 | "metadata": {}, 75 | "source": [ 76 | "## Load the Competition\n", 77 | "As with regular benchmarks, a competition is identified by the `owner/slug` id." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 10, 83 | "id": "4e004589-6c48-4232-b353-b1700536dde6", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "competition = po.load_competition(\"polaris/hello-world-competition\")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "36f3e829", 93 | "metadata": {}, 94 | "source": [ 95 | "## The Competition API\n", 96 | "Similar to the benchmark API, the competition exposes two main API endpoints:\n", 97 | "\n", 98 | "- `get_train_test_split()`, which does exactly the same as for benchmarks. \n", 99 | "- `submit_predictions()`, which is used to submit your predictions to a competition.\n", 100 | "\n", 101 | "Note that different from regular benchmarks, competitions don't have an `evaluate()` endpoint. \n", 102 | "\n", 103 | "That's because the evaluation happens server side. This gives the competition organizers precise control over how and when the test set and associated results get published, providing a unique opportunity for unbiased evaluation and comparison of different methods.\n", 104 | "\n", 105 | "### Submit your _predictions_\n", 106 | "Similar to your actual results, you can also provide metadata about your predictions." 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "2b36e09b", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "competition.submit_predictions(\n", 117 | " predictions=predictions,\n", 118 | " prediction_name=\"my-first-predictions\",\n", 119 | " prediction_owner=\"my-username\",\n", 120 | " report_url=\"https://www.example.com\", \n", 121 | " # The below metadata is optional, but recommended.\n", 122 | " github_url=\"https://github.com/polaris-hub/polaris\",\n", 123 | " description=\"Just testing the Polaris API here!\",\n", 124 | " tags=[\"tutorial\"],\n", 125 | " user_attributes={\"Framework\": \"Scikit-learn\", \"Method\": \"Gradient Boosting\"}\n", 126 | ")" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "44973556", 132 | "metadata": {}, 133 | "source": [ 134 | "That's it! Just like that you have partaken in your first Polaris competition. \n", 135 | "\n", 136 | "
\n", 137 | "

Where are my results?

\n", 138 | "

The results will only be published at predetermined intervals, as detailed in the competition details. Keep an eye on that leaderboard when it goes public and best of luck!

\n", 139 | "
\n", 140 | "\n", 141 | "\n", 142 | "---\n", 143 | "\n", 144 | "The End. " 145 | ] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3 (ipykernel)", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.12.8" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 5 169 | } 170 | -------------------------------------------------------------------------------- /polaris/hub/oauth.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime, timedelta, timezone 3 | from pathlib import Path 4 | from time import time 5 | from typing import Any, Literal 6 | 7 | from authlib.integrations.httpx_client import OAuth2Auth 8 | from pydantic import BaseModel, Field, PositiveInt, model_validator 9 | from typing_extensions import Self 10 | 11 | from polaris.utils.constants import DEFAULT_CACHE_DIR 12 | from polaris.utils.types import AnyUrlString, HttpUrlString 13 | 14 | 15 | class CachedTokenAuth(OAuth2Auth): 16 | """ 17 | A combination of an authlib token and a httpx auth class, that will cache the token to a file. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | token: dict | None = None, 23 | token_placement="header", 24 | client=None, 25 | cache_dir=DEFAULT_CACHE_DIR, 26 | filename="hub_auth_token.json", 27 | ): 28 | self.token_cache_path = Path(cache_dir) / filename 29 | 30 | if token is None and self.token_cache_path.exists(): 31 | token = json.loads(self.token_cache_path.read_text()) 32 | 33 | super().__init__(token, token_placement, client) 34 | 35 | def set_token(self, token: dict): 36 | super().set_token(token) 37 | 38 | # Ensure the cache directory exists. 39 | self.token_cache_path.parent.mkdir(parents=True, exist_ok=True) 40 | 41 | # We cache afterward, because the token setter adds fields we need to save (i.e. expires_at). 42 | self.token_cache_path.write_text(json.dumps(token)) 43 | 44 | 45 | class ExternalCachedTokenAuth(CachedTokenAuth): 46 | """ 47 | Cached token for external authentication. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | token: dict | None = None, 53 | token_placement="header", 54 | client=None, 55 | cache_dir=DEFAULT_CACHE_DIR, 56 | filename="external_auth_token.json", 57 | ): 58 | super().__init__(token, token_placement, client, cache_dir, filename) 59 | 60 | 61 | class ArtifactPaths(BaseModel): 62 | """ 63 | Base model class for artifact paths. 64 | Offer convenience properties to access with paths are files or stores. 65 | """ 66 | 67 | @property 68 | def files(self) -> list[str]: 69 | return [ 70 | field 71 | for field, field_info in self.model_fields.items() 72 | if (field_info.json_schema_extra or {}).get("file") 73 | ] 74 | 75 | @property 76 | def stores(self) -> list[str]: 77 | return [ 78 | field 79 | for field, field_info in self.model_fields.items() 80 | if (field_info.json_schema_extra or {}).get("store") 81 | ] 82 | 83 | 84 | class DatasetV1Paths(ArtifactPaths): 85 | root: AnyUrlString = Field(json_schema_extra={"file": True}) 86 | extension: AnyUrlString | None = Field(None, json_schema_extra={"store": True}) 87 | 88 | 89 | class DatasetV2Paths(ArtifactPaths): 90 | # Discriminator field used to identify this as a dataset-v2 type when deserializing paths 91 | type: Literal["dataset-v2"] = "dataset-v2" 92 | root: AnyUrlString = Field(json_schema_extra={"store": True}) 93 | manifest: AnyUrlString = Field(json_schema_extra={"file": True}) 94 | 95 | 96 | class BenchmarkV2Paths(ArtifactPaths): 97 | training: AnyUrlString = Field(json_schema_extra={"file": True}) 98 | test_sets: dict[str, AnyUrlString] = Field(json_schema_extra={"file": True}) 99 | 100 | 101 | class PredictionPaths(ArtifactPaths): 102 | # Discriminator field used to identify this as a prediction type when deserializing paths 103 | type: Literal["prediction"] = "prediction" 104 | root: AnyUrlString = Field(json_schema_extra={"store": True}) 105 | manifest: AnyUrlString = Field(json_schema_extra={"file": True}) 106 | 107 | 108 | class StorageTokenData(BaseModel): 109 | key: str 110 | secret: str 111 | endpoint: HttpUrlString 112 | paths: DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths | PredictionPaths = Field(union_mode="smart") 113 | 114 | 115 | class HubOAuth2Token(BaseModel): 116 | """ 117 | Model to parse and validate tokens obtained from the Polaris Hub. 118 | """ 119 | 120 | issued_token_type: Literal["urn:ietf:params:oauth:token-type:jwt"] = ( 121 | "urn:ietf:params:oauth:token-type:jwt" 122 | ) 123 | token_type: Literal["Bearer"] = "Bearer" 124 | expires_in: PositiveInt | None = None 125 | expires_at: datetime | None = None 126 | access_token: str 127 | extra_data: None 128 | 129 | @model_validator(mode="after") 130 | def set_expires_at(self) -> Self: 131 | if self.expires_at is None and self.expires_in is not None: 132 | self.expires_at = datetime.fromtimestamp(time() + self.expires_in, timezone.utc) 133 | return self 134 | 135 | def is_expired(self, leeway=60) -> bool | None: 136 | if not self.expires_at: 137 | return None 138 | # Small timedelta to consider token as expired before it actually expires 139 | expiration_threshold = self.expires_at - timedelta(seconds=leeway) 140 | return datetime.now(timezone.utc) >= expiration_threshold 141 | 142 | def __getitem__(self, item) -> Any | None: 143 | """ 144 | Compatibility with authlib's expectation that this is a dict 145 | """ 146 | return getattr(self, item) 147 | 148 | 149 | class HubStorageOAuth2Token(HubOAuth2Token): 150 | """ 151 | Specialized model for storage tokens. 152 | """ 153 | 154 | token_type: Literal["Storage"] = "Storage" 155 | extra_data: StorageTokenData 156 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /polaris/prediction/_predictions_v2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import shutil 5 | from pathlib import Path 6 | import tempfile 7 | 8 | import numpy as np 9 | import zarr 10 | from pydantic import ( 11 | PrivateAttr, 12 | model_validator, 13 | ) 14 | 15 | from polaris.utils.zarr._manifest import generate_zarr_manifest, calculate_file_md5 16 | from polaris.evaluate import ResultsMetadataV2 17 | from polaris.evaluate._predictions import BenchmarkPredictions 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class BenchmarkPredictionsV2(BenchmarkPredictions, ResultsMetadataV2): 23 | """ 24 | Prediction artifact for uploading predictions to a Benchmark V2. 25 | Stores predictions as a Zarr archive, with manifest and metadata for reproducibility and integrity. 26 | In addition to the predictions data, it contains metadata that describes how these predictions 27 | were generated, including the model used and contributors involved. 28 | 29 | Attributes: 30 | dataset_zarr_root: The zarr root of the dataset, used for dtype validation and as template for zarr arrays. 31 | benchmark_artifact_id: The artifact ID of the benchmark these predictions are for. 32 | 33 | For additional metadata attributes, see the base classes. 34 | """ 35 | 36 | dataset_zarr_root: zarr.Group 37 | benchmark_artifact_id: str 38 | _artifact_type = "prediction" 39 | _zarr_root_path: str | None = PrivateAttr(None) 40 | _zarr_manifest_path: str | None = PrivateAttr(None) 41 | _zarr_manifest_md5sum: str | None = PrivateAttr(None) 42 | _zarr_root: zarr.Group | None = PrivateAttr(None) 43 | _temp_dir: str | None = PrivateAttr(None) 44 | 45 | @model_validator(mode="after") 46 | def check_prediction_dtypes(self): 47 | dataset_root = self.dataset_zarr_root 48 | for test_set_label, test_set_predictions in self.predictions.items(): 49 | for col, preds in test_set_predictions.items(): 50 | dataset_array = dataset_root[col] 51 | arr = np.asarray(preds) 52 | if arr.dtype != dataset_array.dtype: 53 | raise ValueError( 54 | f"Dtype mismatch for column '{col}' in test set '{test_set_label}': " 55 | f"predictions dtype {arr.dtype} != dataset dtype {dataset_array.dtype}" 56 | ) 57 | return self 58 | 59 | def to_zarr(self) -> Path: 60 | """Create a Zarr archive from the predictions dictionary. 61 | 62 | This method should be called explicitly when ready to write predictions to disk. 63 | """ 64 | root = self.zarr_root 65 | dataset_root = self.dataset_zarr_root 66 | 67 | for test_set_label, test_set_predictions in self.predictions.items(): 68 | # Create a group for each test set 69 | test_set_group = root.require_group(test_set_label) 70 | for col in self.target_labels: 71 | data = test_set_predictions[col] 72 | template = dataset_root[col] 73 | test_set_group.array( 74 | name=col, 75 | data=data, 76 | dtype=template.dtype, 77 | compressor=template.compressor, 78 | filters=template.filters, 79 | chunks=template.chunks, 80 | overwrite=True, 81 | ) 82 | 83 | return Path(self.zarr_root_path) 84 | 85 | @property 86 | def zarr_root(self) -> zarr.Group: 87 | """Get the zarr Group object corresponding to the root, creating it if it doesn't exist.""" 88 | if self._zarr_root is None: 89 | store = zarr.DirectoryStore(self.zarr_root_path) 90 | self._zarr_root = zarr.group(store=store) 91 | return self._zarr_root 92 | 93 | @property 94 | def zarr_root_path(self) -> str: 95 | """Get the path to the Zarr archive root.""" 96 | if self._zarr_root_path is None: 97 | # Create a temporary directory if not already set 98 | if self._temp_dir is None: 99 | self._temp_dir = tempfile.mkdtemp(prefix="polaris_predictions_") 100 | self._zarr_root_path = str(Path(self._temp_dir) / "predictions.zarr") 101 | return self._zarr_root_path 102 | 103 | @property 104 | def columns(self): 105 | return list(self.zarr_root.keys()) 106 | 107 | @property 108 | def n_rows(self): 109 | cols = self.columns 110 | if not cols: 111 | raise ValueError("No columns found in predictions archive.") 112 | example = self.zarr_root[cols[0]] 113 | return len(example) 114 | 115 | @property 116 | def rows(self): 117 | return range(self.n_rows) 118 | 119 | @property 120 | def zarr_manifest_path(self): 121 | if self._zarr_manifest_path is None: 122 | # Use the temp directory as the output directory 123 | zarr_manifest_path = generate_zarr_manifest(self.zarr_root_path, self._temp_dir) 124 | self._zarr_manifest_path = zarr_manifest_path 125 | return self._zarr_manifest_path 126 | 127 | @property 128 | def zarr_manifest_md5sum(self): 129 | if not self.has_zarr_manifest_md5sum: 130 | logger.info("Computing the checksum. This can be slow for large predictions archives.") 131 | self.zarr_manifest_md5sum = calculate_file_md5(self.zarr_manifest_path) 132 | return self._zarr_manifest_md5sum 133 | 134 | @zarr_manifest_md5sum.setter 135 | def zarr_manifest_md5sum(self, value: str): 136 | if not re.fullmatch(r"^[a-f0-9]{32}$", value): 137 | raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") 138 | self._zarr_manifest_md5sum = value 139 | 140 | @property 141 | def has_zarr_manifest_md5sum(self): 142 | return self._zarr_manifest_md5sum is not None 143 | 144 | def __repr__(self): 145 | return self.model_dump_json(by_alias=True, indent=2) 146 | 147 | def __str__(self): 148 | return self.__repr__() 149 | 150 | def __del__(self) -> None: 151 | if hasattr(self, "_temp_dir") and self._temp_dir and os.path.exists(self._temp_dir): 152 | shutil.rmtree(self._temp_dir) 153 | -------------------------------------------------------------------------------- /polaris/utils/zarr/codecs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastpdb import struc 3 | from numcodecs import MsgPack, register_codec 4 | from numcodecs.vlen import VLenBytes 5 | from rdkit import Chem 6 | 7 | 8 | class RDKitMolCodec(VLenBytes): 9 | """ 10 | Codec for RDKit's Molecules. 11 | 12 | Info: Binary strings for serialization 13 | This class converts the molecules to binary strings (for ML purposes, this should be lossless). 14 | This might not be the most storage efficient, but is fastest and easiest to maintain. 15 | See this [Github Discussion](https://github.com/rdkit/rdkit/discussions/7235) for more info. 16 | 17 | """ 18 | 19 | codec_id = "rdkit_mol" 20 | 21 | def encode(self, buf: np.ndarray): 22 | """ 23 | Encode a chunk of RDKit Mols to byte strings 24 | """ 25 | # NOTE (cwognum): I ran into a Cython issue because we could pass None to the VLenBytes codec. 26 | # Using np.full() ensures all elements are initialized as empty byte strings instead. 27 | to_encode = np.full(fill_value=b"", shape=len(buf), dtype=object) 28 | for idx, mol in enumerate(buf): 29 | if mol is None or (isinstance(mol, bytes) and len(mol) == 0): 30 | continue 31 | if not isinstance(mol, Chem.Mol): 32 | raise ValueError(f"Expected an RDKitMol, but got {type(mol)} instead.") 33 | props = Chem.PropertyPickleOptions.AllProps 34 | to_encode[idx] = mol.ToBinary(props) 35 | 36 | to_encode = np.array(to_encode, dtype=object) 37 | return super().encode(to_encode) 38 | 39 | def decode(self, buf, out=None): 40 | """Decode the variable length bytes encoded data into a RDKit Mol.""" 41 | dec = super().decode(buf, out) 42 | for idx, mol in enumerate(dec): 43 | if len(mol) == 0: 44 | continue 45 | dec[idx] = Chem.Mol(mol) 46 | 47 | if out is not None: 48 | np.copyto(out, dec) 49 | return out 50 | else: 51 | return dec 52 | 53 | 54 | class AtomArrayCodec(MsgPack): 55 | """ 56 | Codec for FastPDB (i.e. Biotite) Atom Arrays. 57 | 58 | Info: Only the most essential structural information of a protein is retained 59 | This conversion saves the 3D coordinates, chain ID, residue ID, insertion code, residue name, heteroatom indicator, atom name, element, atom ID, B-factor, occupancy, and charge. 60 | Records such as CONECT (connectivity information), ANISOU (anisotropic Temperature Factors), HETATM (heteroatoms and ligands) are handled by `fastpdb`. 61 | We believe this makes for a good _ML-ready_ format, but let us know if you require any other information to be saved. 62 | 63 | 64 | Info: PDBs as ND-arrays using `biotite` 65 | To save PDBs in a Polaris-compatible format, we convert them to ND-arrays using `fastpdb` and `biotite`. 66 | We then save these ND-arrays to Zarr archives. 67 | For more info, see [fastpdb](https://github.com/biotite-dev/fastpdb) 68 | and [biotite](https://github.com/biotite-dev/biotite/blob/main/src/biotite/structure/atoms.py) 69 | 70 | This codec is a subclass of the `MsgPack` codec from the `numcodecs` 71 | """ 72 | 73 | codec_id = "atom_array" 74 | 75 | def encode(self, buf: np.ndarray): 76 | """ 77 | Encode a chunk of AtomArrays to a plain Python structure that MsgPack can encode 78 | """ 79 | 80 | to_pack = np.empty_like(buf) 81 | 82 | for idx, atom_array in enumerate(buf): 83 | # A chunk can have missing values 84 | if atom_array is None: 85 | continue 86 | 87 | if not isinstance(atom_array, struc.AtomArray): 88 | raise ValueError(f"Expected an AtomArray, but got {type(atom_array)} instead") 89 | 90 | data = { 91 | "coord": atom_array.coord, 92 | "chain_id": atom_array.chain_id, 93 | "res_id": atom_array.res_id, 94 | "ins_code": atom_array.ins_code, 95 | "res_name": atom_array.res_name, 96 | "hetero": atom_array.hetero, 97 | "atom_name": atom_array.atom_name, 98 | "element": atom_array.element, 99 | "atom_id": atom_array.atom_id, 100 | "b_factor": atom_array.b_factor, 101 | "occupancy": atom_array.occupancy, 102 | "charge": atom_array.charge, 103 | } 104 | data = {k: v.tolist() for k, v in data.items()} 105 | to_pack[idx] = data 106 | 107 | return super().encode(to_pack) 108 | 109 | def decode(self, buf, out=None): 110 | """Decode the MsgPack decoded data into a `fastpdb` AtomArray.""" 111 | 112 | dec = super().decode(buf, out) 113 | 114 | structs = np.empty(shape=len(dec), dtype=object) 115 | 116 | for idx, data in enumerate(dec): 117 | if data is None: 118 | continue 119 | 120 | atom_array = [] 121 | array_length = len(data["coord"]) 122 | 123 | for ind in range(array_length): 124 | atom = struc.Atom( 125 | coord=data["coord"][ind], 126 | chain_id=data["chain_id"][ind], 127 | res_id=data["res_id"][ind], 128 | ins_code=data["ins_code"][ind], 129 | res_name=data["res_name"][ind], 130 | hetero=data["hetero"][ind], 131 | atom_name=data["atom_name"][ind], 132 | element=data["element"][ind], 133 | b_factor=data["b_factor"][ind], 134 | occupancy=data["occupancy"][ind], 135 | charge=data["charge"][ind], 136 | atom_id=data["atom_id"][ind], 137 | ) 138 | atom_array.append(atom) 139 | 140 | # Note that this is a `fastpdb` AtomArray, not a NumPy array. 141 | structs[idx] = struc.array(atom_array) 142 | 143 | if out is not None: 144 | np.copyto(out, structs) 145 | return out 146 | else: 147 | return structs 148 | 149 | 150 | register_codec(RDKitMolCodec) 151 | register_codec(AtomArrayCodec) 152 | -------------------------------------------------------------------------------- /polaris/utils/dict2html.py: -------------------------------------------------------------------------------- 1 | """ 2 | JSON 2 HTML Converter 3 | 4 | (c) Varun Malhotra 2013-2021 5 | Source Code: https://github.com/softvar/json2html 6 | 7 | Contributors: 8 | 1. Michel Müller (@muellermichel), https://github.com/softvar/json2html/pull/2 9 | 2. Daniel Lekic (@lekic), https://github.com/softvar/json2html/pull/17 10 | 11 | LICENSE: MIT 12 | 13 | - 30/06/2023: adapted from: https://github.com/softvar/json2html/blob/e1feea273b210d11e4b2e59b9778a0bb4845fbd4/json2html/jsonconv.py 14 | """ 15 | 16 | from html import escape as html_escape 17 | 18 | text = str 19 | text_types = (str,) 20 | 21 | 22 | class Dict2Html: 23 | def convert( 24 | self, 25 | data_dict: dict, 26 | table_attributes='border="1"', 27 | clubbing=True, 28 | encode=False, 29 | escape=True, 30 | ): 31 | """ 32 | Convert JSON to HTML Table format 33 | """ 34 | # table attributes such as class, id, data-attr-*, etc. 35 | # eg: table_attributes = 'class = "table table-bordered sortable"' 36 | self.table_init_markup = "" % table_attributes 37 | self.clubbing = clubbing 38 | self.escape = escape 39 | json_input = data_dict 40 | converted = self.convert_json_node(json_input) 41 | if encode: 42 | return converted.encode("ascii", "xmlcharrefreplace") 43 | return converted 44 | 45 | def column_headers_from_list_of_dicts(self, json_input): 46 | """ 47 | This method is required to implement clubbing. 48 | It tries to come up with column headers for your input 49 | """ 50 | if not json_input or not hasattr(json_input, "__getitem__") or not hasattr(json_input[0], "keys"): 51 | return None 52 | column_headers = json_input[0].keys() 53 | for entry in json_input: 54 | if ( 55 | not hasattr(entry, "keys") 56 | or not hasattr(entry, "__iter__") 57 | or len(entry.keys()) != len(column_headers) 58 | ): 59 | return None 60 | for header in column_headers: 61 | if header not in entry: 62 | return None 63 | return column_headers 64 | 65 | def convert_json_node(self, json_input): 66 | """ 67 | Dispatch JSON input according to the outermost type and process it 68 | to generate the super awesome HTML format. 69 | We try to adhere to duck typing such that users can just pass all kinds 70 | of funky objects to json2html that *behave* like dicts and lists and other 71 | basic JSON types. 72 | """ 73 | if type(json_input) in text_types: 74 | if self.escape: 75 | return html_escape(text(json_input)) 76 | else: 77 | return text(json_input) 78 | if hasattr(json_input, "items"): 79 | return self.convert_object(json_input) 80 | if hasattr(json_input, "__iter__") and hasattr(json_input, "__getitem__"): 81 | return self.convert_list(json_input) 82 | return text(json_input) 83 | 84 | def convert_list(self, list_input): 85 | """ 86 | Iterate over the JSON list and process it 87 | to generate either an HTML table or a HTML list, depending on what's inside. 88 | If suppose some key has array of objects and all the keys are same, 89 | instead of creating a new row for each such entry, 90 | club such values, thus it makes more sense and more readable table. 91 | 92 | @example: 93 | jsonObject = { 94 | "sampleData": [ 95 | {"a":1, "b":2, "c":3}, 96 | {"a":5, "b":6, "c":7} 97 | ] 98 | } 99 | OUTPUT: 100 | _____________________________ 101 | | | | | | 102 | | | a | c | b | 103 | | sampleData |---|---|---| 104 | | | 1 | 3 | 2 | 105 | | | 5 | 7 | 6 | 106 | ----------------------------- 107 | 108 | @contributed by: @muellermichel 109 | """ 110 | if not list_input: 111 | return "" 112 | converted_output = "" 113 | column_headers = None 114 | if self.clubbing: 115 | column_headers = self.column_headers_from_list_of_dicts(list_input) 116 | if column_headers is not None: 117 | converted_output += self.table_init_markup 118 | converted_output += "" 119 | converted_output += "" 120 | converted_output += "" 121 | converted_output += "" 122 | for list_entry in list_input: 123 | converted_output += "" 128 | converted_output += "" 129 | converted_output += "
" + "".join(column_headers) + "
" 124 | converted_output += "".join( 125 | [self.convert_json_node(list_entry[column_header]) for column_header in column_headers] 126 | ) 127 | converted_output += "
" 130 | return converted_output 131 | 132 | # so you don't want or need clubbing eh? This makes @muellermichel very sad... ;( 133 | # alright, let's fall back to a basic list here... 134 | converted_output = "" 137 | return converted_output 138 | 139 | def convert_object(self, json_input): 140 | """ 141 | Iterate over the JSON object and process it 142 | to generate the super awesome HTML Table format 143 | """ 144 | if not json_input: 145 | return "" # avoid empty tables 146 | converted_output = self.table_init_markup + "" 147 | converted_output += "".join( 148 | [ 149 | "%s%s" % (self.convert_json_node(k), self.convert_json_node(v)) 150 | for k, v in json_input.items() 151 | ] 152 | ) 153 | converted_output += "" 154 | return converted_output 155 | 156 | 157 | dict2html = Dict2Html().convert 158 | -------------------------------------------------------------------------------- /tests/test_zarr_checksum.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code in this file is based on the zarr-checksum package 3 | 4 | Mainted by Jacob Nesbitt, released under the DANDI org on Github 5 | and with Kitware, Inc. credited as the author. This code is released 6 | with the Apache 2.0 license. 7 | 8 | See also: https://github.com/dandi/zarr_checksum 9 | 10 | Instead of adding the package as a dependency, we opted to copy over the code 11 | because it is a small and self-contained module that we will want to alter to 12 | support our Polaris code base. 13 | 14 | NOTE: We have made some modifications to the original code. 15 | 16 | ---- 17 | 18 | Copyright 2023 Kitware, Inc. 19 | 20 | Licensed under the Apache License, Version 2.0 (the "License"); 21 | you may not use this file except in compliance with the License. 22 | You may obtain a copy of the License at 23 | 24 | http://www.apache.org/licenses/LICENSE-2.0 25 | 26 | Unless required by applicable law or agreed to in writing, software 27 | distributed under the License is distributed on an "AS IS" BASIS, 28 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 29 | See the License for the specific language governing permissions and 30 | limitations under the License. 31 | """ 32 | 33 | import os 34 | import uuid 35 | from pathlib import Path 36 | from shutil import copytree, rmtree 37 | 38 | import pytest 39 | import zarr 40 | 41 | from polaris.utils.zarr._checksum import ( 42 | EMPTY_CHECKSUM, 43 | InvalidZarrChecksum, 44 | _ZarrChecksum, 45 | _ZarrChecksumManifest, 46 | _ZarrChecksumTree, 47 | _ZarrDirectoryDigest, 48 | compute_zarr_checksum, 49 | ) 50 | 51 | 52 | def test_generate_digest() -> None: 53 | manifest = _ZarrChecksumManifest( 54 | directories=[_ZarrChecksum(digest="a7e86136543b019d72468ceebf71fb8e-1-1", name="a/b", size=1)], 55 | files=[_ZarrChecksum(digest="92eb5ffee6ae2fec3ad71c777531578f-0-1", name="b", size=1)], 56 | ) 57 | assert manifest.generate_digest().digest == "9c5294e46908cf397cb7ef53ffc12efc-1-2" 58 | 59 | 60 | def test_zarr_checksum_sort_order() -> None: 61 | # The a < b in the name should take precedence over z > y in the md5 62 | a = _ZarrChecksum(name="a", digest="z", size=3) 63 | b = _ZarrChecksum(name="b", digest="y", size=4) 64 | assert sorted([b, a]) == [a, b] 65 | 66 | 67 | def test_parse_zarr_directory_digest() -> None: 68 | # Parse valid 69 | _ZarrDirectoryDigest.parse("c228464f432c4376f0de6ddaea32650c-37481-38757151179") 70 | _ZarrDirectoryDigest.parse(None) 71 | 72 | # Ensure exception is raised 73 | with pytest.raises(InvalidZarrChecksum): 74 | _ZarrDirectoryDigest.parse("asd") 75 | with pytest.raises(InvalidZarrChecksum): 76 | _ZarrDirectoryDigest.parse("asd-0--0") 77 | 78 | 79 | def test_pop_deepest() -> None: 80 | tree = _ZarrChecksumTree() 81 | tree.add_leaf(Path("a/b"), size=1, digest="asd") 82 | tree.add_leaf(Path("a/b/c"), size=1, digest="asd") 83 | node = tree.pop_deepest() 84 | 85 | # Assert popped node is a/b/c, not a/b 86 | assert node.path == Path("a/b") 87 | assert len(node.checksums.files) == 1 88 | assert len(node.checksums.directories) == 0 89 | assert node.checksums.files[0].name == "c" 90 | 91 | 92 | def test_process_empty_tree() -> None: 93 | tree = _ZarrChecksumTree() 94 | assert tree.process().digest == EMPTY_CHECKSUM 95 | 96 | 97 | def test_process_tree() -> None: 98 | tree = _ZarrChecksumTree() 99 | tree.add_leaf(Path("a/b"), size=1, digest="9dd4e461268c8034f5c8564e155c67a6") 100 | tree.add_leaf(Path("c"), size=1, digest="415290769594460e2e485922904f345d") 101 | checksum = tree.process() 102 | 103 | # This zarr checksum was computed against the same file structure using the previous 104 | # zarr checksum implementation 105 | # Assert the current implementation produces a matching checksum 106 | assert checksum.digest == "e53fcb7b5c36b2f4647fbf826a44bdc9-2-2" 107 | 108 | 109 | def test_checksum_for_zarr_archive(zarr_archive, tmp_path): 110 | # NOTE: This test was not in the original code base of the zarr-checksum package. 111 | checksum, _ = compute_zarr_checksum(zarr_archive) 112 | 113 | path = str(tmp_path / "copy") 114 | copytree(zarr_archive, path) 115 | assert checksum == compute_zarr_checksum(path)[0] 116 | 117 | root = zarr.open(path) 118 | root["A"][0:10] = 0 119 | assert checksum != compute_zarr_checksum(path)[0] 120 | 121 | 122 | def test_zarr_leaf_to_checksum(zarr_archive): 123 | # NOTE: This test was not in the original code base of the zarr-checksum package. 124 | _, leaf_to_checksum = compute_zarr_checksum(zarr_archive) 125 | root = zarr.open(zarr_archive) 126 | 127 | # Check the basic structure - Each key corresponds to a file in the zarr archive 128 | assert len(leaf_to_checksum) == len(root.store) 129 | assert all(k.path in root.store for k in leaf_to_checksum) 130 | 131 | 132 | def test_zarr_checksum_fails_for_remote_storage(zarr_archive): 133 | # NOTE: This test was not in the original code base of the zarr-checksum package. 134 | with pytest.raises(RuntimeError): 135 | compute_zarr_checksum("s3://bucket/data.zarr") 136 | with pytest.raises(RuntimeError): 137 | compute_zarr_checksum("gs://bucket/data.zarr") 138 | 139 | 140 | def test_zarr_checksum_with_path_normalization(zarr_archive): 141 | # NOTE: This test was not in the original code base of the zarr-checksum package. 142 | 143 | baseline = compute_zarr_checksum(zarr_archive)[0] 144 | rootdir = os.path.dirname(zarr_archive) 145 | 146 | # Test a relative path 147 | copytree(zarr_archive, os.path.join(rootdir, "relative", "data.zarr")) 148 | compute_zarr_checksum(f"{zarr_archive}/../relative/data.zarr")[0] == baseline 149 | 150 | # Test with variables 151 | rng_id = str(uuid.uuid4()) 152 | os.environ["TMP_TEST_DIR"] = rng_id 153 | copytree(zarr_archive, os.path.join(rootdir, "vars", rng_id)) 154 | compute_zarr_checksum(f"{rootdir}/vars/${{TMP_TEST_DIR}}")[0] == baseline # Format ${...} 155 | compute_zarr_checksum(f"{rootdir}/vars/$TMP_TEST_DIR")[0] == baseline # Format $... 156 | 157 | # And with the user abbreviation 158 | try: 159 | path = os.path.expanduser("~/data.zarr") 160 | copytree(zarr_archive, path) 161 | compute_zarr_checksum("~/data.zarr")[0] == baseline 162 | finally: 163 | rmtree(path) 164 | -------------------------------------------------------------------------------- /polaris/benchmark/_split_v2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import cached_property 3 | from hashlib import md5 4 | from typing import Generator, Sequence 5 | 6 | from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator, model_validator 7 | from pydantic.alias_generators import to_camel 8 | from pyroaring import BitMap 9 | from typing_extensions import Self 10 | 11 | from polaris.utils.errors import InvalidBenchmarkError 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class IndexSet(BaseModel): 17 | """ 18 | A set of indices for a split, either training or test. 19 | 20 | It wraps a Roaring Bitmap object to store the indices, and provides 21 | useful properties when serializing for upload to the Hub. 22 | """ 23 | 24 | model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel) 25 | 26 | indices: BitMap = Field(default_factory=BitMap, frozen=True, exclude=True) 27 | 28 | @field_validator("indices", mode="before") 29 | @classmethod 30 | def _validate_indices(cls, v: BitMap | Sequence[int]) -> BitMap: 31 | """ 32 | Accepts an initial sequence of ints, and turn it into a BitMap 33 | """ 34 | if isinstance(v, BitMap): 35 | return v 36 | return BitMap(v) 37 | 38 | @computed_field 39 | @cached_property 40 | def datapoints(self) -> int: 41 | return len(self.indices) 42 | 43 | @computed_field 44 | @cached_property 45 | def md5_checksum(self) -> str: 46 | return md5(self.serialize()).hexdigest() 47 | 48 | def intersect(self, other: Self) -> bool: 49 | return self.indices.intersect(other.indices) 50 | 51 | def serialize(self) -> bytes: 52 | return self.indices.serialize() 53 | 54 | @staticmethod 55 | def deserialize(index_set: bytes) -> "IndexSet": 56 | return IndexSet(indices=BitMap.deserialize(index_set)) 57 | 58 | 59 | class SplitV2(BaseModel): 60 | """ 61 | A single train-test split pair containing training and test index sets. 62 | 63 | This represents one train-test split with training and test sets. 64 | Multiple SplitV2 instances can be used together for cross-validation scenarios. 65 | """ 66 | 67 | training: IndexSet 68 | test: IndexSet 69 | 70 | @field_validator("training", "test", mode="before") 71 | @classmethod 72 | def _parse_index_set(cls, v: bytes | IndexSet) -> IndexSet: 73 | """Accept a binary serialized IndexSet""" 74 | if isinstance(v, bytes): 75 | return IndexSet.deserialize(v) 76 | return v 77 | 78 | @field_validator("training") 79 | @classmethod 80 | def _validate_training_set(cls, v: IndexSet) -> IndexSet: 81 | """Training index set can be empty (zero-shot)""" 82 | if v.datapoints == 0: 83 | logger.debug( 84 | "This train-test split only specifies a test set. It will return an empty train set in `get_train_test_split()`" 85 | ) 86 | return v 87 | 88 | @field_validator("test") 89 | @classmethod 90 | def _validate_test_set(cls, v: IndexSet) -> IndexSet: 91 | """Test index set cannot be empty""" 92 | if v.datapoints == 0: 93 | raise InvalidBenchmarkError("Test set cannot be empty") 94 | return v 95 | 96 | @model_validator(mode="after") 97 | def validate_set_overlap(self) -> Self: 98 | """The training and test index sets do not overlap""" 99 | if self.training.intersect(self.test): 100 | raise InvalidBenchmarkError("The predefined split specifies overlapping train and test sets") 101 | return self 102 | 103 | @property 104 | def n_train_datapoints(self) -> int: 105 | """The size of the train set.""" 106 | return self.training.datapoints 107 | 108 | @property 109 | def n_test_datapoints(self) -> int: 110 | """The size of the test set.""" 111 | return self.test.datapoints 112 | 113 | @property 114 | def max_index(self) -> int: 115 | """Maximum index across train and test sets""" 116 | max_indices = [] 117 | 118 | # Only add max if the bitmap is not empty 119 | if len(self.training.indices) > 0: 120 | max_indices.append(self.training.indices.max()) 121 | max_indices.append(self.test.indices.max()) 122 | 123 | return max(max_indices) 124 | 125 | 126 | class SplitSpecificationV2Mixin(BaseModel): 127 | """ 128 | Mixin class to add splits field to a benchmark. This is the V2 implementation. 129 | 130 | The internal representation for the splits uses roaring bitmaps, 131 | which drastically improves scalability over the V1 implementation. 132 | 133 | Attributes: 134 | splits: The predefined train-test splits to use for evaluation. 135 | """ 136 | 137 | splits: dict[str, SplitV2] 138 | 139 | @model_validator(mode="after") 140 | def validate_splits_not_empty(self) -> Self: 141 | """Ensure at least one split is provided""" 142 | if not self.splits: 143 | raise InvalidBenchmarkError("At least one split must be specified") 144 | return self 145 | 146 | @computed_field 147 | @property 148 | def n_splits(self) -> int: 149 | """The number of splits""" 150 | return len(self.splits) 151 | 152 | @computed_field 153 | @property 154 | def split_labels(self) -> list[str]: 155 | """Labels of all splits""" 156 | return list(self.splits.keys()) 157 | 158 | @computed_field 159 | @property 160 | def n_train_datapoints(self) -> dict[str, int]: 161 | """The size of the train set for each split.""" 162 | return {label: split.n_train_datapoints for label, split in self.splits.items()} 163 | 164 | @computed_field 165 | @property 166 | def n_test_datapoints(self) -> dict[str, int]: 167 | """The size of the test set for each split.""" 168 | return {label: split.n_test_datapoints for label, split in self.splits.items()} 169 | 170 | @computed_field 171 | @property 172 | def max_index(self) -> int: 173 | """Maximum index across all splits""" 174 | return max(split.max_index for split in self.splits.values()) 175 | 176 | def split_items(self) -> Generator[tuple[str, SplitV2], None, None]: 177 | """Yield all splits with their labels""" 178 | for label, split in self.splits.items(): 179 | yield label, split 180 | -------------------------------------------------------------------------------- /tests/test_factory.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | import pandas as pd 3 | import pytest 4 | from fastpdb import struc 5 | from zarr.errors import ContainsArrayError 6 | 7 | from polaris.dataset import DatasetFactory, create_dataset_from_file 8 | from polaris.dataset._factory import create_dataset_from_files 9 | from polaris.dataset.converters import PDBConverter, SDFConverter, ZarrConverter 10 | 11 | 12 | def _check_pdb_dataset(dataset, ground_truth): 13 | assert len(dataset) == len(ground_truth) 14 | for i in range(dataset.table.shape[0]): 15 | pdb_array = dataset.get_data(row=i, col="pdb") 16 | assert isinstance(pdb_array, struc.AtomArray) 17 | assert pdb_array[0] == ground_truth[i][0] 18 | assert pdb_array.equal_annotations(ground_truth[i]) 19 | 20 | 21 | def _check_dataset(dataset, ground_truth, mol_props_as_col): 22 | assert len(dataset) == len(ground_truth) 23 | 24 | for row in range(len(dataset)): 25 | mol = dataset.get_data(row=row, col="molecule") 26 | 27 | assert isinstance(mol, dm.Mol) 28 | 29 | if mol_props_as_col: 30 | assert not mol.HasProp("my_property") 31 | v = dataset.get_data(row=row, col="my_property") 32 | assert v == ground_truth[row].GetProp("my_property") 33 | 34 | else: 35 | assert mol.HasProp("my_property") 36 | assert mol.GetProp("my_property") == ground_truth[row].GetProp("my_property") 37 | assert "my_property" not in dataset.columns 38 | 39 | 40 | def test_sdf_zarr_conversion(sdf_file, caffeine, tmp_path): 41 | """Test conversion between SDF and Zarr with utility function""" 42 | dataset = create_dataset_from_file(sdf_file, str(tmp_path / "archive.zarr")) 43 | _check_dataset(dataset, [caffeine], True) 44 | 45 | 46 | @pytest.mark.parametrize("mol_props_as_col", [True, False]) 47 | def test_factory_sdf_with_prop_as_col(sdf_file, caffeine, tmp_path, mol_props_as_col): 48 | """Test conversion between SDF and Zarr with factory pattern""" 49 | 50 | factory = DatasetFactory(str(tmp_path / "archive.zarr")) 51 | 52 | converter = SDFConverter(mol_prop_as_cols=mol_props_as_col) 53 | factory.register_converter("sdf", converter) 54 | 55 | factory.add_from_file(sdf_file) 56 | dataset = factory.build() 57 | 58 | _check_dataset(dataset, [caffeine], mol_props_as_col) 59 | 60 | 61 | def test_zarr_to_zarr_conversion(zarr_archive, tmp_path): 62 | """Test conversion between Zarr and Zarr with utility function""" 63 | dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "archive.zarr")) 64 | assert len(dataset) == 100 65 | assert len(dataset.columns) == 2 66 | assert all(c in dataset.columns for c in ["A", "B"]) 67 | assert all(dataset.annotations[c].is_pointer for c in ["A", "B"]) 68 | assert dataset.get_data(row=0, col="A").shape == (2048,) 69 | 70 | 71 | def test_zarr_with_factory_pattern(zarr_archive, tmp_path): 72 | """Test conversion between Zarr and Zarr with factory pattern""" 73 | 74 | factory = DatasetFactory(str(tmp_path / "archive.zarr")) 75 | converter = ZarrConverter() 76 | factory.register_converter("zarr", converter) 77 | factory.add_from_file(zarr_archive) 78 | 79 | factory.add_column(pd.Series([1, 2, 3, 4] * 25, name="C")) 80 | 81 | df = pd.DataFrame({"C": [1, 2, 3, 4], "D": ["W", "X", "Y", "Z"]}) 82 | factory.add_columns(df, merge_on="C") 83 | 84 | dataset = factory.build() 85 | assert len(dataset) == 100 86 | assert len(dataset.columns) == 4 87 | assert all(c in dataset.columns for c in ["A", "B", "C", "D"]) 88 | assert dataset.table["C"].apply({1: "W", 2: "X", 3: "Y", 4: "Z"}.get).equals(dataset.table["D"]) 89 | 90 | 91 | def test_factory_pdb(pdbs_structs, pdb_paths, tmp_path): 92 | """Test conversion between PDB file and Zarr with factory pattern""" 93 | factory = DatasetFactory(str(tmp_path / "pdb.zarr")) 94 | 95 | converter = PDBConverter() 96 | factory.register_converter("pdb", converter) 97 | 98 | factory.add_from_file(pdb_paths[0]) 99 | dataset = factory.build() 100 | 101 | _check_pdb_dataset(dataset, pdbs_structs[:1]) 102 | 103 | 104 | def test_factory_pdbs(pdbs_structs, pdb_paths, tmp_path): 105 | """Test conversion between PDB files and Zarr with factory pattern""" 106 | 107 | factory = DatasetFactory(str(tmp_path / "pdbs.zarr")) 108 | 109 | converter = PDBConverter() 110 | factory.register_converter("pdb", converter) 111 | 112 | factory.add_from_files(pdb_paths, axis=0) 113 | dataset = factory.build() 114 | 115 | assert dataset.table.shape[0] == len(pdb_paths) 116 | _check_pdb_dataset(dataset, pdbs_structs) 117 | 118 | 119 | def test_pdbs_zarr_conversion(pdbs_structs, pdb_paths, tmp_path): 120 | """Test conversion between PDBs and Zarr with utility function""" 121 | 122 | dataset = create_dataset_from_files(pdb_paths, str(tmp_path / "pdbs_2.zarr"), axis=0) 123 | 124 | assert dataset.table.shape[0] == len(pdb_paths) 125 | _check_pdb_dataset(dataset, pdbs_structs) 126 | 127 | 128 | def test_factory_sdfs(sdf_files, caffeine, ibuprofen, tmp_path): 129 | """Test conversion between SDF and Zarr with factory pattern""" 130 | 131 | factory = DatasetFactory(str(tmp_path / "sdfs.zarr")) 132 | 133 | converter = SDFConverter(mol_prop_as_cols=True) 134 | factory.register_converter("sdf", converter) 135 | 136 | factory.add_from_files(sdf_files, axis=0) 137 | dataset = factory.build() 138 | 139 | _check_dataset(dataset, [caffeine, ibuprofen], True) 140 | 141 | 142 | def test_factory_sdf_pdb(sdf_file, pdb_paths, caffeine, pdbs_structs, tmp_path): 143 | """Test conversion between SDF and PDB from files to Zarr with factory pattern""" 144 | 145 | factory = DatasetFactory(str(tmp_path / "sdf_pdb.zarr")) 146 | 147 | sdf_converter = SDFConverter(mol_prop_as_cols=False) 148 | factory.register_converter("sdf", sdf_converter) 149 | 150 | pdb_converter = PDBConverter() 151 | factory.register_converter("pdb", pdb_converter) 152 | 153 | factory.add_from_files([sdf_file, pdb_paths[0]], axis=1) 154 | dataset = factory.build() 155 | 156 | _check_dataset(dataset, [caffeine], False) 157 | _check_pdb_dataset(dataset, pdbs_structs[:1]) 158 | 159 | 160 | def test_factory_from_files_same_column(sdf_files, pdb_paths, tmp_path): 161 | factory = DatasetFactory(str(tmp_path / "files.zarr")) 162 | 163 | sdf_converter = SDFConverter(mol_prop_as_cols=False) 164 | factory.register_converter("sdf", sdf_converter) 165 | 166 | pdb_converter = PDBConverter() 167 | factory.register_converter("pdb", pdb_converter) 168 | 169 | # do not allow same type of files to be appended in columns by `add_from_files` 170 | # in this case, user should define converter for individual columns 171 | 172 | # attempt to append columns by pdbs 173 | with pytest.raises(ValueError): 174 | factory.add_from_files(pdb_paths, axis=1) 175 | 176 | # attempt to append columns by sdfs 177 | with pytest.raises(ContainsArrayError): 178 | factory.add_from_files(sdf_files, axis=1) 179 | -------------------------------------------------------------------------------- /tests/test_dataset_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import perf_counter 3 | 4 | import numcodecs 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | import zarr 9 | from pydantic import ValidationError 10 | 11 | from polaris.dataset import DatasetV2, Subset 12 | from polaris.utils.zarr._manifest import generate_zarr_manifest 13 | 14 | 15 | def test_dataset_v2_get_columns(test_dataset_v2): 16 | assert set(test_dataset_v2.columns) == {"A", "B"} 17 | 18 | 19 | def test_dataset_v2_get_rows(test_dataset_v2): 20 | assert set(test_dataset_v2.rows) == set(range(100)) 21 | 22 | 23 | def test_dataset_v2_get_data(test_dataset_v2, zarr_archive): 24 | root = zarr.open(zarr_archive, "r") 25 | indices = np.random.randint(0, len(test_dataset_v2), 5) 26 | for idx in indices: 27 | assert np.array_equal(test_dataset_v2.get_data(row=idx, col="A"), root["A"][idx]) 28 | assert np.array_equal(test_dataset_v2.get_data(row=idx, col="B"), root["B"][idx]) 29 | 30 | 31 | def test_dataset_v2_with_subset(test_dataset_v2, zarr_archive): 32 | root = zarr.open(zarr_archive, "r") 33 | indices = np.random.randint(0, len(test_dataset_v2), 5) 34 | subset = Subset(test_dataset_v2, indices, "A", "B") 35 | for i, (x, y) in enumerate(subset): 36 | idx = indices[i] 37 | assert np.array_equal(x, root["A"][idx]) 38 | assert np.array_equal(y, root["B"][idx]) 39 | 40 | 41 | def test_dataset_v2_load_to_memory(test_dataset_v2): 42 | subset = Subset( 43 | dataset=test_dataset_v2, 44 | indices=range(100), 45 | input_cols=["A"], 46 | target_cols=["B"], 47 | ) 48 | 49 | t1 = perf_counter() 50 | for x in subset: 51 | pass 52 | d1 = perf_counter() - t1 53 | 54 | test_dataset_v2.load_to_memory() 55 | 56 | t2 = perf_counter() 57 | for x in subset: 58 | pass 59 | d2 = perf_counter() - t2 60 | 61 | assert d2 < d1 62 | 63 | 64 | def test_dataset_v2_serialization(test_dataset_v2, tmp_path): 65 | save_dir = str(tmp_path / "save_dir") 66 | path = test_dataset_v2.to_json(save_dir) 67 | new_dataset = DatasetV2.from_json(path) 68 | for i in range(5): 69 | assert np.array_equal(new_dataset.get_data(i, "A"), test_dataset_v2.get_data(i, "A")) 70 | assert np.array_equal(new_dataset.get_data(i, "B"), test_dataset_v2.get_data(i, "B")) 71 | 72 | 73 | def test_dataset_v2_caching(test_dataset_v2, tmp_path): 74 | cache_dir = str(tmp_path / "cache") 75 | test_dataset_v2._cache_dir = cache_dir 76 | test_dataset_v2.cache() 77 | assert str(test_dataset_v2.zarr_root_path).startswith(cache_dir) 78 | 79 | 80 | def test_dataset_v1_v2_compatibility(test_dataset, tmp_path): 81 | # A DataFrame is ultimately a collection of labeled numpy arrays 82 | # We can thus also saved these same arrays to a Zarr archive 83 | df = test_dataset.table 84 | 85 | path = str(tmp_path / "data" / "v1v2.zarr") 86 | 87 | root = zarr.open(path, "w") 88 | root.array("smiles", data=df["smiles"].values, dtype=object, object_codec=numcodecs.VLenUTF8()) 89 | root.array("iupac", data=df["iupac"].values, dtype=object, object_codec=numcodecs.VLenUTF8()) 90 | for col in set(df.columns) - {"smiles", "iupac"}: 91 | root.array(col, data=df[col].values) 92 | zarr.consolidate_metadata(path) 93 | 94 | kwargs = test_dataset.model_dump(exclude=["table", "zarr_root_path"]) 95 | dataset = DatasetV2(**kwargs, zarr_root_path=path) 96 | 97 | subset_1 = Subset(dataset=test_dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"]) 98 | subset_2 = Subset(dataset=dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"]) 99 | 100 | for idx in range(5): 101 | x1, y1 = subset_1[idx] 102 | x2, y2 = subset_2[idx] 103 | assert x1 == x2 104 | assert y1 == y2 105 | 106 | 107 | def test_dataset_v2_subgroup_validation(zarr_archive): 108 | # Create a subgroup in the Zarr archive 109 | root = zarr.open(zarr_archive, "a") 110 | subgroup = root.create_group("subgroup") 111 | subgroup.array("data", data=np.random.random((100, 2048))) 112 | zarr.consolidate_metadata(zarr_archive) 113 | 114 | # Creating the dataset should fail due to subgroups not being supported 115 | with pytest.raises( 116 | ValidationError, 117 | match="The Zarr archive of a Dataset can't have any subgroups. Found \\['subgroup'\\]", 118 | ): 119 | DatasetV2(zarr_root_path=zarr_archive) 120 | 121 | 122 | def test_dataset_v2_validation_consistent_lengths(zarr_archive): 123 | root = zarr.open(zarr_archive, "a") 124 | 125 | # Change the length of one of the arrays 126 | root["A"].append(np.random.random((1, 2048))) 127 | zarr.consolidate_metadata(zarr_archive) 128 | 129 | with pytest.raises(ValidationError, match="should have the same length"): 130 | DatasetV2(zarr_root_path=zarr_archive) 131 | 132 | # Make the length of the two arrays equal again 133 | # shouldn't crash 134 | root["B"].append(np.random.random((1, 2048))) 135 | zarr.consolidate_metadata(zarr_archive) 136 | DatasetV2(zarr_root_path=zarr_archive) 137 | 138 | 139 | def test_zarr_manifest(test_dataset_v2): 140 | # Assert the manifest Parquet is created 141 | assert test_dataset_v2.zarr_manifest_path is not None 142 | assert os.path.isfile(test_dataset_v2.zarr_manifest_path) 143 | 144 | # Assert the manifest contains 204 rows (the number "204" is chosen because 145 | # the Zarr archive defined in `conftest.py` contains 204 unique files) 146 | df = pd.read_parquet(test_dataset_v2.zarr_manifest_path) 147 | assert len(df) == 204 148 | 149 | # Assert the manifest hash is calculated 150 | assert test_dataset_v2.zarr_manifest_md5sum is not None 151 | 152 | # Add array to Zarr archive to change the number of chunks in the dataset 153 | root = zarr.open(test_dataset_v2.zarr_root_path, "a") 154 | root.array("C", data=np.random.random((100, 2048)), chunks=(1, None)) 155 | 156 | generate_zarr_manifest(test_dataset_v2.zarr_root_path, test_dataset_v2._cache_dir) 157 | 158 | # Get the length of the updated manifest file 159 | post_change_manifest_length = len(pd.read_parquet(test_dataset_v2.zarr_manifest_path)) 160 | 161 | # Ensure Zarr manifest has an additional 100 chunks + 1 array metadata file 162 | assert post_change_manifest_length == 305 163 | 164 | 165 | def test_dataset_v2__get_item__(test_dataset_v2, zarr_archive): 166 | """Test the __getitem__() interface for the dataset V2.""" 167 | 168 | # Ground truth 169 | root = zarr.open(zarr_archive) 170 | 171 | # Get a specific cell 172 | assert np.array_equal(test_dataset_v2[0, "A"], root["A"][0, :]) 173 | 174 | # Get a specific row 175 | def _check_row_equality(d1, d2): 176 | assert len(d1) == len(d2) 177 | for k in d1: 178 | assert np.array_equal(d1[k], d2[k]) 179 | 180 | _check_row_equality(test_dataset_v2[0], {"A": root["A"][0, :], "B": root["B"][0, :]}) 181 | _check_row_equality(test_dataset_v2[10], {"A": root["A"][10, :], "B": root["B"][10, :]}) 182 | -------------------------------------------------------------------------------- /polaris/utils/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Annotated, Any, Literal, Optional 3 | 4 | import numpy as np 5 | from pydantic import ( 6 | AnyUrl, 7 | BaseModel, 8 | BeforeValidator, 9 | ConfigDict, 10 | HttpUrl, 11 | StringConstraints, 12 | TypeAdapter, 13 | ) 14 | from pydantic.alias_generators import to_camel 15 | from typing_extensions import Self, TypeAlias 16 | 17 | SplitIndicesType: TypeAlias = list[int] 18 | """ 19 | A split is defined by a sequence of integers. 20 | """ 21 | 22 | SplitType: TypeAlias = tuple[SplitIndicesType, SplitIndicesType | dict[str, SplitIndicesType]] 23 | """ 24 | A split is a pair of which the first item is always assumed to be the train set. 25 | The second item can either be a single test set or a dictionary with multiple, named test sets. 26 | """ 27 | 28 | ListOrArrayType: TypeAlias = list | np.ndarray 29 | """ 30 | A list of numbers or a numpy array. Predictions can be provided as either a list or a numpy array. 31 | """ 32 | 33 | IncomingPredictionsType: TypeAlias = ListOrArrayType | dict[str, ListOrArrayType | dict[str, ListOrArrayType]] 34 | """ 35 | The type of the predictions that are ingested into the Polaris BenchmarkPredictions object. Can be one 36 | of the following: 37 | 38 | - A single array (single-task, single test set) 39 | - A dictionary of arrays (single-task, multiple test sets) 40 | - A dictionary of dictionaries of arrays (multi-task, multiple test sets) 41 | """ 42 | 43 | PredictionsType: TypeAlias = dict[str, dict[str, np.ndarray]] 44 | """ 45 | The normalized format for predictions for internal use. Predictions are accepted in a generous 46 | variety of representations and normalized into this standard format, a dictionary of dictionaries 47 | that looks like {"test_set_name": {"target_name": np.ndarray}}. 48 | """ 49 | 50 | DatapointPartType = Any | tuple[Any] | dict[str, Any] 51 | DatapointType: TypeAlias = tuple[DatapointPartType, DatapointPartType] 52 | """ 53 | A datapoint has: 54 | 55 | - A single input or multiple inputs (either as dict or tuple) 56 | - No target, a single target or a multiple targets (either as dict or tuple) 57 | """ 58 | 59 | SlugStringType: TypeAlias = Annotated[ 60 | str, StringConstraints(pattern="^[a-z0-9-]+$", min_length=4, max_length=64) 61 | ] 62 | """ 63 | A URL-compatible string that can serve as slug on the Hub. 64 | """ 65 | 66 | SlugCompatibleStringType: TypeAlias = Annotated[ 67 | str, StringConstraints(pattern="^[A-Za-z0-9_-]+$", min_length=4, max_length=64) 68 | ] 69 | """ 70 | A URL-compatible string that can be turned into a slug by the Hub. 71 | 72 | Can only use alpha-numeric characters, underscores and dashes. 73 | The string must be at least 4 and at most 64 characters long. 74 | """ 75 | 76 | Md5StringType: TypeAlias = Annotated[str, StringConstraints(pattern=r"^[a-f0-9]{32}$")] 77 | """ 78 | A string that represents an MD5 hash. 79 | """ 80 | 81 | HubUser: TypeAlias = SlugCompatibleStringType 82 | """ 83 | A user on the Polaris Hub is identified by a username, 84 | which is a [`SlugCompatibleStringType`][polaris.utils.types.SlugCompatibleStringType]. 85 | """ 86 | 87 | HttpUrlAdapter = TypeAdapter(HttpUrl) 88 | HttpUrlString: TypeAlias = Annotated[str, BeforeValidator(lambda v: HttpUrlAdapter.validate_python(v) and v)] 89 | """ 90 | A validated HTTP URL that will be turned into a string. 91 | This is useful for interactions with httpx and authlib, who have their own URL types. 92 | """ 93 | 94 | AnyUrlAdapter = TypeAdapter(AnyUrl) 95 | AnyUrlString: TypeAlias = Annotated[str, BeforeValidator(lambda v: AnyUrlAdapter.validate_python(v) and v)] 96 | """ 97 | A validated generic URL that will be turned into a string. 98 | This is useful for interactions with other libraries that expect a string. 99 | """ 100 | 101 | DirectionType: TypeAlias = float | Literal["min", "max"] 102 | """ 103 | The direction of any variable to be sorted. 104 | This can be used to sort the metric score, indicate the optmization direction of endpoint. 105 | """ 106 | 107 | TimeoutTypes = tuple[int, int] | Literal["timeout", "never"] 108 | """ 109 | Timeout types for specifying maximum wait times. 110 | """ 111 | 112 | IOMode: TypeAlias = Literal["r", "r+", "a", "w", "w-"] 113 | """ 114 | Type to specify the mode for input/output operations (I/O) when interacting with a file or resource. 115 | """ 116 | 117 | SupportedLicenseType: TypeAlias = Literal[ 118 | "CC-BY-4.0", "CC-BY-SA-4.0", "CC-BY-NC-4.0", "CC-BY-NC-SA-4.0", "CC0-1.0", "MIT" 119 | ] 120 | """ 121 | Supported license types for dataset uploads to Polaris Hub 122 | """ 123 | 124 | ZarrConflictResolution: TypeAlias = Literal["raise", "replace", "skip"] 125 | """ 126 | Type to specify which action to take when encountering existing files within a Zarr archive. 127 | """ 128 | 129 | ChecksumStrategy: TypeAlias = Literal["verify", "verify_unless_zarr", "ignore"] 130 | """ 131 | Type to specify which action to take to verify the data integrity of an artifact through a checksum. 132 | """ 133 | 134 | ArtifactUrn: TypeAlias = Annotated[str, StringConstraints(pattern=r"^urn:polaris:\w+:\w+:\w+$")] 135 | """ 136 | A Uniform Resource Name (URN) for an artifact on the Polaris Hub. 137 | """ 138 | 139 | RowIndex: TypeAlias = int | str 140 | ColumnIndex: TypeAlias = str 141 | DatasetIndex: TypeAlias = RowIndex | tuple[RowIndex, ColumnIndex] 142 | """ 143 | To index a dataset using square brackets, we have a few options: 144 | 145 | - A single row, e.g. dataset[0] 146 | - Specify a specific value, e.g. dataset[0, "col1"] 147 | 148 | There are more exciting options we could implement, such as slicing, 149 | but this gets complex. 150 | """ 151 | 152 | 153 | PredictionKwargs: TypeAlias = Literal["y_pred", "y_prob", "y_score"] 154 | """ 155 | The type of predictions expected by the metric interface. 156 | """ 157 | 158 | ColumnName: TypeAlias = str 159 | """A column name in a dataset.""" 160 | 161 | 162 | class HubOwner(BaseModel): 163 | """An owner of an artifact on the Polaris Hub 164 | 165 | The slug is most important as it is the user-facing part of this data model. 166 | The externalId and type are added to be consistent with the model returned by the Polaris Hub . 167 | """ 168 | 169 | slug: SlugStringType 170 | external_id: Optional[str] = None 171 | type: Optional[Literal["user", "organization"]] = None 172 | 173 | # Pydantic config 174 | model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True) 175 | 176 | def __str__(self): 177 | return self.slug 178 | 179 | @staticmethod 180 | def normalize(owner: str | Self) -> Self: 181 | """ 182 | Normalize a string or `HubOwner` instance to a `HubOwner` instance. 183 | """ 184 | return owner if isinstance(owner, HubOwner) else HubOwner(slug=owner) 185 | 186 | 187 | class TargetType(Enum): 188 | """The high-level classification of different targets.""" 189 | 190 | REGRESSION = "regression" 191 | CLASSIFICATION = "classification" 192 | DOCKING = "docking" 193 | 194 | 195 | class TaskType(Enum): 196 | """The high-level classification of different tasks.""" 197 | 198 | MULTI_TASK = "multi_task" 199 | SINGLE_TASK = "single_task" 200 | --------------------------------------------------------------------------------