├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md └── workflows │ ├── documentation.yml │ └── test-type-lint.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── conf.py ├── conftest.py ├── index.rst ├── infra │ ├── explanation.md │ ├── howto.md │ ├── introduction.md │ ├── reference.rst │ └── tutorials.md └── make.bat ├── exca ├── __init__.py ├── base.py ├── cachedict.py ├── confdict.py ├── data │ ├── cachedict2501 │ │ ├── .cache_type │ │ ├── x-9dd4e461.key │ │ ├── x-9dd4e461.npy │ │ ├── y-41529076.key │ │ └── y-41529076.npy │ └── compat-test-2024-11-12.pkl ├── dumperloader.py ├── helpers.py ├── logconf.py ├── map.py ├── py.typed ├── slurm.py ├── task.py ├── test_base.py ├── test_cachedict.py ├── test_compat.py ├── test_confdict.py ├── test_dumperloader.py ├── test_helpers.py ├── test_localmap.py ├── test_map.py ├── test_safeguard.py ├── test_submit.py ├── test_task.py ├── test_utils.py ├── test_workdir.py ├── utils.py └── workdir.py └── pyproject.toml /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to _exca_ 2 | We want to make contributing to this project as easy and transparent as possible. Note however that this project currently aims at helping a specific research team and that their requirements will be prioritized. 3 | 4 | ## Our Development Process 5 | _exca_ is actively used by a team of FAIR researcher and engineers. 6 | Bugs tracking and feature plannings are public. 7 | 8 | 9 | ## Pull Requests 10 | We welcome your pull requests. 11 | 12 | 1. Fork the repo and create your branch from `main`. 13 | 2. Install precommit hooks `pre-commit install` 14 | 3. If you've added code that should be tested, add tests. 15 | 4. If you've changed APIs, update the documentation. 16 | 5. Ensure the test suite passes. (`pytest exca`) 17 | 6. Make sure your code lints. (`black exca`, `mypy exca`) 18 | 7. If you haven't already, complete the Contributor License Agreement ("CLA"). 19 | 20 | ## Contributor License Agreement ("CLA") 21 | In order to accept your pull request, we need you to submit a CLA. You only need 22 | to do this once to work on any of Facebook's open source projects. 23 | 24 | Complete your CLA here: 25 | 26 | ## Issues 27 | We use GitHub issues to track public bugs. Please ensure your description is 28 | clear and has sufficient instructions to be able to reproduce the issue. 29 | 30 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 31 | disclosure of security bugs. In those cases, please go through the process 32 | outlined on that page and do not file a public issue. 33 | 34 | ## Coding Style 35 | We use black coding style with a 90 line length. 36 | 37 | ## License 38 | By contributing to _exca_, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: [push, workflow_dispatch] # pull_request, 4 | 5 | permissions: 6 | contents: write 7 | 8 | jobs: 9 | build-docs: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | - name: Install dependencies 15 | run: | 16 | pip install -U pip 17 | pip install sphinx sphinx_rtd_theme myst_parser pytest 18 | pip install -e . 19 | - name: Verify basic install 20 | run: pytest exca/test_task.py::test_task_infra 21 | - name: Sphinx build 22 | run: pushd docs;make html;popd 23 | - name: Deploy to GitHub Pages 24 | uses: peaceiris/actions-gh-pages@v3 25 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} 26 | with: 27 | publish_branch: gh-pages 28 | github_token: ${{ secrets.GITHUB_TOKEN }} 29 | publish_dir: docs/_build/html 30 | force_orphan: true 31 | -------------------------------------------------------------------------------- /.github/workflows/test-type-lint.yaml: -------------------------------------------------------------------------------- 1 | name: Build & run pytest-mypy-linters 2 | env: 3 | IN_GITHUB_ACTION: 1 4 | 5 | on: [push] 6 | 7 | jobs: 8 | run-on-ubuntu: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | max-parallel: 5 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 3.10 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.10' 19 | 20 | # Building/caching the environment 21 | 22 | - name: Add conda to system path 23 | run: | 24 | # $CONDA is an environment variable pointing to the root of the miniconda directory 25 | echo $CONDA 26 | echo $CONDA/bin >> $GITHUB_PATH 27 | echo $CONDA_PREFIX 28 | 29 | - name: Cache conda env 30 | id: cache-conda 31 | uses: actions/cache@v4 32 | env: 33 | # change name here (only) to invalidate cache 34 | cache-name: cache-conda-env-v0 35 | with: 36 | key: ${{ env.cache-name }}-${{ hashFiles('pyproject.toml') }} 37 | path: ./ci_env 38 | 39 | - name: Create conda env 40 | run: | 41 | # creates the env if it does not exist (not loaded from cache) 42 | sudo apt-get update 43 | if [ ! -d "./ci_env" ]; then \ 44 | conda create -p ./ci_env python=3.10 ipython -y 45 | fi 46 | 47 | - name: Install dependencies 48 | run: | 49 | source activate ./ci_env 50 | pip install -e .[dev] 51 | 52 | - name: Print installed packages 53 | run: | 54 | source activate ./ci_env 55 | pip freeze 56 | 57 | # start checks 58 | 59 | - name: Run type hint checks with mypy 60 | run: | 61 | source activate ./ci_env 62 | pip show mypy 63 | mypy exca 64 | 65 | - name: Test with pytest 66 | run: | 67 | source activate ./ci_env 68 | pip show pytest 69 | pytest exca --durations=10 70 | 71 | - name: Test README code blocks 72 | run: | 73 | source activate ./ci_env 74 | # update readmes to avoid running on slurm: 75 | sed -i 's/cluster: slurm/cluster: null/g' docs/infra/*.md 76 | sed -i 's/\"auto\"/None/g' README.md 77 | # on Mac: sed -i '' 's/cluster: slurm/cluster: null/g' docs/infra/*.md 78 | # check readmes 79 | pytest --markdown-docs -m markdown-docs `**/*.md` 80 | 81 | - name: Run basic pylint 82 | run: | 83 | source activate ./ci_env 84 | pip show pylint 85 | pylint exca --disable=all --enable=unused-import,unused-variable,redefined-builtin,used-before-assignment,super-init-not-called,useless-super-delegation,dangerous-default-value,unnecessary-pass,attribute-defined-outside-init 86 | 87 | - name: black 88 | run: | 89 | source activate ./ci_env 90 | black --version 91 | black -v --check --diff exca 92 | 93 | - name: isort 94 | run: | 95 | source activate ./ci_env 96 | isort --version 97 | isort --check --diff exca 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files / tmp 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.swp 6 | *_build 7 | Untitled.ipynb 8 | *.ipynb_checkpoints 9 | 10 | # C extensions 11 | *.so 12 | 13 | # OS specific files 14 | .DS_Store 15 | 16 | # Distribution / packaging / data storage 17 | data/ 18 | outputs/ 19 | .Python 20 | env/ 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # dotenv 91 | .env 92 | 93 | # virtualenv 94 | .venv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # project specific folders 112 | exp_local 113 | outputs 114 | data 115 | tmp 116 | 117 | # adding output from unit-tests here 118 | *-raw.fif 119 | *.nii.gz 120 | .vscode 121 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.3.0 4 | hooks: 5 | - id: black 6 | # It is recommended to specify the latest version of Python 7 | # supported by your project here, or alternatively use 8 | # pre-commit's default_language_version, see 9 | # https://pre-commit.com/#top_level-default_language_version 10 | language_version: python3 11 | - repo: https://github.com/pycqa/isort 12 | rev: 5.12.0 13 | hooks: 14 | - id: isort 15 | name: isort (python) 16 | args: ["--profile", "black"] 17 | language_version: python3 18 | - repo: https://github.com/kynan/nbstripout 19 | rev: 0.7.1 20 | hooks: 21 | - id: nbstripout 22 | language_version: python3 23 | - repo: https://github.com/PyCQA/autoflake 24 | rev: v2.2.1 25 | hooks: 26 | - id: autoflake 27 | args: [--remove-all-unused-imports, --ignore-init-module-imports, --in-place] 28 | language_version: python3 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exca - ⚔ 2 | 3 | Execute and cache seamlessly in python. 4 | 5 | ![workflow badge](https://github.com/facebookresearch/exca/actions/workflows/test-type-lint.yaml/badge.svg) 6 | 7 | ## Quick install 8 | 9 | ``` 10 | pip install exca 11 | ``` 12 | 13 | ## Full documentation 14 | 15 | Documentation is available at [https://facebookresearch.github.io/exca/](https://facebookresearch.github.io/exca/) 16 | 17 | ## Basic overview 18 | 19 | `exca` provides simple decorators to: 20 | - execute a (hierarchy of) computation(s) either locally or on distant nodes, 21 | - cache the result. 22 | 23 | ### The problem: 24 | In ML pipelines, the use of a simple python function, such as `my_task`: 25 | 26 | ```python 27 | import numpy as np 28 | 29 | def my_task(param: int = 12) -> float: 30 | return param * np.random.rand() 31 | ``` 32 | 33 | often requires cumbersome overheads to (1) configure the parameters, (2) submit the job on a cluster, (3) cache the results: e.g. 34 | ```python continuation fixture:tmp_path 35 | import pickle 36 | from pathlib import Path 37 | import submitit 38 | 39 | # Configure 40 | param = 12 41 | 42 | # Check task has already been executed 43 | filepath = tmp_path / f'result-{param}.npy' 44 | if not filepath.exists(): 45 | 46 | # Submit job on cluster 47 | executor = submitit.AutoExecutor(cluster=None, folder=tmp_path) 48 | job = executor.submit(my_task, param) 49 | result = job.result() 50 | 51 | # Cache result 52 | with filepath.open("wb") as f: 53 | pickle.dump(result, f) 54 | ``` 55 | 56 | These overheads lead to several issues, such as debugging, handling hierarchical execution and properly saving the results (ending in the classic `'result-parm12-v2_final_FIX.npy'`). 57 | 58 | 59 | ### The solution: 60 | `exca` can be used to decorate the method of a [`pydantic` model](https://docs.pydantic.dev/latest/) so as to seamlessly configure its execution and caching: 61 | 62 | ```python fixture:tmp_path 63 | import numpy as np 64 | import pydantic 65 | import exca as xk 66 | 67 | class MyTask(pydantic.BaseModel): 68 | param: int = 12 69 | infra: xk.TaskInfra = xk.TaskInfra() 70 | 71 | @infra.apply 72 | def process(self) -> float: 73 | return self.param * np.random.rand() 74 | 75 | 76 | task = MyTask(param=1, infra={"folder": tmp_path, "cluster": "auto"}) 77 | out = task.process() # runs on slurm if available 78 | # calling process again will load the cache and not a new random number 79 | assert out == task.process() 80 | ``` 81 | See the [API reference for all the details](https://facebookresearch.github.io/exca/infra/reference.html#exca.TaskInfra) 82 | 83 | 84 | ## Quick comparison 85 | 86 | | **feature \ tool** | lru_cache | hydra | submitit | exca | 87 | | ----------------------------- | :-------: | :---: | :------: | :--: | 88 | | RAM cache | ✔ | | | ✔ | 89 | | file cache | | | | ✔ | 90 | | remote compute | | ✔ | ✔ | ✔ | 91 | | pure python (vs command line) | ✔ | | ✔ | ✔ | 92 | | hierarchical config | | ✔ | | ✔ | 93 | 94 | ## Contributing 95 | 96 | See the [CONTRIBUTING](.github/CONTRIBUTING.md) file for how to help out. 97 | 98 | ## Citing 99 | ```bibtex 100 | @misc{exca, 101 | author = {J. Rapin and J.-R. King}, 102 | title = {{Exca - Execution and caching}}, 103 | year = {2024}, 104 | publisher = {GitHub}, 105 | journal = {GitHub repository}, 106 | howpublished = {\url{https://github.com/facebookresearch/exca}}, 107 | } 108 | ``` 109 | ## License 110 | 111 | `exca` is MIT licensed, as found in the LICENSE file. 112 | Also check-out Meta Open Source [Terms of Use](https://opensource.fb.com/legal/terms) and [Privacy Policy](https://opensource.fb.com/legal/privacy). 113 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | from pathlib import Path 9 | 10 | project = "Exca" 11 | copyright = "Meta Platforms, Inc" 12 | author = "FAIR" 13 | release = "0.1" 14 | 15 | 16 | # -- General configuration --------------------------------------------------- 17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 18 | 19 | extensions = [ 20 | "myst_parser", 21 | "sphinx.ext.autodoc", 22 | "sphinx.ext.linkcode", 23 | # "sphinx.ext.autosectionlabel", 24 | # "sphinx.ext.githubpages", 25 | # "sphinx.ext.coverage", 26 | # "sphinx.ext.napoleon", 27 | # "sphinx.ext.autosummary", 28 | # "recommonmark", 29 | ] 30 | 31 | templates_path = ["_templates"] 32 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "pythonplusplus", ".pytest_cache"] 33 | 34 | # Prefix document path to section labels, to use: 35 | # `path/to/file:heading` instead of just `heading` 36 | autosectionlabel_prefix_document = True 37 | 38 | 39 | # -- Options for HTML output ------------------------------------------------- 40 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 41 | 42 | html_theme = "sphinx_rtd_theme" 43 | html_static_path = [] # ["_static"] 44 | 45 | 46 | def linkcode_resolve(domain, info): 47 | if domain != "py": 48 | return None 49 | if not info["module"]: 50 | return None 51 | base = Path().absolute().parent 52 | module = info["module"].replace(".", "/") 53 | if (base / module).with_suffix(".py").exists(): 54 | filepath = module 55 | else: 56 | filepath = ( 57 | module 58 | + "/" 59 | + info["fullname"].split(".", maxsplit=1)[0].replace("Infra", "").lower() 60 | ) 61 | if not (base / filepath).with_suffix(".py").exists(): 62 | return None 63 | return "https://github.com/facebookresearch/exca/blob/main/%s.py" % filepath 64 | -------------------------------------------------------------------------------- /docs/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | import numpy as np 10 | import pydantic 11 | import torch 12 | import yaml 13 | 14 | import exca 15 | from exca import MapInfra, TaskInfra 16 | 17 | 18 | class TutorialTask(pydantic.BaseModel): 19 | param: int = 12 20 | infra: TaskInfra = TaskInfra(version="1") 21 | 22 | @infra.apply 23 | def process(self) -> float: 24 | return self.param * np.random.rand() 25 | 26 | 27 | class TutorialMap(pydantic.BaseModel): 28 | param: int = 12 29 | infra: MapInfra = MapInfra(version="1") 30 | 31 | @infra.apply(item_uid=str) 32 | def process(self, items: tp.Iterable[int]) -> tp.Iterator[np.ndarray]: 33 | for item in items: 34 | yield np.random.rand(item, self.param) 35 | 36 | 37 | def pytest_markdown_docs_globals() -> tp.Dict[str, tp.Any]: 38 | return { 39 | "TutorialTask": TutorialTask, 40 | "TutorialMap": TutorialMap, 41 | "MapInfra": MapInfra, 42 | "TaskInfra": TaskInfra, 43 | "pydantic": pydantic, 44 | "yaml": yaml, 45 | "torch": torch, 46 | "exca": exca, 47 | } 48 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Example documentation master file, created by 2 | sphinx-quickstart on Sat Sep 23 20:35:12 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Exca - Execution and caching infrastructure 7 | =========================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Summary 12 | 13 | infra/introduction.md 14 | infra/tutorials.md 15 | infra/howto.md 16 | infra/explanation.md 17 | infra/reference.rst 18 | 19 | Citing 20 | ------ 21 | 22 | .. code-block:: bibtex 23 | 24 | @misc{exca, 25 | author = {J. Rapin and J.-R. King}, 26 | title = {{Exca - Execution and caching}}, 27 | year = {2024}, 28 | publisher = {GitHub}, 29 | journal = {GitHub repository}, 30 | howpublished = {\url{https://github.com/facebookresearch/exca}}, 31 | } 32 | 33 | Legal 34 | ----- 35 | 36 | :code:`exca` is MIT licensed, as found in the LICENSE file. 37 | Also check-out Meta Open Source `Terms of Use `_ and `Privacy Policy `_. 38 | -------------------------------------------------------------------------------- /docs/infra/explanation.md: -------------------------------------------------------------------------------- 1 | # Explanations 2 | 3 | ## Why? The philosophy 4 | (philosophy)= 5 | 6 | ### Pure python 7 | The tools here do not provide a script API but a way to do everything directly from Python. Specific script APIs can be easily composed on top of it if need be. 8 | 9 | ### Parameter validation 10 | Configurations should be validated before running to avoid discovering bugs much later (eg: missing parameter, inconsistent parameters, wrong type etc). We do this by using `pydantic.BaseModel` which works as `dataclasses` but validate all parameters. 11 | 12 | ### Fast configs 13 | Running a grid search requires creating a bunch of configs, so configurations should be easy and fast to create, and therefore not defer loading data/pytorch models/etc to later 14 | 15 | ### No parameter duplication - easy to exetend 16 | Configuration hold the underlying actual functions/classes parameters. To avoid duplicating the parameters, we opt for having coupled configs and actual classes/functions like below: 17 | 18 | ```python 19 | class MyClassCfg(pydantic.BaseModel): 20 | x: int = 12 21 | y: str = "hello" 22 | 23 | def build(self) -> "MyClass": 24 | return MyClass(self) 25 | 26 | 27 | class MyClass: 28 | def __init__(self, cfg: MyClassCfg): 29 | self.cfg = cfg 30 | ``` 31 | With this easy pattern, building an object from the config is easy (`cfg.build()`), and adding new parameters only requires updating the config, with effective typing and low risk of silently ignored parameters because of a mismatch between configs and functions. 32 | 33 | ### Cached/distributed computation 34 | The main aim of this package is to provide objects that slightly modify methods and make them 35 | distributed and cached their results in a breeze. The `infra` objects that make this possible 36 | are configurations that let you specify how caching should be performed and 37 | how computation should be distributed directly within your experiment config 38 | (including slurm partitions, number of gpus etc) 39 | 40 | ### Modularity 41 | 42 | Pydantic hierarchical configuration and discriminated unions allows for modularity and reusability, as several sub-configs can be proposed for a training config, and plugging a new sub-config is straightforward. 43 | 44 | ## MapInfra / Task Infra differences 45 | 46 | `TaskInfra` must be applied to a method with no parameter (except `self`). It links 1 computation to 1 job and therefore provides easy tools for accessing the job stdout/stderr/status etc. 47 | 48 | `MapInfra` on the other hand must be applied to a method with 1 parameter (in addition to `self`) which must be a `m`-sized iterator/sequence of items. It requires stating how to provide a unique uid for each item (throught the `item_uid` function, and it maps `m` computation (1 for each item) to `n <= m` jobs, packing several computations together. Because of this non-bijective mapping, there is no support for checking jobs stderr/stdout/status. 49 | 50 | 51 | ## uid computation 52 | 53 | A unique id, the `uid`, is computed for each pydantic model/each config based on public instance attributes: 54 | - which are non-defaults 55 | - which are not excluded through the `_exclude_from_cls_uid` class attribute list/tuple 56 | (or class method returning a list/tuple). This allows removing parameters which do not impact 57 | the result (eg: number of workers, device, ect...). 58 | 59 | *Note*: `infra` objects have all their parameters excluded except `version` as the parameters affects how the computation 60 | is performed but not the result 61 | 62 | Furthermore, a specific "cache" `uid` is also computed for which additional parameters can be excluded to account for 63 | parameters which do not impact the cached computation, but impact the class as a whole (attributes which are used 64 | to post-process the cached computation). This is done by specifying `exclude_from_cache_uid` in the 65 | `infra.apply` method. This cache uid is used as storage folder name for the cache. 66 | Exclusion can be specified as a list/tuple of field, or as a method, or as the name of a method 67 | (with format `method:`). Notice that when subclassing, if you specified the exclusion 68 | as a function, the original function will be used (not the new function if it was overriden), 69 | if you want to use the new one, then you should specify the method through its name. 70 | 71 | See more in the [example](howto-efficient-caching) from the how-to guide. 72 | 73 | 74 | ## ConfDict 75 | To simplify working with configuration dictionaries, we use `ConfDict` classes (see [their API](exca.confdict.ConfDict)). 76 | In practice, they are dictionary which breaks into sub-dictionnaries on `"."` characters 77 | such as in a config. Data can be specified either through dotted-keywords or directly through sub-dictionaries 78 | or a mixture of both: 79 | 80 | ```python 81 | from exca import ConfDict 82 | 83 | cfdict = ConfDict({"training.optim.lr": 0.01}) 84 | assert cfdict == {"training": {"optim": {"lr": 0.01}}} 85 | ``` 86 | 87 | `ConfDict` instance have a few convenient methods: 88 | ```python continuation 89 | # flattens the dictionary 90 | assert cfdict.flat() == {"training.optim.lr": 0.01} 91 | 92 | # export to yaml (can take a file as argument) 93 | assert cfdict.to_yaml() == "training.optim.lr: 0.01\n" 94 | 95 | # uid computation 96 | assert cfdict.to_uid() == "training.optim.lr=0.01-0f8936b4" 97 | ``` 98 | 99 | Infra objects extensively use such dictionaries and have a `config` method for instantiating the `ConfDict` generated from an object: 100 | ```python 101 | task = TutorialTask(param=13) 102 | cfdict = task.infra.config(uid=True, exclude_defaults=True) 103 | assert cfdict == {"param": 13} 104 | ``` 105 | 106 | They are used for uid computation as shown above (with `uid=True, exclude_defaults=True`) but also to clone the instance 107 | and updating its value, so that you can pass new values either through dotted-name format or through sub-dictionaries: 108 | ```python continuation 109 | # exports a ConfDict and reinstantiate from it 110 | new = task.infra.clone_obj({"param": 14}) 111 | assert new.param == 14 112 | ``` 113 | 114 | 115 | 116 | 117 | ## Caching 118 | Cache folders are created as `..,/` 119 | 120 | Eg: `mypackage.mymodule.TutorialTask.process,1/param=13-fbfu2iow` 121 | 122 | Under the hood, data are stored using the `CacheDict` class (see [API here](exca.cachedict.CacheDict)). 123 | This class has a `dict`-like interface (`keys`, `items`, `values`, `__getitem__`, `__contains__`), the difference is that the data can be stored to/loaded from disk automatically, and `__setitem__` works through a context manager to be more efficient. 124 | The class is initialized with 2 parameters: 125 | - the storage folder, if present the data will be stored to disk in the folder, or reloaded from disk 126 | - `keep_in_ram` flag, if `True` the data will be cached in RAM when stored/reloaded, for faster access 127 | - `cache_type` the type of caching to use (eg: cache as pickles, or independent npy files, or one large npy file) 128 | 129 | 130 | **Example** 131 | ```python fixture:tmp_path 132 | import numpy as np 133 | from exca import cachedict 134 | # create a cache dict (specialized for numpy arrays) 135 | cache = cachedict.CacheDict(folder=tmp_path, keep_in_ram=True) 136 | 137 | # the dictionary is empty: 138 | assert not cache 139 | 140 | # add a value into the cache 141 | x = np.random.rand(2, 12) 142 | with cache.writer() as writer: 143 | # cache dict needs a writer context to 144 | # be more efficient in case of multiple writes 145 | writer["blublu"] = x 146 | assert "blublu" in cache 147 | # the value is now available 148 | np.testing.assert_almost_equal(cache["blublu"], x) 149 | assert set(cache.keys()) == {"blublu"} 150 | 151 | # create a new dict instance with same cache folder 152 | cache2 = cachedict.CacheDict(folder=tmp_path) 153 | # the data is still available (loading from cache folder) 154 | assert set(cache2.keys()) == {"blublu"} 155 | ``` 156 | 157 | In practice, at write time, each thread/process independently creates an `*-info.jsonl` file in which each line is a json providing the key in the dictionaray, and information on how to read the data corresponding to this key. 158 | 159 | `CacheDict` is designed for use within an infra and may be sub-optimal for other use cases (eg: repeated checks to `__contains__` can repeatedly reload the keys from the file system if the key is not already present, to make sure nothing new was added through another thread/process, which can be inefficient). 160 | 161 | -------------------------------------------------------------------------------- /docs/infra/howto.md: -------------------------------------------------------------------------------- 1 | # How-to guide 2 | 3 | ## Debugging 4 | 5 | Use the following lines to add more logging during debugging, this will help understand what happens: 6 | ```python 7 | import logging 8 | logging.getLogger("exca").setLevel(logging.DEBUG) 9 | ``` 10 | 11 | Also, use `cluster=None` or `cluster="debug"` to avoid distributed computation which is harder to debug. 12 | 13 | **Note**: raised error in a decorated task method will have a different type than the actual raised exception (so as to have an error printing the traceback of the initial error) 14 | 15 | When communiating with others for help, make sure to send the stack trace as well as the error, as the stack trace is often much more informative than the sole error. 16 | 17 | 18 | ## Asynchronous computation with TaskInfra 19 | 20 | ### Job 21 | 22 | Results of a cached task computation can be obtained in 2 ways: 23 | - either by calling the method directly `task.process()` 24 | - through the attached `job = task.infra.job()` result. 25 | 26 | ```python fixture:tmp_path 27 | task = TutorialTask(infra={"folder": tmp_path}) 28 | assert task.process() == task.infra.job().result() 29 | ``` 30 | 31 | Calling `job = task.infra.job()` starts the submission if it was not already started, and does not wait for the result, only calling `job.result()` will block until completion. This can let you run other computation in your current process / monitor the job (eg through `task.infra.status()`) etc. 32 | 33 | 34 | ### Batching (job arrays in slurm clusters) 35 | 36 | Submitting one task job at a time is not efficient nor a good practice *for slurm clusters*, as each submission consumes some of slurm scheduler resources. A better way is to submit arrays of job. With `TaskInfra`, this is done through a `job_array` context: 37 | 38 | ```python fixture:tmp_path 39 | task = TutorialTask(infra={"folder": tmp_path}) 40 | 41 | with task.infra.job_array() as array: 42 | # "array" is a list to append/extend with tasks to compute 43 | for k in range(3): 44 | # the following creates a new task with a new "param" value 45 | task_to_compute = task.infra.clone_obj({"param": k}) 46 | array.append(task_to_compute) 47 | 48 | # leaving the context is non-blocking and submits all tasks to compute 49 | assert array[0].infra.status() in ("completed", "running") 50 | ``` 51 | 52 | Similarly as with calling `task.infra.job()`, previously computed tasks are not resubmitted, unless `infra.mode` is set to `"force"`, or set to `"retry"` with a previously failed computation. 53 | 54 | ## Monitoring 55 | 56 | To monitor running jobs on slurm clusters, one can use the `squeue` command. However, when running a lot of slurm jobs it can become complex to figure out which job does what. We provide a basic helper function to access the logs and config of a given job (assuming default log folder position): `exca.helpers.find_slurm_job` 57 | See more details in its [API reference](#exca.helpers.find_slurm_job). 58 | 59 | We also recommend using [Turm](https://github.com/kabouzeid/turm), which provides a real-time interface to access the `stdout/stderr` of running jobs. Simply install it with `pip install turm` and you can then use `turm --slurm-refresh 20 --me --states RUNNING` to check your running jobs (please use at least 20s for the slurm refresh rate to avoid overloading the cluster). 60 | 61 | 62 | ## Efficient caching: cache and class uid exclusion 63 | (howto-efficient-caching)= 64 | 65 | Consider the following class that defines a `process` function which returns a random torch tensor: 66 | 67 | ```python 68 | import typing as tp 69 | import torch 70 | import numpy as np 71 | 72 | class UidTask(pydantic.BaseModel): 73 | seed: int = 12 74 | shape: tp.Tuple[int, int] = (3, 4) 75 | coeff: float = 12.0 76 | device: str = "cpu" 77 | infra: TaskInfra = TaskInfra(version="1") 78 | 79 | @infra.apply 80 | def _internal_cached_method(self) -> np.ndarray: 81 | rng = np.random.default_rng(seed=12) 82 | return rng.normal(size=self.shape) 83 | 84 | def process(self) -> torch.Tensor: 85 | array = self._internal_cached_method() 86 | tensor = torch.Tensor(array, device=self.device) 87 | return self.coeff * tensor 88 | ``` 89 | 90 | `infra` uses all parameters of the `pydantic.BaseModel` to define a `uid` (a string) of the class or of the cache. Two objects with a same `uid` should behave the same way / provide the same results. Preferably two objects which provide the same computation should also have the same uid, but there are a couple of issues here. 91 | 92 | ### Class uid 93 | The uid of the class takes into account all parameters, you can check the parameters through the `config` for instance: 94 | ```python continuation 95 | task = UidTask(device="cuda", coeff=3) 96 | assert task.infra.config(uid=True) == { 97 | 'seed': 12, 98 | 'shape': [3, 4], 99 | 'coeff': 3.0, 100 | 'device': 'cuda', 101 | 'infra': {'version': '1'} 102 | } 103 | assert task.infra.config(uid=True, exclude_defaults=True) == { 104 | 'coeff': 3.0, 105 | 'device': 'cuda' 106 | } 107 | ``` 108 | In practice the `uid` is computed from the non-default parameters, so the `uid` will be something like `coeff=3,device=cuda-4f4ca7cb` in this case (the last part being a hash for security reasons). 109 | 110 | `device` however defines where the computation is performed but has no impact on the actual result, so it should not impact the uid of the class. This parameter should therefore be excluded from the cache, this can be done by either having a `_exclude_from_cls_uid` method or class variable. 111 | 112 | All `infra` parameters except `version` are ignored in such a way because caching or the required resources for computation (number of cpus/gpus for instance) do not impact the actual result of the computation. `version` is therefore the only parameter of the `infra` that will appear in the config even when specifying caching or remote computation options. 113 | 114 | ```python continuation 115 | class UidTask2(UidTask): 116 | _exclude_from_cls_uid: tp.ClassVar[list[str]] = ["device"] 117 | 118 | task2 = UidTask2(device="cuda", coeff=3) 119 | assert task2.infra.config(uid=True, exclude_defaults=True) == {'coeff': 3.0} 120 | ``` 121 | 122 | ### Cache uid 123 | 124 | Cache also requires a `uid` that will be used to store the results. All parameters ignored from the class `uid` are also ignored for the cache (such as `device`). We can however see in this case that while the `UidTask` does depend on `coeff` parameter (used in `process`), the cache does not, because `_internal_cached_method` does not use it. We can then further ignore `coeff` from the cache by specifying it through the `exclude_from_cache_uid` of the `infra.apply` decorator. This parameter can be one of: 125 | 1. a list of parameter names to ignore 126 | 2. a method defined in the class returning a list of parameters names to ignore 127 | 3. the method above specified by name under the format `"method:"` 128 | 129 | The main difference between 2 and 3 is that a method override in a subclass will only be taken into account with option 3. 130 | 131 | ### Updated task 132 | 133 | 134 | Here an updated class with better `uid` handling: 135 | ```python continuation 136 | class UidTask(pydantic.BaseModel): 137 | seed: int = 12 138 | shape: tp.Tuple[int, int] = (3, 4) 139 | coeff: float = 12.0 140 | device: str = "cpu" 141 | infra: TaskInfra = TaskInfra(version="1") 142 | _exclude_from_cls_uid: tp.ClassVar[tuple[str, ...]] = ("device",) 143 | 144 | @infra.apply(exclude_from_cache_uid=("coeff",)) 145 | def _internal_cached_method(self) -> np.ndarray: 146 | rng = np.random.default_rng(seed=12) 147 | return rng.normal(size=self.shape) 148 | 149 | def process(self) -> torch.Tensor: 150 | array = self._internal_cached_method() 151 | tensor = torch.Tensor(array, device=self.device) 152 | return self.coeff * tensor 153 | ``` 154 | 155 | and here is an equivalent class with different options for specifying the class and cache exclusions: 156 | ```python continuation 157 | class UidTask(pydantic.BaseModel): 158 | seed: int = 12 159 | shape: tp.Tuple[int, int] = (3, 4) 160 | coeff: float = 12.0 161 | device: str = "cpu" 162 | infra: TaskInfra = TaskInfra(version="1") 163 | 164 | def _exclude_from_cls_uid(self) -> tuple[str, ...]: 165 | return ("device",) 166 | 167 | def _cache_exclusion(self) -> tp.List[str]: 168 | return ["coeff"] 169 | 170 | @infra.apply(exclude_from_cache_uid="method:_cache_exclusion") 171 | def _internal_cached_method(self) -> np.ndarray: 172 | rng = np.random.default_rng(seed=12) 173 | return rng.normal(size=self.shape) 174 | 175 | def process(self) -> torch.Tensor: 176 | array = self._internal_cached_method() 177 | tensor = torch.Tensor(array, device=self.device) 178 | return self.coeff * tensor 179 | ``` 180 | 181 | ## Infra Versioning & default heritage 182 | 183 | All attributes of infra configurations are ignored for uid computation (i.e. modifying eg the type of cluster / number of cpus in the infra will not modify the uid), except their `version` attribute. This allows for cache invalidation. Indeed, changing this value will change the current class uid and therefore lead to the creation of a new cache folder. Cache of classes depending on the class with a new version will not be usable anymore (because of conflicting default version value) and the cache folder of these classes may need to be deleted. 184 | 185 | Furthermore, **all** attributes of the default infra set on a config class serve as seed/default values when you instantiate a config instance. So, when instantiating the `TutorialMap` class (see [tutorial](tutorial-map)), you will get a `version="1"` in your instance if you do not override it: 186 | 187 | 188 | ```python continuation fixture:tmp_path 189 | mapper = TutorialMap(infra={"folder": tmp_path}) 190 | assert mapper.infra.version == "1" 191 | # even though the default version in a MapInfra is actually "0": 192 | assert MapInfra(**{"folder": tmp_path}).version == "0" 193 | ``` 194 | 195 | Be careful, this behavior is limited to infra objects, so as to preset version and computation defaults more easily, other nested config do not behave this way. 196 | 197 | 198 | ## Using pydantic's discriminator 199 | (howto-discriminator)= 200 | 201 | `pydantic` allows for automatic selection between several sub-configurations, such as between a `Dog` or a `Cat` sub-configuration below: 202 | 203 | ```python 204 | import typing as tp 205 | 206 | class Dog(pydantic.BaseModel): 207 | name: tp.Literal["dog"] = "dog" 208 | dog_param: int = 12 209 | 210 | class Cat(pydantic.BaseModel): 211 | name: tp.Literal["cat"] = "cat" 212 | cat_param: str = "mew" 213 | 214 | class Household(pydantic.BaseModel): 215 | pet: Dog | Cat = pydantic.Field(..., discriminator="name") 216 | 217 | 218 | household = Household(pet={"name": "dog"}) 219 | assert household.pet.dog_param == 12 220 | ``` 221 | 222 | The syntax requires providing a "discriminator" field which is a constant (a `Literal`) for selecting which class to instantiate. While explicitely stating the discriminator through `pydantic.Field` (as above) or through an annotation (see in `pydantic` documentation) is not strictly necessary with `pydantic`, it is necessary when working with infras so that the discriminator be part of the uid. 223 | 224 | 225 | ## Workdir/code copy 226 | 227 | Running code on a cluster while still working on the code can be dangereous, as the job will use the state of the code **at start time** and **not at submission time**. 228 | 229 | In order to avoid surprises, both task/map infra support a `workdir` for copying the code to a different working directory where the decorated function will be running from: 230 | [Check its parameters here](exca.workdir.WorkDir), in particular, the `copied` parameter can be used to select folders or files or packages installed in editable mode that should be copied to the job's working directory, like in the [`TutorialTask` class from the tutorial section](infra/tutorials:TaskInfra): 231 | 232 | ```python fixture:tmp_path 233 | task = TutorialTask(infra={ 234 | "cluster": "local", 235 | "folder": tmp_path, 236 | # will create a copy of exca in a folder and run from there: 237 | "workdir": {"copied": ["exca"]}, 238 | }) 239 | ``` 240 | 241 | Note that the change of working directory (and possibly the copy) only happens when the infra is called for submitting the decorated function. Depending on your code, this may not be at the very beginning of your execution. 242 | 243 | -------------------------------------------------------------------------------- /docs/infra/introduction.md: -------------------------------------------------------------------------------- 1 | # Exca - Execution and caching 2 | 3 | This is an explanation to why `exca` was built. If you are only intereseted in how to use it, you can move to [tutorials](tutorials.md) and [how-to](howto.md) pages. 4 | 5 | Here are the challenges we want to face: 6 | 1. config validation and remote computation 7 | 2. hierarchical computation and modularity 8 | 3. experiment/computation caching 9 | 10 | 11 | ## Challenge #1: Early configuration validation 12 | ```bash notest 13 | srun --cpus-per-task=4 --time=60 python -m mytask --z=12 14 | 15 | >> srun: job 34633429 queued and waiting for resources 16 | >> srun: job 34633429 has been allocated resources 17 | ... 18 | ... 19 | >> usage: mytask.py [-h] [--x X] [--y Y] 20 | >> mytask.py: error: unrecognized arguments: --z=12 21 | >> srun: error: learnfair0478: task 0: Exited with exit code 2 22 | ``` 23 | 24 | Or similarly: 25 | ```bash notest 26 | mytask.py: error: argument --y: invalid int value: 'blublu' 27 | ``` 28 | 29 | ### Observations and consequences 30 | 31 | - Configurations should be validated before running on the cluster! 32 | - need some tool for validation 33 | → verify configurations locally first 34 | → in Python 35 | → submit from Python as well (avoid boilerplate additional bash command) 36 | - Resource configuration (srun params) and computation configurations (mytask params) come in 2 different places (and sometimes formats) 37 | → specify resource configuration within the same configuration as the computation configuration? (while keeping them distinct in some way?!) 38 | 39 | 40 | ### Parameter validation with Pydantic 41 | 42 | Pydantic (21k★ on github) works like dataclasses, but with (fast) validation: 43 | ```python 44 | import pydantic 45 | 46 | class MyTask(pydantic.BaseModel): 47 | model_config = pydantic.ConfigDict(extra="forbid") # pydantic boilerplate 48 | x: int 49 | y: str = "blublu" 50 | 51 | mytask = MyTask(x=12) 52 | mytask.x # this is 12 53 | 54 | # MyTask(x="blublu") 55 | # >> ValidationError: 1 validation error for MyTask (x hould be a valid integer) 56 | ``` 57 | 58 | Pydantic supports hierarchical configurations: 59 | 60 | ```python continuation 61 | class Parent(pydantic.BaseModel): 62 | task: MyTask 63 | 64 | obj = Parent(task={"x": 12}) # parses the dict into a MyTask class 65 | obj.task.x # this is 12 66 | ``` 67 | 68 | #### Note: discarded options 69 | - `dataclasses`: no dynamic type check 70 | - `omegaconf`: can typecheck (when using dataclasses) but is slow and not well-maintained 71 | - `attrs`: probably usable (but smaller community 5k★ Vs 21k★) 72 | 73 | 74 | ### Local/remote submission with exca 75 | 76 | Convenient pattern (more on this later): tie computation to the config class: 77 | 78 | ```python 79 | class MyTask(pydantic.BaseModel): 80 | x: int 81 | y: str = "blublu" 82 | model_config = pydantic.ConfigDict(extra="forbid") # pydantic boilerplate 83 | 84 | def compute(self) -> int: 85 | print(self.y) 86 | return 2 * self.x 87 | ``` 88 | 89 | Then if we want to enable remote computation, we add a `exca.TaskInfra` subconfiguration: 90 | ```python 91 | class MyTask(pydantic.BaseModel): 92 | x: int 93 | y: str = "blublu" 94 | infra: exca.TaskInfra = exca.TaskInfra() 95 | # note: automatically sets extra="forbid" 96 | 97 | @infra.apply 98 | def compute(self) -> int: 99 | print(self.y) 100 | return 2 * self.x 101 | ``` 102 | By default, this changes nothing, but you can now parametrize the infra to run the `compute` method on slurm, eg: 103 | 104 | ```python continuation fixture:tmp_path 105 | config = f""" 106 | x: 12 107 | y: whatever 108 | infra: # resource parameters 109 | cluster: slurm 110 | folder: {tmp_path} 111 | cpus_per_task: 4 112 | """ 113 | 114 | dictconfig = yaml.safe_load(config) 115 | obj = MyTask(**dictconfig) # validation happens locally 116 | out = obj.compute() # runs in a slurm job! 117 | assert out == 24 118 | ``` 119 | 120 | 121 | Note that the config now holds both the computation parameters (`x` and `y`) as well as the resources parameters (through `infra`) but they are separated through the hierarchical structure of the config. 122 | 123 | 124 | 125 | ## Challenge #2: Complex experiments - hierarchical configurations 126 | Do's and don'ts with `pydantic`'s configurations 127 | 128 | ### Parametrizing pattern 129 | 130 | Seen in many codebases: 131 | ```python 132 | class ConvCfg(pydantic.BaseModel): 133 | layers: int = 12 134 | kernel: int = 5 135 | channels: int = 128 136 | 137 | 138 | class ConvModel(torch.nn.Module): 139 | 140 | def __init__(self, layers: int, kernel: int, channels: int, other: int = 12) -> None: 141 | self.layers = layers 142 | self.kernel = kernel 143 | self.channels = channels 144 | self.other = other 145 | ... # build layers, add forward method 146 | 147 | 148 | # then in your code 149 | cfg = ConvCfg(layers=10, kernel=5, channels=16) 150 | model = ConvModel(layers=cfg.layers, kernel=cfg.kernel, channels=cfg.channels) 151 | ``` 152 | Issues: 153 | - a lot of duplicated code/work 154 | - easy to mess up when propagating a new parameter as it needs 4 changes: the config, the model init parameters, the content of the init, the instantiation of the model from the config (any typo? any mismatch generating a silent bug?) 155 | - some defaults may not be configurable 156 | 157 | 158 | Here is a simpler pattern: 159 | ```python 160 | class ConvCfg(pydantic.BaseModel): 161 | layers: int = 12 162 | kernel: int = 5 163 | channels: int = 128 164 | 165 | def build(self) -> torch.nn.Module: 166 | # instantiate when needed 167 | # (do not slow down config initialization) 168 | return ConvModel(self) 169 | 170 | 171 | class ConvModel(torch.nn.Module): 172 | 173 | def __init__(self, cfg: ConvCfg) -> None: 174 | self.cfg = cfg 175 | ... # build layers, add forward method 176 | 177 | # then in your code 178 | model = ConvCfg().build() 179 | ``` 180 | 181 | **Cost**: classes become coupled (then again, you don't need importing `ConvModel` anymore) 182 | 183 | **Benefit**: fixes all issues mentioned above with 1 set of defaults, in a single place 184 | 185 | 186 | ### One step further - Discriminated unions 187 | 188 | Pipelines often get complex, and require if-else conditions depending on configurations, for instance: 189 | ```python 190 | import typing as tp 191 | 192 | class ModelCfg(pydantic.BaseModel): 193 | name: tp.Literal["conv", "transformer"] = "conv" # special discriminator field 194 | # shared parameters 195 | layers: int = 12 196 | # convolution parameters 197 | kernel: int = 5 198 | channels: int = 128 199 | # transformer parameters 200 | embeddings: int = 128 201 | 202 | def build(self) -> torch.nn.Module: 203 | if self.name == "conv": 204 | return ConvModel(self) 205 | else: 206 | return TransformerModel(self) 207 | ``` 208 | 209 | This creates coupling between different models into a unique config where some parameters are ignored depending on the cases, and can become messier and messier with more models. 210 | Fortunately, `pydantic`'s discriminated unions easily address this issue: 211 | 212 | 213 | ```python continuation 214 | class ConvCfg(pydantic.BaseModel): 215 | name: tp.Literal["conv"] = "conv" # special discriminator field 216 | layers: int = 12 217 | kernel: int = 5 218 | channels: int = 128 219 | 220 | def build(self) -> torch.nn.Module: 221 | return ConvModel(self) #instantiate when needed 222 | 223 | ... 224 | 225 | class TransformerCfg(pydantic.BaseModel): 226 | model_config = pydantic.ConfigDict(extra="forbid") # pydantic boilerplate: safer 227 | name: tp.Literal["transformer"] = "transformer" # special discriminator field 228 | layers: int = 12 229 | embeddings: int = 128 230 | 231 | def build(self) -> torch.nn.Module: 232 | return TransformerModel(self) 233 | 234 | ... 235 | 236 | class Trainer(pydantic.BaseModel): 237 | model: ConvCfg | TransformerCfg = pydantic.Field(..., discriminator="name") 238 | optimizer: str = "Adam" 239 | infra: TaskInfra = TaskInfra() 240 | 241 | @infra.apply 242 | def run(self) -> float: 243 | model = self.model.build() # build either one of the model 244 | # specific location for this very config: 245 | ckpt_path = self.infra.uid_folder() / "checkpoint.pt" 246 | if ckpt_path.exists(): 247 | # load 248 | ... 249 | ... 250 | for batch in loader: 251 | ... 252 | return accuracy 253 | 254 | 255 | string = """ 256 | model: 257 | name: transformer # specifies which model 258 | embeddings: 256 # only accepts transformer specific parameters 259 | optimizer: SGD 260 | """ 261 | trainer = Trainer(**yaml.safe_load(string)) 262 | 263 | isinstance(trainer.model, TransformerCfg) 264 | ``` 265 | 266 | Discriminated unions make it easier to make **modular pipelines** as one can swap part of the experiment by others very easily, and still get full parameter validation. 267 | 268 | 269 | ## Challenge #3: Experiment/computation caching 270 | 271 | `exca` can also handle **caching of the computation result** with no extra effort, so any computation already performed will only be recomputed if explicitely required. 272 | 273 | A lot of additional benefits also come for free: 274 | - sub-configs can have their own infra. 275 | - running a grid search only requires a `for` loop. 276 | - computations can be packed into a job array. 277 | - computation can be performed in a dedicated working directory to avoid interfering with the code. 278 | 279 | ```python continuation fixture:tmp_path 280 | string = f""" 281 | model: 282 | name: transformer # specifies which model 283 | embeddings: 256 284 | optimizer: SGD 285 | infra: 286 | gpus_per_node: 8 287 | cpus_per_task: 80 288 | slurm_constraint: volta32gb 289 | folder: {tmp_path} 290 | cluster: slurm 291 | slurm_partition: learnfair 292 | workdir: 293 | copied: 294 | - . # copies current working directory into a dedicated workdir 295 | # - whatever_other_file_or_folder 296 | """ 297 | 298 | trainer = Trainer(**yaml.safe_load(string)) 299 | with trainer.infra.job_array() as array: 300 | for layers in [12, 14, 15]: 301 | array.append(trainer.infra.clone_obj({"model.layers": layers})) 302 | # leaving the context submits all trainings in a job array 303 | # and is non-blocking 304 | 305 | # show one of the slurm jobs 306 | print(array[0].infra.job()) 307 | ``` 308 | 309 | 310 | Overall with this way of experimenting, you easily get: 311 | - modular pipeline with simple building blocks 312 | - easy remote computation configuration 313 | - validated configuration before sending to remote cluster through a job array 314 | - cached results so that only missing elements of the array get sent 315 | 316 | -------------------------------------------------------------------------------- /docs/infra/reference.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. autoclass:: exca.TaskInfra 5 | :members: 6 | :inherited-members: 7 | :exclude-members: model_post_init, model_fields, model_computed_fields, model_config, model_construct, model_copy, model_dump, model_dump_json, model_extra, model_fields_set, model_json_schema, model_parametrized_name, model_rebuild, model_validate, model_validate_json, model_validate_strings, copy, apply_on 8 | 9 | 10 | .. autoclass:: exca.MapInfra 11 | :members: 12 | :inherited-members: 13 | :exclude-members: model_post_init, model_fields, model_computed_fields, model_config, model_construct, model_copy, model_dump, model_dump_json, model_extra, model_fields_set, model_json_schema, model_parametrized_name, model_rebuild, model_validate, model_validate_json, model_validate_strings, copy, apply_on 14 | 15 | .. autoclass:: exca.SubmitInfra 16 | :members: 17 | :exclude-members: model_post_init, model_fields, model_computed_fields, model_config, model_construct, model_copy, model_dump, model_dump_json, model_extra, model_fields_set, model_json_schema, model_parametrized_name, model_rebuild, model_validate, model_validate_json, model_validate_strings, copy, apply_on 18 | 19 | Associated classes and functions 20 | -------------------------------- 21 | 22 | .. autoclass:: exca.slurm.SubmititMixin 23 | :members: 24 | :inherited-members: 25 | :exclude-members: model_post_init, model_fields, model_computed_fields, model_config, model_construct, model_copy, model_dump, model_dump_json, model_extra, model_fields_set, model_json_schema, model_parametrized_name, model_rebuild, model_validate, model_validate_json, model_validate_strings, copy, apply_on 26 | 27 | .. autoclass:: exca.workdir.WorkDir 28 | :members: 29 | :exclude-members: model_post_init, model_fields, model_computed_fields, model_config 30 | 31 | .. autoclass:: exca.ConfDict 32 | :members: 33 | :exclude-members: model_post_init, model_fields, model_computed_fields, model_config 34 | 35 | .. autoclass:: exca.cachedict.CacheDict 36 | :members: 37 | 38 | .. automodule:: exca.helpers 39 | :members: with_infra, find_slurm_job 40 | -------------------------------------------------------------------------------- /docs/infra/tutorials.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | `pydantic` is a package providing model/configuration classes and allows for parameter validation when instantiating the object. `exca` package builds on top of it, and provides "infra" pydantic configuration that can be part of a parent pydantic configuration and change the way it behaves. In particular, it lets one add caching and remote computation to its methods. Check-out the package [philosophy](philosophy) for more in depths explanation of the "whys" of this package. 4 | 5 | If you are not familiar with `pydantic`, have a look first at the [Pydantic models section](#pydantic-models). 6 | 7 | ## Installation 8 | 9 | `pip install exca` 10 | 11 | ## Two types of infra: Task and Map 12 | 13 | Infras currently come in 2 flavors. 14 | 15 | ### TaskInfra 16 | (infra/tutorials:TaskInfra)= 17 | Consider you have one pydantic model/config that fully defines one processing to perform, for instance through a `process` method like below: 18 | 19 | 20 | ```python 21 | import numpy as np 22 | import typing as tp 23 | import pydantic 24 | 25 | class TutorialTask(pydantic.BaseModel): 26 | param: int = 12 27 | 28 | def process(self) -> float: 29 | return self.param * np.random.rand() 30 | ``` 31 | 32 | Adding an infra on the `process` model only requires adding an [`TaskInfra`](#exca.TaskInfra) object to the config: 33 | 34 | 35 | ```python continuation 36 | import typing as tp 37 | import torch 38 | import exca 39 | 40 | 41 | class TutorialTask(pydantic.BaseModel): 42 | param: int = 12 43 | infra: exca.TaskInfra = exca.TaskInfra(version="1") 44 | 45 | @infra.apply 46 | def process(self) -> float: 47 | return self.param * np.random.rand() 48 | ``` 49 | 50 | `TaskInfra` provides configuration for caching and computation, in particular providing a `folder` activates caching through the filesystem: 51 | 52 | 53 | ```python continuation fixture:tmp_path 54 | task = TutorialTask(param=1, infra={"folder": tmp_path}) 55 | out = task.process() 56 | # calling process again will load the cache and not a new random number 57 | assert out == task.process() 58 | ``` 59 | 60 | Adding `cluster="auto"` to the `infra` would trigger computation either on slurm cluster if available, or in a dedicated process otherwise. See the [API reference for all the details](#exca.TaskInfra) 61 | 62 | 63 | ### Map infra 64 | (tutorial-map)= 65 | 66 | The `TaskInfra` above is limited to methods that do not take additional arguments / computations that are fully defined by the configuration such as an experiment/a training for instance. Consider now that the configuration defines a computation to be applied to a list of items (eg: process a list of images / texts etc), this is the use case for the [`MapInfra`](#exca.MapInfra): 67 | 68 | ```python 69 | import typing as tp 70 | import pydantic 71 | import numpy as np 72 | from exca import MapInfra 73 | 74 | class TutorialMap(pydantic.BaseModel): 75 | param: int = 12 76 | infra: MapInfra = MapInfra(version="1") 77 | 78 | @infra.apply(item_uid=str) 79 | def process(self, items: tp.Iterable[int]) -> tp.Iterator[np.ndarray]: 80 | for item in items: 81 | yield np.random.rand(item, self.param) 82 | ``` 83 | 84 | As opposed to the `TaskInfra`, the `MapInfra.apply` method now requires an `item_uid` parameter that states how to map each item of the input iterable into a unique string which will be used for identification/caching. 85 | 86 | From then, calling `whatever.process([1, 2, 3])` will trigger (possibly) remote computation and caching/storage. 87 | You can control the remote resources through the `infra` instance. 88 | Eg: the following will trigger the computation in the current process (change `"cluster": None` to `auto` to have it run on `slurm` cluster if available or in a dedicated process) 89 | 90 | ```python continuation fixture:tmp_path 91 | mapper = TutorialMap(infra={"cluster": None, "folder": tmp_path, "cpus_per_task": 1}) 92 | mapper.process([1, 2, 3]) 93 | ``` 94 | 95 | See the [API reference for all the details](#exca.TaskInfra) 96 | 97 | ### Features of MapInfra and TaskInfra 98 | 99 | This section provides an overview of parameters and features of infra, but the full [API reference page](exca.TaskInfra) will provide mode options and details if need be. 100 | 101 | Common useful parameters include: 102 | - `folder`: where to create the cache folder 103 | - `mode`: one of: 104 | - `cached`: cache is returned if available (error or not), otherwise computed (and cached). This is the default behavior. 105 | - `force`: cache is ignored, and result is (re)computed (and cached) 106 | - `retry` (only for `TaskInfra`): cache is returned if available except if it's an error, otherwise (re)computed (and cached) 107 | - submitit/slurm parameters (eg: `gpus_per_node`, `cpus_per_node`, `slurm_partition`, `slurm_constraint` etc) 108 | 109 | All infra object have common features such as: 110 | - **config export**: through `task.infra.config(uid=False, exclude_defaults=True)`. 111 | - **uid/xp folder**: through `task.infra.uid_folder()`. The folder is always populated with the full config, and the reduced uid config. It also contains a symlink to the job folder. 112 | 113 | When filesystem caching is used, the folder will contain useful information: 114 | - `config.yaml`: the full configuration (all parameters) of the pydantic model 115 | - `full-uid.yaml`: the config defining the task/map, including defaults (not including non-uid related configs such as number of workers etc) 116 | - `uid.yaml`: the minimal config defining the task/map (not including defaults, nor non-uid related configs such as number of workers etc). 117 | 118 | It will also optionally contain: 119 | - `code` (if `workdir` is specified): a symlink to the directory where the task was executed 120 | - `submitit` (for `TaskInfra` if `cluster` is not `None`): a symlink to the folder containing all `submitit` related files for the task (stdout, stderr, batch file etc) 121 | 122 | 123 | `TaskInfra` also has additional features, in particular: 124 | - *job access*: through `task.infra.job()`. Jobs submitted through `submitit` have `stdout()`, `stderr()` and `cancel()`, and more. All jobs have methods `result()`, `done()`, `wait()`. Calling `infra.job()` submits the job if it does not already exists. 125 | - *cache/job clearing*: through `task.infra.clear_job()` 126 | 127 | ## Quick comparison 128 | 129 | | **feature \ tool** | lru_cache | hydra | submitit | stool | exca | 130 | | ----------------------------- | :-------: | :---: | :------: | :---: | :--: | 131 | | RAM cache | ✔ | ✘ | ✘ | ✘ | ✔ | 132 | | file cache | ✘ | ✘ | ✘ | ✘ | ✔ | 133 | | remote compute | ✘ | ✔ | ✔ | ✔ | ✔ | 134 | | pure python (vs commandline) | ✔ | ✘ | ✔ | ✘ | ✔ | 135 | | hierarchical config | ✘ | ✔ | ✘ | ✘ | ✔ | 136 | 137 | 138 | ## Simplified infra decorator 139 | 140 | For quick experimentation with infra, the `infra.helpers.with_infra` function decorator can add an infra parameter on most functions (with simple arguments). 141 | 142 | ```python fixture:tmp_path 143 | import numpy as np 144 | import exca 145 | 146 | @exca.helpers.with_infra(folder=tmp_path) 147 | def my_func(a: int, b: int) -> np.ndarray: 148 | return np.random.rand(a, b) 149 | 150 | out = my_func(a=3, b=4) 151 | out2 = my_func(a=3, b=4) 152 | 153 | np.testing.assert_array_equal(out2, out) # should the same (as cached) 154 | ``` 155 | 156 | On the long run this is not adviced as this will prevent you from using many features of infra (running an array of jobs, checking their status etc) 157 | 158 | 159 | ## Pydantic models 160 | (pydantic-models)= 161 | 162 | This is a quick recap of important features of `pydantic` models. Models do not have an `__init__` method, parameters are instead specified directly in the class (as if they were class attributes, but they will not be): 163 | 164 | 165 | ```python 166 | import pydantic 167 | 168 | class MyModel(pydantic.BaseModel): 169 | x: int 170 | y: str = "blublu" 171 | 172 | mymodel = MyModel(x=12) 173 | assert mymodel.x == 12 174 | ``` 175 | 176 | One can then instantiate it easily with `mymodel = MyModel(x=12)` and access attributes like `mymodel.x`. One important feature is the typechecking when instantiating the objects, as `x` is typed as an `int`, the field will not accept a string, and the following code would raise an exception: `mymodel = MyModel(x="wrong")`. 177 | 178 | 179 | **Note**: `pydantic` is very similar to the more standard `dataclasses` with a few important features: models are type checked (dataclasses are not), one can set mutable default values like `[]` without risks (with dataclasses this can be buggy or require a factory), and one can use discriminators for sub-configs ([more on that here](howto-discriminator)). 180 | 181 | 182 | For more safety, one should set `extra="forbid"` for models as this will trigger an error as well if you instantiate an object with parameters that do not exist in the model: 183 | 184 | ```python continuation 185 | import pydantic 186 | 187 | class MyModel(pydantic.BaseModel): 188 | model_config = pydantic.ConfigDict(extra="forbid") # safer 189 | x: int 190 | y: str = "blublu" 191 | 192 | # MyModel(x=12, wrong_parameter=12) # will not work anymore 193 | ``` 194 | 195 | **Note**: adding a default infra automatically sets `extra="forbid"` as a default in the pydantic class `model_config`, as it is much safer to avoid silent errors. 196 | 197 | 198 | ### Hierarchical config 199 | 200 | One important aspects of models is that they can be composed as one model/config can contain another config. Instantiating such models is simple as the subparameters can be specified as dictionary and `pydantic will take care of transforming them into the correct class: 201 | 202 | ```python continuation 203 | class Parent(pydantic.BaseModel): 204 | model_config = pydantic.ConfigDict(extra="forbid") # safer 205 | data: MyModel 206 | 207 | obj = Parent(data={"x": 12}) 208 | assert obj.data.x == 12 209 | ``` 210 | 211 | This makes it easy to specify configs as yaml and load them into a model, eg: 212 | ```python continuation 213 | import yaml 214 | 215 | string = """ 216 | data: 217 | x: 12 218 | y: whatever 219 | """ 220 | 221 | dictconfig = yaml.safe_load(string) 222 | obj = Parent(**dictconfig) 223 | ``` 224 | 225 | 226 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /exca/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Execution and caching tool for python""" 8 | 9 | from . import helpers as helpers 10 | from .confdict import ConfDict as ConfDict 11 | from .map import MapInfra as MapInfra 12 | from .task import SubmitInfra as SubmitInfra 13 | from .task import TaskInfra as TaskInfra 14 | 15 | __version__ = "0.4.5" 16 | -------------------------------------------------------------------------------- /exca/data/cachedict2501/.cache_type: -------------------------------------------------------------------------------- 1 | NumpyMemmapArray -------------------------------------------------------------------------------- /exca/data/cachedict2501/x-9dd4e461.key: -------------------------------------------------------------------------------- 1 | x -------------------------------------------------------------------------------- /exca/data/cachedict2501/x-9dd4e461.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/exca/f01a15736dbadd0efc1366649218ab5e220fdf8b/exca/data/cachedict2501/x-9dd4e461.npy -------------------------------------------------------------------------------- /exca/data/cachedict2501/y-41529076.key: -------------------------------------------------------------------------------- 1 | y -------------------------------------------------------------------------------- /exca/data/cachedict2501/y-41529076.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/exca/f01a15736dbadd0efc1366649218ab5e220fdf8b/exca/data/cachedict2501/y-41529076.npy -------------------------------------------------------------------------------- /exca/data/compat-test-2024-11-12.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/exca/f01a15736dbadd0efc1366649218ab5e220fdf8b/exca/data/compat-test-2024-11-12.pkl -------------------------------------------------------------------------------- /exca/dumperloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import contextlib 8 | import hashlib 9 | import io 10 | import logging 11 | import os 12 | import pickle 13 | import socket 14 | import threading 15 | import typing as tp 16 | import warnings 17 | from pathlib import Path 18 | 19 | import numpy as np 20 | 21 | from . import utils 22 | 23 | X = tp.TypeVar("X") 24 | Y = tp.TypeVar("Y", bound=tp.Type[tp.Any]) 25 | logger = logging.getLogger(__name__) 26 | 27 | UNSAFE_TABLE = {ord(char): "-" for char in "/\\\n\t "} 28 | MEMMAP_ARRAY_FILE_MAX_CACHE = "EXCA_MEMMAP_ARRAY_FILE_MAX_CACHE" 29 | 30 | 31 | def _string_uid(string: str) -> str: 32 | out = string.translate(UNSAFE_TABLE) 33 | if len(out) > 80: 34 | out = out[:40] + "[.]" + out[-40:] 35 | h = hashlib.md5(string.encode("utf8")).hexdigest()[:8] 36 | return f"{out}-{h}" 37 | 38 | 39 | def host_pid() -> str: 40 | return f"{socket.gethostname()}-{threading.get_native_id()}" 41 | 42 | 43 | class DumperLoader(tp.Generic[X]): 44 | CLASSES: tp.MutableMapping[str, "tp.Type[DumperLoader[tp.Any]]"] = {} 45 | DEFAULTS: tp.MutableMapping[tp.Any, "tp.Type[DumperLoader[tp.Any]]"] = {} 46 | 47 | def __init__(self, folder: str | Path = "") -> None: 48 | self.folder = Path(folder) 49 | 50 | @contextlib.contextmanager 51 | def open(self) -> tp.Iterator[None]: 52 | yield 53 | 54 | @classmethod 55 | def __init_subclass__(cls, **kwargs: tp.Any) -> None: 56 | super().__init_subclass__(**kwargs) 57 | DumperLoader.CLASSES[cls.__name__] = cls 58 | 59 | def load(self, filename: str, **kwargs: tp.Any) -> X: 60 | raise NotImplementedError 61 | 62 | def dump(self, key: str, value: X) -> dict[str, tp.Any]: 63 | raise NotImplementedError 64 | 65 | @staticmethod 66 | def default_class(type_: Y) -> tp.Type["DumperLoader[Y]"]: 67 | Cls: tp.Any = Pickle # default 68 | try: 69 | for supported, DL in DumperLoader.DEFAULTS.items(): 70 | if issubclass(type_, supported): 71 | Cls = DL 72 | break 73 | except TypeError: 74 | pass 75 | return Cls # type: ignore 76 | 77 | @classmethod 78 | def check_valid_cache_type(cls, cache_type: str) -> None: 79 | if cache_type not in DumperLoader.CLASSES: 80 | avail = list(DumperLoader.CLASSES) 81 | raise ValueError(f"Unknown {cache_type=}, use one of {avail}") 82 | 83 | 84 | class StaticDumperLoader(DumperLoader[X]): 85 | SUFFIX = "" 86 | 87 | def load(self, filename: str) -> X: # type: ignore 88 | filepath = self.folder / filename 89 | return self.static_load(filepath) 90 | 91 | def dump(self, key: str, value: X) -> dict[str, tp.Any]: 92 | uid = _string_uid(key) 93 | filename = uid + self.SUFFIX 94 | self.static_dump(filepath=self.folder / filename, value=value) 95 | return {"filename": filename} 96 | 97 | @classmethod 98 | def static_load(cls, filepath: Path) -> X: 99 | raise NotImplementedError 100 | 101 | @classmethod 102 | def static_dump(cls, filepath: Path, value: X) -> None: 103 | raise NotImplementedError 104 | 105 | 106 | class Pickle(StaticDumperLoader[tp.Any]): 107 | SUFFIX = ".pkl" 108 | 109 | @classmethod 110 | def static_load(cls, filepath: Path) -> tp.Any: 111 | with filepath.open("rb") as f: 112 | return pickle.load(f) 113 | 114 | @classmethod 115 | def static_dump(cls, filepath: Path, value: tp.Any) -> None: 116 | with utils.temporary_save_path(filepath) as tmp: 117 | with tmp.open("wb") as f: 118 | pickle.dump(value, f) 119 | 120 | 121 | class NumpyArray(StaticDumperLoader[np.ndarray]): 122 | SUFFIX = ".npy" 123 | 124 | @classmethod 125 | def static_load(cls, filepath: Path) -> np.ndarray: 126 | return np.load(filepath) # type: ignore 127 | 128 | @classmethod 129 | def static_dump(cls, filepath: Path, value: np.ndarray) -> None: 130 | if not isinstance(value, np.ndarray): 131 | raise TypeError(f"Expected numpy array but got {value} ({type(value)})") 132 | with utils.temporary_save_path(filepath) as tmp: 133 | np.save(tmp, value) 134 | 135 | 136 | class NumpyMemmapArray(NumpyArray): 137 | 138 | @classmethod 139 | def static_load(cls, filepath: Path) -> np.ndarray: 140 | return np.load(filepath, mmap_mode="r") # type: ignore 141 | 142 | 143 | class MemmapArrayFile(DumperLoader[np.ndarray]): 144 | 145 | def __init__(self, folder: str | Path = "", max_cache: int | None = None) -> None: 146 | super().__init__(folder=folder) 147 | self._cache: dict[str, np.memmap] = {} 148 | self._f: io.BufferedWriter | None = None 149 | self._name: str | None = None 150 | if max_cache is None: 151 | max_cache = int(os.environ.get(MEMMAP_ARRAY_FILE_MAX_CACHE, 100_000)) 152 | self._max_cache = max_cache 153 | 154 | @contextlib.contextmanager 155 | def open(self) -> tp.Iterator[None]: 156 | if self._name is not None: 157 | raise RuntimeError("Cannot reopen DumperLoader context") 158 | self._name = f"{host_pid()}.data" 159 | with (self.folder / self._name).open("ab") as f: 160 | self._f = f 161 | try: 162 | yield 163 | finally: 164 | self._f = None 165 | self._name = None 166 | 167 | def load(self, filename: str, offset: int, shape: tp.Sequence[int], dtype: str) -> np.ndarray: # type: ignore 168 | shape = tuple(shape) 169 | length = np.prod(shape) * np.dtype(dtype).itemsize 170 | for _ in range(2): 171 | if filename not in self._cache: 172 | path = self.folder / filename 173 | self._cache[filename] = np.memmap(path, mode="r", order="C") 174 | memmap = self._cache[filename][offset : offset + length] 175 | if memmap.size: 176 | break 177 | # new data was added -> we need to force a reload and retry 178 | msg = "Reloading memmap file %s as offset %s is out of bound for size %s (file was updated?)" 179 | logger.debug(msg, filename, offset, self._cache[filename].size) 180 | del self._cache[filename] 181 | memmap = memmap.view(dtype=dtype).reshape(shape) 182 | if len(self._cache) > self._max_cache: 183 | self._cache.clear() 184 | return memmap 185 | 186 | def dump(self, key: str, value: np.ndarray) -> dict[str, tp.Any]: 187 | if self._f is None or self._name is None: 188 | raise RuntimeError("Need a write_mode context") 189 | if not isinstance(value, np.ndarray): 190 | raise TypeError(f"Expected numpy array but got {value} ({type(value)})") 191 | if not value.size: 192 | raise ValueError(f"Cannot dump data with no size: shape={value.shape}") 193 | offset = self._f.tell() 194 | self._f.write(np.ascontiguousarray(value).data) 195 | return { 196 | "filename": self._name, 197 | "offset": offset, 198 | "shape": tuple(value.shape), 199 | "dtype": str(value.dtype), 200 | } 201 | 202 | 203 | DumperLoader.DEFAULTS[np.ndarray] = MemmapArrayFile 204 | 205 | 206 | class DataDict(DumperLoader[dict[str, tp.Any]]): 207 | """Dumps the first level of values using the default dumper for 208 | their type""" 209 | 210 | def __init__(self, folder: str | Path = "") -> None: 211 | super().__init__(folder=folder) 212 | self._subs: dict[tp.Any, DumperLoader] = {} 213 | self._exit_stack: contextlib.ExitStack | None = None 214 | 215 | @contextlib.contextmanager 216 | def open(self) -> tp.Iterator[None]: 217 | if self._exit_stack is not None: 218 | raise RuntimeError("Cannot reopen DumperLoader context") 219 | with contextlib.ExitStack() as estack: 220 | self._exit_stack = estack 221 | try: 222 | yield 223 | finally: 224 | self._subs.clear() 225 | self._exit_stack = None 226 | 227 | def load(self, optimized: dict[str, tp.Any], pickled: dict[str, tp.Any]) -> dict[str, tp.Any]: # type: ignore 228 | output = {} 229 | for key, info in optimized.items(): 230 | loader = self.CLASSES[info["cls"]](self.folder) 231 | output[key] = loader.load(**info["info"]) 232 | if pickled: 233 | loader = Pickle(self.folder) 234 | output.update(loader.load(**pickled)) 235 | return output 236 | 237 | def dump(self, key: str, value: dict[str, tp.Any]) -> dict[str, tp.Any]: 238 | output: dict[str, dict[str, tp.Any]] = {"optimized": {}, "pickled": {}} 239 | if self._exit_stack is None: 240 | raise RuntimeError("Dict dumper is not in open context") 241 | pickled: tp.Any = {} 242 | for skey, val in value.items(): 243 | default = self.default_class(type(val)) 244 | if default.__name__ not in self._subs: 245 | sub = default(self.folder) 246 | self._exit_stack.enter_context(sub.open()) 247 | self._subs[default.__name__] = sub 248 | sub = self._subs[default.__name__] 249 | if default.__name__ != "Pickle": 250 | output["optimized"][skey] = { 251 | "cls": sub.__class__.__name__, 252 | "info": sub.dump(f"{key}(dict){skey}", val), 253 | } 254 | else: 255 | pickled[skey] = val 256 | if pickled: 257 | sub = self._subs["Pickle"] 258 | output["pickled"] = sub.dump(key, pickled) 259 | return output 260 | 261 | 262 | # making DataDict the default for dicts could generate a lot of small files for heavily nested dicts 263 | # DumperLoader.DEFAULTS[dict] = DataDict 264 | 265 | try: 266 | import pandas as pd 267 | except ImportError: 268 | pass 269 | else: 270 | 271 | class PandasDataFrame(StaticDumperLoader[pd.DataFrame]): 272 | SUFFIX = ".csv" 273 | 274 | @classmethod 275 | def static_load(cls, filepath: Path) -> pd.DataFrame: 276 | return pd.read_csv( 277 | filepath, index_col=0, keep_default_na=False, na_values=[""] 278 | ) 279 | 280 | @classmethod 281 | def static_dump(cls, filepath: Path, value: pd.DataFrame) -> None: 282 | with utils.temporary_save_path(filepath) as tmp: 283 | value.to_csv(tmp, index=True) 284 | 285 | DumperLoader.DEFAULTS[pd.DataFrame] = PandasDataFrame 286 | 287 | try: 288 | # pylint: disable=unused-import 289 | import pyarrow # noqa 290 | except ImportError: 291 | pass 292 | else: 293 | 294 | class ParquetPandasDataFrame(StaticDumperLoader[pd.DataFrame]): 295 | SUFFIX = ".parquet" 296 | 297 | @classmethod 298 | def static_load(cls, filepath: Path) -> pd.DataFrame: 299 | if not filepath.exists(): 300 | # fallback to csv for compatibility when updating to parquet 301 | return PandasDataFrame.static_load(filepath.with_suffix(".csv")) 302 | return pd.read_parquet(filepath, dtype_backend="numpy_nullable") 303 | 304 | @classmethod 305 | def static_dump(cls, filepath: Path, value: pd.DataFrame) -> None: 306 | with utils.temporary_save_path(filepath) as tmp: 307 | value.to_parquet(tmp) 308 | 309 | 310 | try: 311 | import mne 312 | except ImportError: 313 | pass 314 | else: 315 | 316 | class MneRawFif(StaticDumperLoader[mne.io.Raw]): 317 | SUFFIX = "-raw.fif" 318 | 319 | @classmethod 320 | def static_load(cls, filepath: Path) -> mne.io.Raw: 321 | try: 322 | return mne.io.read_raw_fif(filepath, verbose=False, allow_maxshield=False) 323 | except ValueError: 324 | raw = mne.io.read_raw_fif(filepath, verbose=False, allow_maxshield=True) 325 | msg = "MaxShield data detected, consider applying Maxwell filter and interpolating bad channels" 326 | warnings.warn(msg) 327 | return raw 328 | 329 | @classmethod 330 | def static_dump(cls, filepath: Path, value: mne.io.Raw) -> None: 331 | with utils.temporary_save_path(filepath) as tmp: 332 | value.save(tmp) 333 | 334 | DumperLoader.DEFAULTS[(mne.io.Raw, mne.io.RawArray)] = MneRawFif 335 | DumperLoader.CLASSES["MneRaw"] = MneRawFif # for backwards compatibility 336 | 337 | 338 | try: 339 | # pylint: disable=unused-import 340 | import mne 341 | import pybv # noqa 342 | from mne.io.brainvision.brainvision import RawBrainVision 343 | except ImportError: 344 | pass 345 | else: 346 | 347 | Raw = mne.io.Raw | RawBrainVision 348 | 349 | class MneRawBrainVision(DumperLoader[Raw]): 350 | 351 | def dump(self, key: str, value: X) -> dict[str, tp.Any]: 352 | uid = _string_uid(key) 353 | fp = self.folder / uid / f"{uid}-raw.vhdr" 354 | with utils.temporary_save_path(fp) as tmp: 355 | mne.export.export_raw(tmp, value, fmt="brainvision", verbose="ERROR") 356 | return {"filename": uid} 357 | 358 | def load(self, filename: str) -> Raw: # type: ignore 359 | fp = self.folder / filename / f"{filename}-raw.vhdr" 360 | return mne.io.read_raw_brainvision(fp, verbose=False) 361 | 362 | DumperLoader.DEFAULTS[RawBrainVision] = MneRawBrainVision 363 | 364 | 365 | try: 366 | import nibabel 367 | except ImportError: 368 | pass 369 | else: 370 | 371 | Nifti = ( 372 | nibabel.Nifti1Image | nibabel.Nifti2Image | nibabel.filebasedimages.FileBasedImage 373 | ) 374 | 375 | class NibabelNifti(StaticDumperLoader[Nifti]): 376 | SUFFIX = ".nii.gz" 377 | 378 | @classmethod 379 | def static_load(cls, filepath: Path) -> Nifti: 380 | return nibabel.load(filepath, mmap=True) 381 | 382 | @classmethod 383 | def static_dump(cls, filepath: Path, value: Nifti) -> None: 384 | with utils.temporary_save_path(filepath) as tmp: 385 | nibabel.save(value, tmp) 386 | 387 | DumperLoader.DEFAULTS[(nibabel.Nifti1Image, nibabel.Nifti2Image)] = NibabelNifti 388 | 389 | 390 | try: 391 | import torch 392 | except ImportError: 393 | pass 394 | else: 395 | 396 | def is_view(x: torch.Tensor) -> bool: 397 | """Check if the tensor is a view by checking if it is contiguous and has 398 | same size as storage. 399 | 400 | Note 401 | ---- 402 | dumping the view of a slice dumps the full underlying storage, so it is 403 | safer to clone beforehand 404 | """ 405 | storage_size = len(x.untyped_storage()) // x.dtype.itemsize 406 | return storage_size != x.numel() or not x.is_contiguous() 407 | 408 | class TorchTensor(StaticDumperLoader[torch.Tensor]): 409 | SUFFIX = ".pt" 410 | 411 | @classmethod 412 | def static_load(cls, filepath: Path) -> torch.Tensor: 413 | return torch.load(filepath, map_location="cpu", weights_only=True) # type: ignore 414 | 415 | @classmethod 416 | def static_dump(cls, filepath: Path, value: torch.Tensor) -> None: 417 | if not isinstance(value, torch.Tensor): 418 | raise TypeError(f"Expected torch Tensor but got {value} ({type(value)}") 419 | if is_view(value): 420 | value = value.clone() 421 | with utils.temporary_save_path(filepath) as tmp: 422 | torch.save(value.detach().cpu(), tmp) 423 | 424 | DumperLoader.DEFAULTS[torch.Tensor] = TorchTensor 425 | -------------------------------------------------------------------------------- /exca/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import inspect 8 | import logging 9 | import shutil 10 | import subprocess 11 | import typing as tp 12 | from pathlib import Path 13 | 14 | import pydantic 15 | import submitit 16 | 17 | from exca.confdict import ConfDict 18 | from exca.task import TaskInfra 19 | 20 | # pylint: disable=typevar-name-incorrect-variance 21 | X = tp.TypeVar("X", covariant=True) 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class FuncConfigProtocol(tp.Protocol[X]): 26 | model_fields: tp.ClassVar[tp.Dict[str, pydantic.fields.FieldInfo]] 27 | infra: TaskInfra 28 | 29 | def __init__(self, *, infra: TaskInfra, **kwargs: tp.Any) -> None: ... 30 | 31 | def build(self) -> X: ... 32 | 33 | 34 | class FuncConfig(pydantic.BaseModel): 35 | infra: TaskInfra = TaskInfra(version="1") 36 | model_config = pydantic.ConfigDict( 37 | extra="forbid", arbitrary_types_allowed=True, protected_namespaces=("model_conf",) 38 | ) 39 | # as a tuple to avoid getting bounded 40 | _func: tp.ClassVar[tp.Tuple[tp.Callable[..., tp.Any]]] 41 | 42 | @infra.apply 43 | def build(self) -> tp.Any: 44 | """Build the underlying buildable object for this config""" 45 | params = { 46 | name: getattr(self, name) for name in self.model_fields if name != "infra" 47 | } 48 | return self._func[0](**params) 49 | 50 | def __reduce__(self) -> tp.Any: 51 | params = self.model_dump() 52 | return (_unpickle_cfg, (self._func[0], params)) 53 | 54 | 55 | def _unpickle_cfg( 56 | func: tp.Callable[..., X], kwargs: tp.Dict[str, tp.Any] 57 | ) -> FuncConfigProtocol[X]: 58 | return to_config(func, **kwargs) 59 | 60 | 61 | def to_config_model(func: tp.Callable[..., X]) -> tp.Type[FuncConfigProtocol[X]]: 62 | """Create a pydantic model based on a function, with an additional infra 63 | argument for caching and remove configuration 64 | 65 | Example 66 | ------- 67 | def my_func(a: int, b: int) -> np.ndarray: 68 | return np.random.rand(a, b) 69 | 70 | Conf = helpers.to_config_model(my_func) 71 | conf = Conf(a=3, b=4, infra={"folder": tmp_path}) # type: ignore 72 | out1 = conf.build() 73 | """ 74 | params = {} 75 | for p in inspect.signature(func).parameters.values(): 76 | if p.name == "infra": 77 | raise ValueError("Cannot add 'infra' parameter as it already exists") 78 | if p.annotation in (tp.Any, inspect._empty): 79 | raise ValueError( 80 | f"Cannot make config for {func!r} because parameter {p.name!r}" 81 | f" is missing a precise type (found '{p.annotation}')." 82 | ) 83 | default = Ellipsis if p.default is inspect._empty else p.default 84 | params[p.name] = (p.annotation, default) 85 | # create 86 | Model = pydantic.create_model( # type: ignore 87 | func.__name__ + "_FuncConfig", 88 | **params, 89 | __base__=FuncConfig, 90 | __module__=func.__module__, 91 | ) 92 | Model._func = (func,) 93 | return Model # type: ignore 94 | 95 | 96 | def to_config(func: tp.Callable[..., X], **kwargs: tp.Any) -> FuncConfigProtocol[X]: 97 | """Create a pydantic configuration based on a function and its arguments, 98 | including additional "infra" argument to specify caching and remove 99 | computation behaviors. 100 | 101 | Example 102 | ------- 103 | def my_func(a: int, b: int) -> np.ndarray: 104 | return np.random.rand(a, b) 105 | 106 | conf = helpers.to_config(my_func, a=3, b=4, infra={"folder": tmp_path}) 107 | out1 = conf.build() 108 | """ 109 | Cfg = to_config_model(func) 110 | return Cfg(**kwargs) 111 | 112 | 113 | class FunctionWithInfra(tp.Generic[X]): 114 | def __init__( 115 | self, func: tp.Callable[..., X], infra: TaskInfra | tp.Dict[str, tp.Any] 116 | ) -> None: 117 | self.infra = infra 118 | self.func = func 119 | 120 | def config(self, **kwargs: tp.Any) -> FuncConfigProtocol[X]: 121 | return to_config(self.func, infra=self.infra, **kwargs) 122 | 123 | def __call__(self, **kwargs: tp.Any) -> X: 124 | return self.config(**kwargs).build() 125 | 126 | def __repr__(self) -> str: 127 | name = with_infra.__name__ 128 | return f"{name}({self.infra})({self.func!r})" 129 | 130 | 131 | class with_infra: 132 | """Decorator for adding an infra to a function 133 | 134 | Usage 135 | ----- 136 | .. code-block:: python 137 | 138 | @with_infra(folder="whatever") 139 | def my_func(....) 140 | ... 141 | 142 | or directly :code:`my_func = with_infra(folder="whavetever")(my_func)` 143 | then the function will always use this infra. 144 | """ 145 | 146 | def __init__(self, **kwargs: tp.Any) -> None: 147 | infra = TaskInfra(**kwargs) # check that it's correct 148 | if infra.folder is None and infra.cluster is None: 149 | logger.warning( 150 | "Infra is not used as infra cluster=None (so remote computing is deactivated) and " 151 | "folder=None (so caching is deactivated)" 152 | ) 153 | self.infra = kwargs 154 | 155 | def __call__(self, func: tp.Callable[..., X]) -> FunctionWithInfra[X]: 156 | return FunctionWithInfra(func, self.infra) 157 | 158 | 159 | def validate_kwargs(func: tp.Callable[..., tp.Any], kwargs: tp.Dict[str, tp.Any]) -> None: 160 | """Validates mandatory/extra args and basic types (str/int/float) 161 | 162 | Parameters 163 | ---------- 164 | func: Callable 165 | callable to be called with the kwargs 166 | kwargs: dict 167 | keyword arguments to check for the function 168 | """ 169 | 170 | params = inspect.signature(func).parameters 171 | has_kwargs = any(p.kind == p.VAR_KEYWORD for p in params.values()) 172 | params = {name: p for name, p in params.items() if p.kind != p.VAR_KEYWORD} # type: ignore 173 | # check for missing parameters 174 | mandatory = {p.name for p in params.values() if p.default is inspect._empty} 175 | missing = mandatory - set(kwargs) 176 | if missing: 177 | raise ValueError(f"Missing parameter(s) for {func}: {missing}") 178 | # check for extra parameters (in case there is no **kwargs) 179 | if not has_kwargs: 180 | additional = set(kwargs) - set(params.keys()) 181 | if additional: 182 | raise ValueError(f"Extra parameter(s) for {func}: {additional}") 183 | # check for correct types (only basic ones) 184 | for name, val in kwargs.items(): 185 | if name in params: # in case of **kwargs, it may not exist 186 | annot = params[name].annotation 187 | if annot in (bool, str, int, float) and not isinstance(val, annot): 188 | raise TypeError( 189 | f"Wrong type {type(val)} for {name!r} in {func} (expected {annot})" 190 | ) 191 | 192 | 193 | # only used for typing, this is a bit hacky but convenient 194 | class InfraSlurmJob(submitit.SlurmJob[tp.Any]): 195 | # pylint: disable=super-init-not-called 196 | def __init__(self) -> None: 197 | self.config: ConfDict 198 | self.uid_config: ConfDict 199 | 200 | 201 | def find_slurm_job( 202 | *, job_id: str, folder: str | Path | None = None 203 | ) -> InfraSlurmJob | None: 204 | r"""Attemps to instantiate a submitit.SlurmJob instance from a cache folder and a `job_id`, 205 | looking for it recursively. 206 | This is based on default configuration of the log folder position 207 | (:code:`/logs//`), and some additional heuristic that may be 208 | invalid in other pipelines (skipping logs/wandb folders) so this can fail 209 | with other configurations and may need adaptations, but should answer 95% of cases. 210 | 211 | Parameters 212 | ---------- 213 | job_id: str 214 | the job id 215 | folder: str, Path or None 216 | the path of the cache folder. If None, scontrol will be called to try and identify it 217 | automatically (will fail for non-running jobs) 218 | 219 | Notes 220 | ----- 221 | - a :code:`submitit.Job` instance has: 222 | - :code:`job.paths.stderr/stdout`: pathlib.Path of the logs 223 | - :code:`job.stderr()/stdout()`: string of the logs 224 | - :code:`job.result()`: output of the job (waits in not completed, raises if error) 225 | - :code:`job.done()`: True if job is completed 226 | 227 | - On top of it, the returned job has attributes: 228 | - :code:`config`: the full configuration of the job 229 | - :code:`uid_config`: the non default uid configuration of the job 230 | 231 | - The search assumes there is only one "logs" folder in the path 232 | (as we assume the default configuration of the logs path) and will probably 233 | fail if the cache folder contains /logs/ in it It also assumes there is no /code/ in it. 234 | 235 | - Get the err using this line: :code:`out = job.stderr().split("\\n")` 236 | 237 | 238 | Example 239 | ------- 240 | 241 | .. code-block:: python 242 | 243 | job = find_slurm_job(job_id=job_id, folder=my_folder) 244 | print(job.uid_config) # see uid (= simplified) config for this job 245 | print(job.stdout()) # print stdout of the job 246 | 247 | """ 248 | if folder is None: 249 | try: 250 | out = subprocess.check_output( 251 | ["scontrol", "show", "job", job_id], shell=False 252 | ).decode("utf8") 253 | except subprocess.CalledProcessError as e: 254 | raise ValueError("Please provide a folder for non-running jobs") from e 255 | tok = "StdErr=" 256 | lines = [x.strip() for x in out.splitlines() if x.strip().startswith(tok)] 257 | folder = Path(lines[0].replace(tok, "")).parents[3] 258 | folder = Path(folder) 259 | if any(x in folder.parts for x in ["code", "wandb"]): 260 | return None 261 | # if all these files are present, this is the cache folder: 262 | if all((folder / name).exists() for name in ["config.yaml", "uid.yaml"]): 263 | # avoid checking the cache folder as this is extra slow 264 | # task Vs batch 265 | part = "submitit" if (folder / "submitit").exists() else f"logs/*/{job_id}" 266 | for fp in folder.glob(f"{part}/{job_id}_*.out"): 267 | job: tp.Any = submitit.SlurmJob(folder=fp.resolve().parent, job_id=job_id) 268 | assert job.paths.stdout.exists(), f"Expected existence of {job.paths.stdout}" 269 | for name in ("config", "uid"): 270 | fp = folder / (name + ".yaml") 271 | conf = ConfDict.from_yaml(fp) 272 | setattr(job, name if name == "config" else "uid_config", conf) 273 | return job # type: ignore 274 | return None 275 | 276 | for sub in folder.iterdir(): 277 | if not sub.is_dir(): 278 | continue 279 | if folder.parent.name == "logs": 280 | if all( 281 | x.isdigit() for x in sub.name.split("_") 282 | ): # looks like a submitit job folder 283 | if any(sub.glob("*_submitted.pkl")): # definitely is one 284 | return None # stop iteratoring through this log folder 285 | job = find_slurm_job(folder=sub, job_id=job_id) 286 | if job is not None: 287 | return job 288 | return None 289 | 290 | 291 | def update_uids(folder: str | Path, dryrun: bool = True): 292 | folder = Path(folder) 293 | if any(x in folder.parts for x in ["code", "wandb", "logs"]): 294 | return None 295 | # if all these files are present, this is the cache folder: 296 | if not all((folder / name).exists() for name in ["config.yaml", "uid.yaml"]): 297 | # avoid checking the cache folder as this is extra slow 298 | # task Vs batch 299 | for sub in folder.iterdir(): 300 | if sub.is_dir(): 301 | update_uids(sub, dryrun=dryrun) 302 | return None 303 | cd = ConfDict.from_yaml(folder / "uid.yaml") 304 | old = cd.to_uid(version=2) 305 | new = cd.to_uid() 306 | if new in str(folder): 307 | return # all good 308 | if old not in str(folder): 309 | if folder.name != "default": 310 | msg = "CAUTION: folder name %s does not match old uid pattern %s nor new %s" 311 | logger.warning(msg, folder.name, old, new) 312 | return 313 | newfolder = Path(str(folder).replace(old, new)) 314 | msg = "Automatically updating folder name to new uid: '%s' -> '%s'" 315 | if dryrun: 316 | msg += " (dry run)" 317 | logger.warning(msg, folder, newfolder) 318 | if not dryrun: 319 | shutil.move(folder, newfolder) 320 | -------------------------------------------------------------------------------- /exca/logconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | # # # # # CONFIGURE LOGGER # # # # # 10 | 11 | logger = logging.getLogger("exca") 12 | _handler = logging.StreamHandler() 13 | _formatter = logging.Formatter( 14 | "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S" 15 | ) 16 | _handler.setFormatter(_formatter) 17 | logger.addHandler(_handler) 18 | logger.setLevel(logging.INFO) 19 | -------------------------------------------------------------------------------- /exca/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/exca/f01a15736dbadd0efc1366649218ab5e220fdf8b/exca/py.typed -------------------------------------------------------------------------------- /exca/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import contextlib 8 | import functools 9 | import getpass 10 | import logging 11 | import os 12 | import pickle 13 | import sys 14 | import typing as tp 15 | import uuid 16 | from datetime import datetime 17 | from pathlib import Path 18 | 19 | import pydantic 20 | import submitit 21 | from submitit.core import utils as submitit_utils 22 | 23 | from . import base 24 | from .workdir import WorkDir 25 | 26 | submitit.Job._results_timeout_s = 4 # avoid too long a wait 27 | SUBMITIT_EXECUTORS = ("auto", "local", "slurm", "debug") 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def _pickle_dump_override(obj: tp.Any, filename: str | Path) -> None: 32 | """Override for submitit cloudpickle dump to be compatible 33 | with other python version when using a different conda env""" 34 | with Path(filename).open("wb") as ofile: 35 | pickle.dump(obj, ofile, protocol=4) 36 | 37 | 38 | class SubmititMixin(pydantic.BaseModel): 39 | """Mixin class for creating a submitit runner infra 40 | 41 | Parameters 42 | ---------- 43 | folder: optional Path or str 44 | Path to directory for dumping/loading the cache on disk, if provided 45 | cluster: optional str 46 | Where to run the computation, one of: 47 | - :code:`None`: runs in the current thread 48 | - :code:`"debug"`: submitit debug executor (runs in the current process with `ipdb`) 49 | - :code:`"local"`: submitit local executor (runs in a dedicated subprocess) 50 | - :code:`"slurm"`: submitit slurm executor (runs in a slurm cluster) 51 | - :code:`"auto"`: submitit auto executor (uses slurm if available, otherwise local) 52 | logs: Path or str 53 | path to the logs for slurm/local jobs. One can use :code:`{folder}` in the string 54 | to define logs as a subfolder of the storage folder, :code:`{user}` for the user name 55 | and :code:`%j` (slurm syntax) for the job id 56 | workdir: optional :class:`exca.workdir.WorkDir` 57 | pydantic config defining whether and how to copy the current workspace to a directory specific for the job 58 | and avoid interferences when working on the code. See :class:`exca.workdir.WorkDir` for details. 59 | name: optional str 60 | name of the job 61 | timeout_min: optional int 62 | timeout for slurm/local jobs 63 | nodes: optional int 64 | number of nodes for slurm jobs 65 | tasks_per_node: optional int 66 | number of task nodes for slurm jobs 67 | cpus_per_task: optional int 68 | number of cpus per task for slurm jobs 69 | gpus_per_node: optional int 70 | number of gpus per node for slurm jobs 71 | mem_gb: float 72 | RAM memory to be used in GB 73 | slurm_constraint: optional str 74 | node constraint for the job 75 | slurm_account: optional str 76 | account to use for the job 77 | slurm_qos: optional str 78 | qos to use for the job 79 | slurm_partition: optional str 80 | partition for the slurm job 81 | slurm_use_srun: bool 82 | use srun in the sbatch file. This is the default in submitit, but not adviced 83 | for jobs triggering more jobs. 84 | slurm_additional_parameters: optional dict 85 | additional parameters for slurm that are not first class parameters of this config 86 | conda_env: optional str/path 87 | path or name of a conda environment to use in the job. Note that as submitit uses a pickle 88 | that needs to be loaded in the job with the new conda env, the pickle needs to be 89 | compatible. This mostly means that if the env has a different pydantic 90 | version, the job may fail to reload it. Additionally, to allow for different python 91 | versions, the job is dumped with pickle and not cloudpickle, so inline functions 92 | (defined in main or in a notebook) will not be supported. 93 | """ 94 | 95 | folder: Path | str | None = None 96 | cluster: tp.Literal[None, "auto", "local", "slurm", "debug"] = None 97 | # {folder} will be replaced by the class instance folder 98 | # {user} by user id and %j by job id 99 | logs: Path | str = "{folder}/logs/{user}/%j" 100 | # main params 101 | job_name: str | None = None 102 | timeout_min: int | None = None 103 | nodes: int | None = 1 104 | tasks_per_node: int | None = 1 105 | cpus_per_task: int | None = None 106 | gpus_per_node: int | None = None 107 | mem_gb: float | None = None 108 | max_pickle_size_gb: float | None = None 109 | # slurm specifics 110 | slurm_constraint: str | None = None 111 | slurm_partition: str | None = None 112 | slurm_account: str | None = None 113 | slurm_qos: str | None = None 114 | slurm_use_srun: bool = False 115 | slurm_additional_parameters: tp.Dict[str, int | str | float | bool] | None = None 116 | # other 117 | conda_env: Path | str | None = None # conda env name or path 118 | workdir: None | WorkDir = None 119 | 120 | def model_post_init(self, log__: tp.Any) -> None: 121 | super().model_post_init(log__) 122 | if not isinstance(self, base.BaseInfra): 123 | raise RuntimeError("SubmititMixin should be set a BaseInfra mixin") 124 | if self.folder is None: 125 | if self.cluster in SUBMITIT_EXECUTORS: 126 | raise ValueError( 127 | f"cluster={self.cluster} requires a folder to be provided, " 128 | "only cluster=None works without folder" 129 | ) 130 | if self.workdir is not None: 131 | raise ValueError("Workdir requires a folder") 132 | if self.tasks_per_node > 1 and not self.slurm_use_srun: 133 | if self.cluster in ["slurm", "auto"]: 134 | msg = "Currently you must set slurm_use_srun=True if tasks_per_node > 1\n" 135 | msg += "(this implies that your job won't be able to run spawn sub-jobs)" 136 | raise ValueError(msg) 137 | if self.conda_env is not None: 138 | acceptable = list(SUBMITIT_EXECUTORS) 139 | acceptable.remove("debug") # not reloading the environment 140 | if self.cluster not in acceptable: 141 | msg = f"Cannot specify a conda env for cluster {self.cluster}, acceptable: {acceptable}" 142 | raise ValueError(msg) 143 | 144 | def executor(self) -> None | submitit.AutoExecutor: 145 | if self.cluster not in SUBMITIT_EXECUTORS: 146 | return None 147 | cluster: str | None = "debug" if self.cluster is None else self.cluster 148 | if cluster == "auto": 149 | cluster = None 150 | logpath = self._log_path() 151 | executor = submitit.AutoExecutor(folder=logpath, cluster=cluster) 152 | if self.max_pickle_size_gb is not None: 153 | sub = executor._executor 154 | if hasattr(sub, "max_pickle_size_gb"): 155 | sub.max_pickle_size_gb = self.max_pickle_size_gb # type: ignore 156 | non_submitit = { 157 | "cluster", 158 | "logs", 159 | "conda_env", 160 | "workdir", 161 | "folder", 162 | "max_pickle_size_gb", 163 | } 164 | fields = set(SubmititMixin.model_fields) - non_submitit # type: ignore 165 | _missing = base.Sentinel() # for backward compatibility when adding a new param 166 | params = {name: getattr(self, name, _missing) for name in fields} 167 | params = {name: y for name, y in params.items() if y is not _missing} 168 | params["name"] = params.pop("job_name") 169 | params = {name: val for name, val in params.items() if val is not None} 170 | executor.update_parameters(**params) 171 | if self.conda_env is not None: 172 | # find python executable path 173 | envpath = Path(self.conda_env) 174 | if not envpath.exists(): # not absolute 175 | current_python = Path(sys.executable) 176 | if current_python.parents[2].name != "envs": 177 | msg = f"Assumed running in a conda env but structure is weird {current_python=}" 178 | raise RuntimeError(msg) 179 | envpath = current_python.parents[2] / self.conda_env 180 | pythonpath = envpath / "bin" / "python" 181 | # use env's python 182 | sub = executor 183 | if isinstance(sub, submitit.AutoExecutor): 184 | # pylint: disable=protected-access 185 | sub = executor._executor # type: ignore 186 | if not hasattr(sub, "python"): 187 | raise RuntimeError(f"Cannot set python executable on {executor=}") 188 | 189 | sub.python = str(pythonpath) # type: ignore 190 | if self.job_name is None and executor is not None: 191 | if isinstance(self, base.BaseInfra): 192 | cname = self._obj.__class__.__name__ 193 | name = cname + self.uid().split(cname, maxsplit=1)[-1] # shorter uid 194 | executor.update_parameters(name=name) 195 | return executor 196 | 197 | def _log_path(self) -> Path: 198 | if self.logs is None: 199 | raise RuntimeError("No log path provided") 200 | return Path(str(self.logs).replace("{user}", getpass.getuser())) 201 | 202 | @contextlib.contextmanager 203 | def _work_env(self) -> tp.Iterator[None]: 204 | """Clean slurm environment variable and create change to clean/copied workspace""" 205 | if not isinstance(self, base.BaseInfra): 206 | raise RuntimeError("SubmititMixin should be set a BaseInfra mixin") 207 | with contextlib.ExitStack() as estack: 208 | estack.enter_context(submitit.helpers.clean_env()) 209 | if self.workdir is not None: 210 | if self.workdir.folder is None: 211 | if self.folder is None: 212 | raise ValueError("Workdir requires a folder") 213 | today = datetime.now().strftime("%Y-%m-%d") 214 | tag = f"{today}-{uuid.uuid4().hex[:6]}" 215 | uid_folder = self.uid_folder() 216 | assert uid_folder is not None # for typing 217 | parts = uid_folder.relative_to(self.folder).parts 218 | # default to first sub-directory 219 | folder = Path(self.folder) / parts[0] / "code" / tag 220 | folder.parent.mkdir(parents=True, exist_ok=True) 221 | # bypasses freezing checks: 222 | object.__setattr__(self.workdir, "folder", folder) 223 | estack.enter_context(self.workdir.activate()) 224 | base_dump: tp.Any = None 225 | if self.conda_env is not None: 226 | base_dump = submitit_utils.cloudpickle_dump 227 | # replace to allow for python inter-version compatibility 228 | submitit_utils.cloudpickle_dump = _pickle_dump_override 229 | try: 230 | yield 231 | finally: 232 | if base_dump is not None: 233 | submitit_utils.cloudpickle_dump = base_dump 234 | 235 | def _run_method(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: 236 | if not isinstance(self, base.BaseInfra): 237 | raise RuntimeError("This can only run on BaseInfra subclasses") 238 | if self.workdir is not None: 239 | logger.info("Running function from '%s'", os.getcwd()) 240 | if self._infra_method is None: 241 | raise RuntimeError("Infra not correctly applied to a method") 242 | method = self._infra_method.method 243 | if not isinstance(method, staticmethod): 244 | method = functools.partial(self._infra_method.method, self._obj) 245 | return method(*args, **kwargs) 246 | -------------------------------------------------------------------------------- /exca/test_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | import subprocess 9 | import tempfile 10 | import typing as tp 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import pydantic 15 | import pytest 16 | 17 | from exca import ConfDict 18 | 19 | from .task import TaskInfra 20 | from .workdir import WorkDir 21 | 22 | 23 | class Base(pydantic.BaseModel): 24 | infra: TaskInfra = TaskInfra(version="12") 25 | param: int = 12 26 | tag: str = "whatever" 27 | 28 | @infra.apply 29 | def func(self) -> int: 30 | return 2 * self.param 31 | 32 | 33 | class SubInfra(Base): 34 | infra: TaskInfra = TaskInfra() 35 | 36 | 37 | class SubFunc(Base): 38 | 39 | def func(self) -> int: 40 | return 3 * super().func() 41 | 42 | 43 | class SubInfraFunc(Base): 44 | infra: TaskInfra = TaskInfra() 45 | 46 | @infra.apply(exclude_from_cache_uid=("tag",)) 47 | def func(self) -> int: 48 | return 3 * super().func() 49 | 50 | 51 | class SubInfra2Func(Base): 52 | infra2: TaskInfra = TaskInfra() 53 | 54 | @infra2.apply(exclude_from_cache_uid=("tag",)) 55 | def func(self) -> int: 56 | return 3 * super().func() 57 | 58 | 59 | class SubBase(Base): 60 | pass 61 | 62 | 63 | def test_subclass_infra(tmp_path: Path) -> None: 64 | whatever = SubInfra(param=13, tag="hello", infra={"folder": tmp_path}) # type: ignore 65 | with pytest.raises(RuntimeError): 66 | # infra is not connected 67 | _ = whatever.func() 68 | 69 | 70 | def test_subclass_func(tmp_path: Path) -> None: 71 | whatever = SubFunc(param=13, tag="hello", infra={"folder": tmp_path}) # type: ignore 72 | assert whatever.func() == 78 73 | names = [fp.name for fp in tmp_path.iterdir()] 74 | assert tuple("Base.func" in n for n in names) == (True,) 75 | 76 | 77 | def test_subclass_infra_func(tmp_path: Path) -> None: 78 | whatever = SubInfraFunc(param=13, tag="hello", infra={"folder": tmp_path}) # type: ignore 79 | assert whatever.func() == 78 80 | names = [fp.name for fp in tmp_path.iterdir()] 81 | assert tuple("SubInfraFunc.func" in n for n in names) == (True,) 82 | 83 | 84 | def test_subclass_infra2_func(tmp_path: Path) -> None: 85 | whatever = SubInfra2Func(param=13, tag="hello", infra={"folder": tmp_path}, infra2={"folder": tmp_path}) # type: ignore 86 | assert whatever.func() == 78 87 | names = sorted(fp.name for fp in tmp_path.iterdir()) 88 | assert tuple("SubInfra2Func.func" in n for n in names) == (False, True), names 89 | assert tuple("Base.func" in n for n in names) == (True, False), names 90 | 91 | 92 | class BaseRaw(pydantic.BaseModel): 93 | # only derived, never used independently 94 | infra: TaskInfra = TaskInfra(version="12") 95 | param: int = 12 96 | tag: str = "whatever" 97 | 98 | @infra.apply 99 | def func(self) -> int: 100 | return 2 * self.param 101 | 102 | 103 | class SubInfraFuncRaw(BaseRaw): 104 | infra: TaskInfra = TaskInfra() 105 | 106 | @infra.apply(exclude_from_cache_uid=("tag",)) 107 | def func(self) -> int: 108 | return 3 * super().func() 109 | 110 | 111 | def test_subclass_infra_func_raw(tmp_path: Path) -> None: 112 | whatever = SubInfraFuncRaw(param=13, tag="hello", infra={"folder": tmp_path}) # type: ignore 113 | assert whatever.func() == 78 114 | names = [fp.name for fp in tmp_path.iterdir()] 115 | assert tuple("SubInfraFuncRaw.func" in n for n in names) == (True,) 116 | 117 | 118 | @pytest.mark.parametrize( 119 | "cls,name", 120 | [ 121 | (Base, "Base.func,12"), 122 | (SubFunc, "Base.func,12"), # function is overriden 123 | (SubInfra2Func, "Base.func,12"), # function is overriden 124 | (SubBase, "SubBase.func,12"), # function is the same -> get new class 125 | ], 126 | ) 127 | def test_cache_names(cls: tp.Type[Base], name: str, tmp_path: Path) -> None: 128 | base = Base.__module__ + "." # <...>.test_base 129 | whatever = cls(infra={"folder": tmp_path}) # type: ignore 130 | assert whatever.infra.uid() == base + name + "/default" 131 | 132 | 133 | # END OF SUBCLASSING TEST 134 | 135 | 136 | class MyCfg(pydantic.BaseModel): 137 | infra: TaskInfra = TaskInfra(gpus_per_node=12) 138 | param: int = 12 139 | other: int = 12 140 | 141 | def exclude(self) -> tp.Tuple[str]: 142 | return ("other",) 143 | 144 | @infra.apply(exclude_from_cache_uid=exclude) 145 | def func(self) -> np.ndarray: 146 | return np.random.rand(3, 4) 147 | 148 | 149 | class MyCfg2(MyCfg): 150 | infra: TaskInfra = TaskInfra(gpus_per_node=12) 151 | 152 | @infra.apply(exclude_from_cache_uid="method:exclude") 153 | def func(self) -> np.ndarray: 154 | return np.random.rand(3, 4) 155 | 156 | 157 | class MyCfg3(MyCfg): 158 | infra: TaskInfra = TaskInfra(gpus_per_node=12) 159 | 160 | @infra.apply(exclude_from_cache_uid="method:does_not_exist") 161 | def func(self) -> np.ndarray: 162 | return np.random.rand(3, 4) 163 | 164 | 165 | @pytest.mark.parametrize("Cfg", (MyCfg, MyCfg2)) 166 | def test_exclude_func(tmp_path: Path, Cfg: tp.Type[MyCfg]) -> None: 167 | cfg = Cfg(infra={"folder": tmp_path}) # type: ignore 168 | cfgp = cfg.infra.clone_obj(param=13) 169 | cfgo = cfg.infra.clone_obj(other=13) 170 | np.testing.assert_array_equal(cfg.func(), cfgo.func()) 171 | with pytest.raises(AssertionError): 172 | np.testing.assert_array_equal(cfg.func(), cfgp.func()) 173 | 174 | 175 | def test_exclude_func_errors() -> None: 176 | cfg = MyCfg3() 177 | with pytest.raises(RuntimeError): 178 | _ = cfg.infra.config() 179 | 180 | with pytest.raises(TypeError): 181 | 182 | # pylint: disable=unused-variable 183 | class MyCfg4(MyCfg): 184 | infra: TaskInfra = TaskInfra(gpus_per_node=12) 185 | 186 | @infra.apply(exclude_from_cache_uid="bad-format-string") 187 | def func(self) -> np.ndarray: 188 | return np.random.rand(3, 4) 189 | 190 | 191 | def test_infra_default_propagation(tmp_path: Path) -> None: 192 | cfg = MyCfg(infra={"folder": tmp_path}) # type: ignore 193 | assert cfg.infra.gpus_per_node == 12 194 | 195 | 196 | def test_uid_string() -> None: 197 | cfg = MyCfg() 198 | cfg.infra._uid_string = "blublu" 199 | with pytest.raises(ValueError): 200 | cfg.infra.uid() 201 | cfg.infra._uid_string = "{method}@{version}-{uid}" 202 | assert cfg.infra.uid().endswith("MyCfg.func@0-default") 203 | # also check equality 204 | assert cfg.infra == cfg.infra 205 | 206 | 207 | def test_hidden_infra(tmp_path: Path) -> None: 208 | class Hidden(pydantic.BaseModel): 209 | _infra: TaskInfra = TaskInfra(folder=tmp_path) 210 | param: int = 12 211 | 212 | @_infra.apply 213 | def func(self) -> int: 214 | return 2 * self.param 215 | 216 | obj = Hidden(param=13) 217 | assert obj._infra is not Hidden._infra 218 | assert obj.func() == 26 219 | names = [fp.name for fp in tmp_path.iterdir()] 220 | assert tuple("Hidden.func" in n for n in names) == (True,), names 221 | 222 | 223 | class Copied(pydantic.BaseModel): 224 | _infra: TaskInfra = TaskInfra() 225 | infra: TaskInfra = TaskInfra() 226 | param: int = 12 227 | 228 | def model_post_init(self, log__: tp.Any) -> None: 229 | super().model_post_init(log__) 230 | self._infra._update(self.infra) 231 | 232 | @infra.apply 233 | def func1(self) -> int: 234 | return self.param 235 | 236 | @_infra.apply 237 | def func2(self) -> int: 238 | return 2 * self.param 239 | 240 | 241 | def test_copied_infra(tmp_path: Path) -> None: 242 | obj = Copied(param=13, infra={"folder": tmp_path}) # type: ignore 243 | assert obj._infra is not obj.infra 244 | assert obj.func1() == 13 245 | assert obj.func2() == 26 246 | names = [fp.name for fp in tmp_path.iterdir()] 247 | assert tuple("Copied.func" in n for n in names) == (True, True), names 248 | _ = obj.infra.clone_obj() 249 | 250 | 251 | def test_changing_version(tmp_path: Path) -> None: 252 | class VersionXp(Base): 253 | infra: TaskInfra = TaskInfra(version="12") 254 | 255 | @infra.apply 256 | def func(self) -> int: 257 | return super().func() 258 | 259 | class Main(pydantic.BaseModel): 260 | xp: VersionXp = VersionXp() 261 | infra: TaskInfra = TaskInfra(version="1") 262 | 263 | def model_post_init(self, log__: tp.Any) -> None: 264 | super().model_post_init(log__) 265 | self.xp.infra.folder = self.infra.folder 266 | 267 | @infra.apply 268 | def func(self) -> int: 269 | return self.xp.func() 270 | 271 | m = Main(infra={"folder": tmp_path}) # type: ignore 272 | _ = m.func() 273 | assert ",12/" in m.xp.infra.uid() 274 | if not m.xp.infra.uid_folder().exists(): # type: ignore 275 | raise RuntimeError("Folder should have been created by m.func()") 276 | 277 | class VersionXp(Base): # type: ignore 278 | infra: TaskInfra = TaskInfra(version="13") 279 | 280 | @infra.apply 281 | def func(self) -> int: 282 | return super().func() 283 | 284 | class Main(Main): # type: ignore 285 | xp: VersionXp = VersionXp() 286 | 287 | m = Main(infra={"folder": tmp_path}) # type: ignore 288 | with pytest.raises(RuntimeError): 289 | _ = m.func() 290 | 291 | # sub-config should still work because folder is different 292 | assert ",13/" in m.xp.infra.uid() 293 | _ = m.xp.func() 294 | 295 | 296 | class MissingInfra(pydantic.BaseModel): # COMPATIBIILTY 297 | infra: TaskInfra 298 | infra2: TaskInfra = TaskInfra() 299 | 300 | @infra2.apply 301 | def run(self) -> None: 302 | return 303 | 304 | 305 | def test_buggy_compat_validator(tmp_path: Path) -> None: 306 | _ = MissingInfra(infra={"folder": tmp_path}) # type: ignore 307 | 308 | 309 | def test_infra_already_applied_obj() -> None: 310 | infra = TaskInfra(version="12") 311 | cfg1 = MyCfg(infra=infra) 312 | cfg2 = MyCfg(infra=infra) 313 | assert cfg2.infra is not cfg1.infra 314 | # test not applied infra 315 | with pytest.raises(RuntimeError): 316 | _ = infra.config() 317 | 318 | 319 | class DoubleCfg(pydantic.BaseModel): # COMPATIBIILTY 320 | infra: TaskInfra = TaskInfra() 321 | infra2: TaskInfra = TaskInfra() 322 | 323 | @infra.apply 324 | def func(self) -> int: 325 | return 12 326 | 327 | @infra2.apply 328 | def func2(self) -> int: 329 | return 13 330 | 331 | 332 | def test_infra_already_applied_name() -> None: 333 | infra = TaskInfra(version="12") 334 | cfg = DoubleCfg(infra=infra, infra2=infra) 335 | assert cfg.infra is not cfg.infra2 336 | 337 | 338 | def test_obj_infras() -> None: 339 | cfg = Copied() 340 | infras = cfg.infra.obj_infras() 341 | assert set(infras) == {"infra", "_infra"} 342 | assert infras["infra"] is cfg.infra 343 | 344 | 345 | class Xp(pydantic.BaseModel): 346 | infra: TaskInfra = TaskInfra(version="12") 347 | base: Base = Base() 348 | 349 | @infra.apply 350 | def func(self) -> None: 351 | pass 352 | 353 | 354 | def test_obj_in_obj() -> None: 355 | # triggered model_with_infra_validator_after error because obj already set 356 | base = Base() 357 | _ = Xp(base=base) 358 | 359 | 360 | class InfraNotApplied(pydantic.BaseModel): 361 | infra: TaskInfra = TaskInfra() 362 | 363 | 364 | def test_infra_not_applied() -> None: 365 | model = InfraNotApplied() 366 | excluded = model.infra._exclude_from_cls_uid() 367 | assert len(excluded) > 1 368 | 369 | 370 | class WrappedBase(pydantic.BaseModel): 371 | xp: Base = Base() 372 | infra: TaskInfra = TaskInfra(version="12") 373 | 374 | @infra.apply 375 | def wfunc(self) -> int: 376 | return 12 377 | 378 | 379 | def test_tricky_update(tmp_path: Path) -> None: 380 | # pb in confdict for subconfig 381 | infra: tp.Any = {"folder": tmp_path, "workdir": {"copied": [Path(__file__).parent]}} 382 | xp = Base().infra.clone_obj(infra=infra) 383 | wxp = WrappedBase(xp=xp) 384 | wxp.infra._update(dict(xp.infra)) 385 | wxp.infra._update(wxp.xp.infra.model_dump()) 386 | assert isinstance(wxp.infra.workdir, WorkDir) 387 | wxp.infra._update(xp.infra) 388 | assert isinstance(wxp.infra.workdir, WorkDir) 389 | 390 | 391 | def test_missing_base_model() -> None: 392 | with pytest.raises(RuntimeError): 393 | 394 | class MissingBaseModel: # pylint: disable=unused-variable 395 | infra: TaskInfra = TaskInfra(version="12") 396 | 397 | 398 | class WeirdTypes(pydantic.BaseModel): 399 | alphas: tp.List[float] = list(np.ones(2)) 400 | infra: TaskInfra = TaskInfra() 401 | 402 | @infra.apply 403 | def build(self) -> int: 404 | return 8 405 | 406 | 407 | class OrderedCfg(pydantic.BaseModel): 408 | d: collections.OrderedDict[str, tp.Any] = collections.OrderedDict() 409 | d2: dict[str, tp.Any] = {} 410 | infra: TaskInfra = TaskInfra() 411 | 412 | @infra.apply 413 | def build(self) -> str: 414 | return ",".join(self.d) 415 | 416 | 417 | def test_ordered_dict(tmp_path: Path) -> None: 418 | keys = [str(k) for k in range(100)] 419 | whatever = OrderedCfg(d={k: 12 for k in keys}, infra={"folder": tmp_path}) # type: ignore 420 | assert isinstance(whatever.d, collections.OrderedDict) 421 | assert whatever.build() == ",".join(keys) 422 | # new reorder 423 | keys2 = list(keys) 424 | np.random.shuffle(keys2) 425 | whatever2 = OrderedCfg(d={k: 12 for k in keys2}, infra={"folder": tmp_path}) # type: ignore 426 | assert whatever2.build() == ",".join(keys2) 427 | # check yaml 428 | fp: Path = whatever2.infra.uid_folder() / "config.yaml" # type: ignore 429 | cfg = ConfDict.from_yaml(fp) 430 | cfg["infra.mode"] = "read-only" 431 | whatever3 = OrderedCfg(**cfg) 432 | assert ",".join(whatever3.d) == ",".join(keys2) 433 | assert whatever3.build() == ",".join(keys2) 434 | 435 | 436 | def test_unordered_dict() -> None: 437 | ordered = OrderedCfg(d2=collections.OrderedDict({str(k): 12 for k in range(12)})) 438 | if isinstance(ordered.d2, collections.OrderedDict): 439 | raise AssertionError("OrderedDict should be cast to standard dict by pydantic") 440 | 441 | 442 | class Num(pydantic.BaseModel): 443 | model_config = pydantic.ConfigDict(extra="forbid") 444 | k: int 445 | other: int = 12 446 | 447 | 448 | class OrderedNumCfg(pydantic.BaseModel): 449 | d: collections.OrderedDict[str, Num] = collections.OrderedDict() 450 | infra: TaskInfra = TaskInfra() 451 | 452 | @infra.apply 453 | def build(self) -> str: 454 | return ",".join(self.d) 455 | 456 | 457 | def test_ordered_dict_with_subcfg(tmp_path: Path) -> None: 458 | nums = OrderedNumCfg(d={"a": {"k": 12}}, infra={"folder": tmp_path}) # type: ignore 459 | _ = nums.build() 460 | uid = nums.infra.uid() 461 | assert "d={a.k=12}" in uid 462 | 463 | 464 | def test_ordered_dict_with_subcfg_flat(tmp_path: Path) -> None: 465 | infra = {"folder": tmp_path} 466 | keys = list(range(10)) 467 | np.random.shuffle(keys) 468 | nums = OrderedNumCfg(d={f"{k}": {"k": k, "other": 0} for k in keys}, infra=infra) # type: ignore 469 | flat = nums.infra.config().flat() 470 | flat["d.5.k"] = 12 471 | nums2 = OrderedNumCfg(**ConfDict(flat)) 472 | keys2 = [v.k for v in nums2.d.values()] 473 | np.testing.assert_equal([k if k != 5 else 12 for k in keys], keys2) 474 | 475 | 476 | def test_weird_types(tmp_path: Path) -> None: 477 | whatever = WeirdTypes(infra={"folder": tmp_path}) # type: ignore 478 | _ = whatever.build() 479 | 480 | 481 | def test_defined_in_main() -> None: 482 | try: 483 | import neuralset as ns 484 | 485 | cwd = Path(ns.__file__).parents[1] 486 | except ImportError: 487 | import exca 488 | 489 | cwd = Path(exca.__file__).parents[1] 490 | path = Path(__file__).with_suffix("").relative_to(cwd) 491 | cmd = str(path).replace("/", ".") 492 | subprocess.check_call(f"python -m {cmd}".split(), shell=False, cwd=cwd) 493 | 494 | 495 | if __name__ == "__main__": 496 | 497 | class MainCls(pydantic.BaseModel): 498 | infra: TaskInfra = TaskInfra(version="12") 499 | param: int = 12 500 | 501 | @infra.apply 502 | def func(self) -> int: 503 | return 2 * self.param 504 | 505 | with tempfile.TemporaryDirectory() as tmp: 506 | model_ = MainCls(param=13, infra={"folder": tmp, "cluster": "local"}) # type: ignore 507 | assert model_.func() == 26 508 | assert list(Path(tmp).iterdir()) 509 | -------------------------------------------------------------------------------- /exca/test_cachedict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import gc 8 | import logging 9 | import os 10 | import typing as tp 11 | from concurrent import futures 12 | from pathlib import Path 13 | 14 | import nibabel as nib 15 | import numpy as np 16 | import pandas as pd 17 | import psutil 18 | import pytest 19 | import torch 20 | 21 | from . import cachedict as cd 22 | from . import utils 23 | from .dumperloader import MEMMAP_ARRAY_FILE_MAX_CACHE 24 | 25 | logger = logging.getLogger("exca") 26 | logger.setLevel(logging.DEBUG) 27 | 28 | 29 | @pytest.mark.parametrize("in_ram", (True, False)) 30 | def test_array_cache(tmp_path: Path, in_ram: bool) -> None: 31 | x = np.random.rand(2, 12) 32 | folder = tmp_path / "sub" 33 | cache: cd.CacheDict[np.ndarray] = cd.CacheDict(folder=folder, keep_in_ram=in_ram) 34 | assert not list(cache.keys()) 35 | assert not len(cache) 36 | assert not cache 37 | with cache.writer() as writer: 38 | writer["blublu"] = x 39 | assert "blublu" in cache 40 | assert cache 41 | np.testing.assert_almost_equal(cache["blublu"], x) 42 | assert "blabla" not in cache 43 | assert set(cache.keys()) == {"blublu"} 44 | assert bool(cache._ram_data) is in_ram 45 | cache2: cd.CacheDict[tp.Any] = cd.CacheDict(folder=folder) 46 | with cache2.writer() as writer: 47 | writer["blabla"] = 2 * x 48 | assert "blabla" in cache 49 | assert "blabla2" not in cache 50 | assert set(cache.keys()) == {"blublu", "blabla"} 51 | d = dict(cache2.items()) 52 | np.testing.assert_almost_equal(d["blabla"], 2 * d["blublu"]) 53 | assert len(list(cache.values())) == 2 54 | # detect type 55 | cache2 = cd.CacheDict(folder=folder) 56 | assert isinstance(cache2["blublu"], np.ndarray) 57 | # del 58 | del cache2["blublu"] 59 | assert set(cache2.keys()) == {"blabla"} 60 | # clear 61 | cache2.clear() 62 | assert not list(folder.iterdir()) 63 | assert not cache2 64 | 65 | 66 | @pytest.mark.parametrize( 67 | "data", 68 | ( 69 | np.random.rand(2, 12), 70 | nib.Nifti1Image(np.ones(5), np.eye(4)), 71 | nib.Nifti2Image(np.ones(5), np.eye(4)), 72 | pd.DataFrame([{"blu": 12}]), 73 | ), 74 | ) 75 | @pytest.mark.parametrize("write_key_files", (True, False)) 76 | def test_data_dump_suffix(tmp_path: Path, data: tp.Any, write_key_files: bool) -> None: 77 | cache: cd.CacheDict[np.ndarray] = cd.CacheDict( 78 | folder=tmp_path, keep_in_ram=False, _write_legacy_key_files=write_key_files 79 | ) 80 | if isinstance(data, np.ndarray) and write_key_files: 81 | return # deactivated 82 | with cache.writer() as writer: 83 | writer["blublu.tmp"] = data 84 | assert cache.cache_type not in [None, "Pickle"] 85 | names = [fp.name for fp in tmp_path.iterdir() if not fp.name.startswith(".")] 86 | assert len(names) == 2 + write_key_files 87 | j_name = [n for n in names if n.endswith("-info.jsonl")][0] 88 | v_name = [n for n in names if not n.endswith((".key", "-info.jsonl"))][0] 89 | if write_key_files: 90 | k_name = [n for n in names if n.endswith(".key")][0] 91 | num = len(k_name) - 4 92 | assert k_name[:num] == k_name[:num], f"Non-matching names {k_name} and {v_name}" 93 | assert isinstance(cache["blublu.tmp"], type(data)) 94 | assert (tmp_path / j_name).read_text().startswith("metadata={") 95 | 96 | 97 | @pytest.mark.parametrize( 98 | "data,cache_type", 99 | [ 100 | (torch.rand(2, 12), "TorchTensor"), 101 | ([12, 12], "Pickle"), 102 | (pd.DataFrame([{"stuff": 12}]), "PandasDataFrame"), 103 | (pd.DataFrame([{"stuff": 12}]), "ParquetPandasDataFrame"), 104 | (np.array([12, 12]), "NumpyMemmapArray"), 105 | (np.array([12, 12]), "MemmapArrayFile"), 106 | (np.array([12, 12]), "MemmapArrayFile:0"), 107 | ], 108 | ) 109 | @pytest.mark.parametrize("legacy_write", (True, False)) 110 | @pytest.mark.parametrize("keep_in_ram", (True, False)) 111 | def test_specialized_dump( 112 | tmp_path: Path, data: tp.Any, cache_type: str, legacy_write: bool, keep_in_ram: bool 113 | ) -> None: 114 | memmap_cache_size = 10 115 | if cache_type.endswith(":0"): 116 | cache_type = cache_type[:-2] 117 | memmap_cache_size = 0 118 | proc = psutil.Process() 119 | cache: cd.CacheDict[tp.Any] = cd.CacheDict( 120 | folder=tmp_path, 121 | keep_in_ram=keep_in_ram, 122 | cache_type=cache_type, 123 | _write_legacy_key_files=legacy_write, 124 | ) 125 | with cache.writer() as writer: 126 | writer["x"] = data 127 | with utils.environment_variables(**{MEMMAP_ARRAY_FILE_MAX_CACHE: memmap_cache_size}): 128 | assert isinstance(cache["x"], type(data)) 129 | del cache 130 | gc.collect() 131 | # check permissions 132 | octal_permissions = oct(tmp_path.stat().st_mode)[-3:] 133 | assert octal_permissions == "777", f"Wrong permissions for {tmp_path}" 134 | for fp in tmp_path.iterdir(): 135 | octal_permissions = oct(fp.stat().st_mode)[-3:] 136 | assert octal_permissions == "777", f"Wrong permissions for {fp}" 137 | # check file remaining open 138 | keeps_memmap = cache_type == "MemmapArrayFile" and ( 139 | memmap_cache_size or keep_in_ram 140 | ) # keeps internal cache 141 | keeps_memmap |= cache_type == "NumpyMemmapArray" and keep_in_ram # stays in ram 142 | files = proc.open_files() 143 | if keeps_memmap: 144 | assert files, "Some memmaps should stay open" 145 | else: 146 | assert not files, "No file should remain open" 147 | 148 | 149 | def _setval(cache: cd.CacheDict[tp.Any], key: str, val: tp.Any) -> None: 150 | with cache.writer() as writer: 151 | writer[key] = val 152 | 153 | 154 | @pytest.mark.parametrize( 155 | "legacy_write,remove_jsonl", ((True, True), (True, False), (False, False)) 156 | ) 157 | @pytest.mark.parametrize("process", (False,)) # add True for more (slower) tests 158 | def test_info_jsonl( 159 | tmp_path: Path, legacy_write: bool, remove_jsonl: bool, process: bool 160 | ) -> None: 161 | cache: cd.CacheDict[int] = cd.CacheDict( 162 | folder=tmp_path, keep_in_ram=False, _write_legacy_key_files=legacy_write 163 | ) 164 | Pool = futures.ProcessPoolExecutor if process else futures.ThreadPoolExecutor 165 | jobs = [] 166 | with Pool(max_workers=2) as ex: 167 | jobs.append(ex.submit(_setval, cache, "x", 12)) 168 | jobs.append(ex.submit(_setval, cache, "y", 3)) 169 | jobs.append(ex.submit(_setval, cache, "z", 24)) 170 | for j in jobs: 171 | j.result() 172 | # check files 173 | fps = list(tmp_path.iterdir()) 174 | info_paths = [fp for fp in fps if fp.name.endswith("-info.jsonl")] 175 | assert len(info_paths) == 2 176 | if remove_jsonl: 177 | for ipath in info_paths: 178 | ipath.unlink() 179 | # restore 180 | cache = cd.CacheDict(folder=tmp_path, keep_in_ram=False) 181 | assert cache["x"] == 12 182 | cache = cd.CacheDict(folder=tmp_path, keep_in_ram=False) 183 | assert "y" in cache 184 | cache = cd.CacheDict(folder=tmp_path, keep_in_ram=False) 185 | assert len(cache) == 3 186 | cache.clear() 187 | assert not cache 188 | assert not list(tmp_path.iterdir()) 189 | 190 | 191 | @pytest.mark.parametrize( 192 | "legacy_write,remove_jsonl", ((True, True), (True, False), (False, False)) 193 | ) 194 | def test_info_jsonl_deletion( 195 | tmp_path: Path, legacy_write: bool, remove_jsonl: bool 196 | ) -> None: 197 | keys = ("x", "blüblû", "stuff") 198 | for k in keys: 199 | cache: cd.CacheDict[int] = cd.CacheDict( 200 | folder=tmp_path, keep_in_ram=False, _write_legacy_key_files=legacy_write 201 | ) 202 | with cache.writer() as writer: 203 | writer[k] = 12 if k == "x" else 3 204 | _ = cache.keys() # listing 205 | info = cache._key_info 206 | cache = cd.CacheDict( 207 | folder=tmp_path, keep_in_ram=False, _write_legacy_key_files=legacy_write 208 | ) 209 | _ = cache.keys() # listing 210 | assert cache._key_info == info 211 | for sub in info.values(): 212 | fp = sub.jsonl 213 | r = sub.byte_range 214 | with fp.open("rb") as f: 215 | f.seek(r[0]) 216 | out = f.read(r[1] - r[0]) 217 | assert out.startswith(b"{") and out.endswith(b"}\n") 218 | 219 | if remove_jsonl: 220 | for ipath in tmp_path.glob("*.jsonl"): 221 | ipath.unlink() 222 | cache = cd.CacheDict( 223 | folder=tmp_path, keep_in_ram=False, _write_legacy_key_files=legacy_write 224 | ) 225 | # remove one 226 | chosen = np.random.choice(keys) 227 | del cache[chosen] 228 | assert len(cache) == 2 229 | cache = cd.CacheDict( 230 | folder=tmp_path, keep_in_ram=False, _write_legacy_key_files=legacy_write 231 | ) 232 | assert len(cache) == 2 233 | 234 | 235 | def test_info_jsonl_partial_write(tmp_path: Path) -> None: 236 | cache: cd.CacheDict[int] = cd.CacheDict(folder=tmp_path, keep_in_ram=False) 237 | with cache.writer() as writer: 238 | for val, k in enumerate("xyz"): 239 | writer[k] = val 240 | info_path = [fp for fp in tmp_path.iterdir() if fp.name.endswith("-info.jsonl")][0] 241 | lines = info_path.read_bytes().splitlines() 242 | partial_lines = lines[:2] + [lines[2][: len(lines[2]) // 2]] 243 | info_path.write_bytes(b"\n".join(partial_lines)) 244 | # reload cache 245 | logger.debug("new file") 246 | cache = cd.CacheDict(folder=tmp_path, keep_in_ram=False) 247 | assert len(cache) == 1 248 | os.utime(tmp_path) 249 | # now complete 250 | info_path.write_bytes(b"\n".join(lines)) 251 | assert len(cache) == 3 252 | 253 | 254 | def test_2_caches(tmp_path: Path) -> None: 255 | cache: cd.CacheDict[int] = cd.CacheDict(folder=tmp_path, keep_in_ram=False) 256 | cache2: cd.CacheDict[int] = cd.CacheDict(folder=tmp_path, keep_in_ram=False) 257 | with cache.writer() as writer: 258 | writer["blublu"] = 12 259 | keys = list(cache2.keys()) 260 | keys = list(cache2.keys()) 261 | assert "blublu" in keys 262 | 263 | 264 | def test_2_caches_memmap(tmp_path: Path) -> None: 265 | params: dict[str, tp.Any] = dict( 266 | folder=tmp_path, keep_in_ram=True, cache_type="MemmapArrayFile" 267 | ) 268 | cache: cd.CacheDict[np.ndarray] = cd.CacheDict(**params) 269 | cache2: cd.CacheDict[np.ndarray] = cd.CacheDict(**params) 270 | with cache.writer() as writer: 271 | writer["blublu"] = np.random.rand(3, 12) 272 | _ = cache2["blublu"] 273 | with cache.writer() as writer: 274 | writer["blublu2"] = np.random.rand(3, 12) 275 | _ = cache2["blublu2"] 276 | assert "blublu" in cache2._ram_data 277 | _ = cache2["blublu"] 278 | -------------------------------------------------------------------------------- /exca/test_compat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pickle 8 | import typing as tp 9 | from datetime import datetime 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import pydantic 14 | import pytest 15 | 16 | from . import MapInfra, TaskInfra 17 | from .cachedict import CacheDict 18 | 19 | DATA = Path(__file__).parent / "data" 20 | 21 | 22 | class Whatever(pydantic.BaseModel): 23 | taski: TaskInfra = TaskInfra() 24 | mapi: MapInfra = MapInfra() 25 | param: int = 12 26 | 27 | @taski.apply 28 | def process_task(self) -> int: 29 | return 2 * self.param 30 | 31 | @mapi.apply(item_uid=str) 32 | def process_map(self, items: tp.Sequence[int]) -> tp.Iterator[int]: 33 | for item in items: 34 | yield item * self.param 35 | 36 | 37 | @pytest.mark.parametrize("uid_first", (True, False)) 38 | @pytest.mark.parametrize("fp", [None] + list(DATA.glob("*.pkl"))) 39 | def test_backward_compatibility(tmp_path: Path, uid_first: bool, fp: Path | None) -> None: 40 | print(f"Filepath: {fp}") # TO BE REMOVED WHEN LEGACY IS OVER: 41 | DUMP = False # dump a new file (so as to commit it) 42 | if fp is None: 43 | kir: tp.Any = {"keep_in_ram": True} # make sure infra not deactivated 44 | cfg = Whatever(param=13, taski=kir, mapi=kir) 45 | assert cfg.process_task() == 26 46 | assert tuple(cfg.process_map([3])) == (39,) 47 | today = datetime.now().strftime("%Y-%m-%d") 48 | fp = (DATA if DUMP else tmp_path) / f"compat-test-{today}.pkl" 49 | with fp.open("wb") as f: 50 | pickle.dump(cfg, f) 51 | if DUMP: 52 | raise RuntimeError(f"Commit {fp} and rerun without dump=True") 53 | with fp.open("rb") as f: 54 | cfg = pickle.load(f) 55 | if uid_first: 56 | _ = cfg.taski.uid() 57 | _ = cfg.mapi.uid() 58 | if "-ram-" in fp.name: 59 | # check that we keep in ram to make sure infra is not deactivated 60 | assert cfg.taski.keep_in_ram 61 | assert cfg.mapi.keep_in_ram 62 | # check outputs 63 | assert cfg.process_task() == 26 64 | assert tuple(cfg.process_map([3])) == (39,) 65 | 66 | 67 | @pytest.mark.parametrize("cache_type", (None, "MemmapArrayFile")) 68 | def test_legacy_key_files(cache_type: str | None) -> None: 69 | folder = DATA / "cachedict2501" 70 | cd: CacheDict[np.ndarray] = CacheDict(folder=folder, cache_type=cache_type) 71 | assert "x" in cd 72 | assert set(cd.keys()) == {"x", "y"} 73 | -------------------------------------------------------------------------------- /exca/test_confdict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import dataclasses 7 | import decimal 8 | import fractions 9 | import glob 10 | import typing as tp 11 | from collections import OrderedDict 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | import pytest 16 | import torch 17 | 18 | from . import confdict 19 | from .confdict import ConfDict 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "version,expected", 24 | [ 25 | (2, "x=12,y={stuff=13,thing=12,what.hello=11}-4a9d3dba"), 26 | (None, "x=12,y={stuff=13,thing=12,what.hello=11}-3466db1c"), 27 | ], 28 | ) 29 | def test_init(version: int | None, expected: str) -> None: 30 | out = ConfDict({"y.thing": 12, "y.stuff": 13, "y": {"what.hello": 11}}, x=12) 31 | flat = out.flat() 32 | out2 = ConfDict(flat) 33 | assert out2 == out 34 | assert out2.to_uid(version=version) == expected 35 | 36 | 37 | def test_dot_access_and_to_simplied_dict() -> None: 38 | data = ConfDict({"a": 1, "b": {"c": 12}}) 39 | assert data["b.c"] == 12 40 | expected = {"a": 1, "b.c": 12} 41 | assert confdict._to_simplified_dict(data) == expected 42 | 43 | 44 | def test_simplified_dict_2() -> None: 45 | seq = [[{"uid": "D2"}, {"uid": "D1", "sub": {"uid": "D2"}}]] 46 | data = ConfDict({"seq": seq, "stuff": {"a": 12}}) 47 | assert isinstance(data["stuff"], ConfDict) 48 | assert isinstance(data["seq.0.0"], ConfDict) 49 | sub = data["seq.0.1"].flat() 50 | assert sub == {"uid": "D1", "sub.uid": "D2"} 51 | 52 | 53 | def test_update_override() -> None: 54 | data = ConfDict({"a": 12, "b": 12}) 55 | data.update({ConfDict.OVERRIDE: True, "d": 13}) 56 | assert data == {"d": 13} 57 | 58 | 59 | def test_update() -> None: 60 | data = ConfDict({"a": {"c": 12}, "b": {"c": 12}}) 61 | data.update(a={ConfDict.OVERRIDE: True, "d": 13}, b={"d": 13}) 62 | assert data == {"a": {"d": 13}, "b": {"c": 12, "d": 13}} 63 | # more complex 64 | data = ConfDict({"a": {"b": {"c": 12}}}) 65 | data.update(a={"b": {"d": 12, ConfDict.OVERRIDE: True}}) 66 | assert data == {"a": {"b": {"d": 12}}} 67 | # with compressed key 68 | data.update(**{"a.b": {"e": 13, ConfDict.OVERRIDE: True}}) 69 | assert data == {"a": {"b": {"e": 13}}} 70 | # assignment 71 | data["a"] = {"c": 1, "b": {"d": 12, ConfDict.OVERRIDE: True}} 72 | assert data == {"a": {"b": {"d": 12}, "c": 1}} 73 | data["a.b"] = {"e": 15, ConfDict.OVERRIDE: True} 74 | assert data == {"a": {"b": {"e": 15}, "c": 1}} 75 | 76 | 77 | @pytest.mark.parametrize( 78 | "update,expected", 79 | [ 80 | ({"a.b.c": 12}, {"a.b.c": 12}), 81 | ({"a.b.c.d": 12}, {"a.b.c.d": 12}), 82 | ({"a.b": {"c.d": 12}}, {"a.b.c.d": 12}), 83 | ({"a.c": None}, {"a.b": None, "a.c": None}), 84 | ({"a.b": None}, {"a.b": None}), 85 | ({"a": None}, {"a": None}), 86 | ], 87 | ) 88 | def test_update_on_none(update: tp.Any, expected: tp.Any) -> None: 89 | data = ConfDict({"a": {"b": None}}) 90 | data.update(update) 91 | assert data.flat() == expected 92 | 93 | 94 | def test_update_on_list() -> None: 95 | data = ConfDict({"a": [12, {"b": None}]}) 96 | data["a.0"] = 13 97 | data["a.1.b"] = 12 98 | with pytest.raises(TypeError): 99 | data["a.c"] = 12 100 | assert data == {"a": [13, {"b": 12}]} 101 | 102 | 103 | def test_get_on_list() -> None: 104 | data = ConfDict({"a": [12, {"b": 13}]}) 105 | assert data["a.0"] == 12 106 | assert data["a.1.b"] == 13 107 | 108 | 109 | def test_del() -> None: 110 | data = ConfDict({"a": 1, "b": {"c": {"e": 12}, "d": 13}}) 111 | del data["b.c.e"] 112 | assert data == {"a": 1, "b": {"d": 13}} 113 | del data["b"] 114 | assert data == {"a": 1} 115 | 116 | 117 | def test_pop_get() -> None: 118 | data = ConfDict({"a": 1, "b": {"c": {"e": 12}, "d": 13}}) 119 | assert "b.c.e" in data 120 | data.pop("b.c.e") 121 | assert data == {"a": 1, "b": {"d": 13}} 122 | with pytest.raises(KeyError): 123 | data.pop("a.x") 124 | assert data.pop("a.x", 12) == 12 125 | assert data.get("a.d") is None 126 | assert data.get("b.c") is None 127 | assert data.get("b.d") == 13 128 | assert data.pop("b.d") == 13 129 | 130 | 131 | def test_empty_conf_dict_uid() -> None: 132 | data = ConfDict({}) 133 | assert not data.to_uid() 134 | 135 | 136 | def test_from_yaml() -> None: 137 | out = ConfDict.from_yaml( 138 | """ 139 | data: 140 | default.stuff: 141 | duration: 1. 142 | features: 143 | - freq: 2 144 | other: None 145 | """ 146 | ) 147 | exp = { 148 | "data": { 149 | "default": {"stuff": {"duration": 1.0}}, 150 | "features": [{"freq": 2, "other": "None"}], 151 | } 152 | } 153 | assert out == exp 154 | y_str = out.to_yaml() 155 | assert ( 156 | y_str 157 | == """data: 158 | default.stuff.duration: 1.0 159 | features: 160 | - freq: 2 161 | other: None 162 | """ 163 | ) 164 | out2 = ConfDict.from_yaml(y_str) 165 | assert out2 == exp 166 | # uid 167 | e = "data={default.stuff.duration=1,features=({freq=2,other=None})}-d7247912" 168 | assert out2.to_uid() == e 169 | 170 | 171 | @pytest.mark.parametrize( 172 | "version,expected", 173 | [ 174 | (2, "mystuff=13,none=None,t=data-3ddaedfe,x=whatever-hello-1c82f630"), 175 | (3, "none=None,my_stuff=13,x=whatever-hello,t=data-2-3ddaedfe-48c04959"), 176 | (None, "none=None,my_stuff=13,x=whatever-hello,t=data-2-3ddaedfe-48c04959"), 177 | ], 178 | ) 179 | def test_to_uid(version: int, expected: str) -> None: 180 | data = { 181 | "my_stuff": 13.0, 182 | "x": "'whatever*'\nhello", 183 | "none": None, 184 | "t": torch.Tensor([1.2, 1.4]), 185 | } 186 | assert confdict.ConfDict(data).to_uid(version=version) == expected 187 | 188 | 189 | def test_empty(tmp_path: Path) -> None: 190 | fp = tmp_path / "cfg.yaml" 191 | cdict = confdict.ConfDict() 192 | cdict.to_yaml(fp) 193 | cdict = confdict.ConfDict.from_yaml(fp) 194 | assert not cdict 195 | assert isinstance(cdict, dict) 196 | fp.write_text("") 197 | with pytest.raises(TypeError): 198 | confdict.ConfDict.from_yaml(fp) 199 | 200 | 201 | @dataclasses.dataclass 202 | class Data: 203 | x: int = 12 204 | y: str = "blublu" 205 | 206 | 207 | def test_flatten() -> None: 208 | data = {"content": [Data()]} 209 | out = confdict._flatten(data) 210 | assert out == {"content": [{"x": 12, "y": "blublu"}]} 211 | 212 | 213 | def test_list_of_float() -> None: 214 | cfg = {"a": {"b": (1, 2, 3)}} 215 | flat = confdict.ConfDict(cfg).flat() 216 | assert flat == {"a.b": (1, 2, 3)} 217 | 218 | 219 | def test_flat_types() -> None: 220 | cfg = {"a": {"b": Path("blublu")}} 221 | flat = confdict.ConfDict(cfg).flat() 222 | assert flat == {"a.b": Path("blublu")} 223 | 224 | 225 | @pytest.mark.parametrize("ordered", (True, False)) 226 | def test_to_yaml_with_ordered_dict(ordered: bool) -> None: 227 | Dict = OrderedDict if ordered else dict 228 | cfg = {"a": Dict({str(k): {"k": k} for k in range(2)})} 229 | out = confdict.ConfDict(cfg).to_yaml().strip() 230 | expected = "a:\n 0.k: 0\n 1.k: 1" 231 | assert out == expected 232 | # avoid packing ordered dict with len == 1 233 | cfg = {"a": Dict({"b": {"c": 12}})} 234 | expected = "a:\n b.c: 12" if ordered else "a.b.c: 12" 235 | out = confdict.ConfDict(cfg).to_yaml().strip() 236 | assert out == expected 237 | 238 | 239 | def test_from_args() -> None: 240 | args = ["--name=stuff", "--optim.lr=0.01", "--optim.name=Adam"] 241 | confd = ConfDict.from_args(args) 242 | assert confd == {"name": "stuff", "optim": {"lr": "0.01", "name": "Adam"}} 243 | 244 | 245 | def test_collision() -> None: 246 | cfgs = [ 247 | """ 248 | b_model_config: 249 | layer_dim: 12 250 | transformer: 251 | stuff: true 252 | r_p_emb: true 253 | data: 254 | duration: 0.75 255 | start: -0.25 256 | """, 257 | """ 258 | b_model_config: 259 | layer_dim: 12 260 | transformer.stuff: true 261 | use_m_token: true 262 | data: 263 | duration: 0.75 264 | start: -0.25 265 | """, 266 | ] 267 | cds = [ConfDict.from_yaml(cfg) for cfg in cfgs] 268 | assert cds[0].to_uid() != cds[1].to_uid() 269 | expected = "data={start=-0.25,duration=0.75},b_model_config=" 270 | expected += "{layer_dim=12,transformer={stuff=True,r_p_emb=True}}-d1f629b3" 271 | assert cds[0].to_uid() == expected 272 | # reason it was colliding, strings were the same, and hash was incorrectly the same 273 | # legacy check 274 | expected = ( 275 | "bmodelconfig={layerdim=12,transfor[.]},data={duration=0.75,start=-0.25}-8b17a008" 276 | ) 277 | assert cds[0].to_uid(version=2) == expected 278 | assert cds[1].to_uid(version=2) == cds[1].to_uid(version=2) 279 | 280 | 281 | def test_dict_hash() -> None: 282 | maker1 = confdict.UidMaker({"x": 1.2, "y": ("z", 12.0)}, version=3) 283 | maker2 = confdict.UidMaker({"x": 1.2, "z": ("z", 12.0)}, version=3) 284 | assert maker1.hash != maker2.hash 285 | assert maker1.hash == "dict:{x=float:461168601842738689,y=seq:(str:z,int:12)}" 286 | 287 | 288 | def test_set_hash() -> None: 289 | data = [str(k) for k in range(6)] 290 | np.random.shuffle(data) 291 | maker = confdict.UidMaker(set(data)) 292 | assert maker.format() == "0,1,2,3,4,5-06b9e6d9" 293 | 294 | 295 | def test_fractions_decimal() -> None: 296 | d = {"f": 1.1, "d": decimal.Decimal("1.1"), "/": fractions.Fraction(11, 10)} 297 | maker = confdict.UidMaker(d) 298 | assert maker.string == "{-=1.10,d=1.10,f=1.10}" 299 | # float is an approximation while decimal and fraction are exactly the same: 300 | expec = "dict:{/=float:2075258708292324557,d=float:2075258708292324557,f=float:230584300921369601}" 301 | assert maker.hash == expec 302 | 303 | 304 | def test_long_config_glob(tmp_path: Path) -> None: 305 | string = "abcdefghijklmnopqrstuvwxyz" 306 | base: dict[str, tp.Any] = { 307 | "l": [1, 2], 308 | "d": {"a": 1, "b.c": 2}, 309 | "string": string, 310 | "num": 123456789000, 311 | } 312 | cfg = dict(base) 313 | cfg["sub"] = dict(base) 314 | cfg["sub"]["sub"] = dict(base) 315 | cfgd = ConfDict(cfg) 316 | uid = cfgd.to_uid(2) 317 | expected = ( 318 | "d={a=1,b.c=2},l=[1,2],num=12345678[.]tring=abcdefghijklmnopqrstuvwxyz}}-b7348341" 319 | ) 320 | assert uid == expected 321 | uid = cfgd.to_uid() 322 | expected = "l=(1,2),d={a=1,b.c=2},num=123456789000,string=abcdefghijklmnopqrstuvwxyz," 323 | expected += "sub={l=(1,2),d={a=1,b.c=2},num=123456789000,string=abcd...84-63bf871d" 324 | assert uid == expected 325 | folder = tmp_path / uid 326 | folder.mkdir() 327 | (folder / "myfile.txt").touch() 328 | files = list(glob.glob(str(folder / "*file.txt"))) 329 | assert files, "folder name messes up with glob" 330 | -------------------------------------------------------------------------------- /exca/test_dumperloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | from pathlib import Path 9 | 10 | import mne 11 | import nibabel as nib 12 | import numpy as np 13 | import pandas as pd 14 | import pytest 15 | import torch 16 | 17 | from . import dumperloader 18 | 19 | 20 | def make_mne_raw(ch_type: str) -> mne.io.RawArray: 21 | n_channels, sfreq, duration = 4, 64, 60 22 | data = np.random.rand(n_channels, sfreq * duration) 23 | info = mne.create_info(n_channels, sfreq=sfreq, ch_types=[ch_type] * n_channels) 24 | return mne.io.RawArray(data, info=info) 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "data", 29 | ( 30 | np.random.rand(2, 12), 31 | torch.Tensor([12]), 32 | nib.Nifti1Image(np.ones(5), np.eye(4)), 33 | nib.Nifti2Image(np.ones(5), np.eye(4)), 34 | pd.DataFrame([{"blu": 12}]), 35 | make_mne_raw("eeg"), 36 | "stuff", 37 | ), 38 | ) 39 | def test_data_dump_suffix(tmp_path: Path, data: tp.Any) -> None: 40 | Cls = dumperloader.DumperLoader.default_class(type(data)) 41 | if not isinstance(data, str): 42 | assert Cls is not dumperloader.Pickle 43 | dl = Cls(tmp_path) 44 | # test with an extension, as it's easy to mess the new name with Path.with_suffix 45 | with dl.open(): 46 | info = dl.dump("blublu.ext", data) 47 | reloaded = dl.load(**info) 48 | ExpectedCls = type(data) 49 | if ExpectedCls is mne.io.RawArray: 50 | ExpectedCls = mne.io.Raw 51 | assert isinstance(reloaded, ExpectedCls) 52 | 53 | 54 | @pytest.mark.parametrize("name", ("PandasDataFrame", "ParquetPandasDataFrame")) 55 | def test_text_df(tmp_path: Path, name: str) -> None: 56 | df = pd.DataFrame( 57 | [{"type": "Word", "text": "None"}, {"type": "Something", "number": 12}] 58 | ) 59 | dl = dumperloader.DumperLoader.CLASSES[name](tmp_path) 60 | info = dl.dump("blublu", df) 61 | reloaded = dl.load(**info) 62 | assert reloaded.loc[0, "text"] == "None" 63 | assert pd.isna(reloaded.loc[1, "text"]) # type: ignore 64 | assert pd.isna(reloaded.loc[0, "number"]) # type: ignore 65 | assert not set(reloaded.columns).symmetric_difference(df.columns) 66 | 67 | 68 | @pytest.mark.parametrize("ch_type", ("eeg", "ecog", "seeg", "mag", "grad", "ref_meg")) 69 | @pytest.mark.parametrize("name", ("MneRawFif", "MneRawBrainVision")) 70 | def test_mne_raw(tmp_path: Path, ch_type: str, name: str) -> None: 71 | raw = make_mne_raw(ch_type) 72 | dl = dumperloader.DumperLoader.CLASSES[name](tmp_path) 73 | info = dl.dump("blublu", raw) 74 | reloaded = dl.load(**info) 75 | reload_type = ( 76 | mne.io.Raw 77 | if name == "MneRawFif" 78 | else mne.io.brainvision.brainvision.RawBrainVision 79 | ) 80 | assert isinstance(reloaded, reload_type) 81 | raw_data = raw.get_data() 82 | reloaded_data = reloaded.get_data() 83 | assert np.allclose(raw_data, reloaded_data, atol=1e-8) 84 | 85 | 86 | @pytest.mark.parametrize( 87 | "data,expected", 88 | [ 89 | (torch.arange(8), False), 90 | (torch.arange(8) * 1.0, False), 91 | (torch.arange(8)[-2:], True), 92 | (torch.arange(8)[:2], True), 93 | (torch.arange(8).reshape(2, 4), False), 94 | (torch.arange(8).reshape(2, 4).T, True), 95 | ], 96 | ) 97 | def test_is_view(data: torch.Tensor, expected: bool) -> None: 98 | assert dumperloader.is_view(data) is expected 99 | 100 | 101 | def test_dump_torch_view(tmp_path: Path) -> None: 102 | data = torch.arange(8)[:2] 103 | assert dumperloader.is_view(data) 104 | # reloading it should not be a view as it was cloned 105 | dl = dumperloader.TorchTensor(tmp_path) 106 | info = dl.dump("blublu", data) 107 | reloaded = dl.load(**info) 108 | assert not dumperloader.is_view(reloaded) 109 | 110 | 111 | def test_dump_dict(tmp_path: Path) -> None: 112 | data = {"blu": 12, "blublu": np.array([12, 12]), "blabla": np.array([24.0])} 113 | dl = dumperloader.DataDict(tmp_path) 114 | with dl.open(): 115 | info = dl.dump("blublu", data) 116 | assert set(info["optimized"]) == {"blublu", "blabla"} 117 | reloaded = dl.load(**info) 118 | assert set(reloaded) == {"blublu", "blabla", "blu"} 119 | np.testing.assert_array_equal(reloaded["blublu"], [12, 12]) 120 | 121 | 122 | def test_default_class() -> None: 123 | out = dumperloader.DumperLoader.default_class(int | None) # type: ignore 124 | assert out is dumperloader.Pickle 125 | 126 | 127 | @pytest.mark.parametrize( 128 | "string,expected", 129 | [ 130 | ( 131 | "whave\t-er I want/to\nput i^n there", 132 | "whave--er-I-want-to-put-i^n-there-391137b5", 133 | ), 134 | ( 135 | "whave\t-er I want/to put i^n there", # same but space instead of line return 136 | "whave--er-I-want-to-put-i^n-there-cef06284", 137 | ), 138 | (50 * "a" + 50 * "b", 40 * "a" + "[.]" + 40 * "b" + "-932620a9"), 139 | (51 * "a" + 50 * "b", 40 * "a" + "[.]" + 40 * "b" + "-86bb658a"), # longer 140 | ], 141 | ) 142 | def test_string_uid(string: str, expected: str) -> None: 143 | out = dumperloader._string_uid(string) 144 | assert out == expected 145 | 146 | 147 | def test_memmap_array_file(tmp_path: Path) -> None: 148 | dl = dumperloader.MemmapArrayFile(folder=tmp_path) 149 | info = [] 150 | x = np.random.rand(2, 3) 151 | y = np.random.rand(3, 3).astype(np.float16) 152 | with dl.open(): 153 | with pytest.raises(ValueError): # x array with no size not supported 154 | info.append(dl.dump("t", np.random.rand(0, 3))) 155 | info.append(dl.dump("x", x)) 156 | info.append(dl.dump("y", y)) 157 | info.append(dl.dump("z", np.random.rand(4, 3))) 158 | assert info[0]["filename"] == info[1]["filename"] 159 | x2 = dl.load(**info[0]) 160 | with dl.open(): 161 | info.append(dl.dump("w", np.random.rand(5, 3))) # write in between reads 162 | assert isinstance(x2, np.memmap) 163 | np.testing.assert_array_equal(x2, x) 164 | y2 = dl.load(**info[1]) 165 | np.testing.assert_array_equal(y2, y) 166 | assert dl.load(**info[1]).shape == (3, 3) 167 | assert dl.load(**info[-1]).shape == (5, 3) 168 | # recheck after data was reloaded 169 | assert isinstance(x2, np.memmap) 170 | np.testing.assert_array_equal(x2, x) 171 | -------------------------------------------------------------------------------- /exca/test_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pickle 8 | import typing as tp 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import pytest 13 | 14 | from . import helpers 15 | 16 | 17 | def my_func(a: int, b: int) -> np.ndarray: 18 | return np.random.rand(a, b) 19 | 20 | 21 | def test_to_config_model(tmp_path: Path) -> None: 22 | Conf = helpers.to_config_model(my_func) 23 | conf = Conf(a=3, b=4, infra={"folder": tmp_path}) # type: ignore 24 | out1 = conf.build() 25 | out2 = conf.build() 26 | np.testing.assert_array_equal(out1, out2) # should be cached 27 | 28 | 29 | def test_to_config(tmp_path: Path) -> None: 30 | conf = helpers.to_config(my_func, a=3, b=4, infra={"folder": tmp_path}) 31 | out = conf.build() 32 | string = pickle.dumps(conf) 33 | conf2 = pickle.loads(string) 34 | np.testing.assert_array_equal(conf2.build(), out) # should be cached 35 | 36 | 37 | def test_with_infra(tmp_path: Path) -> None: 38 | infra_func = helpers.with_infra(folder=tmp_path)(my_func) 39 | out = infra_func(a=3, b=4) 40 | # pickling and reproducibility 41 | string = pickle.dumps(infra_func) 42 | infra_func2 = pickle.loads(string) 43 | out2 = infra_func2(a=3, b=4) 44 | np.testing.assert_array_equal(out2, out) # should be cached 45 | 46 | 47 | # pylint: disable=unused-argument 48 | def func(a: int, *, b: int = 12) -> None: 49 | pass 50 | 51 | 52 | class KwargsClass: 53 | # pylint: disable=unused-argument 54 | def __init__(self, a: int, b: int = 12, **kwargs: tp.Any) -> None: 55 | pass 56 | 57 | 58 | def test_validate_kwargs() -> None: 59 | with pytest.raises(ValueError): 60 | helpers.validate_kwargs(func, {}) 61 | with pytest.raises(ValueError): 62 | helpers.validate_kwargs(KwargsClass, {}) 63 | with pytest.raises(ValueError): 64 | helpers.validate_kwargs(func, {"a": 12, "c": 13}) 65 | helpers.validate_kwargs(KwargsClass, {"a": 12}) 66 | helpers.validate_kwargs(func, {"a": 12, "b": 13}) 67 | with pytest.raises(TypeError): 68 | helpers.validate_kwargs(func, {"a": "blublu", "b": 13}) 69 | helpers.validate_kwargs(KwargsClass, {"a": 12, "b": 13, "c": "blublu"}) 70 | 71 | 72 | def test_find_slurm_job(tmp_path: Path) -> None: 73 | cfolder = tmp_path / "a" / "b" 74 | jfolder = cfolder / "logs" / "c" / "12" 75 | jfolder.mkdir(parents=True) 76 | stdout = jfolder / "12_0_log.out" 77 | stdout.write_text("Ice cream") 78 | (cfolder / "config.yaml").write_text("a: 12") 79 | (cfolder / "uid.yaml").write_text("a: 12") 80 | job = helpers.find_slurm_job(job_id="12", folder=tmp_path) 81 | assert job is not None 82 | assert job.config == {"a": 12} 83 | assert job.stdout() == "Ice cream" 84 | -------------------------------------------------------------------------------- /exca/test_localmap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import typing as tp 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import pydantic 13 | import pytest 14 | 15 | from .map import MapInfra 16 | 17 | PACKAGE = MapInfra.__module__.split(".", maxsplit=1)[0] 18 | logging.getLogger(PACKAGE).setLevel(logging.DEBUG) 19 | 20 | 21 | class Whatever(pydantic.BaseModel): 22 | param1: int = 12 23 | param2: str = "stuff" 24 | unrelated: str = "hello world" 25 | cls_unrelated: str = "" 26 | infra: MapInfra = MapInfra(version="1", cluster="threadpool", max_jobs=2) 27 | raise_for: int | None = None 28 | _exclude_from_cls_uid = ("cls_unrelated",) 29 | 30 | @infra.apply( 31 | item_uid=str, # how to create the dict key/uid from an item of the method input 32 | exclude_from_cache_uid=("unrelated",), 33 | ) 34 | def process(self, items: tp.Sequence[int]) -> tp.Iterator[np.ndarray]: 35 | for item in items: 36 | if self.raise_for is not None and item == self.raise_for: 37 | raise ValueError(f"Raising for {item}") 38 | yield np.random.rand(item, self.param1) 39 | 40 | 41 | @pytest.mark.parametrize("cluster", [None, "threadpool", "processpool"]) 42 | @pytest.mark.parametrize("keep_in_ram", [True, False]) 43 | @pytest.mark.parametrize("with_folder", [True, False]) 44 | def test_local_map_infra( 45 | tmp_path: Path, keep_in_ram: bool, with_folder: bool, cluster: str 46 | ) -> None: 47 | params: tp.Any = {"keep_in_ram": keep_in_ram, "cluster": cluster} 48 | if with_folder: 49 | params["folder"] = tmp_path 50 | base = Whatever( 51 | param2="stuff", 52 | unrelated="not included", 53 | cls_unrelated="not included either", 54 | infra=params, 55 | ) 56 | whatever = base.infra.clone_obj({"param1": 13}) 57 | _ = base.infra.config(uid=False, exclude_defaults=False) 58 | if with_folder: 59 | objs = list(whatever.infra.iter_cached()) 60 | assert not objs 61 | out = list(whatever.process([1, 2, 2, 3])) 62 | assert [x.shape for x in out] == [(1, 13), (2, 13), (2, 13), (3, 13)] 63 | path = tmp_path 64 | uid = f"{__name__}.Whatever.process,1/param1=13-4c541560" 65 | if with_folder: 66 | for name in uid.split("/"): 67 | path = path / name 68 | if not path.exists(): 69 | content = [f.name for f in path.parent.iterdir()] 70 | raise RuntimeError(f"Missing folder, got {content}") 71 | objs = list(whatever.infra.iter_cached()) 72 | assert len(objs) == 1, "Missing cached configs" 73 | if with_folder or keep_in_ram: 74 | out2 = next(whatever.process([2])) 75 | np.testing.assert_array_equal(out2, out[1]) 76 | # check that clearing cache works 77 | whatever.infra.cache_dict.clear() 78 | _ = np.random.rand() # updates the seed if process is forked 79 | out2 = next(whatever.process([2])) 80 | with pytest.raises(AssertionError): 81 | np.testing.assert_array_equal(out2, out[1]) 82 | -------------------------------------------------------------------------------- /exca/test_map.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import concurrent.futures 8 | import logging 9 | import pickle 10 | import typing as tp 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import pydantic 15 | import pytest 16 | 17 | from . import helpers 18 | from .map import MapInfra, to_chunks 19 | 20 | PACKAGE = MapInfra.__module__.split(".", maxsplit=1)[0] 21 | logging.getLogger(PACKAGE).setLevel(logging.DEBUG) 22 | 23 | 24 | class Whatever(pydantic.BaseModel): 25 | param1: int = 12 26 | param2: str = "stuff" 27 | unrelated: str = "hello world" 28 | cls_unrelated: str = "" 29 | infra: MapInfra = MapInfra(version="1") 30 | 31 | _exclude_from_cls_uid = ("cls_unrelated",) 32 | _missing: bool = False # internal for testing 33 | 34 | @infra.apply( 35 | item_uid=str, # how to create the dict key/uid from an item of the method input 36 | exclude_from_cache_uid=("unrelated",), 37 | ) 38 | def process(self, nums: tp.Sequence[int]) -> tp.Iterator[np.ndarray]: 39 | for num in nums: 40 | yield np.random.rand(num, self.param1) 41 | if self._missing: 42 | return 43 | 44 | 45 | def test_decorator_deactivated() -> None: 46 | whatever = Whatever(param1=13, infra=dict(keep_in_ram=False)) # type: ignore 47 | x, *_ = whatever.process([12]) 48 | x, *_ = whatever.process([12]) 49 | assert x.shape == (12, 13) 50 | 51 | 52 | def test_named_arg() -> None: 53 | whatever = Whatever(param1=13, infra=dict(keep_in_ram=False)) # type: ignore 54 | # pylint: disable=unexpected-keyword-arg,no-value-for-parameter 55 | with pytest.raises(ValueError): 56 | x, *_ = whatever.process(nums=[12], items="stuff") # type: ignore 57 | with pytest.raises(NameError): 58 | x, *_ = whatever.process(items=[12]) # type: ignore 59 | x, *_ = whatever.process(nums=[12]) 60 | assert x.shape == (12, 13) 61 | 62 | 63 | def test_infra_forbid_single_item_computation(tmp_path: Path) -> None: 64 | whatever = Whatever(param1=13, infra={"folder": tmp_path, "cluster": "local"}) # type: ignore 65 | whatever.infra.forbid_single_item_computation = True 66 | with pytest.raises(RuntimeError): 67 | whatever.process([12]) 68 | 69 | 70 | def test_map_infra(tmp_path: Path) -> None: 71 | base = Whatever( 72 | param2="stuff", 73 | unrelated="not included", 74 | cls_unrelated="not included either", 75 | infra={"folder": tmp_path, "cluster": "local"}, # type: ignore 76 | ) 77 | whatever = base.infra.clone_obj({"param1": 13}) 78 | _ = base.infra.config(uid=False, exclude_defaults=False) 79 | objs = list(whatever.infra.iter_cached()) 80 | assert not objs 81 | out = list(whatever.process([1, 2, 2, 3])) 82 | assert [x.shape for x in out] == [(1, 13), (2, 13), (2, 13), (3, 13)] 83 | path = tmp_path 84 | uid = f"{__name__}.Whatever.process,1/param1=13-4c541560" 85 | assert whatever.infra.uid() == uid 86 | for name in uid.split("/"): 87 | path = path / name 88 | msg = f"Missing folder, got {[f.name for f in path.parent.iterdir()]}" 89 | assert path.exists(), msg 90 | out2 = next(whatever.process([2])) 91 | np.testing.assert_array_equal(out2, out[1]) 92 | # check that a default name has been set without changing the config 93 | assert whatever.infra.job_name is None 94 | ex = whatever.infra.executor() 95 | assert ex is not None 96 | expected = "Whatever.process,1/param1=13-4c541560" 97 | assert ex._executor.parameters["name"] == expected 98 | assert "{folder}" not in str(whatever.infra._log_path()) 99 | # recover cached objects 100 | objs = list(whatever.infra.iter_cached()) 101 | assert len(objs) == 1 102 | # check that clearing cache works 103 | whatever.infra.cache_dict.clear() 104 | out2 = next(whatever.process([2])) 105 | with pytest.raises(AssertionError): 106 | np.testing.assert_array_equal(out2, out[1]) 107 | 108 | 109 | def test_map_infra_cache_dict_calls(tmp_path: Path) -> None: 110 | whatever = Whatever(infra={"folder": tmp_path, "cluster": "local"}) # type: ignore 111 | cd = whatever.infra.cache_dict 112 | _ = list(whatever.process([1, 2, 3, 4])) 113 | assert cd._jsonl_readings == 3 114 | whatever = Whatever(infra={"folder": tmp_path, "cluster": "local"}) # type: ignore 115 | cd = whatever.infra.cache_dict 116 | _ = list(whatever.process([1])) 117 | assert cd._jsonl_readings == 1 118 | _ = list(whatever.process([2, 3, 4])) 119 | assert cd._jsonl_readings == 1 120 | _ = list(whatever.process([5])) 121 | assert cd._jsonl_readings == 4 122 | 123 | 124 | def test_missing_yield() -> None: 125 | whatever = Whatever() 126 | whatever._missing = True 127 | with pytest.raises(RuntimeError): 128 | _ = list(whatever.process([1, 2, 2, 3])) 129 | 130 | 131 | def test_map_infra_pickling(tmp_path: Path) -> None: 132 | whatever = Whatever(infra={"folder": tmp_path, "cluster": "local"}) # type: ignore 133 | string = pickle.dumps(whatever) 134 | whatever2 = pickle.loads(string) 135 | assert whatever2.process.__name__ == "_method_override", "Infra not reloaded" 136 | x, *_ = whatever.process([12]) 137 | assert isinstance(x, np.ndarray) 138 | x, *_ = whatever2.process([12]) 139 | assert isinstance(x, np.ndarray) 140 | string = pickle.dumps(whatever2) 141 | whatever3 = pickle.loads(string) 142 | assert hasattr(whatever2.infra, "_cache_dict") 143 | assert not hasattr(whatever3.infra, "_cache_dict") 144 | assert whatever3.process.__name__ == "_method_override", "Infra not reloaded" 145 | 146 | 147 | def test_find_slurm_job(tmp_path: Path) -> None: 148 | whatever = Whatever(param1=13, infra={"folder": tmp_path, "cluster": "local"}) # type: ignore 149 | _ = whatever.process([2]) 150 | folder = next(tmp_path.glob("**/*result.pkl")).parent # there should be a result 151 | job = helpers.find_slurm_job(job_id=folder.name, folder=tmp_path) 152 | assert job is not None 153 | assert job.uid_config == {"param1": 13} 154 | 155 | 156 | def test_map_infra_perm(tmp_path: Path) -> None: 157 | whatever = Whatever(infra={"folder": tmp_path, "permissions": 0o777}) # type: ignore 158 | xpfold = whatever.infra.uid_folder() 159 | assert xpfold is not None 160 | xpfold.mkdir(parents=True) 161 | before = xpfold.stat().st_mode 162 | _ = list(whatever.process([1, 2, 2, 3])) 163 | after = xpfold.stat().st_mode 164 | assert after > before 165 | 166 | 167 | def test_map_infra_debug(tmp_path: Path) -> None: 168 | whatever = Whatever(infra={"folder": tmp_path, "cluster": "debug"}) # type: ignore 169 | _ = list(whatever.process([1, 2, 2, 3])) 170 | 171 | 172 | def test_batch_no_item(tmp_path: Path) -> None: 173 | whatever = Whatever(infra={"folder": tmp_path}) # type: ignore 174 | out = list(whatever.process([])) 175 | assert not out 176 | 177 | 178 | @pytest.mark.parametrize("cluster", [None, "local"]) # processpool requires pickling 179 | def test_script_model(tmp_path: Path, cluster: None | str) -> None: 180 | class LocalModel(pydantic.BaseModel): 181 | infra: MapInfra = MapInfra() 182 | param: int = 12 183 | model_config = pydantic.ConfigDict(extra="forbid") # safer to avoid extra params 184 | 185 | @infra.apply(item_uid=str) 186 | def process( 187 | self, items: tp.Sequence[int] 188 | ) -> tp.Generator[np.ndarray, None, None]: 189 | for item in items: 190 | yield np.random.rand(item, self.param) 191 | 192 | model = LocalModel( 193 | param=13, 194 | infra={"folder": tmp_path, "cluster": cluster}, # type: ignore 195 | ) 196 | assert len(list(model.process([2, 3]))) == 2 197 | 198 | 199 | def test_changing_defaults(tmp_path: Path) -> None: 200 | class Whenever(Whatever): 201 | pass 202 | 203 | whenever = Whenever(param1=13, infra={"folder": tmp_path}) # type: ignore 204 | _ = whenever.process([1]) 205 | 206 | class Whenever(Whatever): # type: ignore 207 | param2: str = "modified" 208 | 209 | whenever = Whenever(param1=13, infra={"folder": tmp_path}) # type: ignore 210 | with pytest.raises(RuntimeError): 211 | _ = whenever.process([1]) 212 | 213 | 214 | def test_multiple_cached(tmp_path: Path) -> None: 215 | for p in range(2): 216 | whatever = Whatever( 217 | param1=p + 1, 218 | infra={"folder": tmp_path}, # type: ignore 219 | ) 220 | _ = list(whatever.process([1, 2, 2, 3])) 221 | objs = list(whatever.infra.iter_cached()) 222 | assert len(objs) == 2 223 | 224 | 225 | class RandMode(pydantic.BaseModel): 226 | infra: MapInfra = MapInfra() 227 | 228 | @infra.apply(item_uid=str) 229 | def process(self, items: tp.Sequence[int]) -> tp.Iterable[np.ndarray]: 230 | for item in items: 231 | yield np.random.rand(2, item) 232 | 233 | 234 | def test_mode(tmp_path: Path) -> None: 235 | modes = ["cached", "force", "read-only"] 236 | cfg = RandMode(infra={"folder": tmp_path, "mode": "force"}) # type: ignore 237 | cfgs = {m: cfg.infra.clone_obj({"infra.mode": m}) for m in modes} 238 | with pytest.raises(RuntimeError): 239 | cfgs["read-only"].process([2]) # not precomputed 240 | out = {m: list(cfgs[m].process([2]))[0] for m in modes} 241 | with pytest.raises(AssertionError): 242 | np.testing.assert_array_equal(out["cached"], out["force"]) 243 | np.testing.assert_array_equal(out["force"], out["read-only"]) 244 | # check not recomputed: 245 | for k in range(2): 246 | newcall = list(cfgs["force"].process([2]))[0] 247 | msg = f"Recomputed on try #{k + 1}" 248 | np.testing.assert_array_equal(newcall, out["force"], err_msg=msg) 249 | 250 | 251 | @pytest.mark.parametrize( 252 | "num,max_chunks,min_items_per_chunk,expected", 253 | [ 254 | (12, 5, 4, (4, 4, 4)), 255 | (13, 2, 5, (7, 6)), 256 | (13, None, 5, (5, 5, 3)), 257 | ], 258 | ) 259 | def test_to_chunks( 260 | num: int, max_chunks: int | None, min_items_per_chunk: int, expected: tp.Tuple[int] 261 | ) -> None: 262 | data = list(range(num)) 263 | chunks = to_chunks( 264 | data, max_chunks=max_chunks, min_items_per_chunk=min_items_per_chunk 265 | ) 266 | sizes = tuple(len(chunk) for chunk in chunks) 267 | assert sizes == expected 268 | 269 | 270 | def test_max_workers() -> None: 271 | # existence of _max_worers is used in map.py but not backed by 272 | # typing, so let's check it here 273 | with concurrent.futures.ProcessPoolExecutor() as p_ex: 274 | assert isinstance(p_ex._max_workers, int) # type: ignore 275 | with concurrent.futures.ThreadPoolExecutor() as t_ex: 276 | assert isinstance(t_ex._max_workers, int) 277 | 278 | 279 | def test_missing_item_uid() -> None: 280 | # pylint: disable=unused-variable 281 | with pytest.raises(TypeError): 282 | 283 | class MissingItemUid(pydantic.BaseModel): # pylint: disable=unused-variable 284 | infra: MapInfra = MapInfra(version="12") 285 | 286 | @infra.apply # type: ignore 287 | def func(self, items: tp.List[int]) -> tp.Iterator[int]: 288 | yield from items 289 | -------------------------------------------------------------------------------- /exca/test_safeguard.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import itertools 8 | import subprocess 9 | from pathlib import Path 10 | 11 | import exca 12 | 13 | 14 | def test_package_version() -> None: 15 | version = exca.__version__ 16 | pyproject = Path(exca.__file__).parent.with_name("pyproject.toml") 17 | assert f'version = "{version}"' in pyproject.read_text() 18 | 19 | 20 | def test_logging() -> None: 21 | line = "from . import logconf # noqa" 22 | fp = Path(__file__).with_name("base.py") 23 | assert line in fp.read_text() 24 | 25 | 26 | def test_slurm_in_doc() -> None: 27 | doc = Path(exca.__file__).parent.with_name("docs") / "infra" / "introduction.md" 28 | assert doc.exists() 29 | expected = "cluster: slurm" # this gets replaced during README tests 30 | assert expected in doc.read_text() 31 | 32 | 33 | def test_header() -> None: 34 | lines = Path(__file__).read_text("utf8").splitlines() 35 | header = "\n".join(itertools.takewhile(lambda line: line.startswith("#"), lines)) 36 | assert len(header.splitlines()) == 5, f"Identified header:\n{header}" 37 | root = Path(__file__).parents[1] 38 | assert root.name == "exca" 39 | # list of files to check 40 | tocheck = [] 41 | output = subprocess.check_output(["find", root, "-name", "*.py"], shell=False) 42 | tocheck.extend([Path(p) for p in output.decode().splitlines()]) 43 | # add missing licenses if none already exists 44 | missing = [] 45 | AUTOADD = True 46 | skip = ("/lib/", "/build/", "docs/conf.py") 47 | for fp in tocheck: 48 | if any(x in str(fp.relative_to(root)) for x in skip): 49 | continue 50 | text = Path(fp).read_text("utf8") 51 | if not text.startswith(header): 52 | if AUTOADD and not any(x in text.lower() for x in ("license", "copyright")): 53 | print(f"Automatically adding header to {fp}") 54 | Path(fp).write_text(header + "\n\n" + text, "utf8") 55 | missing.append(str(fp)) 56 | if missing: 57 | missing_str = "\n - ".join(missing) 58 | raise AssertionError( 59 | f"Following files are/were missing standard header (see other files):\n - {missing_str}" 60 | ) 61 | -------------------------------------------------------------------------------- /exca/test_submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import pydantic 12 | import pytest 13 | 14 | from .task import SubmitInfra 15 | 16 | logger = logging.getLogger(__name__) 17 | logging.getLogger("exca").setLevel(logging.DEBUG) 18 | 19 | 20 | class Whatever(pydantic.BaseModel): 21 | infra: SubmitInfra = SubmitInfra(version="1") 22 | param: int = 12 23 | # uid internals: 24 | 25 | @infra.apply 26 | def process(self, coeff: float = 1) -> float: 27 | return np.random.rand() * coeff + self.param 28 | 29 | 30 | def test_submit_infra_nofolder() -> None: 31 | whatever = Whatever(param=13) 32 | assert 13 < whatever.process() < 14 33 | with pytest.raises(ValueError): 34 | _ = Whatever(param=13, infra={"cluster": "debug"}) # type: ignore 35 | 36 | 37 | def test_submit_infra(tmp_path: Path) -> None: 38 | whatever = Whatever(param=15, infra={"folder": tmp_path, "cluster": "debug"}) # type: ignore 39 | outs = [] 40 | outs.append(whatever.process(coeff=5)) 41 | outs.append(whatever.process(coeff=5)) 42 | outs.append(whatever.infra.submit(coeff=5).result()) 43 | for out in outs: 44 | assert 15 < out < 20 45 | assert outs[0] != outs[1] 46 | assert outs[1] != outs[2] 47 | 48 | 49 | def test_submit_infra_array(tmp_path: Path) -> None: 50 | whatever = Whatever(param=15, infra={"folder": tmp_path, "cluster": "debug"}) # type: ignore 51 | with pytest.raises(AttributeError): # must use submit and not process directly 52 | with whatever.infra.batch(): 53 | whatever.process(coeff=5) 54 | with whatever.infra.batch(): 55 | job = whatever.infra.submit(coeff=5) 56 | assert 15 < job.result() < 20 57 | 58 | 59 | class WhateverStatic(pydantic.BaseModel): 60 | infra: SubmitInfra = SubmitInfra(version="1") 61 | param: int = 12 62 | # uid internals: 63 | 64 | @infra.apply 65 | @staticmethod 66 | def process(coeff: float = 1) -> float: 67 | return np.random.rand() * coeff 68 | 69 | 70 | def test_submit_infra_array_static(tmp_path: Path) -> None: 71 | whatever = WhateverStatic(param=13) 72 | assert 0 < whatever.process(5) < 5 73 | whatever = WhateverStatic(param=15, infra={"folder": tmp_path, "cluster": "debug"}) # type: ignore 74 | assert 0 < whatever.process(5) < 5 75 | -------------------------------------------------------------------------------- /exca/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | import datetime 9 | import os 10 | import typing as tp 11 | from pathlib import Path 12 | 13 | import pydantic 14 | import pytest 15 | 16 | from . import utils 17 | from .confdict import ConfDict 18 | from .utils import to_dict 19 | 20 | 21 | class C(pydantic.BaseModel): 22 | model_config = pydantic.ConfigDict(extra="forbid") 23 | param: int = 12 24 | _exclude_from_cls_uid = (".",) 25 | 26 | 27 | class A(pydantic.BaseModel): 28 | model_config = pydantic.ConfigDict(extra="forbid") 29 | _exclude_from_cls_uid = ("y",) 30 | x: int = 12 31 | y: str = "hello" 32 | 33 | 34 | class B(pydantic.BaseModel): 35 | model_config = pydantic.ConfigDict(extra="forbid") 36 | a1: A 37 | a2: A = A() 38 | a3: A = A(x=13) 39 | a4: int = 12 40 | c: C = C() 41 | 42 | @classmethod 43 | def _exclude_from_cls_uid(cls) -> tp.List[str]: 44 | return ["a4"] 45 | 46 | 47 | def test_to_dict_full() -> None: 48 | d = to_dict(B(a1={"y": "world"})) # type: ignore 49 | out = ConfDict(d).to_yaml() 50 | expected = """a1: 51 | x: 12 52 | y: world 53 | a2: 54 | x: 12 55 | y: hello 56 | a3: 57 | x: 13 58 | y: hello 59 | a4: 12 60 | c.param: 12 61 | """ 62 | assert out == expected 63 | 64 | 65 | def test_to_dict_nondefault() -> None: 66 | b = B(a1={}, a2={"y": "world"}, a4=13, c={"param": 13}) # type: ignore 67 | d = to_dict(b, exclude_defaults=True) 68 | out = ConfDict(d).to_yaml() 69 | expected = """a1: {} 70 | a2.y: world 71 | a4: 13 72 | c.param: 13 73 | """ 74 | assert out == expected 75 | 76 | 77 | def test_to_dict_uid() -> None: 78 | b = B(a1={}, a2={"y": "world"}, a4=13, c={"param": 13}) # type: ignore 79 | d = to_dict(b, uid=True, exclude_defaults=True) 80 | out = ConfDict(d).to_yaml() 81 | print(out) 82 | expected = "a1: {}\n" 83 | assert out == expected 84 | 85 | 86 | class D2(pydantic.BaseModel): 87 | model_config = pydantic.ConfigDict(extra="forbid") 88 | uid: tp.Literal["D2"] = "D2" 89 | 90 | 91 | class D1(pydantic.BaseModel): 92 | model_config = pydantic.ConfigDict(extra="forbid") 93 | uid: tp.Literal["D1"] = "D1" 94 | anything: int = 12 95 | sub: D2 = D2() 96 | 97 | 98 | class Discrim(pydantic.BaseModel): 99 | model_config = pydantic.ConfigDict(extra="forbid") 100 | inst: D1 | D2 = pydantic.Field(..., discriminator="uid") 101 | something_else: tp.List[str] | int 102 | seq: tp.List[tp.List[tp.Annotated[D1 | D2, pydantic.Field(discriminator="uid")]]] 103 | stuff: tp.List[D1] = [] 104 | 105 | 106 | def test_missing_discriminator() -> None: 107 | class DiscrimD(pydantic.BaseModel): 108 | model_config = pydantic.ConfigDict(extra="forbid") 109 | instd: D1 | D2 110 | 111 | _ = DiscrimD(instd={"uid": "D1"}) # type: ignore 112 | 113 | 114 | def test_discriminators(caplog: tp.Any) -> None: 115 | d = Discrim( 116 | inst={"uid": "D2"}, # type: ignore 117 | something_else=12, 118 | seq=[[{"uid": "D2"}, {"uid": "D1"}]], # type: ignore 119 | ) 120 | expected = """inst.uid: D2 121 | seq: 122 | - - uid: D2 123 | - anything: 12 124 | sub.uid: D2 125 | uid: D1 126 | something_else: 12 127 | stuff: [] 128 | """ 129 | # check uid of subinstance (should not have discriminator) 130 | sub_out = ConfDict.from_model(d.inst, exclude_defaults=True) 131 | assert not sub_out 132 | # check uid of instance (should have discriminators) 133 | out = ConfDict.from_model(d).to_yaml() 134 | assert out == expected 135 | expected = """inst.uid: D2 136 | seq: 137 | - - uid: D2 138 | - uid: D1 139 | something_else: 12 140 | """ 141 | out = ConfDict.from_model(d, exclude_defaults=True).to_yaml() 142 | assert not caplog.records 143 | assert out == expected 144 | # check uid of subinstance again (should not have discriminators) 145 | sub_out = ConfDict.from_model(d.inst, exclude_defaults=True) 146 | assert not sub_out 147 | # CHECK AGAIN THE FULL STUFF! 148 | out = ConfDict.from_model(d, exclude_defaults=True).to_yaml() 149 | assert out == expected 150 | 151 | 152 | def test_recursive_freeze() -> None: 153 | d = Discrim( 154 | inst={"uid": "D2"}, # type: ignore 155 | something_else=12, 156 | seq=[[{"uid": "D2"}, {"uid": "D1"}]], # type: ignore 157 | ) 158 | sub = d.seq[0][0] 159 | with pytest.raises(ValueError): 160 | # not frozen but field does not exist 161 | sub.blublu = 12 # type: ignore 162 | utils.recursive_freeze(d) 163 | if hasattr(sub, "_setattr_handler"): 164 | with pytest.raises(RuntimeError): 165 | # frozen, otherwise it would be a value error 166 | sub.blublu = 12 # type: ignore 167 | else: 168 | assert sub.model_config["frozen"] 169 | 170 | 171 | class OptDiscrim(pydantic.BaseModel): 172 | model_config = pydantic.ConfigDict(extra="forbid") 173 | val: tp.Annotated[D1 | D2, pydantic.Field(discriminator="uid")] | None = None 174 | 175 | 176 | def test_optional_discriminator(caplog: tp.Any) -> None: 177 | d = OptDiscrim(val={"uid": "D2"}) # type: ignore 178 | out = ConfDict.from_model(d, exclude_defaults=True).to_yaml() 179 | assert not caplog.records 180 | expected = "val.uid: D2\n" 181 | assert out == expected 182 | 183 | 184 | @pytest.mark.parametrize("replace", (True, False)) 185 | @pytest.mark.parametrize("existing_content", [None, "blublu"]) 186 | def test_temporary_save_path( 187 | tmp_path: Path, existing_content: str | None, replace: bool 188 | ) -> None: 189 | filepath = tmp_path / "save_and_move_test.txt" 190 | if existing_content: 191 | filepath.write_text(existing_content) 192 | with utils.temporary_save_path(filepath, replace=replace) as tmp: 193 | assert str(tmp).endswith(".txt") 194 | tmp.write_text("12") 195 | if existing_content: 196 | assert filepath.read_text() == existing_content 197 | expected = "12" 198 | if existing_content is not None and not replace: 199 | expected = "blublu" 200 | assert filepath.read_text() == expected 201 | 202 | 203 | def test_temporary_save_path_error() -> None: 204 | with pytest.raises(FileNotFoundError): 205 | with utils.temporary_save_path("save_and_move_test"): 206 | pass 207 | 208 | 209 | @pytest.mark.parametrize( 210 | "hint,expected", 211 | [ 212 | (None | int, []), 213 | (None | D1, [D1]), 214 | (D2 | D1, [D2, D1]), 215 | (D1, [D1]), 216 | (list[D2 | D1], [D2, D1]), 217 | ( 218 | tp.List[tp.List[tp.Annotated[D1 | D2, pydantic.Field(discriminator="uid")]]], 219 | [D1, D2], 220 | ), 221 | (tp.Annotated[D1 | D2, pydantic.Field(discriminator="uid")] | None, [D1, D2]), # type: ignore 222 | ], 223 | ) 224 | def test_pydantic_hints(hint: tp.Any, expected: tp.List[tp.Any]) -> None: 225 | assert tuple(utils._pydantic_hints(hint)) == tuple(expected) 226 | 227 | 228 | def test_environment_variable_context() -> None: 229 | name = "ENV_VAR_TEST" 230 | assert name not in os.environ 231 | with utils.environment_variables(ENV_VAR_TEST="blublu"): 232 | assert os.environ[name] == "blublu" 233 | with utils.environment_variables(ENV_VAR_TEST="blublu2"): 234 | assert os.environ[name] == "blublu2" 235 | assert os.environ[name] == "blublu" 236 | assert name not in os.environ 237 | 238 | 239 | def test_iter_string_values(): 240 | out = dict(utils._iter_string_values({"a": [12, {"b": 13, "c": "val"}]})) 241 | assert out == {"a.1.c": "val"} 242 | 243 | 244 | class MissingForbid(pydantic.BaseModel): 245 | param: int = 12 246 | 247 | 248 | class WithMissingForbid(pydantic.BaseModel): 249 | model_config = pydantic.ConfigDict(extra="forbid") 250 | missing: MissingForbid = MissingForbid() 251 | 252 | 253 | def test_extra_forbid() -> None: 254 | m = MissingForbid() 255 | with pytest.raises(RuntimeError): 256 | ConfDict.from_model(m, uid=True, exclude_defaults=True) 257 | w = WithMissingForbid() 258 | with pytest.raises(RuntimeError): 259 | ConfDict.from_model(w, uid=True, exclude_defaults=True) 260 | 261 | 262 | class D(pydantic.BaseModel): 263 | model_config = pydantic.ConfigDict(extra="forbid") 264 | x: int = 12 265 | 266 | 267 | class A12(pydantic.BaseModel): 268 | model_config = pydantic.ConfigDict(extra="forbid") 269 | _exclude_from_cls_uid = ("y",) 270 | name: str = "name" 271 | unneeded: str = "is default" 272 | x: int = 12 273 | y: str = "hello" 274 | 275 | 276 | class NewDefault(pydantic.BaseModel): 277 | model_config = pydantic.ConfigDict(extra="forbid") 278 | a: A12 = A12(x=13) 279 | 280 | 281 | @pytest.mark.parametrize("with_y", (False, True)) 282 | @pytest.mark.parametrize( 283 | "value,expected", 284 | [ 285 | (11, "a.x: 11"), 286 | (12, "a: {}"), 287 | (13, "{}"), 288 | ], 289 | ) 290 | def test_new_default(value: int, expected: str, with_y: bool) -> None: 291 | params: tp.Any = {"x": value} 292 | if with_y: 293 | params["y"] = "world" 294 | m = NewDefault(a=params) 295 | out = ConfDict.from_model(m, uid=True, exclude_defaults=True) 296 | assert out.to_yaml().strip() == expected 297 | m2 = NewDefault(**out) 298 | assert m2.a.x == value 299 | 300 | 301 | class NewDefaultOther(pydantic.BaseModel): 302 | model_config = pydantic.ConfigDict(extra="forbid") 303 | a: A12 = A12(x=13, y="stuff") 304 | 305 | 306 | def test_new_default_other() -> None: 307 | m = NewDefaultOther(a={"x": 13}) # type: ignore 308 | out = ConfDict.from_model(m, uid=True, exclude_defaults=True) 309 | assert out.to_yaml().strip() == "{}" 310 | 311 | 312 | class NewDefaultOther2diff(pydantic.BaseModel): 313 | model_config = pydantic.ConfigDict(extra="forbid") 314 | a: A12 = A12(x=13, unneeded="something else", y="stuff") 315 | 316 | 317 | def test_new_default_other2diff() -> None: 318 | # revert unneeded to default, so it wont show in model_dump, but we need to define x=13 319 | m = NewDefaultOther2diff(a={"x": 13, "unneeded": "is default"}) # type: ignore 320 | out = ConfDict.from_model(m, uid=True, exclude_defaults=True) 321 | assert out.to_yaml().strip() == "a.x: 13" 322 | 323 | 324 | class ActualDefaultOverride(pydantic.BaseModel): 325 | model_config = pydantic.ConfigDict(extra="forbid") 326 | a: A12 = A12(x=12) 327 | a_default: A12 = A12() 328 | 329 | 330 | def test_actual_default_override() -> None: 331 | m = ActualDefaultOverride(a={"x": 13}) # type: ignore 332 | out = ConfDict.from_model(m, uid=True, exclude_defaults=True) 333 | assert out.to_yaml().strip() == "a.x: 13" 334 | # 335 | m = ActualDefaultOverride(a={"x": 12, "y": "stuff"}, a_default={"x": 12, "y": "stuff"}) # type: ignore 336 | out = ConfDict.from_model(m, uid=True, exclude_defaults=True) 337 | assert out.to_yaml().strip() == "{}" 338 | 339 | 340 | class DiscrimDump(pydantic.BaseModel): 341 | model_config = pydantic.ConfigDict(extra="forbid") 342 | inst: D1 | D2 = pydantic.Field(D1(), discriminator="uid") 343 | 344 | 345 | def test_dump() -> None: 346 | dd = DiscrimDump(inst={"uid": "D1"}) # type: ignore 347 | out = ConfDict.from_model(dd, uid=True, exclude_defaults=True) 348 | assert not out 349 | dd = DiscrimDump(inst={"uid": "D2"}) # type: ignore 350 | out = ConfDict.from_model(dd, uid=True, exclude_defaults=True) 351 | assert out == {"inst": {"uid": "D2"}} 352 | 353 | 354 | D1D2 = tp.Annotated[D1 | D2, pydantic.Field(discriminator="uid")] 355 | 356 | 357 | class OrderedDump(pydantic.BaseModel): 358 | model_config = pydantic.ConfigDict(extra="forbid") 359 | insts: collections.OrderedDict[str, D1D2] = collections.OrderedDict() 360 | 361 | 362 | def test_ordered_dict() -> None: 363 | od = OrderedDump(insts={"blublu": {"uid": "D1"}, "stuff": {"uid": "D2"}, "blublu2": {"uid": "D1"}}) # type: ignore 364 | out = ConfDict.from_model(od, uid=True, exclude_defaults=True) 365 | # check that nothing alters the order 366 | assert isinstance(out["insts"], collections.OrderedDict) 367 | assert tuple(out["insts"].keys()) == ("blublu", "stuff", "blublu2") 368 | out["insts.blublu.anything"] = 144 369 | assert tuple(out["insts"].keys()) == ("blublu", "stuff", "blublu2") 370 | out["insts.blublu2.anything"] = 144 371 | assert tuple(out["insts"].keys()) == ("blublu", "stuff", "blublu2") 372 | assert isinstance(out["insts"], collections.OrderedDict) 373 | # keys should be ordered in name and hash: 374 | uid = "insts={blublu={uid=D1,anything=144},stuff.uid=D2,blublu2={uid=D1,anything=144}}-46863fcc" 375 | assert out.to_uid() == uid 376 | 377 | 378 | class HierarchicalCfg(pydantic.BaseModel): 379 | a: A = A() 380 | _a: A = A() 381 | c: C = C() 382 | content: tp.List["HierarchicalCfg"] = [] 383 | 384 | 385 | def test_find_models() -> None: 386 | hcfg = HierarchicalCfg(content=[{}, {}]) # type: ignore 387 | out = utils.find_models(hcfg, A) 388 | assert set(out) == { 389 | "a", 390 | "content.0.a", 391 | "content.1.a", 392 | "_a", 393 | "content.0._a", 394 | "content.1._a", 395 | } 396 | assert all(isinstance(y, A) for y in out.values()) 397 | 398 | 399 | def test_fast_unlink(tmp_path: Path) -> None: 400 | # file 401 | fp = tmp_path / "blublu.txt" 402 | fp.touch() 403 | assert fp.exists() 404 | with utils.fast_unlink(fp): 405 | pass 406 | assert not fp.exists() 407 | # folder 408 | fp = tmp_path / "blublu" 409 | fp.mkdir() 410 | (fp / "stuff.txt").touch() 411 | with utils.fast_unlink(fp): 412 | pass 413 | assert not fp.exists() 414 | 415 | 416 | class ComplexTypesConfig(pydantic.BaseModel): 417 | model_config = pydantic.ConfigDict(extra="forbid") 418 | x: pydantic.DirectoryPath = Path("/") 419 | y: datetime.timedelta = datetime.timedelta(minutes=1) 420 | z: pydantic.ImportString = ConfDict 421 | 422 | 423 | def test_complex_types() -> None: 424 | c = ComplexTypesConfig() 425 | out = ConfDict.from_model(c, uid=True, exclude_defaults=False) 426 | expected = """x: / 427 | y: PT1M 428 | z: exca.confdict.ConfDict 429 | """ 430 | assert out.to_yaml() == expected 431 | assert out.to_uid().startswith("x=-,y=PT1M,z=exca.confdict.ConfDict") 432 | -------------------------------------------------------------------------------- /exca/test_workdir.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | from pathlib import Path 10 | 11 | import pytest 12 | 13 | from . import MapInfra, workdir 14 | 15 | PACKAGE = MapInfra.__module__.split(".", maxsplit=1)[0] 16 | logging.getLogger(PACKAGE).setLevel(logging.DEBUG) 17 | 18 | 19 | def test_identify_bad_package() -> None: 20 | with pytest.raises(ValueError) as exc_info: 21 | workdir.identify_path("blublu12") 22 | assert "failed to import it" in str(exc_info.value) 23 | with pytest.raises(ValueError) as exc_info: 24 | workdir.identify_path("pytest") 25 | assert "not been installed from source" in str(exc_info.value) 26 | 27 | 28 | def test_identify_file(tmp_path: Path) -> None: 29 | fp = tmp_path / "blublu.txt" 30 | fp.touch() 31 | with workdir.chdir(tmp_path): 32 | out = workdir.identify_path(fp.name) 33 | assert out == fp 34 | out = workdir.identify_path(fp) 35 | assert out == fp 36 | 37 | 38 | @pytest.mark.parametrize("file_from_folder", (True, False)) 39 | def test_workdir(tmp_path: Path, file_from_folder: bool) -> None: 40 | old = tmp_path / "old" 41 | old.mkdir() 42 | new = tmp_path / "new" 43 | # add content 44 | fp = old / "blublu.txt" 45 | fp.touch() 46 | folder = old / "folder" 47 | folder.mkdir() 48 | (folder / "a_file.py").touch() 49 | sub = folder / "__pycache__" 50 | sub.mkdir() 51 | (sub / "ignore.py").touch() 52 | sub_string = folder.name 53 | if file_from_folder: 54 | sub_string = "folder/a_file.py" 55 | with workdir.chdir(old): 56 | wdir = workdir.WorkDir(copied=[fp.name, sub_string]) 57 | wdir.folder = new 58 | with wdir.activate(): 59 | assert Path(os.getcwd()).name == "new" 60 | assert Path("folder/a_file.py").exists() 61 | assert not Path("folder/__pycache__").exists() 62 | 63 | 64 | def test_workdir_absolute(tmp_path: Path) -> None: 65 | folder = tmp_path / "folder" 66 | folder.mkdir() 67 | (folder / "a_file.py").touch() 68 | wdir = workdir.WorkDir(folder=tmp_path / "new", copied=[folder]) 69 | with wdir.activate(): 70 | assert Path(os.getcwd()).name == "new" 71 | assert Path("folder/a_file.py").exists() 72 | 73 | 74 | def test_workdir_clean_repo(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: 75 | # raises if not clean: 76 | wd = workdir.WorkDir(folder=tmp_path, log_commit=True, copied=[Path(__file__).parent]) 77 | assert len(caplog.records) == 1 78 | assert "Current git hash" in caplog.records[0].message 79 | repo = "exca" if "exca" in MapInfra.__module__ else "brainai" 80 | assert repo in wd._commits 81 | with wd.activate(): 82 | assert Path("git-hashes.log").read_text().startswith(repo) 83 | 84 | 85 | def test_workdir_editable(tmp_path: Path) -> None: 86 | try: 87 | wdir = workdir.WorkDir(copied=["autoconf"]) 88 | except: 89 | pytest.skip("autoconf not installed in editable mode") 90 | folder = tmp_path / "code" 91 | wdir.folder = folder 92 | with wdir.activate(): 93 | expected = folder / "autoconf/__init__.py" 94 | assert expected.exists() 95 | # pylint: disable=import-outside-toplevel 96 | import autoconf # type: ignore 97 | 98 | assert autoconf.__file__ == str(expected) 99 | 100 | 101 | def test_ignore(tmp_path: Path) -> None: 102 | names = ["stuff.py", "something.py", "data.csv", "folder"] 103 | ig = workdir.Ignore(includes=["*.py"], excludes=["stuff.py"]) 104 | out = ig(tmp_path, names) 105 | assert out == {"stuff.py", "data.csv", "folder"} 106 | # now with a folder 107 | (tmp_path / "folder").mkdir() 108 | out = ig(tmp_path, names) 109 | assert out == {"stuff.py", "data.csv"} 110 | # now multiple includes 111 | ig = workdir.Ignore(includes=["*.py", "*.csv"], excludes=["stuff.py"]) 112 | out = ig(tmp_path, names) 113 | assert out == {"stuff.py"} 114 | # now with a path 115 | ig = workdir.Ignore(excludes=["stuff.py"]) 116 | out = ig("somewhere", names) 117 | assert out == {"stuff.py"} 118 | -------------------------------------------------------------------------------- /exca/workdir.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import contextlib 8 | import fnmatch 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import subprocess 14 | import sys 15 | import typing as tp 16 | from importlib import metadata 17 | from pathlib import Path 18 | 19 | import pydantic 20 | import yaml as _yaml 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | @contextlib.contextmanager 26 | def chdir(folder: Path | str) -> tp.Iterator[None]: 27 | """Temporarily change the working directory and adds 28 | it to sys.path 29 | 30 | Parameter 31 | --------- 32 | folder: str/Path 33 | new working directory 34 | """ 35 | cwd = os.getcwd() 36 | folder = str(Path(folder).absolute()) 37 | to_be_removed = False 38 | try: 39 | os.chdir(folder) 40 | if folder not in sys.path: 41 | to_be_removed = True 42 | sys.path.insert(0, folder) 43 | logger.warning("Moved to working directory: %s", folder) 44 | yield 45 | finally: 46 | os.chdir(cwd) 47 | if to_be_removed and folder in sys.path: 48 | sys.path.remove(folder) 49 | logger.debug("Moved back to working directory: %s", cwd) 50 | 51 | 52 | class WorkDir(pydantic.BaseModel): 53 | """Custom working directory configuration 54 | 55 | Parameters 56 | ---------- 57 | copied: Sequence[str] 58 | list/tuple of names of files, or folders, or packages installed in editable mode 59 | to copy to the new working directory folder. 60 | Relative paths will be moved to the relative equivalent in the new folder, while for 61 | absolute path, the folder/file pointed by the path will be moved directly to the new 62 | folder. 63 | folder: Path/str 64 | folder to use as working directory, 65 | if not specified, infra will create one automatically :code:`/code/-/`. 66 | The folder is logged so you should be able to see what happened in your stderr/stdout. 67 | This parameter can be used in particular to store the code in a specific location 68 | or reuse workdir from a previous run. 69 | includes: sequence of str 70 | file name pattern than must be included (recursively) 71 | folder are always included except if explitely excluded 72 | eg: :code:`["*.py"]` to include only python files 73 | excludes: sequence of str 74 | file/folder name pattern than mush be excluded 75 | log_commit: bool 76 | if True, raises if current working directory is in a git repository 77 | with uncommited changes and logs commit otherwise 78 | 79 | Notes 80 | ----- 81 | - Since python privileges current working directory over installed packages, 82 | the copied packages should be the one running in the job 83 | (be careful there can be a few gotchas, eg: for debug cluster or with no cluster, 84 | the import cannot be not reloaded so the current working directory will be used, 85 | but that should not make a difference in theses cases) 86 | - The change of working directory (and possibly the copy) only happens when the 87 | infra is called for submitting the decorated function. Depending on your code, 88 | this may not be at the very beginning of your execution. 89 | """ 90 | 91 | copied: tp.Sequence[str | Path] = [] 92 | folder: str | Path | None = None 93 | log_commit: bool = False 94 | # include and exclude names (use "*.py" for only python) 95 | includes: tp.Sequence[str] = () 96 | excludes: tp.Sequence[str] = ("__pycache__", ".git") 97 | 98 | # internals 99 | _paths: tp.List[Path] 100 | _commits: tp.Dict[str, str] = {} 101 | model_config = pydantic.ConfigDict(extra="forbid") 102 | 103 | def model_post_init(self, log__: tp.Any) -> None: 104 | super().model_post_init(log__) 105 | self._paths = [identify_path(name) for name in self.copied] 106 | if not self._paths: 107 | msg = "Workdir provided but no paths to copy (specify 'workdir.copied')" 108 | raise RuntimeError(msg) 109 | if self.folder is not None: 110 | if not Path(self.folder).absolute().parent.exists(): 111 | raise ValueError(f"Parent directory of {self.folder} must exist") 112 | if self.log_commit: 113 | for p in self._paths: 114 | # get name 115 | cmd = ["git", "rev-parse", "--show-toplevel"] 116 | try: 117 | folder = subprocess.check_output(cmd, shell=False, cwd=p) 118 | except subprocess.SubprocessError: # not a git repository 119 | continue 120 | name = Path(folder.decode("utf8").strip()).name 121 | if name in self._commits: 122 | continue 123 | # check commited 124 | subprocess.check_call(["git", "diff", "--exit-code"], shell=False, cwd=p) 125 | # get git hash 126 | cmd = ["git", "rev-parse", "--short", "HEAD"] 127 | githash = subprocess.check_output(cmd, shell=False, cwd=p).decode("utf8") 128 | githash = githash.strip() 129 | logger.info("Current git hash for %s is %s", name, githash) 130 | self._commits[name] = githash 131 | 132 | @contextlib.contextmanager 133 | def activate(self) -> tp.Iterator[None]: 134 | if self.folder is None: 135 | raise RuntimeError("folder field must be filled before activation") 136 | folder = Path(self.folder) 137 | folder.mkdir(exist_ok=True) 138 | ignore = Ignore(includes=self.includes, excludes=self.excludes) 139 | for name, path in zip(self.copied, self._paths): 140 | if Path(name).is_absolute(): 141 | # for local folder we keep the structures, for absolute we copy the last item 142 | name = Path(name).name 143 | out = folder / name 144 | if not out.exists(): 145 | if path.is_dir(): 146 | shutil.copytree(path, out, ignore=ignore) 147 | else: 148 | out.parent.mkdir(exist_ok=True, parents=True) 149 | shutil.copyfile(path, out, follow_symlinks=True) 150 | logger.info("Copied %s to %s", path, out) 151 | if self._commits: 152 | string: str = _yaml.safe_dump(self._commits) 153 | fp = folder / "git-hashes.log" 154 | logger.info("Git hashes are dumped to %s", fp) 155 | fp.write_text(string, encoding="utf8") 156 | with chdir(folder): 157 | yield 158 | 159 | 160 | def identify_path(name: str | Path) -> Path: 161 | """Returns the absolute Path corresponding to the name. 162 | The name must either represent: 163 | - a local folder/file in the current working directory 164 | - a folder/file with an absolute path 165 | - a folder in the PYTHONPATH 166 | - a package installed in editable mode 167 | """ 168 | # local files or folder get precedence 169 | folders = ["."] + os.environ.get("PYTHONPATH", "").split(os.pathsep) 170 | for folder in folders: 171 | fp = Path(folder) / name 172 | if fp.exists(): 173 | return fp.absolute() 174 | # otherwise check for editable installations 175 | try: 176 | pdistrib = metadata.Distribution.from_name(str(name)) 177 | except Exception as e: # pylint: disable=broad-except 178 | raise ValueError( 179 | f"No folder/file named {name} in {os.getcwd()}, " 180 | "and failed to import it as well" 181 | ) from e 182 | direct_url_json = pdistrib.read_text("direct_url.json") 183 | if direct_url_json is None: # folder 184 | raise ValueError(f"Package {name} has not been installed from source") 185 | direct_url = json.loads(direct_url_json) 186 | pkg_is_editable = direct_url.get("dir_info", {}).get("editable", False) 187 | if not pkg_is_editable: 188 | raise ValueError(f"Package {name} is not editable") 189 | tag = "file://" 190 | url = direct_url["url"] 191 | if not url.startswith(tag): 192 | raise ValueError("Package url {url} for {name} is not local") 193 | fp = Path(url[len(tag) :]) / name 194 | if not fp.exists(): 195 | raise ValueError(f"Expected to copy {fp} but there's nothing there") 196 | return fp 197 | 198 | 199 | class Ignore: 200 | """Include/Exclude name patterns for shutil.copytree""" 201 | 202 | def __init__( 203 | self, includes: tp.Sequence[str] = (), excludes: tp.Sequence[str] = () 204 | ) -> None: 205 | self.includes = list(includes) 206 | self.excludes = list(excludes) 207 | 208 | def __call__(self, path: str | Path, names: tp.List[str]) -> tp.Set[str]: 209 | if not self.includes: 210 | included = set(names) 211 | else: 212 | included = set() 213 | for include in self.includes: 214 | included |= set(fnmatch.filter(set(names), include)) 215 | missing = set(names) - included 216 | path = Path(path) 217 | for excluded in missing: 218 | # always include subfolders except if explicitely excluded below 219 | if (path / excluded).is_dir(): 220 | included.add(excluded) 221 | for exclude in self.excludes: 222 | included -= set(fnmatch.filter(included, exclude)) 223 | return set(names) - included 224 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "exca" 3 | readme = "README.md" 4 | authors = [{name = "Meta FAIR"}] 5 | requires-python = ">=3.10" 6 | version = "0.4.5" 7 | description = "Execution and caching tool for python" 8 | 9 | dependencies = [ 10 | "numpy>=1.19", 11 | "pyyaml>=6.0", 12 | "pydantic>=2.5.0", 13 | "submitit>=1.5.1", 14 | ] 15 | 16 | [project.urls] 17 | Source = "https://github.com/facebookresearch/exca" 18 | Tracker = "https://github.com/facebookresearch/exca/issues" 19 | 20 | [project.optional-dependencies] 21 | dev = [ 22 | # optional features 23 | "pandas>=2.2.2", 24 | "torch>=2.0.1", 25 | "mne>=1.4.0", 26 | "pybv>=0.7.6", 27 | "nibabel>=5.1.0", 28 | "pyarrow>=17.0.0", 29 | # Test 30 | "pytest>=7.4.0", 31 | "pytest-markdown-docs>=0.5.0", 32 | "psutil>=6.1.1", 33 | # Format 34 | "tqdm>=4.65.0", 35 | "black==24.3.0", 36 | "isort==5.12.0", 37 | "pre-commit>=3.0.0", 38 | # Linters 39 | "mypy>=1.11.0", 40 | "pylint>=2.13.9", 41 | "flake8", 42 | # typing stubs 43 | "pandas-stubs", 44 | "types-PyYAML", 45 | "types-setuptools", 46 | "types-tqdm", 47 | "types-psutil", 48 | # documentation 49 | "sphinx>=7.4.7", 50 | # "sphinx_rtd_theme>=2.0.0", 51 | # "recommonmark>=0.7.1", 52 | # "autodocsumm>=0.2.12", 53 | "myst-parser>=3.0.1", 54 | ] 55 | 56 | [tool.black] 57 | line-length = 90 58 | exclude = ''' 59 | /( 60 | | \.git 61 | | \.mypy_cache 62 | )/ 63 | ''' 64 | force-exclude = ''' 65 | /( 66 | scratch 67 | )\ 68 | ''' 69 | 70 | [tool.setuptools.packages.find] 71 | where = ["."] # list of folders that contain the packages (["."] by default) 72 | include = ["exca*"] # package names should match these glob patterns (["*"] by default) 73 | exclude = [] # exclude packages matching these glob patterns (empty by default) 74 | namespaces = false # to disable scanning PEP 420 namespaces (true by default) 75 | 76 | [tool.setuptools.package-data] 77 | # not sufficient in editable mode: https://github.com/python/mypy/issues/13392 78 | # still need to install with: pip install --config-settings editable_mode=strict -e . 79 | "exca" = ["py.typed", ".pyi"] 80 | 81 | [tool.isort] 82 | profile = "black" 83 | line_length = 90 84 | skip_gitignore = true 85 | 86 | [tool.pylint] 87 | [tool.pylint."MESSAGES CONTROL"] 88 | # disabled messages 89 | # * no-member has a lot of false positive, mypy does it better 90 | disable = """ 91 | broad-except, 92 | fixme, 93 | invalid-name, 94 | logging-fstring-interpolation, 95 | missing-docstring, 96 | no-else-return, 97 | no-member, 98 | protected-access, 99 | too-few-public-methods, 100 | too-many-locals, 101 | too-many-statements, 102 | too-many-return-statements, 103 | too-many-branches, 104 | useless-import-alias, 105 | unspecified-encoding, 106 | use-dict-literal, 107 | useless-import-alias, 108 | import-outside-toplevel 109 | """ 110 | [tool.pylint.DESIGN] 111 | max-args = 6 112 | 113 | [tool.pylint.FORMAT] 114 | max-line-length = "140" 115 | 116 | [tool.pylint.SIMILARITIES] 117 | ignore-imports = "yes" 118 | 119 | [tool.mypy] 120 | plugins = ['pydantic.mypy'] 121 | show_error_codes = true 122 | 123 | [[tool.mypy.overrides]] 124 | module = ['pytest', 'setuptools', 'cloudpickle', 'mne', 'mne.*', 'nibabel', 'neuralset', 'pyarrow', 'pybv'] 125 | ignore_missing_imports = true 126 | [[tool.mypy.overrides]] 127 | # some packages we do not install 128 | module = ['exca.dumperloader'] 129 | disable_error_code = ['import-not-found', 'valid-type'] 130 | 131 | [tool.pydantic-mypy] 132 | init_forbid_extra = true 133 | init_typed = true 134 | warn_required_dynamic_aliases = true 135 | --------------------------------------------------------------------------------