├── .circleci └── config.yml ├── .gitignore ├── .isort.cfg ├── LICENSE ├── README.md ├── ocpapi ├── __init__.py ├── client │ ├── __init__.py │ ├── client.py │ ├── models.py │ └── ui.py ├── version.py └── workflows │ ├── __init__.py │ ├── adsorbates.py │ ├── context.py │ ├── filter.py │ ├── log.py │ └── retry.py ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── integration ├── __init__.py ├── client │ ├── __init__.py │ ├── test_client.py │ └── test_ui.py └── workflows │ ├── __init__.py │ └── test_adsorbates.py └── unit ├── __init__.py ├── client ├── __init__.py ├── test_client.py ├── test_models.py └── test_ui.py └── workflows ├── __init__.py ├── test_adsorbates.py ├── test_context.py ├── test_filter.py └── test_retry.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | codecov: codecov/codecov@3.2.4 5 | 6 | workflows: 7 | analyze-and-test: 8 | jobs: 9 | - lint 10 | - test 11 | 12 | jobs: 13 | lint: 14 | docker: 15 | - image: cimg/python:3.9.13 16 | steps: 17 | - checkout 18 | - run: 19 | name: Setup 20 | command: pip install black==23.9.1 isort==5.12.0 autoflake==2.2.1 21 | - run: 22 | name: Format 23 | command: black . --check 24 | - run: 25 | name: Import order 26 | command: isort . --check-only 27 | - run: 28 | name: Unused imports 29 | command: autoflake . --remove-all-unused-imports --check 30 | test: 31 | docker: 32 | - image: cimg/python:3.9.13 33 | steps: 34 | - checkout 35 | - run: 36 | name: Setup 37 | command: | 38 | pip install --upgrade setuptools 39 | pip install pytest-cov 40 | pip install .[dev] 41 | - run: 42 | name: Test 43 | command: pytest --cov --cov-report=xml 44 | - codecov/upload 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | multi_line_output=3 3 | include_trailing_comma=True 4 | force_grid_wrap=0 5 | use_parentheses=True 6 | line_length=79 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Open-Catalyst-Project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # REPO DEPRECATION: In 2024 the FAIR Chemistry team consolidated its various resources to improve release/testing/availability, and [the OCP API moved to the fairchem repo](https://github.com/FAIR-Chem/fairchem/tree/main/src/fairchem/demo/ocpapi). You can find updated documentation on how to use the consolidated resources at https://fair-chem.github.io/ 2 | 3 | # ocpapi 4 | 5 | [![CircleCI](https://dl.circleci.com/status-badge/img/gh/Open-Catalyst-Project/ocpapi/tree/main.svg?style=shield)](https://dl.circleci.com/status-badge/redirect/gh/Open-Catalyst-Project/ocpapi/tree/main) [![codecov](https://codecov.io/gh/Open-Catalyst-Project/ocpapi/graph/badge.svg?token=66Z7Y7QUUW)](https://codecov.io/gh/Open-Catalyst-Project/ocpapi) 6 | 7 | Python library for programmatic use of the [Open Catalyst Demo](https://open-catalyst.metademolab.com/). Users unfamiliar with the Open Catalyst Demo are encouraged to read more about it before continuing. 8 | 9 | ## Installation 10 | 11 | Ensure you have Python 3.9.1 or newer, and install `ocpapi` using: 12 | 13 | ```sh 14 | pip install ocpapi 15 | ``` 16 | 17 | ## Quickstart 18 | 19 | The following examples are used to search for *OH binding sites on Pt surfaces. They use the `find_adsorbate_binding_sites` function, which is a high-level workflow on top of other methods included in this library. Once familiar with this routine, users are encouraged to learn about lower-level methods and features that support more advanced use cases. 20 | 21 | ### Note about async methods 22 | 23 | This package relies heavily on [asyncio](https://docs.python.org/3/library/asyncio.html). The examples throughout this document can be copied to a python repl launched with: 24 | ```sh 25 | $ python -m asyncio 26 | ``` 27 | Alternatively, an async function can be run in a script by wrapping it with [asyncio.run()](https://docs.python.org/3/library/asyncio-runner.html#asyncio.run): 28 | ```python 29 | import asyncio 30 | from ocpapi import find_adsorbate_binding_sites 31 | 32 | asyncio.run(find_adsorbate_binding_sites(...)) 33 | ``` 34 | 35 | ### Search over all surfaces 36 | 37 | ```python 38 | from ocpapi import find_adsorbate_binding_sites 39 | 40 | results = await find_adsorbate_binding_sites( 41 | adsorbate="*OH", 42 | bulk="mp-126", 43 | ) 44 | ``` 45 | 46 | Users will be prompted to select one or more surfaces that should be relaxed. 47 | 48 | Input to this function includes: 49 | 50 | * The name of the adsorbate to place 51 | * A unique ID of the bulk structure from which surfaces will be generated 52 | 53 | This function will perform the following steps: 54 | 55 | 1. Enumerate surfaces of the bulk material 56 | 2. On each surface, enumerate initial guesses for adorbate binding sites 57 | 3. Run local force-based relaxations of each adsorbate placement 58 | 59 | In addition, this handles: 60 | 61 | * Retrying failed calls to the Open Catalyst Demo API 62 | * Retrying submission of relaxations when they are rate limited 63 | 64 | This should take 2-10 minutes to finish while tens to hundreds (depending on the number of surfaces that are selected) of individual adsorbate placements are relaxed on unique surfaces of Pt. Each of the objects in the returned list includes (among other details): 65 | 66 | * Information about the surface being searched, including its structure and Miller indices 67 | * The initial positions of the adsorbate before relaxation 68 | * The final structure after relaxation 69 | * The predicted energy of the final structure 70 | * The predicted force on each atom in the final structure 71 | 72 | 73 | ### Supported bulks and adsorbates 74 | 75 | A finite set of bulk materials and adsorbates can be referenced by ID throughout the OCP API. The lists of supported values can be viewed in two ways. 76 | 77 | 1. Visit the UI at https://open-catalyst.metademolab.com/demo and explore the lists in Step 1 and Step 3. 78 | 2. Use the low-level client that ships with this library: 79 | ```python 80 | from ocpapi import Client 81 | 82 | client = Client() 83 | 84 | bulks = await client.get_bulks() 85 | print({b.src_id: b.formula for b in bulks.bulks_supported}) 86 | 87 | adsorbates = await client.get_adsorbates() 88 | print(adsorbates.adsorbates_supported) 89 | ``` 90 | 91 | 92 | ### Persisting results 93 | 94 | **Results should be saved whenever possible in order to avoid expensive recomputation.** 95 | 96 | Assuming `results` was generated with the `find_adsorbate_binding_sites` method used above, it is an `AdsorbateBindingSites` object. This can be saved to file with: 97 | 98 | ```python 99 | with open("results.json", "w") as f: 100 | f.write(results.to_json()) 101 | ``` 102 | 103 | Similarly, results can be read back from file to an `AdsorbateBindingSites` object with: 104 | 105 | ```python 106 | from ocpapi import AdsorbateBindingSites 107 | 108 | with open("results.json", "r") as f: 109 | results = AdsorbateBindingSites.from_json(f.read()) 110 | ``` 111 | 112 | ### Viewing results in the web UI 113 | 114 | Relaxation results can be viewed in a web UI. For example, https://open-catalyst.metademolab.com/results/7eaa0d63-83aa-473f-ac84-423ffd0c67f5 shows the results of relaxing *OH on a Pt (1,1,1) surface; the uuid, "7eaa0d63-83aa-473f-ac84-423ffd0c67f5", is referred to as the `system_id`. 115 | 116 | Extending the examples above, the URLs to visualize the results of relaxations on each Pt surface can be obtained with: 117 | ```python 118 | urls = [ 119 | slab.ui_url 120 | for slab in results.slabs 121 | ] 122 | ``` 123 | 124 | ## Advanced usage 125 | 126 | ### Changing the model type 127 | 128 | The API currently supports two models: 129 | * `equiformer_v2_31M_s2ef_all_md` (default): https://arxiv.org/abs/2306.12059 130 | * `gemnet_oc_base_s2ef_all_md`: https://arxiv.org/abs/2204.02782 131 | 132 | A specific model type can be requested with: 133 | ```python 134 | from ocpapi import find_adsorbate_binding_sites 135 | 136 | results = await find_adsorbate_binding_sites( 137 | adsorbate="*OH", 138 | bulk="mp-126", 139 | model="gemnet_oc_base_s2ef_all_md", 140 | ) 141 | ``` 142 | 143 | ### Skip relaxation approval prompts 144 | 145 | Calls to `find_adsorbate_binding_sites()` will, by default, show the user all pending relaxations and ask for approval before they are submitted. In order to run the relaxations automatically without manual approval, `adslab_filter` can be set to a function that automatically approves any or all adsorbate/slab (adslab) configurations. 146 | 147 | Run relaxations for all slabs that are generated: 148 | ```python 149 | from ocpapi import find_adsorbate_binding_sites, keep_all_slabs 150 | 151 | results = await find_adsorbate_binding_sites( 152 | adsorbate="*OH", 153 | bulk="mp-126", 154 | adslab_filter=keep_all_slabs(), 155 | ) 156 | ``` 157 | 158 | Run relaxations only for slabs with Miller Indices in the input set: 159 | ```python 160 | from ocpapi import find_adsorbate_binding_sites, keep_slabs_with_miller_indices 161 | 162 | results = await find_adsorbate_binding_sites( 163 | adsorbate="*OH", 164 | bulk="mp-126", 165 | adslab_filter=keep_slabs_with_miller_indices([(1, 0, 0), (1, 1, 1)]), 166 | ) 167 | ``` 168 | 169 | ### Converting to [ase.Atoms](https://wiki.fysik.dtu.dk/ase/ase/atoms.html) objects 170 | 171 | **Important! The `to_ase_atoms()` method described below will fail with an import error if [ase](https://wiki.fysik.dtu.dk/ase) is not installed.** 172 | 173 | Two classes have support for generating [ase.Atoms](https://wiki.fysik.dtu.dk/ase/ase/atoms.html) objects: 174 | * `ocpapi.Atoms.to_ase_atoms()`: Adds unit cell, atomic positions, and other structural information to the returned `ase.Atoms` object. 175 | * `ocpapi.AdsorbateSlabRelaxationResult.to_ase_atoms()`: Adds the same structure information to the `ase.Atoms` object. Also adds the predicted forces and energy of the relaxed structure, which can be accessed with the `ase.Atoms.get_potential_energy()` and `ase.Atoms.get_forces()` methods. 176 | 177 | For example, the following would generate an `ase.Atoms` object for the first relaxed adsorbate configuration on the first slab generated for *OH binding on Pt: 178 | ```python 179 | from ocpapi import find_adsorbate_binding_sites 180 | 181 | results = await find_adsorbate_binding_sites( 182 | adsorbate="*OH", 183 | bulk="mp-126", 184 | ) 185 | 186 | ase_atoms = results.slabs[0].configs[0].to_ase_atoms() 187 | ``` 188 | 189 | ### Converting to other structure formats 190 | 191 | From an `ase.Atoms` object (see previous section), is is possible to [write to other structure formats](https://wiki.fysik.dtu.dk/ase/ase/io/io.html#ase.io.write). Extending the example above, the `ase_atoms` object could be written to a [VASP POSCAR file](https://www.vasp.at/wiki/index.php/POSCAR) with: 192 | ```python 193 | from ase.io import write 194 | 195 | write("POSCAR", ase_atoms, "vasp") 196 | ``` 197 | 198 | ## License 199 | 200 | `ocpapi` is released under the [MIT License](LICENSE). 201 | 202 | ## Citing `ocpapi` 203 | 204 | If you use `ocpapi` in your research, please consider citing the [AdsorbML paper](https://www.nature.com/articles/s41524-023-01121-5) (in addition to the relevant datasets / models used): 205 | 206 | ```bibtex 207 | @article{lan2023adsorbml, 208 | title={{AdsorbML}: a leap in efficiency for adsorption energy calculations using generalizable machine learning potentials}, 209 | author={Lan*, Janice and Palizhati*, Aini and Shuaibi*, Muhammed and Wood*, Brandon M and Wander, Brook and Das, Abhishek and Uyttendaele, Matt and Zitnick, C Lawrence and Ulissi, Zachary W}, 210 | journal={npj Computational Materials}, 211 | year={2023}, 212 | } 213 | ``` 214 | -------------------------------------------------------------------------------- /ocpapi/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import * # noqa 2 | from .workflows import * # noqa 3 | -------------------------------------------------------------------------------- /ocpapi/client/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import ( # noqa 2 | Client, 3 | NonRetryableRequestException, 4 | RateLimitExceededException, 5 | RequestException, 6 | ) 7 | from .models import ( # noqa 8 | Adsorbates, 9 | AdsorbateSlabConfigs, 10 | AdsorbateSlabRelaxationResult, 11 | AdsorbateSlabRelaxationsRequest, 12 | AdsorbateSlabRelaxationsResults, 13 | AdsorbateSlabRelaxationsSystem, 14 | Atoms, 15 | Bulk, 16 | Bulks, 17 | Model, 18 | Models, 19 | Slab, 20 | SlabMetadata, 21 | Slabs, 22 | Status, 23 | ) 24 | from .ui import get_results_ui_url # noqa 25 | -------------------------------------------------------------------------------- /ocpapi/client/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from datetime import timedelta 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | import requests 7 | 8 | from .models import ( 9 | Adsorbates, 10 | AdsorbateSlabConfigs, 11 | AdsorbateSlabRelaxationsRequest, 12 | AdsorbateSlabRelaxationsResults, 13 | AdsorbateSlabRelaxationsSystem, 14 | Atoms, 15 | Bulk, 16 | Bulks, 17 | Models, 18 | Slab, 19 | Slabs, 20 | ) 21 | 22 | 23 | class RequestException(Exception): 24 | """ 25 | Exception raised any time there is an error while making an API call. 26 | """ 27 | 28 | def __init__(self, method: str, url: str, cause: str) -> None: 29 | """ 30 | Args: 31 | method: The type of the method being run (POST, GET, etc.). 32 | url: The full URL that was called. 33 | cause: A description of the failure. 34 | """ 35 | super().__init__(f"Request to {method} {url} failed. {cause}") 36 | 37 | 38 | class NonRetryableRequestException(RequestException): 39 | """ 40 | Exception raised when an API call is rejected for a reason that will 41 | not succeed on retry. For example, this might include a malformed request 42 | or action that is not allowed. 43 | """ 44 | 45 | def __init__(self, method: str, url: str, cause: str) -> None: 46 | """ 47 | Args: 48 | method: The type of the method being run (POST, GET, etc.). 49 | url: The full URL that was called. 50 | cause: A description of the failure. 51 | """ 52 | super().__init__(method=method, url=url, cause=cause) 53 | 54 | 55 | class RateLimitExceededException(RequestException): 56 | """ 57 | Exception raised when an API call is rejected because a rate limit has 58 | been exceeded. 59 | 60 | Attributes: 61 | retry_after: If known, the time to wait before the next attempt to 62 | call the API should be made. 63 | """ 64 | 65 | def __init__( 66 | self, 67 | method: str, 68 | url: str, 69 | retry_after: Optional[timedelta] = None, 70 | ) -> None: 71 | """ 72 | Args: 73 | method: The type of the method being run (POST, GET, etc.). 74 | url: The full URL that was called. 75 | retry_after: If known, the time to wait before the next attempt 76 | to call the API should be made. 77 | """ 78 | super().__init__(method=method, url=url, cause="Exceeded rate limit") 79 | self.retry_after: Optional[timedelta] = retry_after 80 | 81 | 82 | class Client: 83 | """ 84 | Exposes each route in the OCP API as a method. 85 | """ 86 | 87 | def __init__( 88 | self, 89 | host: str = "open-catalyst-api.metademolab.com", 90 | scheme: str = "https", 91 | ) -> None: 92 | """ 93 | Args: 94 | host: The host that will be called. 95 | scheme: The scheme used when making API calls. 96 | """ 97 | self._host = host 98 | self._base_url = f"{scheme}://{host}" 99 | 100 | @property 101 | def host(self) -> str: 102 | """ 103 | The host being called by this client. 104 | """ 105 | return self._host 106 | 107 | async def get_models(self) -> Models: 108 | """ 109 | Fetch the list of models that are supported in the API. 110 | 111 | Raises: 112 | RateLimitExceededException: If the call was rejected because a 113 | server side rate limit was breached. 114 | NonRetryableRequestException: If the call was rejected and a retry 115 | is not expected to succeed. 116 | RequestException: For all other errors when making the request; it 117 | is possible, though not guaranteed, that a retry could succeed. 118 | 119 | Returns: 120 | The models that are supported in the API. 121 | """ 122 | response: str = await self._run_request( 123 | path="ocp/models", 124 | method="GET", 125 | ) 126 | return Models.from_json(response) 127 | 128 | async def get_bulks(self) -> Bulks: 129 | """ 130 | Fetch the list of bulk materials that are supported in the API. 131 | 132 | Raises: 133 | RateLimitExceededException: If the call was rejected because a 134 | server side rate limit was breached. 135 | NonRetryableRequestException: If the call was rejected and a retry 136 | is not expected to succeed. 137 | RequestException: For all other errors when making the request; it 138 | is possible, though not guaranteed, that a retry could succeed. 139 | 140 | Returns: 141 | The bulks that are supported throughout the API. 142 | """ 143 | response: str = await self._run_request( 144 | path="ocp/bulks", 145 | method="GET", 146 | ) 147 | return Bulks.from_json(response) 148 | 149 | async def get_adsorbates(self) -> Adsorbates: 150 | """ 151 | Fetch the list of adsorbates that are supported in the API. 152 | 153 | Raises: 154 | RateLimitExceededException: If the call was rejected because a 155 | server side rate limit was breached. 156 | NonRetryableRequestException: If the call was rejected and a retry 157 | is not expected to succeed. 158 | RequestException: For all other errors when making the request; it 159 | is possible, though not guaranteed, that a retry could succeed. 160 | 161 | Returns: 162 | The adsorbates that are supported throughout the API. 163 | """ 164 | response: str = await self._run_request( 165 | path="ocp/adsorbates", 166 | method="GET", 167 | ) 168 | return Adsorbates.from_json(response) 169 | 170 | async def get_slabs(self, bulk: Union[str, Bulk]) -> Slabs: 171 | """ 172 | Get a unique list of slabs for the input bulk structure. 173 | 174 | Args: 175 | bulk: If a string, the id of the bulk to use. Otherwise the Bulk 176 | instance to use. 177 | 178 | Raises: 179 | RateLimitExceededException: If the call was rejected because a 180 | server side rate limit was breached. 181 | NonRetryableRequestException: If the call was rejected and a retry 182 | is not expected to succeed. 183 | RequestException: For all other errors when making the request; it 184 | is possible, though not guaranteed, that a retry could succeed. 185 | 186 | Returns: 187 | Slabs for each of the unique surfaces of the material. 188 | """ 189 | response: str = await self._run_request( 190 | path="ocp/slabs", 191 | method="POST", 192 | data=json.dumps( 193 | {"bulk_src_id": bulk.src_id if isinstance(bulk, Bulk) else bulk} 194 | ), 195 | headers={"Content-Type": "application/json"}, 196 | ) 197 | return Slabs.from_json(response) 198 | 199 | async def get_adsorbate_slab_configs( 200 | self, adsorbate: str, slab: Slab 201 | ) -> AdsorbateSlabConfigs: 202 | """ 203 | Get a list of possible binding sites for the input adsorbate on the 204 | input slab. 205 | 206 | Args: 207 | adsorbate: Description of the the adsorbate to place. 208 | slab: Information about the slab on which the adsorbate should 209 | be placed. 210 | 211 | Raises: 212 | RateLimitExceededException: If the call was rejected because a 213 | server side rate limit was breached. 214 | NonRetryableRequestException: If the call was rejected and a retry 215 | is not expected to succeed. 216 | RequestException: For all other errors when making the request; it 217 | is possible, though not guaranteed, that a retry could succeed. 218 | 219 | Returns: 220 | Configurations for each adsorbate binding site on the slab. 221 | """ 222 | response: str = await self._run_request( 223 | path="ocp/adsorbate-slab-configs", 224 | method="POST", 225 | data=json.dumps( 226 | { 227 | "adsorbate": adsorbate, 228 | "slab": slab.to_dict(), 229 | } 230 | ), 231 | headers={"Content-Type": "application/json"}, 232 | ) 233 | return AdsorbateSlabConfigs.from_json(response) 234 | 235 | async def submit_adsorbate_slab_relaxations( 236 | self, 237 | adsorbate: str, 238 | adsorbate_configs: List[Atoms], 239 | bulk: Bulk, 240 | slab: Slab, 241 | model: str, 242 | ephemeral: bool = False, 243 | ) -> AdsorbateSlabRelaxationsSystem: 244 | """ 245 | Starts relaxations of the input adsorbate configurations on the input 246 | slab using energies and forces returned by the input model. Relaxations 247 | are run asynchronously and results can be fetched using the system id 248 | that is returned from this method. 249 | 250 | Args: 251 | adsorbate: Description of the adsorbate being simulated. 252 | adsorbate_configs: List of adsorbate configurations to relax. This 253 | should only include the adsorbates themselves; the surface is 254 | defined in the "slab" field that is a peer to this one. 255 | bulk: Details of the bulk material being simulated. 256 | slab: The structure of the slab on which adsorbates are placed. 257 | model: The model that will be used to evaluate energies and forces 258 | during relaxations. 259 | ephemeral: If False (default), any later attempt to delete the 260 | generated relaxations will be rejected. If True, deleting the 261 | relaxations will be allowed, which is generally useful for 262 | testing when there is no reason for results to be persisted. 263 | 264 | Raises: 265 | RateLimitExceededException: If the call was rejected because a 266 | server side rate limit was breached. 267 | NonRetryableRequestException: If the call was rejected and a retry 268 | is not expected to succeed. 269 | RequestException: For all other errors when making the request; it 270 | is possible, though not guaranteed, that a retry could succeed. 271 | 272 | Returns: 273 | IDs of the relaxations. 274 | """ 275 | response: str = await self._run_request( 276 | path="ocp/adsorbate-slab-relaxations", 277 | method="POST", 278 | data=json.dumps( 279 | { 280 | "adsorbate": adsorbate, 281 | "adsorbate_configs": [a.to_dict() for a in adsorbate_configs], 282 | "bulk": bulk.to_dict(), 283 | "slab": slab.to_dict(), 284 | "model": model, 285 | "ephemeral": ephemeral, 286 | } 287 | ), 288 | headers={"Content-Type": "application/json"}, 289 | ) 290 | return AdsorbateSlabRelaxationsSystem.from_json(response) 291 | 292 | async def get_adsorbate_slab_relaxations_request( 293 | self, system_id: str 294 | ) -> AdsorbateSlabRelaxationsRequest: 295 | """ 296 | Fetches the original relaxations request for the input system. 297 | 298 | Args: 299 | system_id: The ID of the system to fetch. 300 | 301 | Raises: 302 | RateLimitExceededException: If the call was rejected because a 303 | server side rate limit was breached. 304 | NonRetryableRequestException: If the call was rejected and a retry 305 | is not expected to succeed. 306 | RequestException: For all other errors when making the request; it 307 | is possible, though not guaranteed, that a retry could succeed. 308 | 309 | Returns: 310 | The original request that was made when submitting relaxations. 311 | """ 312 | response: str = await self._run_request( 313 | path=f"ocp/adsorbate-slab-relaxations/{system_id}", 314 | method="GET", 315 | ) 316 | return AdsorbateSlabRelaxationsRequest.from_json(response) 317 | 318 | async def get_adsorbate_slab_relaxations_results( 319 | self, 320 | system_id: str, 321 | config_ids: Optional[List[int]] = None, 322 | fields: Optional[List[str]] = None, 323 | ) -> AdsorbateSlabRelaxationsResults: 324 | """ 325 | Fetches relaxation results for the input system. 326 | 327 | Args: 328 | system_id: The system id of the relaxations. 329 | config_ids: If defined and not empty, a subset of configurations 330 | to fetch. Otherwise all configurations are returned. 331 | fields: If defined and not empty, a subset of fields in each 332 | configuration to fetch. Otherwise all fields are returned. 333 | 334 | Raises: 335 | RateLimitExceededException: If the call was rejected because a 336 | server side rate limit was breached. 337 | NonRetryableRequestException: If the call was rejected and a retry 338 | is not expected to succeed. 339 | RequestException: For all other errors when making the request; it 340 | is possible, though not guaranteed, that a retry could succeed. 341 | 342 | Returns: 343 | The relaxation results for each configuration in the system. 344 | """ 345 | params: Dict[str, Any] = {} 346 | if fields: 347 | params["field"] = fields 348 | if config_ids: 349 | params["config_id"] = config_ids 350 | response: str = await self._run_request( 351 | path=f"ocp/adsorbate-slab-relaxations/{system_id}/configs", 352 | method="GET", 353 | params=params, 354 | ) 355 | return AdsorbateSlabRelaxationsResults.from_json(response) 356 | 357 | async def delete_adsorbate_slab_relaxations(self, system_id: str) -> None: 358 | """ 359 | Deletes all relaxation results for the input system. 360 | 361 | Args: 362 | system_id: The ID of the system to delete. 363 | 364 | Raises: 365 | RateLimitExceededException: If the call was rejected because a 366 | server side rate limit was breached. 367 | NonRetryableRequestException: If the call was rejected and a retry 368 | is not expected to succeed. 369 | RequestException: For all other errors when making the request; it 370 | is possible, though not guaranteed, that a retry could succeed. 371 | """ 372 | await self._run_request( 373 | path=f"ocp/adsorbate-slab-relaxations/{system_id}", 374 | method="DELETE", 375 | ) 376 | 377 | async def _run_request(self, path: str, method: str, **kwargs) -> str: 378 | """ 379 | Helper method that runs the input request on a thread so that 380 | it doesn't block the event loop on the calling thread. 381 | 382 | Args: 383 | path: The URL path to make the request against. 384 | method: The HTTP method to use (GET, POST, etc.). 385 | 386 | Raises: 387 | RateLimitExceededException: If the call was rejected because a 388 | server side rate limit was breached. 389 | NonRetryableRequestException: If the call was rejected and a retry 390 | is not expected to succeed. 391 | RequestException: For all other errors when making the request; it 392 | is possible, though not guaranteed, that a retry could succeed. 393 | 394 | Returns: 395 | The response body from the request as a string. 396 | """ 397 | 398 | # Make the request 399 | url = f"{self._base_url}/{path}" 400 | try: 401 | response: requests.Response = await asyncio.to_thread( 402 | requests.request, 403 | method=method, 404 | url=url, 405 | **kwargs, 406 | ) 407 | except Exception as e: 408 | raise RequestException( 409 | method=method, 410 | url=url, 411 | cause=f"Exception while making request: {type(e).__name__}: {e}", 412 | ) from e 413 | 414 | # Check the response code 415 | if response.status_code >= 300: 416 | # Exceeded server side rate limit 417 | if response.status_code == 429: 418 | retry_after: Optional[str] = response.headers.get("Retry-After", None) 419 | raise RateLimitExceededException( 420 | method=method, 421 | url=url, 422 | retry_after=timedelta(seconds=float(retry_after)) 423 | if retry_after is not None 424 | else None, 425 | ) 426 | 427 | # Treat all other 400-level response codes as ones that are 428 | # unlikely to succeed on retry 429 | cause: str = ( 430 | f"Unexpected response code: {response.status_code}. " 431 | f"Response body: {response.text}" 432 | ) 433 | if response.status_code >= 400 and response.status_code < 500: 434 | raise NonRetryableRequestException( 435 | method=method, 436 | url=url, 437 | cause=cause, 438 | ) 439 | 440 | # Treat all other errors as ones that might succeed on retry 441 | raise RequestException( 442 | method=method, 443 | url=url, 444 | cause=cause, 445 | ) 446 | 447 | return response.text 448 | -------------------------------------------------------------------------------- /ocpapi/client/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | from typing import List, Optional, Tuple 4 | 5 | from dataclasses_json import CatchAll, Undefined, config, dataclass_json 6 | 7 | 8 | @dataclass_json(undefined=Undefined.INCLUDE) 9 | @dataclass 10 | class _DataModel: 11 | """ 12 | Base class for all data models. 13 | """ 14 | 15 | other_fields: CatchAll 16 | """ 17 | Fields that may have been added to the API that all not yet supported 18 | explicitly in this class. 19 | """ 20 | 21 | 22 | @dataclass_json(undefined=Undefined.INCLUDE) 23 | @dataclass 24 | class Model(_DataModel): 25 | """ 26 | Stores information about a single model supported in the API. 27 | """ 28 | 29 | id: str 30 | """ 31 | The ID of the model. 32 | """ 33 | 34 | 35 | @dataclass_json(undefined=Undefined.INCLUDE) 36 | @dataclass 37 | class Models(_DataModel): 38 | """ 39 | Stores the response from a request for models supported in the API. 40 | """ 41 | 42 | models: List[Model] 43 | """ 44 | The list of models that are supported. 45 | """ 46 | 47 | 48 | @dataclass_json(undefined=Undefined.INCLUDE) 49 | @dataclass 50 | class Bulk(_DataModel): 51 | """ 52 | Stores information about a single bulk material. 53 | """ 54 | 55 | src_id: str 56 | """ 57 | The ID of the material. 58 | """ 59 | 60 | formula: str 61 | """ 62 | The chemical formula of the material. 63 | """ 64 | 65 | # Stored under "els" in the API response 66 | elements: List[str] = field(metadata=config(field_name="els")) 67 | """ 68 | The list of elements in the material. 69 | """ 70 | 71 | 72 | @dataclass_json(undefined=Undefined.INCLUDE) 73 | @dataclass 74 | class Bulks(_DataModel): 75 | """ 76 | Stores the response from a request to fetch bulks supported in the API. 77 | """ 78 | 79 | bulks_supported: List[Bulk] 80 | """ 81 | List of bulks that can be used in the API. 82 | """ 83 | 84 | 85 | @dataclass_json(undefined=Undefined.INCLUDE) 86 | @dataclass 87 | class Adsorbates(_DataModel): 88 | """ 89 | Stores the response from a request to fetch adsorbates supported in the 90 | API. 91 | """ 92 | 93 | adsorbates_supported: List[str] 94 | """ 95 | List of adsorbates that can be used in the API. 96 | """ 97 | 98 | 99 | @dataclass_json(undefined=Undefined.INCLUDE) 100 | @dataclass 101 | class Atoms(_DataModel): 102 | """ 103 | Subset of the fields from an ASE Atoms object that are used within this 104 | API. 105 | """ 106 | 107 | cell: Tuple[ 108 | Tuple[float, float, float], 109 | Tuple[float, float, float], 110 | Tuple[float, float, float], 111 | ] 112 | """ 113 | 3x3 matrix with unit cell vectors. 114 | """ 115 | 116 | pbc: Tuple[bool, bool, bool] 117 | """ 118 | Whether the structure is periodic along the a, b, and c lattice vectors, 119 | respectively. 120 | """ 121 | 122 | numbers: List[int] 123 | """ 124 | The atomic number of each atom in the unit cell. 125 | """ 126 | 127 | positions: List[Tuple[float, float, float]] 128 | """ 129 | The coordinates of each atom in the unit cell, relative to the cartesian 130 | frame. 131 | """ 132 | 133 | tags: List[int] 134 | """ 135 | Labels for each atom in the unit cell where 0 represents a subsurface atom 136 | (fixed during optimization), 1 represents a surface atom, and 2 represents 137 | an adsorbate atom. 138 | """ 139 | 140 | def to_ase_atoms(self) -> "ASEAtoms": 141 | """ 142 | Creates an ase.Atoms object with the positions, element numbers, 143 | etc. populated from values on this object. 144 | 145 | Returns: 146 | ase.Atoms object with values from this object. 147 | """ 148 | 149 | from ase import Atoms as ASEAtoms 150 | from ase.constraints import FixAtoms 151 | 152 | return ASEAtoms( 153 | cell=self.cell, 154 | pbc=self.pbc, 155 | numbers=self.numbers, 156 | positions=self.positions, 157 | tags=self.tags, 158 | # Fix sub-surface atoms 159 | constraint=FixAtoms(mask=[t == 0 for t in self.tags]), 160 | ) 161 | 162 | 163 | @dataclass_json(undefined=Undefined.INCLUDE) 164 | @dataclass 165 | class SlabMetadata(_DataModel): 166 | """ 167 | Stores metadata about a slab that is returned from the API. 168 | """ 169 | 170 | # Stored under "bulk_id" in the API response 171 | bulk_src_id: str = field(metadata=config(field_name="bulk_id")) 172 | """ 173 | The ID of the bulk material from which the slab was derived. 174 | """ 175 | 176 | millers: Tuple[int, int, int] 177 | """ 178 | The Miller indices of the slab relative to bulk structure. 179 | """ 180 | 181 | shift: float 182 | """ 183 | The position along the vector defined by the Miller indices at which a 184 | cut was taken to generate the slab surface. 185 | """ 186 | 187 | top: bool 188 | """ 189 | If False, the top and bottom surfaces for this millers/shift pair are 190 | distinct and this slab represents the bottom surface. 191 | """ 192 | 193 | 194 | @dataclass_json(undefined=Undefined.INCLUDE) 195 | @dataclass 196 | class Slab(_DataModel): 197 | """ 198 | Stores all information about a slab that is returned from the API. 199 | """ 200 | 201 | # Stored under "slab_atomsobject" in the API response 202 | atoms: Atoms = field(metadata=config(field_name="slab_atomsobject")) 203 | """ 204 | The structure of the slab. 205 | """ 206 | 207 | # Stored under "slab_metadata" in the API response 208 | metadata: SlabMetadata = field(metadata=config(field_name="slab_metadata")) 209 | """ 210 | Extra information about the slab. 211 | """ 212 | 213 | 214 | @dataclass_json(undefined=Undefined.INCLUDE) 215 | @dataclass 216 | class Slabs(_DataModel): 217 | """ 218 | Stores the response from a request to fetch slabs for a bulk structure. 219 | """ 220 | 221 | slabs: List[Slab] 222 | """ 223 | The list of slabs that were generated from the input bulk structure. 224 | """ 225 | 226 | 227 | @dataclass_json(undefined=Undefined.INCLUDE) 228 | @dataclass 229 | class AdsorbateSlabConfigs(_DataModel): 230 | """ 231 | Stores the response from a request to fetch placements of a single 232 | absorbate on a slab. 233 | """ 234 | 235 | adsorbate_configs: List[Atoms] 236 | """ 237 | List of structures, each representing one possible adsorbate placement. 238 | """ 239 | 240 | slab: Slab 241 | """ 242 | The structure of the slab on which the adsorbate is placed. 243 | """ 244 | 245 | 246 | @dataclass_json(undefined=Undefined.INCLUDE) 247 | @dataclass 248 | class AdsorbateSlabRelaxationsSystem(_DataModel): 249 | """ 250 | Stores the response from a request to submit a new batch of adsorbate 251 | slab relaxations. 252 | """ 253 | 254 | system_id: str 255 | """ 256 | Unique ID for this set of relaxations which can be used to fetch results 257 | later. 258 | """ 259 | 260 | config_ids: List[int] 261 | """ 262 | The list of IDs assigned to each of the input adsorbate placements, in the 263 | same order in which they were submitted. 264 | """ 265 | 266 | 267 | @dataclass_json(undefined=Undefined.INCLUDE) 268 | @dataclass 269 | class AdsorbateSlabRelaxationsRequest(_DataModel): 270 | """ 271 | Stores the request to submit a new batch of adsorbate slab relaxations. 272 | """ 273 | 274 | adsorbate: str 275 | """ 276 | Description of the adsorbate. 277 | """ 278 | 279 | adsorbate_configs: List[Atoms] 280 | """ 281 | List of adsorbate placements being relaxed. 282 | """ 283 | 284 | bulk: Bulk 285 | """ 286 | Information about the original bulk structure used to create the slab. 287 | """ 288 | 289 | slab: Slab 290 | """ 291 | The structure of the slab on which adsorbates are placed. 292 | """ 293 | 294 | model: str 295 | """ 296 | The type of the ML model being used during relaxations. 297 | """ 298 | 299 | # Omit from serialization when None 300 | ephemeral: Optional[bool] = field( 301 | default=None, 302 | metadata=config(exclude=lambda v: v is None), 303 | ) 304 | """ 305 | Whether the relaxations can be deleted (assume they cannot be deleted if 306 | None). 307 | """ 308 | 309 | adsorbate_reaction: Optional[str] = field( 310 | default=None, 311 | metadata=config(exclude=lambda v: v is None), 312 | ) 313 | """ 314 | If possible, an html-formatted string describing the reaction will be added 315 | to this field. 316 | """ 317 | 318 | 319 | class Status(Enum): 320 | """ 321 | Relaxation status of a single adsorbate placement on a slab. 322 | """ 323 | 324 | NOT_AVAILABLE = "not_available" 325 | """ 326 | The configuration exists but the result is not yet available. It is 327 | possible that checking again in the future could yield a result. 328 | """ 329 | 330 | FAILED_RELAXATION = "failed_relaxation" 331 | """ 332 | The relaxation failed for this configuration. 333 | """ 334 | 335 | SUCCESS = "success" 336 | """ 337 | The relaxation was successful and the requested information about the 338 | configuration was returned. 339 | """ 340 | 341 | DOES_NOT_EXIST = "does_not_exist" 342 | """ 343 | The requested configuration does not exist. 344 | """ 345 | 346 | def __str__(self) -> str: 347 | return self.value 348 | 349 | 350 | @dataclass_json(undefined=Undefined.INCLUDE) 351 | @dataclass 352 | class AdsorbateSlabRelaxationResult(_DataModel): 353 | """ 354 | Stores information about a single adsorbate slab configuration, including 355 | outputs for the model used in relaxations. 356 | 357 | The API to fetch relaxation results supports requesting a subset of fields 358 | in order to limit the size of response payloads. Optional attributes will 359 | be defined only if they are including the response. 360 | """ 361 | 362 | config_id: int 363 | """ 364 | ID of the configuration within the system. 365 | """ 366 | 367 | status: Status 368 | """ 369 | The status of the request for information about this configuration. 370 | """ 371 | 372 | # Omit from serialization when None 373 | system_id: Optional[str] = field( 374 | default=None, 375 | metadata=config(exclude=lambda v: v is None), 376 | ) 377 | """ 378 | The ID of the system in which the configuration was originally submitted. 379 | """ 380 | 381 | cell: Optional[ 382 | Tuple[ 383 | Tuple[float, float, float], 384 | Tuple[float, float, float], 385 | Tuple[float, float, float], 386 | ] 387 | ] = field( 388 | default=None, 389 | metadata=config(exclude=lambda v: v is None), 390 | ) 391 | """ 392 | 3x3 matrix with unit cell vectors. 393 | """ 394 | 395 | pbc: Optional[Tuple[bool, bool, bool]] = field( 396 | default=None, 397 | metadata=config(exclude=lambda v: v is None), 398 | ) 399 | """ 400 | Whether the structure is periodic along the a, b, and c lattice vectors, 401 | respectively. 402 | """ 403 | 404 | numbers: Optional[List[int]] = field( 405 | default=None, 406 | metadata=config(exclude=lambda v: v is None), 407 | ) 408 | """ 409 | The atomic number of each atom in the unit cell. 410 | """ 411 | 412 | positions: Optional[List[Tuple[float, float, float]]] = field( 413 | default=None, 414 | metadata=config(exclude=lambda v: v is None), 415 | ) 416 | """ 417 | The coordinates of each atom in the unit cell, relative to the cartesian 418 | frame. 419 | """ 420 | 421 | tags: Optional[List[int]] = field( 422 | default=None, 423 | metadata=config(exclude=lambda v: v is None), 424 | ) 425 | """ 426 | Labels for each atom in the unit cell where 0 represents a subsurface atom 427 | (fixed during optimization), 1 represents a surface atom, and 2 represents 428 | an adsorbate atom. 429 | """ 430 | 431 | energy: Optional[float] = field( 432 | default=None, 433 | metadata=config(exclude=lambda v: v is None), 434 | ) 435 | """ 436 | The energy of the configuration. 437 | """ 438 | 439 | energy_trajectory: Optional[List[float]] = field( 440 | default=None, 441 | metadata=config(exclude=lambda v: v is None), 442 | ) 443 | """ 444 | The energy of the configuration at each point along the relaxation 445 | trajectory. 446 | """ 447 | 448 | forces: Optional[List[Tuple[float, float, float]]] = field( 449 | default=None, 450 | metadata=config(exclude=lambda v: v is None), 451 | ) 452 | """ 453 | The forces on each atom in the relaxed structure. 454 | """ 455 | 456 | def to_ase_atoms(self) -> "ASEAtoms": 457 | """ 458 | Creates an ase.Atoms object with the positions, element numbers, 459 | etc. populated from values on this object. 460 | 461 | The predicted energy and forces will also be copied to the new 462 | ase.Atoms object as a SinglePointCalculator (a calculator that 463 | stores the results of an already-run simulation). 464 | 465 | Returns: 466 | ase.Atoms object with values from this object. 467 | """ 468 | from ase import Atoms as ASEAtoms 469 | from ase.calculators.singlepoint import SinglePointCalculator 470 | from ase.constraints import FixAtoms 471 | 472 | atoms: ASEAtoms = ASEAtoms( 473 | cell=self.cell, 474 | pbc=self.pbc, 475 | numbers=self.numbers, 476 | positions=self.positions, 477 | tags=self.tags, 478 | ) 479 | if self.tags is not None: 480 | # Fix sub-surface atoms 481 | atoms.constraints = FixAtoms(mask=[t == 0 for t in self.tags]) 482 | atoms.calc = SinglePointCalculator( 483 | atoms=atoms, 484 | energy=self.energy, 485 | forces=self.forces, 486 | ) 487 | return atoms 488 | 489 | 490 | @dataclass_json(undefined=Undefined.INCLUDE) 491 | @dataclass 492 | class AdsorbateSlabRelaxationsResults(_DataModel): 493 | """ 494 | Stores the response from a request for results of adsorbate slab 495 | relaxations. 496 | """ 497 | 498 | configs: List[AdsorbateSlabRelaxationResult] 499 | """ 500 | List of configurations in the system, each representing one placement of 501 | an adsorbate on a slab surface. 502 | """ 503 | 504 | omitted_config_ids: List[int] = field(default_factory=lambda: list()) 505 | """ 506 | List of IDs of configurations that were requested but omitted by the 507 | server. Results for these IDs can be requested again. 508 | """ 509 | -------------------------------------------------------------------------------- /ocpapi/client/ui.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | # Map of known API hosts to UI hosts 4 | _API_TO_UI_HOSTS: Dict[str, str] = { 5 | "open-catalyst-api.metademolab.com": "open-catalyst.metademolab.com", 6 | } 7 | 8 | 9 | def get_results_ui_url(api_host: str, system_id: str) -> Optional[str]: 10 | """ 11 | Generates the URL at which results for the input system can be 12 | visualized. 13 | 14 | Args: 15 | api_host: The API host on which the system was run. 16 | system_id: ID of the system being visualized. 17 | 18 | Returns: 19 | The URL at which the input system can be visualized. None if the 20 | API host is not recognized. 21 | """ 22 | if ui_host := _API_TO_UI_HOSTS.get(api_host, None): 23 | return f"https://{ui_host}/results/{system_id}" 24 | return None 25 | -------------------------------------------------------------------------------- /ocpapi/version.py: -------------------------------------------------------------------------------- 1 | VERSION = "1.0.0" 2 | -------------------------------------------------------------------------------- /ocpapi/workflows/__init__.py: -------------------------------------------------------------------------------- 1 | from .adsorbates import ( # noqa 2 | AdsorbateBindingSites, 3 | AdsorbateSlabRelaxations, 4 | Lifetime, 5 | UnsupportedAdsorbateException, 6 | UnsupportedBulkException, 7 | UnsupportedModelException, 8 | find_adsorbate_binding_sites, 9 | get_adsorbate_slab_relaxation_results, 10 | wait_for_adsorbate_slab_relaxations, 11 | ) 12 | from .filter import ( # noqa 13 | keep_all_slabs, 14 | keep_slabs_with_miller_indices, 15 | prompt_for_slabs_to_keep, 16 | ) 17 | from .retry import ( # noqa 18 | NO_LIMIT, 19 | NoLimitType, 20 | RateLimitLogging, 21 | retry_api_calls, 22 | ) 23 | -------------------------------------------------------------------------------- /ocpapi/workflows/adsorbates.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from contextlib import AsyncExitStack, asynccontextmanager, suppress 4 | from contextvars import ContextVar 5 | from dataclasses import dataclass 6 | from enum import Enum, auto 7 | from typing import ( 8 | Any, 9 | AsyncGenerator, 10 | Awaitable, 11 | Callable, 12 | Dict, 13 | List, 14 | Optional, 15 | Tuple, 16 | ) 17 | 18 | from dataclasses_json import Undefined, dataclass_json 19 | from tqdm import tqdm 20 | from tqdm.contrib.logging import logging_redirect_tqdm 21 | 22 | from ocpapi.client import ( 23 | Adsorbates, 24 | AdsorbateSlabConfigs, 25 | AdsorbateSlabRelaxationResult, 26 | AdsorbateSlabRelaxationsResults, 27 | AdsorbateSlabRelaxationsSystem, 28 | Atoms, 29 | Bulk, 30 | Bulks, 31 | Client, 32 | Models, 33 | Slab, 34 | Slabs, 35 | Status, 36 | get_results_ui_url, 37 | ) 38 | 39 | from .context import set_context_var 40 | from .filter import prompt_for_slabs_to_keep 41 | from .log import log 42 | from .retry import NO_LIMIT, RateLimitLogging, retry_api_calls 43 | 44 | # Context instance that stores information about the adsorbate and bulk 45 | # material as a tuple in that order 46 | _CTX_AD_BULK: ContextVar[Tuple[str, str]] = ContextVar(f"{__name__}:ad_bulk") 47 | 48 | # Context intance that stores information about a slab 49 | _CTX_SLAB: ContextVar[Slab] = ContextVar(f"{__name__}:slab") 50 | 51 | 52 | def _setup_log_record_factory() -> None: 53 | """ 54 | Adds a log record factory that stores information about the currently 55 | running job on a log message. 56 | """ 57 | old_factory: Callable[..., logging.LogRecord] = logging.getLogRecordFactory() 58 | 59 | def new_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: 60 | # Save information about the bulk and absorbate if set 61 | parts: List[str] = [] 62 | if (ad_bulk := _CTX_AD_BULK.get(None)) is not None: 63 | parts.append(f"[{ad_bulk[0]}/{ad_bulk[1]}]") 64 | 65 | # Save information about the slab if set 66 | if (slab := _CTX_SLAB.get(None)) is not None: 67 | m = slab.metadata 68 | top = "t" if m.top else "b" 69 | millers = f"({m.millers[0]},{m.millers[1]},{m.millers[2]})" 70 | parts.append(f"[{millers}/{round(m.shift, 3):.3f},{top}]") 71 | 72 | # Prepend context to the current message 73 | record = old_factory(*args, **kwargs) 74 | parts.append(record.msg) 75 | record.msg = " ".join(parts) 76 | return record 77 | 78 | logging.setLogRecordFactory(new_factory) 79 | 80 | 81 | _setup_log_record_factory() 82 | 83 | 84 | DEFAULT_CLIENT: Client = Client() 85 | 86 | 87 | class AdsorbatesException(Exception): 88 | """ 89 | Base exception for all others in this module. 90 | """ 91 | 92 | pass 93 | 94 | 95 | class UnsupportedModelException(AdsorbatesException): 96 | """ 97 | Exception raised when a model is not supported in the API. 98 | """ 99 | 100 | def __init__(self, model: str, allowed_models: List[str]) -> None: 101 | """ 102 | Args: 103 | model: The model that was requested. 104 | allowed_models: The list of models that are supported. 105 | """ 106 | super().__init__( 107 | f"Model {model} is not supported; expected one of {allowed_models}" 108 | ) 109 | 110 | 111 | class UnsupportedBulkException(AdsorbatesException): 112 | """ 113 | Exception raised when a bulk material is not supported in the API. 114 | """ 115 | 116 | def __init__(self, bulk: str) -> None: 117 | """ 118 | Args: 119 | bulk: The bulk structure that was requested. 120 | """ 121 | super().__init__(f"Bulk {bulk} is not supported") 122 | 123 | 124 | class UnsupportedAdsorbateException(AdsorbatesException): 125 | """ 126 | Exception raised when an adsorbate is not supported in the API. 127 | """ 128 | 129 | def __init__(self, adsorbate: str) -> None: 130 | """ 131 | Args: 132 | adsorbate: The adsorbate that was requested. 133 | """ 134 | super().__init__(f"Adsorbate {adsorbate} is not supported") 135 | 136 | 137 | class Lifetime(Enum): 138 | """ 139 | Represents different lifetimes when running relaxations. 140 | """ 141 | 142 | SAVE = auto() 143 | """ 144 | The relaxation will be available on API servers indefinitely. It will not 145 | be possible to delete the relaxation in the future. 146 | """ 147 | 148 | MARK_EPHEMERAL = auto() 149 | """ 150 | The relaxation will be saved on API servers, but can be deleted at any time 151 | in the future. 152 | """ 153 | 154 | DELETE = auto() 155 | """ 156 | The relaxation will be deleted from API servers as soon as the results have 157 | been fetched. 158 | """ 159 | 160 | 161 | @dataclass_json(undefined=Undefined.EXCLUDE) 162 | @dataclass 163 | class AdsorbateSlabRelaxations: 164 | """ 165 | Stores the relaxations of adsorbate placements on the surface of a slab. 166 | """ 167 | 168 | slab: Slab 169 | """ 170 | The slab on which the adsorbate was placed. 171 | """ 172 | 173 | configs: List[AdsorbateSlabRelaxationResult] 174 | """ 175 | Details of the relaxation of each adsorbate placement, including the 176 | final position. 177 | """ 178 | 179 | system_id: str 180 | """ 181 | The ID of the system that stores all of the relaxations. 182 | """ 183 | 184 | api_host: str 185 | """ 186 | The API host on which the relaxations were run. 187 | """ 188 | 189 | ui_url: Optional[str] 190 | """ 191 | The URL at which results can be visualized. 192 | """ 193 | 194 | 195 | @dataclass_json(undefined=Undefined.EXCLUDE) 196 | @dataclass 197 | class AdsorbateBindingSites: 198 | """ 199 | Stores the inputs and results of a set of relaxations of adsorbate 200 | placements on the surface of a slab. 201 | """ 202 | 203 | adsorbate: str 204 | """ 205 | Description of the adsorbate. 206 | """ 207 | 208 | bulk: Bulk 209 | """ 210 | The bulk material that was being modeled. 211 | """ 212 | 213 | model: str 214 | """ 215 | The type of the model that was run. 216 | """ 217 | 218 | slabs: List[AdsorbateSlabRelaxations] 219 | """ 220 | The list of slabs that were generated from the bulk structure. Each 221 | contains its own list of adsorbate placements. 222 | """ 223 | 224 | 225 | @retry_api_calls(max_attempts=3) 226 | async def _ensure_model_supported(client: Client, model: str) -> None: 227 | """ 228 | Checks that the input model is supported in the API. 229 | 230 | Args: 231 | client: The client to use when making requests to the API. 232 | model: The model to check. 233 | 234 | Raises: 235 | UnsupportedModelException: If the model is not supported. 236 | """ 237 | models: Models = await client.get_models() 238 | allowed_models: List[str] = [m.id for m in models.models] 239 | if model not in allowed_models: 240 | raise UnsupportedModelException( 241 | model=model, 242 | allowed_models=allowed_models, 243 | ) 244 | 245 | 246 | @retry_api_calls(max_attempts=3) 247 | async def _get_bulk_if_supported(client: Client, bulk: str) -> Bulk: 248 | """ 249 | Returns the object from the input bulk if it is supported in the API. 250 | 251 | Args: 252 | client: The client to use when making requests to the API. 253 | bulk: The bulk to fetch. 254 | 255 | Raises: 256 | UnsupportedBulkException: If the requested bulk is not supported. 257 | 258 | Returns: 259 | Bulk instance for the input type. 260 | """ 261 | bulks: Bulks = await client.get_bulks() 262 | for b in bulks.bulks_supported: 263 | if b.src_id == bulk: 264 | return b 265 | raise UnsupportedBulkException(bulk) 266 | 267 | 268 | @retry_api_calls(max_attempts=3) 269 | async def _ensure_adsorbate_supported(client: Client, adsorbate: str) -> None: 270 | """ 271 | Checks that the input adsorbate is supported in the API. 272 | 273 | Args: 274 | client: The client to use when making requests to the API. 275 | adsorbate: The adsorbate to check. 276 | 277 | Raises: 278 | UnsupportedAdsorbateException: If the adsorbate is not supported. 279 | """ 280 | adsorbates: Adsorbates = await client.get_adsorbates() 281 | if adsorbate not in adsorbates.adsorbates_supported: 282 | raise UnsupportedAdsorbateException(adsorbate) 283 | 284 | 285 | @retry_api_calls(max_attempts=3) 286 | async def _get_slabs( 287 | client: Client, 288 | bulk: Bulk, 289 | ) -> List[Slab]: 290 | """ 291 | Enumerates surfaces for the input bulk material. 292 | 293 | Args: 294 | client: The client to use when making requests to the API. 295 | bulk: The bulk material from which slabs will be generated. 296 | 297 | Returns: 298 | The list of slabs that were generated. 299 | """ 300 | slabs: Slabs = await client.get_slabs(bulk) 301 | return slabs.slabs 302 | 303 | 304 | @retry_api_calls(max_attempts=3) 305 | async def _get_absorbate_configs_on_slab( 306 | client: Client, 307 | adsorbate: str, 308 | slab: Slab, 309 | ) -> AdsorbateSlabConfigs: 310 | """ 311 | Generate initial guesses at adsorbate binding sites on the input slab. 312 | 313 | Args: 314 | client: The client to use when making API calls. 315 | adsorbate: Description of the adsorbate to place. 316 | slab: The slab on which the adsorbate should be placed. 317 | 318 | Returns: 319 | An updated slab instance that has had tags applied to it and a list 320 | of Atoms objects, each with the positions of the adsorbate atoms on 321 | one of the candidate binding sites. 322 | """ 323 | return await client.get_adsorbate_slab_configs( 324 | adsorbate=adsorbate, 325 | slab=slab, 326 | ) 327 | 328 | 329 | async def _get_absorbate_configs_on_slab_with_logging( 330 | client: Client, 331 | adsorbate: str, 332 | slab: Slab, 333 | ) -> AdsorbateSlabConfigs: 334 | """ 335 | Wrapper around _get_absorbate_configs_on_slab that adds logging. 336 | """ 337 | with set_context_var(_CTX_SLAB, slab): 338 | # Enumerate candidate binding sites 339 | log.info( 340 | "Generating adsorbate placements on " 341 | f"{'top' if slab.metadata.top else 'bottom'} " 342 | f"{slab.metadata.millers} surface, shifted by " 343 | f"{round(slab.metadata.shift, 3)}" 344 | ) 345 | return await _get_absorbate_configs_on_slab( 346 | client=client, 347 | adsorbate=adsorbate, 348 | slab=slab, 349 | ) 350 | 351 | 352 | async def _get_adsorbate_configs_on_slabs( 353 | client: Client, 354 | adsorbate: str, 355 | slabs: List[Slab], 356 | ) -> List[AdsorbateSlabConfigs]: 357 | """ 358 | Finds candidate adsorbate binding sites on each of the input slabs. 359 | 360 | Args: 361 | client: The client to use when making API calls. 362 | adsorbate: Description of the adsorbate to place. 363 | slabs: The slabs on which the adsorbate should be placed. 364 | 365 | Returns: 366 | List of slabs and, for each, the positions of the adsorbate 367 | atoms in the potential binding site. 368 | """ 369 | tasks: List[asyncio.Task] = [ 370 | asyncio.create_task( 371 | _get_absorbate_configs_on_slab_with_logging( 372 | client=client, 373 | adsorbate=adsorbate, 374 | slab=slab, 375 | ) 376 | ) 377 | for slab in slabs 378 | ] 379 | if tasks: 380 | await asyncio.wait(tasks) 381 | return [t.result() for t in tasks] 382 | 383 | 384 | # The API behind Client.submit_adsorbate_slab_relaxations() is rate limited 385 | # and this decorator will handle retrying when that rate limit is breached. 386 | # Retry forever since we can't know how many jobs are being submitted along 387 | # with this one (rate limits are enforced on the API server and not limited 388 | # to a specific instance of this module). 389 | @retry_api_calls( 390 | max_attempts=NO_LIMIT, 391 | rate_limit_logging=RateLimitLogging( 392 | logger=log, 393 | action="submit relaxations", 394 | ), 395 | ) 396 | async def _submit_relaxations( 397 | client: Client, 398 | adsorbate: str, 399 | adsorbate_configs: List[Atoms], 400 | bulk: Bulk, 401 | slab: Slab, 402 | model: str, 403 | ephemeral: bool, 404 | ) -> str: 405 | """ 406 | Start relaxations for each of the input adsorbate configurations on the 407 | input slab. 408 | 409 | Args: 410 | client: The client to use when making API calls. 411 | adsorbate: Description of the adsorbate to place. 412 | adsorbate_configs: Positions of the adsorbate on the slab. Each 413 | will be relaxed independently. 414 | bulk: The bulk material from which the slab was generated. 415 | slab: The slab that should be searched for adsorbate binding sites. 416 | model: The model to use when evaluating forces and energies. 417 | ephemeral: Whether the relaxations should be marked as ephemeral. 418 | 419 | Returns: 420 | The system ID of the relaxation run, which can be used to fetch results 421 | as they become available. 422 | """ 423 | system: AdsorbateSlabRelaxationsSystem = ( 424 | await client.submit_adsorbate_slab_relaxations( 425 | adsorbate=adsorbate, 426 | adsorbate_configs=adsorbate_configs, 427 | bulk=bulk, 428 | slab=slab, 429 | model=model, 430 | ephemeral=ephemeral, 431 | ) 432 | ) 433 | return system.system_id 434 | 435 | 436 | async def _submit_relaxations_with_progress_logging( 437 | client: Client, 438 | adsorbate: str, 439 | adsorbate_configs: List[Atoms], 440 | bulk: Bulk, 441 | slab: Slab, 442 | model: str, 443 | ephemeral: bool, 444 | ) -> str: 445 | """ 446 | Wrapper around _submit_relaxations that adds periodic logging in case 447 | calls to submit relaxations are being rate limited. 448 | """ 449 | 450 | # Function that will log periodically while attempts to submit relaxations 451 | # are being retried 452 | async def log_waiting() -> None: 453 | while True: 454 | await asyncio.sleep(30) 455 | log.info( 456 | "Still waiting for relaxations to be accepted, possibly " 457 | "because calls are being rate limited" 458 | ) 459 | 460 | # Run until relaxations are accepted 461 | submit_task = asyncio.create_task( 462 | _submit_relaxations( 463 | client=client, 464 | adsorbate=adsorbate, 465 | adsorbate_configs=adsorbate_configs, 466 | bulk=bulk, 467 | slab=slab, 468 | model=model, 469 | ephemeral=ephemeral, 470 | ) 471 | ) 472 | logging_task = asyncio.create_task(log_waiting()) 473 | _, pending = await asyncio.wait( 474 | [logging_task, submit_task], 475 | return_when=asyncio.FIRST_COMPLETED, 476 | ) 477 | 478 | # Cancel pending tasks (this should just be the task to log that waiting) 479 | for task in pending: 480 | with suppress(asyncio.CancelledError): 481 | task.cancel() 482 | await task 483 | 484 | return submit_task.result() 485 | 486 | 487 | @retry_api_calls(max_attempts=3) 488 | async def get_adsorbate_slab_relaxation_results( 489 | system_id: str, 490 | config_ids: Optional[List[int]] = None, 491 | fields: Optional[List[str]] = None, 492 | client: Client = DEFAULT_CLIENT, 493 | ) -> List[AdsorbateSlabRelaxationResult]: 494 | """ 495 | Wrapper around Client.get_adsorbate_slab_relaxations_results() that 496 | handles retries, including re-fetching individual configurations that 497 | are initially omitted. 498 | 499 | Args: 500 | client: The client to use when making API calls. 501 | system_id: The system ID of the relaxations. 502 | config_ids: If defined and not empty, a subset of configurations 503 | to fetch. Otherwise all configurations are returned. 504 | fields: If defined and not empty, a subset of fields in each 505 | configuration to fetch. Otherwise all fields are returned. 506 | 507 | Returns: 508 | List of relaxation results, one for each adsorbate configuration in 509 | the system. 510 | """ 511 | results: AdsorbateSlabRelaxationsResults = ( 512 | await client.get_adsorbate_slab_relaxations_results( 513 | system_id=system_id, 514 | config_ids=config_ids, 515 | fields=fields, 516 | ) 517 | ) 518 | 519 | # Save a copy of all results that were fetched 520 | fetched: List[AdsorbateSlabRelaxationResult] = list(results.configs) 521 | 522 | # If any results were omitted, fetch them before returning 523 | if results.omitted_config_ids: 524 | fetched.extend( 525 | await get_adsorbate_slab_relaxation_results( 526 | client=client, 527 | system_id=system_id, 528 | config_ids=results.omitted_config_ids, 529 | fields=fields, 530 | ) 531 | ) 532 | 533 | return fetched 534 | 535 | 536 | async def wait_for_adsorbate_slab_relaxations( 537 | system_id: str, 538 | check_immediately: bool = False, 539 | slow_interval_sec: float = 30, 540 | fast_interval_sec: float = 10, 541 | pbar: Optional[tqdm] = None, 542 | client: Client = DEFAULT_CLIENT, 543 | ) -> Dict[int, Status]: 544 | """ 545 | Blocks until all relaxations in the input system have finished, whether 546 | successfully or not. 547 | 548 | Relaxations are queued in the API, waiting until machines are ready to 549 | run them. Once started, they can take 1-2 minutes to finish. This method 550 | initially sleeps "slow_interval_sec" seconds between each check for any 551 | relaxations having finished. Once at least one result is ready, subsequent 552 | sleeps are for "fast_interval_sec" seconds. 553 | 554 | Args: 555 | system_id: The ID of the system for which relaxations are running. 556 | check_immediately: If False (default), sleep before the first check 557 | for relaxations having finished. If True, check whether relaxations 558 | have finished immediately on entering this function. 559 | slow_interval_sec: The number of seconds to wait between each check 560 | while all are still running. 561 | fast_interval_sec: The number of seconds to wait between each check 562 | when at least one relaxation has finished in the system. 563 | pbar: A tqdm instance that tracks the number of configurations that 564 | have finished. This will be updated with the number of individual 565 | configurations whose relaxations have finished. 566 | client: The client to use when making API calls. 567 | 568 | Returns: 569 | Map of config IDs in the system to their terminal status. 570 | """ 571 | 572 | # First wait if needed 573 | wait_for_sec: float = slow_interval_sec 574 | if not check_immediately: 575 | await asyncio.sleep(wait_for_sec) 576 | 577 | # Run until all results are available 578 | num_finished: int = 0 579 | while True: 580 | # Get the current results. Only fetch the energy; this hits an index 581 | # that will return results more quickly. 582 | results: List[ 583 | AdsorbateSlabRelaxationResult 584 | ] = await get_adsorbate_slab_relaxation_results( 585 | client=client, 586 | system_id=system_id, 587 | fields=["energy"], 588 | ) 589 | 590 | # Check the number of finished jobs 591 | last_num_finished: int = num_finished 592 | num_finished = len([r for r in results if r.status != Status.NOT_AVAILABLE]) 593 | if pbar is not None: 594 | pbar.update(num_finished - last_num_finished) 595 | 596 | # Return if all of the relaxations have finished 597 | if num_finished == len(results): 598 | log.info("All relaxations have finished") 599 | return {r.config_id: r.status for r in results} 600 | 601 | # Shorten the wait time if any relaxations have finished 602 | if num_finished > 0: 603 | wait_for_sec = fast_interval_sec 604 | 605 | # Wait until the next scheduled check 606 | log.info(f"{num_finished} of {len(results)} relaxations have finished") 607 | await asyncio.sleep(wait_for_sec) 608 | 609 | 610 | @retry_api_calls(max_attempts=3) 611 | async def _delete_system(client: Client, system_id: str) -> None: 612 | """ 613 | Deletes the input system, with retries on failed attempts. 614 | 615 | Args: 616 | client: The client to use when making API calls. 617 | system_id: The ID of the system to delete. 618 | """ 619 | await client.delete_adsorbate_slab_relaxations(system_id) 620 | 621 | 622 | @asynccontextmanager 623 | async def _ensure_system_deleted( 624 | client: Client, 625 | system_id: str, 626 | ) -> AsyncGenerator[None, None]: 627 | """ 628 | Immediately yields control to the caller. When control returns to this 629 | function, try to delete the system with the input id. 630 | 631 | Args: 632 | client: The client to use when making API calls. 633 | system_id: The ID of the system to delete. 634 | """ 635 | try: 636 | yield 637 | finally: 638 | log.info(f"Ensuring system with id {system_id} is deleted") 639 | await _delete_system(client=client, system_id=system_id) 640 | 641 | 642 | async def _run_relaxations_on_slab( 643 | client: Client, 644 | adsorbate: str, 645 | adsorbate_configs: List[Atoms], 646 | bulk: Bulk, 647 | slab: Slab, 648 | model: str, 649 | lifetime: Lifetime, 650 | pbar: tqdm, 651 | ) -> AdsorbateSlabRelaxations: 652 | """ 653 | Start relaxations for each adsorbate configuration on the input slab 654 | and wait for all to finish. 655 | 656 | Args: 657 | client: The client to use when making API calls. 658 | adsorbate: Description of the adsorbate to place. 659 | adsorbate_configs: The positions of atoms in each adsorbate placement 660 | to be relaxed. 661 | bulk: The bulk material from which the slab was generated. 662 | slab: The slab that should be searched for adsorbate binding sites. 663 | model: The model to use when evaluating forces and energies. 664 | lifetime: Whether relaxations should be saved on the server, be marked 665 | as ephemeral (allowing them to deleted in the future), or deleted 666 | immediately. 667 | pbar: A progress bar to update as relaxations finish. 668 | 669 | Returns: 670 | Details of each adsorbate placement, including its relaxed position. 671 | """ 672 | async with AsyncExitStack() as es: 673 | es.enter_context(set_context_var(_CTX_SLAB, slab)) 674 | 675 | # Start relaxations for all of the adsorbate placements 676 | log.info( 677 | f"Submitting relaxations for {len(adsorbate_configs)} " 678 | "adsorbate placements" 679 | ) 680 | system_id: str = await _submit_relaxations_with_progress_logging( 681 | client=client, 682 | adsorbate=adsorbate, 683 | adsorbate_configs=adsorbate_configs, 684 | bulk=bulk, 685 | slab=slab, 686 | model=model, 687 | ephemeral=lifetime in {Lifetime.MARK_EPHEMERAL, Lifetime.DELETE}, 688 | ) 689 | log.info(f"Relaxations running with system id {system_id}") 690 | 691 | # If requested, ensure the system is deleted once results have been 692 | # fetched 693 | if lifetime == Lifetime.DELETE: 694 | await es.enter_async_context( 695 | _ensure_system_deleted(client=client, system_id=system_id) 696 | ) 697 | 698 | # Wait for all relaxations to finish 699 | await wait_for_adsorbate_slab_relaxations( 700 | system_id=system_id, 701 | pbar=pbar, 702 | client=client, 703 | ) 704 | 705 | # Fetch the final results 706 | results: List[ 707 | AdsorbateSlabRelaxationResult 708 | ] = await get_adsorbate_slab_relaxation_results( 709 | client=client, 710 | system_id=system_id, 711 | ) 712 | return AdsorbateSlabRelaxations( 713 | slab=slab, 714 | configs=results, 715 | system_id=system_id, 716 | api_host=client.host, 717 | ui_url=get_results_ui_url( 718 | api_host=client.host, 719 | system_id=system_id, 720 | ), 721 | ) 722 | 723 | 724 | async def _refresh_pbar(pbar: tqdm, interval_sec: float) -> None: 725 | """ 726 | Helper function that refreshes the input progress bar on a regular 727 | schedule. This function never returns; it must be cancelled. 728 | 729 | Args: 730 | pbar: The progress bar to refresh. 731 | interval_sec: The number of seconds to wait between each refresh. 732 | """ 733 | while True: 734 | await asyncio.sleep(interval_sec) 735 | pbar.refresh() 736 | 737 | 738 | async def _relax_binding_sites_on_slabs( 739 | client: Client, 740 | adsorbate: str, 741 | bulk: Bulk, 742 | adslabs: List[AdsorbateSlabConfigs], 743 | model: str, 744 | lifetime: Lifetime, 745 | ) -> AdsorbateBindingSites: 746 | """ 747 | Search for adsorbate binding sites on the input slab. 748 | 749 | Args: 750 | client: The client to use when making API calls. 751 | adsorbate: Description of the adsorbate to place. 752 | bulk: The bulk material from which the slab was generated. 753 | adslabs: The slabs and, for each, the binding sites that should be 754 | relaxed. 755 | model: The model to use when evaluating forces and energies. 756 | lifetime: Whether relaxations should be saved on the server, be marked 757 | as ephemeral (allowing them to deleted in the future), or deleted 758 | immediately. 759 | 760 | Returns: 761 | Details of each adsorbate placement, including its relaxed position. 762 | """ 763 | 764 | # Make sure logs and progress bars work together while tqdm is 765 | # being used 766 | with logging_redirect_tqdm(): 767 | # Start a progress bar to track relaxations of the individual 768 | # configurations 769 | with tqdm( 770 | desc="Finished relaxations", 771 | bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]", 772 | total=sum(len(adslab.adsorbate_configs) for adslab in adslabs), 773 | miniters=0, 774 | leave=False, 775 | ) as pbar: 776 | # Start a task that refreshes the progress bar on a regular 777 | # schedule 778 | pbar_refresh_task: asyncio.Task = asyncio.create_task( 779 | _refresh_pbar(pbar=pbar, interval_sec=1) 780 | ) 781 | 782 | # Run relaxations for all configurations on all slabs 783 | tasks: List[asyncio.Task] = [ 784 | asyncio.create_task( 785 | _run_relaxations_on_slab( 786 | client=client, 787 | adsorbate=adsorbate, 788 | adsorbate_configs=adslab.adsorbate_configs, 789 | bulk=bulk, 790 | slab=adslab.slab, 791 | model=model, 792 | lifetime=lifetime, 793 | pbar=pbar, 794 | ) 795 | ) 796 | for adslab in adslabs 797 | ] 798 | if tasks: 799 | await asyncio.wait(tasks) 800 | 801 | # Cancel the task that refreshes the progress bar on a schedule 802 | with suppress(asyncio.CancelledError): 803 | pbar_refresh_task.cancel() 804 | await pbar_refresh_task 805 | 806 | # Return results 807 | return AdsorbateBindingSites( 808 | adsorbate=adsorbate, 809 | bulk=bulk, 810 | model=model, 811 | slabs=[t.result() for t in tasks], 812 | ) 813 | 814 | 815 | _DEFAULT_ADSLAB_FILTER: Callable[ 816 | [List[AdsorbateSlabConfigs]], Awaitable[List[AdsorbateSlabConfigs]] 817 | ] = prompt_for_slabs_to_keep() 818 | 819 | 820 | async def find_adsorbate_binding_sites( 821 | adsorbate: str, 822 | bulk: str, 823 | model: str = "equiformer_v2_31M_s2ef_all_md", 824 | adslab_filter: Callable[ 825 | [List[AdsorbateSlabConfigs]], Awaitable[List[AdsorbateSlabConfigs]] 826 | ] = _DEFAULT_ADSLAB_FILTER, 827 | client: Client = DEFAULT_CLIENT, 828 | lifetime: Lifetime = Lifetime.SAVE, 829 | ) -> AdsorbateBindingSites: 830 | """ 831 | Search for adsorbate binding sites on surfaces of a bulk material. 832 | This executes the following steps: 833 | 834 | 1. Ensure that both the adsorbate and bulk are supported in the 835 | OCP API. 836 | 2. Enumerate unique surfaces from the bulk material. 837 | 3. Enumerate likely binding sites for the input adsorbate on each 838 | of the generated surfaces. 839 | 4. Filter the list of generated adsorbate/slab (adslab) configurations 840 | using the input adslab_filter. 841 | 5. Relax each generated surface+adsorbate structure by refining 842 | atomic positions to minimize forces generated by the input model. 843 | 844 | Args: 845 | adsorbate: Description of the adsorbate to place. 846 | bulk: The ID (typically Materials Project MP ID) of the bulk material 847 | on which the adsorbate will be placed. 848 | model: The type of the model to use when calculating forces during 849 | relaxations. 850 | adslab_filter: A function that modifies the set of adsorbate/slab 851 | configurations that will be relaxed. This can be used to subselect 852 | slabs and/or adsorbate configurations. 853 | client: The OCP API client to use. 854 | lifetime: Whether relaxations should be saved on the server, be marked 855 | as ephemeral (allowing them to deleted in the future), or deleted 856 | immediately. 857 | 858 | Returns: 859 | Details of each adsorbate binding site, including results of relaxing 860 | to locally-optimized positions using the input model. 861 | 862 | Raises: 863 | UnsupportedModelException: If the requested model is not supported. 864 | UnsupportedBulkException: If the requested bulk is not supported. 865 | UnsupportedAdsorbateException: If the requested adsorbate is not 866 | supported. 867 | """ 868 | with set_context_var(_CTX_AD_BULK, (adsorbate, bulk)): 869 | # Make sure the input model is supported in the API 870 | log.info(f"Ensuring that model {model} is supported") 871 | await _ensure_model_supported( 872 | client=client, 873 | model=model, 874 | ) 875 | 876 | # Make sure the input adsorbate is supported in the API 877 | log.info(f"Ensuring that adsorbate {adsorbate} is supported") 878 | await _ensure_adsorbate_supported( 879 | client=client, 880 | adsorbate=adsorbate, 881 | ) 882 | 883 | # Make sure the input bulk is supported in the API 884 | log.info(f"Ensuring that bulk {bulk} is supported") 885 | bulk_obj: Bulk = await _get_bulk_if_supported( 886 | client=client, 887 | bulk=bulk, 888 | ) 889 | 890 | # Fetch all slabs for the bulk 891 | log.info("Generating surfaces") 892 | slabs: List[Slab] = await _get_slabs( 893 | client=client, 894 | bulk=bulk_obj, 895 | ) 896 | 897 | # Finding candidate binding site on each slab 898 | adslabs: List[AdsorbateSlabConfigs] = await _get_adsorbate_configs_on_slabs( 899 | client=client, 900 | adsorbate=adsorbate, 901 | slabs=slabs, 902 | ) 903 | 904 | # Filter the adslabs 905 | adslabs = await adslab_filter(adslabs) 906 | 907 | # Find binding sites on all slabs 908 | return await _relax_binding_sites_on_slabs( 909 | client=client, 910 | adsorbate=adsorbate, 911 | bulk=bulk_obj, 912 | adslabs=adslabs, 913 | model=model, 914 | lifetime=lifetime, 915 | ) 916 | -------------------------------------------------------------------------------- /ocpapi/workflows/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from contextvars import ContextVar 3 | from typing import Any, Generator 4 | 5 | 6 | @contextmanager 7 | def set_context_var( 8 | context_var: ContextVar, 9 | value: Any, 10 | ) -> Generator[None, None, None]: 11 | """ 12 | Sets the input convext variable to the input value and yields control 13 | back to the caller. When control returns to this function, the context 14 | variable is reset to its original value. 15 | 16 | Args: 17 | context_var: The context variable to set. 18 | value: The value to assign to the variable. 19 | """ 20 | token = context_var.set(value) 21 | try: 22 | yield 23 | finally: 24 | context_var.reset(token) 25 | -------------------------------------------------------------------------------- /ocpapi/workflows/filter.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Set, Tuple 2 | 3 | from ocpapi.client import AdsorbateSlabConfigs, SlabMetadata 4 | 5 | 6 | class keep_all_slabs: 7 | """ 8 | Adslab filter than returns all slabs. 9 | """ 10 | 11 | async def __call__( 12 | self, 13 | adslabs: List[AdsorbateSlabConfigs], 14 | ) -> List[AdsorbateSlabConfigs]: 15 | return adslabs 16 | 17 | 18 | class keep_slabs_with_miller_indices: 19 | """ 20 | Adslab filter that keeps any slabs with the configured miller indices. 21 | Slabs with other miller indices will be ignored. 22 | """ 23 | 24 | def __init__(self, miller_indices: Iterable[Tuple[int, int, int]]) -> None: 25 | """ 26 | Args: 27 | miller_indices: The list of miller indices that will be allowed. 28 | Slabs with any other miller indices will be dropped by this 29 | filter. 30 | """ 31 | self._unique_millers: Set[Tuple[int, int, int]] = set(miller_indices) 32 | 33 | async def __call__( 34 | self, 35 | adslabs: List[AdsorbateSlabConfigs], 36 | ) -> List[AdsorbateSlabConfigs]: 37 | return [ 38 | adslab 39 | for adslab in adslabs 40 | if adslab.slab.metadata.millers in self._unique_millers 41 | ] 42 | 43 | 44 | class prompt_for_slabs_to_keep: 45 | """ 46 | Adslab filter than presents the user with an interactive prompt to choose 47 | which of the input slabs to keep. 48 | """ 49 | 50 | @staticmethod 51 | def _sort_key( 52 | adslab: AdsorbateSlabConfigs, 53 | ) -> Tuple[Tuple[int, int, int], float, str]: 54 | """ 55 | Generates a sort key from the input adslab. Returns the miller indices, 56 | shift, and top/bottom label so that they will be sorted by those values 57 | in that order. 58 | """ 59 | metadata: SlabMetadata = adslab.slab.metadata 60 | return (metadata.millers, metadata.shift, metadata.top) 61 | 62 | async def __call__( 63 | self, 64 | adslabs: List[AdsorbateSlabConfigs], 65 | ) -> List[AdsorbateSlabConfigs]: 66 | from inquirer import Checkbox, prompt 67 | 68 | # Break early if no adslabs were provided 69 | if not adslabs: 70 | return adslabs 71 | 72 | # Sort the input list so the options are grouped in a sensible way 73 | adslabs = sorted(adslabs, key=self._sort_key) 74 | 75 | # List of options to present to the user. The first item in each tuple 76 | # will be presented to the user in the prompt. The second item in each 77 | # tuple (indices from the input list of adslabs) will be returned from 78 | # the prompt. 79 | choices: List[Tuple[str, int]] = [ 80 | ( 81 | ( 82 | f"{adslab.slab.metadata.millers} " 83 | f"{'top' if adslab.slab.metadata.top else 'bottom'} " 84 | "surface shifted by " 85 | f"{round(adslab.slab.metadata.shift, 3)}; " 86 | f"{len(adslab.adsorbate_configs)} unique adsorbate " 87 | "placements to relax" 88 | ), 89 | idx, 90 | ) 91 | for idx, adslab in enumerate(adslabs) 92 | ] 93 | checkbox: Checkbox = Checkbox( 94 | "adslabs", 95 | message=( 96 | "Choose surfaces to relax (up/down arrows to move, " 97 | "space to select, enter when finished)" 98 | ), 99 | choices=choices, 100 | ) 101 | selected_indices: List[int] = prompt([checkbox])["adslabs"] 102 | 103 | # Return the adslabs that were chosen 104 | return [adslabs[i] for i in selected_indices] 105 | -------------------------------------------------------------------------------- /ocpapi/workflows/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | log = logging.getLogger("ocpapi") 4 | -------------------------------------------------------------------------------- /ocpapi/workflows/retry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Any, Literal, Optional, Union 4 | 5 | from tenacity import RetryCallState 6 | from tenacity import retry as tenacity_retry 7 | from tenacity import ( 8 | retry_if_exception_type, 9 | retry_if_not_exception_type, 10 | stop_after_attempt, 11 | stop_never, 12 | wait_fixed, 13 | wait_random, 14 | ) 15 | from tenacity.wait import wait_base 16 | 17 | from ocpapi.client import ( 18 | NonRetryableRequestException, 19 | RateLimitExceededException, 20 | RequestException, 21 | ) 22 | 23 | 24 | @dataclass 25 | class RateLimitLogging: 26 | """ 27 | Controls logging when rate limits are hit. 28 | """ 29 | 30 | logger: logging.Logger 31 | """ 32 | The logger to use. 33 | """ 34 | 35 | action: str 36 | """ 37 | A short description of the action being attempted. 38 | """ 39 | 40 | 41 | class _wait_check_retry_after(wait_base): 42 | """ 43 | Tenacity wait strategy that first checks whether RateLimitExceededException 44 | was raised and that it includes a retry-after value; if so wait, for that 45 | amount of time. Otherwise, fall back to the provided default strategy. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | default_wait: wait_base, 51 | rate_limit_logging: Optional[RateLimitLogging] = None, 52 | ) -> None: 53 | """ 54 | Args: 55 | default_wait: If a retry-after value was not provided in an API 56 | response, use this wait method. 57 | rate_limit_logging: If not None, log statements will be generated 58 | using this configuration when a rate limit is hit. 59 | """ 60 | self._default_wait = default_wait 61 | self._rate_limit_logging = rate_limit_logging 62 | 63 | def __call__(self, retry_state: RetryCallState) -> float: 64 | """ 65 | If a RateLimitExceededException was raised and has a retry_after value, 66 | return it. Otherwise use the default waiter method. 67 | """ 68 | exception = retry_state.outcome.exception() 69 | if isinstance(exception, RateLimitExceededException): 70 | if exception.retry_after is not None: 71 | # Log information about the rate limit if needed 72 | wait_for: float = exception.retry_after.total_seconds() 73 | if (l := self._rate_limit_logging) is not None: 74 | l.logger.info( 75 | f"Request to {l.action} was rate limited with " 76 | f"retry-after = {wait_for} seconds" 77 | ) 78 | return wait_for 79 | return self._default_wait(retry_state) 80 | 81 | 82 | NoLimitType = Literal[0] 83 | NO_LIMIT: NoLimitType = 0 84 | 85 | 86 | def retry_api_calls( 87 | max_attempts: Union[int, NoLimitType] = 3, 88 | rate_limit_logging: Optional[RateLimitLogging] = None, 89 | fixed_wait_sec: float = 2, 90 | max_jitter_sec: float = 1, 91 | ) -> Any: 92 | """ 93 | Decorator with sensible defaults for retrying calls to the OCP API. 94 | 95 | Args: 96 | max_attempts: The maximum number of calls to make. If NO_LIMIT, 97 | retries will be made forever. 98 | rate_limit_logging: If not None, log statements will be generated 99 | using this configuration when a rate limit is hit. 100 | fixed_wait_sec: The fixed number of seconds to wait when retrying an 101 | exception that does *not* include a retry-after value. The default 102 | value is sensible; this is exposed mostly for testing. 103 | max_jitter_sec: The maximum number of seconds that will be randomly 104 | added to wait times. The default value is sensible; this is exposed 105 | mostly for testing. 106 | """ 107 | return tenacity_retry( 108 | # Retry forever if no limit was applied. Otherwise stop after the 109 | # max number of attempts has been made. 110 | stop=stop_never 111 | if max_attempts == NO_LIMIT 112 | else stop_after_attempt(max_attempts), 113 | # If the API returns that a rate limit was breached and gives a 114 | # retry-after value, use that. Otherwise wait a fixed number of 115 | # seconds. In all cases, add a random jitter. 116 | wait=_wait_check_retry_after( 117 | wait_fixed(fixed_wait_sec), 118 | rate_limit_logging, 119 | ) 120 | + wait_random(0, max_jitter_sec), 121 | # Retry any API exceptions unless they are explicitly marked as 122 | # not retryable. 123 | retry=retry_if_exception_type(RequestException) 124 | & retry_if_not_exception_type(NonRetryableRequestException), 125 | # Raise the original exception instead of wrapping it in a 126 | # tenacity exception 127 | reraise=True, 128 | ) 129 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = ocpapi 3 | version = attr: ocpapi.version.VERSION 4 | author = Open Catalyst Project 5 | author_email = opencatalyst@meta.com 6 | description = Python client library for the Open Catalyst API 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | license_files = LICENSE 10 | classifiers = 11 | Programming Language :: Python :: 3 12 | License :: OSI Approved :: MIT License 13 | Operating System :: OS Independent 14 | 15 | [options] 16 | packages = find: 17 | python_requires = >=3.9 18 | include_package_data = True 19 | install_requires = 20 | requests == 2.31.0 21 | responses == 0.23.2 22 | tenacity == 8.2.3 23 | tqdm == 4.66.1 24 | inquirer == 3.1.3 25 | dataclasses-json == 0.6.0 26 | 27 | [options.extras_require] 28 | dev = 29 | ase == 3.22.1 30 | readchar == 4.0.5 31 | 32 | [options.packages.find] 33 | exclude = 34 | tests* 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-Catalyst-Project/ocpapi/03a3277d873459816fdba80885ac275f149420c5/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-Catalyst-Project/ocpapi/03a3277d873459816fdba80885ac275f149420c5/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-Catalyst-Project/ocpapi/03a3277d873459816fdba80885ac275f149420c5/tests/integration/client/__init__.py -------------------------------------------------------------------------------- /tests/integration/client/test_client.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import asynccontextmanager 3 | from typing import AsyncGenerator 4 | from unittest import IsolatedAsyncioTestCase, mock 5 | 6 | from ocpapi.client import ( 7 | Atoms, 8 | Bulk, 9 | Client, 10 | Model, 11 | Slab, 12 | SlabMetadata, 13 | Status, 14 | ) 15 | 16 | log = logging.getLogger(__name__) 17 | 18 | 19 | @asynccontextmanager 20 | async def _ensure_system_deleted( 21 | client: Client, 22 | system_id: str, 23 | ) -> AsyncGenerator[None, None]: 24 | """ 25 | Immediately yields control to the caller. When control returns to this 26 | function, try to delete the system with the input id. 27 | """ 28 | try: 29 | yield 30 | finally: 31 | await client.delete_adsorbate_slab_relaxations(system_id) 32 | 33 | 34 | class TestClient(IsolatedAsyncioTestCase): 35 | """ 36 | Tests that calls to a real server are handled correctly. 37 | """ 38 | 39 | CLIENT: Client = Client( 40 | host="open-catalyst-api.metademolab.com", 41 | scheme="https", 42 | ) 43 | KNOWN_SYSTEM_ID: str = "f9eacd8f-748c-41dd-ae43-f263dd36d735" 44 | 45 | async def test_get_models(self) -> None: 46 | # Make sure that at least one of the known models is in the response 47 | 48 | response = await self.CLIENT.get_models() 49 | 50 | self.assertIn( 51 | Model(id="equiformer_v2_31M_s2ef_all_md"), 52 | response.models, 53 | ) 54 | 55 | async def test_get_bulks(self) -> None: 56 | # Make sure that at least one of the expected bulks is in the response 57 | 58 | response = await self.CLIENT.get_bulks() 59 | 60 | self.assertIn( 61 | Bulk(src_id="mp-149", elements=["Si"], formula="Si"), 62 | response.bulks_supported, 63 | ) 64 | 65 | async def test_get_adsorbates(self) -> None: 66 | # Make sure that at least one of the expected adsorbates is in the 67 | # response 68 | 69 | response = await self.CLIENT.get_adsorbates() 70 | 71 | self.assertIn("*CO", response.adsorbates_supported) 72 | 73 | async def test_get_slabs(self) -> None: 74 | # Make sure that at least one of the expected slabs is in the response 75 | 76 | response = await self.CLIENT.get_slabs("mp-149") 77 | 78 | self.assertIn( 79 | Slab( 80 | # Don't worry about checking the specific values in the 81 | # returned structure. This could be unstable if the code 82 | # on the server changes and we don't necessarily care here 83 | # what each value is. 84 | atoms=mock.ANY, 85 | metadata=SlabMetadata( 86 | bulk_src_id="mp-149", 87 | millers=(1, 1, 1), 88 | shift=0.125, 89 | top=True, 90 | ), 91 | ), 92 | response.slabs, 93 | ) 94 | 95 | async def test_get_adsorbate_slab_configs(self) -> None: 96 | # Make sure that adsorbate placements are generated for a slab 97 | # and adsorbate combination that is known to be supported 98 | 99 | response = await self.CLIENT.get_adsorbate_slab_configs( 100 | adsorbate="*CO", 101 | slab=Slab( 102 | atoms=Atoms( 103 | cell=( 104 | (11.6636, 0, 0), 105 | (-5.8318, 10.1010, 0), 106 | (0, 0, 38.0931), 107 | ), 108 | pbc=(True, True, True), 109 | numbers=[14] * 54, 110 | tags=[0] * 54, 111 | positions=[ 112 | (1.9439, 1.1223, 17.0626), 113 | (-0.0, 0.0, 20.237), 114 | (-0.0, 2.2447, 23.4114), 115 | (1.9439, 1.1223, 14.6817), 116 | (3.8879, 0.0, 17.8562), 117 | (-0.0, 2.2447, 21.0306), 118 | (-0.0, 4.4893, 17.0626), 119 | (-1.9439, 3.367, 20.237), 120 | (-1.9439, 5.6117, 23.4114), 121 | (-0.0, 4.4893, 14.6817), 122 | (1.9439, 3.367, 17.8562), 123 | (-1.9439, 5.6117, 21.0306), 124 | (-1.9439, 7.8563, 17.0626), 125 | (-3.8879, 6.734, 20.237), 126 | (-3.8879, 8.9786, 23.4114), 127 | (-1.9439, 7.8563, 14.6817), 128 | (-0.0, 6.734, 17.8562), 129 | (-3.8879, 8.9786, 21.0306), 130 | (5.8318, 1.1223, 17.0626), 131 | (3.8879, 0.0, 20.237), 132 | (3.8879, 2.2447, 23.4114), 133 | (5.8318, 1.1223, 14.6817), 134 | (7.7757, 0.0, 17.8562), 135 | (3.8879, 2.2447, 21.0306), 136 | (3.8879, 4.4893, 17.0626), 137 | (1.9439, 3.367, 20.237), 138 | (1.9439, 5.6117, 23.4114), 139 | (3.8879, 4.4893, 14.6817), 140 | (5.8318, 3.367, 17.8562), 141 | (1.9439, 5.6117, 21.0306), 142 | (1.9439, 7.8563, 17.0626), 143 | (-0.0, 6.734, 20.237), 144 | (-0.0, 8.9786, 23.4114), 145 | (1.9439, 7.8563, 14.6817), 146 | (3.8879, 6.734, 17.8562), 147 | (-0.0, 8.9786, 21.0306), 148 | (9.7197, 1.1223, 17.0626), 149 | (7.7757, 0.0, 20.237), 150 | (7.7757, 2.2447, 23.4114), 151 | (9.7197, 1.1223, 14.6817), 152 | (11.6636, 0.0, 17.8562), 153 | (7.7757, 2.2447, 21.0306), 154 | (7.7757, 4.4893, 17.0626), 155 | (5.8318, 3.367, 20.237), 156 | (5.8318, 5.6117, 23.4114), 157 | (7.7757, 4.4893, 14.6817), 158 | (9.7197, 3.367, 17.8562), 159 | (5.8318, 5.6117, 21.0306), 160 | (5.8318, 7.8563, 17.0626), 161 | (3.8879, 6.734, 20.237), 162 | (3.8879, 8.9786, 23.4114), 163 | (5.8318, 7.8563, 14.6817), 164 | (7.7757, 6.734, 17.8562), 165 | (3.8879, 8.9786, 21.0306), 166 | ], 167 | ), 168 | metadata=SlabMetadata( 169 | bulk_src_id="mp-149", 170 | millers=(1, 1, 1), 171 | shift=0.125, 172 | top=True, 173 | ), 174 | ), 175 | ) 176 | 177 | self.assertGreater(len(response.adsorbate_configs), 10) 178 | 179 | async def test_submit_adsorbate_slab_relaxations__gemnet_oc(self) -> None: 180 | # Make sure that a relaxation can be started for an adsorbate 181 | # placement on a slab with the gemnet oc model 182 | 183 | response = await self.CLIENT.submit_adsorbate_slab_relaxations( 184 | adsorbate="*CO", 185 | adsorbate_configs=[ 186 | Atoms( 187 | cell=( 188 | (11.6636, 0, 0), 189 | (-5.8318, 10.1010, 0), 190 | (0, 0, 38.0931), 191 | ), 192 | pbc=(True, True, False), 193 | numbers=[6, 8], 194 | tags=[2, 2], 195 | positions=[ 196 | (1.9439, 3.3670, 22.2070), 197 | (1.9822, 3.2849, 23.3697), 198 | ], 199 | ) 200 | ], 201 | bulk=Bulk(src_id="mp-149", elements=["Si"], formula="Si"), 202 | slab=Slab( 203 | atoms=Atoms( 204 | cell=( 205 | (11.6636, 0, 0), 206 | (-5.8318, 10.1010, 0), 207 | (0, 0, 38.0931), 208 | ), 209 | pbc=(True, True, True), 210 | numbers=[14] * 54, 211 | tags=[0] * 54, 212 | positions=[ 213 | (1.9439, 1.1223, 17.0626), 214 | (-0.0, 0.0, 20.237), 215 | (-0.0, 2.2447, 23.4114), 216 | (1.9439, 1.1223, 14.6817), 217 | (3.8879, 0.0, 17.8562), 218 | (-0.0, 2.2447, 21.0306), 219 | (-0.0, 4.4893, 17.0626), 220 | (-1.9439, 3.367, 20.237), 221 | (-1.9439, 5.6117, 23.4114), 222 | (-0.0, 4.4893, 14.6817), 223 | (1.9439, 3.367, 17.8562), 224 | (-1.9439, 5.6117, 21.0306), 225 | (-1.9439, 7.8563, 17.0626), 226 | (-3.8879, 6.734, 20.237), 227 | (-3.8879, 8.9786, 23.4114), 228 | (-1.9439, 7.8563, 14.6817), 229 | (-0.0, 6.734, 17.8562), 230 | (-3.8879, 8.9786, 21.0306), 231 | (5.8318, 1.1223, 17.0626), 232 | (3.8879, 0.0, 20.237), 233 | (3.8879, 2.2447, 23.4114), 234 | (5.8318, 1.1223, 14.6817), 235 | (7.7757, 0.0, 17.8562), 236 | (3.8879, 2.2447, 21.0306), 237 | (3.8879, 4.4893, 17.0626), 238 | (1.9439, 3.367, 20.237), 239 | (1.9439, 5.6117, 23.4114), 240 | (3.8879, 4.4893, 14.6817), 241 | (5.8318, 3.367, 17.8562), 242 | (1.9439, 5.6117, 21.0306), 243 | (1.9439, 7.8563, 17.0626), 244 | (-0.0, 6.734, 20.237), 245 | (-0.0, 8.9786, 23.4114), 246 | (1.9439, 7.8563, 14.6817), 247 | (3.8879, 6.734, 17.8562), 248 | (-0.0, 8.9786, 21.0306), 249 | (9.7197, 1.1223, 17.0626), 250 | (7.7757, 0.0, 20.237), 251 | (7.7757, 2.2447, 23.4114), 252 | (9.7197, 1.1223, 14.6817), 253 | (11.6636, 0.0, 17.8562), 254 | (7.7757, 2.2447, 21.0306), 255 | (7.7757, 4.4893, 17.0626), 256 | (5.8318, 3.367, 20.237), 257 | (5.8318, 5.6117, 23.4114), 258 | (7.7757, 4.4893, 14.6817), 259 | (9.7197, 3.367, 17.8562), 260 | (5.8318, 5.6117, 21.0306), 261 | (5.8318, 7.8563, 17.0626), 262 | (3.8879, 6.734, 20.237), 263 | (3.8879, 8.9786, 23.4114), 264 | (5.8318, 7.8563, 14.6817), 265 | (7.7757, 6.734, 17.8562), 266 | (3.8879, 8.9786, 21.0306), 267 | ], 268 | ), 269 | metadata=SlabMetadata( 270 | bulk_src_id="mp-149", 271 | millers=(1, 1, 1), 272 | shift=0.125, 273 | top=True, 274 | ), 275 | ), 276 | model="gemnet_oc_base_s2ef_all_md", 277 | ephemeral=True, 278 | ) 279 | 280 | async with _ensure_system_deleted(self.CLIENT, response.system_id): 281 | self.assertNotEqual(response.system_id, "") 282 | self.assertEqual(len(response.config_ids), 1) 283 | 284 | async def test_submit_adsorbate_slab_relaxations__equiformer_v2(self) -> None: 285 | # Make sure that a relaxation can be started for an adsorbate 286 | # placement on a slab with the equiformer v2 model 287 | 288 | response = await self.CLIENT.submit_adsorbate_slab_relaxations( 289 | adsorbate="*CO", 290 | adsorbate_configs=[ 291 | Atoms( 292 | cell=( 293 | (11.6636, 0, 0), 294 | (-5.8318, 10.1010, 0), 295 | (0, 0, 38.0931), 296 | ), 297 | pbc=(True, True, False), 298 | numbers=[6, 8], 299 | tags=[2, 2], 300 | positions=[ 301 | (1.9439, 3.3670, 22.2070), 302 | (1.9822, 3.2849, 23.3697), 303 | ], 304 | ) 305 | ], 306 | bulk=Bulk(src_id="mp-149", elements=["Si"], formula="Si"), 307 | slab=Slab( 308 | atoms=Atoms( 309 | cell=( 310 | (11.6636, 0, 0), 311 | (-5.8318, 10.1010, 0), 312 | (0, 0, 38.0931), 313 | ), 314 | pbc=(True, True, True), 315 | numbers=[14] * 54, 316 | tags=[0] * 54, 317 | positions=[ 318 | (1.9439, 1.1223, 17.0626), 319 | (-0.0, 0.0, 20.237), 320 | (-0.0, 2.2447, 23.4114), 321 | (1.9439, 1.1223, 14.6817), 322 | (3.8879, 0.0, 17.8562), 323 | (-0.0, 2.2447, 21.0306), 324 | (-0.0, 4.4893, 17.0626), 325 | (-1.9439, 3.367, 20.237), 326 | (-1.9439, 5.6117, 23.4114), 327 | (-0.0, 4.4893, 14.6817), 328 | (1.9439, 3.367, 17.8562), 329 | (-1.9439, 5.6117, 21.0306), 330 | (-1.9439, 7.8563, 17.0626), 331 | (-3.8879, 6.734, 20.237), 332 | (-3.8879, 8.9786, 23.4114), 333 | (-1.9439, 7.8563, 14.6817), 334 | (-0.0, 6.734, 17.8562), 335 | (-3.8879, 8.9786, 21.0306), 336 | (5.8318, 1.1223, 17.0626), 337 | (3.8879, 0.0, 20.237), 338 | (3.8879, 2.2447, 23.4114), 339 | (5.8318, 1.1223, 14.6817), 340 | (7.7757, 0.0, 17.8562), 341 | (3.8879, 2.2447, 21.0306), 342 | (3.8879, 4.4893, 17.0626), 343 | (1.9439, 3.367, 20.237), 344 | (1.9439, 5.6117, 23.4114), 345 | (3.8879, 4.4893, 14.6817), 346 | (5.8318, 3.367, 17.8562), 347 | (1.9439, 5.6117, 21.0306), 348 | (1.9439, 7.8563, 17.0626), 349 | (-0.0, 6.734, 20.237), 350 | (-0.0, 8.9786, 23.4114), 351 | (1.9439, 7.8563, 14.6817), 352 | (3.8879, 6.734, 17.8562), 353 | (-0.0, 8.9786, 21.0306), 354 | (9.7197, 1.1223, 17.0626), 355 | (7.7757, 0.0, 20.237), 356 | (7.7757, 2.2447, 23.4114), 357 | (9.7197, 1.1223, 14.6817), 358 | (11.6636, 0.0, 17.8562), 359 | (7.7757, 2.2447, 21.0306), 360 | (7.7757, 4.4893, 17.0626), 361 | (5.8318, 3.367, 20.237), 362 | (5.8318, 5.6117, 23.4114), 363 | (7.7757, 4.4893, 14.6817), 364 | (9.7197, 3.367, 17.8562), 365 | (5.8318, 5.6117, 21.0306), 366 | (5.8318, 7.8563, 17.0626), 367 | (3.8879, 6.734, 20.237), 368 | (3.8879, 8.9786, 23.4114), 369 | (5.8318, 7.8563, 14.6817), 370 | (7.7757, 6.734, 17.8562), 371 | (3.8879, 8.9786, 21.0306), 372 | ], 373 | ), 374 | metadata=SlabMetadata( 375 | bulk_src_id="mp-149", 376 | millers=(1, 1, 1), 377 | shift=0.125, 378 | top=True, 379 | ), 380 | ), 381 | model="equiformer_v2_31M_s2ef_all_md", 382 | ephemeral=True, 383 | ) 384 | 385 | async with _ensure_system_deleted(self.CLIENT, response.system_id): 386 | self.assertNotEqual(response.system_id, "") 387 | self.assertEqual(len(response.config_ids), 1) 388 | 389 | async def test_get_adsorbate_slab_relaxations_request(self) -> None: 390 | # Make sure the original request can be fetched for an already- 391 | # submitted system. 392 | 393 | response = await self.CLIENT.get_adsorbate_slab_relaxations_request( 394 | system_id=self.KNOWN_SYSTEM_ID 395 | ) 396 | 397 | # Don't worry about checking all fields - just make sure at least one 398 | # of the expected fields was returned 399 | self.assertEqual(response.adsorbate, "*CO") 400 | 401 | async def test_get_adsorbate_slab_relaxations_results__all_fields_and_configs( 402 | self, 403 | ) -> None: 404 | # Make sure relaxation results can be fetched for an already-relaxed 405 | # system. Check that all configurations and all fields for each are 406 | # returned. 407 | 408 | response = await self.CLIENT.get_adsorbate_slab_relaxations_results( 409 | system_id=self.KNOWN_SYSTEM_ID, 410 | ) 411 | 412 | self.assertEqual(len(response.configs), 59) 413 | for config in response.configs: 414 | self.assertEqual(config.status, Status.SUCCESS) 415 | self.assertIsNotNone(config.system_id) 416 | self.assertIsNotNone(config.cell) 417 | self.assertIsNotNone(config.pbc) 418 | self.assertIsNotNone(config.numbers) 419 | self.assertIsNotNone(config.positions) 420 | self.assertIsNotNone(config.tags) 421 | self.assertIsNotNone(config.energy) 422 | self.assertIsNotNone(config.energy_trajectory) 423 | self.assertIsNotNone(config.forces) 424 | config_ids = {c.config_id for c in response.configs} 425 | self.assertEqual(config_ids, set(range(59))) 426 | 427 | async def test_get_adsorbate_slab_relaxations_results__limited_fields_and_configs( 428 | self, 429 | ) -> None: 430 | # Make sure relaxation results can be fetched for an already-relaxed 431 | # system. Check that only the requested configurations and fields are 432 | # returned. 433 | 434 | response = await self.CLIENT.get_adsorbate_slab_relaxations_results( 435 | system_id=self.KNOWN_SYSTEM_ID, 436 | config_ids=[10, 20, 30], 437 | fields=["energy", "cell"], 438 | ) 439 | 440 | self.assertEqual(len(response.configs), 3) 441 | for config in response.configs: 442 | self.assertEqual(config.status, Status.SUCCESS) 443 | self.assertIsNone(config.system_id) 444 | self.assertIsNotNone(config.cell) 445 | self.assertIsNone(config.pbc) 446 | self.assertIsNone(config.numbers) 447 | self.assertIsNone(config.positions) 448 | self.assertIsNone(config.tags) 449 | self.assertIsNotNone(config.energy) 450 | self.assertIsNone(config.energy_trajectory) 451 | self.assertIsNone(config.forces) 452 | config_ids = {c.config_id for c in response.configs} 453 | self.assertEqual(config_ids, {10, 20, 30}) 454 | -------------------------------------------------------------------------------- /tests/integration/client/test_ui.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase as UnitTestCase 2 | 3 | import requests 4 | 5 | from ocpapi.client import get_results_ui_url 6 | 7 | 8 | class TestUI(UnitTestCase): 9 | """ 10 | Tests that calls to a real server are handled correctly. 11 | """ 12 | 13 | API_HOST: str = "open-catalyst-api.metademolab.com" 14 | KNOWN_SYSTEM_ID: str = "f9eacd8f-748c-41dd-ae43-f263dd36d735" 15 | 16 | def test_get_results_ui_url(self) -> None: 17 | # Make sure the UI URL is reachable 18 | 19 | ui_url = get_results_ui_url(self.API_HOST, self.KNOWN_SYSTEM_ID) 20 | response = requests.head(ui_url) 21 | 22 | self.assertEqual(200, response.status_code) 23 | -------------------------------------------------------------------------------- /tests/integration/workflows/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-Catalyst-Project/ocpapi/03a3277d873459816fdba80885ac275f149420c5/tests/integration/workflows/__init__.py -------------------------------------------------------------------------------- /tests/integration/workflows/test_adsorbates.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List 3 | from unittest import IsolatedAsyncioTestCase 4 | 5 | import requests 6 | 7 | from ocpapi.client import AdsorbateSlabConfigs, Client, Status 8 | from ocpapi.workflows import ( 9 | Lifetime, 10 | find_adsorbate_binding_sites, 11 | get_adsorbate_slab_relaxation_results, 12 | wait_for_adsorbate_slab_relaxations, 13 | ) 14 | 15 | 16 | class TestAdsorbates(IsolatedAsyncioTestCase): 17 | """ 18 | Tests that workflow methods run against a real server execute correctly. 19 | """ 20 | 21 | CLIENT: Client = Client( 22 | host="open-catalyst-api.metademolab.com", 23 | scheme="https", 24 | ) 25 | KNOWN_SYSTEM_ID: str = "f9eacd8f-748c-41dd-ae43-f263dd36d735" 26 | 27 | async def test_get_adsorbate_slab_relaxation_results(self) -> None: 28 | # The server is expected to omit some results when too many are 29 | # requested. Check that all results are fetched since test method 30 | # under test should retry until all results have been retrieved. 31 | 32 | # The system under test has 59 configs: 33 | # https://open-catalyst.metademolab.com/results/f9eacd8f-748c-41dd-ae43-f263dd36d735 34 | num_configs = 59 35 | 36 | results = await get_adsorbate_slab_relaxation_results( 37 | system_id=self.KNOWN_SYSTEM_ID, 38 | config_ids=list(range(num_configs)), 39 | # Fetch a subset of fields to avoid transferring significantly more 40 | # data than we really need in this test 41 | fields=["energy", "pbc"], 42 | client=self.CLIENT, 43 | ) 44 | 45 | self.assertEqual( 46 | [r.status for r in results], 47 | [Status.SUCCESS] * num_configs, 48 | ) 49 | 50 | async def test_wait_for_adsorbate_slab_relaxations(self) -> None: 51 | # This test runs against an already-finished set of relaxations. 52 | # The goal is not to check that the method waits when relaxations 53 | # are still running (that is covered in unit tests), but just to 54 | # ensure that the call to the API is made correctly and that the 55 | # function returns ~immediately because the relaxations are done. 56 | 57 | start = time.monotonic() 58 | 59 | await wait_for_adsorbate_slab_relaxations( 60 | system_id=self.KNOWN_SYSTEM_ID, 61 | check_immediately=False, 62 | slow_interval_sec=1, 63 | fast_interval_sec=1, 64 | client=self.CLIENT, 65 | ) 66 | 67 | took = time.monotonic() - start 68 | self.assertGreaterEqual(took, 1) 69 | # Give a pretty generous upper bound so that this test is not flaky 70 | # when there is a poor connection or the server is busy 71 | self.assertLess(took, 5) 72 | 73 | async def test_find_adsorbate_binding_sites(self) -> None: 74 | # Run an end-to-end test to find adsorbate binding sites on the 75 | # surface of a bulk material. 76 | 77 | # By default, we'll end up running relaxations for dozens of adsorbate 78 | # placements on the bulk surface. This function selects out only the 79 | # first adsorbate configuration. This lets us run a smaller number of 80 | # relaxations since we really don't need to run dozens just to know 81 | # that the method under test works. 82 | async def _keep_first_adslab( 83 | adslabs: List[AdsorbateSlabConfigs], 84 | ) -> List[AdsorbateSlabConfigs]: 85 | return [ 86 | AdsorbateSlabConfigs( 87 | adsorbate_configs=adslabs[0].adsorbate_configs[:1], 88 | slab=adslabs[0].slab, 89 | ) 90 | ] 91 | 92 | results = await find_adsorbate_binding_sites( 93 | adsorbate="*O", 94 | bulk="mp-30", 95 | model="gemnet_oc_base_s2ef_all_md", 96 | adslab_filter=_keep_first_adslab, 97 | client=self.CLIENT, 98 | # Since this is a test, delete the relaxations from the server 99 | # once results have been fetched. 100 | lifetime=Lifetime.DELETE, 101 | ) 102 | 103 | self.assertEqual(1, len(results.slabs)) 104 | self.assertEqual(1, len(results.slabs[0].configs)) 105 | self.assertEqual(Status.SUCCESS, results.slabs[0].configs[0].status) 106 | 107 | # Make sure that the adslabs being used have tags for sub-surface, 108 | # surface, and adsorbate atoms. Then make sure that forces are 109 | # exactly zero only for the sub-surface atoms. 110 | config = results.slabs[0].configs[0] 111 | self.assertEqual( 112 | {0, 1, 2}, 113 | set(config.tags), 114 | "Expected tags for surface, sub-surface, and adsorbate atoms", 115 | ) 116 | for tag, forces in zip(config.tags, config.forces): 117 | if tag == 0: # Sub-surface atoms are fixed / have 0 forces 118 | self.assertEqual(forces, (0, 0, 0)) 119 | else: 120 | self.assertNotEqual(forces, (0, 0, 0)) 121 | 122 | # Make sure the UI URL is reachable 123 | response = requests.head(results.slabs[0].ui_url) 124 | self.assertEqual(200, response.status_code) 125 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-Catalyst-Project/ocpapi/03a3277d873459816fdba80885ac275f149420c5/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-Catalyst-Project/ocpapi/03a3277d873459816fdba80885ac275f149420c5/tests/unit/client/__init__.py -------------------------------------------------------------------------------- /tests/unit/client/test_client.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import timedelta 3 | from typing import Any, Dict, List, Optional, Union 4 | from unittest import IsolatedAsyncioTestCase 5 | 6 | import responses 7 | 8 | from ocpapi.client import ( 9 | Adsorbates, 10 | AdsorbateSlabConfigs, 11 | AdsorbateSlabRelaxationResult, 12 | AdsorbateSlabRelaxationsRequest, 13 | AdsorbateSlabRelaxationsResults, 14 | AdsorbateSlabRelaxationsSystem, 15 | Atoms, 16 | Bulk, 17 | Bulks, 18 | Client, 19 | Model, 20 | Models, 21 | NonRetryableRequestException, 22 | RateLimitExceededException, 23 | RequestException, 24 | Slab, 25 | SlabMetadata, 26 | Slabs, 27 | Status, 28 | ) 29 | from ocpapi.client.models import _DataModel 30 | 31 | 32 | class TestClient(IsolatedAsyncioTestCase): 33 | """ 34 | Tests with mocked responses to ensure that they are handled correctly. 35 | """ 36 | 37 | async def _run_common_tests_against_route( 38 | self, 39 | method: str, 40 | route: str, 41 | client_method_name: str, 42 | successful_response_code: int, 43 | successful_response_body: str, 44 | successful_response_object: Optional[_DataModel], 45 | client_method_args: Optional[Dict[str, Any]] = None, 46 | expected_request_params: Optional[Dict[str, Any]] = None, 47 | expected_request_body: Optional[Dict[str, Any]] = None, 48 | ) -> None: 49 | @dataclass 50 | class TestCase: 51 | message: str 52 | scheme: str 53 | host: str 54 | response_body: Union[str, Exception] 55 | response_code: int 56 | response_headers: Optional[Dict[str, str]] = None 57 | expected: Optional[_DataModel] = None 58 | expected_request_params: Optional[Dict[str, Any]] = None 59 | expected_request_body: Optional[Dict[str, Any]] = None 60 | expected_exception: Optional[Exception] = None 61 | 62 | test_cases: List[TestCase] = [ 63 | # If a 429 response code is returned, then a 64 | # RateLimitExceededException should be raised 65 | TestCase( 66 | message="rate limit exceeded", 67 | scheme="https", 68 | host="test_host", 69 | response_body='{"message": "failed"}', 70 | response_code=429, 71 | response_headers={"Retry-After": "100"}, 72 | expected_request_params=expected_request_params, 73 | expected_request_body=expected_request_body, 74 | expected_exception=RateLimitExceededException( 75 | method=method, 76 | url=f"https://test_host/{route}", 77 | retry_after=timedelta(seconds=100), 78 | ), 79 | ), 80 | # If a 429 response code is returned, then a 81 | # RateLimitExceededException should be raised - ensure correct 82 | # handling when retry-after header is not present 83 | TestCase( 84 | message="rate limit exceeded, no retry-after", 85 | scheme="https", 86 | host="test_host", 87 | response_body='{"message": "failed"}', 88 | response_code=429, 89 | response_headers={}, 90 | expected_request_params=expected_request_params, 91 | expected_request_body=expected_request_body, 92 | expected_exception=RateLimitExceededException( 93 | method=method, 94 | url=f"https://test_host/{route}", 95 | retry_after=None, 96 | ), 97 | ), 98 | # If a 400-level response code is returned then a 99 | # NonRetryableRequestException should be raised 100 | TestCase( 101 | message="non-retryable error", 102 | scheme="https", 103 | host="test_host", 104 | response_body='{"message": "failed"}', 105 | response_code=404, 106 | response_headers={}, 107 | expected_request_params=expected_request_params, 108 | expected_request_body=expected_request_body, 109 | expected_exception=NonRetryableRequestException( 110 | method=method, 111 | url=f"https://test_host/{route}", 112 | cause=( 113 | "Unexpected response code: 404. " 114 | 'Response body: {"message": "failed"}' 115 | ), 116 | ), 117 | ), 118 | # If another unexpected response code is returned then an exception 119 | # should be raised 120 | TestCase( 121 | message="non-200 response code", 122 | scheme="https", 123 | host="test_host", 124 | response_body='{"message": "failed"}', 125 | response_code=500, 126 | expected_request_params=expected_request_params, 127 | expected_request_body=expected_request_body, 128 | expected_exception=RequestException( 129 | method=method, 130 | url=f"https://test_host/{route}", 131 | cause=( 132 | "Unexpected response code: 500. " 133 | 'Response body: {"message": "failed"}' 134 | ), 135 | ), 136 | ), 137 | # If an exception is raised from within requests, it should be 138 | # re-raised in the client 139 | TestCase( 140 | message="exception in request handling", 141 | scheme="https", 142 | host="test_host", 143 | # This tells the responses library to raise an exception 144 | response_body=Exception("exception message"), 145 | response_code=successful_response_code, 146 | expected_request_params=expected_request_params, 147 | expected_request_body=expected_request_body, 148 | expected_exception=RequestException( 149 | method=method, 150 | url=f"https://test_host/{route}", 151 | cause=( 152 | "Exception while making request: " 153 | "Exception: exception message" 154 | ), 155 | ), 156 | ), 157 | # If the request is successful then data should be saved in 158 | # the response object 159 | TestCase( 160 | message="response with data", 161 | scheme="https", 162 | host="test_host", 163 | response_body=successful_response_body, 164 | response_code=successful_response_code, 165 | expected=successful_response_object, 166 | expected_request_params=expected_request_params, 167 | expected_request_body=expected_request_body, 168 | ), 169 | ] 170 | 171 | for case in test_cases: 172 | with self.subTest(msg=case.message): 173 | # Match the request body if one is expected 174 | match = [] 175 | if case.expected_request_body is not None: 176 | match.append( 177 | responses.matchers.json_params_matcher( 178 | case.expected_request_body 179 | ) 180 | ) 181 | if case.expected_request_params is not None: 182 | match.append( 183 | responses.matchers.query_param_matcher( 184 | case.expected_request_params 185 | ) 186 | ) 187 | 188 | # Mock the response to the request in the current test case 189 | with responses.RequestsMock() as mock_responses: 190 | mock_responses.add( 191 | method, 192 | f"{case.scheme}://{case.host}/{route}", 193 | body=case.response_body, 194 | headers=case.response_headers, 195 | status=case.response_code, 196 | match=match, 197 | ) 198 | 199 | # Create the coroutine that will run the request 200 | client = Client(scheme=case.scheme, host=case.host) 201 | request_method = getattr(client, client_method_name) 202 | args = client_method_args if client_method_args else {} 203 | request_coro = request_method(**args) 204 | 205 | # Ensure that an exception is raised if one is expected 206 | if case.expected_exception is not None: 207 | with self.assertRaises(type(case.expected_exception)) as ex: 208 | await request_coro 209 | self.assertEqual( 210 | vars(case.expected_exception), 211 | vars(ex.exception), 212 | ) 213 | self.assertEqual( 214 | str(case.expected_exception), 215 | str(ex.exception), 216 | ) 217 | 218 | # If an exception is not expected, then make sure the 219 | # response is correct 220 | else: 221 | response = await request_coro 222 | self.assertEqual(response, case.expected) 223 | 224 | def test_host(self) -> None: 225 | client = Client(host="test-host") 226 | self.assertEqual("test-host", client.host) 227 | 228 | async def test_get_models(self) -> None: 229 | await self._run_common_tests_against_route( 230 | method="GET", 231 | route="ocp/models", 232 | client_method_name="get_models", 233 | successful_response_code=200, 234 | successful_response_body=""" 235 | { 236 | "models": [ 237 | { 238 | "id": "model_1" 239 | }, 240 | { 241 | "id": "model_2" 242 | } 243 | ] 244 | } 245 | """, 246 | successful_response_object=Models( 247 | models=[ 248 | Model(id="model_1"), 249 | Model(id="model_2"), 250 | ] 251 | ), 252 | ) 253 | 254 | async def test_get_bulks(self) -> None: 255 | await self._run_common_tests_against_route( 256 | method="GET", 257 | route="ocp/bulks", 258 | client_method_name="get_bulks", 259 | successful_response_code=200, 260 | successful_response_body=""" 261 | { 262 | "bulks_supported": [ 263 | { 264 | "src_id": "1", 265 | "els": ["A", "B"], 266 | "formula": "AB2" 267 | }, 268 | { 269 | "src_id": "2", 270 | "els": ["C"], 271 | "formula": "C60" 272 | } 273 | ] 274 | } 275 | """, 276 | successful_response_object=Bulks( 277 | bulks_supported=[ 278 | Bulk( 279 | src_id="1", 280 | elements=["A", "B"], 281 | formula="AB2", 282 | ), 283 | Bulk( 284 | src_id="2", 285 | elements=["C"], 286 | formula="C60", 287 | ), 288 | ], 289 | ), 290 | ) 291 | 292 | async def test_get_adsorbates(self) -> None: 293 | await self._run_common_tests_against_route( 294 | method="GET", 295 | route="ocp/adsorbates", 296 | client_method_name="get_adsorbates", 297 | successful_response_code=200, 298 | successful_response_body=""" 299 | { 300 | "adsorbates_supported": ["A", "B"] 301 | } 302 | """, 303 | successful_response_object=Adsorbates( 304 | adsorbates_supported=["A", "B"], 305 | ), 306 | ) 307 | 308 | async def test_get_slabs__bulk_by_id(self) -> None: 309 | await self._run_common_tests_against_route( 310 | method="POST", 311 | route="ocp/slabs", 312 | client_method_name="get_slabs", 313 | client_method_args={"bulk": "test_id"}, 314 | expected_request_body={"bulk_src_id": "test_id"}, 315 | successful_response_code=200, 316 | successful_response_body=""" 317 | { 318 | "slabs": [{ 319 | "slab_atomsobject": { 320 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 321 | "pbc": [true, false, true], 322 | "numbers": [1, 2], 323 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 324 | "tags": [0, 1] 325 | }, 326 | "slab_metadata": { 327 | "bulk_id": "test_id", 328 | "millers": [-1, 0, 1], 329 | "shift": 0.25, 330 | "top": false 331 | } 332 | }] 333 | } 334 | """, 335 | successful_response_object=Slabs( 336 | slabs=[ 337 | Slab( 338 | atoms=Atoms( 339 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 340 | pbc=(True, False, True), 341 | numbers=[1, 2], 342 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 343 | tags=[0, 1], 344 | ), 345 | metadata=SlabMetadata( 346 | bulk_src_id="test_id", 347 | millers=(-1, 0, 1), 348 | shift=0.25, 349 | top=False, 350 | ), 351 | ) 352 | ], 353 | ), 354 | ) 355 | 356 | async def test_get_slabs__bulk_by_obj(self) -> None: 357 | await self._run_common_tests_against_route( 358 | method="POST", 359 | route="ocp/slabs", 360 | client_method_name="get_slabs", 361 | client_method_args={ 362 | "bulk": Bulk( 363 | src_id="test_id", 364 | formula="AB", 365 | elements=["A", "B"], 366 | ) 367 | }, 368 | expected_request_body={"bulk_src_id": "test_id"}, 369 | successful_response_code=200, 370 | successful_response_body=""" 371 | { 372 | "slabs": [{ 373 | "slab_atomsobject": { 374 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 375 | "pbc": [true, false, true], 376 | "numbers": [1, 2], 377 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 378 | "tags": [0, 1] 379 | }, 380 | "slab_metadata": { 381 | "bulk_id": "test_id", 382 | "millers": [-1, 0, 1], 383 | "shift": 0.25, 384 | "top": false 385 | } 386 | }] 387 | } 388 | """, 389 | successful_response_object=Slabs( 390 | slabs=[ 391 | Slab( 392 | atoms=Atoms( 393 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 394 | pbc=(True, False, True), 395 | numbers=[1, 2], 396 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 397 | tags=[0, 1], 398 | ), 399 | metadata=SlabMetadata( 400 | bulk_src_id="test_id", 401 | millers=(-1, 0, 1), 402 | shift=0.25, 403 | top=False, 404 | ), 405 | ) 406 | ], 407 | ), 408 | ) 409 | 410 | async def test_get_adsorbate_slab_configurations(self) -> None: 411 | await self._run_common_tests_against_route( 412 | method="POST", 413 | route="ocp/adsorbate-slab-configs", 414 | client_method_name="get_adsorbate_slab_configs", 415 | client_method_args={ 416 | "adsorbate": "*A", 417 | "slab": Slab( 418 | atoms=Atoms( 419 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 420 | pbc=(True, False, True), 421 | numbers=[1, 2], 422 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 423 | tags=[0, 1], 424 | ), 425 | metadata=SlabMetadata( 426 | bulk_src_id="test_id", 427 | millers=(-1, 0, 1), 428 | shift=0.25, 429 | top=False, 430 | ), 431 | ), 432 | }, 433 | expected_request_body={ 434 | "adsorbate": "*A", 435 | "slab": { 436 | "slab_atomsobject": { 437 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 438 | "pbc": [True, False, True], 439 | "numbers": [1, 2], 440 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 441 | "tags": [0, 1], 442 | }, 443 | "slab_metadata": { 444 | "bulk_id": "test_id", 445 | "millers": [-1, 0, 1], 446 | "shift": 0.25, 447 | "top": False, 448 | }, 449 | }, 450 | }, 451 | successful_response_code=200, 452 | successful_response_body=""" 453 | { 454 | "adsorbate_configs": [ 455 | { 456 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 457 | "pbc": [true, false, true], 458 | "numbers": [1], 459 | "positions": [[1.1, 1.2, 1.3]], 460 | "tags": [2] 461 | } 462 | ], 463 | "slab": { 464 | "slab_atomsobject": { 465 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 466 | "pbc": [true, false, true], 467 | "numbers": [1, 2], 468 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 469 | "tags": [0, 1] 470 | }, 471 | "slab_metadata": { 472 | "bulk_id": "test_id", 473 | "millers": [-1, 0, 1], 474 | "shift": 0.25, 475 | "top": false 476 | } 477 | } 478 | } 479 | """, 480 | successful_response_object=AdsorbateSlabConfigs( 481 | adsorbate_configs=[ 482 | Atoms( 483 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 484 | pbc=(True, False, True), 485 | numbers=[1], 486 | positions=[(1.1, 1.2, 1.3)], 487 | tags=[2], 488 | ) 489 | ], 490 | slab=Slab( 491 | atoms=Atoms( 492 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 493 | pbc=(True, False, True), 494 | numbers=[1, 2], 495 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 496 | tags=[0, 1], 497 | ), 498 | metadata=SlabMetadata( 499 | bulk_src_id="test_id", 500 | millers=(-1, 0, 1), 501 | shift=0.25, 502 | top=False, 503 | ), 504 | ), 505 | ), 506 | ) 507 | 508 | async def test_submit_adsorbate_slab_relaxations(self) -> None: 509 | await self._run_common_tests_against_route( 510 | method="POST", 511 | route="ocp/adsorbate-slab-relaxations", 512 | client_method_name="submit_adsorbate_slab_relaxations", 513 | client_method_args={ 514 | "adsorbate": "*A", 515 | "adsorbate_configs": [ 516 | Atoms( 517 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 518 | pbc=(True, False, True), 519 | numbers=[1], 520 | positions=[(1.1, 1.2, 1.3)], 521 | tags=[2], 522 | ), 523 | ], 524 | "bulk": Bulk( 525 | src_id="test_id", 526 | formula="AB", 527 | elements=["A", "B"], 528 | ), 529 | "slab": Slab( 530 | atoms=Atoms( 531 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 532 | pbc=(True, False, True), 533 | numbers=[1, 2], 534 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 535 | tags=[0, 1], 536 | ), 537 | metadata=SlabMetadata( 538 | bulk_src_id="test_id", 539 | millers=(-1, 0, 1), 540 | shift=0.25, 541 | top=False, 542 | ), 543 | ), 544 | "model": "test_model", 545 | "ephemeral": True, 546 | }, 547 | expected_request_body={ 548 | "adsorbate": "*A", 549 | "adsorbate_configs": [ 550 | { 551 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 552 | "pbc": [True, False, True], 553 | "numbers": [1], 554 | "positions": [[1.1, 1.2, 1.3]], 555 | "tags": [2], 556 | } 557 | ], 558 | "bulk": { 559 | "src_id": "test_id", 560 | "formula": "AB", 561 | "els": ["A", "B"], 562 | }, 563 | "slab": { 564 | "slab_atomsobject": { 565 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 566 | "pbc": [True, False, True], 567 | "numbers": [1, 2], 568 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 569 | "tags": [0, 1], 570 | }, 571 | "slab_metadata": { 572 | "bulk_id": "test_id", 573 | "millers": [-1, 0, 1], 574 | "shift": 0.25, 575 | "top": False, 576 | }, 577 | }, 578 | "model": "test_model", 579 | "ephemeral": True, 580 | }, 581 | successful_response_code=200, 582 | successful_response_body=""" 583 | { 584 | "system_id": "sys_id", 585 | "config_ids": [1, 2, 3] 586 | } 587 | """, 588 | successful_response_object=AdsorbateSlabRelaxationsSystem( 589 | system_id="sys_id", 590 | config_ids=[1, 2, 3], 591 | ), 592 | ) 593 | 594 | async def test_get_adsorbate_slab_relaxations_request(self) -> None: 595 | await self._run_common_tests_against_route( 596 | method="GET", 597 | route="ocp/adsorbate-slab-relaxations/test_system_id", 598 | client_method_name="get_adsorbate_slab_relaxations_request", 599 | client_method_args={"system_id": "test_system_id"}, 600 | successful_response_code=200, 601 | successful_response_body=""" 602 | { 603 | "adsorbate": "ABC", 604 | "adsorbate_configs": [ 605 | { 606 | "cell": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], 607 | "pbc": [true, false, true], 608 | "numbers": [1, 2], 609 | "positions": [[1.1, 1.2, 1.3], [1.4, 1.5, 1.6]], 610 | "tags": [2, 2] 611 | } 612 | ], 613 | "bulk": { 614 | "src_id": "bulk_id", 615 | "formula": "XYZ", 616 | "els": ["X", "Y", "Z"] 617 | }, 618 | "slab": { 619 | "slab_atomsobject": { 620 | "cell": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], 621 | "pbc": [true, true, true], 622 | "numbers": [1], 623 | "positions": [[1.1, 1.2, 1.3]], 624 | "tags": [0] 625 | }, 626 | "slab_metadata": { 627 | "bulk_id": "bulk_id", 628 | "millers": [1, 1, 1], 629 | "shift": 0.25, 630 | "top": false 631 | } 632 | }, 633 | "model": "test_model" 634 | } 635 | """, 636 | successful_response_object=AdsorbateSlabRelaxationsRequest( 637 | adsorbate="ABC", 638 | adsorbate_configs=[ 639 | Atoms( 640 | cell=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6), (0.7, 0.8, 0.9)), 641 | pbc=(True, False, True), 642 | numbers=[1, 2], 643 | positions=[(1.1, 1.2, 1.3), (1.4, 1.5, 1.6)], 644 | tags=[2, 2], 645 | ) 646 | ], 647 | bulk=Bulk( 648 | src_id="bulk_id", 649 | formula="XYZ", 650 | elements=["X", "Y", "Z"], 651 | ), 652 | slab=Slab( 653 | atoms=Atoms( 654 | cell=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6), (0.7, 0.8, 0.9)), 655 | pbc=(True, True, True), 656 | numbers=[1], 657 | positions=[(1.1, 1.2, 1.3)], 658 | tags=[0], 659 | ), 660 | metadata=SlabMetadata( 661 | bulk_src_id="bulk_id", 662 | millers=(1, 1, 1), 663 | shift=0.25, 664 | top=False, 665 | ), 666 | ), 667 | model="test_model", 668 | ), 669 | ) 670 | 671 | async def test_get_adsorbate_slab_relaxations_results__all_args(self) -> None: 672 | await self._run_common_tests_against_route( 673 | method="GET", 674 | route="ocp/adsorbate-slab-relaxations/test_sys_id/configs", 675 | client_method_name="get_adsorbate_slab_relaxations_results", 676 | client_method_args={ 677 | "system_id": "test_sys_id", 678 | "config_ids": [1, 2], 679 | "fields": ["A", "B"], 680 | }, 681 | expected_request_params={ 682 | "config_id": ["1", "2"], 683 | "field": ["A", "B"], 684 | }, 685 | successful_response_code=200, 686 | successful_response_body=""" 687 | { 688 | "configs": [ 689 | { 690 | "config_id": 1, 691 | "status": "success" 692 | } 693 | ] 694 | } 695 | """, 696 | successful_response_object=AdsorbateSlabRelaxationsResults( 697 | configs=[ 698 | AdsorbateSlabRelaxationResult( 699 | config_id=1, 700 | status=Status.SUCCESS, 701 | ) 702 | ], 703 | omitted_config_ids=[], 704 | ), 705 | ) 706 | 707 | async def test_get_adsorbate_slab_relaxations_results__req_args_only(self) -> None: 708 | await self._run_common_tests_against_route( 709 | method="GET", 710 | route="ocp/adsorbate-slab-relaxations/test_sys_id/configs", 711 | client_method_name="get_adsorbate_slab_relaxations_results", 712 | client_method_args={ 713 | "system_id": "test_sys_id", 714 | }, 715 | expected_request_params={}, 716 | successful_response_code=200, 717 | successful_response_body=""" 718 | { 719 | "configs": [ 720 | { 721 | "config_id": 1, 722 | "status": "success" 723 | } 724 | ] 725 | } 726 | """, 727 | successful_response_object=AdsorbateSlabRelaxationsResults( 728 | configs=[ 729 | AdsorbateSlabRelaxationResult( 730 | config_id=1, 731 | status=Status.SUCCESS, 732 | ) 733 | ], 734 | omitted_config_ids=[], 735 | ), 736 | ) 737 | 738 | async def test_delete_adsorbate_slab_relaxations(self) -> None: 739 | await self._run_common_tests_against_route( 740 | method="DELETE", 741 | route="ocp/adsorbate-slab-relaxations/test_sys_id", 742 | client_method_name="delete_adsorbate_slab_relaxations", 743 | client_method_args={ 744 | "system_id": "test_sys_id", 745 | }, 746 | successful_response_code=200, 747 | successful_response_body="{}", 748 | successful_response_object=None, 749 | ) 750 | -------------------------------------------------------------------------------- /tests/unit/client/test_models.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from typing import ( 4 | Any, 5 | Final, 6 | Generic, 7 | List, 8 | Optional, 9 | Tuple, 10 | Type, 11 | TypeVar, 12 | Union, 13 | ) 14 | from unittest import TestCase as UnitTestCase 15 | 16 | import numpy as np 17 | from ase.atoms import Atoms as ASEAtoms 18 | from ase.calculators.singlepoint import SinglePointCalculator 19 | from ase.constraints import FixAtoms 20 | 21 | from ocpapi.client import ( 22 | Adsorbates, 23 | AdsorbateSlabConfigs, 24 | AdsorbateSlabRelaxationResult, 25 | AdsorbateSlabRelaxationsRequest, 26 | AdsorbateSlabRelaxationsResults, 27 | AdsorbateSlabRelaxationsSystem, 28 | Atoms, 29 | Bulk, 30 | Bulks, 31 | Model, 32 | Models, 33 | Slab, 34 | SlabMetadata, 35 | Slabs, 36 | Status, 37 | ) 38 | from ocpapi.client.models import _DataModel 39 | 40 | T = TypeVar("T", bound=_DataModel) 41 | 42 | 43 | class ModelTestWrapper: 44 | class ModelTest(UnitTestCase, Generic[T]): 45 | """ 46 | Base class for all tests below that assert behavior of data models. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | *args: Any, 52 | obj: T, 53 | obj_json: str, 54 | **kwargs: Any, 55 | ) -> None: 56 | """ 57 | Args: 58 | obj: A model instance in which all fields, even unknown ones, 59 | are included. 60 | obj_json: JSON-serialized version of obj. 61 | """ 62 | super().__init__(*args, **kwargs) 63 | self._obj = obj 64 | self._obj_json = obj_json 65 | self._obj_type = type(obj) 66 | 67 | def test_from_json(self) -> None: 68 | @dataclass 69 | class TestCase: 70 | message: str 71 | json_repr: str 72 | expected: Final[Optional[T]] = None 73 | expected_exception: Final[Optional[Type[Exception]]] = None 74 | 75 | test_cases: List[TestCase] = [ 76 | # If the json object is empty then default values should 77 | # be used for all fields 78 | TestCase( 79 | message="empty object", 80 | json_repr="{}", 81 | expected_exception=Exception, 82 | ), 83 | # If all fields are set then they should be included in the 84 | # resulting object 85 | TestCase( 86 | message="all fields set", 87 | json_repr=self._obj_json, 88 | expected=self._obj, 89 | ), 90 | ] 91 | 92 | for case in test_cases: 93 | with self.subTest(msg=case.message): 94 | # Make sure an exception is raised if one is expected 95 | if case.expected_exception is not None: 96 | with self.assertRaises(case.expected_exception): 97 | self._obj_type.from_json(case.json_repr) 98 | 99 | # Otherwise make sure the expected value is returned 100 | if case.expected is not None: 101 | actual = self._obj_type.from_json(case.json_repr) 102 | self.assertEqual(actual, case.expected) 103 | 104 | def test_to_json(self) -> None: 105 | @dataclass 106 | class TestCase: 107 | message: str 108 | obj: T 109 | expected: str 110 | 111 | test_cases: List[TestCase] = [ 112 | # All explicitly-defined fields should serialize 113 | TestCase( 114 | message="all fields set", 115 | obj=self._obj, 116 | expected=self._obj_json, 117 | ), 118 | ] 119 | 120 | for case in test_cases: 121 | with self.subTest(msg=case.message): 122 | actual = case.obj.to_json() 123 | self.assertJsonEqual(actual, case.expected) 124 | 125 | def assertJsonEqual(self, first: str, second: str) -> None: 126 | """ 127 | Compares two JSON-formatted strings by deserializing them and then 128 | comparing the generated built-in types. 129 | """ 130 | self.assertEqual(json.loads(first), json.loads(second)) 131 | 132 | 133 | class TestModel(ModelTestWrapper.ModelTest[Model]): 134 | """ 135 | Serde tests for the Model data model. 136 | """ 137 | 138 | def __init__(self, *args: Any, **kwargs: Any) -> None: 139 | super().__init__( 140 | obj=Model( 141 | id="model_id", 142 | other_fields={"extra_field": "extra_value"}, 143 | ), 144 | obj_json=""" 145 | { 146 | "id": "model_id", 147 | "extra_field": "extra_value" 148 | } 149 | """, 150 | *args, 151 | **kwargs, 152 | ) 153 | 154 | 155 | class TestModels(ModelTestWrapper.ModelTest[Models]): 156 | """ 157 | Serde tests for the Models data model. 158 | """ 159 | 160 | def __init__(self, *args: Any, **kwargs: Any) -> None: 161 | super().__init__( 162 | obj=Models( 163 | models=[Model(id="model_id")], 164 | other_fields={"extra_field": "extra_value"}, 165 | ), 166 | obj_json=""" 167 | { 168 | "models": [ 169 | { 170 | "id": "model_id" 171 | } 172 | ], 173 | "extra_field": "extra_value" 174 | } 175 | """, 176 | *args, 177 | **kwargs, 178 | ) 179 | 180 | 181 | class TestBulk(ModelTestWrapper.ModelTest[Bulk]): 182 | """ 183 | Serde tests for the Bulk data model. 184 | """ 185 | 186 | def __init__(self, *args: Any, **kwargs: Any) -> None: 187 | super().__init__( 188 | obj=Bulk( 189 | src_id="test_id", 190 | elements=["A", "B"], 191 | formula="AB2", 192 | other_fields={"extra_field": "extra_value"}, 193 | ), 194 | obj_json=""" 195 | { 196 | "src_id": "test_id", 197 | "els": ["A", "B"], 198 | "formula": "AB2", 199 | "extra_field": "extra_value" 200 | } 201 | """, 202 | *args, 203 | **kwargs, 204 | ) 205 | 206 | 207 | class TestBulks(ModelTestWrapper.ModelTest[Bulks]): 208 | """ 209 | Serde tests for the Bulks data model. 210 | """ 211 | 212 | def __init__(self, *args: Any, **kwargs: Any) -> None: 213 | super().__init__( 214 | obj=Bulks( 215 | bulks_supported=[ 216 | Bulk( 217 | src_id="test_id", 218 | elements=["A", "B"], 219 | formula="AB2", 220 | ) 221 | ], 222 | other_fields={"extra_field": "extra_value"}, 223 | ), 224 | obj_json=""" 225 | { 226 | "bulks_supported": [ 227 | { 228 | "src_id": "test_id", 229 | "els": ["A", "B"], 230 | "formula": "AB2" 231 | } 232 | ], 233 | "extra_field": "extra_value" 234 | } 235 | """, 236 | *args, 237 | **kwargs, 238 | ) 239 | 240 | 241 | class TestAdsorbates(ModelTestWrapper.ModelTest[Adsorbates]): 242 | """ 243 | Serde tests for the Adsorbates data model. 244 | """ 245 | 246 | def __init__(self, *args: Any, **kwargs: Any) -> None: 247 | super().__init__( 248 | obj=Adsorbates( 249 | adsorbates_supported=["A", "B"], 250 | other_fields={"extra_field": "extra_value"}, 251 | ), 252 | obj_json=""" 253 | { 254 | "adsorbates_supported": ["A", "B"], 255 | "extra_field": "extra_value" 256 | } 257 | """, 258 | *args, 259 | **kwargs, 260 | ) 261 | 262 | 263 | class TestAtoms(ModelTestWrapper.ModelTest[Atoms]): 264 | """ 265 | Serde tests for the Atoms data model. 266 | """ 267 | 268 | def __init__(self, *args: Any, **kwargs: Any) -> None: 269 | super().__init__( 270 | obj=Atoms( 271 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 272 | pbc=(True, False, True), 273 | numbers=[1, 2], 274 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 275 | tags=[0, 1], 276 | other_fields={"extra_field": "extra_value"}, 277 | ), 278 | obj_json=""" 279 | { 280 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 281 | "pbc": [true, false, true], 282 | "numbers": [1, 2], 283 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 284 | "tags": [0, 1], 285 | "extra_field": "extra_value" 286 | } 287 | """, 288 | *args, 289 | **kwargs, 290 | ) 291 | 292 | def test_to_ase_atoms(self) -> None: 293 | atoms = Atoms( 294 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 295 | pbc=(True, False, True), 296 | numbers=[1, 2], 297 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 298 | tags=[0, 1], 299 | ) 300 | actual = atoms.to_ase_atoms() 301 | expected = ASEAtoms( 302 | cell=[(1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)], 303 | pbc=(True, False, True), 304 | numbers=[1, 2], 305 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 306 | tags=[0, 1], 307 | constraint=FixAtoms(mask=[True, False]), 308 | ) 309 | self.assertEqual(actual, expected) 310 | # The constraint property isn't checked in the Atoms.__eq__ method 311 | # so check it explicitly 312 | self.assertEqual(actual.constraints[0].index, [0]) 313 | 314 | 315 | class TestSlabMetadata(ModelTestWrapper.ModelTest[SlabMetadata]): 316 | """ 317 | Serde tests for the SlabMetadata data model. 318 | """ 319 | 320 | def __init__(self, *args: Any, **kwargs: Any) -> None: 321 | super().__init__( 322 | obj=SlabMetadata( 323 | bulk_src_id="test_id", 324 | millers=(-1, 0, 1), 325 | shift=0.25, 326 | top=False, 327 | other_fields={"extra_field": "extra_value"}, 328 | ), 329 | obj_json=""" 330 | { 331 | "bulk_id": "test_id", 332 | "millers": [-1, 0, 1], 333 | "shift": 0.25, 334 | "top": false, 335 | "extra_field": "extra_value" 336 | } 337 | """, 338 | *args, 339 | **kwargs, 340 | ) 341 | 342 | 343 | class TestSlab(ModelTestWrapper.ModelTest[Slab]): 344 | """ 345 | Serde tests for the Slab data model. 346 | """ 347 | 348 | def __init__(self, *args: Any, **kwargs: Any) -> None: 349 | super().__init__( 350 | obj=Slab( 351 | atoms=Atoms( 352 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 353 | pbc=(True, False, True), 354 | numbers=[1, 2], 355 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 356 | tags=[0, 1], 357 | other_fields={"extra_atoms_field": "extra_atoms_value"}, 358 | ), 359 | metadata=SlabMetadata( 360 | bulk_src_id="test_id", 361 | millers=(-1, 0, 1), 362 | shift=0.25, 363 | top=False, 364 | other_fields={"extra_metadata_field": "extra_metadata_value"}, 365 | ), 366 | other_fields={"extra_field": "extra_value"}, 367 | ), 368 | obj_json=""" 369 | { 370 | "slab_atomsobject": { 371 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 372 | "pbc": [true, false, true], 373 | "numbers": [1, 2], 374 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 375 | "tags": [0, 1], 376 | "extra_atoms_field": "extra_atoms_value" 377 | }, 378 | "slab_metadata": { 379 | "bulk_id": "test_id", 380 | "millers": [-1, 0, 1], 381 | "shift": 0.25, 382 | "top": false, 383 | "extra_metadata_field": "extra_metadata_value" 384 | }, 385 | "extra_field": "extra_value" 386 | } 387 | """, 388 | *args, 389 | **kwargs, 390 | ) 391 | 392 | 393 | class TestSlabs(ModelTestWrapper.ModelTest[Slabs]): 394 | """ 395 | Serde tests for the Slabs data model. 396 | """ 397 | 398 | def __init__(self, *args: Any, **kwargs: Any) -> None: 399 | super().__init__( 400 | obj=Slabs( 401 | slabs=[ 402 | Slab( 403 | atoms=Atoms( 404 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 405 | pbc=(True, False, True), 406 | numbers=[1, 2], 407 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 408 | tags=[0, 1], 409 | other_fields={"extra_atoms_field": "extra_atoms_value"}, 410 | ), 411 | metadata=SlabMetadata( 412 | bulk_src_id="test_id", 413 | millers=(-1, 0, 1), 414 | shift=0.25, 415 | top=False, 416 | other_fields={ 417 | "extra_metadata_field": "extra_metadata_value" 418 | }, 419 | ), 420 | other_fields={"extra_slab_field": "extra_slab_value"}, 421 | ) 422 | ], 423 | other_fields={"extra_field": "extra_value"}, 424 | ), 425 | obj_json=""" 426 | { 427 | "slabs": [{ 428 | "slab_atomsobject": { 429 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 430 | "pbc": [true, false, true], 431 | "numbers": [1, 2], 432 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 433 | "tags": [0, 1], 434 | "extra_atoms_field": "extra_atoms_value" 435 | }, 436 | "slab_metadata": { 437 | "bulk_id": "test_id", 438 | "millers": [-1, 0, 1], 439 | "shift": 0.25, 440 | "top": false, 441 | "extra_metadata_field": "extra_metadata_value" 442 | }, 443 | "extra_slab_field": "extra_slab_value" 444 | }], 445 | "extra_field": "extra_value" 446 | } 447 | """, 448 | *args, 449 | **kwargs, 450 | ) 451 | 452 | 453 | class TestAdsorbateSlabConfigs(ModelTestWrapper.ModelTest[AdsorbateSlabConfigs]): 454 | """ 455 | Serde tests for the AdsorbateSlabConfigs data model. 456 | """ 457 | 458 | def __init__(self, *args: Any, **kwargs: Any) -> None: 459 | super().__init__( 460 | obj=AdsorbateSlabConfigs( 461 | adsorbate_configs=[ 462 | Atoms( 463 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 464 | pbc=(True, False, True), 465 | numbers=[1, 2], 466 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 467 | tags=[0, 1], 468 | other_fields={"extra_ad_atoms_field": "extra_ad_atoms_value"}, 469 | ), 470 | ], 471 | slab=Slab( 472 | atoms=Atoms( 473 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 474 | pbc=(True, False, True), 475 | numbers=[1, 2], 476 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 477 | tags=[0, 1], 478 | other_fields={ 479 | "extra_slab_atoms_field": "extra_slab_atoms_value" 480 | }, 481 | ), 482 | metadata=SlabMetadata( 483 | bulk_src_id="test_id", 484 | millers=(-1, 0, 1), 485 | shift=0.25, 486 | top=False, 487 | other_fields={"extra_metadata_field": "extra_metadata_value"}, 488 | ), 489 | other_fields={"extra_slab_field": "extra_slab_value"}, 490 | ), 491 | other_fields={"extra_field": "extra_value"}, 492 | ), 493 | obj_json=""" 494 | { 495 | "adsorbate_configs": [ 496 | { 497 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 498 | "pbc": [true, false, true], 499 | "numbers": [1, 2], 500 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 501 | "tags": [0, 1], 502 | "extra_ad_atoms_field": "extra_ad_atoms_value" 503 | } 504 | ], 505 | "slab": { 506 | "slab_atomsobject": { 507 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 508 | "pbc": [true, false, true], 509 | "numbers": [1, 2], 510 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 511 | "tags": [0, 1], 512 | "extra_slab_atoms_field": "extra_slab_atoms_value" 513 | }, 514 | "slab_metadata": { 515 | "bulk_id": "test_id", 516 | "millers": [-1, 0, 1], 517 | "shift": 0.25, 518 | "top": false, 519 | "extra_metadata_field": "extra_metadata_value" 520 | }, 521 | "extra_slab_field": "extra_slab_value" 522 | }, 523 | "extra_field": "extra_value" 524 | } 525 | """, 526 | *args, 527 | **kwargs, 528 | ) 529 | 530 | 531 | class TestAdsorbateSlabRelaxationsSystem( 532 | ModelTestWrapper.ModelTest[AdsorbateSlabRelaxationsSystem] 533 | ): 534 | """ 535 | Serde tests for the AdsorbateSlabRelaxationsSystem data model. 536 | """ 537 | 538 | def __init__(self, *args: Any, **kwargs: Any) -> None: 539 | super().__init__( 540 | obj=AdsorbateSlabRelaxationsSystem( 541 | system_id="test_id", 542 | config_ids=[1, 2, 3], 543 | other_fields={"extra_field": "extra_value"}, 544 | ), 545 | obj_json=""" 546 | { 547 | "system_id": "test_id", 548 | "config_ids": [1, 2, 3], 549 | "extra_field": "extra_value" 550 | } 551 | """, 552 | *args, 553 | **kwargs, 554 | ) 555 | 556 | 557 | class TestAdsorbateSlabRelaxationsRequest( 558 | ModelTestWrapper.ModelTest[AdsorbateSlabRelaxationsRequest] 559 | ): 560 | """ 561 | Serde tests for the AdsorbateSlabRelaxationsRequest data model. 562 | """ 563 | 564 | def __init__(self, *args: Any, **kwargs: Any) -> None: 565 | super().__init__( 566 | obj=AdsorbateSlabRelaxationsRequest( 567 | adsorbate="ABC", 568 | adsorbate_configs=[ 569 | Atoms( 570 | cell=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6), (0.7, 0.8, 0.9)), 571 | pbc=(True, False, True), 572 | numbers=[1, 2], 573 | positions=[(1.1, 1.2, 1.3), (1.4, 1.5, 1.6)], 574 | tags=[2, 2], 575 | other_fields={"extra_ad_field": "extra_ad_value"}, 576 | ) 577 | ], 578 | bulk=Bulk( 579 | src_id="bulk_id", 580 | formula="XYZ", 581 | elements=["X", "Y", "Z"], 582 | other_fields={"extra_bulk_field": "extra_bulk_value"}, 583 | ), 584 | slab=Slab( 585 | atoms=Atoms( 586 | cell=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6), (0.7, 0.8, 0.9)), 587 | pbc=(True, True, True), 588 | numbers=[1], 589 | positions=[(1.1, 1.2, 1.3)], 590 | tags=[0], 591 | other_fields={"extra_slab_field": "extra_slab_value"}, 592 | ), 593 | metadata=SlabMetadata( 594 | bulk_src_id="bulk_id", 595 | millers=(1, 1, 1), 596 | shift=0.25, 597 | top=False, 598 | other_fields={"extra_meta_field": "extra_meta_value"}, 599 | ), 600 | ), 601 | model="test_model", 602 | ephemeral=True, 603 | adsorbate_reaction="A + B -> C", 604 | other_fields={"extra_field": "extra_value"}, 605 | ), 606 | obj_json=""" 607 | { 608 | "adsorbate": "ABC", 609 | "adsorbate_configs": [ 610 | { 611 | "cell": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], 612 | "pbc": [true, false, true], 613 | "numbers": [1, 2], 614 | "positions": [[1.1, 1.2, 1.3], [1.4, 1.5, 1.6]], 615 | "tags": [2, 2], 616 | "extra_ad_field": "extra_ad_value" 617 | } 618 | ], 619 | "bulk": { 620 | "src_id": "bulk_id", 621 | "formula": "XYZ", 622 | "els": ["X", "Y", "Z"], 623 | "extra_bulk_field": "extra_bulk_value" 624 | }, 625 | "slab": { 626 | "slab_atomsobject": { 627 | "cell": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], 628 | "pbc": [true, true, true], 629 | "numbers": [1], 630 | "positions": [[1.1, 1.2, 1.3]], 631 | "tags": [0], 632 | "extra_slab_field": "extra_slab_value" 633 | }, 634 | "slab_metadata": { 635 | "bulk_id": "bulk_id", 636 | "millers": [1, 1, 1], 637 | "shift": 0.25, 638 | "top": false, 639 | "extra_meta_field": "extra_meta_value" 640 | } 641 | }, 642 | "model": "test_model", 643 | "ephemeral": true, 644 | "adsorbate_reaction": "A + B -> C", 645 | "extra_field": "extra_value" 646 | } 647 | """, 648 | *args, 649 | **kwargs, 650 | ) 651 | 652 | 653 | class TestAdsorbateSlabRelaxationsRequest_req_fields_only( 654 | ModelTestWrapper.ModelTest[AdsorbateSlabRelaxationsRequest] 655 | ): 656 | """ 657 | Serde tests for the AdsorbateSlabRelaxationsRequest data model in which 658 | optional fields are omitted. 659 | """ 660 | 661 | def __init__(self, *args: Any, **kwargs: Any) -> None: 662 | super().__init__( 663 | obj=AdsorbateSlabRelaxationsRequest( 664 | adsorbate="ABC", 665 | adsorbate_configs=[ 666 | Atoms( 667 | cell=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6), (0.7, 0.8, 0.9)), 668 | pbc=(True, False, True), 669 | numbers=[1, 2], 670 | positions=[(1.1, 1.2, 1.3), (1.4, 1.5, 1.6)], 671 | tags=[2, 2], 672 | ) 673 | ], 674 | bulk=Bulk( 675 | src_id="bulk_id", 676 | formula="XYZ", 677 | elements=["X", "Y", "Z"], 678 | ), 679 | slab=Slab( 680 | atoms=Atoms( 681 | cell=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6), (0.7, 0.8, 0.9)), 682 | pbc=(True, True, True), 683 | numbers=[1], 684 | positions=[(1.1, 1.2, 1.3)], 685 | tags=[0], 686 | ), 687 | metadata=SlabMetadata( 688 | bulk_src_id="bulk_id", 689 | millers=(1, 1, 1), 690 | shift=0.25, 691 | top=False, 692 | ), 693 | ), 694 | model="test_model", 695 | ), 696 | obj_json=""" 697 | { 698 | "adsorbate": "ABC", 699 | "adsorbate_configs": [ 700 | { 701 | "cell": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], 702 | "pbc": [true, false, true], 703 | "numbers": [1, 2], 704 | "positions": [[1.1, 1.2, 1.3], [1.4, 1.5, 1.6]], 705 | "tags": [2, 2] 706 | } 707 | ], 708 | "bulk": { 709 | "src_id": "bulk_id", 710 | "formula": "XYZ", 711 | "els": ["X", "Y", "Z"] 712 | }, 713 | "slab": { 714 | "slab_atomsobject": { 715 | "cell": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], 716 | "pbc": [true, true, true], 717 | "numbers": [1], 718 | "positions": [[1.1, 1.2, 1.3]], 719 | "tags": [0] 720 | }, 721 | "slab_metadata": { 722 | "bulk_id": "bulk_id", 723 | "millers": [1, 1, 1], 724 | "shift": 0.25, 725 | "top": false 726 | } 727 | }, 728 | "model": "test_model" 729 | } 730 | """, 731 | *args, 732 | **kwargs, 733 | ) 734 | 735 | 736 | class TestAdsorbateSlabRelaxationResult( 737 | ModelTestWrapper.ModelTest[AdsorbateSlabRelaxationResult] 738 | ): 739 | """ 740 | Serde tests for the AdsorbateSlabRelaxationResult data model. 741 | """ 742 | 743 | def __init__(self, *args: Any, **kwargs: Any) -> None: 744 | super().__init__( 745 | obj=AdsorbateSlabRelaxationResult( 746 | config_id=1, 747 | status=Status.SUCCESS, 748 | system_id="sys_id", 749 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 750 | pbc=(True, False, True), 751 | numbers=[1, 2], 752 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 753 | tags=[0, 1], 754 | energy=100.1, 755 | energy_trajectory=[99.9, 100.1], 756 | forces=[(0.1, 0.2, 0.3), (0.4, 0.5, 0.6)], 757 | other_fields={"extra_field": "extra_value"}, 758 | ), 759 | obj_json=""" 760 | { 761 | "config_id": 1, 762 | "status": "success", 763 | "system_id": "sys_id", 764 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 765 | "pbc": [true, false, true], 766 | "numbers": [1, 2], 767 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 768 | "tags": [0, 1], 769 | "energy": 100.1, 770 | "energy_trajectory": [99.9, 100.1], 771 | "forces": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], 772 | "extra_field": "extra_value" 773 | } 774 | """, 775 | *args, 776 | **kwargs, 777 | ) 778 | 779 | def test_to_ase_atoms(self) -> None: 780 | @dataclass 781 | class TestCase: 782 | message: str 783 | result: AdsorbateSlabRelaxationResult 784 | expected_atoms: ASEAtoms 785 | expected_energy: Union[float, Type[Exception]] 786 | expected_forces: Union[np.ndarray, Type[Exception]] 787 | expected_unconstrained_forces: Union[np.ndarray, Type[Exception]] 788 | 789 | # Helper function to construct an ase.Atoms object with the 790 | # input attributes. 791 | def get_ase_atoms( 792 | energy: Optional[float], 793 | forces: Optional[List[Tuple[float, float, float]]], 794 | **kwargs: Any, 795 | ) -> ASEAtoms: 796 | atoms = ASEAtoms(**kwargs) 797 | atoms.calc = SinglePointCalculator( 798 | atoms=atoms, 799 | energy=energy, 800 | forces=forces, 801 | ) 802 | return atoms 803 | 804 | test_cases: List[TestCase] = [ 805 | # If all optional fields are omitted, the generated ase.Atoms 806 | # object should also have no values set 807 | TestCase( 808 | message="optional fields omitted", 809 | result=AdsorbateSlabRelaxationResult( 810 | config_id=1, 811 | status=Status.SUCCESS, 812 | ), 813 | expected_atoms=ASEAtoms(), 814 | expected_energy=Exception, 815 | expected_forces=Exception, 816 | expected_unconstrained_forces=Exception, 817 | ), 818 | # If all fields are included, the generated ase.Atoms object 819 | # should have positions, forces, etc. configured 820 | TestCase( 821 | message="all fields set", 822 | result=AdsorbateSlabRelaxationResult( 823 | config_id=1, 824 | status=Status.SUCCESS, 825 | system_id="sys_id", 826 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 827 | pbc=(True, False, True), 828 | numbers=[1, 2], 829 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 830 | tags=[0, 1], 831 | energy=100.1, 832 | energy_trajectory=[99.9, 100.1], 833 | forces=[(0.1, 0.2, 0.3), (0.4, 0.5, 0.6)], 834 | ), 835 | expected_atoms=get_ase_atoms( 836 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 837 | pbc=(True, False, True), 838 | numbers=[1, 2], 839 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 840 | tags=[0, 1], 841 | energy=100.1, 842 | forces=[(0.1, 0.2, 0.3), (0.4, 0.5, 0.6)], 843 | ), 844 | expected_energy=100.1, 845 | # The constraint on the first atom causes its force to be 846 | # zeroed out 847 | expected_forces=np.array([(0, 0, 0), (0.4, 0.5, 0.6)], float), 848 | expected_unconstrained_forces=np.array( 849 | [(0.1, 0.2, 0.3), (0.4, 0.5, 0.6)], float 850 | ), 851 | ), 852 | ] 853 | 854 | for case in test_cases: 855 | with self.subTest(msg=case.message): 856 | # Check that the atoms object is constructed correctly 857 | ase_atoms = case.result.to_ase_atoms() 858 | self.assertEqual(ase_atoms, case.expected_atoms) 859 | 860 | # Check the energy (or that an exception is raised if expected) 861 | if isinstance(case.expected_energy, float): 862 | self.assertEqual( 863 | case.expected_energy, 864 | ase_atoms.get_potential_energy(), 865 | ) 866 | else: 867 | with self.assertRaises(case.expected_energy): 868 | ase_atoms.get_potential_energy() 869 | 870 | # Check the forces (or that an exception is raised if expected) 871 | if isinstance(case.expected_forces, np.ndarray): 872 | self.assertEqual( 873 | case.expected_forces.tolist(), 874 | ase_atoms.get_forces().tolist(), 875 | ) 876 | else: 877 | with self.assertRaises(case.expected_forces): 878 | ase_atoms.get_forces() 879 | 880 | # Check the unconstrained forces (or that an exception is 881 | # raised if expected) 882 | if isinstance(case.expected_unconstrained_forces, np.ndarray): 883 | self.assertEqual( 884 | case.expected_unconstrained_forces.tolist(), 885 | ase_atoms.get_forces(apply_constraint=False).tolist(), 886 | ) 887 | else: 888 | with self.assertRaises(case.expected_unconstrained_forces): 889 | ase_atoms.get_forces(apply_constraint=False) 890 | 891 | 892 | class TestAdsorbateSlabRelaxationResult_req_fields_only( 893 | ModelTestWrapper.ModelTest[AdsorbateSlabRelaxationResult] 894 | ): 895 | """ 896 | Serde tests for the AdsorbateSlabRelaxationResult data model in which 897 | optional fields are omitted. 898 | """ 899 | 900 | def __init__(self, *args: Any, **kwargs: Any) -> None: 901 | super().__init__( 902 | obj=AdsorbateSlabRelaxationResult( 903 | config_id=1, 904 | status=Status.SUCCESS, 905 | ), 906 | obj_json=""" 907 | { 908 | "config_id": 1, 909 | "status": "success" 910 | } 911 | """, 912 | *args, 913 | **kwargs, 914 | ) 915 | 916 | 917 | class TestAdsorbateSlabRelaxationsResults( 918 | ModelTestWrapper.ModelTest[AdsorbateSlabRelaxationsResults] 919 | ): 920 | """ 921 | Serde tests for the AdsorbateSlabRelaxationsResults data model. 922 | """ 923 | 924 | def __init__(self, *args: Any, **kwargs: Any) -> None: 925 | super().__init__( 926 | obj=AdsorbateSlabRelaxationsResults( 927 | configs=[ 928 | AdsorbateSlabRelaxationResult( 929 | config_id=1, 930 | status=Status.SUCCESS, 931 | system_id="sys_id", 932 | cell=((1.1, 2.1, 3.1), (4.1, 5.1, 6.1), (7.1, 8.1, 9.1)), 933 | pbc=(True, False, True), 934 | numbers=[1, 2], 935 | positions=[(1.1, 1.2, 1.3), (2.1, 2.2, 2.3)], 936 | tags=[0, 1], 937 | energy=100.1, 938 | energy_trajectory=[99.9, 100.1], 939 | forces=[(0.1, 0.2, 0.3), (0.4, 0.5, 0.6)], 940 | other_fields={"extra_adslab_field": "extra_adslab_value"}, 941 | ) 942 | ], 943 | omitted_config_ids=[1, 2, 3], 944 | other_fields={"extra_field": "extra_value"}, 945 | ), 946 | obj_json=""" 947 | { 948 | "configs": [{ 949 | "config_id": 1, 950 | "status": "success", 951 | "system_id": "sys_id", 952 | "cell": [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]], 953 | "pbc": [true, false, true], 954 | "numbers": [1, 2], 955 | "positions": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], 956 | "tags": [0, 1], 957 | "energy": 100.1, 958 | "energy_trajectory": [99.9, 100.1], 959 | "forces": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], 960 | "extra_adslab_field": "extra_adslab_value" 961 | }], 962 | "omitted_config_ids": [1, 2, 3], 963 | "extra_field": "extra_value" 964 | } 965 | """, 966 | *args, 967 | **kwargs, 968 | ) 969 | -------------------------------------------------------------------------------- /tests/unit/client/test_ui.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional 3 | from unittest import TestCase as UnitTestCase 4 | 5 | from ocpapi.client import get_results_ui_url 6 | 7 | 8 | class TestUI(UnitTestCase): 9 | def test_get_results_ui_url(self) -> None: 10 | @dataclass 11 | class TestCase: 12 | message: str 13 | api_host: str 14 | system_id: str 15 | expected: Optional[str] 16 | 17 | test_cases: List[TestCase] = [ 18 | # If the prod host is used, then a URL to the prod UI 19 | # should be returned 20 | TestCase( 21 | message="prod host", 22 | api_host="open-catalyst-api.metademolab.com", 23 | system_id="abc", 24 | expected="https://open-catalyst.metademolab.com/results/abc", 25 | ), 26 | # If an unknown host name is used, then no URL should be returned 27 | TestCase( 28 | message="unknown host", 29 | api_host="unknown.host", 30 | system_id="abc", 31 | expected=None, 32 | ), 33 | ] 34 | 35 | for case in test_cases: 36 | with self.subTest(msg=case.message): 37 | actual = get_results_ui_url(case.api_host, case.system_id) 38 | self.assertEqual(case.expected, actual) 39 | -------------------------------------------------------------------------------- /tests/unit/workflows/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-Catalyst-Project/ocpapi/03a3277d873459816fdba80885ac275f149420c5/tests/unit/workflows/__init__.py -------------------------------------------------------------------------------- /tests/unit/workflows/test_context.py: -------------------------------------------------------------------------------- 1 | from contextvars import ContextVar 2 | from unittest import TestCase as UnitTestCase 3 | from uuid import uuid4 4 | 5 | from ocpapi.workflows.context import set_context_var 6 | 7 | 8 | class TestContext(UnitTestCase): 9 | def test_set_context_var(self) -> None: 10 | # Set an initial value for a context var 11 | ctx_var = ContextVar(str(uuid4())) 12 | ctx_var.set("initial") 13 | 14 | # Update the context variable and make sure it is changed 15 | with set_context_var(ctx_var, "updated"): 16 | self.assertEqual("updated", ctx_var.get()) 17 | 18 | # After exiting the context manager, make sure the context var was 19 | # reset to its original value 20 | self.assertEqual("initial", ctx_var.get()) 21 | -------------------------------------------------------------------------------- /tests/unit/workflows/test_filter.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import sys 3 | from contextlib import ExitStack 4 | from dataclasses import dataclass 5 | from io import StringIO 6 | from typing import Any, List, Optional, Tuple 7 | from unittest import IsolatedAsyncioTestCase, mock 8 | 9 | from inquirer import prompt 10 | from inquirer.events import KeyEventGenerator 11 | from inquirer.render import ConsoleRender 12 | from readchar import key 13 | 14 | from ocpapi.client import AdsorbateSlabConfigs, Atoms, Slab, SlabMetadata 15 | from ocpapi.workflows import ( 16 | keep_all_slabs, 17 | keep_slabs_with_miller_indices, 18 | prompt_for_slabs_to_keep, 19 | ) 20 | 21 | 22 | # Function used to generate a new adslab instance. This filles the minimum 23 | # set of required fields with default values. Inputs allow for overriding 24 | # those defaults. 25 | def _new_adslab( 26 | miller_indices: Optional[Tuple[int, int, int]] = None, 27 | ) -> AdsorbateSlabConfigs: 28 | return AdsorbateSlabConfigs( 29 | adsorbate_configs=[], 30 | slab=Slab( 31 | atoms=Atoms( 32 | cell=((1, 0, 0), (0, 1, 0), (0, 0, 1)), 33 | pbc=[True, True, False], 34 | numbers=[], 35 | positions=[], 36 | tags=[], 37 | ), 38 | metadata=SlabMetadata( 39 | bulk_src_id="bulk_id", 40 | millers=miller_indices or (2, 1, 0), 41 | shift=0.5, 42 | top=True, 43 | ), 44 | ), 45 | ) 46 | 47 | 48 | class TestFilter(IsolatedAsyncioTestCase): 49 | async def test_keep_all_slabs(self) -> None: 50 | @dataclass 51 | class TestCase: 52 | message: str 53 | input: List[AdsorbateSlabConfigs] 54 | expected: List[AdsorbateSlabConfigs] 55 | 56 | test_cases: List[TestCase] = [ 57 | # If no adslabs are provided then none should be returned 58 | TestCase( 59 | message="empty list", 60 | input=[], 61 | expected=[], 62 | ), 63 | # If adslabs are provided, all should be returned 64 | TestCase( 65 | message="non-empty list", 66 | input=[ 67 | _new_adslab(), 68 | _new_adslab(), 69 | ], 70 | expected=[ 71 | _new_adslab(), 72 | _new_adslab(), 73 | ], 74 | ), 75 | ] 76 | 77 | for case in test_cases: 78 | with self.subTest(msg=case.message): 79 | adslab_filter = keep_all_slabs() 80 | actual = await adslab_filter(case.input) 81 | self.assertEqual(case.expected, actual) 82 | 83 | async def test_keep_slabs_with_miller_indices(self) -> None: 84 | @dataclass 85 | class TestCase: 86 | message: str 87 | adslab_filter: keep_slabs_with_miller_indices 88 | input: List[AdsorbateSlabConfigs] 89 | expected: List[AdsorbateSlabConfigs] 90 | 91 | test_cases: List[TestCase] = [ 92 | # If no miller indices are defined, then no slabs should be kept 93 | TestCase( 94 | message="no miller indices allowed", 95 | adslab_filter=keep_slabs_with_miller_indices(miller_indices=[]), 96 | input=[ 97 | _new_adslab(miller_indices=(1, 0, 0)), 98 | _new_adslab(miller_indices=(1, 1, 0)), 99 | _new_adslab(miller_indices=(1, 1, 1)), 100 | ], 101 | expected=[], 102 | ), 103 | # If no slabs are defined then nothing should be returned 104 | TestCase( 105 | message="no slabs provided", 106 | adslab_filter=keep_slabs_with_miller_indices( 107 | miller_indices=[(1, 1, 1)] 108 | ), 109 | input=[], 110 | expected=[], 111 | ), 112 | # Any miller indices that do match should be kept 113 | TestCase( 114 | message="some miller indices matched", 115 | adslab_filter=keep_slabs_with_miller_indices( 116 | miller_indices=[ 117 | (1, 0, 1), # Won't match anything 118 | (1, 0, 0), # Will match 119 | (1, 1, 1), # Will match 120 | ] 121 | ), 122 | input=[ 123 | _new_adslab(miller_indices=(1, 0, 0)), 124 | _new_adslab(miller_indices=(1, 1, 0)), 125 | _new_adslab(miller_indices=(1, 1, 1)), 126 | ], 127 | expected=[ 128 | _new_adslab(miller_indices=(1, 0, 0)), 129 | _new_adslab(miller_indices=(1, 1, 1)), 130 | ], 131 | ), 132 | ] 133 | 134 | for case in test_cases: 135 | with self.subTest(msg=case.message): 136 | actual = await case.adslab_filter(case.input) 137 | self.assertEqual(case.expected, actual) 138 | 139 | async def test_prompt_for_slabs_to_keep(self) -> None: 140 | @dataclass 141 | class TestCase: 142 | message: str 143 | input: List[AdsorbateSlabConfigs] 144 | key_events: List[Any] 145 | expected: List[AdsorbateSlabConfigs] 146 | 147 | test_cases: List[TestCase] = [ 148 | # If no adslabs are provided then none should be returned 149 | TestCase( 150 | message="no slabs provided", 151 | input=[], 152 | key_events=[], 153 | expected=[], 154 | ), 155 | # If adslabs are provided but none are selected then none 156 | # should be returned 157 | TestCase( 158 | message="no slabs selected", 159 | input=[ 160 | _new_adslab(miller_indices=(1, 0, 0)), 161 | _new_adslab(miller_indices=(2, 0, 0)), 162 | _new_adslab(miller_indices=(3, 0, 0)), 163 | ], 164 | key_events=[key.ENTER], 165 | expected=[], 166 | ), 167 | # If adslabs are provided and some are selected then those 168 | # should be returned 169 | TestCase( 170 | message="some slabs selected", 171 | input=[ 172 | _new_adslab(miller_indices=(1, 0, 0)), 173 | _new_adslab(miller_indices=(2, 0, 0)), 174 | _new_adslab(miller_indices=(3, 0, 0)), 175 | ], 176 | key_events=[ 177 | key.SPACE, # Select first slab 178 | key.DOWN, # Move to second slab 179 | key.DOWN, # Move to third slab 180 | key.SPACE, # Select third slab 181 | key.ENTER, # Finish 182 | ], 183 | expected=[ 184 | _new_adslab(miller_indices=(1, 0, 0)), 185 | _new_adslab(miller_indices=(3, 0, 0)), 186 | ], 187 | ), 188 | ] 189 | 190 | for case in test_cases: 191 | with ExitStack() as es: 192 | es.enter_context(self.subTest(msg=case.message)) 193 | 194 | # prompt_for_slabs_to_keep() creates an interactive prompt 195 | # that the user can select from. Here we inject key presses 196 | # to simulate a user interacting with the prompt. First we 197 | # need to direct stdin and stdout to our own io objects. 198 | orig_stdin = sys.stdin 199 | orig_stdout = sys.stdout 200 | try: 201 | sys.stdin = StringIO() 202 | sys.stdout = StringIO() 203 | 204 | # Now we create a inquirer.ConsoleRender instance that 205 | # uses the key_events (key presses) in the current test 206 | # case. 207 | it = iter(case.key_events) 208 | renderer = ConsoleRender( 209 | event_generator=KeyEventGenerator(lambda: next(it)) 210 | ) 211 | 212 | # Now inject our renderer into the prompt 213 | es.enter_context( 214 | mock.patch( 215 | "inquirer.prompt", 216 | side_effect=functools.partial( 217 | prompt, 218 | render=renderer, 219 | ), 220 | ) 221 | ) 222 | 223 | # Finally run the filter 224 | adslab_filter = prompt_for_slabs_to_keep() 225 | actual = await adslab_filter(case.input) 226 | self.assertEqual(case.expected, actual) 227 | 228 | finally: 229 | sys.stdin = orig_stdin 230 | sys.stdout = orig_stdout 231 | -------------------------------------------------------------------------------- /tests/unit/workflows/test_retry.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import suppress 3 | from dataclasses import dataclass 4 | from datetime import timedelta 5 | from typing import ( 6 | Any, 7 | Callable, 8 | Final, 9 | Iterable, 10 | List, 11 | Optional, 12 | Tuple, 13 | Type, 14 | TypeVar, 15 | Union, 16 | ) 17 | from unittest import TestCase as UnitTestCase 18 | from unittest import mock 19 | 20 | from ocpapi.client import ( 21 | NonRetryableRequestException, 22 | RateLimitExceededException, 23 | RequestException, 24 | ) 25 | from ocpapi.workflows import ( 26 | NO_LIMIT, 27 | NoLimitType, 28 | RateLimitLogging, 29 | retry_api_calls, 30 | ) 31 | 32 | T = TypeVar("T") 33 | 34 | 35 | # Helper function that returns the input value immediately 36 | def returns(val: T) -> Callable[[], T]: 37 | return lambda: val 38 | 39 | 40 | # Helper function that raises the input exception 41 | def raises(ex: Exception) -> Callable[[], None]: 42 | def func() -> None: 43 | raise ex 44 | 45 | return func 46 | 47 | 48 | class TestRetry(UnitTestCase): 49 | def test_retry_api_calls__results(self) -> None: 50 | # Tests for retry behavior under various results (returning a 51 | # successful value, raising various exceptions, etc.) 52 | 53 | @dataclass 54 | class TestCase: 55 | message: str 56 | max_attempts: Union[int, NoLimitType] 57 | funcs: Iterable[Callable[[], Any]] 58 | expected_attempt_count: int 59 | expected_return_value: Final[Optional[Any]] = None 60 | expected_exception: Final[Optional[Type[Exception]]] = None 61 | 62 | test_cases: List[TestCase] = [ 63 | # If a function runs successfully on the first call then exactly 64 | # one attempt should be made 65 | TestCase( 66 | message="success on first call", 67 | max_attempts=3, 68 | funcs=[returns(True)], 69 | expected_attempt_count=1, 70 | expected_return_value=True, 71 | ), 72 | # If a function raises a generic exception, it should never be 73 | # retried 74 | TestCase( 75 | message="non-api-type exception", 76 | max_attempts=3, 77 | funcs=[raises(Exception())], 78 | expected_attempt_count=1, 79 | expected_exception=Exception, 80 | ), 81 | # If a function raises an exception from the API that is not 82 | # retryable, then it should be re-raised 83 | TestCase( 84 | message="non-retryable api exception", 85 | max_attempts=3, 86 | funcs=[raises(NonRetryableRequestException("", "", ""))], 87 | expected_attempt_count=1, 88 | expected_exception=NonRetryableRequestException, 89 | ), 90 | # If a function raises an exception from the API that can be 91 | # retried, then another call should be made 92 | TestCase( 93 | message="retryable api exception, below max attempts", 94 | max_attempts=3, 95 | # Raise on the first attempt and return a value on the second 96 | funcs=[raises(RequestException("", "", "")), returns(True)], 97 | # Expect that two calls are made since the first should be 98 | # retried 99 | expected_attempt_count=2, 100 | expected_return_value=True, 101 | ), 102 | # If a function raises an exception from the API that can be 103 | # retried, but is raised more times than is allowed, then it 104 | # should be re-raised eventually 105 | TestCase( 106 | message="retryable api exception, exceeds max attempts", 107 | # Make at most two calls to the function 108 | max_attempts=2, 109 | # Raise on each attempt 110 | funcs=[ 111 | raises(RequestException("", "", "")), 112 | raises(RequestException("", "", "")), 113 | ], 114 | # Expect that two calls are made since the first should be 115 | # retried 116 | expected_attempt_count=2, 117 | # Except that the exception is eventually raised 118 | expected_exception=RequestException, 119 | ), 120 | # If a function has no limit on the number of retries, then it 121 | # should retry the function until a non-retryable exception is 122 | # raised 123 | TestCase( 124 | message="no attempt limit", 125 | max_attempts=NO_LIMIT, 126 | # Raise several retryable exceptions before finally raising 127 | # one that cannot be retried 128 | funcs=[ 129 | raises(RequestException("", "", "")), 130 | raises(RequestException("", "", "")), 131 | raises(RequestException("", "", "")), 132 | raises(RequestException("", "", "")), 133 | raises(Exception()), 134 | ], 135 | expected_attempt_count=5, 136 | expected_exception=Exception, 137 | ), 138 | ] 139 | 140 | for case in test_cases: 141 | with self.subTest(msg=case.message): 142 | func_iter = iter(case.funcs) 143 | 144 | # Retried function that runs all of the functions in the test 145 | # case 146 | @retry_api_calls( 147 | max_attempts=case.max_attempts, 148 | # Disable waits so that this tests run more quickly. 149 | # Waits are tested elsewhere. 150 | fixed_wait_sec=0, 151 | max_jitter_sec=0, 152 | ) 153 | def test_func() -> Any: 154 | cur_func = next(func_iter) 155 | return cur_func() 156 | 157 | # Make sure an exception is raised if one is expected 158 | if case.expected_exception is not None: 159 | with self.assertRaises(case.expected_exception): 160 | test_func() 161 | else: 162 | return_value = test_func() 163 | self.assertEqual(case.expected_return_value, return_value) 164 | 165 | # Make sure the number of function calls is expected 166 | self.assertEqual( 167 | case.expected_attempt_count, 168 | test_func.retry.statistics["attempt_number"], 169 | ) 170 | 171 | def test_retry_api_calls__wait(self) -> None: 172 | # Tests for retry wait times 173 | 174 | @dataclass 175 | class TestCase: 176 | message: str 177 | funcs: Iterable[Callable[[], Any]] 178 | fixed_wait_sec: float 179 | max_jitter_sec: float 180 | expected_duration_range: Tuple[float, float] 181 | 182 | test_cases: List[TestCase] = [ 183 | # A function that succeeds on the first attempt should return 184 | # without waiting 185 | TestCase( 186 | message="success on first call", 187 | funcs=[returns(None)], 188 | fixed_wait_sec=2, 189 | max_jitter_sec=1, 190 | # Function should return nearly immediately, but certainly 191 | # much faster than the fixed wait time 192 | expected_duration_range=(0, 0.1), 193 | ), 194 | # A function that raises a non-retryable exception should return 195 | # without waiting 196 | TestCase( 197 | message="non-retryable exception", 198 | funcs=[raises(Exception)], 199 | fixed_wait_sec=2, 200 | max_jitter_sec=1, 201 | # Function should return nearly immediately, but certainly 202 | # much faster than the fixed wait time 203 | expected_duration_range=(0, 0.1), 204 | ), 205 | # A function with a retryable API exception, but without a value 206 | # for retry-after, should wait based on the fixed wait time and 207 | # random jitter 208 | TestCase( 209 | message="retryable exception without retry-after", 210 | funcs=[raises(RequestException("", "", "")), returns(None)], 211 | fixed_wait_sec=0.2, 212 | max_jitter_sec=0.2, 213 | # Function should wait between 0.2 and 0.4 seconds - give a 214 | # small buffer on the upper bound 215 | expected_duration_range=(0.2, 0.5), 216 | ), 217 | # A function with a retryable API exception that includes a value 218 | # for retry-after should wait based on that returned time 219 | TestCase( 220 | message="retryable exception with retry-after", 221 | funcs=[ 222 | raises(RateLimitExceededException("", "", timedelta(seconds=0.2))), 223 | returns(None), 224 | ], 225 | fixed_wait_sec=2, 226 | max_jitter_sec=0.2, 227 | # Function should wait between 0.2 and 0.4 seconds - give a 228 | # small buffer on the upper bound 229 | expected_duration_range=(0.2, 0.5), 230 | ), 231 | ] 232 | 233 | for case in test_cases: 234 | with self.subTest(msg=case.message): 235 | func_iter = iter(case.funcs) 236 | 237 | # Retried function that runs all of the functions in the test 238 | # case 239 | @retry_api_calls( 240 | fixed_wait_sec=case.fixed_wait_sec, 241 | max_jitter_sec=case.max_jitter_sec, 242 | ) 243 | def test_func() -> Any: 244 | cur_func = next(func_iter) 245 | return cur_func() 246 | 247 | # Time the function execution. Ignore any exceptions since 248 | # they may be intentionally raised in the test case. 249 | start = time.monotonic() 250 | with suppress(Exception): 251 | test_func() 252 | took = time.monotonic() - start 253 | 254 | self.assertGreaterEqual(took, case.expected_duration_range[0]) 255 | self.assertLessEqual(took, case.expected_duration_range[1]) 256 | 257 | def test_retry_api_calls__logging(self) -> None: 258 | # Tests for logging during retries 259 | 260 | @dataclass 261 | class TestCase: 262 | message: str 263 | funcs: Iterable[Callable[[], Any]] 264 | log_action: str 265 | expected_log_statements: Iterable[str] 266 | 267 | test_cases: List[TestCase] = [ 268 | # A function that succeeds immediately should not generate any 269 | # log statements 270 | TestCase( 271 | message="success on first call", 272 | funcs=[returns(None)], 273 | log_action="test_action", 274 | expected_log_statements=[], 275 | ), 276 | # A function that raises a non-retryable exception should not 277 | # generate any log statements 278 | TestCase( 279 | message="non-retryable exception", 280 | funcs=[raises(Exception)], 281 | log_action="test_action", 282 | expected_log_statements=[], 283 | ), 284 | # A function with a retryable API exception, but without a value 285 | # for retry-after, should not generate any log statements 286 | TestCase( 287 | message="base retryable exception", 288 | funcs=[raises(RequestException("", "", "")), returns(None)], 289 | log_action="test_action", 290 | expected_log_statements=[], 291 | ), 292 | # A function with a RateLimitExceededException exception, but 293 | # without a value for retry-after, should not generate any log 294 | # statements 295 | TestCase( 296 | message="rate limit exception without retry-after", 297 | funcs=[ 298 | raises(RateLimitExceededException("", "", None)), 299 | returns(None), 300 | ], 301 | log_action="test_action", 302 | expected_log_statements=[], 303 | ), 304 | # A function with a RateLimitExceededException exception with a 305 | # value for retry-after should generate a log statement 306 | TestCase( 307 | message="rate limit exception with retry-after", 308 | funcs=[ 309 | raises(RateLimitExceededException("", "", timedelta(seconds=1))), 310 | returns(None), 311 | ], 312 | log_action="test_action", 313 | expected_log_statements=[ 314 | ( 315 | "Request to test_action was rate limited with " 316 | "retry-after = 1.0 seconds" 317 | ) 318 | ], 319 | ), 320 | ] 321 | 322 | for case in test_cases: 323 | with self.subTest(msg=case.message): 324 | func_iter = iter(case.funcs) 325 | 326 | # Inject a mocked logger so we can intercept calls to log 327 | # messages 328 | import logging 329 | 330 | log = mock.create_autospec(logging.Logger) 331 | 332 | # Retried function that runs all of the functions in the test 333 | # case 334 | @retry_api_calls( 335 | fixed_wait_sec=0, 336 | max_jitter_sec=0, 337 | rate_limit_logging=RateLimitLogging( 338 | logger=log, 339 | action=case.log_action, 340 | ), 341 | ) 342 | def test_func() -> Any: 343 | cur_func = next(func_iter) 344 | return cur_func() 345 | 346 | # Run the test function and make sure the excepted log 347 | # statements were generated. Ignore any exceptions since 348 | # they may be intentionally raised in the test case. 349 | with suppress(Exception): 350 | test_func() 351 | log.info.assert_has_calls( 352 | [mock.call(s) for s in case.expected_log_statements] 353 | ) 354 | --------------------------------------------------------------------------------