├── .flake8
├── .git-blame-ignore-revs
├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ ├── config.yml
│ └── feature_request.yaml
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ ├── publish.yml
│ ├── python.yml
│ └── towncrier-changelog.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── CONTRIBUTORS.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── Substra-logo-colour.svg
├── Substra-logo-white.svg
├── bin
├── generate_sdk_documentation.py
└── generate_sdk_schemas_documentation.py
├── changes
└── .gitkeep
├── docs
├── README.md
└── technical_documentation.md
├── pyproject.toml
├── references
├── sdk.md
├── sdk_models.md
└── sdk_schemas.md
├── substra
├── __init__.py
├── __version__.py
├── config.py
└── sdk
│ ├── __init__.py
│ ├── archive
│ ├── __init__.py
│ ├── safezip.py
│ └── tarsafe.py
│ ├── backends
│ ├── __init__.py
│ ├── base.py
│ ├── local
│ │ ├── __init__.py
│ │ ├── backend.py
│ │ ├── compute
│ │ │ ├── __init__.py
│ │ │ ├── spawner
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── docker.py
│ │ │ │ └── subprocess.py
│ │ │ └── worker.py
│ │ ├── dal.py
│ │ ├── db.py
│ │ └── models.py
│ └── remote
│ │ ├── __init__.py
│ │ ├── backend.py
│ │ ├── request_formatter.py
│ │ └── rest_client.py
│ ├── client.py
│ ├── compute_plan.py
│ ├── exceptions.py
│ ├── fs.py
│ ├── graph.py
│ ├── hasher.py
│ ├── models.py
│ ├── schemas.py
│ └── utils.py
└── tests
├── __init__.py
├── conftest.py
├── data_factory.py
├── datastore.py
├── fl_interface.py
├── mocked_requests.py
├── sdk
├── __init__.py
├── data
│ ├── symlink.zip
│ └── traversal.zip
├── local
│ ├── __init__.py
│ ├── conftest.py
│ └── test_debug.py
├── test_add.py
├── test_archive.py
├── test_cancel.py
├── test_client.py
├── test_describe.py
├── test_download.py
├── test_get.py
├── test_graph.py
├── test_list.py
├── test_rest_client.py
├── test_schemas.py
├── test_subprocess.py
├── test_update.py
└── test_wait.py
├── test_request_formatter.py
├── test_utils.py
└── utils.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | max-complexity = 10
4 | extend-ignore = E203, W503, N802, N803, N806
5 | # W503 is incompatible with flake8, see https://github.com/psf/black/pull/36
6 | # E203 must be disabled for flake8 to work with Black.
7 | # See https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#id1
8 | # N802, N803 and N806 prevent us from using upper cases in variables names, functions name and arguments.
9 | per-file-ignores =
10 | __init__.py:F401,
11 |
12 | exclude =
13 | .git
14 | .github
15 | .dvc
16 | __pycache__
17 | .venv
18 | .mypy_cache
19 | .pytest_cache
20 | hubconf.py
21 | **local-worker
22 |
--------------------------------------------------------------------------------
/.git-blame-ignore-revs:
--------------------------------------------------------------------------------
1 | 3b9b9d3d24cc94bdf34ce39cf0273f410188cbd2
2 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @Substra/code-owners
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: Report bug or performance issue
3 | title: "BUG: "
4 | labels: [Bug]
5 |
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | Thanks for taking the time to fill out this bug report!
11 | - type: textarea
12 | id: context
13 | attributes:
14 | label: What are you trying to do?
15 | description: >
16 | Please provide some context on what you are trying to achieve.
17 | placeholder:
18 | validations:
19 | required: true
20 | - type: textarea
21 | id: issue-description
22 | attributes:
23 | label: Issue Description (what is happening?)
24 | description: >
25 | Please provide a description of the issue.
26 | validations:
27 | required: true
28 | - type: textarea
29 | id: expected-behavior
30 | attributes:
31 | label: Expected Behavior (what should happen?)
32 | description: >
33 | Please describe or show a code example of the expected behavior.
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: example
38 | attributes:
39 | label: Reproducible Example
40 | description: >
41 | If possible, provide a reproducible example.
42 | render: python
43 |
44 | - type: textarea
45 | id: os-version
46 | attributes:
47 | label: Operating system
48 | description: >
49 | Which operating system are you using? (Provide the version number)
50 | validations:
51 | required: true
52 | - type: textarea
53 | id: python-version
54 | attributes:
55 | label: Python version
56 | description: >
57 | Which Python version are you using?
58 | placeholder: >
59 | python --version
60 | validations:
61 | required: true
62 | - type: textarea
63 | id: substra-version
64 | attributes:
65 | label: Installed Substra versions
66 | description: >
67 | Which version of `substrafl`/ `substra` / `substra-tools` are you using?
68 | You can check if they are compatible in the [compatibility table](https://docs.substra.org/en/stable/additional/release.html#compatibility-table).
69 | placeholder: >
70 | pip freeze | grep substra
71 | render: python
72 | validations:
73 | required: true
74 | - type: textarea
75 | id: dependencies-version
76 | attributes:
77 | label: Installed versions of dependencies
78 | description: >
79 | Please provide versions of dependencies which might be relevant to your issue (eg. `helm` and `skaffold` version for a deployment issue, `numpy` and `pytorch` for an algorithmic issue).
80 |
81 |
82 | - type: textarea
83 | id: logs
84 | attributes:
85 | label: Logs / Stacktrace
86 | description: >
87 | Please copy-paste here any log and/or stacktrace that might be relevant. Remove confidential and personal information if necessary.
88 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: true
2 | contact_links:
3 | - name: Ask a question
4 | url: https://join.slack.com/t/substra-workspace/shared_invite/zt-1fqnk0nw6-xoPwuLJ8dAPXThfyldX8yA
5 | about: Don't hesitate to join the Substra community on Slack to ask all your questions!
6 | - name: User Documentation Improvements
7 | url: https://github.com/Substra/substra-documentation/issues
8 | about: For issues related to the User Documentation, please open an issue on the substra-documentation repository
9 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yaml:
--------------------------------------------------------------------------------
1 | name: Feature Request
2 | description: Suggest an idea for Substra
3 | title: "ENH: "
4 | labels: [Enhancement]
5 |
6 | body:
7 | - type: checkboxes
8 | id: checks
9 | attributes:
10 | label: Feature Type
11 | description: Please check what type of feature request you would like to propose.
12 | options:
13 | - label: >
14 | Adding new functionality to Substra
15 | - label: >
16 | Changing existing functionality in Substra
17 | - label: >
18 | Removing existing functionality in Substra
19 | - type: textarea
20 | id: description
21 | attributes:
22 | label: Problem Description
23 | description: >
24 | Please describe what problem the feature would solve, e.g. "I wish I could use Substra to ..."
25 | validations:
26 | required: true
27 | - type: textarea
28 | id: feature
29 | attributes:
30 | label: Feature Description
31 | description: >
32 | Please describe what will be this new feature.
33 | validations:
34 | required: true
35 | - type: textarea
36 | id: alternative
37 | attributes:
38 | label: Alternative Solutions
39 | description: >
40 | Please describe any alternative solution (existing functionality, 3rd party package, etc.)
41 | that would satisfy the feature request.
42 | - type: textarea
43 | id: context
44 | attributes:
45 | label: Additional Context
46 | description: >
47 | Please provide any relevant GitHub issues, code examples or references that help describe and support
48 | the feature request.
49 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Related issue
2 |
3 | `#` followed by the number of the issue
4 |
5 | ## Summary
6 |
7 | ## Notes
8 |
9 | ## Please check if the PR fulfills these requirements
10 |
11 | - [ ] If necessary, the [changelog](https://github.com/Substra/substra/blob/main/CHANGELOG.md) has been updated
12 | - [ ] Tests for the changes have been added (for bug fixes / features)
13 | - [ ] Docs have been added / updated (for bug fixes / features)
14 | - [ ] The commit message follows the [conventional commit](https://www.conventionalcommits.org/en/v1.0.0/) specification
15 | - For any breaking changes, companion PRs have been opened on the following repositories:
16 | - [ ] [substra-tests](https://github.com/Substra/substra)
17 | - [ ] [substrafl](https://github.com/Substra/substrafl)
18 | - [ ] [substra-documentation](https://github.com/Substra/substra-documentation)
19 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | publish:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - uses: actions/setup-python@v5
13 | with:
14 | python-version: 3.11
15 | - name: Install Hatch
16 | run: pipx install hatch
17 | - name: Build dist
18 | run: hatch build
19 | - name: Publish
20 | run: hatch publish -u __token__ -a ${{ secrets.PYPI_API_TOKEN }}
21 |
22 |
--------------------------------------------------------------------------------
/.github/workflows/python.yml:
--------------------------------------------------------------------------------
1 | name: Python
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | pull_request:
8 |
9 | jobs:
10 | lint:
11 | name: Lint and documentation
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: "3.12"
19 | - name: Install tools
20 | run: pip install flake8 black isort wheel docstring-parser
21 | - name: Lint
22 | run: |
23 | black --check substra
24 | isort --check substra
25 | flake8 substra
26 | - name: Install substra
27 | run: pip install -e .
28 | - name: Generate and validate SDK documentation
29 | run: |
30 | python bin/generate_sdk_documentation.py --output-path='references/sdk.md'
31 | python bin/generate_sdk_schemas_documentation.py --output-path references/sdk_schemas.md
32 | python bin/generate_sdk_schemas_documentation.py --models --output-path='references/sdk_models.md'
33 | - name: Documentation artifacts
34 | uses: actions/upload-artifact@v4
35 | if: always()
36 | with:
37 | retention-days: 1
38 | name: references
39 | path: references/*
40 | tests:
41 | runs-on: ubuntu-20.04
42 | env:
43 | XDG_RUNTIME_DIR: /home/runner/.docker/run
44 | DOCKER_HOST: unix:///home/runner/.docker/run/docker.sock
45 | strategy:
46 | fail-fast: false
47 | matrix:
48 | python-version: ["3.10", "3.11", "3.12"]
49 | name: Tests on Python ${{ matrix.python-version }}
50 | steps:
51 | - name: Set up python
52 | uses: actions/setup-python@v5
53 | with:
54 | python-version: ${{ matrix.python-version }}
55 | architecture: x64
56 | - name: Install Docker rootless
57 | run: |
58 | sudo systemctl disable --now docker.service
59 | export FORCE_ROOTLESS_INSTALL=1
60 | curl -fsSL https://get.docker.com/rootless | sh
61 | - name: Configure docker
62 | run: |
63 | export PATH=/home/runner/bin:$PATH
64 | /home/runner/bin/dockerd-rootless.sh & # Start Docker rootless in the background
65 | - name: Cloning substra
66 | uses: actions/checkout@v4
67 | with:
68 | path: substra
69 | - name: Cloning substratools
70 | uses: actions/checkout@v4
71 | with:
72 | repository: Substra/substra-tools
73 | path: substratools
74 | ref: main
75 | - name: Install substra and substratools
76 | run: |
77 | pip install --no-cache-dir -e substratools
78 | pip install --no-cache-dir -e 'substra[dev]'
79 |
80 | - name: Test
81 | run: |
82 | export PATH=/home/runner/bin:$PATH
83 | cd substra && make test
84 |
--------------------------------------------------------------------------------
/.github/workflows/towncrier-changelog.yml:
--------------------------------------------------------------------------------
1 | name: Towncrier changelog
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | app_version:
7 | type: string
8 | description: 'The version of the app'
9 | required: true
10 | branch:
11 | type: string
12 | description: 'The branch to update'
13 | required: true
14 |
15 | jobs:
16 | test-generate-publish:
17 | uses: substra/substra-gha-workflows/.github/workflows/towncrier-changelog.yml@main
18 | secrets: inherit
19 | with:
20 | app_version: ${{ inputs.app_version }}
21 | repo: substra
22 | branch: ${{ inputs.branch }}
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized
2 | __pycache__/
3 |
4 | # Distribution / packaging
5 | build/
6 | dist/
7 | sdist/
8 | .eggs/*
9 | *.egg-info
10 |
11 | # Unit test / coverage reports
12 | .cache
13 | .coverage*
14 | .pytest_cache/
15 |
16 | # Developer environments
17 | .idea
18 | .vscode
19 | .env
20 | .venv
21 | .DS_Store
22 | .mypy_cache
23 |
24 | local-worker/
25 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: "^references/*"
2 | repos:
3 | - repo: https://github.com/psf/black
4 | rev: 24.1.1
5 | hooks:
6 | - id: black
7 |
8 | - repo: https://github.com/pycqa/flake8
9 | rev: 7.0.0
10 | hooks:
11 | - id: flake8
12 | additional_dependencies: [pep8-naming, flake8-bugbear]
13 | args: ['--classmethod-decorators=classmethod,validator,root_validator']
14 |
15 | - repo: https://github.com/pycqa/isort
16 | rev: 5.12.0
17 | hooks:
18 | - id: isort
19 |
20 | - repo: https://github.com/pre-commit/pre-commit-hooks
21 | rev: v4.3.0
22 | hooks:
23 | - id: trailing-whitespace
24 | - id: end-of-file-fixer
25 | - id: debug-statements
26 | - id: check-added-large-files
27 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | Substra repositories' code of conduct is available in the Substra documentation [here](https://docs.substra.org/en/stable/contributing/code-of-conduct.html).
2 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Substra repositories' contributing guide is available in the Substra documentation [here](https://docs.substra.org/en/stable/contributing/contributing-guide.html).
2 |
--------------------------------------------------------------------------------
/CONTRIBUTORS.md:
--------------------------------------------------------------------------------
1 | This is a file of people that have made significant contributions to the Substra sdk. It is sorted in chronological order. Please include your contribution at the bottom of this document in the following format : name (N), email (E), description of work (W) and date (D).
2 |
3 | To have your contribution listed, your work must meet the minimum [threshold of originality](https://en.wikipedia.org/wiki/Threshold_of_originality), which will be evaluated by the maintainers of the repository.
4 |
5 | Thank you for your contribution, your work is greatly appreciated !
6 |
7 | —-- Example —--
8 |
9 | - N: John Doe
10 | - E: john.doe@owkin.com
11 | - W: Integrated new feature
12 | - D: 02/02/2023
13 |
14 | ---
15 |
16 | Copyright (c) 2018-present Owkin Inc. All rights reserved.
17 |
18 | All other contributions:
19 | Copyright (c) 2023 to the respective contributors.
20 | All rights reserved.
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | Copyright 2018-2022 Owkin, Inc.
180 |
181 | Licensed under the Apache License, Version 2.0 (the "License");
182 | you may not use this file except in compliance with the License.
183 | You may obtain a copy of the License at
184 |
185 | http://www.apache.org/licenses/LICENSE-2.0
186 |
187 | Unless required by applicable law or agreed to in writing, software
188 | distributed under the License is distributed on an "AS IS" BASIS,
189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
190 | See the License for the specific language governing permissions and
191 | limitations under the License.
192 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include setup.cfg README.md
2 |
3 | recursive-include tests *
4 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: pyclean doc test doc-cli doc-sdk
2 |
3 | pyclean:
4 | find . -type f -name "*.py[co]" -delete
5 | find . -type d -name "__pycache__" -delete
6 | rm -rf build/ dist/ *.egg-info
7 |
8 | doc-sdk: pyclean
9 | python bin/generate_sdk_documentation.py
10 | python bin/generate_sdk_schemas_documentation.py
11 | python bin/generate_sdk_schemas_documentation.py --models --output-path='references/sdk_models.md'
12 |
13 | doc: doc-sdk
14 |
15 | test: pyclean
16 | pytest tests
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | Substra is an open source federated learning (FL) software. It enables the training and validation of machine learning models on distributed datasets. It provides a flexible Python interface and a web application to run federated learning training at scale. This specific repository is the low-level Python library used to interact with a Substra network.
17 |
18 | Substra's main usage is in production environments. It has already been deployed and used by hospitals and biotech companies (see the [MELLODDY](https://www.melloddy.eu/) project for instance). Substra can also be used on a single machine to perform FL simulations and debug code.
19 |
20 | Substra was originally developed by [Owkin](https://owkin.com/) and is now hosted by the [Linux Foundation for AI and Data](https://lfaidata.foundation/). Today Owkin is the main contributor to Substra.
21 |
22 | Join the discussion on [Slack](https://join.slack.com/t/substra-workspace/shared_invite/zt-1fqnk0nw6-xoPwuLJ8dAPXThfyldX8yA) and [subscribe here](https://lists.lfaidata.foundation/g/substra-announce/join) to our newsletter.
23 |
24 |
25 | ## To start using Substra
26 |
27 | Have a look at our [documentation](https://docs.substra.org/).
28 |
29 | Try out our [MNIST example](https://docs.substra.org/en/stable/substrafl_doc/examples/index.html#example-to-get-started-using-the-pytorch-interface).
30 |
31 | ## Support
32 |
33 | If you need support, please either raise an issue on Github or ask on [Slack](https://join.slack.com/t/substra-workspace/shared_invite/zt-1fqnk0nw6-xoPwuLJ8dAPXThfyldX8yA).
34 |
35 |
36 |
37 | ## Contributing
38 |
39 | Substra warmly welcomes any contribution. Feel free to fork the repo and create a pull request.
40 |
41 |
42 | ## Setup
43 |
44 | To setup the project in development mode, run:
45 |
46 | ```sh
47 | pip install -e ".[dev]"
48 | ```
49 |
50 | To run all tests, use the following command:
51 |
52 | ```sh
53 | make test
54 | ```
55 |
56 | Some of the tests require Docker running on your machine before running them.
57 |
58 | ## Code formatting
59 |
60 | You can opt into auto-formatting of code on pre-commit using [Black](https://github.com/psf/black).
61 |
62 | This relies on hooks managed by [pre-commit](https://pre-commit.com/), which you can set up as follows.
63 |
64 | Install [pre-commit](https://pre-commit.com/), then run:
65 |
66 | ```sh
67 | pre-commit install
68 | ```
69 |
70 | ## Documentation generation
71 |
72 | To generate the command line interface documentation, sdk and schemas documentation, the `python` version
73 | must be 3.8. Run the following command:
74 |
75 | ```sh
76 | make doc
77 | ```
78 |
79 | Documentation will be available in the *references/* directory.
80 |
81 | # Changelog generation
82 |
83 | The changelog is managed with [towncrier](https://towncrier.readthedocs.io/en/stable/index.html).
84 | To add a new entry in the changelog, add a file in the `changes` folder. The file name should have the following structure:
85 | `.`.
86 | The `unique_id` is a unique identifier, we currently use the PR number.
87 | The `change_type` can be of the following types: `added`, `changed`, `removed`, `fixed`.
88 |
89 | To generate the changelog (for example during a release), use the following command (you must have the dev dependencies installed):
90 |
91 | ```
92 | towncrier build --version=
93 | ```
94 | You can use the `--draft` option to see what would be generated without actually writing to the changelog (and without removing the fragments).
95 |
--------------------------------------------------------------------------------
/Substra-logo-colour.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/Substra-logo-white.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/bin/generate_sdk_documentation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 | import sys
4 | from pathlib import Path
5 |
6 | import docstring_parser
7 |
8 | from substra import Client
9 | from substra.sdk.utils import retry_on_exception
10 |
11 | MODULE_LIST = [Client, retry_on_exception]
12 |
13 | KEYWORDS = ["Args", "Returns", "Yields", "Raises", "Example"]
14 |
15 |
16 | def generate_function_help(fh, asset):
17 | """Write the description of a function"""
18 | fh.write(f"# {asset.__name__}\n")
19 | signature = str(inspect.signature(asset))
20 | fh.write("```text\n")
21 | fh.write(f"{asset.__name__}{signature}")
22 | fh.write("\n```")
23 | fh.write("\n\n")
24 | # Write the docstring
25 | docstring = inspect.getdoc(asset)
26 | docstring = docstring_parser.parse(inspect.getdoc(asset))
27 | fh.write(f"{docstring.short_description}\n")
28 | if docstring.long_description:
29 | fh.write(f"{docstring.long_description}\n")
30 | # Write the arguments as a list
31 | if len(docstring.params) > 0:
32 | fh.write("\n**Arguments:**\n")
33 | for param in docstring.params:
34 | type_and_optional = ""
35 | if param.type_name or param.is_optional is not None:
36 | text_optional = "required"
37 | if param.is_optional:
38 | text_optional = "optional"
39 | type_and_optional = f"({param.type_name}, {text_optional})"
40 | fh.write(f" - `{param.arg_name} {type_and_optional}`: {param.description}\n")
41 | # Write everything else as is
42 | for param in [
43 | meta_param for meta_param in docstring.meta if not isinstance(meta_param, docstring_parser.DocstringParam)
44 | ]:
45 | fh.write(f"\n**{param.args[0].title()}:**\n")
46 | if len(param.args) > 1:
47 | for extra_param in param.args[1:]:
48 | fh.write(f"\n - `{extra_param}`: ")
49 | fh.write(f"{param.description}\n")
50 |
51 |
52 | def generate_properties_help(fh, public_methods):
53 | properties = [(f_name, f_method) for f_name, f_method in public_methods if isinstance(f_method, property)]
54 | for f_name, f_method in properties:
55 | fh.write(f"## {f_name}\n")
56 | fh.write("_This is a property._ \n")
57 | fh.write(f"{f_method.__doc__}\n")
58 |
59 |
60 | def generate_help(fh):
61 | for asset in MODULE_LIST:
62 | if inspect.isclass(asset): # Class
63 | public_methods = [
64 | (f_name, f_method) for f_name, f_method in inspect.getmembers(asset) if not f_name.startswith("_")
65 | ]
66 | generate_function_help(fh, asset)
67 | generate_properties_help(fh, public_methods)
68 | for _, f_method in public_methods:
69 | if not isinstance(f_method, property):
70 | fh.write("#") # Title for the methods are one level below
71 | generate_function_help(fh, f_method)
72 | elif callable(asset):
73 | generate_function_help(fh, asset)
74 |
75 |
76 | def write_help(path):
77 | with path.open("w") as fh:
78 | generate_help(fh)
79 |
80 |
81 | if __name__ == "__main__":
82 | default_path = Path(__file__).resolve().parents[1] / "references" / "sdk.md"
83 |
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument("--output-path", type=str, default=str(default_path.resolve()), required=False)
86 |
87 | args = parser.parse_args(sys.argv[1:])
88 | write_help(Path(args.output_path))
89 |
--------------------------------------------------------------------------------
/bin/generate_sdk_schemas_documentation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 | import sys
4 | import warnings
5 | from pathlib import Path
6 |
7 | import pydantic
8 |
9 | from substra.sdk import models
10 | from substra.sdk import schemas
11 |
12 | local_dir = Path(__file__).parent
13 |
14 | schemas_list = [
15 | schemas.DataSampleSpec,
16 | schemas.DatasetSpec,
17 | schemas.UpdateDatasetSpec,
18 | schemas.FunctionSpec,
19 | schemas.FunctionInputSpec,
20 | schemas.FunctionOutputSpec,
21 | schemas.TaskSpec,
22 | schemas.ComputeTaskOutputSpec,
23 | schemas.UpdateFunctionSpec,
24 | schemas.ComputePlanSpec,
25 | schemas.UpdateComputePlanSpec,
26 | schemas.UpdateComputePlanTasksSpec,
27 | schemas.ComputePlanTaskSpec,
28 | schemas.Permissions,
29 | schemas.PrivatePermissions,
30 | ]
31 |
32 | models_list = [
33 | models.DataSample,
34 | models.Dataset,
35 | models.Task,
36 | models.Function,
37 | models.ComputePlan,
38 | models.Performances,
39 | models.Organization,
40 | models.Permissions,
41 | models.InModel,
42 | models.OutModel,
43 | ]
44 |
45 |
46 | def _get_field_description(fields):
47 | desc = [f"{name}: {field.annotation}" for name, field in fields.items()]
48 | return desc
49 |
50 |
51 | def generate_help(fh, models: bool):
52 | if models:
53 | asset_list = models_list
54 | title = "Models"
55 | else:
56 | asset_list = schemas_list
57 | title = "Schemas"
58 |
59 | fh.write("# Summary\n\n")
60 |
61 | def _create_anchor(schema):
62 | return "#{}".format(schema.__name__)
63 |
64 | for asset in asset_list:
65 | anchor = _create_anchor(asset)
66 | fh.write(f"- [{asset.__name__}]({anchor})\n")
67 |
68 | fh.write("\n\n")
69 | fh.write(f"# {title}\n\n")
70 |
71 | for asset in asset_list:
72 | anchor = _create_anchor(asset)
73 |
74 | fh.write(f"## {asset.__name__}\n")
75 | # Write the docstring
76 | fh.write(f"{inspect.getdoc(asset)}\n")
77 | # List the fields and their types
78 | description = _get_field_description(asset.model_fields)
79 | fh.write("```text\n")
80 | fh.write("- " + "\n- ".join(description))
81 | fh.write("\n```")
82 | fh.write("\n\n")
83 |
84 |
85 | def write_help(path, models: bool):
86 | with path.open("w") as fh:
87 | generate_help(fh, models)
88 |
89 |
90 | if __name__ == "__main__":
91 | expected_pydantic_version = "2.3.0"
92 | if pydantic.VERSION != expected_pydantic_version:
93 | warnings.warn(
94 | f"The documentation should be generated with the version {expected_pydantic_version} of pydantic or \
95 | there might be mismatches with the CI: version {pydantic.VERSION} used"
96 | )
97 |
98 | doc_dir = local_dir.parent / "references"
99 | default_path = doc_dir / "sdk_schemas.md"
100 |
101 | parser = argparse.ArgumentParser()
102 | parser.add_argument("--output-path", type=str, default=str(default_path.resolve()), required=False)
103 | parser.add_argument(
104 | "--models",
105 | action="store_true",
106 | help="Generate the doc for the models.\
107 | Default: generate for the schemas",
108 | )
109 |
110 | args = parser.parse_args(sys.argv[1:])
111 | write_help(Path(args.output_path), models=args.models)
112 |
--------------------------------------------------------------------------------
/changes/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/changes/.gitkeep
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Substra documentation
2 |
3 | The Substra documentation is hosted by Read The Docs and can be found [here](https://docs.substra.org/).
4 |
--------------------------------------------------------------------------------
/docs/technical_documentation.md:
--------------------------------------------------------------------------------
1 | # Substra technical documentation
2 |
3 | Client operations:
4 |
5 | ```mermaid
6 | graph TD;
7 | Client-->debug{debug}
8 | debug-->|False|remotebackend[remote backend]
9 | debug-->|True|local[local backend]
10 | local-->|CRUD|dal
11 | dal-->|read|isremote{asset is remote}
12 | isremote-->|True|remotebackend
13 | isremote-->|False|db
14 | dal-->|save / update|db
15 | local-->|execute task|worker
16 | worker-->|CRUD|dal
17 | worker-->spawner{spawner}
18 | spawner-->docker
19 | spawner-->subprocess
20 |
21 | click Client "https://github.com/owkin/substra/blob/main/substra/sdk/client.py"
22 | click debug "https://github.com/owkin/substra/blob/main/substra/sdk/client.py#L75"
23 | click remotebackend "https://github.com/owkin/substra/blob/main/substra/sdk/backends/remote/backend.py"
24 | click local "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/backend.py"
25 | click dal "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/dal.py"
26 | click db "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/db.py"
27 | click worker "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/compute/worker.py"
28 | click spawner "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/compute/worker.py#L69"
29 | click docker "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/compute/spawner/docker.py"
30 | click subprocess "https://github.com/owkin/substra/blob/main/substra/sdk/backends/local/compute/spawner/subprocess.py"
31 | ```
32 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [tool.hatch.build.targets.sdist]
6 | exclude = ["tests*"]
7 |
8 | [tool.hatch.version]
9 | path = "substra/__version__.py"
10 |
11 | [project]
12 | name = "substra"
13 | description = "Low-level Python library for interacting with a Substra network"
14 | dynamic = ["version"]
15 | readme = "README.md"
16 | requires-python = ">= 3.10"
17 | dependencies = [
18 | "requests!=2.32.*",
19 | "docker",
20 | "pyyaml",
21 | "pydantic>=2.3.0,<3.0.0",
22 | "tqdm",
23 | "python-slugify",
24 | ]
25 | keywords = ["substra"]
26 | classifiers = [
27 | "Intended Audience :: Developers",
28 | "Topic :: Utilities",
29 | "Natural Language :: English",
30 | "Operating System :: OS Independent",
31 | "Programming Language :: Python :: 3.10",
32 | "Programming Language :: Python :: 3.11",
33 | "Programming Language :: Python :: 3.12",
34 | ]
35 | license = { file = "LICENSE" }
36 | authors = [{ name = "Owkin, Inc." }]
37 |
38 |
39 | [project.optional-dependencies]
40 | dev = [
41 | "pandas",
42 | "pytest",
43 | "pytest-cov",
44 | "pytest-mock",
45 | "substratools~=1.0.0",
46 | "black",
47 | "flake8",
48 | "isort",
49 | "docstring_parser",
50 | "towncrier",
51 | ]
52 |
53 | [project.urls]
54 | Documentation = "https://docs.substra.org/en/stable/"
55 | Repository = "https://github.com/Substra/substra"
56 | Changelog = "https://github.com/Substra/substra/blob/main/CHANGELOG.md"
57 |
58 | [tool.pytest.ini_options]
59 | addopts = "-v --cov=substra --ignore=tests/unit --ignore=tests/e2e"
60 |
61 | [tool.black]
62 | line-length = 120
63 | target-version = ['py39']
64 |
65 | [tool.isort]
66 | filter_files = true
67 | force_single_line = true
68 | line_length = 120
69 | profile = "black"
70 |
71 | [tool.towncrier]
72 | directory = "changes"
73 | filename = "CHANGELOG.md"
74 | start_string = "\n"
75 | underlines = ["", "", ""]
76 | title_format = "## [{version}](https://github.com/Substra/substra/releases/tag/{version}) - {project_date}"
77 | issue_format = "[#{issue}](https://github.com/Substra/substra/pull/{issue})"
78 | [tool.towncrier.fragment.added]
79 | [tool.towncrier.fragment.removed]
80 | [tool.towncrier.fragment.changed]
81 | [tool.towncrier.fragment.fixed]
82 |
--------------------------------------------------------------------------------
/references/sdk_models.md:
--------------------------------------------------------------------------------
1 | # Summary
2 |
3 | - [DataSample](#DataSample)
4 | - [Dataset](#Dataset)
5 | - [Task](#Task)
6 | - [Function](#Function)
7 | - [ComputePlan](#ComputePlan)
8 | - [Performances](#Performances)
9 | - [Organization](#Organization)
10 | - [Permissions](#Permissions)
11 | - [InModel](#InModel)
12 | - [OutModel](#OutModel)
13 |
14 |
15 | # Models
16 |
17 | ## DataSample
18 | Data sample
19 | ```text
20 | - key:
21 | - owner:
22 | - data_manager_keys: typing.Optional[typing.List[str]]
23 | - path: typing.Optional[typing.Annotated[pathlib.Path, PathType(path_type='dir')]]
24 | - creation_date:
25 | ```
26 |
27 | ## Dataset
28 | Dataset asset
29 | ```text
30 | - key:
31 | - name:
32 | - owner:
33 | - permissions:
34 | - data_sample_keys: typing.List[str]
35 | - opener:
36 | - description:
37 | - metadata: typing.Dict[str, str]
38 | - creation_date:
39 | - logs_permission:
40 | ```
41 |
42 | ## Task
43 | Asset creation specification base class.
44 | ```text
45 | - key:
46 | - function:
47 | - owner:
48 | - compute_plan_key:
49 | - metadata: typing.Dict[str, str]
50 | - status:
51 | - worker:
52 | - rank: typing.Optional[int]
53 | - tag:
54 | - creation_date:
55 | - start_date: typing.Optional[datetime.datetime]
56 | - end_date: typing.Optional[datetime.datetime]
57 | - error_type: typing.Optional[substra.sdk.models.TaskErrorType]
58 | - inputs: typing.List[substra.sdk.models.InputRef]
59 | - outputs: typing.Dict[str, substra.sdk.models.ComputeTaskOutput]
60 | ```
61 |
62 | ## Function
63 | Asset creation specification base class.
64 | ```text
65 | - key:
66 | - name:
67 | - owner:
68 | - permissions:
69 | - metadata: typing.Dict[str, str]
70 | - creation_date:
71 | - inputs: typing.List[substra.sdk.models.FunctionInput]
72 | - outputs: typing.List[substra.sdk.models.FunctionOutput]
73 | - status:
74 | - description:
75 | - archive:
76 | ```
77 |
78 | ## ComputePlan
79 | ComputePlan
80 | ```text
81 | - key:
82 | - tag:
83 | - name:
84 | - owner:
85 | - metadata: typing.Dict[str, str]
86 | - task_count:
87 | - waiting_builder_slot_count:
88 | - building_count:
89 | - waiting_parent_tasks_count:
90 | - waiting_executor_slot_count:
91 | - executing_count:
92 | - canceled_count:
93 | - failed_count:
94 | - done_count:
95 | - failed_task_key: typing.Optional[str]
96 | - status:
97 | - creation_date:
98 | - start_date: typing.Optional[datetime.datetime]
99 | - end_date: typing.Optional[datetime.datetime]
100 | - estimated_end_date: typing.Optional[datetime.datetime]
101 | - duration: typing.Optional[int]
102 | - creator: typing.Optional[str]
103 | ```
104 |
105 | ## Performances
106 | Performances of the different compute tasks of a compute plan
107 | ```text
108 | - compute_plan_key: typing.List[str]
109 | - compute_plan_tag: typing.List[str]
110 | - compute_plan_status: typing.List[str]
111 | - compute_plan_start_date: typing.List[datetime.datetime]
112 | - compute_plan_end_date: typing.List[datetime.datetime]
113 | - compute_plan_metadata: typing.List[dict]
114 | - worker: typing.List[str]
115 | - task_key: typing.List[str]
116 | - task_rank: typing.List[int]
117 | - round_idx: typing.List[int]
118 | - identifier: typing.List[str]
119 | - performance: typing.List[float]
120 | ```
121 |
122 | ## Organization
123 | Organization
124 | ```text
125 | - id:
126 | - is_current:
127 | - creation_date:
128 | ```
129 |
130 | ## Permissions
131 | Permissions structure stored in various asset types.
132 | ```text
133 | - process:
134 | ```
135 |
136 | ## InModel
137 | In model of a task
138 | ```text
139 | - checksum:
140 | - storage_address: typing.Union[typing.Annotated[pathlib.Path, PathType(path_type='file')], pydantic_core._pydantic_core.Url, str]
141 | ```
142 |
143 | ## OutModel
144 | Out model of a task
145 | ```text
146 | - key:
147 | - compute_task_key:
148 | - address: typing.Optional[substra.sdk.models.InModel]
149 | - permissions:
150 | - owner:
151 | - creation_date:
152 | ```
153 |
154 |
--------------------------------------------------------------------------------
/references/sdk_schemas.md:
--------------------------------------------------------------------------------
1 | # Summary
2 |
3 | - [DataSampleSpec](#DataSampleSpec)
4 | - [DatasetSpec](#DatasetSpec)
5 | - [UpdateDatasetSpec](#UpdateDatasetSpec)
6 | - [FunctionSpec](#FunctionSpec)
7 | - [FunctionInputSpec](#FunctionInputSpec)
8 | - [FunctionOutputSpec](#FunctionOutputSpec)
9 | - [TaskSpec](#TaskSpec)
10 | - [ComputeTaskOutputSpec](#ComputeTaskOutputSpec)
11 | - [UpdateFunctionSpec](#UpdateFunctionSpec)
12 | - [ComputePlanSpec](#ComputePlanSpec)
13 | - [UpdateComputePlanSpec](#UpdateComputePlanSpec)
14 | - [UpdateComputePlanTasksSpec](#UpdateComputePlanTasksSpec)
15 | - [ComputePlanTaskSpec](#ComputePlanTaskSpec)
16 | - [Permissions](#Permissions)
17 | - [PrivatePermissions](#PrivatePermissions)
18 |
19 |
20 | # Schemas
21 |
22 | ## DataSampleSpec
23 | Specification to create one or many data samples
24 | To create one data sample, use the 'path' field, otherwise use
25 | the 'paths' field.
26 | ```text
27 | - path: typing.Optional[pathlib.Path]
28 | - paths: typing.Optional[typing.List[pathlib.Path]]
29 | - data_manager_keys: typing.List[str]
30 | ```
31 |
32 | ## DatasetSpec
33 | Specification for creating a dataset
34 |
35 | note : metadata field does not accept strings containing '__' as dict key
36 |
37 | note : If no description markdown file is given, create an empty one on the data_opener folder.
38 | ```text
39 | - name:
40 | - data_opener:
41 | - description: typing.Optional[pathlib.Path]
42 | - permissions:
43 | - metadata: typing.Optional[typing.Dict[str, str]]
44 | - logs_permission:
45 | ```
46 |
47 | ## UpdateDatasetSpec
48 | Specification for updating a dataset
49 | ```text
50 | - name:
51 | ```
52 |
53 | ## FunctionSpec
54 | Specification for creating an function
55 |
56 | note : metadata field does not accept strings containing '__' as dict key
57 | ```text
58 | - name:
59 | - description:
60 | - file:
61 | - permissions:
62 | - metadata: typing.Optional[typing.Dict[str, str]]
63 | - inputs: typing.Optional[typing.List[substra.sdk.schemas.FunctionInputSpec]]
64 | - outputs: typing.Optional[typing.List[substra.sdk.schemas.FunctionOutputSpec]]
65 | ```
66 |
67 | ## FunctionInputSpec
68 | Asset creation specification base class.
69 | ```text
70 | - identifier:
71 | - multiple:
72 | - optional:
73 | - kind:
74 | ```
75 |
76 | ## FunctionOutputSpec
77 | Asset creation specification base class.
78 | ```text
79 | - identifier:
80 | - kind:
81 | - multiple:
82 | ```
83 |
84 | ## TaskSpec
85 | Asset creation specification base class.
86 | ```text
87 | - key:
88 | - tag: typing.Optional[str]
89 | - compute_plan_key: typing.Optional[str]
90 | - metadata: typing.Optional[typing.Dict[str, str]]
91 | - function_key:
92 | - worker:
93 | - rank: typing.Optional[int]
94 | - inputs: typing.Optional[typing.List[substra.sdk.schemas.InputRef]]
95 | - outputs: typing.Optional[typing.Dict[str, substra.sdk.schemas.ComputeTaskOutputSpec]]
96 | ```
97 |
98 | ## ComputeTaskOutputSpec
99 | Specification of a compute task output
100 | ```text
101 | - permissions:
102 | - is_transient: typing.Optional[bool]
103 | ```
104 |
105 | ## UpdateFunctionSpec
106 | Specification for updating an function
107 | ```text
108 | - name:
109 | ```
110 |
111 | ## ComputePlanSpec
112 | Specification for creating a compute plan
113 |
114 | note : metadata field does not accept strings containing '__' as dict key
115 | ```text
116 | - key:
117 | - tasks: typing.Optional[typing.List[substra.sdk.schemas.ComputePlanTaskSpec]]
118 | - tag: typing.Optional[str]
119 | - name:
120 | - metadata: typing.Optional[typing.Dict[str, str]]
121 | ```
122 |
123 | ## UpdateComputePlanSpec
124 | Specification for updating a compute plan
125 | ```text
126 | - name:
127 | ```
128 |
129 | ## UpdateComputePlanTasksSpec
130 | Specification for updating a compute plan's tasks
131 | ```text
132 | - key:
133 | - tasks: typing.Optional[typing.List[substra.sdk.schemas.ComputePlanTaskSpec]]
134 | ```
135 |
136 | ## ComputePlanTaskSpec
137 | Specification of a compute task inside a compute plan specification
138 |
139 | note : metadata field does not accept strings containing '__' as dict key
140 | ```text
141 | - task_id:
142 | - function_key:
143 | - worker:
144 | - tag: typing.Optional[str]
145 | - metadata: typing.Optional[typing.Dict[str, str]]
146 | - inputs: typing.Optional[typing.List[substra.sdk.schemas.InputRef]]
147 | - outputs: typing.Optional[typing.Dict[str, substra.sdk.schemas.ComputeTaskOutputSpec]]
148 | ```
149 |
150 | ## Permissions
151 | Specification for permissions. If public is False,
152 | give the list of authorized ids.
153 | ```text
154 | - public:
155 | - authorized_ids: typing.List[str]
156 | ```
157 |
158 | ## PrivatePermissions
159 | Specification for private permissions. Only the organizations whose
160 | ids are in authorized_ids can access the asset.
161 | ```text
162 | - authorized_ids: typing.List[str]
163 | ```
164 |
165 |
--------------------------------------------------------------------------------
/substra/__init__.py:
--------------------------------------------------------------------------------
1 | from substra.__version__ import __version__
2 | from substra.sdk import Client
3 | from substra.sdk import exceptions
4 | from substra.sdk import models
5 | from substra.sdk import schemas
6 | from substra.sdk.schemas import BackendType
7 |
8 | __all__ = [
9 | "__version__",
10 | "Client",
11 | "exceptions",
12 | "BackendType",
13 | "schemas",
14 | "models",
15 | ]
16 |
--------------------------------------------------------------------------------
/substra/__version__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
--------------------------------------------------------------------------------
/substra/config.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/substra/config.py
--------------------------------------------------------------------------------
/substra/sdk/__init__.py:
--------------------------------------------------------------------------------
1 | from substra.sdk import models
2 | from substra.sdk import schemas
3 | from substra.sdk.client import Client
4 | from substra.sdk.schemas import BackendType
5 | from substra.sdk.utils import retry_on_exception
6 |
7 | __all__ = [
8 | "Client",
9 | "retry_on_exception",
10 | "BackendType",
11 | "schemas",
12 | "models",
13 | ]
14 |
--------------------------------------------------------------------------------
/substra/sdk/archive/__init__.py:
--------------------------------------------------------------------------------
1 | import tarfile
2 | import zipfile
3 |
4 | from substra.sdk import exceptions
5 | from substra.sdk.archive import tarsafe
6 | from substra.sdk.archive.safezip import ZipFile
7 |
8 |
9 | def _untar(archive, to_):
10 | with tarsafe.open(archive, "r:*") as tf:
11 | tf.extractall(to_)
12 |
13 |
14 | def _unzip(archive, to_):
15 | with ZipFile(archive, "r") as zf:
16 | zf.extractall(to_)
17 |
18 |
19 | def uncompress(archive, to_):
20 | """Uncompress tar or zip archive to destination."""
21 | if tarfile.is_tarfile(archive):
22 | _untar(archive, to_)
23 | elif zipfile.is_zipfile(archive):
24 | _unzip(archive, to_)
25 | else:
26 | raise exceptions.InvalidRequest(f"Cannot uncompress '{archive}'", 400)
27 |
--------------------------------------------------------------------------------
/substra/sdk/archive/safezip.py:
--------------------------------------------------------------------------------
1 | import os
2 | import stat
3 | import zipfile
4 |
5 |
6 | class ZipFile(zipfile.ZipFile):
7 | """Override Zipfile to ensure unix file permissions are preserved.
8 |
9 | This is due to a python bug:
10 | https://bugs.python.org/issue15795
11 |
12 | Workaround from:
13 | https://stackoverflow.com/questions/39296101/python-zipfile-removes-execute-permissions-from-binaries
14 | """
15 |
16 | def extract(self, member, path=None, pwd=None):
17 | if not isinstance(member, zipfile.ZipInfo):
18 | member = self.getinfo(member)
19 |
20 | if path is None:
21 | path = os.getcwd()
22 |
23 | ret_val = self._extract_member(member, path, pwd)
24 | attr = member.external_attr >> 16
25 | os.chmod(ret_val, attr)
26 | return ret_val
27 |
28 | def extractall(self, path=None, members=None, pwd=None):
29 | self._sanity_check()
30 | super().extractall(path, members, pwd)
31 |
32 | def _sanity_check(self):
33 | """Check that the archive does not attempt to traverse path.
34 |
35 | This is inspired by TarSafe: https://github.com/beatsbears/tarsafe
36 | """
37 | for zipinfo in self.infolist():
38 | if self._is_traversal_attempt(zipinfo):
39 | raise Exception(f"Attempted directory traversal for member: {zipinfo.filename}")
40 | if self._is_symlink(zipinfo):
41 | raise Exception(f"Unsupported symlink for member: {zipinfo.filename}")
42 |
43 | def _is_traversal_attempt(self, zipinfo: zipfile.ZipInfo) -> bool:
44 | base_directory = os.getcwd()
45 | zipfile_path = os.path.abspath(os.path.join(base_directory, zipinfo.filename))
46 | if not zipfile_path.startswith(base_directory):
47 | return True
48 | return False
49 |
50 | def _is_symlink(self, zipinfo: zipfile.ZipInfo) -> bool:
51 | return zipinfo.external_attr & stat.S_IFLNK << 16 == stat.S_IFLNK << 16
52 |
--------------------------------------------------------------------------------
/substra/sdk/archive/tarsafe.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import tarfile
4 |
5 |
6 | class TarSafe(tarfile.TarFile):
7 | """
8 | A safe subclass of the TarFile class for interacting with tar files.
9 | """
10 |
11 | def __init__(self, *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | self.directory = os.getcwd()
14 |
15 | @classmethod
16 | def open(cls, name=None, mode="r", fileobj=None, bufsize=tarfile.RECORDSIZE, **kwargs):
17 | return super().open(name, mode, fileobj, bufsize, **kwargs)
18 |
19 | def extract(self, member, path="", set_attrs=True, *, numeric_owner=False):
20 | """
21 | Override the parent extract method and add safety checks.
22 | """
23 | self._safetar_check()
24 | super().extract(member, path, set_attrs=set_attrs, numeric_owner=numeric_owner)
25 |
26 | def extractall(self, path=".", members=None, numeric_owner=False):
27 | """
28 | Override the parent extractall method and add safety checks.
29 | """
30 | self._safetar_check()
31 | super().extractall(path, members, numeric_owner=numeric_owner)
32 |
33 | def _safetar_check(self):
34 | """
35 | Runs all necessary checks for the safety of a tarfile.
36 | """
37 | try:
38 | for tarinfo in self.__iter__():
39 | if self._is_traversal_attempt(tarinfo=tarinfo):
40 | raise TarSafeError(f"Attempted directory traversal for member: {tarinfo.name}")
41 | if tarinfo.issym():
42 | raise TarSafeError(f"Unsupported symlink for member: {tarinfo.linkname}")
43 | if self._is_unsafe_link(tarinfo=tarinfo):
44 | raise TarSafeError(f"Attempted directory traversal via link for member: {tarinfo.linkname}")
45 | if self._is_device(tarinfo=tarinfo):
46 | raise TarSafeError("tarfile returns true for isblk() or ischr()")
47 | except Exception:
48 | raise
49 |
50 | def _is_traversal_attempt(self, tarinfo):
51 | if not os.path.abspath(os.path.join(self.directory, tarinfo.name)).startswith(self.directory):
52 | return True
53 | return False
54 |
55 | def _is_unsafe_link(self, tarinfo):
56 | if tarinfo.islnk():
57 | link_file = pathlib.Path(os.path.normpath(os.path.join(self.directory, tarinfo.linkname)))
58 | if not os.path.abspath(os.path.join(self.directory, link_file)).startswith(self.directory):
59 | return True
60 | return False
61 |
62 | def _is_device(self, tarinfo):
63 | return tarinfo.ischr() or tarinfo.isblk()
64 |
65 |
66 | class TarSafeError(Exception):
67 | pass
68 |
69 |
70 | open = TarSafe.open
71 |
--------------------------------------------------------------------------------
/substra/sdk/backends/__init__.py:
--------------------------------------------------------------------------------
1 | from substra.sdk import schemas
2 | from substra.sdk.backends.local.backend import Local
3 | from substra.sdk.backends.remote.backend import Remote
4 |
5 | _BACKEND_CHOICES = {
6 | schemas.BackendType.REMOTE: Remote,
7 | schemas.BackendType.LOCAL_DOCKER: Local,
8 | schemas.BackendType.LOCAL_SUBPROCESS: Local,
9 | }
10 |
11 |
12 | def get(name, *args, **kwargs):
13 | return _BACKEND_CHOICES[name](*args, **kwargs, backend_type=name)
14 |
--------------------------------------------------------------------------------
/substra/sdk/backends/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import List
3 |
4 | from substra.sdk.schemas import BackendType
5 |
6 |
7 | class BaseBackend(abc.ABC):
8 | @property
9 | @abc.abstractmethod
10 | def backend_mode(self) -> BackendType:
11 | raise NotImplementedError
12 |
13 | @abc.abstractmethod
14 | def login(self, username, password):
15 | pass
16 |
17 | @abc.abstractmethod
18 | def logout(self):
19 | pass
20 |
21 | @abc.abstractmethod
22 | def get(self, asset_type, key):
23 | raise NotImplementedError
24 |
25 | @abc.abstractmethod
26 | def list(self, asset_type, filters=None, paginated=False):
27 | raise NotImplementedError
28 |
29 | @abc.abstractmethod
30 | def add(self, spec, spec_options=None):
31 | raise NotImplementedError
32 |
33 | @abc.abstractmethod
34 | def update(self, key, spec, spec_options=None):
35 | raise NotImplementedError
36 |
37 | @abc.abstractmethod
38 | def add_compute_plan_tasks(self, spec, spec_options):
39 | raise NotImplementedError
40 |
41 | @abc.abstractmethod
42 | def link_dataset_with_data_samples(self, dataset_key, data_sample_keys) -> List[str]:
43 | raise NotImplementedError
44 |
45 | @abc.abstractmethod
46 | def download(self, asset_type, url_field_path, key, destination):
47 | raise NotImplementedError
48 |
49 | @abc.abstractmethod
50 | def describe(self, asset_type, key):
51 | raise NotImplementedError
52 |
53 | @abc.abstractmethod
54 | def cancel_compute_plan(self, key):
55 | raise NotImplementedError
56 |
--------------------------------------------------------------------------------
/substra/sdk/backends/local/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/substra/sdk/backends/local/__init__.py
--------------------------------------------------------------------------------
/substra/sdk/backends/local/compute/__init__.py:
--------------------------------------------------------------------------------
1 | from substra.sdk.backends.local.compute.worker import Worker
2 |
3 | __all__ = [
4 | "Worker",
5 | ]
6 |
--------------------------------------------------------------------------------
/substra/sdk/backends/local/compute/spawner/__init__.py:
--------------------------------------------------------------------------------
1 | from substra.sdk.backends.local.compute.spawner.base import BaseSpawner
2 | from substra.sdk.backends.local.compute.spawner.docker import Docker
3 | from substra.sdk.backends.local.compute.spawner.subprocess import Subprocess
4 | from substra.sdk.schemas import BackendType
5 |
6 | __all__ = ["BaseSpawner", "Docker", "Subprocess"]
7 |
8 | DEBUG_SPAWNER_CHOICES = {
9 | BackendType.LOCAL_DOCKER: Docker,
10 | BackendType.LOCAL_SUBPROCESS: Subprocess,
11 | }
12 |
13 |
14 | def get(name, *args, **kwargs):
15 | """Return a Docker spawner or a Subprocess spawner"""
16 | return DEBUG_SPAWNER_CHOICES[name](*args, **kwargs)
17 |
--------------------------------------------------------------------------------
/substra/sdk/backends/local/compute/spawner/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import pathlib
3 | import string
4 | import typing
5 |
6 | VOLUME_CLI_ARGS = "_VOLUME_CLI_ARGS"
7 | VOLUME_INPUTS = "_VOLUME_INPUTS"
8 | VOLUME_OUTPUTS = "_VOLUME_OUTPUTS"
9 |
10 |
11 | class BuildError(Exception):
12 | """An error occurred during the build of the function"""
13 |
14 | pass
15 |
16 |
17 | class ExecutionError(Exception):
18 | """An error occurred during the execution of the compute task"""
19 |
20 | pass
21 |
22 |
23 | class BaseSpawner(abc.ABC):
24 | """Base wrapper to execute a command"""
25 |
26 | def __init__(self, local_worker_dir: pathlib.Path):
27 | self._local_worker_dir = local_worker_dir
28 |
29 | @abc.abstractmethod
30 | def spawn(
31 | self,
32 | name,
33 | archive_path,
34 | command_args_tpl: typing.List[string.Template],
35 | data_sample_paths: typing.Optional[typing.Dict[str, pathlib.Path]],
36 | local_volumes,
37 | envs,
38 | ):
39 | """Execute archive in a contained environment."""
40 | raise NotImplementedError
41 |
42 |
43 | def write_command_args_file(args_file: pathlib.Path, command_args: typing.List[str]) -> None:
44 | """Write the substra-tools command line arguments to a file.
45 |
46 | The format uses one line per argument. See
47 | https://docs.python.org/3/library/argparse.html#fromfile-prefix-chars
48 | """
49 | with open(args_file, "w") as f:
50 | for item in command_args:
51 | f.write(item + "\n")
52 |
--------------------------------------------------------------------------------
/substra/sdk/backends/local/compute/spawner/docker.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pathlib
3 | import shutil
4 | import string
5 | import tempfile
6 | import typing
7 |
8 | import docker
9 |
10 | from substra.sdk.archive import uncompress
11 | from substra.sdk.backends.local.compute.spawner.base import VOLUME_CLI_ARGS
12 | from substra.sdk.backends.local.compute.spawner.base import VOLUME_INPUTS
13 | from substra.sdk.backends.local.compute.spawner.base import VOLUME_OUTPUTS
14 | from substra.sdk.backends.local.compute.spawner.base import BaseSpawner
15 | from substra.sdk.backends.local.compute.spawner.base import BuildError
16 | from substra.sdk.backends.local.compute.spawner.base import ExecutionError
17 | from substra.sdk.backends.local.compute.spawner.base import write_command_args_file
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 | ROOT_DIR = "/substra_internal"
22 | DOCKER_VOLUMES = {
23 | VOLUME_INPUTS: {"bind": f"{ROOT_DIR}/inputs", "mode": "ro"},
24 | VOLUME_OUTPUTS: {"bind": f"{ROOT_DIR}/outputs", "mode": "rw"},
25 | VOLUME_CLI_ARGS: {"bind": f"{ROOT_DIR}/cli-args", "mode": "rw"},
26 | }
27 |
28 |
29 | def _copy_data_samples(data_sample_paths: typing.Dict[str, pathlib.Path], dest_dir: str):
30 | """Move the data samples to the data directory.
31 |
32 | We copy the data samples even though it is slow because:
33 |
34 | - symbolic links do not work with Docker container volumes
35 | - hard links cannot be created across partitions
36 | - mounting each data sample as its own volume causes permission errors
37 |
38 | Args:
39 | data_sample_paths (typing.Dict[str, pathlib.Path]): Paths to the samples
40 | dest_dir (str): Temp data directory
41 | """
42 | # Check if there are already data samples in the dest dir (testtasks are executed in 2 parts)
43 | sample_key = next(iter(data_sample_paths))
44 | if (pathlib.Path(dest_dir) / sample_key).exists():
45 | return
46 | # copy the whole tree
47 | for sample_key, sample_path in data_sample_paths.items():
48 | dest_path = pathlib.Path(dest_dir) / sample_key
49 | shutil.copytree(sample_path, dest_path)
50 |
51 |
52 | class Docker(BaseSpawner):
53 | """Wrapper around docker daemon to execute a command in a container."""
54 |
55 | def __init__(self, local_worker_dir: pathlib.Path):
56 | try:
57 | self._docker = docker.from_env()
58 | except docker.errors.DockerException as e:
59 | raise ConnectionError(
60 | "Couldn't get the Docker client from environment variables. "
61 | "Is your Docker server running ?\n"
62 | "Docker error : {0}".format(e)
63 | )
64 | super().__init__(local_worker_dir=local_worker_dir)
65 |
66 | def _build_docker_image(self, name: str, archive_path: pathlib.Path):
67 | """Spawn a docker container (blocking)."""
68 | with tempfile.TemporaryDirectory(dir=self._local_worker_dir) as tmpdir:
69 | image_exists = False
70 | try:
71 | self._docker.images.get(name=name)
72 | image_exists = True
73 | except docker.errors.ImageNotFound:
74 | pass
75 |
76 | if not image_exists:
77 | try:
78 | logger.debug("Did not find the Docker image %s - building it", name)
79 | uncompress(archive_path, tmpdir)
80 | self._docker.images.build(path=tmpdir, tag=name, rm=True)
81 | except docker.errors.BuildError as exc:
82 | log = ""
83 | for line in exc.build_log:
84 | if "stream" in line:
85 | log += line["stream"].strip()
86 | logger.error(log)
87 | raise BuildError(log)
88 |
89 | def spawn(
90 | self,
91 | name: str,
92 | archive_path: pathlib.Path,
93 | command_args_tpl: typing.List[string.Template],
94 | data_sample_paths: typing.Optional[typing.Dict[str, pathlib.Path]],
95 | local_volumes: typing.Optional[dict],
96 | envs: typing.Optional[typing.List[str]],
97 | ):
98 | """Build the docker image, copy the data samples then spawn a Docker container
99 | and execute the task.
100 | """
101 | self._build_docker_image(name=name, archive_path=archive_path)
102 |
103 | # format the command to replace each occurrence of a DOCKER_VOLUMES's key
104 | # by its "bind" value
105 | volumes_format = {volume_name: volume_path["bind"] for volume_name, volume_path in DOCKER_VOLUMES.items()}
106 | command_args = [tpl.substitute(**volumes_format) for tpl in command_args_tpl]
107 |
108 | if data_sample_paths is not None and len(data_sample_paths) > 0:
109 | _copy_data_samples(data_sample_paths, local_volumes[VOLUME_INPUTS])
110 |
111 | args_filename = "arguments.txt"
112 | args_path_local = pathlib.Path(local_volumes[VOLUME_CLI_ARGS]) / args_filename
113 | # Pathlib is incompatible here for Windows, as it would create a "WindowsPath"
114 | args_path_docker = DOCKER_VOLUMES[VOLUME_CLI_ARGS]["bind"] + "/" + args_filename
115 | write_command_args_file(args_path_local, command_args)
116 |
117 | # create the volumes dict for docker by binding the local_volumes and the DOCKER_VOLUME
118 | volumes_docker = {
119 | volume_path: DOCKER_VOLUMES[volume_name] for volume_name, volume_path in local_volumes.items()
120 | }
121 |
122 | container = self._docker.containers.run(
123 | name,
124 | command=f"@{args_path_docker}",
125 | volumes=volumes_docker or {},
126 | environment=envs,
127 | remove=False,
128 | detach=True,
129 | tty=True,
130 | stdin_open=True,
131 | shm_size="8G",
132 | )
133 |
134 | execution_logs = []
135 | for line in container.logs(stream=True, stdout=True, stderr=True):
136 | execution_logs.append(line.decode("utf-8"))
137 |
138 | r = container.wait()
139 | execution_logs_str = "".join(execution_logs)
140 | exit_code = r["StatusCode"]
141 | if exit_code != 0:
142 | logger.error("\n\nExecution logs: %s", execution_logs_str)
143 | raise ExecutionError(f"Container '{name}' exited with status code '{exit_code}'")
144 |
145 | container.remove()
146 |
--------------------------------------------------------------------------------
/substra/sdk/backends/local/compute/spawner/subprocess.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pathlib
4 | import re
5 | import shutil
6 | import string
7 | import subprocess
8 | import sys
9 | import tempfile
10 | import typing
11 |
12 | from substra.sdk.archive import uncompress
13 | from substra.sdk.backends.local.compute.spawner.base import VOLUME_INPUTS
14 | from substra.sdk.backends.local.compute.spawner.base import BaseSpawner
15 | from substra.sdk.backends.local.compute.spawner.base import ExecutionError
16 | from substra.sdk.backends.local.compute.spawner.base import write_command_args_file
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | PYTHON_SCRIPT_REGEX = r"(?<=\")([^\"]*\.py)(?=\")"
21 | METHOD_REGEX = r"\"\-\-function-name\"\,\s*\"([^\"]*)\""
22 |
23 |
24 | def _get_entrypoint_from_dockerfile(tmpdir):
25 | """
26 | Extracts the .py script and the function name to execute in
27 | an ENTRYPOINT line of a Dockerfile, located in tmpdir.
28 | For instance if the line `ENTRYPOINT ["python3", "function.py", "--function-name", "train"]` is in the Dockerfile,
29 | `function.py`, `train` is extracted.
30 | """
31 | valid_example = (
32 | """The entry point should be specified as follow: """
33 | """``ENTRYPOINT ["", "", "--function-name", ""]"""
34 | )
35 | with open(tmpdir / "Dockerfile") as f:
36 | for line in f:
37 | if "ENTRYPOINT" not in line:
38 | continue
39 |
40 | script_name = re.findall(PYTHON_SCRIPT_REGEX, line)
41 | if len(script_name) != 1:
42 | raise ExecutionError("Couldn't extract script from ENTRYPOINT line in Dockerfile", valid_example)
43 |
44 | function_name = re.findall(METHOD_REGEX, line)
45 | if len(function_name) != 1:
46 | raise ExecutionError("Couldn't extract method name from ENTRYPOINT line in Dockerfile", valid_example)
47 |
48 | return script_name[0], function_name[0]
49 |
50 | raise ExecutionError("Couldn't get entrypoint in Dockerfile", valid_example)
51 |
52 |
53 | def _get_command_args(
54 | function_name: str, args_template: typing.List[string.Template], local_volumes: typing.Dict[str, str]
55 | ) -> typing.List[str]:
56 | args = ["--function-name", str(function_name)]
57 | args += [tpl.substitute(**local_volumes) for tpl in args_template]
58 | return args
59 |
60 |
61 | def _symlink_data_samples(data_sample_paths: typing.Dict[str, pathlib.Path], dest_dir: str):
62 | """Create a symbolic link to "move" the data samples
63 | to the data directory.
64 |
65 | Args:
66 | data_sample_paths (typing.Dict[str, pathlib.Path]): Paths to the samples
67 | dest_dir (str): Temp data directory
68 | """
69 | # Check if there are already data samples in the dest dir (testtasks are executed in 2 parts)
70 | sample_key = next(iter(data_sample_paths))
71 | if (pathlib.Path(dest_dir) / sample_key).exists():
72 | return
73 | # copy the whole tree but using hard link to be fast and not use too much place
74 | for sample_key, sample_path in data_sample_paths.items():
75 | dest_path = pathlib.Path(dest_dir) / sample_key
76 | shutil.copytree(sample_path, dest_path, copy_function=os.symlink)
77 |
78 |
79 | class Subprocess(BaseSpawner):
80 | """Wrapper to execute a command in a python process."""
81 |
82 | def __init__(self, local_worker_dir: pathlib.Path):
83 | super().__init__(local_worker_dir=local_worker_dir)
84 |
85 | def spawn(
86 | self,
87 | name,
88 | archive_path,
89 | command_args_tpl: typing.List[string.Template],
90 | data_sample_paths: typing.Optional[typing.Dict[str, pathlib.Path]],
91 | local_volumes,
92 | envs,
93 | ):
94 | """Spawn a python process (blocking)."""
95 | with tempfile.TemporaryDirectory(dir=self._local_worker_dir) as function_dir:
96 | with tempfile.TemporaryDirectory(dir=function_dir) as args_dir:
97 | function_dir = pathlib.Path(function_dir)
98 | args_dir = pathlib.Path(args_dir)
99 | uncompress(archive_path, function_dir)
100 | script_name, function_name = _get_entrypoint_from_dockerfile(function_dir)
101 |
102 | args_file = args_dir / "arguments.txt"
103 |
104 | py_command = [sys.executable, str(function_dir / script_name), f"@{args_file}"]
105 | py_command_args = _get_command_args(function_name, command_args_tpl, local_volumes)
106 | write_command_args_file(args_file, py_command_args)
107 |
108 | if data_sample_paths is not None and len(data_sample_paths) > 0:
109 | _symlink_data_samples(data_sample_paths, local_volumes[VOLUME_INPUTS])
110 |
111 | # Catching error and raising to be ISO to the docker local backend
112 | # Don't capture the output to be able to use pdb
113 | try:
114 | subprocess.run(py_command, capture_output=False, check=True, cwd=function_dir, env=envs)
115 | except subprocess.CalledProcessError as e:
116 | raise ExecutionError(e)
117 |
--------------------------------------------------------------------------------
/substra/sdk/backends/local/dal.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pathlib
3 | import shutil
4 | import tempfile
5 | import typing
6 |
7 | from substra.sdk import exceptions
8 | from substra.sdk import models
9 | from substra.sdk import schemas
10 | from substra.sdk.backends.local import db
11 | from substra.sdk.backends.remote import backend
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class DataAccess:
17 | """Data access layer.
18 |
19 | This is an intermediate layer between the backend and the local/remote data access.
20 | """
21 |
22 | def __init__(self, remote_backend: typing.Optional[backend.Remote], local_worker_dir: pathlib.Path):
23 | self._db = db.get_db()
24 | self._remote = remote_backend
25 | self._tmp_dir = tempfile.TemporaryDirectory(prefix=str(local_worker_dir) + "/")
26 |
27 | @property
28 | def tmp_dir(self):
29 | return pathlib.Path(self._tmp_dir.name)
30 |
31 | def is_local(self, key: str, type_: schemas.Type):
32 | try:
33 | self._db.get(type_, key)
34 | return True
35 | except exceptions.NotFound:
36 | return False
37 |
38 | def _get_asset_content_filename(self, type_):
39 | if type_ == schemas.Type.Function:
40 | filename = "function.tar.gz"
41 | field_name = "archive"
42 |
43 | elif type_ == schemas.Type.Dataset:
44 | filename = "opener.py"
45 | field_name = "opener"
46 |
47 | else:
48 | raise ValueError(f"Cannot download this type of asset {type_}")
49 |
50 | return filename, field_name
51 |
52 | def login(self, username, password):
53 | if self._remote:
54 | self._remote.login(username, password)
55 |
56 | def logout(self):
57 | if self._remote:
58 | self._remote.logout()
59 |
60 | def add(self, asset):
61 | return self._db.add(asset)
62 |
63 | def remote_download(self, asset_type, url_field_path, key, destination):
64 | self._remote.download(asset_type, url_field_path, key, destination)
65 |
66 | def remote_download_model(self, key, destination_file):
67 | self._remote.download_model(key, destination_file)
68 |
69 | def get_remote_description(self, asset_type, key):
70 | return self._remote.describe(asset_type, key)
71 |
72 | def get_with_files(self, type_: schemas.Type, key: str):
73 | """Get the asset with files on the local disk for execution.
74 | This does not load the description as it is not required for execution.
75 | """
76 | try:
77 | # Try to find the asset locally
78 | return self._db.get(type_, key)
79 | except exceptions.NotFound:
80 | if self._remote is not None:
81 | # if not found, try remotely
82 | filename, field_name = self._get_asset_content_filename(type_)
83 | asset = self._remote.get(type_, key)
84 | tmp_directory = self.tmp_dir / key
85 | asset_path = tmp_directory / filename
86 |
87 | if not tmp_directory.exists():
88 | pathlib.Path.mkdir(tmp_directory)
89 |
90 | self._remote.download(
91 | type_,
92 | field_name + ".storage_address",
93 | key,
94 | asset_path,
95 | )
96 |
97 | attr = getattr(asset, field_name)
98 | attr.storage_address = asset_path
99 | return asset
100 | raise
101 |
102 | def get(self, type_, key: str):
103 | try:
104 | # Try to find the asset locally
105 | return self._db.get(type_, key)
106 | except exceptions.NotFound:
107 | if self._remote is not None:
108 | return self._remote.get(type_, key)
109 | raise
110 |
111 | def get_performances(self, key: str) -> models.Performances:
112 | """Get the performances of a given compute. Return models.Performances() object
113 | easily convertible to dict, filled by the performances data of done tasks that output a performance.
114 | """
115 | compute_plan = self.get(schemas.Type.ComputePlan, key)
116 | list_tasks = self.list(
117 | schemas.Type.Task,
118 | filters={"compute_plan_key": [key]},
119 | order_by="rank",
120 | ascending=True,
121 | )
122 |
123 | performances = models.Performances()
124 |
125 | for task in list_tasks:
126 | if task.status == models.ComputeTaskStatus.done:
127 | function = self.get(schemas.Type.Function, task.function.key)
128 | perf_identifiers = [
129 | output.identifier for output in function.outputs if output.kind == schemas.AssetKind.performance
130 | ]
131 | outputs = self.list(
132 | schemas.Type.OutputAsset, {"compute_task_key": task.key, "identifier": perf_identifiers}
133 | )
134 | for output in outputs:
135 | performances.compute_plan_key.append(compute_plan.key)
136 | performances.compute_plan_tag.append(compute_plan.tag)
137 | performances.compute_plan_status.append(compute_plan.status)
138 | performances.compute_plan_start_date.append(compute_plan.start_date)
139 | performances.compute_plan_end_date.append(compute_plan.end_date)
140 | performances.compute_plan_metadata.append(compute_plan.metadata)
141 |
142 | performances.worker.append(task.worker)
143 | performances.task_key.append(task.key)
144 | performances.task_rank.append(task.rank)
145 | try:
146 | round_idx = int(task.metadata.get("round_idx"))
147 | except TypeError:
148 | round_idx = None
149 | performances.round_idx.append(round_idx)
150 | performances.identifier.append(output.identifier)
151 | performances.performance.append(output.asset)
152 |
153 | return performances
154 |
155 | def list(
156 | self, type_: str, filters: typing.Dict[str, typing.List[str]], order_by: str = None, ascending: bool = False
157 | ):
158 | """Joins the results of the [local db](substra.sdk.backends.local.db.list) and the
159 | [remote db](substra.sdk.backends.rest_client.list) in hybrid mode.
160 | """
161 |
162 | local_assets = self._db.list(type_=type_, filters=filters, order_by=order_by, ascending=ascending)
163 |
164 | remote_assets = []
165 | if self._remote:
166 | try:
167 | remote_assets = self._remote.list(
168 | asset_type=type_, filters=filters, order_by=order_by, ascending=ascending
169 | )
170 | except Exception as e:
171 | logger.info(
172 | f"Could not list assets from the remote platform:\n{e}. \
173 | \nIf you are not logged to a remote platform, ignore this message."
174 | )
175 | return local_assets + remote_assets
176 |
177 | def save_file(self, file_path: typing.Union[str, pathlib.Path], key: str):
178 | """Copy file or directory into the local temp dir to mimick
179 | the remote backend that saves the files given by the user.
180 | """
181 | tmp_directory = self.tmp_dir / key
182 | tmp_file = tmp_directory / pathlib.Path(file_path).name
183 |
184 | if not tmp_directory.exists():
185 | pathlib.Path.mkdir(tmp_directory)
186 |
187 | if tmp_file.exists():
188 | raise exceptions.AlreadyExists(f"File {tmp_file.name} already exists for asset {key}", 409)
189 | elif pathlib.Path(file_path).is_file():
190 | shutil.copyfile(file_path, tmp_file)
191 | elif pathlib.Path(file_path).is_dir():
192 | shutil.copytree(file_path, tmp_file)
193 | else:
194 | raise exceptions.InvalidRequest(f"Could not copy {file_path}", 400)
195 | return tmp_file
196 |
197 | def update(self, asset):
198 | self._db.update(asset)
199 | return
200 |
--------------------------------------------------------------------------------
/substra/sdk/backends/local/db.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | import typing
4 |
5 | from substra.sdk import exceptions
6 | from substra.sdk import models
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class InMemoryDb:
12 | """In memory data db."""
13 |
14 | def __init__(self):
15 | # assets stored per type and per key
16 | self._data = collections.defaultdict(dict)
17 |
18 | def add(self, asset):
19 | """Add an asset."""
20 | type_ = asset.__class__.type_
21 | key = getattr(asset, "key", None)
22 | if not key:
23 | key = asset.id
24 | if key in self._data[type_]:
25 | raise exceptions.KeyAlreadyExistsError(f"The asset key {key} of type {type_} has already been used.")
26 | self._data[type_][key] = asset
27 | logger.info(f"{type_} with key '{key}' has been created.")
28 |
29 | return asset
30 |
31 | def get(self, type_, key: str):
32 | """Return asset."""
33 | try:
34 | return self._data[type_][key]
35 | except KeyError:
36 | raise exceptions.NotFound(f"Wrong pk {key}", 404)
37 |
38 | def _match_asset(self, asset: models._Model, attribute: str, values: typing.Union[typing.Dict, typing.List]):
39 | """Checks if an asset attributes matches the given values.
40 | For the metadata, it checks that all the given filters returns True (AND condition)"""
41 | if attribute == "metadata":
42 | metadata_conditions = []
43 | for value in values:
44 | if value["type"] == models.MetadataFilterType.exists:
45 | metadata_conditions.append(value["key"] in asset.metadata.keys())
46 |
47 | elif asset.metadata.get(value["key"]) is None:
48 | # for is_equal and contains, if the key is not there then return False
49 | metadata_conditions.append(False)
50 |
51 | elif value["type"] == models.MetadataFilterType.is_equal:
52 | metadata_conditions.append(str(value["value"]) == str(asset.metadata[value["key"]]))
53 |
54 | elif value["type"] == models.MetadataFilterType.contains:
55 | metadata_conditions.append(str(value["value"]) in str(asset.metadata.get(value["key"])))
56 | else:
57 | raise NotImplementedError
58 |
59 | return all(metadata_conditions)
60 |
61 | return str(getattr(asset, attribute)) in values
62 |
63 | def _filter_assets(
64 | self, db_assets: typing.List[models._Model], filters: typing.Dict[str, typing.List[str]]
65 | ) -> typing.List[models._Model]:
66 | """Return assets matching al the given filters"""
67 |
68 | matching_assets = [
69 | asset
70 | for asset in db_assets
71 | if all(self._match_asset(asset, attribute, values) for attribute, values in filters.items())
72 | ]
73 | return matching_assets
74 |
75 | def list(
76 | self, type_: str, filters: typing.Dict[str, typing.List[str]], order_by: str = None, ascending: bool = False
77 | ):
78 | """List assets by filters.
79 |
80 | Args:
81 | asset_type (str): asset type. e.g. "function"
82 | filters (dict, optional): keys = attributes, values = list of values for this attribute.
83 | e.g. {"name": ["name1", "name2"]}. "," corresponds to an "OR". Defaults to None.
84 | order_by (str, optional): attribute name to order the results on. Defaults to None.
85 | e.g. "name" for an ordering on name.
86 | ascending (bool, optional): to reverse ordering. Defaults to False (descending order).
87 |
88 | Returns:
89 | List[Dict] : a List of assets (dicts)
90 | """
91 | # get all assets of this type
92 | assets = list(self._data[type_].values())
93 |
94 | if filters:
95 | assets = self._filter_assets(assets, filters)
96 | if order_by:
97 | assets.sort(key=lambda x: getattr(x, order_by), reverse=(not ascending))
98 |
99 | return assets
100 |
101 | def update(self, asset):
102 | type_ = asset.__class__.type_
103 | key = asset.key
104 |
105 | if key not in self._data[type_]:
106 | raise exceptions.NotFound(f"Wrong pk {key}", 404)
107 |
108 | self._data[type_][key] = asset
109 | return
110 |
111 |
112 | db = InMemoryDb()
113 |
114 |
115 | def get_db():
116 | return db
117 |
--------------------------------------------------------------------------------
/substra/sdk/backends/local/models.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from typing import List
3 |
4 | from pydantic import Field
5 |
6 | from substra.sdk import models
7 | from substra.sdk import schemas
8 |
9 |
10 | class _TaskAssetLocal(schemas._PydanticConfig):
11 | key: str = Field(default_factory=uuid.uuid4)
12 | compute_task_key: str
13 |
14 | @staticmethod
15 | def allowed_filters() -> List[str]:
16 | return super().allowed_filters() + ["compute_task_key"]
17 |
18 |
19 | class OutputAssetLocal(models.OutputAsset, _TaskAssetLocal):
20 | pass
21 |
22 |
23 | class InputAssetLocal(models.InputAsset, _TaskAssetLocal):
24 | pass
25 |
--------------------------------------------------------------------------------
/substra/sdk/backends/remote/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/substra/sdk/backends/remote/__init__.py
--------------------------------------------------------------------------------
/substra/sdk/backends/remote/request_formatter.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def format_search_filters_for_remote(filters):
5 | formatted_filters = {}
6 | # do not process if no filters
7 | if filters is None:
8 | return formatted_filters
9 |
10 | for key in filters:
11 | # handle special cases name and metadata
12 | if key == "name":
13 | formatted_filters["match"] = filters[key]
14 | elif key == "metadata":
15 | formatted_filters["metadata"] = json.dumps(filters["metadata"]).replace(" ", "")
16 |
17 | else:
18 | # all other filters are formatted as a csv string without spaces
19 | values = ",".join(filters[key])
20 | formatted_filters[key] = values.replace(" ", "")
21 |
22 | return formatted_filters
23 |
24 |
25 | def format_search_ordering_for_remote(order_by, ascending):
26 | if not ascending:
27 | return "-" + order_by
28 | return order_by
29 |
--------------------------------------------------------------------------------
/substra/sdk/compute_plan.py:
--------------------------------------------------------------------------------
1 | from substra.sdk import exceptions
2 | from substra.sdk import graph
3 | from substra.sdk import schemas
4 |
5 |
6 | def _insert_into_graph(task_graph, task_id, in_model_ids):
7 | if task_id in task_graph:
8 | raise exceptions.InvalidRequest("Two tasks cannot have the same id.", 400)
9 | task_graph[task_id] = in_model_ids
10 |
11 |
12 | def get_dependency_graph(spec: schemas._BaseComputePlanSpec):
13 | """Get the task dependency graph and, for each type of task, a mapping table id/task."""
14 | task_graph = {}
15 | tasks = {}
16 |
17 | if spec.tasks:
18 | for task in spec.tasks:
19 | _insert_into_graph(
20 | task_graph=task_graph,
21 | task_id=task.task_id,
22 | in_model_ids=[
23 | input_ref.parent_task_key for input_ref in (task.inputs or []) if input_ref.parent_task_key
24 | ],
25 | )
26 | tasks[task.task_id] = schemas.TaskSpec.from_compute_plan(
27 | compute_plan_key=spec.key,
28 | rank=None,
29 | spec=task,
30 | )
31 | return task_graph, tasks
32 |
33 |
34 | def get_tasks(spec):
35 | """Returns compute plan tasks sorted by dependencies."""
36 |
37 | # Create the dependency graph and get the dict of tasks by id
38 | task_graph, tasks = get_dependency_graph(spec)
39 |
40 | already_created_ids = set()
41 | # Here we get the pre-existing tasks and assign them the minimal rank
42 | for dependencies in task_graph.values():
43 | for dependency_id in dependencies:
44 | if dependency_id not in task_graph:
45 | already_created_ids.add(dependency_id)
46 |
47 | # Compute the relative ranks of the new tasks (relatively to each other, these
48 | # are not their actual ranks in the compute plan)
49 | id_ranks = graph.compute_ranks(node_graph=task_graph, node_to_ignore=already_created_ids)
50 |
51 | # Sort the tasks by rank
52 | sorted_by_rank = sorted(id_ranks.items(), key=lambda item: item[1])
53 |
54 | return [tasks[task_id] for task_id, _ in sorted_by_rank]
55 |
--------------------------------------------------------------------------------
/substra/sdk/exceptions.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union
3 |
4 | import requests
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class SDKException(Exception):
10 | pass
11 |
12 |
13 | class LoadDataException(SDKException):
14 | pass
15 |
16 |
17 | class RequestException(SDKException):
18 | def __init__(self, msg, status_code):
19 | self.msg = msg
20 | self.status_code = status_code
21 | super().__init__(msg)
22 |
23 | @classmethod
24 | def from_request_exception(cls, request_exception):
25 | msg = None
26 | try:
27 | msg = request_exception.response.json()["detail"]
28 | msg = f"{request_exception}: {msg}"
29 | except Exception:
30 | msg = str(request_exception)
31 |
32 | try:
33 | status_code = request_exception.response.status_code
34 | except AttributeError:
35 | status_code = None
36 |
37 | return cls(msg, status_code)
38 |
39 |
40 | class ConnectionError(RequestException):
41 | pass
42 |
43 |
44 | class Timeout(RequestException):
45 | pass
46 |
47 |
48 | class HTTPError(RequestException):
49 | pass
50 |
51 |
52 | class InternalServerError(HTTPError):
53 | pass
54 |
55 |
56 | class GatewayUnavailable(HTTPError):
57 | pass
58 |
59 |
60 | class InvalidRequest(HTTPError):
61 | def __init__(self, msg, status_code, errors=None):
62 | super().__init__(msg, status_code)
63 | self.errors = errors
64 |
65 | @classmethod
66 | def from_request_exception(cls, request_exception):
67 | try:
68 | error = request_exception.response.json()
69 | except ValueError:
70 | error = request_exception.response
71 |
72 | get_method = getattr(error, "get", None)
73 | if callable(get_method):
74 | msg = get_method("detail", str(error))
75 | else:
76 | msg = str(error)
77 |
78 | try:
79 | status_code = request_exception.response.status_code
80 | except AttributeError:
81 | status_code = None
82 |
83 | return cls(msg, status_code, error)
84 |
85 |
86 | class NotFound(HTTPError):
87 | pass
88 |
89 |
90 | class RequestTimeout(HTTPError):
91 | def __init__(self, key, status_code):
92 | self.key = key
93 | msg = f"Operation on object with key(s) '{key}' timed out."
94 | super().__init__(msg, status_code)
95 |
96 | @classmethod
97 | def from_request_exception(cls, request_exception):
98 | # parse response and fetch key
99 | r = request_exception.response.json()
100 |
101 | try:
102 | key = r["key"] if "key" in r else r["detail"].get("key")
103 | except (AttributeError, KeyError):
104 | # XXX this is the case when doing a POST query to update the
105 | # data manager for instance
106 | key = None
107 |
108 | return cls(key, request_exception.response.status_code)
109 |
110 |
111 | class AlreadyExists(HTTPError):
112 | def __init__(self, key, status_code):
113 | self.key = key
114 | msg = f"Object with key(s) '{key}' already exists."
115 | super().__init__(msg, status_code)
116 |
117 | @classmethod
118 | def from_request_exception(cls, request_exception):
119 | # parse response and fetch key
120 | r = request_exception.response.json()
121 | # XXX support list of keys; this could be the case when adding
122 | # a list of data samples through a single POST request
123 | if isinstance(r, list):
124 | key = [x["key"] for x in r]
125 | elif isinstance(r, dict):
126 | key = r.get("key", None)
127 | else:
128 | key = r
129 |
130 | return cls(key, request_exception.response.status_code)
131 |
132 |
133 | class InvalidResponse(SDKException):
134 | def __init__(self, response, msg):
135 | self.response = response
136 | super(InvalidResponse, self).__init__(msg)
137 |
138 |
139 | class AuthenticationError(HTTPError):
140 | pass
141 |
142 |
143 | class AuthorizationError(HTTPError):
144 | pass
145 |
146 |
147 | class BadLoginException(RequestException):
148 | """The server refused to log-in with these credentials"""
149 |
150 | pass
151 |
152 |
153 | class UsernamePasswordLoginDisabledException(RequestException):
154 | """The server disabled the endpoint, preventing the use of Client.login"""
155 |
156 | @classmethod
157 | def from_request_exception(cls, request_exception):
158 | base = super().from_request_exception(request_exception)
159 | return cls(
160 | base.msg
161 | + (
162 | "\n\nAuthenticating with username/password is disabled.\n"
163 | "Log onto the frontend for your instance and generate a token there, "
164 | 'then use it in the Client(token="...") constructor: '
165 | "https://docs.substra.org/en/stable/documentation/api_tokens_generation.html"
166 | ),
167 | base.status_code,
168 | )
169 |
170 |
171 | def from_request_exception(
172 | e: requests.exceptions.RequestException,
173 | ) -> Union[RequestException, requests.exceptions.RequestException]:
174 | """
175 | try turning an exception from the `requests` library into a Substra exception
176 | """
177 | connection_error_mapping: dict[requests.exceptions.RequestException, RequestException] = {
178 | requests.exceptions.ConnectionError: ConnectionError,
179 | requests.exceptions.Timeout: Timeout,
180 | }
181 | for k, v in connection_error_mapping.items():
182 | if isinstance(e, k):
183 | return v.from_request_exception(e)
184 |
185 | http_status_mapping: dict[int, RequestException] = {
186 | 400: InvalidRequest,
187 | 401: AuthenticationError,
188 | 403: AuthorizationError,
189 | 404: NotFound,
190 | 408: RequestTimeout,
191 | 409: AlreadyExists,
192 | 500: InternalServerError,
193 | 502: GatewayUnavailable,
194 | 503: GatewayUnavailable,
195 | 504: GatewayUnavailable,
196 | }
197 | if isinstance(e, requests.exceptions.HTTPError):
198 | logger.error(f"Requests error status {e.response.status_code}: {e.response.text}")
199 | return http_status_mapping.get(e.response.status_code, HTTPError).from_request_exception(e)
200 |
201 | return e
202 |
203 |
204 | class ConfigurationInfoError(SDKException):
205 | """ConfigurationInfoError"""
206 |
207 | pass
208 |
209 |
210 | class BadConfiguration(SDKException):
211 | """Bad configuration"""
212 |
213 | pass
214 |
215 |
216 | class UserException(SDKException):
217 | """User Exception"""
218 |
219 | pass
220 |
221 |
222 | class EmptyInModelException(SDKException):
223 | """No in_models when needed"""
224 |
225 | pass
226 |
227 |
228 | class ComputePlanKeyFormatError(Exception):
229 | """The given compute plan key has to respect the UUID format."""
230 |
231 | pass
232 |
233 |
234 | class OrderingFormatError(Exception):
235 | """The given ordering parameter has to respect expected format."""
236 |
237 | pass
238 |
239 |
240 | class FilterFormatError(Exception):
241 | """The given filters has to respect expected format."""
242 |
243 | pass
244 |
245 |
246 | class NotAllowedFilterError(Exception):
247 | """The given filter is not available on asset."""
248 |
249 | pass
250 |
251 |
252 | class KeyAlreadyExistsError(Exception):
253 | """The asset key has already been used."""
254 |
255 | pass
256 |
257 |
258 | class _TaskAssetError(Exception):
259 | """Base exception class for task asset error"""
260 |
261 | def __init__(self, *, compute_task_key: str, identifier: str, message: str):
262 | self.compute_task_key = compute_task_key
263 | self.identifier = identifier
264 | self.message = message
265 | super().__init__(self.message)
266 |
267 | pass
268 |
269 |
270 | class TaskAssetNotFoundError(_TaskAssetError):
271 | """Exception raised when no task input/output asset have not been found for specific task key and identifier"""
272 |
273 | def __init__(self, compute_task_key: str, identifier: str):
274 | message = f"No task asset found with {compute_task_key=} and {identifier=}"
275 | super().__init__(compute_task_key=compute_task_key, identifier=identifier, message=message)
276 |
277 | pass
278 |
279 |
280 | class TaskAssetMultipleFoundError(_TaskAssetError):
281 | """Exception raised when more than one task input/output assets have been found for specific task key and
282 | identifier"""
283 |
284 | def __init__(self, compute_task_key: str, identifier: str):
285 | message = f"Multiple task assets found with {compute_task_key=} and {identifier=}"
286 | super().__init__(compute_task_key=compute_task_key, identifier=identifier, message=message)
287 |
288 |
289 | class FutureError(Exception):
290 | """Error while waiting a blocking operation to complete"""
291 |
292 |
293 | class FutureTimeoutError(FutureError):
294 | """Future execution timed out."""
295 |
296 |
297 | class FutureFailureError(FutureError):
298 | """Future execution failed."""
299 |
--------------------------------------------------------------------------------
/substra/sdk/fs.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from substra.sdk.hasher import Hasher
4 |
5 | _BLOCK_SIZE = 64 * 1024
6 |
7 |
8 | def hash_file(path):
9 | """Hash a file."""
10 | hasher = Hasher()
11 |
12 | with open(path, "rb") as fp:
13 | while True:
14 | data = fp.read(_BLOCK_SIZE)
15 | if not data:
16 | break
17 | hasher.update(data)
18 | return hasher.compute()
19 |
20 |
21 | def hash_directory(path, followlinks=False):
22 | """Hash a directory."""
23 |
24 | if not os.path.isdir(path):
25 | raise TypeError(f"{path} is not a directory.")
26 |
27 | hash_values = []
28 | for root, dirs, files in os.walk(path, topdown=True, followlinks=followlinks):
29 | dirs.sort()
30 | files.sort()
31 |
32 | for fname in files:
33 | hash_values.append(hash_file(os.path.join(root, fname)))
34 |
35 | return Hasher(values=sorted(hash_values)).compute()
36 |
--------------------------------------------------------------------------------
/substra/sdk/graph.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | from substra.sdk import exceptions
4 |
5 |
6 | def _get_inverted_node_graph(node_graph, node_to_ignore):
7 | """Get the graph {node_id: nodes that depend on this one}
8 | Also this graph does not contain the nodes to ignore
9 | """
10 | inverted = dict()
11 | for node, dependencies in node_graph.items():
12 | if node not in node_to_ignore:
13 | for dependency in dependencies:
14 | if dependency not in node_to_ignore:
15 | inverted.setdefault(dependency, [])
16 | inverted[dependency].append(node)
17 | return inverted
18 |
19 |
20 | def _breadth_first_traversal_rank(
21 | ranks: typing.Dict[str, int], inverted_node_graph: typing.Dict[str, typing.List[str]]
22 | ):
23 | edges = set()
24 | queue = [node for node in ranks]
25 | visited = set(queue)
26 |
27 | if len(queue) == 0:
28 | raise exceptions.InvalidRequest("missing dependency among inModels IDs, circular dependency found", 400)
29 |
30 | while len(queue) > 0:
31 | current_node = queue.pop(0)
32 | for child in inverted_node_graph.get(current_node, []):
33 | new_child_rank = max(ranks[current_node] + 1, ranks.get(child, -1))
34 |
35 | if new_child_rank != ranks.get(child, -1):
36 | # either the child has never been visited
37 | # or its rank has been updated and we must visit again
38 | ranks[child] = new_child_rank
39 | visited.add(child)
40 | queue.append(child)
41 |
42 | # Cycle detection
43 | edge = (current_node, child)
44 | if (edge[1], edge[0]) in edges:
45 | raise exceptions.InvalidRequest(
46 | f"missing dependency among inModels IDs, \
47 | circular dependency between {edge[0]} and {edge[1]}",
48 | 400,
49 | )
50 | else:
51 | edges.add(edge)
52 | return ranks
53 |
54 |
55 | def compute_ranks(
56 | node_graph: typing.Dict[str, typing.List[str]],
57 | node_to_ignore: typing.Set[str] = None,
58 | ranks: typing.Dict[str, int] = None,
59 | ) -> typing.Dict[str, int]:
60 | """Compute the ranks of the nodes in the graph.
61 |
62 | Args:
63 | node_graph (typing.Dict[str, typing.List[str]]):
64 | Dict {node_id: list of nodes it depends on}.
65 | Node graph keys must not contain any node to ignore.
66 | node_to_ignore (typing.Set[str], optional): List of nodes to ignore.
67 | Defaults to None.
68 | ranks (typing.Dict[str, int]): Already computed ranks. Defaults to None.
69 |
70 | Raises:
71 | exceptions.InvalidRequest: If the node graph contains a cycle
72 |
73 | Returns:
74 | typing.Dict[str, int]: Dict { node_id : rank }
75 | """
76 | ranks = ranks or dict()
77 | node_to_ignore = node_to_ignore or set()
78 |
79 | if len(node_graph) == 0:
80 | return dict()
81 |
82 | extra_nodes = set(node_graph.keys()).intersection(node_to_ignore)
83 | if len(extra_nodes) > 0:
84 | raise ValueError(f"node_graph keys should not contain any node to ignore: {extra_nodes}")
85 |
86 | inverted_node_graph = _get_inverted_node_graph(node_graph, node_to_ignore)
87 |
88 | # Assign rank 0 to nodes without deps
89 | for node, dependencies in node_graph.items():
90 | if node not in node_to_ignore:
91 | actual_deps = [dep for dep in dependencies if dep not in node_to_ignore]
92 | if len(actual_deps) == 0:
93 | ranks[node] = 0
94 |
95 | ranks = _breadth_first_traversal_rank(ranks=ranks, inverted_node_graph=inverted_node_graph)
96 |
97 | return ranks
98 |
--------------------------------------------------------------------------------
/substra/sdk/hasher.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 |
3 |
4 | class Hasher:
5 | def __init__(self, values=None):
6 | self._h = hashlib.sha256()
7 | if values:
8 | for v in values:
9 | self.update(v)
10 |
11 | def update(self, v):
12 | if isinstance(v, str):
13 | v = v.encode("utf-8")
14 | self._h.update(v)
15 |
16 | def compute(self):
17 | return self._h.hexdigest()
18 |
--------------------------------------------------------------------------------
/substra/sdk/models.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import enum
3 | import re
4 | from datetime import datetime
5 | from typing import ClassVar
6 | from typing import Dict
7 | from typing import List
8 | from typing import Optional
9 | from typing import Type
10 | from typing import Union
11 |
12 | import pydantic
13 | from pydantic import AnyUrl
14 | from pydantic import ConfigDict
15 | from pydantic import DirectoryPath
16 | from pydantic import FilePath
17 | from pydantic.fields import Field
18 |
19 | from substra.sdk import schemas
20 |
21 | # The remote can return an URL or an empty string for paths
22 | UriPath = Union[FilePath, AnyUrl, str]
23 | CAMEL_TO_SNAKE_PATTERN = re.compile(r"(.)([A-Z][a-z]+)")
24 | CAMEL_TO_SNAKE_PATTERN_2 = re.compile(r"([a-z0-9])([A-Z])")
25 |
26 |
27 | class MetadataFilterType(str, enum.Enum):
28 | is_equal = "is"
29 | contains = "contains"
30 | exists = "exists"
31 |
32 |
33 | class ComputeTaskStatus(str, enum.Enum):
34 | """Status of the task"""
35 |
36 | unknown = "STATUS_UNKNOWN"
37 | building = "STATUS_BUILDING"
38 | executing = "STATUS_EXECUTING"
39 | done = "STATUS_DONE"
40 | failed = "STATUS_FAILED"
41 | waiting_for_executor_slot = "STATUS_WAITING_FOR_EXECUTOR_SLOT"
42 | waiting_for_parent_tasks = "STATUS_WAITING_FOR_PARENT_TASKS"
43 | waiting_for_builder_slot = "STATUS_WAITING_FOR_BUILDER_SLOT"
44 | canceled = "STATUS_CANCELED"
45 |
46 |
47 | class ComputePlanStatus(str, enum.Enum):
48 | """Status of the compute plan"""
49 |
50 | unknown = "PLAN_STATUS_UNKNOWN"
51 | doing = "PLAN_STATUS_DOING"
52 | done = "PLAN_STATUS_DONE"
53 | failed = "PLAN_STATUS_FAILED"
54 | created = "PLAN_STATUS_CREATED"
55 | canceled = "PLAN_STATUS_CANCELED"
56 |
57 |
58 | class FunctionStatus(str, enum.Enum):
59 | """Status of the function"""
60 |
61 | unknown = "FUNCTION_STATUS_UNKNOWN"
62 | waiting = "FUNCTION_STATUS_WAITING"
63 | building = "FUNCTION_STATUS_BUILDING"
64 | ready = "FUNCTION_STATUS_READY"
65 | failed = "FUNCTION_STATUS_FAILED"
66 | canceled = "FUNCTION_STATUS_CANCELED"
67 |
68 |
69 | class TaskErrorType(str, enum.Enum):
70 | """Types of errors that can occur in a task"""
71 |
72 | build = "BUILD_ERROR"
73 | execution = "EXECUTION_ERROR"
74 | internal = "INTERNAL_ERROR"
75 |
76 |
77 | class OrderingFields(str, enum.Enum):
78 | """Model fields ordering is allowed on for list"""
79 |
80 | creation_date = "creation_date"
81 | start_date = "start_date"
82 | end_date = "end_date"
83 |
84 | @classmethod
85 | def __contains__(cls, item):
86 | try:
87 | cls(item)
88 | except ValueError:
89 | return False
90 | else:
91 | return True
92 |
93 |
94 | class Permission(schemas._PydanticConfig):
95 | """Permissions of a task"""
96 |
97 | public: bool
98 | authorized_ids: List[str]
99 |
100 |
101 | class Permissions(schemas._PydanticConfig):
102 | """Permissions structure stored in various asset types."""
103 |
104 | process: Permission
105 |
106 |
107 | class _Model(schemas._PydanticConfig, abc.ABC):
108 | """Asset creation specification base class."""
109 |
110 | # pretty print
111 | def __str__(self):
112 | return self.model_dump_json(indent=4)
113 |
114 | def __repr__(self):
115 | return self.model_dump_json(indent=4)
116 |
117 | @staticmethod
118 | def allowed_filters() -> List[str]:
119 | """allowed fields to filter on"""
120 | return []
121 |
122 |
123 | class DataSample(_Model):
124 | """Data sample"""
125 |
126 | key: str
127 | owner: str
128 | data_manager_keys: Optional[List[str]] = None
129 | path: Optional[DirectoryPath] = None
130 | creation_date: datetime
131 |
132 | type_: ClassVar[str] = schemas.Type.DataSample
133 |
134 | @staticmethod
135 | def allowed_filters() -> List[str]:
136 | return ["key", "owner", "compute_plan_key", "function_key", "dataset_key"]
137 |
138 |
139 | class _File(schemas._PydanticConfig):
140 | """File as stored in the models"""
141 |
142 | checksum: str
143 | storage_address: UriPath
144 |
145 |
146 | class Dataset(_Model):
147 | """Dataset asset"""
148 |
149 | key: str
150 | name: str
151 | owner: str
152 | permissions: Permissions
153 | data_sample_keys: List[str] = []
154 | opener: _File
155 | description: _File
156 | metadata: Dict[str, str]
157 | creation_date: datetime
158 | logs_permission: Permission
159 |
160 | type_: ClassVar[str] = schemas.Type.Dataset
161 |
162 | @staticmethod
163 | def allowed_filters() -> List[str]:
164 | return ["key", "name", "owner", "permissions", "compute_plan_key", "function_key", "data_sample_key"]
165 |
166 |
167 | class FunctionInput(_Model):
168 | identifier: str
169 | kind: schemas.AssetKind
170 | optional: bool
171 | multiple: bool
172 |
173 |
174 | class FunctionOutput(_Model):
175 | identifier: str
176 | kind: schemas.AssetKind
177 | multiple: bool
178 |
179 |
180 | class Function(_Model):
181 | key: str
182 | name: str
183 | owner: str
184 | permissions: Permissions
185 | metadata: Dict[str, str]
186 | creation_date: datetime
187 | inputs: List[FunctionInput]
188 | outputs: List[FunctionOutput]
189 | status: FunctionStatus
190 |
191 | description: _File
192 | archive: _File
193 |
194 | type_: ClassVar[str] = schemas.Type.Function
195 |
196 | @staticmethod
197 | def allowed_filters() -> List[str]:
198 | return ["key", "name", "owner", "permissions", "compute_plan_key", "dataset_key", "data_sample_key"]
199 |
200 | @pydantic.field_validator("inputs", mode="before")
201 | @classmethod
202 | def dict_input_to_list(cls, v):
203 | if isinstance(v, dict):
204 | # Transform the inputs dict to a list
205 | return [
206 | FunctionInput(
207 | identifier=identifier,
208 | kind=function_input["kind"],
209 | optional=function_input["optional"],
210 | multiple=function_input["multiple"],
211 | )
212 | for identifier, function_input in v.items()
213 | ]
214 | else:
215 | return v
216 |
217 | @pydantic.field_validator("outputs", mode="before")
218 | @classmethod
219 | def dict_output_to_list(cls, v):
220 | if isinstance(v, dict):
221 | # Transform the outputs dict to a list
222 | return [
223 | FunctionOutput(
224 | identifier=identifier, kind=function_output["kind"], multiple=function_output["multiple"]
225 | )
226 | for identifier, function_output in v.items()
227 | ]
228 | else:
229 | return v
230 |
231 |
232 | class InModel(schemas._PydanticConfig):
233 | """In model of a task"""
234 |
235 | checksum: str
236 | storage_address: UriPath
237 |
238 |
239 | class OutModel(schemas._PydanticConfig):
240 | """Out model of a task"""
241 |
242 | key: str
243 | compute_task_key: str
244 | address: Optional[InModel] = None
245 | permissions: Permissions
246 | owner: str
247 | creation_date: datetime
248 |
249 | type_: ClassVar[str] = schemas.Type.Model
250 |
251 | @staticmethod
252 | def allowed_filters() -> List[str]:
253 | return ["key", "compute_task_key", "owner", "permissions"]
254 |
255 |
256 | class InputRef(schemas._PydanticConfig):
257 | identifier: str
258 | asset_key: Optional[str] = None
259 | parent_task_key: Optional[str] = None
260 | parent_task_output_identifier: Optional[str] = None
261 |
262 | # either (asset_key) or (parent_task_key, parent_task_output_identifier) must be specified
263 | _check_asset_key_or_parent_ref = pydantic.model_validator(mode="before")(schemas.check_asset_key_or_parent_ref)
264 |
265 |
266 | class ComputeTaskOutput(schemas._PydanticConfig):
267 | """Specification of a compute task input"""
268 |
269 | permissions: Permissions
270 | is_transient: bool = Field(False, alias="transient")
271 | model_config = ConfigDict(populate_by_name=True)
272 |
273 |
274 | class Task(_Model):
275 | key: str
276 | function: Function
277 | owner: str
278 | compute_plan_key: str
279 | metadata: Dict[str, str]
280 | status: ComputeTaskStatus
281 | worker: str
282 | rank: Optional[int] = None
283 | tag: str
284 | creation_date: datetime
285 | start_date: Optional[datetime] = None
286 | end_date: Optional[datetime] = None
287 | error_type: Optional[TaskErrorType] = None
288 | inputs: List[InputRef]
289 | outputs: Dict[str, ComputeTaskOutput]
290 |
291 | type_: ClassVar[Type] = schemas.Type.Task
292 |
293 | @staticmethod
294 | def allowed_filters() -> List[str]:
295 | return [
296 | "key",
297 | "owner",
298 | "worker",
299 | "rank",
300 | "status",
301 | "metadata",
302 | "compute_plan_key",
303 | "function_key",
304 | ]
305 |
306 |
307 | Task.model_rebuild()
308 |
309 |
310 | class ComputePlan(_Model):
311 | """ComputePlan"""
312 |
313 | key: str
314 | tag: str
315 | name: str
316 | owner: str
317 | metadata: Dict[str, str]
318 | task_count: int = 0
319 | waiting_builder_slot_count: int = 0
320 | building_count: int = 0
321 | waiting_parent_tasks_count: int = 0
322 | waiting_executor_slot_count: int = 0
323 | executing_count: int = 0
324 | canceled_count: int = 0
325 | failed_count: int = 0
326 | done_count: int = 0
327 | failed_task_key: Optional[str] = None
328 | status: ComputePlanStatus
329 | creation_date: datetime
330 | start_date: Optional[datetime] = None
331 | end_date: Optional[datetime] = None
332 | estimated_end_date: Optional[datetime] = None
333 | duration: Optional[int] = None
334 | creator: Optional[str] = None
335 |
336 | type_: ClassVar[str] = schemas.Type.ComputePlan
337 |
338 | @staticmethod
339 | def allowed_filters() -> List[str]:
340 | return [
341 | "key",
342 | "name",
343 | "owner",
344 | "worker",
345 | "status",
346 | "metadata",
347 | "function_key",
348 | "dataset_key",
349 | "data_sample_key",
350 | ]
351 |
352 |
353 | class Performances(_Model):
354 | """Performances of the different compute tasks of a compute plan"""
355 |
356 | compute_plan_key: List[str] = []
357 | compute_plan_tag: List[str] = []
358 | compute_plan_status: List[str] = []
359 | compute_plan_start_date: List[datetime] = []
360 | compute_plan_end_date: List[datetime] = []
361 | compute_plan_metadata: List[dict] = []
362 | worker: List[str] = []
363 | task_key: List[str] = []
364 | task_rank: List[int] = []
365 | round_idx: List[int] = []
366 | identifier: List[str] = []
367 | performance: List[float] = []
368 |
369 |
370 | class Organization(schemas._PydanticConfig):
371 | """Organization"""
372 |
373 | id: str
374 | is_current: bool
375 | creation_date: datetime
376 |
377 | type_: ClassVar[str] = schemas.Type.Organization
378 |
379 |
380 | class OrganizationInfoConfig(schemas._PydanticConfig, extra="allow"):
381 | model_config = ConfigDict(protected_namespaces=())
382 | model_export_enabled: bool
383 |
384 |
385 | class OrganizationInfo(schemas._PydanticConfig):
386 | host: AnyUrl
387 | organization_id: str
388 | organization_name: str
389 | config: OrganizationInfoConfig
390 | channel: str
391 | version: str
392 | orchestrator_version: str
393 |
394 |
395 | class _TaskAsset(schemas._PydanticConfig):
396 | kind: str
397 | identifier: str
398 |
399 | @staticmethod
400 | def allowed_filters() -> List[str]:
401 | return ["identifier", "kind"]
402 |
403 |
404 | class InputAsset(_TaskAsset):
405 | asset: Union[Dataset, DataSample, OutModel]
406 | type_: ClassVar[str] = schemas.Type.InputAsset
407 |
408 |
409 | class OutputAsset(_TaskAsset):
410 | asset: Union[float, OutModel]
411 | type_: ClassVar[str] = schemas.Type.OutputAsset
412 |
413 | # Deal with remote returning the actual performance object
414 | @pydantic.field_validator("asset", mode="before")
415 | @classmethod
416 | def convert_remote_performance(cls, value, values):
417 | if values.data.get("kind") == schemas.AssetKind.performance and isinstance(value, dict):
418 | return value.get("performance_value")
419 |
420 | return value
421 |
422 |
423 | SCHEMA_TO_MODEL = {
424 | schemas.Type.Task: Task,
425 | schemas.Type.Function: Function,
426 | schemas.Type.ComputePlan: ComputePlan,
427 | schemas.Type.DataSample: DataSample,
428 | schemas.Type.Dataset: Dataset,
429 | schemas.Type.Organization: Organization,
430 | schemas.Type.Model: OutModel,
431 | schemas.Type.InputAsset: InputAsset,
432 | schemas.Type.OutputAsset: OutputAsset,
433 | schemas.Type.FunctionOutput: FunctionOutput,
434 | schemas.Type.FunctionInput: FunctionInput,
435 | }
436 |
--------------------------------------------------------------------------------
/substra/sdk/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import copy
3 | import functools
4 | import io
5 | import logging
6 | import ntpath
7 | import os
8 | import re
9 | import time
10 | import uuid
11 | import zipfile
12 |
13 | from substra.sdk import exceptions
14 | from substra.sdk import models
15 |
16 |
17 | def path_leaf(path):
18 | head, tail = ntpath.split(path)
19 | return tail or ntpath.basename(head)
20 |
21 |
22 | @contextlib.contextmanager
23 | def extract_files(data, file_attributes):
24 | data = copy.deepcopy(data)
25 |
26 | paths = {}
27 | for attr in file_attributes:
28 | try:
29 | paths[attr] = data[attr]
30 | except KeyError:
31 | raise exceptions.LoadDataException(f"The '{attr}' attribute is missing.")
32 | del data[attr]
33 |
34 | files = {}
35 | for k, f in paths.items():
36 | if not os.path.exists(f):
37 | raise exceptions.LoadDataException(f"The '{k}' attribute file ({f}) does not exist.")
38 | files[k] = open(f, "rb")
39 |
40 | try:
41 | yield (data, files)
42 | finally:
43 | for f in files.values():
44 | f.close()
45 |
46 |
47 | def zip_folder(fp, path, followlinks=False):
48 | zipf = zipfile.ZipFile(fp, "w", zipfile.ZIP_DEFLATED)
49 | for root, _, files in os.walk(path, followlinks=followlinks):
50 | for f in files:
51 | abspath = os.path.join(root, f)
52 | archive_path = os.path.relpath(abspath, start=path)
53 | zipf.write(abspath, arcname=archive_path)
54 | zipf.close()
55 |
56 |
57 | def zip_folder_in_memory(path, followlinks=False):
58 | fp = io.BytesIO()
59 | zip_folder(fp, path, followlinks=followlinks)
60 | fp.seek(0)
61 | return fp
62 |
63 |
64 | @contextlib.contextmanager
65 | def extract_data_sample_files(data, followlinks=False):
66 | # handle data sample specific case; paths and path cases
67 | data = copy.deepcopy(data)
68 |
69 | folders = {}
70 | if data.get("path"):
71 | attr = "path"
72 | folders[attr] = data[attr]
73 | del data[attr]
74 |
75 | if data.get("paths"): # field is set and is not None/empty
76 | for p in data["paths"]:
77 | folders[path_leaf(p)] = p
78 | del data["paths"]
79 |
80 | files = {}
81 | for k, f in folders.items():
82 | if not os.path.isdir(f):
83 | raise exceptions.LoadDataException(f"Paths '{f}' is not an existing directory")
84 | files[k] = zip_folder_in_memory(f, followlinks=followlinks)
85 |
86 | try:
87 | yield (data, files)
88 | finally:
89 | for f in files.values():
90 | f.close()
91 |
92 |
93 | def _check_metadata_search_filter(filter):
94 | if not isinstance(filter, dict):
95 | raise exceptions.FilterFormatError(
96 | "Cannot load filters. Please review the documentation, metadata filter should be a list of dict."
97 | "But one passed elements is not."
98 | )
99 |
100 | if "key" not in filter.keys() or "type" not in filter.keys():
101 | raise exceptions.FilterFormatError("Each metadata filter, must contains both `key` and `type` as key.")
102 |
103 | if filter["type"] not in ("is", "contains", "exists"):
104 | raise exceptions.FilterFormatError(
105 | "Each metadata filter `type` filed value must be `is`, `contains` or `exists`"
106 | )
107 |
108 | if filter["type"] in ("is", "contains"):
109 | if "value" not in filter.keys():
110 | raise exceptions.FilterFormatError(
111 | "For each metadata filter, if `type` value is `is` or `contains`, the filter should also contain the "
112 | "`value` key."
113 | )
114 |
115 | if not isinstance(filter.get("value"), str):
116 | raise exceptions.FilterFormatError(
117 | "For each metadata filter, if a `value` is passed, it should be a string."
118 | )
119 |
120 |
121 | def _check_metadata_search_filters(filters):
122 | if not isinstance(filters, list):
123 | raise exceptions.FilterFormatError(
124 | "Cannot load filters. Please review the documentation, metadata filter should be a list of dict."
125 | )
126 |
127 | for filter in filters:
128 | _check_metadata_search_filter(filter)
129 |
130 |
131 | def check_and_format_search_filters(asset_type, filters): # noqa: C901
132 | # do not check if no filters
133 | if filters is None:
134 | return filters
135 |
136 | # check filters structure
137 | if not isinstance(filters, dict):
138 | raise exceptions.FilterFormatError(
139 | "Cannot load filters. Please review the documentation, filters should be a dict"
140 | )
141 |
142 | # retrieving asset allowed fields to filter on
143 | allowed_filters = models.SCHEMA_TO_MODEL[asset_type].allowed_filters()
144 |
145 | # for each attribute (key) to filter on
146 | for key in filters:
147 | # check that key is a valid filter
148 | if key not in allowed_filters:
149 | raise exceptions.NotAllowedFilterError(
150 | f"Cannot filter on {key}. Please review the documentation, filtering allowed only on {allowed_filters}"
151 | )
152 | elif key == "name":
153 | if not isinstance(filters[key], str):
154 | raise exceptions.FilterFormatError(
155 | """Cannot load filters. Please review the documentation, 'name' filter is partial match in remote,
156 | exact match in local, value should be str"""
157 | )
158 | elif key == "metadata":
159 | _check_metadata_search_filters(filters[key])
160 |
161 | # all other filters should be a list, throw an error if not
162 | elif not isinstance(filters[key], list):
163 | raise exceptions.FilterFormatError(
164 | "Cannot load filters. Please review the documentation, filters values should be a list"
165 | )
166 | # handle default case (List)
167 | else:
168 | # convert all keys to str, needed for rank as user can give int, can prevent errors if user doesn't give str
169 | filters[key] = [str(v) for v in filters[key]]
170 |
171 | return filters
172 |
173 |
174 | def check_search_ordering(order_by):
175 | if order_by is None:
176 | return
177 | elif not models.OrderingFields.__contains__(order_by):
178 | raise exceptions.OrderingFormatError(
179 | f"Please review the documentation, ordering is available only on {list(models.OrderingFields.__members__)}"
180 | )
181 |
182 |
183 | def retry_on_exception(exceptions, timeout=300):
184 | """Retry function in case of exception(s).
185 |
186 | Args:
187 | exceptions (list): list of exception types that trigger a retry
188 | timeout (int, optional): timeout in seconds
189 |
190 | Example:
191 | ```python
192 | from substra.sdk import exceptions, retry_on_exception
193 |
194 | def my_function(arg1, arg2):
195 | pass
196 |
197 | retry = retry_on_exception(
198 | exceptions=(exceptions.RequestTimeout),
199 | timeout=300,
200 | )
201 | retry(my_function)(arg1, arg2)
202 | ```
203 | """
204 |
205 | def _retry(f):
206 | @functools.wraps(f)
207 | def wrapper(*args, **kwargs):
208 | delay = 1
209 | backoff = 2
210 | tstart = time.time()
211 |
212 | while True:
213 | try:
214 | return f(*args, **kwargs)
215 |
216 | except exceptions:
217 | if timeout is not False and time.time() - tstart > timeout:
218 | raise
219 | logging.warning(f"Function {f.__name__} failed: retrying in {delay}s")
220 | time.sleep(delay)
221 | delay *= backoff
222 |
223 | return wrapper
224 |
225 | return _retry
226 |
227 |
228 | def response_get_destination_filename(response):
229 | """Get filename from content-disposition header."""
230 | disposition = response.headers.get("content-disposition")
231 | if not disposition:
232 | return None
233 | filenames = re.findall("filename=(.+)", disposition)
234 | if not filenames:
235 | return None
236 | filename = filenames[0]
237 | filename = filename.strip("'\"")
238 | return filename
239 |
240 |
241 | def is_valid_uuid(value):
242 | try:
243 | uuid.UUID(value)
244 | return True
245 | except ValueError:
246 | return False
247 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import substra
4 | from substra.sdk.schemas import FunctionSpec
5 |
6 | from . import data_factory
7 | from .fl_interface import FLFunctionInputs
8 | from .fl_interface import FLFunctionOutputs
9 | from .fl_interface import FunctionCategory
10 |
11 |
12 | def pytest_configure(config):
13 | config.addinivalue_line(
14 | "markers",
15 | "slow: marks tests as slow (deselect with '-m \"not slow\"')",
16 | )
17 |
18 |
19 | @pytest.fixture
20 | def client(tmpdir):
21 | c = substra.Client(url="http://foo.io", backend_type="remote", token="foo")
22 | return c
23 |
24 |
25 | @pytest.fixture
26 | def workdir(tmp_path):
27 | d = tmp_path / "substra-cli"
28 | d.mkdir()
29 | return d
30 |
31 |
32 | @pytest.fixture
33 | def dataset_query(tmpdir):
34 | opener_path = tmpdir / "opener.py"
35 | opener_path.write_text("raise ValueError()", encoding="utf-8")
36 |
37 | desc_path = tmpdir / "description.md"
38 | desc_path.write_text("#Hello world", encoding="utf-8")
39 |
40 | return {
41 | "name": "dataset_name",
42 | "data_opener": str(opener_path),
43 | "description": str(desc_path),
44 | "permissions": {
45 | "public": True,
46 | "authorized_ids": [],
47 | },
48 | "logs_permission": {
49 | "public": True,
50 | "authorized_ids": [],
51 | },
52 | }
53 |
54 |
55 | @pytest.fixture
56 | def metric_query(tmpdir):
57 | metrics_path = tmpdir / "metrics.zip"
58 | metrics_path.write_text("foo archive", encoding="utf-8")
59 |
60 | desc_path = tmpdir / "description.md"
61 | desc_path.write_text("#Hello world", encoding="utf-8")
62 |
63 | return {
64 | "name": "metrics_name",
65 | "file": str(metrics_path),
66 | "description": str(desc_path),
67 | "permissions": {
68 | "public": True,
69 | "authorized_ids": [],
70 | },
71 | }
72 |
73 |
74 | @pytest.fixture
75 | def function_query(tmpdir):
76 | function_file_path = tmpdir / "function.tar.gz"
77 | function_file_path.write(b"tar gz archive")
78 |
79 | desc_path = tmpdir / "description.md"
80 | desc_path.write_text("#Hello world", encoding="utf-8")
81 |
82 | function_category = FunctionCategory.simple
83 |
84 | return FunctionSpec(
85 | name="function_name",
86 | inputs=FLFunctionInputs[function_category],
87 | outputs=FLFunctionOutputs[function_category],
88 | description=str(desc_path),
89 | file=str(function_file_path),
90 | permissions={
91 | "public": True,
92 | "authorized_ids": [],
93 | },
94 | )
95 |
96 |
97 | @pytest.fixture
98 | def data_sample_query(tmpdir):
99 | data_sample_dir_path = tmpdir / "data_sample_0"
100 | data_sample_file_path = data_sample_dir_path / "data.txt"
101 | data_sample_file_path.write_text("Hello world 0", encoding="utf-8", ensure=True)
102 |
103 | return {
104 | "path": str(data_sample_dir_path),
105 | "data_manager_keys": ["42"],
106 | }
107 |
108 |
109 | @pytest.fixture
110 | def data_samples_query(tmpdir):
111 | nb = 3
112 | paths = []
113 | for i in range(nb):
114 | data_sample_dir_path = tmpdir / f"data_sample_{i}"
115 | data_sample_file_path = data_sample_dir_path / "data.txt"
116 | data_sample_file_path.write_text(f"Hello world {i}", encoding="utf-8", ensure=True)
117 |
118 | paths.append(str(data_sample_dir_path))
119 |
120 | return {
121 | "paths": paths,
122 | "data_manager_keys": ["42"],
123 | }
124 |
125 |
126 | @pytest.fixture(scope="session")
127 | def asset_factory():
128 | return data_factory.AssetsFactory("test_debug")
129 |
130 |
131 | @pytest.fixture()
132 | def data_sample(asset_factory):
133 | return asset_factory.create_data_sample()
134 |
--------------------------------------------------------------------------------
/tests/fl_interface.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from substra.sdk.schemas import AssetKind
4 | from substra.sdk.schemas import ComputeTaskOutputSpec
5 | from substra.sdk.schemas import FunctionInputSpec
6 | from substra.sdk.schemas import FunctionOutputSpec
7 | from substra.sdk.schemas import InputRef
8 | from substra.sdk.schemas import Permissions
9 |
10 | PUBLIC_PERMISSIONS = Permissions(public=True, authorized_ids=[])
11 |
12 |
13 | class FunctionCategory(str, Enum):
14 | """Function category"""
15 |
16 | unknown = "FUNCTION_UNKNOWN"
17 | simple = "FUNCTION_SIMPLE"
18 | composite = "FUNCTION_COMPOSITE"
19 | aggregate = "FUNCTION_AGGREGATE"
20 | metric = "FUNCTION_METRIC"
21 | predict = "FUNCTION_PREDICT"
22 | predict_composite = "FUNCTION_PREDICT_COMPOSITE"
23 |
24 |
25 | class InputIdentifiers(str, Enum):
26 | local = "local"
27 | shared = "shared"
28 | predictions = "predictions"
29 | performance = "performance"
30 | opener = "opener"
31 | datasamples = "datasamples"
32 |
33 |
34 | class OutputIdentifiers(str, Enum):
35 | local = "local"
36 | shared = "shared"
37 | predictions = "predictions"
38 | performance = "performance"
39 |
40 |
41 | class FLFunctionInputs(list, Enum):
42 | """Substra function inputs by function category based on the InputIdentifiers"""
43 |
44 | FUNCTION_AGGREGATE = [
45 | FunctionInputSpec(identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=False, multiple=True)
46 | ]
47 | FUNCTION_SIMPLE = [
48 | FunctionInputSpec(
49 | identifier=InputIdentifiers.datasamples,
50 | kind=AssetKind.data_sample.value,
51 | optional=False,
52 | multiple=True,
53 | ),
54 | FunctionInputSpec(
55 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False
56 | ),
57 | FunctionInputSpec(identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=True, multiple=True),
58 | ]
59 | FUNCTION_COMPOSITE = [
60 | FunctionInputSpec(
61 | identifier=InputIdentifiers.datasamples,
62 | kind=AssetKind.data_sample.value,
63 | optional=False,
64 | multiple=True,
65 | ),
66 | FunctionInputSpec(
67 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False
68 | ),
69 | FunctionInputSpec(identifier=InputIdentifiers.local, kind=AssetKind.model.value, optional=True, multiple=False),
70 | FunctionInputSpec(
71 | identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=True, multiple=False
72 | ),
73 | ]
74 | FUNCTION_PREDICT = [
75 | FunctionInputSpec(
76 | identifier=InputIdentifiers.datasamples,
77 | kind=AssetKind.data_sample.value,
78 | optional=False,
79 | multiple=True,
80 | ),
81 | FunctionInputSpec(
82 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False
83 | ),
84 | FunctionInputSpec(
85 | identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=False, multiple=False
86 | ),
87 | ]
88 | FUNCTION_PREDICT_COMPOSITE = [
89 | FunctionInputSpec(
90 | identifier=InputIdentifiers.datasamples,
91 | kind=AssetKind.data_sample.value,
92 | optional=False,
93 | multiple=True,
94 | ),
95 | FunctionInputSpec(
96 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False
97 | ),
98 | FunctionInputSpec(
99 | identifier=InputIdentifiers.local, kind=AssetKind.model.value, optional=False, multiple=False
100 | ),
101 | FunctionInputSpec(
102 | identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=False, multiple=False
103 | ),
104 | ]
105 | FUNCTION_METRIC = [
106 | FunctionInputSpec(
107 | identifier=InputIdentifiers.datasamples,
108 | kind=AssetKind.data_sample.value,
109 | optional=False,
110 | multiple=True,
111 | ),
112 | FunctionInputSpec(
113 | identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False
114 | ),
115 | FunctionInputSpec(
116 | identifier=InputIdentifiers.predictions, kind=AssetKind.model.value, optional=False, multiple=False
117 | ),
118 | ]
119 |
120 |
121 | class FLFunctionOutputs(list, Enum):
122 | """Substra function outputs by function category based on the OutputIdentifiers"""
123 |
124 | FUNCTION_AGGREGATE = [
125 | FunctionOutputSpec(identifier=OutputIdentifiers.shared, kind=AssetKind.model.value, multiple=False)
126 | ]
127 | FUNCTION_SIMPLE = [
128 | FunctionOutputSpec(identifier=OutputIdentifiers.shared, kind=AssetKind.model.value, multiple=False)
129 | ]
130 | FUNCTION_COMPOSITE = [
131 | FunctionOutputSpec(identifier=OutputIdentifiers.local, kind=AssetKind.model.value, multiple=False),
132 | FunctionOutputSpec(identifier=OutputIdentifiers.shared, kind=AssetKind.model.value, multiple=False),
133 | ]
134 | FUNCTION_PREDICT = [
135 | FunctionOutputSpec(identifier=OutputIdentifiers.predictions, kind=AssetKind.model.value, multiple=False)
136 | ]
137 | FUNCTION_PREDICT_COMPOSITE = [
138 | FunctionOutputSpec(identifier=OutputIdentifiers.predictions, kind=AssetKind.model.value, multiple=False)
139 | ]
140 | FUNCTION_METRIC = [
141 | FunctionOutputSpec(identifier=OutputIdentifiers.performance, kind=AssetKind.performance.value, multiple=False)
142 | ]
143 |
144 |
145 | class FLTaskInputGenerator:
146 | "Generates task inputs based on Input and OutputIdentifiers"
147 |
148 | @staticmethod
149 | def opener(opener_key):
150 | return [InputRef(identifier=InputIdentifiers.opener, asset_key=opener_key)]
151 |
152 | @staticmethod
153 | def data_samples(data_sample_keys):
154 | return [
155 | InputRef(identifier=InputIdentifiers.datasamples, asset_key=data_sample) for data_sample in data_sample_keys
156 | ]
157 |
158 | @staticmethod
159 | def task(opener_key, data_sample_keys):
160 | return FLTaskInputGenerator.opener(opener_key=opener_key) + FLTaskInputGenerator.data_samples(
161 | data_sample_keys=data_sample_keys
162 | )
163 |
164 | @staticmethod
165 | def trains_to_train(model_keys):
166 | return [
167 | InputRef(
168 | identifier=InputIdentifiers.shared,
169 | parent_task_key=model_key,
170 | parent_task_output_identifier=OutputIdentifiers.shared,
171 | )
172 | for model_key in model_keys
173 | ]
174 |
175 | @staticmethod
176 | def trains_to_aggregate(model_keys):
177 | return [
178 | InputRef(
179 | identifier=InputIdentifiers.shared,
180 | parent_task_key=model_key,
181 | parent_task_output_identifier=OutputIdentifiers.shared,
182 | )
183 | for model_key in model_keys
184 | ]
185 |
186 | @staticmethod
187 | def train_to_predict(model_key):
188 | return [
189 | InputRef(
190 | identifier=InputIdentifiers.shared,
191 | parent_task_key=model_key,
192 | parent_task_output_identifier=OutputIdentifiers.shared,
193 | )
194 | ]
195 |
196 | @staticmethod
197 | def predict_to_test(model_key):
198 | return [
199 | InputRef(
200 | identifier=InputIdentifiers.predictions,
201 | parent_task_key=model_key,
202 | parent_task_output_identifier=OutputIdentifiers.predictions,
203 | )
204 | ]
205 |
206 | @staticmethod
207 | def composite_to_predict(model_key):
208 | return [
209 | InputRef(
210 | identifier=InputIdentifiers.local,
211 | parent_task_key=model_key,
212 | parent_task_output_identifier=OutputIdentifiers.local,
213 | ),
214 | InputRef(
215 | identifier=InputIdentifiers.shared,
216 | parent_task_key=model_key,
217 | parent_task_output_identifier=OutputIdentifiers.shared,
218 | ),
219 | ]
220 |
221 | @staticmethod
222 | def composite_to_local(model_key):
223 | return [
224 | InputRef(
225 | identifier=InputIdentifiers.local,
226 | parent_task_key=model_key,
227 | parent_task_output_identifier=OutputIdentifiers.local,
228 | )
229 | ]
230 |
231 | @staticmethod
232 | def composite_to_composite(model1_key, model2_key=None):
233 | return [
234 | InputRef(
235 | identifier=InputIdentifiers.local,
236 | parent_task_key=model1_key,
237 | parent_task_output_identifier=OutputIdentifiers.local,
238 | ),
239 | InputRef(
240 | identifier=InputIdentifiers.shared,
241 | parent_task_key=model2_key or model1_key,
242 | parent_task_output_identifier=OutputIdentifiers.shared,
243 | ),
244 | ]
245 |
246 | @staticmethod
247 | def aggregate_to_shared(model_key):
248 | return [
249 | InputRef(
250 | identifier=InputIdentifiers.shared,
251 | parent_task_key=model_key,
252 | parent_task_output_identifier=OutputIdentifiers.shared,
253 | )
254 | ]
255 |
256 | @staticmethod
257 | def composites_to_aggregate(model_keys):
258 | return [
259 | InputRef(
260 | identifier=InputIdentifiers.shared,
261 | parent_task_key=model_key,
262 | parent_task_output_identifier=OutputIdentifiers.shared,
263 | )
264 | for model_key in model_keys
265 | ]
266 |
267 | @staticmethod
268 | def aggregate_to_predict(model_key):
269 | return [
270 | InputRef(
271 | identifier=InputIdentifiers.shared,
272 | parent_task_key=model_key,
273 | parent_task_output_identifier=OutputIdentifiers.shared,
274 | )
275 | ]
276 |
277 | @staticmethod
278 | def local_to_aggregate(model_key):
279 | return [
280 | InputRef(
281 | identifier=InputIdentifiers.shared,
282 | parent_task_key=model_key,
283 | parent_task_output_identifier=OutputIdentifiers.local,
284 | )
285 | ]
286 |
287 | @staticmethod
288 | def shared_to_aggregate(model_key):
289 | return [
290 | InputRef(
291 | identifier=InputIdentifiers.shared,
292 | parent_task_key=model_key,
293 | parent_task_output_identifier=OutputIdentifiers.shared,
294 | )
295 | ]
296 |
297 |
298 | def _permission_from_ids(authorized_ids):
299 | if authorized_ids is None:
300 | return PUBLIC_PERMISSIONS
301 |
302 | return Permissions(public=False, authorized_ids=authorized_ids)
303 |
304 |
305 | class FLTaskOutputGenerator:
306 | "Generates task outputs based on Input and OutputIdentifiers"
307 |
308 | @staticmethod
309 | def traintask(authorized_ids=None):
310 | return {OutputIdentifiers.shared: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))}
311 |
312 | @staticmethod
313 | def aggregatetask(authorized_ids=None):
314 | return {OutputIdentifiers.shared: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))}
315 |
316 | @staticmethod
317 | def predicttask(authorized_ids=None):
318 | return {OutputIdentifiers.predictions: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))}
319 |
320 | @staticmethod
321 | def testtask(authorized_ids=None):
322 | return {
323 | OutputIdentifiers.performance: ComputeTaskOutputSpec(
324 | permissions=Permissions(public=True, authorized_ids=[])
325 | )
326 | }
327 |
328 | @staticmethod
329 | def composite_traintask(shared_authorized_ids=None, local_authorized_ids=None):
330 | return {
331 | OutputIdentifiers.shared: ComputeTaskOutputSpec(permissions=_permission_from_ids(shared_authorized_ids)),
332 | OutputIdentifiers.local: ComputeTaskOutputSpec(permissions=_permission_from_ids(local_authorized_ids)),
333 | }
334 |
--------------------------------------------------------------------------------
/tests/mocked_requests.py:
--------------------------------------------------------------------------------
1 | from tests.utils import mock_requests
2 |
3 |
4 | def cancel_compute_plan(mocker):
5 | m = mock_requests(mocker, "post", response=None)
6 | return m
7 |
--------------------------------------------------------------------------------
/tests/sdk/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/sdk/__init__.py
--------------------------------------------------------------------------------
/tests/sdk/data/symlink.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/sdk/data/symlink.zip
--------------------------------------------------------------------------------
/tests/sdk/data/traversal.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/sdk/data/traversal.zip
--------------------------------------------------------------------------------
/tests/sdk/local/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Substra/substra/960f33d548ed2d571cc12aa2b61dca208b5262a3/tests/sdk/local/__init__.py
--------------------------------------------------------------------------------
/tests/sdk/local/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import substra
4 |
5 |
6 | @pytest.fixture(scope="session")
7 | def docker_clients():
8 | return [substra.Client(backend_type=substra.BackendType.LOCAL_DOCKER) for _ in range(2)]
9 |
10 |
11 | @pytest.fixture(scope="session")
12 | def subprocess_clients():
13 | return [substra.Client(backend_type=substra.BackendType.LOCAL_SUBPROCESS) for _ in range(2)]
14 |
15 |
16 | @pytest.fixture(scope="session", params=["docker", "subprocess"])
17 | def clients(request, docker_clients, subprocess_clients):
18 | if request.param == "docker":
19 | return docker_clients
20 | else:
21 | return subprocess_clients
22 |
--------------------------------------------------------------------------------
/tests/sdk/test_add.py:
--------------------------------------------------------------------------------
1 | import pydantic
2 | import pytest
3 |
4 | import substra
5 | from substra.sdk import models
6 | from substra.sdk.exceptions import ComputePlanKeyFormatError
7 |
8 | from .. import datastore
9 | from ..utils import mock_requests
10 |
11 |
12 | def test_add_dataset(client, dataset_query, mocker):
13 | m_post = mock_requests(mocker, "post", response=datastore.DATASET)
14 | m_get = mock_requests(mocker, "get", response=datastore.DATASET)
15 | key = client.add_dataset(dataset_query)
16 | response = client.get_dataset(key)
17 |
18 | assert response == models.Dataset(**datastore.DATASET)
19 | m_post.assert_called()
20 | m_get.assert_called()
21 |
22 |
23 | def test_add_dataset_invalid_args(client, dataset_query, mocker):
24 | mock_requests(mocker, "post", response=datastore.DATASET)
25 | del dataset_query["data_opener"]
26 |
27 | with pytest.raises(pydantic.ValidationError):
28 | client.add_dataset(dataset_query)
29 |
30 |
31 | def test_add_dataset_response_failure_500(client, dataset_query, mocker):
32 | mock_requests(mocker, "post", status=500)
33 |
34 | with pytest.raises(substra.sdk.exceptions.InternalServerError):
35 | client.add_dataset(dataset_query)
36 |
37 |
38 | def test_add_dataset_409_success(client, dataset_query, mocker):
39 | mock_requests(mocker, "post", response={"key": datastore.DATASET["key"]}, status=409)
40 | mock_requests(mocker, "get", response=datastore.DATASET)
41 |
42 | key = client.add_dataset(dataset_query)
43 |
44 | assert key == datastore.DATASET["key"]
45 |
46 |
47 | def test_add_function(client, function_query, mocker):
48 | m_post = mock_requests(mocker, "post", response=datastore.FUNCTION)
49 | m_get = mock_requests(mocker, "get", response=datastore.FUNCTION)
50 | key = client.add_function(function_query)
51 | response = client.get_function(key)
52 |
53 | assert response == models.Function(**datastore.FUNCTION)
54 | m_post.assert_called()
55 | m_get.assert_called()
56 |
57 |
58 | def test_add_data_sample(client, data_sample_query, mocker):
59 | server_response = [{"key": "42"}]
60 | m = mock_requests(mocker, "post", response=server_response)
61 | response = client.add_data_sample(data_sample_query)
62 |
63 | assert response == server_response[0]["key"]
64 | m.assert_called()
65 |
66 |
67 | def test_add_data_sample_already_exists(client, data_sample_query, mocker):
68 | m = mock_requests(mocker, "post", response=[{"key": "42"}], status=409)
69 | response = client.add_data_sample(data_sample_query)
70 |
71 | assert response == "42"
72 | m.assert_called()
73 |
74 |
75 | # We try to add multiple data samples instead of a single one
76 | def test_add_data_sample_with_paths(client, data_samples_query):
77 | with pytest.raises(ValueError):
78 | client.add_data_sample(data_samples_query)
79 |
80 |
81 | def test_add_data_samples(client, data_samples_query, mocker):
82 | server_response = [{"key": "42"}]
83 | m = mock_requests(mocker, "post", response=server_response)
84 | response = client.add_data_samples(data_samples_query)
85 |
86 | assert response == ["42"]
87 | m.assert_called()
88 |
89 |
90 | # We try to add a single data sample instead of multiple ones
91 | def test_add_data_samples_with_path(client, data_sample_query):
92 | with pytest.raises(ValueError):
93 | client.add_data_samples(data_sample_query)
94 |
95 |
96 | def test_add_compute_plan_wrong_key_format(client):
97 | with pytest.raises(ComputePlanKeyFormatError):
98 | data = {"key": "wrong_format", "name": "A perfectly valid name"}
99 | client.add_compute_plan(data)
100 |
--------------------------------------------------------------------------------
/tests/sdk/test_archive.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tarfile
4 | import tempfile
5 |
6 | import pytest
7 |
8 | from substra.sdk.archive import _untar
9 | from substra.sdk.archive import _unzip
10 | from substra.sdk.archive import tarsafe
11 | from substra.sdk.archive.safezip import ZipFile
12 |
13 |
14 | class TestsZipFile:
15 | # This zip file was specifically crafted and contains empty files named:
16 | # foo/bar
17 | # ../foo/bar
18 | # ../../foo/bar
19 | # ../../../foo/bar
20 | TRAVERSAL_ZIP = os.path.join(os.path.dirname(__file__), "data", "traversal.zip")
21 |
22 | # This zip file was specifically crafted and contains:
23 | # bar
24 | # foo -> bar (symlink)
25 | SYMLINK_ZIP = os.path.join(os.path.dirname(__file__), "data", "symlink.zip")
26 |
27 | def test_raise_on_path_traversal(self):
28 | zf = ZipFile(self.TRAVERSAL_ZIP, "r")
29 | with pytest.raises(Exception) as exc:
30 | zf.extractall(tempfile.gettempdir())
31 |
32 | assert "Attempted directory traversal" in str(exc.value)
33 |
34 | def test_raise_on_symlink(self):
35 | zf = ZipFile(self.SYMLINK_ZIP, "r")
36 | with pytest.raises(Exception) as exc:
37 | zf.extractall(tempfile.gettempdir())
38 |
39 | assert "Unsupported symlink" in str(exc.value)
40 |
41 | def test_compress_uncompress_zip(self):
42 | with tempfile.TemporaryDirectory() as tmpdirname:
43 | path_testdir = os.path.join(tmpdirname, "testdir")
44 | os.makedirs(path_testdir)
45 | with open(os.path.join(path_testdir, "testfile.txt"), "w") as f:
46 | f.write("testtext")
47 | shutil.make_archive(os.path.join(tmpdirname, "test_archive"), "zip", root_dir=os.path.dirname(path_testdir))
48 | _unzip(os.path.join(tmpdirname, "test_archive.zip"), os.path.join(tmpdirname, "test_archive"))
49 |
50 | assert os.listdir(tmpdirname + "/test_archive/testdir") == ["testfile.txt"]
51 |
52 |
53 | class TestsTarSafe:
54 | def test_compress_uncompress_tar(self):
55 | with tempfile.TemporaryDirectory() as tmpdirname:
56 | path_testdir = os.path.join(tmpdirname, "testdir")
57 | os.makedirs(path_testdir)
58 | with open(os.path.join(path_testdir, "testfile.txt"), "w") as f:
59 | f.write("testtext")
60 | path_tarfile = os.path.join(tmpdirname, "test_archive.tar")
61 | with tarsafe.open(path_tarfile, "w:gz") as tar:
62 | tar.add(path_testdir, arcname=os.path.basename(path_testdir))
63 |
64 | _untar(path_tarfile, os.path.join(tmpdirname, "test_archive"))
65 |
66 | assert os.listdir(tmpdirname + "/test_archive/testdir") == ["testfile.txt"]
67 |
68 | def test_raise_on_symlink(self):
69 | with tempfile.TemporaryDirectory() as tmpdir:
70 | # create the following tree structure:
71 | # ./Dockerfile
72 | # ./foo
73 | # ./Dockerfile -> ../Dockerfile
74 |
75 | filename = "Dockerfile"
76 | symlink_source = os.path.join(tmpdir, filename)
77 | with open(symlink_source, "w") as fp:
78 | fp.write("FROM bar")
79 |
80 | archive_root = os.path.join(tmpdir, "foo")
81 | os.mkdir(archive_root)
82 | os.symlink(symlink_source, os.path.join(archive_root, filename))
83 |
84 | # create a tar archive of the foo folder
85 | tarpath = os.path.join(tmpdir, "foo.tgz")
86 | with tarfile.open(tarpath, "w:gz") as tar:
87 | for root, _, files in os.walk(archive_root):
88 | for file in files:
89 | tar.add(os.path.join(root, file))
90 |
91 | with pytest.raises(tarsafe.TarSafeError) as error:
92 | with tarsafe.open(tarpath, "r") as tar:
93 | tar.extractall()
94 |
95 | assert "Unsupported symlink" in str(error.exception)
96 |
--------------------------------------------------------------------------------
/tests/sdk/test_cancel.py:
--------------------------------------------------------------------------------
1 | from tests import mocked_requests
2 |
3 |
4 | def test_cancel_compute_plan(client, mocker):
5 | m = mocked_requests.cancel_compute_plan(mocker)
6 |
7 | response = client.cancel_compute_plan("magic-key")
8 |
9 | assert response is None
10 | m.assert_called()
11 |
--------------------------------------------------------------------------------
/tests/sdk/test_client.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import yaml
3 |
4 | from substra.sdk import exceptions
5 | from substra.sdk import schemas
6 | from substra.sdk.client import Client
7 | from substra.sdk.client import _upper_slug
8 | from substra.sdk.exceptions import ConfigurationInfoError
9 |
10 |
11 | @pytest.fixture
12 | def config_file(tmp_path):
13 | config_dict = {
14 | "toto": {
15 | "backend_type": "remote",
16 | "url": "toto-org.com",
17 | "username": "toto_file_username",
18 | "password": "toto_file_password",
19 | }
20 | }
21 | config_file = tmp_path / "config.yaml"
22 | config_file.write_text(yaml.dump(config_dict))
23 | return config_file
24 |
25 |
26 | def stub_login(username, password):
27 | if username == "org-1" and password == "password1":
28 | return "token1"
29 | if username == "env_var_username" and password == "env_var_password":
30 | return "env_var_token"
31 | if username == "toto_file_username" and password == "toto_file_password":
32 | return "toto_file_token"
33 |
34 |
35 | @pytest.mark.parametrize(
36 | ["input", "expected"],
37 | [
38 | ("toto", "TOTO"),
39 | ("client-org-1", "CLIENT_ORG_1"),
40 | ("un nom très français", "UN_NOM_TRES_FRANCAIS"),
41 | ],
42 | )
43 | def test_upper_slug(input, expected):
44 | assert expected == _upper_slug(input)
45 |
46 |
47 | def test_default_client():
48 | client = Client()
49 | assert client.backend_mode == schemas.BackendType.LOCAL_SUBPROCESS
50 |
51 |
52 | @pytest.mark.parametrize(
53 | ["mode", "client_name", "url", "token", "insecure", "retry_timeout"],
54 | [
55 | (schemas.BackendType.LOCAL_SUBPROCESS, "foo", None, None, True, 5),
56 | (schemas.BackendType.LOCAL_DOCKER, "foobar", None, None, True, None),
57 | (schemas.BackendType.REMOTE, "bar", "example.com", "bloop", False, 15),
58 | (schemas.BackendType.LOCAL_SUBPROCESS, "hybrid", "example.com", "foo", True, 500),
59 | (schemas.BackendType.REMOTE, "foofoo", "https://example.com/api-token-auth/", None, True, 500),
60 | ],
61 | )
62 | def test_client_configured_in_code(mode, client_name, url, token, insecure, retry_timeout):
63 | client = Client(
64 | backend_type=mode,
65 | client_name=client_name,
66 | url=url,
67 | token=token,
68 | insecure=insecure,
69 | retry_timeout=retry_timeout,
70 | )
71 | assert client.backend_mode == mode
72 | assert client._insecure == insecure
73 | if retry_timeout is not None:
74 | assert client._retry_timeout == retry_timeout
75 | else:
76 | assert client._retry_timeout == 300
77 | if token is not None:
78 | assert client._token == token
79 | else:
80 | assert client._token is None
81 | if url is not None:
82 | assert client._url == url
83 | else:
84 | assert client._url is None
85 |
86 |
87 | def test_client_should_raise_when_missing_name():
88 | with pytest.raises(ConfigurationInfoError):
89 | Client(configuration_file="something")
90 |
91 |
92 | def test_client_with_password(mocker):
93 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
94 | rest_client_logout = mocker.patch("substra.sdk.backends.remote.rest_client.Client.logout")
95 | rest_client_logout.reset_mock()
96 | client_args = {
97 | "backend_type": "remote",
98 | "url": "example.com",
99 | "token": None,
100 | "username": "org-1",
101 | "password": "password1",
102 | }
103 |
104 | client = Client(**client_args)
105 | assert client._token == "token1"
106 | client.logout()
107 | assert client._token is None
108 | rest_client_logout.assert_called_once()
109 | del client
110 |
111 | rest_client_logout.reset_mock()
112 | client = Client(**client_args)
113 | del client
114 | rest_client_logout.assert_called_once()
115 |
116 | rest_client_logout.reset_mock()
117 | with Client(**client_args) as client:
118 | assert client._token == "token1"
119 | rest_client_logout.assert_called_once()
120 |
121 |
122 | def test_client_token_supercedes_password(mocker):
123 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
124 | client = Client(
125 | backend_type="remote",
126 | url="example.com",
127 | token="token0",
128 | username="org-1",
129 | password="password1",
130 | )
131 | assert client._token == "token0"
132 |
133 |
134 | def test_client_configuration_from_env_var(mocker, monkeypatch):
135 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
136 | monkeypatch.setenv("SUBSTRA_TOTO_BACKEND_TYPE", "remote")
137 | monkeypatch.setenv("SUBSTRA_TOTO_URL", "toto-org.com")
138 | monkeypatch.setenv("SUBSTRA_TOTO_USERNAME", "env_var_username")
139 | monkeypatch.setenv("SUBSTRA_TOTO_PASSWORD", "env_var_password")
140 | monkeypatch.setenv("SUBSTRA_TOTO_RETRY_TIMEOUT", "42")
141 | monkeypatch.setenv("SUBSTRA_TOTO_INSECURE", "true")
142 | client = Client(client_name="toto")
143 | assert client.backend_mode == "remote"
144 | assert client._url == "toto-org.com"
145 | assert client._token == "env_var_token"
146 | assert client._retry_timeout == 42
147 | assert client._insecure is True
148 |
149 |
150 | def test_client_configuration_from_config_file(mocker, config_file):
151 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
152 | client = Client(configuration_file=config_file, client_name="toto")
153 | assert client.backend_mode == "remote"
154 | assert client._url == "toto-org.com"
155 | assert client._token == "toto_file_token"
156 | assert client._retry_timeout == 300
157 | assert client._insecure is False
158 |
159 |
160 | def test_client_configuration_code_overrides_env_var(monkeypatch):
161 | """
162 | A variable set in the code overrides one set in an env variable
163 | """
164 | monkeypatch.setenv("SUBSTRA_TOTO_BACKEND_TYPE", "remote")
165 | monkeypatch.setenv("SUBSTRA_TOTO_URL", "toto-org.com")
166 | monkeypatch.setenv("SUBSTRA_TOTO_USERNAME", "env_var_username")
167 | monkeypatch.setenv("SUBSTRA_TOTO_PASSWORD", "env_var_password")
168 | monkeypatch.setenv("SUBSTRA_TOTO_RETRY_TIMEOUT", "42")
169 | monkeypatch.setenv("SUBSTRA_TOTO_INSECURE", "true")
170 | client = Client(
171 | client_name="toto",
172 | backend_type="subprocess",
173 | url="",
174 | )
175 | assert client.backend_mode == "subprocess"
176 | assert client._url == ""
177 | assert client._token is None
178 | assert client._retry_timeout == 42
179 | assert client._insecure is True
180 |
181 |
182 | def test_client_configuration_code_overrides_config_file(mocker, config_file):
183 | """
184 | A variable set in the code overrides one set in a config file
185 | """
186 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
187 | client = Client(
188 | client_name="toto",
189 | configuration_file=config_file,
190 | username="org-1",
191 | password="password1",
192 | retry_timeout=100,
193 | )
194 | assert client.backend_mode == "remote"
195 | assert client._url == "toto-org.com"
196 | assert client._token == "token1"
197 | assert client._retry_timeout == 100
198 | assert client._insecure is False
199 |
200 |
201 | def test_client_configuration_env_var_overrides_config_file(mocker, monkeypatch, config_file):
202 | """
203 | A variable set in an env var overrides one set in a config file
204 | """
205 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
206 | monkeypatch.setenv("SUBSTRA_TOTO_BACKEND_TYPE", "docker")
207 | monkeypatch.setenv("SUBSTRA_TOTO_USERNAME", "env_var_username")
208 | monkeypatch.setenv("SUBSTRA_TOTO_PASSWORD", "env_var_password")
209 | client = Client(configuration_file=config_file, client_name="toto", retry_timeout=12)
210 | assert client.backend_mode == "docker"
211 | assert client._url == "toto-org.com"
212 | assert client._token == "env_var_token"
213 | assert client._retry_timeout == 12
214 | assert client._insecure is False
215 |
216 |
217 | def test_login_remote_without_url(tmpdir):
218 | with pytest.raises(exceptions.SDKException):
219 | Client(backend_type="remote")
220 |
221 |
222 | def test_client_configuration_configuration_file_path_from_env_var(mocker, monkeypatch, config_file):
223 | """
224 | The configuration file path can be set through an env var
225 | """
226 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
227 | monkeypatch.setenv("SUBSTRA_CLIENTS_CONFIGURATION_FILE_PATH", config_file)
228 | client = Client(client_name="toto")
229 | assert client.backend_mode == "remote"
230 | assert client._url == "toto-org.com"
231 | assert client._token == "toto_file_token"
232 | assert client._retry_timeout == 300
233 | assert client._insecure is False
234 |
235 |
236 | def test_client_configuration_configuration_file_path_parameter_supercedes_env_var(
237 | mocker, monkeypatch, config_file, tmp_path
238 | ):
239 | """
240 | The configuration file path env var is supercedes by `configuration_file=`
241 | """
242 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
243 | monkeypatch.setenv("SUBSTRA_CLIENTS_CONFIGURATION_FILE_PATH", config_file)
244 |
245 | config_2_dict = {
246 | "toto": {
247 | "backend_type": "docker",
248 | }
249 | }
250 | config_2_file = tmp_path / "config.yaml"
251 | config_2_file.write_text(yaml.dump(config_2_dict))
252 |
253 | client = Client(configuration_file=config_2_file, client_name="toto")
254 | assert client.backend_mode == "docker"
255 |
256 |
257 | def test_client_configuration_file_path_env_var_empty_string(mocker, monkeypatch, config_file):
258 | """
259 | The configuration file path env var is supercedes by `configuration_file=`
260 | """
261 | mocker.patch("substra.sdk.Client.login", side_effect=stub_login)
262 | monkeypatch.setenv("SUBSTRA_CLIENTS_CONFIGURATION_FILE_PATH", config_file)
263 |
264 | client = Client(configuration_file="", client_name="toto")
265 | assert client.backend_mode == "subprocess"
266 |
--------------------------------------------------------------------------------
/tests/sdk/test_describe.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import substra
4 |
5 | from .. import datastore
6 | from ..utils import mock_requests
7 | from ..utils import mock_requests_responses
8 | from ..utils import mock_response
9 |
10 |
11 | @pytest.mark.parametrize("asset_type", ["dataset", "function"])
12 | def test_describe_asset(asset_type, client, mocker):
13 | item = getattr(datastore, asset_type.upper())
14 | responses = [
15 | mock_response(item), # metadata
16 | mock_response("foo"), # data
17 | ]
18 | m = mock_requests_responses(mocker, "get", responses)
19 |
20 | method = getattr(client, f"describe_{asset_type}")
21 | response = method("magic-key")
22 |
23 | assert response == "foo"
24 | m.assert_called()
25 |
26 |
27 | @pytest.mark.parametrize("asset_type", ["dataset", "function"])
28 | def test_describe_asset_not_found(asset_type, client, mocker):
29 | m = mock_requests(mocker, "get", status=404)
30 |
31 | with pytest.raises(substra.sdk.exceptions.NotFound):
32 | method = getattr(client, f"describe_{asset_type}")
33 | method("foo")
34 |
35 | assert m.call_count == 1
36 |
37 |
38 | @pytest.mark.parametrize("asset_type", ["dataset", "function"])
39 | def test_describe_description_not_found(asset_type, client, mocker):
40 | item = getattr(datastore, asset_type.upper())
41 | responses = [
42 | mock_response(item), # metadata
43 | mock_response("foo", 404), # data
44 | ]
45 | m = mock_requests_responses(mocker, "get", responses)
46 |
47 | method = getattr(client, f"describe_{asset_type}")
48 |
49 | with pytest.raises(substra.sdk.exceptions.NotFound):
50 | method("key")
51 |
52 | assert m.call_count == 2
53 |
--------------------------------------------------------------------------------
/tests/sdk/test_download.py:
--------------------------------------------------------------------------------
1 | import os
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | import substra
7 | from substra.sdk import Client
8 |
9 | from .. import datastore
10 | from ..utils import mock_requests
11 | from ..utils import mock_requests_responses
12 | from ..utils import mock_response
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "asset_type",
17 | [
18 | ("dataset"),
19 | ("function"),
20 | ("model"),
21 | ],
22 | )
23 | def test_download_asset(asset_type, tmp_path, client, mocker):
24 | item = getattr(datastore, asset_type.upper())
25 | responses = [
26 | mock_response(item), # metadata
27 | mock_response("foo"), # data
28 | ]
29 | m = mock_requests_responses(mocker, "get", responses)
30 |
31 | method = getattr(client, f"download_{asset_type}")
32 | temp_file = method("foo", tmp_path)
33 |
34 | assert os.path.exists(temp_file)
35 | m.assert_called()
36 |
37 |
38 | @pytest.mark.parametrize("asset_type", ["dataset", "function", "model", "logs"])
39 | def test_download_asset_not_found(asset_type, tmp_path, client, mocker):
40 | m = mock_requests(mocker, "get", status=404)
41 |
42 | with pytest.raises(substra.sdk.exceptions.NotFound):
43 | method = getattr(client, f"download_{asset_type}")
44 | method("foo", tmp_path)
45 |
46 | assert m.call_count == 1
47 |
48 |
49 | @pytest.mark.parametrize("asset_type", ["dataset", "function", "model"])
50 | def test_download_content_not_found(asset_type, tmp_path, client, mocker):
51 | item = getattr(datastore, asset_type.upper())
52 |
53 | expected_call_count = 2
54 | responses = [
55 | mock_response(item), # metadata
56 | mock_response("foo", status=404), # description
57 | ]
58 |
59 | if asset_type == "model":
60 | responses = [responses[1]] # No metadata for model download
61 | expected_call_count = 1
62 |
63 | m = mock_requests_responses(mocker, "get", responses)
64 |
65 | method = getattr(client, f"download_{asset_type}")
66 |
67 | with pytest.raises(substra.sdk.exceptions.NotFound):
68 | method("key", tmp_path)
69 |
70 | assert m.call_count == expected_call_count
71 |
72 |
73 | @pytest.mark.parametrize(
74 | "asset_type, identifier",
75 | [
76 | ("TRAINTASK", "model"),
77 | ("AGGREGATETASK", "model"),
78 | ("COMPOSITE_TRAINTASK", "local"),
79 | ("COMPOSITE_TRAINTASK", "shared"),
80 | ],
81 | )
82 | @patch.object(Client, "download_model")
83 | def test_download_model_from_task(fake_download_model, tmp_path, client, asset_type, identifier, mocker):
84 | item = getattr(datastore, f"{asset_type}_{identifier.upper()}_RESPONSE")
85 | responses = [
86 | mock_response(item), # metadata
87 | mock_response("foo"), # data
88 | ]
89 |
90 | m = mock_requests_responses(mocker, "get", responses)
91 |
92 | client.download_model_from_task("key", identifier, tmp_path)
93 |
94 | m.assert_called
95 | assert fake_download_model.call_count == 1
96 |
97 |
98 | def test_download_logs(tmp_path, client, mocker):
99 | logs = b"Lorem ipsum dolor sit amet"
100 | task_key = "key"
101 |
102 | response = mock_response(logs)
103 | response.iter_content.return_value = [logs]
104 |
105 | m = mock_requests_responses(mocker, "get", [response])
106 | client.download_logs(task_key, tmp_path)
107 |
108 | m.assert_called_once()
109 | response.iter_content.assert_called_once()
110 |
111 | assert (tmp_path / f"task_logs_{task_key}.txt").read_bytes() == logs
112 |
--------------------------------------------------------------------------------
/tests/sdk/test_get.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | import substra
5 | from substra.sdk import models
6 | from substra.sdk import schemas
7 |
8 | from .. import datastore
9 | from ..utils import mock_requests
10 | from ..utils import mock_requests_responses
11 | from ..utils import mock_response
12 |
13 |
14 | @pytest.mark.parametrize(
15 | "asset_type",
16 | [
17 | "model",
18 | "dataset",
19 | "data_sample",
20 | "function",
21 | "compute_plan",
22 | ],
23 | )
24 | def test_get_asset(asset_type, client, mocker):
25 | item = getattr(datastore, asset_type.upper())
26 | method = getattr(client, f"get_{asset_type}")
27 |
28 | m = mock_requests(mocker, "get", response=item)
29 |
30 | response = method("magic-key")
31 |
32 | assert response == models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**item)
33 | m.assert_called()
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "asset_type",
38 | [
39 | "predicttask",
40 | "testtask",
41 | "traintask",
42 | "aggregatetask",
43 | "composite_traintask",
44 | ],
45 | )
46 | def test_get_task(asset_type, client, mocker):
47 | item = getattr(datastore, asset_type.upper())
48 |
49 | m = mock_requests(mocker, "get", response=item)
50 |
51 | response = client.get_task("magic-key")
52 |
53 | assert response == models.SCHEMA_TO_MODEL[schemas.Type.Task](**item)
54 | m.assert_called()
55 |
56 |
57 | def test_get_asset_not_found(client, mocker):
58 | mock_requests(mocker, "get", status=404)
59 |
60 | with pytest.raises(substra.sdk.exceptions.NotFound):
61 | client.get_dataset("magic-key")
62 |
63 |
64 | @pytest.mark.parametrize(
65 | "asset_type",
66 | [
67 | "dataset",
68 | "function",
69 | "compute_plan",
70 | "model",
71 | ],
72 | )
73 | def test_get_extra_field(asset_type, client, mocker):
74 | item = getattr(datastore, asset_type.upper())
75 | raw = getattr(datastore, asset_type.upper()).copy()
76 | raw["unknown_extra_field"] = "some value"
77 |
78 | method = getattr(client, f"get_{asset_type}")
79 |
80 | m = mock_requests(mocker, "get", response=raw)
81 |
82 | response = method("magic-key")
83 |
84 | assert response == models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**item)
85 | m.assert_called()
86 |
87 |
88 | @pytest.mark.parametrize(
89 | "asset_type",
90 | [
91 | "predicttask",
92 | "testtask",
93 | "traintask",
94 | "aggregatetask",
95 | "composite_traintask",
96 | ],
97 | )
98 | def test_get_task_extra_field(asset_type, client, mocker):
99 | item = getattr(datastore, asset_type.upper())
100 | raw = getattr(datastore, asset_type.upper()).copy()
101 | raw["unknown_extra_field"] = "some value"
102 |
103 | m = mock_requests(mocker, "get", response=raw)
104 |
105 | response = client.get_task("magic-key")
106 |
107 | assert response == models.SCHEMA_TO_MODEL[schemas.Type.Task](**item)
108 | m.assert_called()
109 |
110 |
111 | def test_get_logs(client, mocker):
112 | logs = "Lorem ipsum dolor sit amet"
113 | task_key = "key"
114 |
115 | responses = [mock_response(logs)]
116 | m = mock_requests_responses(mocker, "get", responses)
117 | result = client.get_logs(task_key)
118 |
119 | m.assert_called_once()
120 | assert result == logs
121 |
122 |
123 | def test_get_performances(client, mocker):
124 | """Test the get_performances features, and test the immediate conversion to pandas DataFrame."""
125 | cp_item = datastore.COMPUTE_PLAN
126 | perf_item = datastore.COMPUTE_PLAN_PERF
127 |
128 | m = mock_requests_responses(mocker, "get", [mock_response(cp_item), mock_response(perf_item)])
129 |
130 | response = client.get_performances("magic-key")
131 | results = response.model_dump()
132 |
133 | df = pd.DataFrame(results)
134 | assert list(df.columns) == list(results.keys())
135 | assert all(len(v) == df.shape[0] for v in results.values())
136 | assert m.call_count == 2
137 |
--------------------------------------------------------------------------------
/tests/sdk/test_graph.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import pytest
4 |
5 | from substra.sdk import exceptions
6 | from substra.sdk import graph
7 |
8 |
9 | @pytest.fixture
10 | def node_graph():
11 | return {key: list(range(key)) for key in range(10)}
12 |
13 |
14 | @pytest.fixture
15 | def node_graph_linear():
16 | return {key: [key - 1] if key > 0 else list() for key in range(10)}
17 |
18 |
19 | @pytest.fixture
20 | def node_graph_straight_branch():
21 | """
22 | 0 1
23 | |-> 2 <-|
24 | | | |
25 | 3 <---> 4
26 | |
27 | 5
28 | |
29 | 6
30 | """
31 | return {
32 | "0": [],
33 | "1": [],
34 | "2": ["0", "1"],
35 | "3": ["0", "2"],
36 | "4": ["1", "2"],
37 | "5": ["3", "3"], # duplicated link
38 | "6": ["5"],
39 | }
40 |
41 |
42 | def test_compute_ranks(node_graph):
43 | visited = graph.compute_ranks(node_graph=node_graph)
44 | for key, rank in visited.items():
45 | assert key == rank
46 |
47 |
48 | def test_compute_ranks_linear(node_graph_linear):
49 | visited = graph.compute_ranks(node_graph=node_graph_linear)
50 | for key, rank in visited.items():
51 | assert key == rank
52 |
53 |
54 | def test_compute_ranks_no_correlation():
55 | node_graph = {key: list() for key in range(10)}
56 | visited = graph.compute_ranks(node_graph=node_graph)
57 | for _, rank in visited.items():
58 | assert rank == 0
59 |
60 |
61 | def test_compute_ranks_cycle(node_graph):
62 | node_graph[5].append(9)
63 | with pytest.raises(exceptions.InvalidRequest) as e:
64 | graph.compute_ranks(node_graph=node_graph)
65 |
66 | assert "missing dependency among inModels IDs" in str(e.value)
67 |
68 |
69 | def test_compute_ranks_closed_cycle(node_graph_linear):
70 | node_graph_linear[0] = [9]
71 | with pytest.raises(exceptions.InvalidRequest) as e:
72 | graph.compute_ranks(node_graph=node_graph_linear)
73 |
74 | assert "missing dependency among inModels IDs" in str(e.value)
75 |
76 |
77 | def test_compute_ranks_ignore(node_graph):
78 | node_to_ignore = set(range(5))
79 | for i in range(5):
80 | node_graph.pop(i)
81 | visited = graph.compute_ranks(node_graph=node_graph, node_to_ignore=node_to_ignore)
82 | for key, rank in visited.items():
83 | assert rank == key - 5
84 |
85 |
86 | @pytest.mark.parametrize("vertices", itertools.permutations([str(i) for i in range(6)]))
87 | def test_compute_ranks_alpha(vertices):
88 | """
89 | 0 1
90 | |-> 2 <-|
91 | | | |
92 | 3 <---> 4
93 | |
94 | 5
95 | """
96 | node_graph = {
97 | "0": [],
98 | "1": [],
99 | "2": ["0", "1"],
100 | "3": ["0", "2"],
101 | "4": ["1", "2"],
102 | "5": ["3"], # duplicated link
103 | }
104 | ordered_node_graph = {v: node_graph[v] for v in vertices}
105 | visited = graph.compute_ranks(node_graph=ordered_node_graph)
106 | print(list(ordered_node_graph.keys()))
107 | assert visited == {
108 | "0": 0,
109 | "1": 0,
110 | "2": 1,
111 | "3": 2,
112 | "4": 2,
113 | "5": 3,
114 | }
115 |
--------------------------------------------------------------------------------
/tests/sdk/test_list.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from substra.sdk import exceptions
4 | from substra.sdk import models
5 | from substra.sdk import schemas
6 |
7 | from .. import datastore
8 | from ..utils import make_paginated_response
9 | from ..utils import mock_requests
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "asset_type",
14 | [
15 | "dataset",
16 | "function",
17 | "compute_plan",
18 | "data_sample",
19 | "model",
20 | ],
21 | )
22 | def test_list_asset(asset_type, client, mocker):
23 | item = getattr(datastore, asset_type.upper())
24 | method = getattr(client, f"list_{asset_type}")
25 |
26 | mocked_response = make_paginated_response([item])
27 | m = mock_requests(mocker, "get", response=mocked_response)
28 |
29 | response = method()
30 |
31 | assert response == [models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**item)]
32 | m.assert_called()
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "asset_type",
37 | [
38 | "predicttask",
39 | "testtask",
40 | "traintask",
41 | "aggregatetask",
42 | "composite_traintask",
43 | ],
44 | )
45 | def test_list_task(asset_type, client, mocker):
46 | item = getattr(datastore, asset_type.upper())
47 |
48 | mocked_response = make_paginated_response([item])
49 | m = mock_requests(mocker, "get", response=mocked_response)
50 |
51 | response = client.list_task()
52 |
53 | assert response == [models.Task(**item)]
54 | m.assert_called()
55 |
56 |
57 | @pytest.mark.parametrize(
58 | "asset_type,filters",
59 | [
60 | ("dataset", {"permissions": ["foo", "bar"]}),
61 | ("function", {"owner": ["foo", "bar"]}),
62 | ("compute_plan", {"name": "foo"}),
63 | ("compute_plan", {"status": [models.ComputePlanStatus.done.value]}),
64 | ("model", {"owner": ["MyOrg1MSP"]}),
65 | ],
66 | )
67 | def test_list_asset_with_filters(asset_type, filters, client, mocker):
68 | item = getattr(datastore, asset_type.upper())
69 | method = getattr(client, f"list_{asset_type}")
70 |
71 | mocked_response = make_paginated_response([item])
72 | m = mock_requests(mocker, "get", response=mocked_response)
73 |
74 | response = method(filters)
75 |
76 | assert response == [models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**item)]
77 | m.assert_called()
78 |
79 |
80 | @pytest.mark.parametrize(
81 | "filters",
82 | [
83 | {"rank": [1, 3]},
84 | {"key": ["foo", "bar"]},
85 | {"worker": ["foo", "bar"]},
86 | {"owner": ["foo", "bar"]},
87 | ],
88 | )
89 | def test_list_task_with_filters(filters, client, mocker):
90 | items = datastore.TASK_LIST
91 |
92 | mocked_response = make_paginated_response(items)
93 | m = mock_requests(mocker, "get", response=mocked_response)
94 |
95 | response = client.list_task(filters)
96 |
97 | assert response == [models.Task(**item) for item in items]
98 | m.assert_called()
99 |
100 |
101 | def test_list_asset_with_filters_failure(client, mocker):
102 | items = [datastore.FUNCTION]
103 | m = mock_requests(mocker, "get", response=items)
104 |
105 | filters = {"foo"}
106 | with pytest.raises(exceptions.FilterFormatError) as exc_info:
107 | client.list_function(filters)
108 |
109 | m.assert_not_called()
110 | assert str(exc_info.value).startswith("Cannot load filters")
111 |
112 |
113 | def test_list_compute_plan_with_ordering(client, mocker):
114 | item = datastore.COMPUTE_PLAN
115 |
116 | mocked_response = make_paginated_response([item])
117 | m = mock_requests(mocker, "get", response=mocked_response)
118 |
119 | order_by = "start_date"
120 | response = client.list_compute_plan(order_by=order_by)
121 |
122 | assert response == [models.SCHEMA_TO_MODEL[schemas.Type.ComputePlan](**item)]
123 | m.assert_called()
124 |
125 |
126 | def test_list_task_with_ordering(client, mocker):
127 | items = datastore.TASK_LIST
128 |
129 | mocked_response = make_paginated_response(items)
130 | m = mock_requests(mocker, "get", response=mocked_response)
131 |
132 | order_by = "start_date"
133 | response = client.list_task(order_by=order_by)
134 |
135 | assert response == [models.Task(**item) for item in items]
136 | m.assert_called()
137 |
138 |
139 | def test_list_asset_with_ordering_failure(client, mocker):
140 | items = [datastore.COMPUTE_PLAN]
141 | m = mock_requests(mocker, "get", response=items)
142 |
143 | order_by = "foo"
144 | with pytest.raises(exceptions.OrderingFormatError) as exc_info:
145 | client.list_compute_plan(order_by=order_by)
146 |
147 | m.assert_not_called()
148 | assert str(exc_info.value).startswith("Please review the documentation")
149 |
--------------------------------------------------------------------------------
/tests/sdk/test_rest_client.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import pytest
4 | import requests
5 |
6 | from substra.sdk import exceptions
7 | from substra.sdk.backends.remote import rest_client
8 | from substra.sdk.client import Client
9 |
10 | from .. import datastore
11 | from ..utils import mock_requests
12 | from ..utils import mock_requests_responses
13 | from ..utils import mock_response
14 |
15 | CONFIG = {
16 | "url": "http://foo.com",
17 | "insecure": False,
18 | }
19 |
20 | CONFIG_SECURE = {
21 | "url": "http://foo.com",
22 | "insecure": False,
23 | }
24 |
25 | CONFIG_INSECURE = {
26 | "url": "http://foo.com",
27 | "insecure": True,
28 | }
29 |
30 | CONFIGS = [CONFIG, CONFIG_SECURE, CONFIG_INSECURE]
31 |
32 |
33 | def _client_from_config(config):
34 | return rest_client.Client(
35 | config["url"],
36 | config["insecure"],
37 | None,
38 | )
39 |
40 |
41 | @pytest.mark.parametrize("config", CONFIGS)
42 | def test_post_success(mocker, config):
43 | m = mock_requests(mocker, "post", response={})
44 | _client_from_config(config).add("http://foo", {})
45 | assert len(m.call_args_list) == 1
46 |
47 |
48 | @pytest.mark.parametrize("config", CONFIGS)
49 | def test_verify_login(mocker, config):
50 | """
51 | check "insecure" configuration results in endpoints being called with verify=False
52 | """
53 | m_post = mock_requests(mocker, "post", response={"id": "a", "token": "a", "expires_at": "3000-01-01T00:00:00Z"})
54 | m_delete = mock_requests(mocker, "delete", response={})
55 |
56 | c = _client_from_config(config)
57 | c.login("foo", "bar")
58 | c.logout()
59 | if config.get("insecure", None):
60 | assert m_post.call_args.kwargs["verify"] is False
61 | assert m_delete.call_args.kwargs["verify"] is False
62 | else:
63 | assert "verify" not in m_post.call_args.kwargs or m_post.call_args.kwargs["verify"]
64 | assert "verify" not in m_post.call_args.kwargs or m_delete.call_args.kwargs["verify"]
65 |
66 |
67 | @pytest.mark.parametrize(
68 | "status_code, http_response, sdk_exception",
69 | [
70 | (400, {"detail": "Invalid Request"}, exceptions.InvalidRequest),
71 | (401, {"detail": "Invalid username/password"}, exceptions.AuthenticationError),
72 | (403, {"detail": "Unauthorized"}, exceptions.AuthorizationError),
73 | (404, {"detail": "Not Found"}, exceptions.NotFound),
74 | (408, {"key": "a-key"}, exceptions.RequestTimeout),
75 | (408, {}, exceptions.RequestTimeout),
76 | (500, "CRASH", exceptions.InternalServerError),
77 | ],
78 | )
79 | def test_request_http_errors(mocker, status_code, http_response, sdk_exception):
80 | m = mock_requests(mocker, "post", response=http_response, status=status_code)
81 | with pytest.raises(sdk_exception):
82 | _client_from_config(CONFIG).add("http://foo", {})
83 | assert len(m.call_args_list) == 1
84 |
85 |
86 | def test_request_connection_error(mocker):
87 | mocker.patch(
88 | "substra.sdk.backends.remote.rest_client.requests.post", side_effect=requests.exceptions.ConnectionError
89 | )
90 | with pytest.raises(exceptions.ConnectionError):
91 | _client_from_config(CONFIG).add("foo", {})
92 |
93 |
94 | def test_add_timeout_with_retry(mocker):
95 | asset_type = "traintask"
96 | responses = [
97 | mock_response(response={"key": "a-key"}, status=408),
98 | mock_response(response={"key": "a-key"}),
99 | ]
100 | m_post = mock_requests_responses(mocker, "post", responses)
101 | asset = _client_from_config(CONFIG).add(asset_type, retry_timeout=60)
102 | assert len(m_post.call_args_list) == 2
103 | assert asset == {"key": "a-key"}
104 |
105 |
106 | def test_add_already_exist(mocker):
107 | asset_type = "traintask"
108 | m_post = mock_requests(mocker, "post", response={"key": "a-key"}, status=409)
109 | asset = _client_from_config(CONFIG).add(asset_type)
110 | assert len(m_post.call_args_list) == 1
111 | assert asset == {"key": "a-key"}
112 |
113 |
114 | def test_add_wrong_url(mocker):
115 | """Check correct error is raised when wrong url with correct syntax is set."""
116 | error = json.decoder.JSONDecodeError("", "", 0)
117 |
118 | mock_requests(mocker, "post", status=200, json_error=error)
119 | test_client = Client(url="http://www.dummy.com", token="foo")
120 | with pytest.raises(exceptions.BadConfiguration) as e:
121 | test_client.login("test_client", "hehe")
122 | assert "Make sure that given url" in e.value.args[0]
123 |
124 |
125 | def test_list_paginated(mocker):
126 | asset_type = "traintask"
127 | items = [datastore.TRAINTASK, datastore.TRAINTASK]
128 | responses = [
129 | mock_response(
130 | response={
131 | "count": len(items),
132 | "next": "http://foo.com/?page=2",
133 | "previous": None,
134 | "results": items[:1],
135 | },
136 | status=200,
137 | ),
138 | mock_response(
139 | response={
140 | "count": len(items),
141 | "next": None,
142 | "previous": "http://foo.com/?page=1",
143 | "results": items[1:],
144 | },
145 | status=200,
146 | ),
147 | ]
148 | m_get = mock_requests_responses(mocker, "get", responses)
149 | asset = _client_from_config(CONFIG).list(asset_type)
150 | assert len(asset) == len(items)
151 | assert len(m_get.call_args_list) == 2
152 |
153 |
154 | def test_list_not_paginated(mocker):
155 | asset_type = "traintask"
156 | items = [datastore.TRAINTASK, datastore.TRAINTASK]
157 | m_get = mock_requests(
158 | mocker,
159 | "get",
160 | response={
161 | "count": len(items),
162 | "next": "http://foo.com/?page=2",
163 | "previous": None,
164 | "results": items[:1],
165 | },
166 | status=200,
167 | )
168 | asset = _client_from_config(CONFIG).list(asset_type, paginated=False)
169 | assert len(asset) != len(items)
170 | assert len(m_get.call_args_list) == 1
171 |
--------------------------------------------------------------------------------
/tests/sdk/test_schemas.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import uuid
3 |
4 | import pytest
5 |
6 | from substra.sdk.schemas import DataSampleSpec
7 | from substra.sdk.schemas import DatasetSpec
8 | from substra.sdk.schemas import Permissions
9 |
10 |
11 | @pytest.mark.parametrize("path", [pathlib.Path() / "data", "./data", pathlib.Path().cwd() / "data"])
12 | def test_datasample_spec_resolve_path(path):
13 | datasample_spec = DataSampleSpec(path=path, data_manager_keys=[str(uuid.uuid4())])
14 |
15 | assert datasample_spec.path == pathlib.Path().cwd() / "data"
16 |
17 |
18 | def test_datasample_spec_resolve_paths():
19 | paths = [pathlib.Path() / "data", "./data", pathlib.Path().cwd() / "data"]
20 | datasample_spec = DataSampleSpec(paths=paths, data_manager_keys=[str(uuid.uuid4())])
21 |
22 | assert all([path == pathlib.Path().cwd() / "data" for path in datasample_spec.paths])
23 |
24 |
25 | def test_datasample_spec_exclusive_path():
26 | with pytest.raises(ValueError):
27 | DataSampleSpec(paths=["fake_paths"], path="fake_paths", data_manager_keys=[str(uuid.uuid4())])
28 |
29 |
30 | def test_datasample_spec_no_path():
31 | with pytest.raises(ValueError):
32 | DataSampleSpec(data_manager_keys=[str(uuid.uuid4())])
33 |
34 |
35 | def test_datasample_spec_paths_set_to_none():
36 | with pytest.raises(ValueError):
37 | DataSampleSpec(paths=None, data_manager_keys=[str(uuid.uuid4())])
38 |
39 |
40 | def test_datasample_spec_path_set_to_none():
41 | with pytest.raises(ValueError):
42 | DataSampleSpec(path=None, data_manager_keys=[str(uuid.uuid4())])
43 |
44 |
45 | def test_dataset_spec_no_description(tmpdir):
46 |
47 | opener_path = tmpdir / "fake_opener.py"
48 | permissions = Permissions(public=True, authorized_ids=[])
49 |
50 | DatasetSpec(
51 | name="Fake Dataset",
52 | data_opener=str(opener_path),
53 | permissions=permissions,
54 | logs_permission=permissions,
55 | )
56 |
57 | assert (pathlib.Path(opener_path).parent / "generated_description.md").exists
58 |
--------------------------------------------------------------------------------
/tests/sdk/test_subprocess.py:
--------------------------------------------------------------------------------
1 | import string
2 |
3 | from substra.sdk.backends.local.compute.spawner.subprocess import _get_command_args
4 |
5 |
6 | def test_get_command_args_without_space():
7 | # check that it's not changing the path if there is no spaces
8 | function_name = "train"
9 | command = ["--opener-path", "${_VOLUME_OPENER}", "--compute-plan-path", "${_VOLUME_LOCAL}"]
10 | command_template = [string.Template(part) for part in command]
11 | local_volumes = {
12 | "_VOLUME_OPENER": "/a/path/without/any/space/opener.py",
13 | "_VOLUME_LOCAL": "/another/path/without/any/space",
14 | }
15 |
16 | py_commands = _get_command_args(function_name, command_template, local_volumes)
17 |
18 | valid_py_commands = [
19 | "--function-name",
20 | "train",
21 | "--opener-path",
22 | "/a/path/without/any/space/opener.py",
23 | "--compute-plan-path",
24 | "/another/path/without/any/space",
25 | ]
26 |
27 | assert py_commands == valid_py_commands
28 |
29 |
30 | def test_get_command_args_with_spaces():
31 | # check that it's not splitting path with spaces in different arguments
32 | function_name = "train"
33 | command = ["--opener-path", "${_VOLUME_OPENER}", "--compute-plan-path", "${_VOLUME_LOCAL}"]
34 | command_template = [string.Template(part) for part in command]
35 | local_volumes = {
36 | "_VOLUME_OPENER": "/a/path with spaces/opener.py",
37 | "_VOLUME_LOCAL": "/another/path with spaces",
38 | }
39 |
40 | py_commands = _get_command_args(function_name, command_template, local_volumes)
41 |
42 | valid_py_commands = [
43 | "--function-name",
44 | "train",
45 | "--opener-path",
46 | "/a/path with spaces/opener.py",
47 | "--compute-plan-path",
48 | "/another/path with spaces",
49 | ]
50 |
51 | assert py_commands == valid_py_commands
52 |
--------------------------------------------------------------------------------
/tests/sdk/test_update.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from substra.sdk import models
4 | from substra.sdk import schemas
5 |
6 | from .. import datastore
7 | from ..utils import mock_requests
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "asset_type",
12 | ["dataset", "function", "compute_plan"],
13 | )
14 | def test_update_asset(asset_type, client, mocker):
15 | update_method = getattr(client, f"update_{asset_type}")
16 | get_method = getattr(client, f"get_{asset_type}")
17 |
18 | item = getattr(datastore, asset_type.upper())
19 | name_update = {"name": "New name"}
20 | updated_item = {**item, **name_update}
21 |
22 | m_put = mock_requests(mocker, "put")
23 | m_get = mock_requests(mocker, "get", response=updated_item)
24 |
25 | update_method("magic-key", name_update["name"])
26 | response = get_method("magic-key")
27 |
28 | assert response == models.SCHEMA_TO_MODEL[schemas.Type(asset_type)](**updated_item)
29 | m_put.assert_called()
30 | m_get.assert_called()
31 |
32 |
33 | def test_add_compute_plan_tasks(client, mocker):
34 | item = datastore.COMPUTE_PLAN
35 | m = mock_requests(mocker, "post", response=item)
36 | m_get = mock_requests(mocker, "get", response=datastore.COMPUTE_PLAN)
37 |
38 | response = client.add_compute_plan_tasks("foo", {})
39 |
40 | assert response == models.ComputePlan(**item)
41 | m.assert_called
42 | m_get.assert_called
43 |
44 |
45 | def test_add_compute_plan_tasks_with_schema(client, mocker):
46 | item = datastore.COMPUTE_PLAN
47 | m = mock_requests(mocker, "post", response=item)
48 | m_get = mock_requests(mocker, "get", response=datastore.COMPUTE_PLAN)
49 |
50 | response = client.add_compute_plan_tasks("foo", schemas.UpdateComputePlanTasksSpec(key="foo"))
51 |
52 | assert response == models.ComputePlan(**item)
53 | m.assert_called
54 | m_get.assert_called
55 |
--------------------------------------------------------------------------------
/tests/sdk/test_wait.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext as does_not_raise
2 |
3 | import pytest
4 |
5 | from substra.sdk import exceptions
6 | from substra.sdk.models import ComputePlanStatus
7 | from substra.sdk.models import ComputeTaskStatus
8 | from substra.sdk.models import TaskErrorType
9 |
10 | from .. import datastore
11 | from ..utils import mock_requests
12 |
13 |
14 | def _param_name_maker(arg):
15 | if isinstance(arg, str):
16 | return arg
17 | else:
18 | return ""
19 |
20 |
21 | @pytest.mark.parametrize(
22 | ("asset_dict", "function_name", "status", "expectation"),
23 | [
24 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.done, does_not_raise()),
25 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.canceled, pytest.raises(exceptions.FutureFailureError)),
26 | (datastore.COMPUTE_PLAN, "wait_compute_plan", ComputePlanStatus.done, does_not_raise()),
27 | (
28 | datastore.COMPUTE_PLAN,
29 | "wait_compute_plan",
30 | ComputePlanStatus.failed,
31 | pytest.raises(exceptions.FutureFailureError),
32 | ),
33 | (
34 | datastore.COMPUTE_PLAN,
35 | "wait_compute_plan",
36 | ComputePlanStatus.canceled,
37 | pytest.raises(exceptions.FutureFailureError),
38 | ),
39 | ],
40 | ids=_param_name_maker,
41 | )
42 | def test_wait(client, mocker, asset_dict, function_name, status, expectation):
43 | item = {**asset_dict, "status": status}
44 | mock_requests(mocker, "get", item)
45 | function = getattr(client, function_name)
46 | with expectation:
47 | function(key=item["key"])
48 |
49 |
50 | def test_wait_task_failed(client, mocker):
51 | # We need an error type to stop the iteration
52 | item = {**datastore.TRAINTASK, "status": ComputeTaskStatus.failed, "error_type": TaskErrorType.internal}
53 | mock_requests(mocker, "get", item)
54 | with pytest.raises(exceptions.FutureFailureError):
55 | client.wait_task(key=item["key"])
56 |
57 |
58 | @pytest.mark.parametrize(
59 | ("asset_dict", "function_name", "status"),
60 | [
61 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.waiting_for_parent_tasks),
62 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.waiting_for_builder_slot),
63 | (datastore.TRAINTASK, "wait_task", ComputeTaskStatus.waiting_for_executor_slot),
64 | (datastore.COMPUTE_PLAN, "wait_compute_plan", ComputePlanStatus.created),
65 | ],
66 | ids=_param_name_maker,
67 | )
68 | def test_wait_timeout(client, mocker, asset_dict, function_name, status):
69 | item = {**asset_dict, "status": status}
70 | mock_requests(mocker, "get", item)
71 | function = getattr(client, function_name)
72 | with pytest.raises(exceptions.FutureTimeoutError):
73 | # mock_requests returns only once and timeout=0 is falsy, so setting a microscopic duration
74 | function(key=item["key"], timeout=1e-10)
75 |
--------------------------------------------------------------------------------
/tests/test_request_formatter.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from substra.sdk.backends.remote import request_formatter
4 |
5 |
6 | @pytest.mark.parametrize(
7 | "raw,formatted",
8 | [
9 | ({"foo": ["bar", "baz"]}, {"foo": "bar,baz"}),
10 | ({"foo": ["bar", "baz"], "bar": ["qux"]}, {"foo": "bar,baz", "bar": "qux"}),
11 | ({"foo": ["b ar ", " baz"]}, {"foo": "bar,baz"}),
12 | (
13 | {},
14 | {},
15 | ),
16 | ({"name": "bar,baz"}, {"match": "bar,baz"}),
17 | (
18 | {"metadata": [{"key": "epochs", "type": "is", "value": "10"}]},
19 | {"metadata": '[{"key":"epochs","type":"is","value":"10"}]'},
20 | ),
21 | (None, {}),
22 | ],
23 | )
24 | def test_format_search_filters_for_remote(raw, formatted):
25 | assert request_formatter.format_search_filters_for_remote(raw) == formatted
26 |
27 |
28 | @pytest.mark.parametrize(
29 | "ordering, ascending, formatted",
30 | [
31 | ("creation_date", False, "-creation_date"),
32 | ("start_date", True, "start_date"),
33 | ],
34 | )
35 | def test_format_search_ordering_for_remote(ordering, ascending, formatted):
36 | assert request_formatter.format_search_ordering_for_remote(ordering, ascending) == formatted
37 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 |
4 | import pytest
5 |
6 | import substra
7 | from substra.sdk import exceptions
8 | from substra.sdk import schemas
9 | from substra.sdk import utils
10 |
11 |
12 | def _unzip(fp, destination):
13 | with zipfile.ZipFile(fp, "r") as zipf:
14 | zipf.extractall(destination)
15 |
16 |
17 | def test_zip_folder(tmp_path):
18 | # initialise dir to zip
19 | dir_to_zip = tmp_path / "dir"
20 | dir_to_zip.mkdir()
21 |
22 | file_items = [
23 | ("name0.txt", "content0"),
24 | ("dir1/name1.txt", "content1"),
25 | ("dir2/name2.txt", "content2"),
26 | ]
27 |
28 | for name, content in file_items:
29 | path = dir_to_zip / name
30 | path.parents[0].mkdir(exist_ok=True)
31 | path.write_text(content)
32 |
33 | for name, _ in file_items:
34 | path = dir_to_zip / name
35 | assert os.path.exists(str(path))
36 |
37 | # zip dir
38 | fp = utils.zip_folder_in_memory(str(dir_to_zip))
39 | assert fp
40 |
41 | # unzip dir
42 | destination_dir = tmp_path / "destination"
43 | destination_dir.mkdir()
44 | _unzip(fp, str(destination_dir))
45 | for name, content in file_items:
46 | path = destination_dir / name
47 | assert os.path.exists(str(path))
48 | assert path.read_text() == content
49 |
50 |
51 | @pytest.mark.parametrize(
52 | "filters,expected,exception",
53 | [
54 | ("str", None, exceptions.FilterFormatError),
55 | ({}, None, exceptions.FilterFormatError),
56 | (
57 | [{"key": "foo", "type": "bar", "value": "baz"}],
58 | None,
59 | exceptions.FilterFormatError,
60 | ),
61 | ([{"key": "foo", "type": "is", "value": "baz"}, {}], None, exceptions.FilterFormatError),
62 | ([{"key": "foo", "type": "is", "value": "baz"}], None, None),
63 | ],
64 | )
65 | def test_check_metadata_search_filter(filters, expected, exception):
66 | if exception:
67 | with pytest.raises(exception):
68 | utils._check_metadata_search_filters(filters)
69 | else:
70 | assert utils._check_metadata_search_filters(filters) == expected
71 |
72 |
73 | @pytest.mark.parametrize(
74 | "asset_type,filters,expected,exception",
75 | [
76 | (
77 | schemas.Type.ComputePlan,
78 | {"status": [substra.models.ComputePlanStatus.doing.value]},
79 | {"status": [substra.models.ComputePlanStatus.doing.value]},
80 | None,
81 | ),
82 | (
83 | schemas.Type.Task,
84 | {"status": [substra.models.ComputeTaskStatus.done.value]},
85 | {"status": [substra.models.ComputeTaskStatus.done.value]},
86 | None,
87 | ),
88 | (schemas.Type.Task, {"rank": [1]}, {"rank": ["1"]}, None),
89 | (schemas.Type.DataSample, ["wrong filter type"], None, exceptions.FilterFormatError),
90 | (schemas.Type.ComputePlan, {"name": ["list"]}, None, exceptions.FilterFormatError),
91 | (schemas.Type.Task, {"foo": "not allowed key"}, None, exceptions.NotAllowedFilterError),
92 | (
93 | schemas.Type.ComputePlan,
94 | {"name": "cp1", "key": ["key1", "key2"]},
95 | {"name": "cp1", "key": ["key1", "key2"]},
96 | None,
97 | ),
98 | ],
99 | )
100 | def test_check_and_format_search_filters(asset_type, filters, expected, exception):
101 | if exception:
102 | with pytest.raises(exception):
103 | utils.check_and_format_search_filters(asset_type, filters)
104 | else:
105 | assert utils.check_and_format_search_filters(asset_type, filters) == expected
106 |
107 |
108 | @pytest.mark.parametrize(
109 | "ordering, exception",
110 | [
111 | ("creation_date", None),
112 | ("start_date", None),
113 | ("foo", exceptions.OrderingFormatError),
114 | (None, None),
115 | ],
116 | )
117 | def test_check_search_ordering(ordering, exception):
118 | if exception:
119 | with pytest.raises(exception):
120 | utils.check_search_ordering(ordering)
121 | else:
122 | utils.check_search_ordering(ordering)
123 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | from unittest import mock
2 |
3 | import requests
4 |
5 |
6 | def mock_response(response=None, status=200, headers=None, json_error=None):
7 | headers = headers or {}
8 | m = mock.MagicMock(spec=requests.Response)
9 | m.status_code = status
10 | m.headers = headers
11 | m.text = str(response)
12 | m.json = mock.MagicMock(return_value=response, headers=headers, side_effect=json_error)
13 |
14 | if status not in (200, 201):
15 | exception = requests.exceptions.HTTPError(str(status), response=m)
16 | m.raise_for_status = mock.MagicMock(side_effect=exception)
17 |
18 | return m
19 |
20 |
21 | def mock_requests_responses(mocker, method, responses):
22 | return mocker.patch(
23 | f"substra.sdk.backends.remote.rest_client.requests.{method}",
24 | side_effect=responses,
25 | )
26 |
27 |
28 | def mock_requests(mocker, method, response=None, status=200, headers=None, json_error=None):
29 | r = mock_response(response, status, headers, json_error)
30 | return mock_requests_responses(mocker, method, (r,))
31 |
32 |
33 | def make_paginated_response(items):
34 | return {"count": len(items), "next": None, "previous": None, "results": items}
35 |
--------------------------------------------------------------------------------