├── .flake8 ├── .git-blame-ignore-revs ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ ├── config.yml │ └── feature_request.yaml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── publish.yml │ ├── python.yml │ └── towncrier-changelog.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── CONTRIBUTORS.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── Substra-logo-colour.svg ├── Substra-logo-white.svg ├── bin ├── generate_sdk_documentation.py └── generate_sdk_schemas_documentation.py ├── changes └── .gitkeep ├── docs ├── README.md └── technical_documentation.md ├── pyproject.toml ├── references ├── sdk.md ├── sdk_models.md └── sdk_schemas.md ├── substra ├── __init__.py ├── __version__.py ├── config.py └── sdk │ ├── __init__.py │ ├── archive │ ├── __init__.py │ ├── safezip.py │ └── tarsafe.py │ ├── backends │ ├── __init__.py │ ├── base.py │ ├── local │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── compute │ │ │ ├── __init__.py │ │ │ ├── spawner │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── docker.py │ │ │ │ └── subprocess.py │ │ │ └── worker.py │ │ ├── dal.py │ │ ├── db.py │ │ └── models.py │ └── remote │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── request_formatter.py │ │ └── rest_client.py │ ├── client.py │ ├── compute_plan.py │ ├── exceptions.py │ ├── fs.py │ ├── graph.py │ ├── hasher.py │ ├── models.py │ ├── schemas.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── data_factory.py ├── datastore.py ├── fl_interface.py ├── mocked_requests.py ├── sdk ├── __init__.py ├── data │ ├── symlink.zip │ └── traversal.zip ├── local │ ├── __init__.py │ ├── conftest.py │ └── test_debug.py ├── test_add.py ├── test_archive.py ├── test_cancel.py ├── test_client.py ├── test_describe.py ├── test_download.py ├── test_get.py ├── test_graph.py ├── test_list.py ├── test_rest_client.py ├── test_schemas.py ├── test_subprocess.py ├── test_update.py └── test_wait.py ├── test_request_formatter.py ├── test_utils.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | max-complexity = 10 4 | extend-ignore = E203, W503, N802, N803, N806 5 | # W503 is incompatible with flake8, see https://github.com/psf/black/pull/36 6 | # E203 must be disabled for flake8 to work with Black. 7 | # See https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#id1 8 | # N802, N803 and N806 prevent us from using upper cases in variables names, functions name and arguments. 9 | per-file-ignores = 10 | __init__.py:F401, 11 | 12 | exclude = 13 | .git 14 | .github 15 | .dvc 16 | __pycache__ 17 | .venv 18 | .mypy_cache 19 | .pytest_cache 20 | hubconf.py 21 | **local-worker 22 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | 3b9b9d3d24cc94bdf34ce39cf0273f410188cbd2 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @Substra/code-owners 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Report bug or performance issue 3 | title: "BUG: " 4 | labels: [Bug] 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for taking the time to fill out this bug report! 11 | - type: textarea 12 | id: context 13 | attributes: 14 | label: What are you trying to do? 15 | description: > 16 | Please provide some context on what you are trying to achieve. 17 | placeholder: 18 | validations: 19 | required: true 20 | - type: textarea 21 | id: issue-description 22 | attributes: 23 | label: Issue Description (what is happening?) 24 | description: > 25 | Please provide a description of the issue. 26 | validations: 27 | required: true 28 | - type: textarea 29 | id: expected-behavior 30 | attributes: 31 | label: Expected Behavior (what should happen?) 32 | description: > 33 | Please describe or show a code example of the expected behavior. 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: example 38 | attributes: 39 | label: Reproducible Example 40 | description: > 41 | If possible, provide a reproducible example. 42 | render: python 43 | 44 | - type: textarea 45 | id: os-version 46 | attributes: 47 | label: Operating system 48 | description: > 49 | Which operating system are you using? (Provide the version number) 50 | validations: 51 | required: true 52 | - type: textarea 53 | id: python-version 54 | attributes: 55 | label: Python version 56 | description: > 57 | Which Python version are you using? 58 | placeholder: > 59 | python --version 60 | validations: 61 | required: true 62 | - type: textarea 63 | id: substra-version 64 | attributes: 65 | label: Installed Substra versions 66 | description: > 67 | Which version of `substrafl`/ `substra` / `substra-tools` are you using? 68 | You can check if they are compatible in the [compatibility table](https://docs.substra.org/en/stable/additional/release.html#compatibility-table). 69 | placeholder: > 70 | pip freeze | grep substra 71 | render: python 72 | validations: 73 | required: true 74 | - type: textarea 75 | id: dependencies-version 76 | attributes: 77 | label: Installed versions of dependencies 78 | description: > 79 | Please provide versions of dependencies which might be relevant to your issue (eg. `helm` and `skaffold` version for a deployment issue, `numpy` and `pytorch` for an algorithmic issue). 80 | 81 | 82 | - type: textarea 83 | id: logs 84 | attributes: 85 | label: Logs / Stacktrace 86 | description: > 87 | Please copy-paste here any log and/or stacktrace that might be relevant. Remove confidential and personal information if necessary. 88 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Ask a question 4 | url: https://join.slack.com/t/substra-workspace/shared_invite/zt-1fqnk0nw6-xoPwuLJ8dAPXThfyldX8yA 5 | about: Don't hesitate to join the Substra community on Slack to ask all your questions! 6 | - name: User Documentation Improvements 7 | url: https://github.com/Substra/substra-documentation/issues 8 | about: For issues related to the User Documentation, please open an issue on the substra-documentation repository 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yaml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest an idea for Substra 3 | title: "ENH: " 4 | labels: [Enhancement] 5 | 6 | body: 7 | - type: checkboxes 8 | id: checks 9 | attributes: 10 | label: Feature Type 11 | description: Please check what type of feature request you would like to propose. 12 | options: 13 | - label: > 14 | Adding new functionality to Substra 15 | - label: > 16 | Changing existing functionality in Substra 17 | - label: > 18 | Removing existing functionality in Substra 19 | - type: textarea 20 | id: description 21 | attributes: 22 | label: Problem Description 23 | description: > 24 | Please describe what problem the feature would solve, e.g. "I wish I could use Substra to ..." 25 | validations: 26 | required: true 27 | - type: textarea 28 | id: feature 29 | attributes: 30 | label: Feature Description 31 | description: > 32 | Please describe what will be this new feature. 33 | validations: 34 | required: true 35 | - type: textarea 36 | id: alternative 37 | attributes: 38 | label: Alternative Solutions 39 | description: > 40 | Please describe any alternative solution (existing functionality, 3rd party package, etc.) 41 | that would satisfy the feature request. 42 | - type: textarea 43 | id: context 44 | attributes: 45 | label: Additional Context 46 | description: > 47 | Please provide any relevant GitHub issues, code examples or references that help describe and support 48 | the feature request. 49 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Related issue 2 | 3 | `#` followed by the number of the issue 4 | 5 | ## Summary 6 | 7 | ## Notes 8 | 9 | ## Please check if the PR fulfills these requirements 10 | 11 | - [ ] If necessary, the [changelog](https://github.com/Substra/substra/blob/main/CHANGELOG.md) has been updated 12 | - [ ] Tests for the changes have been added (for bug fixes / features) 13 | - [ ] Docs have been added / updated (for bug fixes / features) 14 | - [ ] The commit message follows the [conventional commit](https://www.conventionalcommits.org/en/v1.0.0/) specification 15 | - For any breaking changes, companion PRs have been opened on the following repositories: 16 | - [ ] [substra-tests](https://github.com/Substra/substra) 17 | - [ ] [substrafl](https://github.com/Substra/substrafl) 18 | - [ ] [substra-documentation](https://github.com/Substra/substra-documentation) 19 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: 3.11 15 | - name: Install Hatch 16 | run: pipx install hatch 17 | - name: Build dist 18 | run: hatch build 19 | - name: Publish 20 | run: hatch publish -u __token__ -a ${{ secrets.PYPI_API_TOKEN }} 21 | 22 | -------------------------------------------------------------------------------- /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | name: Python 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | pull_request: 8 | 9 | jobs: 10 | lint: 11 | name: Lint and documentation 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.12" 19 | - name: Install tools 20 | run: pip install flake8 black isort wheel docstring-parser 21 | - name: Lint 22 | run: | 23 | black --check substra 24 | isort --check substra 25 | flake8 substra 26 | - name: Install substra 27 | run: pip install -e . 28 | - name: Generate and validate SDK documentation 29 | run: | 30 | python bin/generate_sdk_documentation.py --output-path='references/sdk.md' 31 | python bin/generate_sdk_schemas_documentation.py --output-path references/sdk_schemas.md 32 | python bin/generate_sdk_schemas_documentation.py --models --output-path='references/sdk_models.md' 33 | - name: Documentation artifacts 34 | uses: actions/upload-artifact@v4 35 | if: always() 36 | with: 37 | retention-days: 1 38 | name: references 39 | path: references/* 40 | tests: 41 | runs-on: ubuntu-20.04 42 | env: 43 | XDG_RUNTIME_DIR: /home/runner/.docker/run 44 | DOCKER_HOST: unix:///home/runner/.docker/run/docker.sock 45 | strategy: 46 | fail-fast: false 47 | matrix: 48 | python-version: ["3.10", "3.11", "3.12"] 49 | name: Tests on Python ${{ matrix.python-version }} 50 | steps: 51 | - name: Set up python 52 | uses: actions/setup-python@v5 53 | with: 54 | python-version: ${{ matrix.python-version }} 55 | architecture: x64 56 | - name: Install Docker rootless 57 | run: | 58 | sudo systemctl disable --now docker.service 59 | export FORCE_ROOTLESS_INSTALL=1 60 | curl -fsSL https://get.docker.com/rootless | sh 61 | - name: Configure docker 62 | run: | 63 | export PATH=/home/runner/bin:$PATH 64 | /home/runner/bin/dockerd-rootless.sh & # Start Docker rootless in the background 65 | - name: Cloning substra 66 | uses: actions/checkout@v4 67 | with: 68 | path: substra 69 | - name: Cloning substratools 70 | uses: actions/checkout@v4 71 | with: 72 | repository: Substra/substra-tools 73 | path: substratools 74 | ref: main 75 | - name: Install substra and substratools 76 | run: | 77 | pip install --no-cache-dir -e substratools 78 | pip install --no-cache-dir -e 'substra[dev]' 79 | 80 | - name: Test 81 | run: | 82 | export PATH=/home/runner/bin:$PATH 83 | cd substra && make test 84 | -------------------------------------------------------------------------------- /.github/workflows/towncrier-changelog.yml: -------------------------------------------------------------------------------- 1 | name: Towncrier changelog 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | app_version: 7 | type: string 8 | description: 'The version of the app' 9 | required: true 10 | branch: 11 | type: string 12 | description: 'The branch to update' 13 | required: true 14 | 15 | jobs: 16 | test-generate-publish: 17 | uses: substra/substra-gha-workflows/.github/workflows/towncrier-changelog.yml@main 18 | secrets: inherit 19 | with: 20 | app_version: ${{ inputs.app_version }} 21 | repo: substra 22 | branch: ${{ inputs.branch }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized 2 | __pycache__/ 3 | 4 | # Distribution / packaging 5 | build/ 6 | dist/ 7 | sdist/ 8 | .eggs/* 9 | *.egg-info 10 | 11 | # Unit test / coverage reports 12 | .cache 13 | .coverage* 14 | .pytest_cache/ 15 | 16 | # Developer environments 17 | .idea 18 | .vscode 19 | .env 20 | .venv 21 | .DS_Store 22 | .mypy_cache 23 | 24 | local-worker/ 25 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: "^references/*" 2 | repos: 3 | - repo: https://github.com/psf/black 4 | rev: 24.1.1 5 | hooks: 6 | - id: black 7 | 8 | - repo: https://github.com/pycqa/flake8 9 | rev: 7.0.0 10 | hooks: 11 | - id: flake8 12 | additional_dependencies: [pep8-naming, flake8-bugbear] 13 | args: ['--classmethod-decorators=classmethod,validator,root_validator'] 14 | 15 | - repo: https://github.com/pycqa/isort 16 | rev: 5.12.0 17 | hooks: 18 | - id: isort 19 | 20 | - repo: https://github.com/pre-commit/pre-commit-hooks 21 | rev: v4.3.0 22 | hooks: 23 | - id: trailing-whitespace 24 | - id: end-of-file-fixer 25 | - id: debug-statements 26 | - id: check-added-large-files 27 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | Substra repositories' code of conduct is available in the Substra documentation [here](https://docs.substra.org/en/stable/contributing/code-of-conduct.html). 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Substra repositories' contributing guide is available in the Substra documentation [here](https://docs.substra.org/en/stable/contributing/contributing-guide.html). 2 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | This is a file of people that have made significant contributions to the Substra sdk. It is sorted in chronological order. Please include your contribution at the bottom of this document in the following format : name (N), email (E), description of work (W) and date (D). 2 | 3 | To have your contribution listed, your work must meet the minimum [threshold of originality](https://en.wikipedia.org/wiki/Threshold_of_originality), which will be evaluated by the maintainers of the repository. 4 | 5 | Thank you for your contribution, your work is greatly appreciated ! 6 | 7 | —-- Example —-- 8 | 9 | - N: John Doe 10 | - E: john.doe@owkin.com 11 | - W: Integrated new feature 12 | - D: 02/02/2023 13 | 14 | --- 15 | 16 | Copyright (c) 2018-present Owkin Inc. All rights reserved. 17 | 18 | All other contributions: 19 | Copyright (c) 2023 to the respective contributors. 20 | All rights reserved. 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2018-2022 Owkin, Inc. 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include setup.cfg README.md 2 | 3 | recursive-include tests * 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: pyclean doc test doc-cli doc-sdk 2 | 3 | pyclean: 4 | find . -type f -name "*.py[co]" -delete 5 | find . -type d -name "__pycache__" -delete 6 | rm -rf build/ dist/ *.egg-info 7 | 8 | doc-sdk: pyclean 9 | python bin/generate_sdk_documentation.py 10 | python bin/generate_sdk_schemas_documentation.py 11 | python bin/generate_sdk_schemas_documentation.py --models --output-path='references/sdk_models.md' 12 | 13 | doc: doc-sdk 14 | 15 | test: pyclean 16 | pytest tests 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

4 | 5 |
6 | 7 | 8 | 9 | 10 | Substra 11 | 12 |
13 |
14 |
15 | 16 | Substra is an open source federated learning (FL) software. It enables the training and validation of machine learning models on distributed datasets. It provides a flexible Python interface and a web application to run federated learning training at scale. This specific repository is the low-level Python library used to interact with a Substra network. 17 | 18 | Substra's main usage is in production environments. It has already been deployed and used by hospitals and biotech companies (see the [MELLODDY](https://www.melloddy.eu/) project for instance). Substra can also be used on a single machine to perform FL simulations and debug code. 19 | 20 | Substra was originally developed by [Owkin](https://owkin.com/) and is now hosted by the [Linux Foundation for AI and Data](https://lfaidata.foundation/). Today Owkin is the main contributor to Substra. 21 | 22 | Join the discussion on [Slack](https://join.slack.com/t/substra-workspace/shared_invite/zt-1fqnk0nw6-xoPwuLJ8dAPXThfyldX8yA) and [subscribe here](https://lists.lfaidata.foundation/g/substra-announce/join) to our newsletter. 23 | 24 | 25 | ## To start using Substra 26 | 27 | Have a look at our [documentation](https://docs.substra.org/). 28 | 29 | Try out our [MNIST example](https://docs.substra.org/en/stable/substrafl_doc/examples/index.html#example-to-get-started-using-the-pytorch-interface). 30 | 31 | ## Support 32 | 33 | If you need support, please either raise an issue on Github or ask on [Slack](https://join.slack.com/t/substra-workspace/shared_invite/zt-1fqnk0nw6-xoPwuLJ8dAPXThfyldX8yA). 34 | 35 | 36 | 37 | ## Contributing 38 | 39 | Substra warmly welcomes any contribution. Feel free to fork the repo and create a pull request. 40 | 41 | 42 | ## Setup 43 | 44 | To setup the project in development mode, run: 45 | 46 | ```sh 47 | pip install -e ".[dev]" 48 | ``` 49 | 50 | To run all tests, use the following command: 51 | 52 | ```sh 53 | make test 54 | ``` 55 | 56 | Some of the tests require Docker running on your machine before running them. 57 | 58 | ## Code formatting 59 | 60 | You can opt into auto-formatting of code on pre-commit using [Black](https://github.com/psf/black). 61 | 62 | This relies on hooks managed by [pre-commit](https://pre-commit.com/), which you can set up as follows. 63 | 64 | Install [pre-commit](https://pre-commit.com/), then run: 65 | 66 | ```sh 67 | pre-commit install 68 | ``` 69 | 70 | ## Documentation generation 71 | 72 | To generate the command line interface documentation, sdk and schemas documentation, the `python` version 73 | must be 3.8. Run the following command: 74 | 75 | ```sh 76 | make doc 77 | ``` 78 | 79 | Documentation will be available in the *references/* directory. 80 | 81 | # Changelog generation 82 | 83 | The changelog is managed with [towncrier](https://towncrier.readthedocs.io/en/stable/index.html). 84 | To add a new entry in the changelog, add a file in the `changes` folder. The file name should have the following structure: 85 | `.`. 86 | The `unique_id` is a unique identifier, we currently use the PR number. 87 | The `change_type` can be of the following types: `added`, `changed`, `removed`, `fixed`. 88 | 89 | To generate the changelog (for example during a release), use the following command (you must have the dev dependencies installed): 90 | 91 | ``` 92 | towncrier build --version= 93 | ``` 94 | You can use the `--draft` option to see what would be generated without actually writing to the changelog (and without removing the fragments). 95 | -------------------------------------------------------------------------------- /Substra-logo-colour.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /Substra-logo-white.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /bin/generate_sdk_documentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import sys 4 | from pathlib import Path 5 | 6 | import docstring_parser 7 | 8 | from substra import Client 9 | from substra.sdk.utils import retry_on_exception 10 | 11 | MODULE_LIST = [Client, retry_on_exception] 12 | 13 | KEYWORDS = ["Args", "Returns", "Yields", "Raises", "Example"] 14 | 15 | 16 | def generate_function_help(fh, asset): 17 | """Write the description of a function""" 18 | fh.write(f"# {asset.__name__}\n") 19 | signature = str(inspect.signature(asset)) 20 | fh.write("```text\n") 21 | fh.write(f"{asset.__name__}{signature}") 22 | fh.write("\n```") 23 | fh.write("\n\n") 24 | # Write the docstring 25 | docstring = inspect.getdoc(asset) 26 | docstring = docstring_parser.parse(inspect.getdoc(asset)) 27 | fh.write(f"{docstring.short_description}\n") 28 | if docstring.long_description: 29 | fh.write(f"{docstring.long_description}\n") 30 | # Write the arguments as a list 31 | if len(docstring.params) > 0: 32 | fh.write("\n**Arguments:**\n") 33 | for param in docstring.params: 34 | type_and_optional = "" 35 | if param.type_name or param.is_optional is not None: 36 | text_optional = "required" 37 | if param.is_optional: 38 | text_optional = "optional" 39 | type_and_optional = f"({param.type_name}, {text_optional})" 40 | fh.write(f" - `{param.arg_name} {type_and_optional}`: {param.description}\n") 41 | # Write everything else as is 42 | for param in [ 43 | meta_param for meta_param in docstring.meta if not isinstance(meta_param, docstring_parser.DocstringParam) 44 | ]: 45 | fh.write(f"\n**{param.args[0].title()}:**\n") 46 | if len(param.args) > 1: 47 | for extra_param in param.args[1:]: 48 | fh.write(f"\n - `{extra_param}`: ") 49 | fh.write(f"{param.description}\n") 50 | 51 | 52 | def generate_properties_help(fh, public_methods): 53 | properties = [(f_name, f_method) for f_name, f_method in public_methods if isinstance(f_method, property)] 54 | for f_name, f_method in properties: 55 | fh.write(f"## {f_name}\n") 56 | fh.write("_This is a property._ \n") 57 | fh.write(f"{f_method.__doc__}\n") 58 | 59 | 60 | def generate_help(fh): 61 | for asset in MODULE_LIST: 62 | if inspect.isclass(asset): # Class 63 | public_methods = [ 64 | (f_name, f_method) for f_name, f_method in inspect.getmembers(asset) if not f_name.startswith("_") 65 | ] 66 | generate_function_help(fh, asset) 67 | generate_properties_help(fh, public_methods) 68 | for _, f_method in public_methods: 69 | if not isinstance(f_method, property): 70 | fh.write("#") # Title for the methods are one level below 71 | generate_function_help(fh, f_method) 72 | elif callable(asset): 73 | generate_function_help(fh, asset) 74 | 75 | 76 | def write_help(path): 77 | with path.open("w") as fh: 78 | generate_help(fh) 79 | 80 | 81 | if __name__ == "__main__": 82 | default_path = Path(__file__).resolve().parents[1] / "references" / "sdk.md" 83 | 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument("--output-path", type=str, default=str(default_path.resolve()), required=False) 86 | 87 | args = parser.parse_args(sys.argv[1:]) 88 | write_help(Path(args.output_path)) 89 | -------------------------------------------------------------------------------- /bin/generate_sdk_schemas_documentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import sys 4 | import warnings 5 | from pathlib import Path 6 | 7 | import pydantic 8 | 9 | from substra.sdk import models 10 | from substra.sdk import schemas 11 | 12 | local_dir = Path(__file__).parent 13 | 14 | schemas_list = [ 15 | schemas.DataSampleSpec, 16 | schemas.DatasetSpec, 17 | schemas.UpdateDatasetSpec, 18 | schemas.FunctionSpec, 19 | schemas.FunctionInputSpec, 20 | schemas.FunctionOutputSpec, 21 | schemas.TaskSpec, 22 | schemas.ComputeTaskOutputSpec, 23 | schemas.UpdateFunctionSpec, 24 | schemas.ComputePlanSpec, 25 | schemas.UpdateComputePlanSpec, 26 | schemas.UpdateComputePlanTasksSpec, 27 | schemas.ComputePlanTaskSpec, 28 | schemas.Permissions, 29 | schemas.PrivatePermissions, 30 | ] 31 | 32 | models_list = [ 33 | models.DataSample, 34 | models.Dataset, 35 | models.Task, 36 | models.Function, 37 | models.ComputePlan, 38 | models.Performances, 39 | models.Organization, 40 | models.Permissions, 41 | models.InModel, 42 | models.OutModel, 43 | ] 44 | 45 | 46 | def _get_field_description(fields): 47 | desc = [f"{name}: {field.annotation}" for name, field in fields.items()] 48 | return desc 49 | 50 | 51 | def generate_help(fh, models: bool): 52 | if models: 53 | asset_list = models_list 54 | title = "Models" 55 | else: 56 | asset_list = schemas_list 57 | title = "Schemas" 58 | 59 | fh.write("# Summary\n\n") 60 | 61 | def _create_anchor(schema): 62 | return "#{}".format(schema.__name__) 63 | 64 | for asset in asset_list: 65 | anchor = _create_anchor(asset) 66 | fh.write(f"- [{asset.__name__}]({anchor})\n") 67 | 68 | fh.write("\n\n") 69 | fh.write(f"# {title}\n\n") 70 | 71 | for asset in asset_list: 72 | anchor = _create_anchor(asset) 73 | 74 | fh.write(f"## {asset.__name__}\n") 75 | # Write the docstring 76 | fh.write(f"{inspect.getdoc(asset)}\n") 77 | # List the fields and their types 78 | description = _get_field_description(asset.model_fields) 79 | fh.write("```text\n") 80 | fh.write("- " + "\n- ".join(description)) 81 | fh.write("\n```") 82 | fh.write("\n\n") 83 | 84 | 85 | def write_help(path, models: bool): 86 | with path.open("w") as fh: 87 | generate_help(fh, models) 88 | 89 | 90 | if __name__ == "__main__": 91 | expected_pydantic_version = "2.3.0" 92 | if pydantic.VERSION != expected_pydantic_version: 93 | warnings.warn( 94 | f"The documentation should be generated with the version {expected_pydantic_version} of pydantic or \ 95 | there might be mismatches with the CI: version {pydantic.VERSION} used" 96 | ) 97 | 98 | doc_dir = local_dir.parent / "references" 99 | default_path = doc_dir / "sdk_schemas.md" 100 | 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument("--output-path", type=str, default=str(default_path.resolve()), required=False) 103 | parser.add_argument( 104 | "--models", 105 | action="store_true", 106 | help="Generate the doc for the models.\ 107 | Default: generate for the schemas", 108 | ) 109 | 110 | args = parser.parse_args(sys.argv[1:]) 111 | write_help(Path(args.output_path), models=args.models) 112 | -------------------------------------------------------------------------------- /changes/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/changes/.gitkeep -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Substra documentation 2 | 3 | The Substra documentation is hosted by Read The Docs and can be found [here](https://docs.substra.org/). 4 | -------------------------------------------------------------------------------- /docs/technical_documentation.md: -------------------------------------------------------------------------------- 1 | # Substra technical documentation 2 | 3 | Client operations: 4 | 5 | ```mermaid 6 | graph TD; 7 | Client-->debug{debug} 8 | debug-->|False|remotebackend[remote backend] 9 | debug-->|True|local[local backend] 10 | local-->|CRUD|dal 11 | dal-->|read|isremote{asset is remote} 12 | isremote-->|True|remotebackend 13 | isremote-->|False|db 14 | dal-->|save / update|db 15 | local-->|execute task|worker 16 | worker-->|CRUD|dal 17 | worker-->spawner{spawner} 18 | spawner-->docker 19 | spawner-->subprocess 20 | 21 | click Client "https://github.com/owkin/substra/blob/main/substra/sdk/client.py" 22 | click debug "https://github.com/owkin/substra/blob/main/substra/sdk/client.py#L75" 23 | click remotebackend "https://github.com/owkin/substra/blob/main/substra/sdk/backends/remote/backend.py" 24 | click local "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/backend.py" 25 | click dal "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/dal.py" 26 | click db "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/db.py" 27 | click worker "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/compute/worker.py" 28 | click spawner "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/compute/worker.py#L69" 29 | click docker "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/compute/spawner/docker.py" 30 | click subprocess "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/compute/spawner/subprocess.py" 31 | ``` 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.build.targets.sdist] 6 | exclude = ["tests*"] 7 | 8 | [tool.hatch.version] 9 | path = "substra/__version__.py" 10 | 11 | [project] 12 | name = "substra" 13 | description = "Low-level Python library for interacting with a Substra network" 14 | dynamic = ["version"] 15 | readme = "README.md" 16 | requires-python = ">= 3.10" 17 | dependencies = [ 18 | "requests!=2.32.*", 19 | "docker", 20 | "pyyaml", 21 | "pydantic>=2.3.0,<3.0.0", 22 | "tqdm", 23 | "python-slugify", 24 | ] 25 | keywords = ["substra"] 26 | classifiers = [ 27 | "Intended Audience :: Developers", 28 | "Topic :: Utilities", 29 | "Natural Language :: English", 30 | "Operating System :: OS Independent", 31 | "Programming Language :: Python :: 3.10", 32 | "Programming Language :: Python :: 3.11", 33 | "Programming Language :: Python :: 3.12", 34 | ] 35 | license = { file = "LICENSE" } 36 | authors = [{ name = "Owkin, Inc." }] 37 | 38 | 39 | [project.optional-dependencies] 40 | dev = [ 41 | "pandas", 42 | "pytest", 43 | "pytest-cov", 44 | "pytest-mock", 45 | "substratools~=1.0.0", 46 | "black", 47 | "flake8", 48 | "isort", 49 | "docstring_parser", 50 | "towncrier", 51 | ] 52 | 53 | [project.urls] 54 | Documentation = "https://docs.substra.org/en/stable/" 55 | Repository = "https://github.com/Substra/substra" 56 | Changelog = "https://github.com/Substra/substra/blob/main/CHANGELOG.md" 57 | 58 | [tool.pytest.ini_options] 59 | addopts = "-v --cov=substra --ignore=tests/unit --ignore=tests/e2e" 60 | 61 | [tool.black] 62 | line-length = 120 63 | target-version = ['py39'] 64 | 65 | [tool.isort] 66 | filter_files = true 67 | force_single_line = true 68 | line_length = 120 69 | profile = "black" 70 | 71 | [tool.towncrier] 72 | directory = "changes" 73 | filename = "CHANGELOG.md" 74 | start_string = "\n" 75 | underlines = ["", "", ""] 76 | title_format = "## [{version}](https://github.com/Substra/substra/releases/tag/{version}) - {project_date}" 77 | issue_format = "[#{issue}](https://github.com/Substra/substra/pull/{issue})" 78 | [tool.towncrier.fragment.added] 79 | [tool.towncrier.fragment.removed] 80 | [tool.towncrier.fragment.changed] 81 | [tool.towncrier.fragment.fixed] 82 | -------------------------------------------------------------------------------- /references/sdk_models.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | - [DataSample](#DataSample) 4 | - [Dataset](#Dataset) 5 | - [Task](#Task) 6 | - [Function](#Function) 7 | - [ComputePlan](#ComputePlan) 8 | - [Performances](#Performances) 9 | - [Organization](#Organization) 10 | - [Permissions](#Permissions) 11 | - [InModel](#InModel) 12 | - [OutModel](#OutModel) 13 | 14 | 15 | # Models 16 | 17 | ## DataSample 18 | Data sample 19 | ```text 20 | - key: 21 | - owner: 22 | - data_manager_keys: typing.Optional[typing.List[str]] 23 | - path: typing.Optional[typing.Annotated[pathlib.Path, PathType(path_type='dir')]] 24 | - creation_date: 25 | ``` 26 | 27 | ## Dataset 28 | Dataset asset 29 | ```text 30 | - key: 31 | - name: 32 | - owner: 33 | - permissions: 34 | - data_sample_keys: typing.List[str] 35 | - opener: 36 | - description: 37 | - metadata: typing.Dict[str, str] 38 | - creation_date: 39 | - logs_permission: 40 | ``` 41 | 42 | ## Task 43 | Asset creation specification base class. 44 | ```text 45 | - key: 46 | - function: 47 | - owner: 48 | - compute_plan_key: 49 | - metadata: typing.Dict[str, str] 50 | - status: 51 | - worker: 52 | - rank: typing.Optional[int] 53 | - tag: 54 | - creation_date: 55 | - start_date: typing.Optional[datetime.datetime] 56 | - end_date: typing.Optional[datetime.datetime] 57 | - error_type: typing.Optional[substra.sdk.models.TaskErrorType] 58 | - inputs: typing.List[substra.sdk.models.InputRef] 59 | - outputs: typing.Dict[str, substra.sdk.models.ComputeTaskOutput] 60 | ``` 61 | 62 | ## Function 63 | Asset creation specification base class. 64 | ```text 65 | - key: 66 | - name: 67 | - owner: 68 | - permissions: 69 | - metadata: typing.Dict[str, str] 70 | - creation_date: 71 | - inputs: typing.List[substra.sdk.models.FunctionInput] 72 | - outputs: typing.List[substra.sdk.models.FunctionOutput] 73 | - status: 74 | - description: 75 | - archive: 76 | ``` 77 | 78 | ## ComputePlan 79 | ComputePlan 80 | ```text 81 | - key: 82 | - tag: 83 | - name: 84 | - owner: 85 | - metadata: typing.Dict[str, str] 86 | - task_count: 87 | - waiting_builder_slot_count: 88 | - building_count: 89 | - waiting_parent_tasks_count: 90 | - waiting_executor_slot_count: 91 | - executing_count: 92 | - canceled_count: 93 | - failed_count: 94 | - done_count: 95 | - failed_task_key: typing.Optional[str] 96 | - status: 97 | - creation_date: 98 | - start_date: typing.Optional[datetime.datetime] 99 | - end_date: typing.Optional[datetime.datetime] 100 | - estimated_end_date: typing.Optional[datetime.datetime] 101 | - duration: typing.Optional[int] 102 | - creator: typing.Optional[str] 103 | ``` 104 | 105 | ## Performances 106 | Performances of the different compute tasks of a compute plan 107 | ```text 108 | - compute_plan_key: typing.List[str] 109 | - compute_plan_tag: typing.List[str] 110 | - compute_plan_status: typing.List[str] 111 | - compute_plan_start_date: typing.List[datetime.datetime] 112 | - compute_plan_end_date: typing.List[datetime.datetime] 113 | - compute_plan_metadata: typing.List[dict] 114 | - worker: typing.List[str] 115 | - task_key: typing.List[str] 116 | - task_rank: typing.List[int] 117 | - round_idx: typing.List[int] 118 | - identifier: typing.List[str] 119 | - performance: typing.List[float] 120 | ``` 121 | 122 | ## Organization 123 | Organization 124 | ```text 125 | - id: 126 | - is_current: 127 | - creation_date: 128 | ``` 129 | 130 | ## Permissions 131 | Permissions structure stored in various asset types. 132 | ```text 133 | - process: 134 | ``` 135 | 136 | ## InModel 137 | In model of a task 138 | ```text 139 | - checksum: 140 | - storage_address: typing.Union[typing.Annotated[pathlib.Path, PathType(path_type='file')], pydantic_core._pydantic_core.Url, str] 141 | ``` 142 | 143 | ## OutModel 144 | Out model of a task 145 | ```text 146 | - key: 147 | - compute_task_key: 148 | - address: typing.Optional[substra.sdk.models.InModel] 149 | - permissions: 150 | - owner: 151 | - creation_date: 152 | ``` 153 | 154 | -------------------------------------------------------------------------------- /references/sdk_schemas.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | - [DataSampleSpec](#DataSampleSpec) 4 | - [DatasetSpec](#DatasetSpec) 5 | - [UpdateDatasetSpec](#UpdateDatasetSpec) 6 | - [FunctionSpec](#FunctionSpec) 7 | - [FunctionInputSpec](#FunctionInputSpec) 8 | - [FunctionOutputSpec](#FunctionOutputSpec) 9 | - [TaskSpec](#TaskSpec) 10 | - [ComputeTaskOutputSpec](#ComputeTaskOutputSpec) 11 | - [UpdateFunctionSpec](#UpdateFunctionSpec) 12 | - [ComputePlanSpec](#ComputePlanSpec) 13 | - [UpdateComputePlanSpec](#UpdateComputePlanSpec) 14 | - [UpdateComputePlanTasksSpec](#UpdateComputePlanTasksSpec) 15 | - [ComputePlanTaskSpec](#ComputePlanTaskSpec) 16 | - [Permissions](#Permissions) 17 | - [PrivatePermissions](#PrivatePermissions) 18 | 19 | 20 | # Schemas 21 | 22 | ## DataSampleSpec 23 | Specification to create one or many data samples 24 | To create one data sample, use the 'path' field, otherwise use 25 | the 'paths' field. 26 | ```text 27 | - path: typing.Optional[pathlib.Path] 28 | - paths: typing.Optional[typing.List[pathlib.Path]] 29 | - data_manager_keys: typing.List[str] 30 | ``` 31 | 32 | ## DatasetSpec 33 | Specification for creating a dataset 34 | 35 | note : metadata field does not accept strings containing '__' as dict key 36 | 37 | note : If no description markdown file is given, create an empty one on the data_opener folder. 38 | ```text 39 | - name: 40 | - data_opener: 41 | - description: typing.Optional[pathlib.Path] 42 | - permissions: 43 | - metadata: typing.Optional[typing.Dict[str, str]] 44 | - logs_permission: 45 | ``` 46 | 47 | ## UpdateDatasetSpec 48 | Specification for updating a dataset 49 | ```text 50 | - name: 51 | ``` 52 | 53 | ## FunctionSpec 54 | Specification for creating an function 55 | 56 | note : metadata field does not accept strings containing '__' as dict key 57 | ```text 58 | - name: 59 | - description: 60 | - file: 61 | - permissions: 62 | - metadata: typing.Optional[typing.Dict[str, str]] 63 | - inputs: typing.Optional[typing.List[substra.sdk.schemas.FunctionInputSpec]] 64 | - outputs: typing.Optional[typing.List[substra.sdk.schemas.FunctionOutputSpec]] 65 | ``` 66 | 67 | ## FunctionInputSpec 68 | Asset creation specification base class. 69 | ```text 70 | - identifier: 71 | - multiple: 72 | - optional: 73 | - kind: 74 | ``` 75 | 76 | ## FunctionOutputSpec 77 | Asset creation specification base class. 78 | ```text 79 | - identifier: 80 | - kind: 81 | - multiple: 82 | ``` 83 | 84 | ## TaskSpec 85 | Asset creation specification base class. 86 | ```text 87 | - key: 88 | - tag: typing.Optional[str] 89 | - compute_plan_key: typing.Optional[str] 90 | - metadata: typing.Optional[typing.Dict[str, str]] 91 | - function_key: 92 | - worker: 93 | - rank: typing.Optional[int] 94 | - inputs: typing.Optional[typing.List[substra.sdk.schemas.InputRef]] 95 | - outputs: typing.Optional[typing.Dict[str, substra.sdk.schemas.ComputeTaskOutputSpec]] 96 | ``` 97 | 98 | ## ComputeTaskOutputSpec 99 | Specification of a compute task output 100 | ```text 101 | - permissions: 102 | - is_transient: typing.Optional[bool] 103 | ``` 104 | 105 | ## UpdateFunctionSpec 106 | Specification for updating an function 107 | ```text 108 | - name: 109 | ``` 110 | 111 | ## ComputePlanSpec 112 | Specification for creating a compute plan 113 | 114 | note : metadata field does not accept strings containing '__' as dict key 115 | ```text 116 | - key: 117 | - tasks: typing.Optional[typing.List[substra.sdk.schemas.ComputePlanTaskSpec]] 118 | - tag: typing.Optional[str] 119 | - name: 120 | - metadata: typing.Optional[typing.Dict[str, str]] 121 | ``` 122 | 123 | ## UpdateComputePlanSpec 124 | Specification for updating a compute plan 125 | ```text 126 | - name: 127 | ``` 128 | 129 | ## UpdateComputePlanTasksSpec 130 | Specification for updating a compute plan's tasks 131 | ```text 132 | - key: 133 | - tasks: typing.Optional[typing.List[substra.sdk.schemas.ComputePlanTaskSpec]] 134 | ``` 135 | 136 | ## ComputePlanTaskSpec 137 | Specification of a compute task inside a compute plan specification 138 | 139 | note : metadata field does not accept strings containing '__' as dict key 140 | ```text 141 | - task_id: 142 | - function_key: 143 | - worker: 144 | - tag: typing.Optional[str] 145 | - metadata: typing.Optional[typing.Dict[str, str]] 146 | - inputs: typing.Optional[typing.List[substra.sdk.schemas.InputRef]] 147 | - outputs: typing.Optional[typing.Dict[str, substra.sdk.schemas.ComputeTaskOutputSpec]] 148 | ``` 149 | 150 | ## Permissions 151 | Specification for permissions. If public is False, 152 | give the list of authorized ids. 153 | ```text 154 | - public: 155 | - authorized_ids: typing.List[str] 156 | ``` 157 | 158 | ## PrivatePermissions 159 | Specification for private permissions. Only the organizations whose 160 | ids are in authorized_ids can access the asset. 161 | ```text 162 | - authorized_ids: typing.List[str] 163 | ``` 164 | 165 | -------------------------------------------------------------------------------- /substra/__init__.py: -------------------------------------------------------------------------------- 1 | from substra.__version__ import __version__ 2 | from substra.sdk import Client 3 | from substra.sdk import exceptions 4 | from substra.sdk import models 5 | from substra.sdk import schemas 6 | from substra.sdk.schemas import BackendType 7 | 8 | __all__ = [ 9 | "__version__", 10 | "Client", 11 | "exceptions", 12 | "BackendType", 13 | "schemas", 14 | "models", 15 | ] 16 | -------------------------------------------------------------------------------- /substra/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | -------------------------------------------------------------------------------- /substra/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/substra/config.py -------------------------------------------------------------------------------- /substra/sdk/__init__.py: -------------------------------------------------------------------------------- 1 | from substra.sdk import models 2 | from substra.sdk import schemas 3 | from substra.sdk.client import Client 4 | from substra.sdk.schemas import BackendType 5 | from substra.sdk.utils import retry_on_exception 6 | 7 | __all__ = [ 8 | "Client", 9 | "retry_on_exception", 10 | "BackendType", 11 | "schemas", 12 | "models", 13 | ] 14 | -------------------------------------------------------------------------------- /substra/sdk/archive/__init__.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | import zipfile 3 | 4 | from substra.sdk import exceptions 5 | from substra.sdk.archive import tarsafe 6 | from substra.sdk.archive.safezip import ZipFile 7 | 8 | 9 | def _untar(archive, to_): 10 | with tarsafe.open(archive, "r:*") as tf: 11 | tf.extractall(to_) 12 | 13 | 14 | def _unzip(archive, to_): 15 | with ZipFile(archive, "r") as zf: 16 | zf.extractall(to_) 17 | 18 | 19 | def uncompress(archive, to_): 20 | """Uncompress tar or zip archive to destination.""" 21 | if tarfile.is_tarfile(archive): 22 | _untar(archive, to_) 23 | elif zipfile.is_zipfile(archive): 24 | _unzip(archive, to_) 25 | else: 26 | raise exceptions.InvalidRequest(f"Cannot uncompress '{archive}'", 400) 27 | -------------------------------------------------------------------------------- /substra/sdk/archive/safezip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import stat 3 | import zipfile 4 | 5 | 6 | class ZipFile(zipfile.ZipFile): 7 | """Override Zipfile to ensure unix file permissions are preserved. 8 | 9 | This is due to a python bug: 10 | https://bugs.python.org/issue15795 11 | 12 | Workaround from: 13 | https://stackoverflow.com/questions/39296101/python-zipfile-removes-execute-permissions-from-binaries 14 | """ 15 | 16 | def extract(self, member, path=None, pwd=None): 17 | if not isinstance(member, zipfile.ZipInfo): 18 | member = self.getinfo(member) 19 | 20 | if path is None: 21 | path = os.getcwd() 22 | 23 | ret_val = self._extract_member(member, path, pwd) 24 | attr = member.external_attr >> 16 25 | os.chmod(ret_val, attr) 26 | return ret_val 27 | 28 | def extractall(self, path=None, members=None, pwd=None): 29 | self._sanity_check() 30 | super().extractall(path, members, pwd) 31 | 32 | def _sanity_check(self): 33 | """Check that the archive does not attempt to traverse path. 34 | 35 | This is inspired by TarSafe: https://github.com/beatsbears/tarsafe 36 | """ 37 | for zipinfo in self.infolist(): 38 | if self._is_traversal_attempt(zipinfo): 39 | raise Exception(f"Attempted directory traversal for member: {zipinfo.filename}") 40 | if self._is_symlink(zipinfo): 41 | raise Exception(f"Unsupported symlink for member: {zipinfo.filename}") 42 | 43 | def _is_traversal_attempt(self, zipinfo: zipfile.ZipInfo) -> bool: 44 | base_directory = os.getcwd() 45 | zipfile_path = os.path.abspath(os.path.join(base_directory, zipinfo.filename)) 46 | if not zipfile_path.startswith(base_directory): 47 | return True 48 | return False 49 | 50 | def _is_symlink(self, zipinfo: zipfile.ZipInfo) -> bool: 51 | return zipinfo.external_attr & stat.S_IFLNK << 16 == stat.S_IFLNK << 16 52 | -------------------------------------------------------------------------------- /substra/sdk/archive/tarsafe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import tarfile 4 | 5 | 6 | class TarSafe(tarfile.TarFile): 7 | """ 8 | A safe subclass of the TarFile class for interacting with tar files. 9 | """ 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.directory = os.getcwd() 14 | 15 | @classmethod 16 | def open(cls, name=None, mode="r", fileobj=None, bufsize=tarfile.RECORDSIZE, **kwargs): 17 | return super().open(name, mode, fileobj, bufsize, **kwargs) 18 | 19 | def extract(self, member, path="", set_attrs=True, *, numeric_owner=False): 20 | """ 21 | Override the parent extract method and add safety checks. 22 | """ 23 | self._safetar_check() 24 | super().extract(member, path, set_attrs=set_attrs, numeric_owner=numeric_owner) 25 | 26 | def extractall(self, path=".", members=None, numeric_owner=False): 27 | """ 28 | Override the parent extractall method and add safety checks. 29 | """ 30 | self._safetar_check() 31 | super().extractall(path, members, numeric_owner=numeric_owner) 32 | 33 | def _safetar_check(self): 34 | """ 35 | Runs all necessary checks for the safety of a tarfile. 36 | """ 37 | try: 38 | for tarinfo in self.__iter__(): 39 | if self._is_traversal_attempt(tarinfo=tarinfo): 40 | raise TarSafeError(f"Attempted directory traversal for member: {tarinfo.name}") 41 | if tarinfo.issym(): 42 | raise TarSafeError(f"Unsupported symlink for member: {tarinfo.linkname}") 43 | if self._is_unsafe_link(tarinfo=tarinfo): 44 | raise TarSafeError(f"Attempted directory traversal via link for member: {tarinfo.linkname}") 45 | if self._is_device(tarinfo=tarinfo): 46 | raise TarSafeError("tarfile returns true for isblk() or ischr()") 47 | except Exception: 48 | raise 49 | 50 | def _is_traversal_attempt(self, tarinfo): 51 | if not os.path.abspath(os.path.join(self.directory, tarinfo.name)).startswith(self.directory): 52 | return True 53 | return False 54 | 55 | def _is_unsafe_link(self, tarinfo): 56 | if tarinfo.islnk(): 57 | link_file = pathlib.Path(os.path.normpath(os.path.join(self.directory, tarinfo.linkname))) 58 | if not os.path.abspath(os.path.join(self.directory, link_file)).startswith(self.directory): 59 | return True 60 | return False 61 | 62 | def _is_device(self, tarinfo): 63 | return tarinfo.ischr() or tarinfo.isblk() 64 | 65 | 66 | class TarSafeError(Exception): 67 | pass 68 | 69 | 70 | open = TarSafe.open 71 | -------------------------------------------------------------------------------- /substra/sdk/backends/__init__.py: -------------------------------------------------------------------------------- 1 | from substra.sdk import schemas 2 | from substra.sdk.backends.local.backend import Local 3 | from substra.sdk.backends.remote.backend import Remote 4 | 5 | _BACKEND_CHOICES = { 6 | schemas.BackendType.REMOTE: Remote, 7 | schemas.BackendType.LOCAL_DOCKER: Local, 8 | schemas.BackendType.LOCAL_SUBPROCESS: Local, 9 | } 10 | 11 | 12 | def get(name, *args, **kwargs): 13 | return _BACKEND_CHOICES[name](*args, **kwargs, backend_type=name) 14 | -------------------------------------------------------------------------------- /substra/sdk/backends/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import List 3 | 4 | from substra.sdk.schemas import BackendType 5 | 6 | 7 | class BaseBackend(abc.ABC): 8 | @property 9 | @abc.abstractmethod 10 | def backend_mode(self) -> BackendType: 11 | raise NotImplementedError 12 | 13 | @abc.abstractmethod 14 | def login(self, username, password): 15 | pass 16 | 17 | @abc.abstractmethod 18 | def logout(self): 19 | pass 20 | 21 | @abc.abstractmethod 22 | def get(self, asset_type, key): 23 | raise NotImplementedError 24 | 25 | @abc.abstractmethod 26 | def list(self, asset_type, filters=None, paginated=False): 27 | raise NotImplementedError 28 | 29 | @abc.abstractmethod 30 | def add(self, spec, spec_options=None): 31 | raise NotImplementedError 32 | 33 | @abc.abstractmethod 34 | def update(self, key, spec, spec_options=None): 35 | raise NotImplementedError 36 | 37 | @abc.abstractmethod 38 | def add_compute_plan_tasks(self, spec, spec_options): 39 | raise NotImplementedError 40 | 41 | @abc.abstractmethod 42 | def link_dataset_with_data_samples(self, dataset_key, data_sample_keys) -> List[str]: 43 | raise NotImplementedError 44 | 45 | @abc.abstractmethod 46 | def download(self, asset_type, url_field_path, key, destination): 47 | raise NotImplementedError 48 | 49 | @abc.abstractmethod 50 | def describe(self, asset_type, key): 51 | raise NotImplementedError 52 | 53 | @abc.abstractmethod 54 | def cancel_compute_plan(self, key): 55 | raise NotImplementedError 56 | -------------------------------------------------------------------------------- /substra/sdk/backends/local/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/substra/sdk/backends/local/__init__.py -------------------------------------------------------------------------------- /substra/sdk/backends/local/compute/__init__.py: -------------------------------------------------------------------------------- 1 | from substra.sdk.backends.local.compute.worker import Worker 2 | 3 | __all__ = [ 4 | "Worker", 5 | ] 6 | -------------------------------------------------------------------------------- /substra/sdk/backends/local/compute/spawner/__init__.py: -------------------------------------------------------------------------------- 1 | from substra.sdk.backends.local.compute.spawner.base import BaseSpawner 2 | from substra.sdk.backends.local.compute.spawner.docker import Docker 3 | from substra.sdk.backends.local.compute.spawner.subprocess import Subprocess 4 | from substra.sdk.schemas import BackendType 5 | 6 | __all__ = ["BaseSpawner", "Docker", "Subprocess"] 7 | 8 | DEBUG_SPAWNER_CHOICES = { 9 | BackendType.LOCAL_DOCKER: Docker, 10 | BackendType.LOCAL_SUBPROCESS: Subprocess, 11 | } 12 | 13 | 14 | def get(name, *args, **kwargs): 15 | """Return a Docker spawner or a Subprocess spawner""" 16 | return DEBUG_SPAWNER_CHOICES[name](*args, **kwargs) 17 | -------------------------------------------------------------------------------- /substra/sdk/backends/local/compute/spawner/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import pathlib 3 | import string 4 | import typing 5 | 6 | VOLUME_CLI_ARGS = "_VOLUME_CLI_ARGS" 7 | VOLUME_INPUTS = "_VOLUME_INPUTS" 8 | VOLUME_OUTPUTS = "_VOLUME_OUTPUTS" 9 | 10 | 11 | class BuildError(Exception): 12 | """An error occurred during the build of the function""" 13 | 14 | pass 15 | 16 | 17 | class ExecutionError(Exception): 18 | """An error occurred during the execution of the compute task""" 19 | 20 | pass 21 | 22 | 23 | class BaseSpawner(abc.ABC): 24 | """Base wrapper to execute a command""" 25 | 26 | def __init__(self, local_worker_dir: pathlib.Path): 27 | self._local_worker_dir = local_worker_dir 28 | 29 | @abc.abstractmethod 30 | def spawn( 31 | self, 32 | name, 33 | archive_path, 34 | command_args_tpl: typing.List[string.Template], 35 | data_sample_paths: typing.Optional[typing.Dict[str, pathlib.Path]], 36 | local_volumes, 37 | envs, 38 | ): 39 | """Execute archive in a contained environment.""" 40 | raise NotImplementedError 41 | 42 | 43 | def write_command_args_file(args_file: pathlib.Path, command_args: typing.List[str]) -> None: 44 | """Write the substra-tools command line arguments to a file. 45 | 46 | The format uses one line per argument. See 47 | https://docs.python.org/3/library/argparse.html#fromfile-prefix-chars 48 | """ 49 | with open(args_file, "w") as f: 50 | for item in command_args: 51 | f.write(item + "\n") 52 | -------------------------------------------------------------------------------- /substra/sdk/backends/local/compute/spawner/docker.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import shutil 4 | import string 5 | import tempfile 6 | import typing 7 | 8 | import docker 9 | 10 | from substra.sdk.archive import uncompress 11 | from substra.sdk.backends.local.compute.spawner.base import VOLUME_CLI_ARGS 12 | from substra.sdk.backends.local.compute.spawner.base import VOLUME_INPUTS 13 | from substra.sdk.backends.local.compute.spawner.base import VOLUME_OUTPUTS 14 | from substra.sdk.backends.local.compute.spawner.base import BaseSpawner 15 | from substra.sdk.backends.local.compute.spawner.base import BuildError 16 | from substra.sdk.backends.local.compute.spawner.base import ExecutionError 17 | from substra.sdk.backends.local.compute.spawner.base import write_command_args_file 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | ROOT_DIR = "/substra_internal" 22 | DOCKER_VOLUMES = { 23 | VOLUME_INPUTS: {"bind": f"{ROOT_DIR}/inputs", "mode": "ro"}, 24 | VOLUME_OUTPUTS: {"bind": f"{ROOT_DIR}/outputs", "mode": "rw"}, 25 | VOLUME_CLI_ARGS: {"bind": f"{ROOT_DIR}/cli-args", "mode": "rw"}, 26 | } 27 | 28 | 29 | def _copy_data_samples(data_sample_paths: typing.Dict[str, pathlib.Path], dest_dir: str): 30 | """Move the data samples to the data directory. 31 | 32 | We copy the data samples even though it is slow because: 33 | 34 | - symbolic links do not work with Docker container volumes 35 | - hard links cannot be created across partitions 36 | - mounting each data sample as its own volume causes permission errors 37 | 38 | Args: 39 | data_sample_paths (typing.Dict[str, pathlib.Path]): Paths to the samples 40 | dest_dir (str): Temp data directory 41 | """ 42 | # Check if there are already data samples in the dest dir (testtasks are executed in 2 parts) 43 | sample_key = next(iter(data_sample_paths)) 44 | if (pathlib.Path(dest_dir) / sample_key).exists(): 45 | return 46 | # copy the whole tree 47 | for sample_key, sample_path in data_sample_paths.items(): 48 | dest_path = pathlib.Path(dest_dir) / sample_key 49 | shutil.copytree(sample_path, dest_path) 50 | 51 | 52 | class Docker(BaseSpawner): 53 | """Wrapper around docker daemon to execute a command in a container.""" 54 | 55 | def __init__(self, local_worker_dir: pathlib.Path): 56 | try: 57 | self._docker = docker.from_env() 58 | except docker.errors.DockerException as e: 59 | raise ConnectionError( 60 | "Couldn't get the Docker client from environment variables. " 61 | "Is your Docker server running ?\n" 62 | "Docker error : {0}".format(e) 63 | ) 64 | super().__init__(local_worker_dir=local_worker_dir) 65 | 66 | def _build_docker_image(self, name: str, archive_path: pathlib.Path): 67 | """Spawn a docker container (blocking).""" 68 | with tempfile.TemporaryDirectory(dir=self._local_worker_dir) as tmpdir: 69 | image_exists = False 70 | try: 71 | self._docker.images.get(name=name) 72 | image_exists = True 73 | except docker.errors.ImageNotFound: 74 | pass 75 | 76 | if not image_exists: 77 | try: 78 | logger.debug("Did not find the Docker image %s - building it", name) 79 | uncompress(archive_path, tmpdir) 80 | self._docker.images.build(path=tmpdir, tag=name, rm=True) 81 | except docker.errors.BuildError as exc: 82 | log = "" 83 | for line in exc.build_log: 84 | if "stream" in line: 85 | log += line["stream"].strip() 86 | logger.error(log) 87 | raise BuildError(log) 88 | 89 | def spawn( 90 | self, 91 | name: str, 92 | archive_path: pathlib.Path, 93 | command_args_tpl: typing.List[string.Template], 94 | data_sample_paths: typing.Optional[typing.Dict[str, pathlib.Path]], 95 | local_volumes: typing.Optional[dict], 96 | envs: typing.Optional[typing.List[str]], 97 | ): 98 | """Build the docker image, copy the data samples then spawn a Docker container 99 | and execute the task. 100 | """ 101 | self._build_docker_image(name=name, archive_path=archive_path) 102 | 103 | # format the command to replace each occurrence of a DOCKER_VOLUMES's key 104 | # by its "bind" value 105 | volumes_format = {volume_name: volume_path["bind"] for volume_name, volume_path in DOCKER_VOLUMES.items()} 106 | command_args = [tpl.substitute(**volumes_format) for tpl in command_args_tpl] 107 | 108 | if data_sample_paths is not None and len(data_sample_paths) > 0: 109 | _copy_data_samples(data_sample_paths, local_volumes[VOLUME_INPUTS]) 110 | 111 | args_filename = "arguments.txt" 112 | args_path_local = pathlib.Path(local_volumes[VOLUME_CLI_ARGS]) / args_filename 113 | # Pathlib is incompatible here for Windows, as it would create a "WindowsPath" 114 | args_path_docker = DOCKER_VOLUMES[VOLUME_CLI_ARGS]["bind"] + "/" + args_filename 115 | write_command_args_file(args_path_local, command_args) 116 | 117 | # create the volumes dict for docker by binding the local_volumes and the DOCKER_VOLUME 118 | volumes_docker = { 119 | volume_path: DOCKER_VOLUMES[volume_name] for volume_name, volume_path in local_volumes.items() 120 | } 121 | 122 | container = self._docker.containers.run( 123 | name, 124 | command=f"@{args_path_docker}", 125 | volumes=volumes_docker or {}, 126 | environment=envs, 127 | remove=False, 128 | detach=True, 129 | tty=True, 130 | stdin_open=True, 131 | shm_size="8G", 132 | ) 133 | 134 | execution_logs = [] 135 | for line in container.logs(stream=True, stdout=True, stderr=True): 136 | execution_logs.append(line.decode("utf-8")) 137 | 138 | r = container.wait() 139 | execution_logs_str = "".join(execution_logs) 140 | exit_code = r["StatusCode"] 141 | if exit_code != 0: 142 | logger.error("\n\nExecution logs: %s", execution_logs_str) 143 | raise ExecutionError(f"Container '{name}' exited with status code '{exit_code}'") 144 | 145 | container.remove() 146 | -------------------------------------------------------------------------------- /substra/sdk/backends/local/compute/spawner/subprocess.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pathlib 4 | import re 5 | import shutil 6 | import string 7 | import subprocess 8 | import sys 9 | import tempfile 10 | import typing 11 | 12 | from substra.sdk.archive import uncompress 13 | from substra.sdk.backends.local.compute.spawner.base import VOLUME_INPUTS 14 | from substra.sdk.backends.local.compute.spawner.base import BaseSpawner 15 | from substra.sdk.backends.local.compute.spawner.base import ExecutionError 16 | from substra.sdk.backends.local.compute.spawner.base import write_command_args_file 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | PYTHON_SCRIPT_REGEX = r"(?<=\")([^\"]*\.py)(?=\")" 21 | METHOD_REGEX = r"\"\-\-function-name\"\,\s*\"([^\"]*)\"" 22 | 23 | 24 | def _get_entrypoint_from_dockerfile(tmpdir): 25 | """ 26 | Extracts the .py script and the function name to execute in 27 | an ENTRYPOINT line of a Dockerfile, located in tmpdir. 28 | For instance if the line `ENTRYPOINT ["python3", "function.py", "--function-name", "train"]` is in the Dockerfile, 29 | `function.py`, `train` is extracted. 30 | """ 31 | valid_example = ( 32 | """The entry point should be specified as follow: """ 33 | """``ENTRYPOINT ["", "", "--function-name", ""]""" 34 | ) 35 | with open(tmpdir / "Dockerfile") as f: 36 | for line in f: 37 | if "ENTRYPOINT" not in line: 38 | continue 39 | 40 | script_name = re.findall(PYTHON_SCRIPT_REGEX, line) 41 | if len(script_name) != 1: 42 | raise ExecutionError("Couldn't extract script from ENTRYPOINT line in Dockerfile", valid_example) 43 | 44 | function_name = re.findall(METHOD_REGEX, line) 45 | if len(function_name) != 1: 46 | raise ExecutionError("Couldn't extract method name from ENTRYPOINT line in Dockerfile", valid_example) 47 | 48 | return script_name[0], function_name[0] 49 | 50 | raise ExecutionError("Couldn't get entrypoint in Dockerfile", valid_example) 51 | 52 | 53 | def _get_command_args( 54 | function_name: str, args_template: typing.List[string.Template], local_volumes: typing.Dict[str, str] 55 | ) -> typing.List[str]: 56 | args = ["--function-name", str(function_name)] 57 | args += [tpl.substitute(**local_volumes) for tpl in args_template] 58 | return args 59 | 60 | 61 | def _symlink_data_samples(data_sample_paths: typing.Dict[str, pathlib.Path], dest_dir: str): 62 | """Create a symbolic link to "move" the data samples 63 | to the data directory. 64 | 65 | Args: 66 | data_sample_paths (typing.Dict[str, pathlib.Path]): Paths to the samples 67 | dest_dir (str): Temp data directory 68 | """ 69 | # Check if there are already data samples in the dest dir (testtasks are executed in 2 parts) 70 | sample_key = next(iter(data_sample_paths)) 71 | if (pathlib.Path(dest_dir) / sample_key).exists(): 72 | return 73 | # copy the whole tree but using hard link to be fast and not use too much place 74 | for sample_key, sample_path in data_sample_paths.items(): 75 | dest_path = pathlib.Path(dest_dir) / sample_key 76 | shutil.copytree(sample_path, dest_path, copy_function=os.symlink) 77 | 78 | 79 | class Subprocess(BaseSpawner): 80 | """Wrapper to execute a command in a python process.""" 81 | 82 | def __init__(self, local_worker_dir: pathlib.Path): 83 | super().__init__(local_worker_dir=local_worker_dir) 84 | 85 | def spawn( 86 | self, 87 | name, 88 | archive_path, 89 | command_args_tpl: typing.List[string.Template], 90 | data_sample_paths: typing.Optional[typing.Dict[str, pathlib.Path]], 91 | local_volumes, 92 | envs, 93 | ): 94 | """Spawn a python process (blocking).""" 95 | with tempfile.TemporaryDirectory(dir=self._local_worker_dir) as function_dir: 96 | with tempfile.TemporaryDirectory(dir=function_dir) as args_dir: 97 | function_dir = pathlib.Path(function_dir) 98 | args_dir = pathlib.Path(args_dir) 99 | uncompress(archive_path, function_dir) 100 | script_name, function_name = _get_entrypoint_from_dockerfile(function_dir) 101 | 102 | args_file = args_dir / "arguments.txt" 103 | 104 | py_command = [sys.executable, str(function_dir / script_name), f"@{args_file}"] 105 | py_command_args = _get_command_args(function_name, command_args_tpl, local_volumes) 106 | write_command_args_file(args_file, py_command_args) 107 | 108 | if data_sample_paths is not None and len(data_sample_paths) > 0: 109 | _symlink_data_samples(data_sample_paths, local_volumes[VOLUME_INPUTS]) 110 | 111 | # Catching error and raising to be ISO to the docker local backend 112 | # Don't capture the output to be able to use pdb 113 | try: 114 | subprocess.run(py_command, capture_output=False, check=True, cwd=function_dir, env=envs) 115 | except subprocess.CalledProcessError as e: 116 | raise ExecutionError(e) 117 | -------------------------------------------------------------------------------- /substra/sdk/backends/local/dal.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import shutil 4 | import tempfile 5 | import typing 6 | 7 | from substra.sdk import exceptions 8 | from substra.sdk import models 9 | from substra.sdk import schemas 10 | from substra.sdk.backends.local import db 11 | from substra.sdk.backends.remote import backend 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class DataAccess: 17 | """Data access layer. 18 | 19 | This is an intermediate layer between the backend and the local/remote data access. 20 | """ 21 | 22 | def __init__(self, remote_backend: typing.Optional[backend.Remote], local_worker_dir: pathlib.Path): 23 | self._db = db.get_db() 24 | self._remote = remote_backend 25 | self._tmp_dir = tempfile.TemporaryDirectory(prefix=str(local_worker_dir) + "/") 26 | 27 | @property 28 | def tmp_dir(self): 29 | return pathlib.Path(self._tmp_dir.name) 30 | 31 | def is_local(self, key: str, type_: schemas.Type): 32 | try: 33 | self._db.get(type_, key) 34 | return True 35 | except exceptions.NotFound: 36 | return False 37 | 38 | def _get_asset_content_filename(self, type_): 39 | if type_ == schemas.Type.Function: 40 | filename = "function.tar.gz" 41 | field_name = "archive" 42 | 43 | elif type_ == schemas.Type.Dataset: 44 | filename = "opener.py" 45 | field_name = "opener" 46 | 47 | else: 48 | raise ValueError(f"Cannot download this type of asset {type_}") 49 | 50 | return filename, field_name 51 | 52 | def login(self, username, password): 53 | if self._remote: 54 | self._remote.login(username, password) 55 | 56 | def logout(self): 57 | if self._remote: 58 | self._remote.logout() 59 | 60 | def add(self, asset): 61 | return self._db.add(asset) 62 | 63 | def remote_download(self, asset_type, url_field_path, key, destination): 64 | self._remote.download(asset_type, url_field_path, key, destination) 65 | 66 | def remote_download_model(self, key, destination_file): 67 | self._remote.download_model(key, destination_file) 68 | 69 | def get_remote_description(self, asset_type, key): 70 | return self._remote.describe(asset_type, key) 71 | 72 | def get_with_files(self, type_: schemas.Type, key: str): 73 | """Get the asset with files on the local disk for execution. 74 | This does not load the description as it is not required for execution. 75 | """ 76 | try: 77 | # Try to find the asset locally 78 | return self._db.get(type_, key) 79 | except exceptions.NotFound: 80 | if self._remote is not None: 81 | # if not found, try remotely 82 | filename, field_name = self._get_asset_content_filename(type_) 83 | asset = self._remote.get(type_, key) 84 | tmp_directory = self.tmp_dir / key 85 | asset_path = tmp_directory / filename 86 | 87 | if not tmp_directory.exists(): 88 | pathlib.Path.mkdir(tmp_directory) 89 | 90 | self._remote.download( 91 | type_, 92 | field_name + ".storage_address", 93 | key, 94 | asset_path, 95 | ) 96 | 97 | attr = getattr(asset, field_name) 98 | attr.storage_address = asset_path 99 | return asset 100 | raise 101 | 102 | def get(self, type_, key: str): 103 | try: 104 | # Try to find the asset locally 105 | return self._db.get(type_, key) 106 | except exceptions.NotFound: 107 | if self._remote is not None: 108 | return self._remote.get(type_, key) 109 | raise 110 | 111 | def get_performances(self, key: str) -> models.Performances: 112 | """Get the performances of a given compute. Return models.Performances() object 113 | easily convertible to dict, filled by the performances data of done tasks that output a performance. 114 | """ 115 | compute_plan = self.get(schemas.Type.ComputePlan, key) 116 | list_tasks = self.list( 117 | schemas.Type.Task, 118 | filters={"compute_plan_key": [key]}, 119 | order_by="rank", 120 | ascending=True, 121 | ) 122 | 123 | performances = models.Performances() 124 | 125 | for task in list_tasks: 126 | if task.status == models.ComputeTaskStatus.done: 127 | function = self.get(schemas.Type.Function, task.function.key) 128 | perf_identifiers = [ 129 | output.identifier for output in function.outputs if output.kind == schemas.AssetKind.performance 130 | ] 131 | outputs = self.list( 132 | schemas.Type.OutputAsset, {"compute_task_key": task.key, "identifier": perf_identifiers} 133 | ) 134 | for output in outputs: 135 | performances.compute_plan_key.append(compute_plan.key) 136 | performances.compute_plan_tag.append(compute_plan.tag) 137 | performances.compute_plan_status.append(compute_plan.status) 138 | performances.compute_plan_start_date.append(compute_plan.start_date) 139 | performances.compute_plan_end_date.append(compute_plan.end_date) 140 | performances.compute_plan_metadata.append(compute_plan.metadata) 141 | 142 | performances.worker.append(task.worker) 143 | performances.task_key.append(task.key) 144 | performances.task_rank.append(task.rank) 145 | try: 146 | round_idx = int(task.metadata.get("round_idx")) 147 | except TypeError: 148 | round_idx = None 149 | performances.round_idx.append(round_idx) 150 | performances.identifier.append(output.identifier) 151 | performances.performance.append(output.asset) 152 | 153 | return performances 154 | 155 | def list( 156 | self, type_: str, filters: typing.Dict[str, typing.List[str]], order_by: str = None, ascending: bool = False 157 | ): 158 | """Joins the results of the [local db](substra.sdk.backends.local.db.list) and the 159 | [remote db](substra.sdk.backends.rest_client.list) in hybrid mode. 160 | """ 161 | 162 | local_assets = self._db.list(type_=type_, filters=filters, order_by=order_by, ascending=ascending) 163 | 164 | remote_assets = [] 165 | if self._remote: 166 | try: 167 | remote_assets = self._remote.list( 168 | asset_type=type_, filters=filters, order_by=order_by, ascending=ascending 169 | ) 170 | except Exception as e: 171 | logger.info( 172 | f"Could not list assets from the remote platform:\n{e}. \ 173 | \nIf you are not logged to a remote platform, ignore this message." 174 | ) 175 | return local_assets + remote_assets 176 | 177 | def save_file(self, file_path: typing.Union[str, pathlib.Path], key: str): 178 | """Copy file or directory into the local temp dir to mimick 179 | the remote backend that saves the files given by the user. 180 | """ 181 | tmp_directory = self.tmp_dir / key 182 | tmp_file = tmp_directory / pathlib.Path(file_path).name 183 | 184 | if not tmp_directory.exists(): 185 | pathlib.Path.mkdir(tmp_directory) 186 | 187 | if tmp_file.exists(): 188 | raise exceptions.AlreadyExists(f"File {tmp_file.name} already exists for asset {key}", 409) 189 | elif pathlib.Path(file_path).is_file(): 190 | shutil.copyfile(file_path, tmp_file) 191 | elif pathlib.Path(file_path).is_dir(): 192 | shutil.copytree(file_path, tmp_file) 193 | else: 194 | raise exceptions.InvalidRequest(f"Could not copy {file_path}", 400) 195 | return tmp_file 196 | 197 | def update(self, asset): 198 | self._db.update(asset) 199 | return 200 | -------------------------------------------------------------------------------- /substra/sdk/backends/local/db.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import typing 4 | 5 | from substra.sdk import exceptions 6 | from substra.sdk import models 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class InMemoryDb: 12 | """In memory data db.""" 13 | 14 | def __init__(self): 15 | # assets stored per type and per key 16 | self._data = collections.defaultdict(dict) 17 | 18 | def add(self, asset): 19 | """Add an asset.""" 20 | type_ = asset.__class__.type_ 21 | key = getattr(asset, "key", None) 22 | if not key: 23 | key = asset.id 24 | if key in self._data[type_]: 25 | raise exceptions.KeyAlreadyExistsError(f"The asset key {key} of type {type_} has already been used.") 26 | self._data[type_][key] = asset 27 | logger.info(f"{type_} with key '{key}' has been created.") 28 | 29 | return asset 30 | 31 | def get(self, type_, key: str): 32 | """Return asset.""" 33 | try: 34 | return self._data[type_][key] 35 | except KeyError: 36 | raise exceptions.NotFound(f"Wrong pk {key}", 404) 37 | 38 | def _match_asset(self, asset: models._Model, attribute: str, values: typing.Union[typing.Dict, typing.List]): 39 | """Checks if an asset attributes matches the given values. 40 | For the metadata, it checks that all the given filters returns True (AND condition)""" 41 | if attribute == "metadata": 42 | metadata_conditions = [] 43 | for value in values: 44 | if value["type"] == models.MetadataFilterType.exists: 45 | metadata_conditions.append(value["key"] in asset.metadata.keys()) 46 | 47 | elif asset.metadata.get(value["key"]) is None: 48 | # for is_equal and contains, if the key is not there then return False 49 | metadata_conditions.append(False) 50 | 51 | elif value["type"] == models.MetadataFilterType.is_equal: 52 | metadata_conditions.append(str(value["value"]) == str(asset.metadata[value["key"]])) 53 | 54 | elif value["type"] == models.MetadataFilterType.contains: 55 | metadata_conditions.append(str(value["value"]) in str(asset.metadata.get(value["key"]))) 56 | else: 57 | raise NotImplementedError 58 | 59 | return all(metadata_conditions) 60 | 61 | return str(getattr(asset, attribute)) in values 62 | 63 | def _filter_assets( 64 | self, db_assets: typing.List[models._Model], filters: typing.Dict[str, typing.List[str]] 65 | ) -> typing.List[models._Model]: 66 | """Return assets matching al the given filters""" 67 | 68 | matching_assets = [ 69 | asset 70 | for asset in db_assets 71 | if all(self._match_asset(asset, attribute, values) for attribute, values in filters.items()) 72 | ] 73 | return matching_assets 74 | 75 | def list( 76 | self, type_: str, filters: typing.Dict[str, typing.List[str]], order_by: str = None, ascending: bool = False 77 | ): 78 | """List assets by filters. 79 | 80 | Args: 81 | asset_type (str): asset type. e.g. "function" 82 | filters (dict, optional): keys = attributes, values = list of values for this attribute. 83 | e.g. {"name": ["name1", "name2"]}. "," corresponds to an "OR". Defaults to None. 84 | order_by (str, optional): attribute name to order the results on. Defaults to None. 85 | e.g. "name" for an ordering on name. 86 | ascending (bool, optional): to reverse ordering. Defaults to False (descending order). 87 | 88 | Returns: 89 | List[Dict] : a List of assets (dicts) 90 | """ 91 | # get all assets of this type 92 | assets = list(self._data[type_].values()) 93 | 94 | if filters: 95 | assets = self._filter_assets(assets, filters) 96 | if order_by: 97 | assets.sort(key=lambda x: getattr(x, order_by), reverse=(not ascending)) 98 | 99 | return assets 100 | 101 | def update(self, asset): 102 | type_ = asset.__class__.type_ 103 | key = asset.key 104 | 105 | if key not in self._data[type_]: 106 | raise exceptions.NotFound(f"Wrong pk {key}", 404) 107 | 108 | self._data[type_][key] = asset 109 | return 110 | 111 | 112 | db = InMemoryDb() 113 | 114 | 115 | def get_db(): 116 | return db 117 | -------------------------------------------------------------------------------- /substra/sdk/backends/local/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import List 3 | 4 | from pydantic import Field 5 | 6 | from substra.sdk import models 7 | from substra.sdk import schemas 8 | 9 | 10 | class _TaskAssetLocal(schemas._PydanticConfig): 11 | key: str = Field(default_factory=uuid.uuid4) 12 | compute_task_key: str 13 | 14 | @staticmethod 15 | def allowed_filters() -> List[str]: 16 | return super().allowed_filters() + ["compute_task_key"] 17 | 18 | 19 | class OutputAssetLocal(models.OutputAsset, _TaskAssetLocal): 20 | pass 21 | 22 | 23 | class InputAssetLocal(models.InputAsset, _TaskAssetLocal): 24 | pass 25 | -------------------------------------------------------------------------------- /substra/sdk/backends/remote/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/substra/sdk/backends/remote/__init__.py -------------------------------------------------------------------------------- /substra/sdk/backends/remote/request_formatter.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def format_search_filters_for_remote(filters): 5 | formatted_filters = {} 6 | # do not process if no filters 7 | if filters is None: 8 | return formatted_filters 9 | 10 | for key in filters: 11 | # handle special cases name and metadata 12 | if key == "name": 13 | formatted_filters["match"] = filters[key] 14 | elif key == "metadata": 15 | formatted_filters["metadata"] = json.dumps(filters["metadata"]).replace(" ", "") 16 | 17 | else: 18 | # all other filters are formatted as a csv string without spaces 19 | values = ",".join(filters[key]) 20 | formatted_filters[key] = values.replace(" ", "") 21 | 22 | return formatted_filters 23 | 24 | 25 | def format_search_ordering_for_remote(order_by, ascending): 26 | if not ascending: 27 | return "-" + order_by 28 | return order_by 29 | -------------------------------------------------------------------------------- /substra/sdk/compute_plan.py: -------------------------------------------------------------------------------- 1 | from substra.sdk import exceptions 2 | from substra.sdk import graph 3 | from substra.sdk import schemas 4 | 5 | 6 | def _insert_into_graph(task_graph, task_id, in_model_ids): 7 | if task_id in task_graph: 8 | raise exceptions.InvalidRequest("Two tasks cannot have the same id.", 400) 9 | task_graph[task_id] = in_model_ids 10 | 11 | 12 | def get_dependency_graph(spec: schemas._BaseComputePlanSpec): 13 | """Get the task dependency graph and, for each type of task, a mapping table id/task.""" 14 | task_graph = {} 15 | tasks = {} 16 | 17 | if spec.tasks: 18 | for task in spec.tasks: 19 | _insert_into_graph( 20 | task_graph=task_graph, 21 | task_id=task.task_id, 22 | in_model_ids=[ 23 | input_ref.parent_task_key for input_ref in (task.inputs or []) if input_ref.parent_task_key 24 | ], 25 | ) 26 | tasks[task.task_id] = schemas.TaskSpec.from_compute_plan( 27 | compute_plan_key=spec.key, 28 | rank=None, 29 | spec=task, 30 | ) 31 | return task_graph, tasks 32 | 33 | 34 | def get_tasks(spec): 35 | """Returns compute plan tasks sorted by dependencies.""" 36 | 37 | # Create the dependency graph and get the dict of tasks by id 38 | task_graph, tasks = get_dependency_graph(spec) 39 | 40 | already_created_ids = set() 41 | # Here we get the pre-existing tasks and assign them the minimal rank 42 | for dependencies in task_graph.values(): 43 | for dependency_id in dependencies: 44 | if dependency_id not in task_graph: 45 | already_created_ids.add(dependency_id) 46 | 47 | # Compute the relative ranks of the new tasks (relatively to each other, these 48 | # are not their actual ranks in the compute plan) 49 | id_ranks = graph.compute_ranks(node_graph=task_graph, node_to_ignore=already_created_ids) 50 | 51 | # Sort the tasks by rank 52 | sorted_by_rank = sorted(id_ranks.items(), key=lambda item: item[1]) 53 | 54 | return [tasks[task_id] for task_id, _ in sorted_by_rank] 55 | -------------------------------------------------------------------------------- /substra/sdk/exceptions.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | import requests 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class SDKException(Exception): 10 | pass 11 | 12 | 13 | class LoadDataException(SDKException): 14 | pass 15 | 16 | 17 | class RequestException(SDKException): 18 | def __init__(self, msg, status_code): 19 | self.msg = msg 20 | self.status_code = status_code 21 | super().__init__(msg) 22 | 23 | @classmethod 24 | def from_request_exception(cls, request_exception): 25 | msg = None 26 | try: 27 | msg = request_exception.response.json()["detail"] 28 | msg = f"{request_exception}: {msg}" 29 | except Exception: 30 | msg = str(request_exception) 31 | 32 | try: 33 | status_code = request_exception.response.status_code 34 | except AttributeError: 35 | status_code = None 36 | 37 | return cls(msg, status_code) 38 | 39 | 40 | class ConnectionError(RequestException): 41 | pass 42 | 43 | 44 | class Timeout(RequestException): 45 | pass 46 | 47 | 48 | class HTTPError(RequestException): 49 | pass 50 | 51 | 52 | class InternalServerError(HTTPError): 53 | pass 54 | 55 | 56 | class GatewayUnavailable(HTTPError): 57 | pass 58 | 59 | 60 | class InvalidRequest(HTTPError): 61 | def __init__(self, msg, status_code, errors=None): 62 | super().__init__(msg, status_code) 63 | self.errors = errors 64 | 65 | @classmethod 66 | def from_request_exception(cls, request_exception): 67 | try: 68 | error = request_exception.response.json() 69 | except ValueError: 70 | error = request_exception.response 71 | 72 | get_method = getattr(error, "get", None) 73 | if callable(get_method): 74 | msg = get_method("detail", str(error)) 75 | else: 76 | msg = str(error) 77 | 78 | try: 79 | status_code = request_exception.response.status_code 80 | except AttributeError: 81 | status_code = None 82 | 83 | return cls(msg, status_code, error) 84 | 85 | 86 | class NotFound(HTTPError): 87 | pass 88 | 89 | 90 | class RequestTimeout(HTTPError): 91 | def __init__(self, key, status_code): 92 | self.key = key 93 | msg = f"Operation on object with key(s) '{key}' timed out." 94 | super().__init__(msg, status_code) 95 | 96 | @classmethod 97 | def from_request_exception(cls, request_exception): 98 | # parse response and fetch key 99 | r = request_exception.response.json() 100 | 101 | try: 102 | key = r["key"] if "key" in r else r["detail"].get("key") 103 | except (AttributeError, KeyError): 104 | # XXX this is the case when doing a POST query to update the 105 | # data manager for instance 106 | key = None 107 | 108 | return cls(key, request_exception.response.status_code) 109 | 110 | 111 | class AlreadyExists(HTTPError): 112 | def __init__(self, key, status_code): 113 | self.key = key 114 | msg = f"Object with key(s) '{key}' already exists." 115 | super().__init__(msg, status_code) 116 | 117 | @classmethod 118 | def from_request_exception(cls, request_exception): 119 | # parse response and fetch key 120 | r = request_exception.response.json() 121 | # XXX support list of keys; this could be the case when adding 122 | # a list of data samples through a single POST request 123 | if isinstance(r, list): 124 | key = [x["key"] for x in r] 125 | elif isinstance(r, dict): 126 | key = r.get("key", None) 127 | else: 128 | key = r 129 | 130 | return cls(key, request_exception.response.status_code) 131 | 132 | 133 | class InvalidResponse(SDKException): 134 | def __init__(self, response, msg): 135 | self.response = response 136 | super(InvalidResponse, self).__init__(msg) 137 | 138 | 139 | class AuthenticationError(HTTPError): 140 | pass 141 | 142 | 143 | class AuthorizationError(HTTPError): 144 | pass 145 | 146 | 147 | class BadLoginException(RequestException): 148 | """The server refused to log-in with these credentials""" 149 | 150 | pass 151 | 152 | 153 | class UsernamePasswordLoginDisabledException(RequestException): 154 | """The server disabled the endpoint, preventing the use of Client.login""" 155 | 156 | @classmethod 157 | def from_request_exception(cls, request_exception): 158 | base = super().from_request_exception(request_exception) 159 | return cls( 160 | base.msg 161 | + ( 162 | "\n\nAuthenticating with username/password is disabled.\n" 163 | "Log onto the frontend for your instance and generate a token there, " 164 | 'then use it in the Client(token="...") constructor: ' 165 | "https://docs.substra.org/en/stable/documentation/api_tokens_generation.html" 166 | ), 167 | base.status_code, 168 | ) 169 | 170 | 171 | def from_request_exception( 172 | e: requests.exceptions.RequestException, 173 | ) -> Union[RequestException, requests.exceptions.RequestException]: 174 | """ 175 | try turning an exception from the `requests` library into a Substra exception 176 | """ 177 | connection_error_mapping: dict[requests.exceptions.RequestException, RequestException] = { 178 | requests.exceptions.ConnectionError: ConnectionError, 179 | requests.exceptions.Timeout: Timeout, 180 | } 181 | for k, v in connection_error_mapping.items(): 182 | if isinstance(e, k): 183 | return v.from_request_exception(e) 184 | 185 | http_status_mapping: dict[int, RequestException] = { 186 | 400: InvalidRequest, 187 | 401: AuthenticationError, 188 | 403: AuthorizationError, 189 | 404: NotFound, 190 | 408: RequestTimeout, 191 | 409: AlreadyExists, 192 | 500: InternalServerError, 193 | 502: GatewayUnavailable, 194 | 503: GatewayUnavailable, 195 | 504: GatewayUnavailable, 196 | } 197 | if isinstance(e, requests.exceptions.HTTPError): 198 | logger.error(f"Requests error status {e.response.status_code}: {e.response.text}") 199 | return http_status_mapping.get(e.response.status_code, HTTPError).from_request_exception(e) 200 | 201 | return e 202 | 203 | 204 | class ConfigurationInfoError(SDKException): 205 | """ConfigurationInfoError""" 206 | 207 | pass 208 | 209 | 210 | class BadConfiguration(SDKException): 211 | """Bad configuration""" 212 | 213 | pass 214 | 215 | 216 | class UserException(SDKException): 217 | """User Exception""" 218 | 219 | pass 220 | 221 | 222 | class EmptyInModelException(SDKException): 223 | """No in_models when needed""" 224 | 225 | pass 226 | 227 | 228 | class ComputePlanKeyFormatError(Exception): 229 | """The given compute plan key has to respect the UUID format.""" 230 | 231 | pass 232 | 233 | 234 | class OrderingFormatError(Exception): 235 | """The given ordering parameter has to respect expected format.""" 236 | 237 | pass 238 | 239 | 240 | class FilterFormatError(Exception): 241 | """The given filters has to respect expected format.""" 242 | 243 | pass 244 | 245 | 246 | class NotAllowedFilterError(Exception): 247 | """The given filter is not available on asset.""" 248 | 249 | pass 250 | 251 | 252 | class KeyAlreadyExistsError(Exception): 253 | """The asset key has already been used.""" 254 | 255 | pass 256 | 257 | 258 | class _TaskAssetError(Exception): 259 | """Base exception class for task asset error""" 260 | 261 | def __init__(self, *, compute_task_key: str, identifier: str, message: str): 262 | self.compute_task_key = compute_task_key 263 | self.identifier = identifier 264 | self.message = message 265 | super().__init__(self.message) 266 | 267 | pass 268 | 269 | 270 | class TaskAssetNotFoundError(_TaskAssetError): 271 | """Exception raised when no task input/output asset have not been found for specific task key and identifier""" 272 | 273 | def __init__(self, compute_task_key: str, identifier: str): 274 | message = f"No task asset found with {compute_task_key=} and {identifier=}" 275 | super().__init__(compute_task_key=compute_task_key, identifier=identifier, message=message) 276 | 277 | pass 278 | 279 | 280 | class TaskAssetMultipleFoundError(_TaskAssetError): 281 | """Exception raised when more than one task input/output assets have been found for specific task key and 282 | identifier""" 283 | 284 | def __init__(self, compute_task_key: str, identifier: str): 285 | message = f"Multiple task assets found with {compute_task_key=} and {identifier=}" 286 | super().__init__(compute_task_key=compute_task_key, identifier=identifier, message=message) 287 | 288 | 289 | class FutureError(Exception): 290 | """Error while waiting a blocking operation to complete""" 291 | 292 | 293 | class FutureTimeoutError(FutureError): 294 | """Future execution timed out.""" 295 | 296 | 297 | class FutureFailureError(FutureError): 298 | """Future execution failed.""" 299 | -------------------------------------------------------------------------------- /substra/sdk/fs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from substra.sdk.hasher import Hasher 4 | 5 | _BLOCK_SIZE = 64 * 1024 6 | 7 | 8 | def hash_file(path): 9 | """Hash a file.""" 10 | hasher = Hasher() 11 | 12 | with open(path, "rb") as fp: 13 | while True: 14 | data = fp.read(_BLOCK_SIZE) 15 | if not data: 16 | break 17 | hasher.update(data) 18 | return hasher.compute() 19 | 20 | 21 | def hash_directory(path, followlinks=False): 22 | """Hash a directory.""" 23 | 24 | if not os.path.isdir(path): 25 | raise TypeError(f"{path} is not a directory.") 26 | 27 | hash_values = [] 28 | for root, dirs, files in os.walk(path, topdown=True, followlinks=followlinks): 29 | dirs.sort() 30 | files.sort() 31 | 32 | for fname in files: 33 | hash_values.append(hash_file(os.path.join(root, fname))) 34 | 35 | return Hasher(values=sorted(hash_values)).compute() 36 | -------------------------------------------------------------------------------- /substra/sdk/graph.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from substra.sdk import exceptions 4 | 5 | 6 | def _get_inverted_node_graph(node_graph, node_to_ignore): 7 | """Get the graph {node_id: nodes that depend on this one} 8 | Also this graph does not contain the nodes to ignore 9 | """ 10 | inverted = dict() 11 | for node, dependencies in node_graph.items(): 12 | if node not in node_to_ignore: 13 | for dependency in dependencies: 14 | if dependency not in node_to_ignore: 15 | inverted.setdefault(dependency, []) 16 | inverted[dependency].append(node) 17 | return inverted 18 | 19 | 20 | def _breadth_first_traversal_rank( 21 | ranks: typing.Dict[str, int], inverted_node_graph: typing.Dict[str, typing.List[str]] 22 | ): 23 | edges = set() 24 | queue = [node for node in ranks] 25 | visited = set(queue) 26 | 27 | if len(queue) == 0: 28 | raise exceptions.InvalidRequest("missing dependency among inModels IDs, circular dependency found", 400) 29 | 30 | while len(queue) > 0: 31 | current_node = queue.pop(0) 32 | for child in inverted_node_graph.get(current_node, []): 33 | new_child_rank = max(ranks[current_node] + 1, ranks.get(child, -1)) 34 | 35 | if new_child_rank != ranks.get(child, -1): 36 | # either the child has never been visited 37 | # or its rank has been updated and we must visit again 38 | ranks[child] = new_child_rank 39 | visited.add(child) 40 | queue.append(child) 41 | 42 | # Cycle detection 43 | edge = (current_node, child) 44 | if (edge[1], edge[0]) in edges: 45 | raise exceptions.InvalidRequest( 46 | f"missing dependency among inModels IDs, \ 47 | circular dependency between {edge[0]} and {edge[1]}", 48 | 400, 49 | ) 50 | else: 51 | edges.add(edge) 52 | return ranks 53 | 54 | 55 | def compute_ranks( 56 | node_graph: typing.Dict[str, typing.List[str]], 57 | node_to_ignore: typing.Set[str] = None, 58 | ranks: typing.Dict[str, int] = None, 59 | ) -> typing.Dict[str, int]: 60 | """Compute the ranks of the nodes in the graph. 61 | 62 | Args: 63 | node_graph (typing.Dict[str, typing.List[str]]): 64 | Dict {node_id: list of nodes it depends on}. 65 | Node graph keys must not contain any node to ignore. 66 | node_to_ignore (typing.Set[str], optional): List of nodes to ignore. 67 | Defaults to None. 68 | ranks (typing.Dict[str, int]): Already computed ranks. Defaults to None. 69 | 70 | Raises: 71 | exceptions.InvalidRequest: If the node graph contains a cycle 72 | 73 | Returns: 74 | typing.Dict[str, int]: Dict { node_id : rank } 75 | """ 76 | ranks = ranks or dict() 77 | node_to_ignore = node_to_ignore or set() 78 | 79 | if len(node_graph) == 0: 80 | return dict() 81 | 82 | extra_nodes = set(node_graph.keys()).intersection(node_to_ignore) 83 | if len(extra_nodes) > 0: 84 | raise ValueError(f"node_graph keys should not contain any node to ignore: {extra_nodes}") 85 | 86 | inverted_node_graph = _get_inverted_node_graph(node_graph, node_to_ignore) 87 | 88 | # Assign rank 0 to nodes without deps 89 | for node, dependencies in node_graph.items(): 90 | if node not in node_to_ignore: 91 | actual_deps = [dep for dep in dependencies if dep not in node_to_ignore] 92 | if len(actual_deps) == 0: 93 | ranks[node] = 0 94 | 95 | ranks = _breadth_first_traversal_rank(ranks=ranks, inverted_node_graph=inverted_node_graph) 96 | 97 | return ranks 98 | -------------------------------------------------------------------------------- /substra/sdk/hasher.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | 4 | class Hasher: 5 | def __init__(self, values=None): 6 | self._h = hashlib.sha256() 7 | if values: 8 | for v in values: 9 | self.update(v) 10 | 11 | def update(self, v): 12 | if isinstance(v, str): 13 | v = v.encode("utf-8") 14 | self._h.update(v) 15 | 16 | def compute(self): 17 | return self._h.hexdigest() 18 | -------------------------------------------------------------------------------- /substra/sdk/models.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import enum 3 | import re 4 | from datetime import datetime 5 | from typing import ClassVar 6 | from typing import Dict 7 | from typing import List 8 | from typing import Optional 9 | from typing import Type 10 | from typing import Union 11 | 12 | import pydantic 13 | from pydantic import AnyUrl 14 | from pydantic import ConfigDict 15 | from pydantic import DirectoryPath 16 | from pydantic import FilePath 17 | from pydantic.fields import Field 18 | 19 | from substra.sdk import schemas 20 | 21 | # The remote can return an URL or an empty string for paths 22 | UriPath = Union[FilePath, AnyUrl, str] 23 | CAMEL_TO_SNAKE_PATTERN = re.compile(r"(.)([A-Z][a-z]+)") 24 | CAMEL_TO_SNAKE_PATTERN_2 = re.compile(r"([a-z0-9])([A-Z])") 25 | 26 | 27 | class MetadataFilterType(str, enum.Enum): 28 | is_equal = "is" 29 | contains = "contains" 30 | exists = "exists" 31 | 32 | 33 | class ComputeTaskStatus(str, enum.Enum): 34 | """Status of the task""" 35 | 36 | unknown = "STATUS_UNKNOWN" 37 | building = "STATUS_BUILDING" 38 | executing = "STATUS_EXECUTING" 39 | done = "STATUS_DONE" 40 | failed = "STATUS_FAILED" 41 | waiting_for_executor_slot = "STATUS_WAITING_FOR_EXECUTOR_SLOT" 42 | waiting_for_parent_tasks = "STATUS_WAITING_FOR_PARENT_TASKS" 43 | waiting_for_builder_slot = "STATUS_WAITING_FOR_BUILDER_SLOT" 44 | canceled = "STATUS_CANCELED" 45 | 46 | 47 | class ComputePlanStatus(str, enum.Enum): 48 | """Status of the compute plan""" 49 | 50 | unknown = "PLAN_STATUS_UNKNOWN" 51 | doing = "PLAN_STATUS_DOING" 52 | done = "PLAN_STATUS_DONE" 53 | failed = "PLAN_STATUS_FAILED" 54 | created = "PLAN_STATUS_CREATED" 55 | canceled = "PLAN_STATUS_CANCELED" 56 | 57 | 58 | class FunctionStatus(str, enum.Enum): 59 | """Status of the function""" 60 | 61 | unknown = "FUNCTION_STATUS_UNKNOWN" 62 | waiting = "FUNCTION_STATUS_WAITING" 63 | building = "FUNCTION_STATUS_BUILDING" 64 | ready = "FUNCTION_STATUS_READY" 65 | failed = "FUNCTION_STATUS_FAILED" 66 | canceled = "FUNCTION_STATUS_CANCELED" 67 | 68 | 69 | class TaskErrorType(str, enum.Enum): 70 | """Types of errors that can occur in a task""" 71 | 72 | build = "BUILD_ERROR" 73 | execution = "EXECUTION_ERROR" 74 | internal = "INTERNAL_ERROR" 75 | 76 | 77 | class OrderingFields(str, enum.Enum): 78 | """Model fields ordering is allowed on for list""" 79 | 80 | creation_date = "creation_date" 81 | start_date = "start_date" 82 | end_date = "end_date" 83 | 84 | @classmethod 85 | def __contains__(cls, item): 86 | try: 87 | cls(item) 88 | except ValueError: 89 | return False 90 | else: 91 | return True 92 | 93 | 94 | class Permission(schemas._PydanticConfig): 95 | """Permissions of a task""" 96 | 97 | public: bool 98 | authorized_ids: List[str] 99 | 100 | 101 | class Permissions(schemas._PydanticConfig): 102 | """Permissions structure stored in various asset types.""" 103 | 104 | process: Permission 105 | 106 | 107 | class _Model(schemas._PydanticConfig, abc.ABC): 108 | """Asset creation specification base class.""" 109 | 110 | # pretty print 111 | def __str__(self): 112 | return self.model_dump_json(indent=4) 113 | 114 | def __repr__(self): 115 | return self.model_dump_json(indent=4) 116 | 117 | @staticmethod 118 | def allowed_filters() -> List[str]: 119 | """allowed fields to filter on""" 120 | return [] 121 | 122 | 123 | class DataSample(_Model): 124 | """Data sample""" 125 | 126 | key: str 127 | owner: str 128 | data_manager_keys: Optional[List[str]] = None 129 | path: Optional[DirectoryPath] = None 130 | creation_date: datetime 131 | 132 | type_: ClassVar[str] = schemas.Type.DataSample 133 | 134 | @staticmethod 135 | def allowed_filters() -> List[str]: 136 | return ["key", "owner", "compute_plan_key", "function_key", "dataset_key"] 137 | 138 | 139 | class _File(schemas._PydanticConfig): 140 | """File as stored in the models""" 141 | 142 | checksum: str 143 | storage_address: UriPath 144 | 145 | 146 | class Dataset(_Model): 147 | """Dataset asset""" 148 | 149 | key: str 150 | name: str 151 | owner: str 152 | permissions: Permissions 153 | data_sample_keys: List[str] = [] 154 | opener: _File 155 | description: _File 156 | metadata: Dict[str, str] 157 | creation_date: datetime 158 | logs_permission: Permission 159 | 160 | type_: ClassVar[str] = schemas.Type.Dataset 161 | 162 | @staticmethod 163 | def allowed_filters() -> List[str]: 164 | return ["key", "name", "owner", "permissions", "compute_plan_key", "function_key", "data_sample_key"] 165 | 166 | 167 | class FunctionInput(_Model): 168 | identifier: str 169 | kind: schemas.AssetKind 170 | optional: bool 171 | multiple: bool 172 | 173 | 174 | class FunctionOutput(_Model): 175 | identifier: str 176 | kind: schemas.AssetKind 177 | multiple: bool 178 | 179 | 180 | class Function(_Model): 181 | key: str 182 | name: str 183 | owner: str 184 | permissions: Permissions 185 | metadata: Dict[str, str] 186 | creation_date: datetime 187 | inputs: List[FunctionInput] 188 | outputs: List[FunctionOutput] 189 | status: FunctionStatus 190 | 191 | description: _File 192 | archive: _File 193 | 194 | type_: ClassVar[str] = schemas.Type.Function 195 | 196 | @staticmethod 197 | def allowed_filters() -> List[str]: 198 | return ["key", "name", "owner", "permissions", "compute_plan_key", "dataset_key", "data_sample_key"] 199 | 200 | @pydantic.field_validator("inputs", mode="before") 201 | @classmethod 202 | def dict_input_to_list(cls, v): 203 | if isinstance(v, dict): 204 | # Transform the inputs dict to a list 205 | return [ 206 | FunctionInput( 207 | identifier=identifier, 208 | kind=function_input["kind"], 209 | optional=function_input["optional"], 210 | multiple=function_input["multiple"], 211 | ) 212 | for identifier, function_input in v.items() 213 | ] 214 | else: 215 | return v 216 | 217 | @pydantic.field_validator("outputs", mode="before") 218 | @classmethod 219 | def dict_output_to_list(cls, v): 220 | if isinstance(v, dict): 221 | # Transform the outputs dict to a list 222 | return [ 223 | FunctionOutput( 224 | identifier=identifier, kind=function_output["kind"], multiple=function_output["multiple"] 225 | ) 226 | for identifier, function_output in v.items() 227 | ] 228 | else: 229 | return v 230 | 231 | 232 | class InModel(schemas._PydanticConfig): 233 | """In model of a task""" 234 | 235 | checksum: str 236 | storage_address: UriPath 237 | 238 | 239 | class OutModel(schemas._PydanticConfig): 240 | """Out model of a task""" 241 | 242 | key: str 243 | compute_task_key: str 244 | address: Optional[InModel] = None 245 | permissions: Permissions 246 | owner: str 247 | creation_date: datetime 248 | 249 | type_: ClassVar[str] = schemas.Type.Model 250 | 251 | @staticmethod 252 | def allowed_filters() -> List[str]: 253 | return ["key", "compute_task_key", "owner", "permissions"] 254 | 255 | 256 | class InputRef(schemas._PydanticConfig): 257 | identifier: str 258 | asset_key: Optional[str] = None 259 | parent_task_key: Optional[str] = None 260 | parent_task_output_identifier: Optional[str] = None 261 | 262 | # either (asset_key) or (parent_task_key, parent_task_output_identifier) must be specified 263 | _check_asset_key_or_parent_ref = pydantic.model_validator(mode="before")(schemas.check_asset_key_or_parent_ref) 264 | 265 | 266 | class ComputeTaskOutput(schemas._PydanticConfig): 267 | """Specification of a compute task input""" 268 | 269 | permissions: Permissions 270 | is_transient: bool = Field(False, alias="transient") 271 | model_config = ConfigDict(populate_by_name=True) 272 | 273 | 274 | class Task(_Model): 275 | key: str 276 | function: Function 277 | owner: str 278 | compute_plan_key: str 279 | metadata: Dict[str, str] 280 | status: ComputeTaskStatus 281 | worker: str 282 | rank: Optional[int] = None 283 | tag: str 284 | creation_date: datetime 285 | start_date: Optional[datetime] = None 286 | end_date: Optional[datetime] = None 287 | error_type: Optional[TaskErrorType] = None 288 | inputs: List[InputRef] 289 | outputs: Dict[str, ComputeTaskOutput] 290 | 291 | type_: ClassVar[Type] = schemas.Type.Task 292 | 293 | @staticmethod 294 | def allowed_filters() -> List[str]: 295 | return [ 296 | "key", 297 | "owner", 298 | "worker", 299 | "rank", 300 | "status", 301 | "metadata", 302 | "compute_plan_key", 303 | "function_key", 304 | ] 305 | 306 | 307 | Task.model_rebuild() 308 | 309 | 310 | class ComputePlan(_Model): 311 | """ComputePlan""" 312 | 313 | key: str 314 | tag: str 315 | name: str 316 | owner: str 317 | metadata: Dict[str, str] 318 | task_count: int = 0 319 | waiting_builder_slot_count: int = 0 320 | building_count: int = 0 321 | waiting_parent_tasks_count: int = 0 322 | waiting_executor_slot_count: int = 0 323 | executing_count: int = 0 324 | canceled_count: int = 0 325 | failed_count: int = 0 326 | done_count: int = 0 327 | failed_task_key: Optional[str] = None 328 | status: ComputePlanStatus 329 | creation_date: datetime 330 | start_date: Optional[datetime] = None 331 | end_date: Optional[datetime] = None 332 | estimated_end_date: Optional[datetime] = None 333 | duration: Optional[int] = None 334 | creator: Optional[str] = None 335 | 336 | type_: ClassVar[str] = schemas.Type.ComputePlan 337 | 338 | @staticmethod 339 | def allowed_filters() -> List[str]: 340 | return [ 341 | "key", 342 | "name", 343 | "owner", 344 | "worker", 345 | "status", 346 | "metadata", 347 | "function_key", 348 | "dataset_key", 349 | "data_sample_key", 350 | ] 351 | 352 | 353 | class Performances(_Model): 354 | """Performances of the different compute tasks of a compute plan""" 355 | 356 | compute_plan_key: List[str] = [] 357 | compute_plan_tag: List[str] = [] 358 | compute_plan_status: List[str] = [] 359 | compute_plan_start_date: List[datetime] = [] 360 | compute_plan_end_date: List[datetime] = [] 361 | compute_plan_metadata: List[dict] = [] 362 | worker: List[str] = [] 363 | task_key: List[str] = [] 364 | task_rank: List[int] = [] 365 | round_idx: List[int] = [] 366 | identifier: List[str] = [] 367 | performance: List[float] = [] 368 | 369 | 370 | class Organization(schemas._PydanticConfig): 371 | """Organization""" 372 | 373 | id: str 374 | is_current: bool 375 | creation_date: datetime 376 | 377 | type_: ClassVar[str] = schemas.Type.Organization 378 | 379 | 380 | class OrganizationInfoConfig(schemas._PydanticConfig, extra="allow"): 381 | model_config = ConfigDict(protected_namespaces=()) 382 | model_export_enabled: bool 383 | 384 | 385 | class OrganizationInfo(schemas._PydanticConfig): 386 | host: AnyUrl 387 | organization_id: str 388 | organization_name: str 389 | config: OrganizationInfoConfig 390 | channel: str 391 | version: str 392 | orchestrator_version: str 393 | 394 | 395 | class _TaskAsset(schemas._PydanticConfig): 396 | kind: str 397 | identifier: str 398 | 399 | @staticmethod 400 | def allowed_filters() -> List[str]: 401 | return ["identifier", "kind"] 402 | 403 | 404 | class InputAsset(_TaskAsset): 405 | asset: Union[Dataset, DataSample, OutModel] 406 | type_: ClassVar[str] = schemas.Type.InputAsset 407 | 408 | 409 | class OutputAsset(_TaskAsset): 410 | asset: Union[float, OutModel] 411 | type_: ClassVar[str] = schemas.Type.OutputAsset 412 | 413 | # Deal with remote returning the actual performance object 414 | @pydantic.field_validator("asset", mode="before") 415 | @classmethod 416 | def convert_remote_performance(cls, value, values): 417 | if values.data.get("kind") == schemas.AssetKind.performance and isinstance(value, dict): 418 | return value.get("performance_value") 419 | 420 | return value 421 | 422 | 423 | SCHEMA_TO_MODEL = { 424 | schemas.Type.Task: Task, 425 | schemas.Type.Function: Function, 426 | schemas.Type.ComputePlan: ComputePlan, 427 | schemas.Type.DataSample: DataSample, 428 | schemas.Type.Dataset: Dataset, 429 | schemas.Type.Organization: Organization, 430 | schemas.Type.Model: OutModel, 431 | schemas.Type.InputAsset: InputAsset, 432 | schemas.Type.OutputAsset: OutputAsset, 433 | schemas.Type.FunctionOutput: FunctionOutput, 434 | schemas.Type.FunctionInput: FunctionInput, 435 | } 436 | -------------------------------------------------------------------------------- /substra/sdk/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import copy 3 | import functools 4 | import io 5 | import logging 6 | import ntpath 7 | import os 8 | import re 9 | import time 10 | import uuid 11 | import zipfile 12 | 13 | from substra.sdk import exceptions 14 | from substra.sdk import models 15 | 16 | 17 | def path_leaf(path): 18 | head, tail = ntpath.split(path) 19 | return tail or ntpath.basename(head) 20 | 21 | 22 | @contextlib.contextmanager 23 | def extract_files(data, file_attributes): 24 | data = copy.deepcopy(data) 25 | 26 | paths = {} 27 | for attr in file_attributes: 28 | try: 29 | paths[attr] = data[attr] 30 | except KeyError: 31 | raise exceptions.LoadDataException(f"The '{attr}' attribute is missing.") 32 | del data[attr] 33 | 34 | files = {} 35 | for k, f in paths.items(): 36 | if not os.path.exists(f): 37 | raise exceptions.LoadDataException(f"The '{k}' attribute file ({f}) does not exist.") 38 | files[k] = open(f, "rb") 39 | 40 | try: 41 | yield (data, files) 42 | finally: 43 | for f in files.values(): 44 | f.close() 45 | 46 | 47 | def zip_folder(fp, path, followlinks=False): 48 | zipf = zipfile.ZipFile(fp, "w", zipfile.ZIP_DEFLATED) 49 | for root, _, files in os.walk(path, followlinks=followlinks): 50 | for f in files: 51 | abspath = os.path.join(root, f) 52 | archive_path = os.path.relpath(abspath, start=path) 53 | zipf.write(abspath, arcname=archive_path) 54 | zipf.close() 55 | 56 | 57 | def zip_folder_in_memory(path, followlinks=False): 58 | fp = io.BytesIO() 59 | zip_folder(fp, path, followlinks=followlinks) 60 | fp.seek(0) 61 | return fp 62 | 63 | 64 | @contextlib.contextmanager 65 | def extract_data_sample_files(data, followlinks=False): 66 | # handle data sample specific case; paths and path cases 67 | data = copy.deepcopy(data) 68 | 69 | folders = {} 70 | if data.get("path"): 71 | attr = "path" 72 | folders[attr] = data[attr] 73 | del data[attr] 74 | 75 | if data.get("paths"): # field is set and is not None/empty 76 | for p in data["paths"]: 77 | folders[path_leaf(p)] = p 78 | del data["paths"] 79 | 80 | files = {} 81 | for k, f in folders.items(): 82 | if not os.path.isdir(f): 83 | raise exceptions.LoadDataException(f"Paths '{f}' is not an existing directory") 84 | files[k] = zip_folder_in_memory(f, followlinks=followlinks) 85 | 86 | try: 87 | yield (data, files) 88 | finally: 89 | for f in files.values(): 90 | f.close() 91 | 92 | 93 | def _check_metadata_search_filter(filter): 94 | if not isinstance(filter, dict): 95 | raise exceptions.FilterFormatError( 96 | "Cannot load filters. Please review the documentation, metadata filter should be a list of dict." 97 | "But one passed elements is not." 98 | ) 99 | 100 | if "key" not in filter.keys() or "type" not in filter.keys(): 101 | raise exceptions.FilterFormatError("Each metadata filter, must contains both `key` and `type` as key.") 102 | 103 | if filter["type"] not in ("is", "contains", "exists"): 104 | raise exceptions.FilterFormatError( 105 | "Each metadata filter `type` filed value must be `is`, `contains` or `exists`" 106 | ) 107 | 108 | if filter["type"] in ("is", "contains"): 109 | if "value" not in filter.keys(): 110 | raise exceptions.FilterFormatError( 111 | "For each metadata filter, if `type` value is `is` or `contains`, the filter should also contain the " 112 | "`value` key." 113 | ) 114 | 115 | if not isinstance(filter.get("value"), str): 116 | raise exceptions.FilterFormatError( 117 | "For each metadata filter, if a `value` is passed, it should be a string." 118 | ) 119 | 120 | 121 | def _check_metadata_search_filters(filters): 122 | if not isinstance(filters, list): 123 | raise exceptions.FilterFormatError( 124 | "Cannot load filters. Please review the documentation, metadata filter should be a list of dict." 125 | ) 126 | 127 | for filter in filters: 128 | _check_metadata_search_filter(filter) 129 | 130 | 131 | def check_and_format_search_filters(asset_type, filters): # noqa: C901 132 | # do not check if no filters 133 | if filters is None: 134 | return filters 135 | 136 | # check filters structure 137 | if not isinstance(filters, dict): 138 | raise exceptions.FilterFormatError( 139 | "Cannot load filters. Please review the documentation, filters should be a dict" 140 | ) 141 | 142 | # retrieving asset allowed fields to filter on 143 | allowed_filters = models.SCHEMA_TO_MODEL[asset_type].allowed_filters() 144 | 145 | # for each attribute (key) to filter on 146 | for key in filters: 147 | # check that key is a valid filter 148 | if key not in allowed_filters: 149 | raise exceptions.NotAllowedFilterError( 150 | f"Cannot filter on {key}. Please review the documentation, filtering allowed only on {allowed_filters}" 151 | ) 152 | elif key == "name": 153 | if not isinstance(filters[key], str): 154 | raise exceptions.FilterFormatError( 155 | """Cannot load filters. Please review the documentation, 'name' filter is partial match in remote, 156 | exact match in local, value should be str""" 157 | ) 158 | elif key == "metadata": 159 | _check_metadata_search_filters(filters[key]) 160 | 161 | # all other filters should be a list, throw an error if not 162 | elif not isinstance(filters[key], list): 163 | raise exceptions.FilterFormatError( 164 | "Cannot load filters. Please review the documentation, filters values should be a list" 165 | ) 166 | # handle default case (List) 167 | else: 168 | # convert all keys to str, needed for rank as user can give int, can prevent errors if user doesn't give str 169 | filters[key] = [str(v) for v in filters[key]] 170 | 171 | return filters 172 | 173 | 174 | def check_search_ordering(order_by): 175 | if order_by is None: 176 | return 177 | elif not models.OrderingFields.__contains__(order_by): 178 | raise exceptions.OrderingFormatError( 179 | f"Please review the documentation, ordering is available only on {list(models.OrderingFields.__members__)}" 180 | ) 181 | 182 | 183 | def retry_on_exception(exceptions, timeout=300): 184 | """Retry function in case of exception(s). 185 | 186 | Args: 187 | exceptions (list): list of exception types that trigger a retry 188 | timeout (int, optional): timeout in seconds 189 | 190 | Example: 191 | ```python 192 | from substra.sdk import exceptions, retry_on_exception 193 | 194 | def my_function(arg1, arg2): 195 | pass 196 | 197 | retry = retry_on_exception( 198 | exceptions=(exceptions.RequestTimeout), 199 | timeout=300, 200 | ) 201 | retry(my_function)(arg1, arg2) 202 | ``` 203 | """ 204 | 205 | def _retry(f): 206 | @functools.wraps(f) 207 | def wrapper(*args, **kwargs): 208 | delay = 1 209 | backoff = 2 210 | tstart = time.time() 211 | 212 | while True: 213 | try: 214 | return f(*args, **kwargs) 215 | 216 | except exceptions: 217 | if timeout is not False and time.time() - tstart > timeout: 218 | raise 219 | logging.warning(f"Function {f.__name__} failed: retrying in {delay}s") 220 | time.sleep(delay) 221 | delay *= backoff 222 | 223 | return wrapper 224 | 225 | return _retry 226 | 227 | 228 | def response_get_destination_filename(response): 229 | """Get filename from content-disposition header.""" 230 | disposition = response.headers.get("content-disposition") 231 | if not disposition: 232 | return None 233 | filenames = re.findall("filename=(.+)", disposition) 234 | if not filenames: 235 | return None 236 | filename = filenames[0] 237 | filename = filename.strip("'\"") 238 | return filename 239 | 240 | 241 | def is_valid_uuid(value): 242 | try: 243 | uuid.UUID(value) 244 | return True 245 | except ValueError: 246 | return False 247 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import substra 4 | from substra.sdk.schemas import FunctionSpec 5 | 6 | from . import data_factory 7 | from .fl_interface import FLFunctionInputs 8 | from .fl_interface import FLFunctionOutputs 9 | from .fl_interface import FunctionCategory 10 | 11 | 12 | def pytest_configure(config): 13 | config.addinivalue_line( 14 | "markers", 15 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def client(tmpdir): 21 | c = substra.Client(url="http://foo.io", backend_type="remote", token="foo") 22 | return c 23 | 24 | 25 | @pytest.fixture 26 | def workdir(tmp_path): 27 | d = tmp_path / "substra-cli" 28 | d.mkdir() 29 | return d 30 | 31 | 32 | @pytest.fixture 33 | def dataset_query(tmpdir): 34 | opener_path = tmpdir / "opener.py" 35 | opener_path.write_text("raise ValueError()", encoding="utf-8") 36 | 37 | desc_path = tmpdir / "description.md" 38 | desc_path.write_text("#Hello world", encoding="utf-8") 39 | 40 | return { 41 | "name": "dataset_name", 42 | "data_opener": str(opener_path), 43 | "description": str(desc_path), 44 | "permissions": { 45 | "public": True, 46 | "authorized_ids": [], 47 | }, 48 | "logs_permission": { 49 | "public": True, 50 | "authorized_ids": [], 51 | }, 52 | } 53 | 54 | 55 | @pytest.fixture 56 | def metric_query(tmpdir): 57 | metrics_path = tmpdir / "metrics.zip" 58 | metrics_path.write_text("foo archive", encoding="utf-8") 59 | 60 | desc_path = tmpdir / "description.md" 61 | desc_path.write_text("#Hello world", encoding="utf-8") 62 | 63 | return { 64 | "name": "metrics_name", 65 | "file": str(metrics_path), 66 | "description": str(desc_path), 67 | "permissions": { 68 | "public": True, 69 | "authorized_ids": [], 70 | }, 71 | } 72 | 73 | 74 | @pytest.fixture 75 | def function_query(tmpdir): 76 | function_file_path = tmpdir / "function.tar.gz" 77 | function_file_path.write(b"tar gz archive") 78 | 79 | desc_path = tmpdir / "description.md" 80 | desc_path.write_text("#Hello world", encoding="utf-8") 81 | 82 | function_category = FunctionCategory.simple 83 | 84 | return FunctionSpec( 85 | name="function_name", 86 | inputs=FLFunctionInputs[function_category], 87 | outputs=FLFunctionOutputs[function_category], 88 | description=str(desc_path), 89 | file=str(function_file_path), 90 | permissions={ 91 | "public": True, 92 | "authorized_ids": [], 93 | }, 94 | ) 95 | 96 | 97 | @pytest.fixture 98 | def data_sample_query(tmpdir): 99 | data_sample_dir_path = tmpdir / "data_sample_0" 100 | data_sample_file_path = data_sample_dir_path / "data.txt" 101 | data_sample_file_path.write_text("Hello world 0", encoding="utf-8", ensure=True) 102 | 103 | return { 104 | "path": str(data_sample_dir_path), 105 | "data_manager_keys": ["42"], 106 | } 107 | 108 | 109 | @pytest.fixture 110 | def data_samples_query(tmpdir): 111 | nb = 3 112 | paths = [] 113 | for i in range(nb): 114 | data_sample_dir_path = tmpdir / f"data_sample_{i}" 115 | data_sample_file_path = data_sample_dir_path / "data.txt" 116 | data_sample_file_path.write_text(f"Hello world {i}", encoding="utf-8", ensure=True) 117 | 118 | paths.append(str(data_sample_dir_path)) 119 | 120 | return { 121 | "paths": paths, 122 | "data_manager_keys": ["42"], 123 | } 124 | 125 | 126 | @pytest.fixture(scope="session") 127 | def asset_factory(): 128 | return data_factory.AssetsFactory("test_debug") 129 | 130 | 131 | @pytest.fixture() 132 | def data_sample(asset_factory): 133 | return asset_factory.create_data_sample() 134 | -------------------------------------------------------------------------------- /tests/fl_interface.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from substra.sdk.schemas import AssetKind 4 | from substra.sdk.schemas import ComputeTaskOutputSpec 5 | from substra.sdk.schemas import FunctionInputSpec 6 | from substra.sdk.schemas import FunctionOutputSpec 7 | from substra.sdk.schemas import InputRef 8 | from substra.sdk.schemas import Permissions 9 | 10 | PUBLIC_PERMISSIONS = Permissions(public=True, authorized_ids=[]) 11 | 12 | 13 | class FunctionCategory(str, Enum): 14 | """Function category""" 15 | 16 | unknown = "FUNCTION_UNKNOWN" 17 | simple = "FUNCTION_SIMPLE" 18 | composite = "FUNCTION_COMPOSITE" 19 | aggregate = "FUNCTION_AGGREGATE" 20 | metric = "FUNCTION_METRIC" 21 | predict = "FUNCTION_PREDICT" 22 | predict_composite = "FUNCTION_PREDICT_COMPOSITE" 23 | 24 | 25 | class InputIdentifiers(str, Enum): 26 | local = "local" 27 | shared = "shared" 28 | predictions = "predictions" 29 | performance = "performance" 30 | opener = "opener" 31 | datasamples = "datasamples" 32 | 33 | 34 | class OutputIdentifiers(str, Enum): 35 | local = "local" 36 | shared = "shared" 37 | predictions = "predictions" 38 | performance = "performance" 39 | 40 | 41 | class FLFunctionInputs(list, Enum): 42 | """Substra function inputs by function category based on the InputIdentifiers""" 43 | 44 | FUNCTION_AGGREGATE = [ 45 | FunctionInputSpec(identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=False, multiple=True) 46 | ] 47 | FUNCTION_SIMPLE = [ 48 | FunctionInputSpec( 49 | identifier=InputIdentifiers.datasamples, 50 | kind=AssetKind.data_sample.value, 51 | optional=False, 52 | multiple=True, 53 | ), 54 | FunctionInputSpec( 55 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False 56 | ), 57 | FunctionInputSpec(identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=True, multiple=True), 58 | ] 59 | FUNCTION_COMPOSITE = [ 60 | FunctionInputSpec( 61 | identifier=InputIdentifiers.datasamples, 62 | kind=AssetKind.data_sample.value, 63 | optional=False, 64 | multiple=True, 65 | ), 66 | FunctionInputSpec( 67 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False 68 | ), 69 | FunctionInputSpec(identifier=InputIdentifiers.local, kind=AssetKind.model.value, optional=True, multiple=False), 70 | FunctionInputSpec( 71 | identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=True, multiple=False 72 | ), 73 | ] 74 | FUNCTION_PREDICT = [ 75 | FunctionInputSpec( 76 | identifier=InputIdentifiers.datasamples, 77 | kind=AssetKind.data_sample.value, 78 | optional=False, 79 | multiple=True, 80 | ), 81 | FunctionInputSpec( 82 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False 83 | ), 84 | FunctionInputSpec( 85 | identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=False, multiple=False 86 | ), 87 | ] 88 | FUNCTION_PREDICT_COMPOSITE = [ 89 | FunctionInputSpec( 90 | identifier=InputIdentifiers.datasamples, 91 | kind=AssetKind.data_sample.value, 92 | optional=False, 93 | multiple=True, 94 | ), 95 | FunctionInputSpec( 96 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False 97 | ), 98 | FunctionInputSpec( 99 | identifier=InputIdentifiers.local, kind=AssetKind.model.value, optional=False, multiple=False 100 | ), 101 | FunctionInputSpec( 102 | identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=False, multiple=False 103 | ), 104 | ] 105 | FUNCTION_METRIC = [ 106 | FunctionInputSpec( 107 | identifier=InputIdentifiers.datasamples, 108 | kind=AssetKind.data_sample.value, 109 | optional=False, 110 | multiple=True, 111 | ), 112 | FunctionInputSpec( 113 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False 114 | ), 115 | FunctionInputSpec( 116 | identifier=InputIdentifiers.predictions, kind=AssetKind.model.value, optional=False, multiple=False 117 | ), 118 | ] 119 | 120 | 121 | class FLFunctionOutputs(list, Enum): 122 | """Substra function outputs by function category based on the OutputIdentifiers""" 123 | 124 | FUNCTION_AGGREGATE = [ 125 | FunctionOutputSpec(identifier=OutputIdentifiers.shared, kind=AssetKind.model.value, multiple=False) 126 | ] 127 | FUNCTION_SIMPLE = [ 128 | FunctionOutputSpec(identifier=OutputIdentifiers.shared, kind=AssetKind.model.value, multiple=False) 129 | ] 130 | FUNCTION_COMPOSITE = [ 131 | FunctionOutputSpec(identifier=OutputIdentifiers.local, kind=AssetKind.model.value, multiple=False), 132 | FunctionOutputSpec(identifier=OutputIdentifiers.shared, kind=AssetKind.model.value, multiple=False), 133 | ] 134 | FUNCTION_PREDICT = [ 135 | FunctionOutputSpec(identifier=OutputIdentifiers.predictions, kind=AssetKind.model.value, multiple=False) 136 | ] 137 | FUNCTION_PREDICT_COMPOSITE = [ 138 | FunctionOutputSpec(identifier=OutputIdentifiers.predictions, kind=AssetKind.model.value, multiple=False) 139 | ] 140 | FUNCTION_METRIC = [ 141 | FunctionOutputSpec(identifier=OutputIdentifiers.performance, kind=AssetKind.performance.value, multiple=False) 142 | ] 143 | 144 | 145 | class FLTaskInputGenerator: 146 | "Generates task inputs based on Input and OutputIdentifiers" 147 | 148 | @staticmethod 149 | def opener(opener_key): 150 | return [InputRef(identifier=InputIdentifiers.opener, asset_key=opener_key)] 151 | 152 | @staticmethod 153 | def data_samples(data_sample_keys): 154 | return [ 155 | InputRef(identifier=InputIdentifiers.datasamples, asset_key=data_sample) for data_sample in data_sample_keys 156 | ] 157 | 158 | @staticmethod 159 | def task(opener_key, data_sample_keys): 160 | return FLTaskInputGenerator.opener(opener_key=opener_key) + FLTaskInputGenerator.data_samples( 161 | data_sample_keys=data_sample_keys 162 | ) 163 | 164 | @staticmethod 165 | def trains_to_train(model_keys): 166 | return [ 167 | InputRef( 168 | identifier=InputIdentifiers.shared, 169 | parent_task_key=model_key, 170 | parent_task_output_identifier=OutputIdentifiers.shared, 171 | ) 172 | for model_key in model_keys 173 | ] 174 | 175 | @staticmethod 176 | def trains_to_aggregate(model_keys): 177 | return [ 178 | InputRef( 179 | identifier=InputIdentifiers.shared, 180 | parent_task_key=model_key, 181 | parent_task_output_identifier=OutputIdentifiers.shared, 182 | ) 183 | for model_key in model_keys 184 | ] 185 | 186 | @staticmethod 187 | def train_to_predict(model_key): 188 | return [ 189 | InputRef( 190 | identifier=InputIdentifiers.shared, 191 | parent_task_key=model_key, 192 | parent_task_output_identifier=OutputIdentifiers.shared, 193 | ) 194 | ] 195 | 196 | @staticmethod 197 | def predict_to_test(model_key): 198 | return [ 199 | InputRef( 200 | identifier=InputIdentifiers.predictions, 201 | parent_task_key=model_key, 202 | parent_task_output_identifier=OutputIdentifiers.predictions, 203 | ) 204 | ] 205 | 206 | @staticmethod 207 | def composite_to_predict(model_key): 208 | return [ 209 | InputRef( 210 | identifier=InputIdentifiers.local, 211 | parent_task_key=model_key, 212 | parent_task_output_identifier=OutputIdentifiers.local, 213 | ), 214 | InputRef( 215 | identifier=InputIdentifiers.shared, 216 | parent_task_key=model_key, 217 | parent_task_output_identifier=OutputIdentifiers.shared, 218 | ), 219 | ] 220 | 221 | @staticmethod 222 | def composite_to_local(model_key): 223 | return [ 224 | InputRef( 225 | identifier=InputIdentifiers.local, 226 | parent_task_key=model_key, 227 | parent_task_output_identifier=OutputIdentifiers.local, 228 | ) 229 | ] 230 | 231 | @staticmethod 232 | def composite_to_composite(model1_key, model2_key=None): 233 | return [ 234 | InputRef( 235 | identifier=InputIdentifiers.local, 236 | parent_task_key=model1_key, 237 | parent_task_output_identifier=OutputIdentifiers.local, 238 | ), 239 | InputRef( 240 | identifier=InputIdentifiers.shared, 241 | parent_task_key=model2_key or model1_key, 242 | parent_task_output_identifier=OutputIdentifiers.shared, 243 | ), 244 | ] 245 | 246 | @staticmethod 247 | def aggregate_to_shared(model_key): 248 | return [ 249 | InputRef( 250 | identifier=InputIdentifiers.shared, 251 | parent_task_key=model_key, 252 | parent_task_output_identifier=OutputIdentifiers.shared, 253 | ) 254 | ] 255 | 256 | @staticmethod 257 | def composites_to_aggregate(model_keys): 258 | return [ 259 | InputRef( 260 | identifier=InputIdentifiers.shared, 261 | parent_task_key=model_key, 262 | parent_task_output_identifier=OutputIdentifiers.shared, 263 | ) 264 | for model_key in model_keys 265 | ] 266 | 267 | @staticmethod 268 | def aggregate_to_predict(model_key): 269 | return [ 270 | InputRef( 271 | identifier=InputIdentifiers.shared, 272 | parent_task_key=model_key, 273 | parent_task_output_identifier=OutputIdentifiers.shared, 274 | ) 275 | ] 276 | 277 | @staticmethod 278 | def local_to_aggregate(model_key): 279 | return [ 280 | InputRef( 281 | identifier=InputIdentifiers.shared, 282 | parent_task_key=model_key, 283 | parent_task_output_identifier=OutputIdentifiers.local, 284 | ) 285 | ] 286 | 287 | @staticmethod 288 | def shared_to_aggregate(model_key): 289 | return [ 290 | InputRef( 291 | identifier=InputIdentifiers.shared, 292 | parent_task_key=model_key, 293 | parent_task_output_identifier=OutputIdentifiers.shared, 294 | ) 295 | ] 296 | 297 | 298 | def _permission_from_ids(authorized_ids): 299 | if authorized_ids is None: 300 | return PUBLIC_PERMISSIONS 301 | 302 | return Permissions(public=False, authorized_ids=authorized_ids) 303 | 304 | 305 | class FLTaskOutputGenerator: 306 | "Generates task outputs based on Input and OutputIdentifiers" 307 | 308 | @staticmethod 309 | def traintask(authorized_ids=None): 310 | return {OutputIdentifiers.shared: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))} 311 | 312 | @staticmethod 313 | def aggregatetask(authorized_ids=None): 314 | return {OutputIdentifiers.shared: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))} 315 | 316 | @staticmethod 317 | def predicttask(authorized_ids=None): 318 | return {OutputIdentifiers.predictions: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))} 319 | 320 | @staticmethod 321 | def testtask(authorized_ids=None): 322 | return { 323 | OutputIdentifiers.performance: ComputeTaskOutputSpec( 324 | permissions=Permissions(public=True, authorized_ids=[]) 325 | ) 326 | } 327 | 328 | @staticmethod 329 | def composite_traintask(shared_authorized_ids=None, local_authorized_ids=None): 330 | return { 331 | OutputIdentifiers.shared: ComputeTaskOutputSpec(permissions=_permission_from_ids(shared_authorized_ids)), 332 | OutputIdentifiers.local: ComputeTaskOutputSpec(permissions=_permission_from_ids(local_authorized_ids)), 333 | } 334 | -------------------------------------------------------------------------------- /tests/mocked_requests.py: -------------------------------------------------------------------------------- 1 | from tests.utils import mock_requests 2 | 3 | 4 | def cancel_compute_plan(mocker): 5 | m = mock_requests(mocker, "post", response=None) 6 | return m 7 | -------------------------------------------------------------------------------- /tests/sdk/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/sdk/__init__.py -------------------------------------------------------------------------------- /tests/sdk/data/symlink.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/sdk/data/symlink.zip -------------------------------------------------------------------------------- /tests/sdk/data/traversal.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/sdk/data/traversal.zip -------------------------------------------------------------------------------- /tests/sdk/local/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/sdk/local/__init__.py -------------------------------------------------------------------------------- /tests/sdk/local/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import substra 4 | 5 | 6 | @pytest.fixture(scope="session") 7 | def docker_clients(): 8 | return [substra.Client(backend_type=substra.BackendType.LOCAL_DOCKER) for _ in range(2)] 9 | 10 | 11 | @pytest.fixture(scope="session") 12 | def subprocess_clients(): 13 | return [substra.Client(backend_type=substra.BackendType.LOCAL_SUBPROCESS) for _ in range(2)] 14 | 15 | 16 | @pytest.fixture(scope="session", params=["docker", "subprocess"]) 17 | def clients(request, docker_clients, subprocess_clients): 18 | if request.param == "docker": 19 | return docker_clients 20 | else: 21 | return subprocess_clients 22 | -------------------------------------------------------------------------------- /tests/sdk/test_add.py: -------------------------------------------------------------------------------- 1 | import pydantic 2 | import pytest 3 | 4 | import substra 5 | from substra.sdk import models 6 | from substra.sdk.exceptions import ComputePlanKeyFormatError 7 | 8 | from .. import datastore 9 | from ..utils import mock_requests 10 | 11 | 12 | def test_add_dataset(client, dataset_query, mocker): 13 | m_post = mock_requests(mocker, "post", response=datastore.DATASET) 14 | m_get = mock_requests(mocker, "get", response=datastore.DATASET) 15 | key = client.add_dataset(dataset_query) 16 | response = client.get_dataset(key) 17 | 18 | assert response == models.Dataset(**datastore.DATASET) 19 | m_post.assert_called() 20 | m_get.assert_called() 21 | 22 | 23 | def test_add_dataset_invalid_args(client, dataset_query, mocker): 24 | mock_requests(mocker, "post", response=datastore.DATASET) 25 | del dataset_query["data_opener"] 26 | 27 | with pytest.raises(pydantic.ValidationError): 28 | client.add_dataset(dataset_query) 29 | 30 | 31 | def test_add_dataset_response_failure_500(client, dataset_query, mocker): 32 | mock_requests(mocker, "post", status=500) 33 | 34 | with pytest.raises(substra.sdk.exceptions.InternalServerError): 35 | client.add_dataset(dataset_query) 36 | 37 | 38 | def test_add_dataset_409_success(client, dataset_query, mocker): 39 | mock_requests(mocker, "post", response={"key": datastore.DATASET["key"]}, status=409) 40 | mock_requests(mocker, "get", response=datastore.DATASET) 41 | 42 | key = client.add_dataset(dataset_query) 43 | 44 | assert key == datastore.DATASET["key"] 45 | 46 | 47 | def test_add_function(client, function_query, mocker): 48 | m_post = mock_requests(mocker, "post", response=datastore.FUNCTION) 49 | m_get = mock_requests(mocker, "get", response=datastore.FUNCTION) 50 | key = client.add_function(function_query) 51 | response = client.get_function(key) 52 | 53 | assert response == models.Function(**datastore.FUNCTION) 54 | m_post.assert_called() 55 | m_get.assert_called() 56 | 57 | 58 | def test_add_data_sample(client, data_sample_query, mocker): 59 | server_response = [{"key": "42"}] 60 | m = mock_requests(mocker, "post", response=server_response) 61 | response = client.add_data_sample(data_sample_query) 62 | 63 | assert response == server_response[0]["key"] 64 | m.assert_called() 65 | 66 | 67 | def test_add_data_sample_already_exists(client, data_sample_query, mocker): 68 | m = mock_requests(mocker, "post", response=[{"key": "42"}], status=409) 69 | response = client.add_data_sample(data_sample_query) 70 | 71 | assert response == "42" 72 | m.assert_called() 73 | 74 | 75 | # We try to add multiple data samples instead of a single one 76 | def test_add_data_sample_with_paths(client, data_samples_query): 77 | with pytest.raises(ValueError): 78 | client.add_data_sample(data_samples_query) 79 | 80 | 81 | def test_add_data_samples(client, data_samples_query, mocker): 82 | server_response = [{"key": "42"}] 83 | m = mock_requests(mocker, "post", response=server_response) 84 | response = client.add_data_samples(data_samples_query) 85 | 86 | assert response == ["42"] 87 | m.assert_called() 88 | 89 | 90 | # We try to add a single data sample instead of multiple ones 91 | def test_add_data_samples_with_path(client, data_sample_query): 92 | with pytest.raises(ValueError): 93 | client.add_data_samples(data_sample_query) 94 | 95 | 96 | def test_add_compute_plan_wrong_key_format(client): 97 | with pytest.raises(ComputePlanKeyFormatError): 98 | data = {"key": "wrong_format", "name": "A perfectly valid name"} 99 | client.add_compute_plan(data) 100 | -------------------------------------------------------------------------------- /tests/sdk/test_archive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tarfile 4 | import tempfile 5 | 6 | import pytest 7 | 8 | from substra.sdk.archive import _untar 9 | from substra.sdk.archive import _unzip 10 | from substra.sdk.archive import tarsafe 11 | from substra.sdk.archive.safezip import ZipFile 12 | 13 | 14 | class TestsZipFile: 15 | # This zip file was specifically crafted and contains empty files named: 16 | # foo/bar 17 | # ../foo/bar 18 | # ../../foo/bar 19 | # ../../../foo/bar 20 | TRAVERSAL_ZIP = os.path.join(os.path.dirname(__file__), "data", "traversal.zip") 21 | 22 | # This zip file was specifically crafted and contains: 23 | # bar 24 | # foo -> bar (symlink) 25 | SYMLINK_ZIP = os.path.join(os.path.dirname(__file__), "data", "symlink.zip") 26 | 27 | def test_raise_on_path_traversal(self): 28 | zf = ZipFile(self.TRAVERSAL_ZIP, "r") 29 | with pytest.raises(Exception) as exc: 30 | zf.extractall(tempfile.gettempdir()) 31 | 32 | assert "Attempted directory traversal" in str(exc.value) 33 | 34 | def test_raise_on_symlink(self): 35 | zf = ZipFile(self.SYMLINK_ZIP, "r") 36 | with pytest.raises(Exception) as exc: 37 | zf.extractall(tempfile.gettempdir()) 38 | 39 | assert "Unsupported symlink" in str(exc.value) 40 | 41 | def test_compress_uncompress_zip(self): 42 | with tempfile.TemporaryDirectory() as tmpdirname: 43 | path_testdir = os.path.join(tmpdirname, "testdir") 44 | os.makedirs(path_testdir) 45 | with open(os.path.join(path_testdir, "testfile.txt"), "w") as f: 46 | f.write("testtext") 47 | shutil.make_archive(os.path.join(tmpdirname, "test_archive"), "zip", root_dir=os.path.dirname(path_testdir)) 48 | _unzip(os.path.join(tmpdirname, "test_archive.zip"), os.path.join(tmpdirname, "test_archive")) 49 | 50 | assert os.listdir(tmpdirname + "/test_archive/testdir") == ["testfile.txt"] 51 | 52 | 53 | class TestsTarSafe: 54 | def test_compress_uncompress_tar(self): 55 | with tempfile.TemporaryDirectory() as tmpdirname: 56 | path_testdir = os.path.join(tmpdirname, "testdir") 57 | os.makedirs(path_testdir) 58 | with open(os.path.join(path_testdir, "testfile.txt"), "w") as f: 59 | f.write("testtext") 60 | path_tarfile = os.path.join(tmpdirname, "test_archive.tar") 61 | with tarsafe.open(path_tarfile, "w:gz") as tar: 62 | tar.add(path_testdir, arcname=os.path.basename(path_testdir)) 63 | 64 | _untar(path_tarfile, os.path.join(tmpdirname, "test_archive")) 65 | 66 | assert os.listdir(tmpdirname + "/test_archive/testdir") == ["testfile.txt"] 67 | 68 | def test_raise_on_symlink(self): 69 | with tempfile.TemporaryDirectory() as tmpdir: 70 | # create the following tree structure: 71 | # ./Dockerfile 72 | # ./foo 73 | # ./Dockerfile -> ../Dockerfile 74 | 75 | filename = "Dockerfile" 76 | symlink_source = os.path.join(tmpdir, filename) 77 | with open(symlink_source, "w") as fp: 78 | fp.write("FROM bar") 79 | 80 | archive_root = os.path.join(tmpdir, "foo") 81 | os.mkdir(archive_root) 82 | os.symlink(symlink_source, os.path.join(archive_root, filename)) 83 | 84 | # create a tar archive of the foo folder 85 | tarpath = os.path.join(tmpdir, "foo.tgz") 86 | with tarfile.open(tarpath, "w:gz") as tar: 87 | for root, _, files in os.walk(archive_root): 88 | for file in files: 89 | tar.add(os.path.join(root, file)) 90 | 91 | with pytest.raises(tarsafe.TarSafeError) as error: 92 | with tarsafe.open(tarpath, "r") as tar: 93 | tar.extractall() 94 | 95 | assert "Unsupported symlink" in str(error.exception) 96 | -------------------------------------------------------------------------------- /tests/sdk/test_cancel.py: -------------------------------------------------------------------------------- 1 | from tests import mocked_requests 2 | 3 | 4 | def test_cancel_compute_plan(client, mocker): 5 | m = mocked_requests.cancel_compute_plan(mocker) 6 | 7 | response = client.cancel_compute_plan("magic-key") 8 | 9 | assert response is None 10 | m.assert_called() 11 | -------------------------------------------------------------------------------- /tests/sdk/test_client.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import yaml 3 | 4 | from substra.sdk import exceptions 5 | from substra.sdk import schemas 6 | from substra.sdk.client import Client 7 | from substra.sdk.client import _upper_slug 8 | from substra.sdk.exceptions import ConfigurationInfoError 9 | 10 | 11 | @pytest.fixture 12 | def config_file(tmp_path): 13 | config_dict = { 14 | "toto": { 15 | "backend_type": "remote", 16 | "url": "toto-org.com", 17 | "username": "toto_file_username", 18 | "password": "toto_file_password", 19 | } 20 | } 21 | config_file = tmp_path / "config.yaml" 22 | config_file.write_text(yaml.dump(config_dict)) 23 | return config_file 24 | 25 | 26 | def stub_login(username, password): 27 | if username == "org-1" and password == "password1": 28 | return "token1" 29 | if username == "env_var_username" and password == "env_var_password": 30 | return "env_var_token" 31 | if username == "toto_file_username" and password == "toto_file_password": 32 | return "toto_file_token" 33 | 34 | 35 | @pytest.mark.parametrize( 36 | ["input", "expected"], 37 | [ 38 | ("toto", "TOTO"), 39 | ("client-org-1", "CLIENT_ORG_1"), 40 | ("un nom très français", "UN_NOM_TRES_FRANCAIS"), 41 | ], 42 | ) 43 | def test_upper_slug(input, expected): 44 | assert expected == _upper_slug(input) 45 | 46 | 47 | def test_default_client(): 48 | client = Client() 49 | assert client.backend_mode == schemas.BackendType.LOCAL_SUBPROCESS 50 | 51 | 52 | @pytest.mark.parametrize( 53 | ["mode", "client_name", "url", "token", "insecure", "retry_timeout"], 54 | [ 55 | (schemas.BackendType.LOCAL_SUBPROCESS, "foo", None, None, True, 5), 56 | (schemas.BackendType.LOCAL_DOCKER, "foobar", None, None, True, None), 57 | (schemas.BackendType.REMOTE, "bar", "example.com", "bloop", False, 15), 58 | (schemas.BackendType.LOCAL_SUBPROCESS, "hybrid", "example.com", "foo", True, 500), 59 | (schemas.BackendType.REMOTE, "foofoo", "https://example.com/api-token-auth/", None, True, 500), 60 | ], 61 | ) 62 | def test_client_configured_in_code(mode, client_name, url, token, insecure, retry_timeout): 63 | client = Client( 64 | backend_type=mode, 65 | client_name=client_name, 66 | url=url, 67 | token=token, 68 | insecure=insecure, 69 | retry_timeout=retry_timeout, 70 | ) 71 | assert client.backend_mode == mode 72 | assert client._insecure == insecure 73 | if retry_timeout is not None: 74 | assert client._retry_timeout == retry_timeout 75 | else: 76 | assert client._retry_timeout == 300 77 | if token is not None: 78 | assert client._token == token 79 | else: 80 | assert client._token is None 81 | if url is not None: 82 | assert client._url == url 83 | else: 84 | assert client._url is None 85 | 86 | 87 | def test_client_should_raise_when_missing_name(): 88 | with pytest.raises(ConfigurationInfoError): 89 | Client(configuration_file="something") 90 | 91 | 92 | def test_client_with_password(mocker): 93 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 94 | rest_client_logout = mocker.patch("substra.sdk.backends.remote.rest_client.Client.logout") 95 | rest_client_logout.reset_mock() 96 | client_args = { 97 | "backend_type": "remote", 98 | "url": "example.com", 99 | "token": None, 100 | "username": "org-1", 101 | "password": "password1", 102 | } 103 | 104 | client = Client(**client_args) 105 | assert client._token == "token1" 106 | client.logout() 107 | assert client._token is None 108 | rest_client_logout.assert_called_once() 109 | del client 110 | 111 | rest_client_logout.reset_mock() 112 | client = Client(**client_args) 113 | del client 114 | rest_client_logout.assert_called_once() 115 | 116 | rest_client_logout.reset_mock() 117 | with Client(**client_args) as client: 118 | assert client._token == "token1" 119 | rest_client_logout.assert_called_once() 120 | 121 | 122 | def test_client_token_supercedes_password(mocker): 123 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 124 | client = Client( 125 | backend_type="remote", 126 | url="example.com", 127 | token="token0", 128 | username="org-1", 129 | password="password1", 130 | ) 131 | assert client._token == "token0" 132 | 133 | 134 | def test_client_configuration_from_env_var(mocker, monkeypatch): 135 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 136 | monkeypatch.setenv("SUBSTRA_TOTO_BACKEND_TYPE", "remote") 137 | monkeypatch.setenv("SUBSTRA_TOTO_URL", "toto-org.com") 138 | monkeypatch.setenv("SUBSTRA_TOTO_USERNAME", "env_var_username") 139 | monkeypatch.setenv("SUBSTRA_TOTO_PASSWORD", "env_var_password") 140 | monkeypatch.setenv("SUBSTRA_TOTO_RETRY_TIMEOUT", "42") 141 | monkeypatch.setenv("SUBSTRA_TOTO_INSECURE", "true") 142 | client = Client(client_name="toto") 143 | assert client.backend_mode == "remote" 144 | assert client._url == "toto-org.com" 145 | assert client._token == "env_var_token" 146 | assert client._retry_timeout == 42 147 | assert client._insecure is True 148 | 149 | 150 | def test_client_configuration_from_config_file(mocker, config_file): 151 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 152 | client = Client(configuration_file=config_file, client_name="toto") 153 | assert client.backend_mode == "remote" 154 | assert client._url == "toto-org.com" 155 | assert client._token == "toto_file_token" 156 | assert client._retry_timeout == 300 157 | assert client._insecure is False 158 | 159 | 160 | def test_client_configuration_code_overrides_env_var(monkeypatch): 161 | """ 162 | A variable set in the code overrides one set in an env variable 163 | """ 164 | monkeypatch.setenv("SUBSTRA_TOTO_BACKEND_TYPE", "remote") 165 | monkeypatch.setenv("SUBSTRA_TOTO_URL", "toto-org.com") 166 | monkeypatch.setenv("SUBSTRA_TOTO_USERNAME", "env_var_username") 167 | monkeypatch.setenv("SUBSTRA_TOTO_PASSWORD", "env_var_password") 168 | monkeypatch.setenv("SUBSTRA_TOTO_RETRY_TIMEOUT", "42") 169 | monkeypatch.setenv("SUBSTRA_TOTO_INSECURE", "true") 170 | client = Client( 171 | client_name="toto", 172 | backend_type="subprocess", 173 | url="", 174 | ) 175 | assert client.backend_mode == "subprocess" 176 | assert client._url == "" 177 | assert client._token is None 178 | assert client._retry_timeout == 42 179 | assert client._insecure is True 180 | 181 | 182 | def test_client_configuration_code_overrides_config_file(mocker, config_file): 183 | """ 184 | A variable set in the code overrides one set in a config file 185 | """ 186 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 187 | client = Client( 188 | client_name="toto", 189 | configuration_file=config_file, 190 | username="org-1", 191 | password="password1", 192 | retry_timeout=100, 193 | ) 194 | assert client.backend_mode == "remote" 195 | assert client._url == "toto-org.com" 196 | assert client._token == "token1" 197 | assert client._retry_timeout == 100 198 | assert client._insecure is False 199 | 200 | 201 | def test_client_configuration_env_var_overrides_config_file(mocker, monkeypatch, config_file): 202 | """ 203 | A variable set in an env var overrides one set in a config file 204 | """ 205 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 206 | monkeypatch.setenv("SUBSTRA_TOTO_BACKEND_TYPE", "docker") 207 | monkeypatch.setenv("SUBSTRA_TOTO_USERNAME", "env_var_username") 208 | monkeypatch.setenv("SUBSTRA_TOTO_PASSWORD", "env_var_password") 209 | client = Client(configuration_file=config_file, client_name="toto", retry_timeout=12) 210 | assert client.backend_mode == "docker" 211 | assert client._url == "toto-org.com" 212 | assert client._token == "env_var_token" 213 | assert client._retry_timeout == 12 214 | assert client._insecure is False 215 | 216 | 217 | def test_login_remote_without_url(tmpdir): 218 | with pytest.raises(exceptions.SDKException): 219 | Client(backend_type="remote") 220 | 221 | 222 | def test_client_configuration_configuration_file_path_from_env_var(mocker, monkeypatch, config_file): 223 | """ 224 | The configuration file path can be set through an env var 225 | """ 226 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 227 | monkeypatch.setenv("SUBSTRA_CLIENTS_CONFIGURATION_FILE_PATH", config_file) 228 | client = Client(client_name="toto") 229 | assert client.backend_mode == "remote" 230 | assert client._url == "toto-org.com" 231 | assert client._token == "toto_file_token" 232 | assert client._retry_timeout == 300 233 | assert client._insecure is False 234 | 235 | 236 | def test_client_configuration_configuration_file_path_parameter_supercedes_env_var( 237 | mocker, monkeypatch, config_file, tmp_path 238 | ): 239 | """ 240 | The configuration file path env var is supercedes by `configuration_file=` 241 | """ 242 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 243 | monkeypatch.setenv("SUBSTRA_CLIENTS_CONFIGURATION_FILE_PATH", config_file) 244 | 245 | config_2_dict = { 246 | "toto": { 247 | "backend_type": "docker", 248 | } 249 | } 250 | config_2_file = tmp_path / "config.yaml" 251 | config_2_file.write_text(yaml.dump(config_2_dict)) 252 | 253 | client = Client(configuration_file=config_2_file, client_name="toto") 254 | assert client.backend_mode == "docker" 255 | 256 | 257 | def test_client_configuration_file_path_env_var_empty_string(mocker, monkeypatch, config_file): 258 | """ 259 | The configuration file path env var is supercedes by `configuration_file=` 260 | """ 261 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login) 262 | monkeypatch.setenv("SUBSTRA_CLIENTS_CONFIGURATION_FILE_PATH", config_file) 263 | 264 | client = Client(configuration_file="", client_name="toto") 265 | assert client.backend_mode == "subprocess" 266 | -------------------------------------------------------------------------------- /tests/sdk/test_describe.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import substra 4 | 5 | from .. import datastore 6 | from ..utils import mock_requests 7 | from ..utils import mock_requests_responses 8 | from ..utils import mock_response 9 | 10 | 11 | @pytest.mark.parametrize("asset_type", ["dataset", "function"]) 12 | def test_describe_asset(asset_type, client, mocker): 13 | item = getattr(datastore, asset_type.upper()) 14 | responses = [ 15 | mock_response(item), # metadata 16 | mock_response("foo"), # data 17 | ] 18 | m = mock_requests_responses(mocker, "get", responses) 19 | 20 | method = getattr(client, f"describe_{asset_type}") 21 | response = method("magic-key") 22 | 23 | assert response == "foo" 24 | m.assert_called() 25 | 26 | 27 | @pytest.mark.parametrize("asset_type", ["dataset", "function"]) 28 | def test_describe_asset_not_found(asset_type, client, mocker): 29 | m = mock_requests(mocker, "get", status=404) 30 | 31 | with pytest.raises(substra.sdk.exceptions.NotFound): 32 | method = getattr(client, f"describe_{asset_type}") 33 | method("foo") 34 | 35 | assert m.call_count == 1 36 | 37 | 38 | @pytest.mark.parametrize("asset_type", ["dataset", "function"]) 39 | def test_describe_description_not_found(asset_type, client, mocker): 40 | item = getattr(datastore, asset_type.upper()) 41 | responses = [ 42 | mock_response(item), # metadata 43 | mock_response("foo", 404), # data 44 | ] 45 | m = mock_requests_responses(mocker, "get", responses) 46 | 47 | method = getattr(client, f"describe_{asset_type}") 48 | 49 | with pytest.raises(substra.sdk.exceptions.NotFound): 50 | method("key") 51 | 52 | assert m.call_count == 2 53 | -------------------------------------------------------------------------------- /tests/sdk/test_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | import substra 7 | from substra.sdk import Client 8 | 9 | from .. import datastore 10 | from ..utils import mock_requests 11 | from ..utils import mock_requests_responses 12 | from ..utils import mock_response 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "asset_type", 17 | [ 18 | ("dataset"), 19 | ("function"), 20 | ("model"), 21 | ], 22 | ) 23 | def test_download_asset(asset_type, tmp_path, client, mocker): 24 | item = getattr(datastore, asset_type.upper()) 25 | responses = [ 26 | mock_response(item), # metadata 27 | mock_response("foo"), # data 28 | ] 29 | m = mock_requests_responses(mocker, "get", responses) 30 | 31 | method = getattr(client, f"download_{asset_type}") 32 | temp_file = method("foo", tmp_path) 33 | 34 | assert os.path.exists(temp_file) 35 | m.assert_called() 36 | 37 | 38 | @pytest.mark.parametrize("asset_type", ["dataset", "function", "model", "logs"]) 39 | def test_download_asset_not_found(asset_type, tmp_path, client, mocker): 40 | m = mock_requests(mocker, "get", status=404) 41 | 42 | with pytest.raises(substra.sdk.exceptions.NotFound): 43 | method = getattr(client, f"download_{asset_type}") 44 | method("foo", tmp_path) 45 | 46 | assert m.call_count == 1 47 | 48 | 49 | @pytest.mark.parametrize("asset_type", ["dataset", "function", "model"]) 50 | def test_download_content_not_found(asset_type, tmp_path, client, mocker): 51 | item = getattr(datastore, asset_type.upper()) 52 | 53 | expected_call_count = 2 54 | responses = [ 55 | mock_response(item), # metadata 56 | mock_response("foo", status=404), # description 57 | ] 58 | 59 | if asset_type == "model": 60 | responses = [responses[1]] # No metadata for model download 61 | expected_call_count = 1 62 | 63 | m = mock_requests_responses(mocker, "get", responses) 64 | 65 | method = getattr(client, f"download_{asset_type}") 66 | 67 | with pytest.raises(substra.sdk.exceptions.NotFound): 68 | method("key", tmp_path) 69 | 70 | assert m.call_count == expected_call_count 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "asset_type, identifier", 75 | [ 76 | ("TRAINTASK", "model"), 77 | ("AGGREGATETASK", "model"), 78 | ("COMPOSITE_TRAINTASK", "local"), 79 | ("COMPOSITE_TRAINTASK", "shared"), 80 | ], 81 | ) 82 | @patch.object(Client, "download_model") 83 | def test_download_model_from_task(fake_download_model, tmp_path, client, asset_type, identifier, mocker): 84 | item = getattr(datastore, f"{asset_type}_{identifier.upper()}_RESPONSE") 85 | responses = [ 86 | mock_response(item), # metadata 87 | mock_response("foo"), # data 88 | ] 89 | 90 | m = mock_requests_responses(mocker, "get", responses) 91 | 92 | client.download_model_from_task("key", identifier, tmp_path) 93 | 94 | m.assert_called 95 | assert fake_download_model.call_count == 1 96 | 97 | 98 | def test_download_logs(tmp_path, client, mocker): 99 | logs = b"Lorem ipsum dolor sit amet" 100 | task_key = "key" 101 | 102 | response = mock_response(logs) 103 | response.iter_content.return_value = [logs] 104 | 105 | m = mock_requests_responses(mocker, "get", [response]) 106 | client.download_logs(task_key, tmp_path) 107 | 108 | m.assert_called_once() 109 | response.iter_content.assert_called_once() 110 | 111 | assert (tmp_path / f"task_logs_{task_key}.txt").read_bytes() == logs 112 | -------------------------------------------------------------------------------- /tests/sdk/test_get.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | import substra 5 | from substra.sdk import models 6 | from substra.sdk import schemas 7 | 8 | from .. import datastore 9 | from ..utils import mock_requests 10 | from ..utils import mock_requests_responses 11 | from ..utils import mock_response 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "asset_type", 16 | [ 17 | "model", 18 | "dataset", 19 | "data_sample", 20 | "function", 21 | "compute_plan", 22 | ], 23 | ) 24 | def test_get_asset(asset_type, client, mocker): 25 | item = getattr(datastore, asset_type.upper()) 26 | method = getattr(client, f"get_{asset_type}") 27 | 28 | m = mock_requests(mocker, "get", response=item) 29 | 30 | response = method("magic-key") 31 | 32 | assert response == models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**item) 33 | m.assert_called() 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "asset_type", 38 | [ 39 | "predicttask", 40 | "testtask", 41 | "traintask", 42 | "aggregatetask", 43 | "composite_traintask", 44 | ], 45 | ) 46 | def test_get_task(asset_type, client, mocker): 47 | item = getattr(datastore, asset_type.upper()) 48 | 49 | m = mock_requests(mocker, "get", response=item) 50 | 51 | response = client.get_task("magic-key") 52 | 53 | assert response == models.SCHEMA_TO_MODEL[schemas.Type.Task](**item) 54 | m.assert_called() 55 | 56 | 57 | def test_get_asset_not_found(client, mocker): 58 | mock_requests(mocker, "get", status=404) 59 | 60 | with pytest.raises(substra.sdk.exceptions.NotFound): 61 | client.get_dataset("magic-key") 62 | 63 | 64 | @pytest.mark.parametrize( 65 | "asset_type", 66 | [ 67 | "dataset", 68 | "function", 69 | "compute_plan", 70 | "model", 71 | ], 72 | ) 73 | def test_get_extra_field(asset_type, client, mocker): 74 | item = getattr(datastore, asset_type.upper()) 75 | raw = getattr(datastore, asset_type.upper()).copy() 76 | raw["unknown_extra_field"] = "some value" 77 | 78 | method = getattr(client, f"get_{asset_type}") 79 | 80 | m = mock_requests(mocker, "get", response=raw) 81 | 82 | response = method("magic-key") 83 | 84 | assert response == models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**item) 85 | m.assert_called() 86 | 87 | 88 | @pytest.mark.parametrize( 89 | "asset_type", 90 | [ 91 | "predicttask", 92 | "testtask", 93 | "traintask", 94 | "aggregatetask", 95 | "composite_traintask", 96 | ], 97 | ) 98 | def test_get_task_extra_field(asset_type, client, mocker): 99 | item = getattr(datastore, asset_type.upper()) 100 | raw = getattr(datastore, asset_type.upper()).copy() 101 | raw["unknown_extra_field"] = "some value" 102 | 103 | m = mock_requests(mocker, "get", response=raw) 104 | 105 | response = client.get_task("magic-key") 106 | 107 | assert response == models.SCHEMA_TO_MODEL[schemas.Type.Task](**item) 108 | m.assert_called() 109 | 110 | 111 | def test_get_logs(client, mocker): 112 | logs = "Lorem ipsum dolor sit amet" 113 | task_key = "key" 114 | 115 | responses = [mock_response(logs)] 116 | m = mock_requests_responses(mocker, "get", responses) 117 | result = client.get_logs(task_key) 118 | 119 | m.assert_called_once() 120 | assert result == logs 121 | 122 | 123 | def test_get_performances(client, mocker): 124 | """Test the get_performances features, and test the immediate conversion to pandas DataFrame.""" 125 | cp_item = datastore.COMPUTE_PLAN 126 | perf_item = datastore.COMPUTE_PLAN_PERF 127 | 128 | m = mock_requests_responses(mocker, "get", [mock_response(cp_item), mock_response(perf_item)]) 129 | 130 | response = client.get_performances("magic-key") 131 | results = response.model_dump() 132 | 133 | df = pd.DataFrame(results) 134 | assert list(df.columns) == list(results.keys()) 135 | assert all(len(v) == df.shape[0] for v in results.values()) 136 | assert m.call_count == 2 137 | -------------------------------------------------------------------------------- /tests/sdk/test_graph.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import pytest 4 | 5 | from substra.sdk import exceptions 6 | from substra.sdk import graph 7 | 8 | 9 | @pytest.fixture 10 | def node_graph(): 11 | return {key: list(range(key)) for key in range(10)} 12 | 13 | 14 | @pytest.fixture 15 | def node_graph_linear(): 16 | return {key: [key - 1] if key > 0 else list() for key in range(10)} 17 | 18 | 19 | @pytest.fixture 20 | def node_graph_straight_branch(): 21 | """ 22 | 0 1 23 | |-> 2 <-| 24 | | | | 25 | 3 <---> 4 26 | | 27 | 5 28 | | 29 | 6 30 | """ 31 | return { 32 | "0": [], 33 | "1": [], 34 | "2": ["0", "1"], 35 | "3": ["0", "2"], 36 | "4": ["1", "2"], 37 | "5": ["3", "3"], # duplicated link 38 | "6": ["5"], 39 | } 40 | 41 | 42 | def test_compute_ranks(node_graph): 43 | visited = graph.compute_ranks(node_graph=node_graph) 44 | for key, rank in visited.items(): 45 | assert key == rank 46 | 47 | 48 | def test_compute_ranks_linear(node_graph_linear): 49 | visited = graph.compute_ranks(node_graph=node_graph_linear) 50 | for key, rank in visited.items(): 51 | assert key == rank 52 | 53 | 54 | def test_compute_ranks_no_correlation(): 55 | node_graph = {key: list() for key in range(10)} 56 | visited = graph.compute_ranks(node_graph=node_graph) 57 | for _, rank in visited.items(): 58 | assert rank == 0 59 | 60 | 61 | def test_compute_ranks_cycle(node_graph): 62 | node_graph[5].append(9) 63 | with pytest.raises(exceptions.InvalidRequest) as e: 64 | graph.compute_ranks(node_graph=node_graph) 65 | 66 | assert "missing dependency among inModels IDs" in str(e.value) 67 | 68 | 69 | def test_compute_ranks_closed_cycle(node_graph_linear): 70 | node_graph_linear[0] = [9] 71 | with pytest.raises(exceptions.InvalidRequest) as e: 72 | graph.compute_ranks(node_graph=node_graph_linear) 73 | 74 | assert "missing dependency among inModels IDs" in str(e.value) 75 | 76 | 77 | def test_compute_ranks_ignore(node_graph): 78 | node_to_ignore = set(range(5)) 79 | for i in range(5): 80 | node_graph.pop(i) 81 | visited = graph.compute_ranks(node_graph=node_graph, node_to_ignore=node_to_ignore) 82 | for key, rank in visited.items(): 83 | assert rank == key - 5 84 | 85 | 86 | @pytest.mark.parametrize("vertices", itertools.permutations([str(i) for i in range(6)])) 87 | def test_compute_ranks_alpha(vertices): 88 | """ 89 | 0 1 90 | |-> 2 <-| 91 | | | | 92 | 3 <---> 4 93 | | 94 | 5 95 | """ 96 | node_graph = { 97 | "0": [], 98 | "1": [], 99 | "2": ["0", "1"], 100 | "3": ["0", "2"], 101 | "4": ["1", "2"], 102 | "5": ["3"], # duplicated link 103 | } 104 | ordered_node_graph = {v: node_graph[v] for v in vertices} 105 | visited = graph.compute_ranks(node_graph=ordered_node_graph) 106 | print(list(ordered_node_graph.keys())) 107 | assert visited == { 108 | "0": 0, 109 | "1": 0, 110 | "2": 1, 111 | "3": 2, 112 | "4": 2, 113 | "5": 3, 114 | } 115 | -------------------------------------------------------------------------------- /tests/sdk/test_list.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from substra.sdk import exceptions 4 | from substra.sdk import models 5 | from substra.sdk import schemas 6 | 7 | from .. import datastore 8 | from ..utils import make_paginated_response 9 | from ..utils import mock_requests 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "asset_type", 14 | [ 15 | "dataset", 16 | "function", 17 | "compute_plan", 18 | "data_sample", 19 | "model", 20 | ], 21 | ) 22 | def test_list_asset(asset_type, client, mocker): 23 | item = getattr(datastore, asset_type.upper()) 24 | method = getattr(client, f"list_{asset_type}") 25 | 26 | mocked_response = make_paginated_response([item]) 27 | m = mock_requests(mocker, "get", response=mocked_response) 28 | 29 | response = method() 30 | 31 | assert response == [models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**item)] 32 | m.assert_called() 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "asset_type", 37 | [ 38 | "predicttask", 39 | "testtask", 40 | "traintask", 41 | "aggregatetask", 42 | "composite_traintask", 43 | ], 44 | ) 45 | def test_list_task(asset_type, client, mocker): 46 | item = getattr(datastore, asset_type.upper()) 47 | 48 | mocked_response = make_paginated_response([item]) 49 | m = mock_requests(mocker, "get", response=mocked_response) 50 | 51 | response = client.list_task() 52 | 53 | assert response == [models.Task(**item)] 54 | m.assert_called() 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "asset_type,filters", 59 | [ 60 | ("dataset", {"permissions": ["foo", "bar"]}), 61 | ("function", {"owner": ["foo", "bar"]}), 62 | ("compute_plan", {"name": "foo"}), 63 | ("compute_plan", {"status": [models.ComputePlanStatus.done.value]}), 64 | ("model", {"owner": ["MyOrg1MSP"]}), 65 | ], 66 | ) 67 | def test_list_asset_with_filters(asset_type, filters, client, mocker): 68 | item = getattr(datastore, asset_type.upper()) 69 | method = getattr(client, f"list_{asset_type}") 70 | 71 | mocked_response = make_paginated_response([item]) 72 | m = mock_requests(mocker, "get", response=mocked_response) 73 | 74 | response = method(filters) 75 | 76 | assert response == [models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**item)] 77 | m.assert_called() 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "filters", 82 | [ 83 | {"rank": [1, 3]}, 84 | {"key": ["foo", "bar"]}, 85 | {"worker": ["foo", "bar"]}, 86 | {"owner": ["foo", "bar"]}, 87 | ], 88 | ) 89 | def test_list_task_with_filters(filters, client, mocker): 90 | items = datastore.TASK_LIST 91 | 92 | mocked_response = make_paginated_response(items) 93 | m = mock_requests(mocker, "get", response=mocked_response) 94 | 95 | response = client.list_task(filters) 96 | 97 | assert response == [models.Task(**item) for item in items] 98 | m.assert_called() 99 | 100 | 101 | def test_list_asset_with_filters_failure(client, mocker): 102 | items = [datastore.FUNCTION] 103 | m = mock_requests(mocker, "get", response=items) 104 | 105 | filters = {"foo"} 106 | with pytest.raises(exceptions.FilterFormatError) as exc_info: 107 | client.list_function(filters) 108 | 109 | m.assert_not_called() 110 | assert str(exc_info.value).startswith("Cannot load filters") 111 | 112 | 113 | def test_list_compute_plan_with_ordering(client, mocker): 114 | item = datastore.COMPUTE_PLAN 115 | 116 | mocked_response = make_paginated_response([item]) 117 | m = mock_requests(mocker, "get", response=mocked_response) 118 | 119 | order_by = "start_date" 120 | response = client.list_compute_plan(order_by=order_by) 121 | 122 | assert response == [models.SCHEMA_TO_MODEL[schemas.Type.ComputePlan](**item)] 123 | m.assert_called() 124 | 125 | 126 | def test_list_task_with_ordering(client, mocker): 127 | items = datastore.TASK_LIST 128 | 129 | mocked_response = make_paginated_response(items) 130 | m = mock_requests(mocker, "get", response=mocked_response) 131 | 132 | order_by = "start_date" 133 | response = client.list_task(order_by=order_by) 134 | 135 | assert response == [models.Task(**item) for item in items] 136 | m.assert_called() 137 | 138 | 139 | def test_list_asset_with_ordering_failure(client, mocker): 140 | items = [datastore.COMPUTE_PLAN] 141 | m = mock_requests(mocker, "get", response=items) 142 | 143 | order_by = "foo" 144 | with pytest.raises(exceptions.OrderingFormatError) as exc_info: 145 | client.list_compute_plan(order_by=order_by) 146 | 147 | m.assert_not_called() 148 | assert str(exc_info.value).startswith("Please review the documentation") 149 | -------------------------------------------------------------------------------- /tests/sdk/test_rest_client.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | import requests 5 | 6 | from substra.sdk import exceptions 7 | from substra.sdk.backends.remote import rest_client 8 | from substra.sdk.client import Client 9 | 10 | from .. import datastore 11 | from ..utils import mock_requests 12 | from ..utils import mock_requests_responses 13 | from ..utils import mock_response 14 | 15 | CONFIG = { 16 | "url": "http://foo.com", 17 | "insecure": False, 18 | } 19 | 20 | CONFIG_SECURE = { 21 | "url": "http://foo.com", 22 | "insecure": False, 23 | } 24 | 25 | CONFIG_INSECURE = { 26 | "url": "http://foo.com", 27 | "insecure": True, 28 | } 29 | 30 | CONFIGS = [CONFIG, CONFIG_SECURE, CONFIG_INSECURE] 31 | 32 | 33 | def _client_from_config(config): 34 | return rest_client.Client( 35 | config["url"], 36 | config["insecure"], 37 | None, 38 | ) 39 | 40 | 41 | @pytest.mark.parametrize("config", CONFIGS) 42 | def test_post_success(mocker, config): 43 | m = mock_requests(mocker, "post", response={}) 44 | _client_from_config(config).add("http://foo", {}) 45 | assert len(m.call_args_list) == 1 46 | 47 | 48 | @pytest.mark.parametrize("config", CONFIGS) 49 | def test_verify_login(mocker, config): 50 | """ 51 | check "insecure" configuration results in endpoints being called with verify=False 52 | """ 53 | m_post = mock_requests(mocker, "post", response={"id": "a", "token": "a", "expires_at": "3000-01-01T00:00:00Z"}) 54 | m_delete = mock_requests(mocker, "delete", response={}) 55 | 56 | c = _client_from_config(config) 57 | c.login("foo", "bar") 58 | c.logout() 59 | if config.get("insecure", None): 60 | assert m_post.call_args.kwargs["verify"] is False 61 | assert m_delete.call_args.kwargs["verify"] is False 62 | else: 63 | assert "verify" not in m_post.call_args.kwargs or m_post.call_args.kwargs["verify"] 64 | assert "verify" not in m_post.call_args.kwargs or m_delete.call_args.kwargs["verify"] 65 | 66 | 67 | @pytest.mark.parametrize( 68 | "status_code, http_response, sdk_exception", 69 | [ 70 | (400, {"detail": "Invalid Request"}, exceptions.InvalidRequest), 71 | (401, {"detail": "Invalid username/password"}, exceptions.AuthenticationError), 72 | (403, {"detail": "Unauthorized"}, exceptions.AuthorizationError), 73 | (404, {"detail": "Not Found"}, exceptions.NotFound), 74 | (408, {"key": "a-key"}, exceptions.RequestTimeout), 75 | (408, {}, exceptions.RequestTimeout), 76 | (500, "CRASH", exceptions.InternalServerError), 77 | ], 78 | ) 79 | def test_request_http_errors(mocker, status_code, http_response, sdk_exception): 80 | m = mock_requests(mocker, "post", response=http_response, status=status_code) 81 | with pytest.raises(sdk_exception): 82 | _client_from_config(CONFIG).add("http://foo", {}) 83 | assert len(m.call_args_list) == 1 84 | 85 | 86 | def test_request_connection_error(mocker): 87 | mocker.patch( 88 | "substra.sdk.backends.remote.rest_client.requests.post", side_effect=requests.exceptions.ConnectionError 89 | ) 90 | with pytest.raises(exceptions.ConnectionError): 91 | _client_from_config(CONFIG).add("foo", {}) 92 | 93 | 94 | def test_add_timeout_with_retry(mocker): 95 | asset_type = "traintask" 96 | responses = [ 97 | mock_response(response={"key": "a-key"}, status=408), 98 | mock_response(response={"key": "a-key"}), 99 | ] 100 | m_post = mock_requests_responses(mocker, "post", responses) 101 | asset = _client_from_config(CONFIG).add(asset_type, retry_timeout=60) 102 | assert len(m_post.call_args_list) == 2 103 | assert asset == {"key": "a-key"} 104 | 105 | 106 | def test_add_already_exist(mocker): 107 | asset_type = "traintask" 108 | m_post = mock_requests(mocker, "post", response={"key": "a-key"}, status=409) 109 | asset = _client_from_config(CONFIG).add(asset_type) 110 | assert len(m_post.call_args_list) == 1 111 | assert asset == {"key": "a-key"} 112 | 113 | 114 | def test_add_wrong_url(mocker): 115 | """Check correct error is raised when wrong url with correct syntax is set.""" 116 | error = json.decoder.JSONDecodeError("", "", 0) 117 | 118 | mock_requests(mocker, "post", status=200, json_error=error) 119 | test_client = Client(url="http://www.dummy.com", token="foo") 120 | with pytest.raises(exceptions.BadConfiguration) as e: 121 | test_client.login("test_client", "hehe") 122 | assert "Make sure that given url" in e.value.args[0] 123 | 124 | 125 | def test_list_paginated(mocker): 126 | asset_type = "traintask" 127 | items = [datastore.TRAINTASK, datastore.TRAINTASK] 128 | responses = [ 129 | mock_response( 130 | response={ 131 | "count": len(items), 132 | "next": "http://foo.com/?page=2", 133 | "previous": None, 134 | "results": items[:1], 135 | }, 136 | status=200, 137 | ), 138 | mock_response( 139 | response={ 140 | "count": len(items), 141 | "next": None, 142 | "previous": "http://foo.com/?page=1", 143 | "results": items[1:], 144 | }, 145 | status=200, 146 | ), 147 | ] 148 | m_get = mock_requests_responses(mocker, "get", responses) 149 | asset = _client_from_config(CONFIG).list(asset_type) 150 | assert len(asset) == len(items) 151 | assert len(m_get.call_args_list) == 2 152 | 153 | 154 | def test_list_not_paginated(mocker): 155 | asset_type = "traintask" 156 | items = [datastore.TRAINTASK, datastore.TRAINTASK] 157 | m_get = mock_requests( 158 | mocker, 159 | "get", 160 | response={ 161 | "count": len(items), 162 | "next": "http://foo.com/?page=2", 163 | "previous": None, 164 | "results": items[:1], 165 | }, 166 | status=200, 167 | ) 168 | asset = _client_from_config(CONFIG).list(asset_type, paginated=False) 169 | assert len(asset) != len(items) 170 | assert len(m_get.call_args_list) == 1 171 | -------------------------------------------------------------------------------- /tests/sdk/test_schemas.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import uuid 3 | 4 | import pytest 5 | 6 | from substra.sdk.schemas import DataSampleSpec 7 | from substra.sdk.schemas import DatasetSpec 8 | from substra.sdk.schemas import Permissions 9 | 10 | 11 | @pytest.mark.parametrize("path", [pathlib.Path() / "data", "./data", pathlib.Path().cwd() / "data"]) 12 | def test_datasample_spec_resolve_path(path): 13 | datasample_spec = DataSampleSpec(path=path, data_manager_keys=[str(uuid.uuid4())]) 14 | 15 | assert datasample_spec.path == pathlib.Path().cwd() / "data" 16 | 17 | 18 | def test_datasample_spec_resolve_paths(): 19 | paths = [pathlib.Path() / "data", "./data", pathlib.Path().cwd() / "data"] 20 | datasample_spec = DataSampleSpec(paths=paths, data_manager_keys=[str(uuid.uuid4())]) 21 | 22 | assert all([path == pathlib.Path().cwd() / "data" for path in datasample_spec.paths]) 23 | 24 | 25 | def test_datasample_spec_exclusive_path(): 26 | with pytest.raises(ValueError): 27 | DataSampleSpec(paths=["fake_paths"], path="fake_paths", data_manager_keys=[str(uuid.uuid4())]) 28 | 29 | 30 | def test_datasample_spec_no_path(): 31 | with pytest.raises(ValueError): 32 | DataSampleSpec(data_manager_keys=[str(uuid.uuid4())]) 33 | 34 | 35 | def test_datasample_spec_paths_set_to_none(): 36 | with pytest.raises(ValueError): 37 | DataSampleSpec(paths=None, data_manager_keys=[str(uuid.uuid4())]) 38 | 39 | 40 | def test_datasample_spec_path_set_to_none(): 41 | with pytest.raises(ValueError): 42 | DataSampleSpec(path=None, data_manager_keys=[str(uuid.uuid4())]) 43 | 44 | 45 | def test_dataset_spec_no_description(tmpdir): 46 | 47 | opener_path = tmpdir / "fake_opener.py" 48 | permissions = Permissions(public=True, authorized_ids=[]) 49 | 50 | DatasetSpec( 51 | name="Fake Dataset", 52 | data_opener=str(opener_path), 53 | permissions=permissions, 54 | logs_permission=permissions, 55 | ) 56 | 57 | assert (pathlib.Path(opener_path).parent / "generated_description.md").exists 58 | -------------------------------------------------------------------------------- /tests/sdk/test_subprocess.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | from substra.sdk.backends.local.compute.spawner.subprocess import _get_command_args 4 | 5 | 6 | def test_get_command_args_without_space(): 7 | # check that it's not changing the path if there is no spaces 8 | function_name = "train" 9 | command = ["--opener-path", "${_VOLUME_OPENER}", "--compute-plan-path", "${_VOLUME_LOCAL}"] 10 | command_template = [string.Template(part) for part in command] 11 | local_volumes = { 12 | "_VOLUME_OPENER": "/a/path/without/any/space/opener.py", 13 | "_VOLUME_LOCAL": "/another/path/without/any/space", 14 | } 15 | 16 | py_commands = _get_command_args(function_name, command_template, local_volumes) 17 | 18 | valid_py_commands = [ 19 | "--function-name", 20 | "train", 21 | "--opener-path", 22 | "/a/path/without/any/space/opener.py", 23 | "--compute-plan-path", 24 | "/another/path/without/any/space", 25 | ] 26 | 27 | assert py_commands == valid_py_commands 28 | 29 | 30 | def test_get_command_args_with_spaces(): 31 | # check that it's not splitting path with spaces in different arguments 32 | function_name = "train" 33 | command = ["--opener-path", "${_VOLUME_OPENER}", "--compute-plan-path", "${_VOLUME_LOCAL}"] 34 | command_template = [string.Template(part) for part in command] 35 | local_volumes = { 36 | "_VOLUME_OPENER": "/a/path with spaces/opener.py", 37 | "_VOLUME_LOCAL": "/another/path with spaces", 38 | } 39 | 40 | py_commands = _get_command_args(function_name, command_template, local_volumes) 41 | 42 | valid_py_commands = [ 43 | "--function-name", 44 | "train", 45 | "--opener-path", 46 | "/a/path with spaces/opener.py", 47 | "--compute-plan-path", 48 | "/another/path with spaces", 49 | ] 50 | 51 | assert py_commands == valid_py_commands 52 | -------------------------------------------------------------------------------- /tests/sdk/test_update.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from substra.sdk import models 4 | from substra.sdk import schemas 5 | 6 | from .. import datastore 7 | from ..utils import mock_requests 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "asset_type", 12 | ["dataset", "function", "compute_plan"], 13 | ) 14 | def test_update_asset(asset_type, client, mocker): 15 | update_method = getattr(client, f"update_{asset_type}") 16 | get_method = getattr(client, f"get_{asset_type}") 17 | 18 | item = getattr(datastore, asset_type.upper()) 19 | name_update = {"name": "New name"} 20 | updated_item = {**item, **name_update} 21 | 22 | m_put = mock_requests(mocker, "put") 23 | m_get = mock_requests(mocker, "get", response=updated_item) 24 | 25 | update_method("magic-key", name_update["name"]) 26 | response = get_method("magic-key") 27 | 28 | assert response == models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**updated_item) 29 | m_put.assert_called() 30 | m_get.assert_called() 31 | 32 | 33 | def test_add_compute_plan_tasks(client, mocker): 34 | item = datastore.COMPUTE_PLAN 35 | m = mock_requests(mocker, "post", response=item) 36 | m_get = mock_requests(mocker, "get", response=datastore.COMPUTE_PLAN) 37 | 38 | response = client.add_compute_plan_tasks("foo", {}) 39 | 40 | assert response == models.ComputePlan(**item) 41 | m.assert_called 42 | m_get.assert_called 43 | 44 | 45 | def test_add_compute_plan_tasks_with_schema(client, mocker): 46 | item = datastore.COMPUTE_PLAN 47 | m = mock_requests(mocker, "post", response=item) 48 | m_get = mock_requests(mocker, "get", response=datastore.COMPUTE_PLAN) 49 | 50 | response = client.add_compute_plan_tasks("foo", schemas.UpdateComputePlanTasksSpec(key="foo")) 51 | 52 | assert response == models.ComputePlan(**item) 53 | m.assert_called 54 | m_get.assert_called 55 | -------------------------------------------------------------------------------- /tests/sdk/test_wait.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | import pytest 4 | 5 | from substra.sdk import exceptions 6 | from substra.sdk.models import ComputePlanStatus 7 | from substra.sdk.models import ComputeTaskStatus 8 | from substra.sdk.models import TaskErrorType 9 | 10 | from .. import datastore 11 | from ..utils import mock_requests 12 | 13 | 14 | def _param_name_maker(arg): 15 | if isinstance(arg, str): 16 | return arg 17 | else: 18 | return "" 19 | 20 | 21 | @pytest.mark.parametrize( 22 | ("asset_dict", "function_name", "status", "expectation"), 23 | [ 24 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.done, does_not_raise()), 25 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.canceled, pytest.raises(exceptions.FutureFailureError)), 26 | (datastore.COMPUTE_PLAN, "wait_compute_plan", ComputePlanStatus.done, does_not_raise()), 27 | ( 28 | datastore.COMPUTE_PLAN, 29 | "wait_compute_plan", 30 | ComputePlanStatus.failed, 31 | pytest.raises(exceptions.FutureFailureError), 32 | ), 33 | ( 34 | datastore.COMPUTE_PLAN, 35 | "wait_compute_plan", 36 | ComputePlanStatus.canceled, 37 | pytest.raises(exceptions.FutureFailureError), 38 | ), 39 | ], 40 | ids=_param_name_maker, 41 | ) 42 | def test_wait(client, mocker, asset_dict, function_name, status, expectation): 43 | item = {**asset_dict, "status": status} 44 | mock_requests(mocker, "get", item) 45 | function = getattr(client, function_name) 46 | with expectation: 47 | function(key=item["key"]) 48 | 49 | 50 | def test_wait_task_failed(client, mocker): 51 | # We need an error type to stop the iteration 52 | item = {**datastore.TRAINTASK, "status": ComputeTaskStatus.failed, "error_type": TaskErrorType.internal} 53 | mock_requests(mocker, "get", item) 54 | with pytest.raises(exceptions.FutureFailureError): 55 | client.wait_task(key=item["key"]) 56 | 57 | 58 | @pytest.mark.parametrize( 59 | ("asset_dict", "function_name", "status"), 60 | [ 61 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.waiting_for_parent_tasks), 62 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.waiting_for_builder_slot), 63 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.waiting_for_executor_slot), 64 | (datastore.COMPUTE_PLAN, "wait_compute_plan", ComputePlanStatus.created), 65 | ], 66 | ids=_param_name_maker, 67 | ) 68 | def test_wait_timeout(client, mocker, asset_dict, function_name, status): 69 | item = {**asset_dict, "status": status} 70 | mock_requests(mocker, "get", item) 71 | function = getattr(client, function_name) 72 | with pytest.raises(exceptions.FutureTimeoutError): 73 | # mock_requests returns only once and timeout=0 is falsy, so setting a microscopic duration 74 | function(key=item["key"], timeout=1e-10) 75 | -------------------------------------------------------------------------------- /tests/test_request_formatter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from substra.sdk.backends.remote import request_formatter 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "raw,formatted", 8 | [ 9 | ({"foo": ["bar", "baz"]}, {"foo": "bar,baz"}), 10 | ({"foo": ["bar", "baz"], "bar": ["qux"]}, {"foo": "bar,baz", "bar": "qux"}), 11 | ({"foo": ["b ar ", " baz"]}, {"foo": "bar,baz"}), 12 | ( 13 | {}, 14 | {}, 15 | ), 16 | ({"name": "bar,baz"}, {"match": "bar,baz"}), 17 | ( 18 | {"metadata": [{"key": "epochs", "type": "is", "value": "10"}]}, 19 | {"metadata": '[{"key":"epochs","type":"is","value":"10"}]'}, 20 | ), 21 | (None, {}), 22 | ], 23 | ) 24 | def test_format_search_filters_for_remote(raw, formatted): 25 | assert request_formatter.format_search_filters_for_remote(raw) == formatted 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "ordering, ascending, formatted", 30 | [ 31 | ("creation_date", False, "-creation_date"), 32 | ("start_date", True, "start_date"), 33 | ], 34 | ) 35 | def test_format_search_ordering_for_remote(ordering, ascending, formatted): 36 | assert request_formatter.format_search_ordering_for_remote(ordering, ascending) == formatted 37 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | 4 | import pytest 5 | 6 | import substra 7 | from substra.sdk import exceptions 8 | from substra.sdk import schemas 9 | from substra.sdk import utils 10 | 11 | 12 | def _unzip(fp, destination): 13 | with zipfile.ZipFile(fp, "r") as zipf: 14 | zipf.extractall(destination) 15 | 16 | 17 | def test_zip_folder(tmp_path): 18 | # initialise dir to zip 19 | dir_to_zip = tmp_path / "dir" 20 | dir_to_zip.mkdir() 21 | 22 | file_items = [ 23 | ("name0.txt", "content0"), 24 | ("dir1/name1.txt", "content1"), 25 | ("dir2/name2.txt", "content2"), 26 | ] 27 | 28 | for name, content in file_items: 29 | path = dir_to_zip / name 30 | path.parents[0].mkdir(exist_ok=True) 31 | path.write_text(content) 32 | 33 | for name, _ in file_items: 34 | path = dir_to_zip / name 35 | assert os.path.exists(str(path)) 36 | 37 | # zip dir 38 | fp = utils.zip_folder_in_memory(str(dir_to_zip)) 39 | assert fp 40 | 41 | # unzip dir 42 | destination_dir = tmp_path / "destination" 43 | destination_dir.mkdir() 44 | _unzip(fp, str(destination_dir)) 45 | for name, content in file_items: 46 | path = destination_dir / name 47 | assert os.path.exists(str(path)) 48 | assert path.read_text() == content 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "filters,expected,exception", 53 | [ 54 | ("str", None, exceptions.FilterFormatError), 55 | ({}, None, exceptions.FilterFormatError), 56 | ( 57 | [{"key": "foo", "type": "bar", "value": "baz"}], 58 | None, 59 | exceptions.FilterFormatError, 60 | ), 61 | ([{"key": "foo", "type": "is", "value": "baz"}, {}], None, exceptions.FilterFormatError), 62 | ([{"key": "foo", "type": "is", "value": "baz"}], None, None), 63 | ], 64 | ) 65 | def test_check_metadata_search_filter(filters, expected, exception): 66 | if exception: 67 | with pytest.raises(exception): 68 | utils._check_metadata_search_filters(filters) 69 | else: 70 | assert utils._check_metadata_search_filters(filters) == expected 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "asset_type,filters,expected,exception", 75 | [ 76 | ( 77 | schemas.Type.ComputePlan, 78 | {"status": [substra.models.ComputePlanStatus.doing.value]}, 79 | {"status": [substra.models.ComputePlanStatus.doing.value]}, 80 | None, 81 | ), 82 | ( 83 | schemas.Type.Task, 84 | {"status": [substra.models.ComputeTaskStatus.done.value]}, 85 | {"status": [substra.models.ComputeTaskStatus.done.value]}, 86 | None, 87 | ), 88 | (schemas.Type.Task, {"rank": [1]}, {"rank": ["1"]}, None), 89 | (schemas.Type.DataSample, ["wrong filter type"], None, exceptions.FilterFormatError), 90 | (schemas.Type.ComputePlan, {"name": ["list"]}, None, exceptions.FilterFormatError), 91 | (schemas.Type.Task, {"foo": "not allowed key"}, None, exceptions.NotAllowedFilterError), 92 | ( 93 | schemas.Type.ComputePlan, 94 | {"name": "cp1", "key": ["key1", "key2"]}, 95 | {"name": "cp1", "key": ["key1", "key2"]}, 96 | None, 97 | ), 98 | ], 99 | ) 100 | def test_check_and_format_search_filters(asset_type, filters, expected, exception): 101 | if exception: 102 | with pytest.raises(exception): 103 | utils.check_and_format_search_filters(asset_type, filters) 104 | else: 105 | assert utils.check_and_format_search_filters(asset_type, filters) == expected 106 | 107 | 108 | @pytest.mark.parametrize( 109 | "ordering, exception", 110 | [ 111 | ("creation_date", None), 112 | ("start_date", None), 113 | ("foo", exceptions.OrderingFormatError), 114 | (None, None), 115 | ], 116 | ) 117 | def test_check_search_ordering(ordering, exception): 118 | if exception: 119 | with pytest.raises(exception): 120 | utils.check_search_ordering(ordering) 121 | else: 122 | utils.check_search_ordering(ordering) 123 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import requests 4 | 5 | 6 | def mock_response(response=None, status=200, headers=None, json_error=None): 7 | headers = headers or {} 8 | m = mock.MagicMock(spec=requests.Response) 9 | m.status_code = status 10 | m.headers = headers 11 | m.text = str(response) 12 | m.json = mock.MagicMock(return_value=response, headers=headers, side_effect=json_error) 13 | 14 | if status not in (200, 201): 15 | exception = requests.exceptions.HTTPError(str(status), response=m) 16 | m.raise_for_status = mock.MagicMock(side_effect=exception) 17 | 18 | return m 19 | 20 | 21 | def mock_requests_responses(mocker, method, responses): 22 | return mocker.patch( 23 | f"substra.sdk.backends.remote.rest_client.requests.{method}", 24 | side_effect=responses, 25 | ) 26 | 27 | 28 | def mock_requests(mocker, method, response=None, status=200, headers=None, json_error=None): 29 | r = mock_response(response, status, headers, json_error) 30 | return mock_requests_responses(mocker, method, (r,)) 31 | 32 | 33 | def make_paginated_response(items): 34 | return {"count": len(items), "next": None, "previous": None, "results": items} 35 | --------------------------------------------------------------------------------