├── changes └── .gitkeep ├── .github ├── CODEOWNERS ├── workflows │ ├── publish.yml │ ├── towncrier-changelog.yml │ └── python.yml └── ISSUE_TEMPLATE │ ├── config.yml │ └── bug_report.yaml ├── substratools ├── __version__.py ├── __init__.py ├── exceptions.py ├── workspace.py ├── opener.py ├── task_resources.py ├── utils.py └── function.py ├── tests ├── __init__.py ├── test_genericalgo.py ├── test_utils.py ├── utils.py ├── conftest.py ├── test_task_resources.py ├── test_opener.py ├── test_metrics.py ├── test_workflow.py ├── test_aggregatealgo.py ├── test_compositealgo.py └── test_function.py ├── .git-blame-ignore-revs ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Makefile ├── .coveragerc ├── .flake8 ├── .pre-commit-config.yaml ├── CONTRIBUTORS.md ├── .gitignore ├── pyproject.toml ├── README.md ├── CHANGELOG.md ├── Substra-logo-white.svg ├── Substra-logo-colour.svg └── LICENSE /changes/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @Substra/code-owners 2 | -------------------------------------------------------------------------------- /substratools/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from tests import utils 2 | 3 | __all__ = ["utils"] 4 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | 25ad9f6b8f17422ac8a50cfc5f4604a26ae7f21a 2 | f587640673be0bb0789cc0036b7af12516a1ff9b 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | Substra repositories' code of conduct is available in the Substra documentation [here](https://docs.substra.org/en/stable/contributing/code-of-conduct.html). 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Substra repositories' contributing guide is available in the Substra documentation [here](https://docs.substra.org/en/stable/contributing/contributing-guide.html). 2 | -------------------------------------------------------------------------------- /tests/test_genericalgo.py: -------------------------------------------------------------------------------- 1 | # TODO: As the implementation is going to change from a class to a function 2 | # decorator, those test will be added when the implementation is stable 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test doc 2 | 3 | DOCS_FILEPATH = docs/api.md 4 | 5 | ifeq ($(TAG),) 6 | TAG := $(shell git describe --always --tags) 7 | endif 8 | 9 | test: 10 | pytest tests 11 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | 3 | [report] 4 | # Regexes for lines to exclude from consideration 5 | exclude_lines = 6 | 7 | # Don't complain if tests don't hit defensive assertion code: 8 | raise AssertionError 9 | raise NotImplementedError 10 | 11 | show_missing = True 12 | -------------------------------------------------------------------------------- /substratools/__init__.py: -------------------------------------------------------------------------------- 1 | from substratools.__version__ import __version__ 2 | 3 | from . import function 4 | from . import opener 5 | from .function import execute 6 | from .function import load_performance 7 | from .function import register 8 | from .function import save_performance 9 | from .opener import Opener 10 | 11 | __all__ = [ 12 | "__version__", 13 | function, 14 | opener, 15 | Opener, 16 | execute, 17 | load_performance, 18 | register, 19 | save_performance, 20 | ] 21 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: 3.12 15 | - name: Install Hatch 16 | run: pipx install hatch 17 | - name: Build dist 18 | run: hatch build 19 | - name: Publish 20 | run: hatch publish -u __token__ -a ${{ secrets.PYPI_API_TOKEN }} 21 | 22 | -------------------------------------------------------------------------------- /substratools/exceptions.py: -------------------------------------------------------------------------------- 1 | class InvalidInterfaceError(Exception): 2 | pass 3 | 4 | 5 | class EmptyInterfaceError(InvalidInterfaceError): 6 | pass 7 | 8 | 9 | class NotAFileError(Exception): 10 | pass 11 | 12 | 13 | class MissingFileError(Exception): 14 | pass 15 | 16 | 17 | class InvalidInputOutputsError(Exception): 18 | pass 19 | 20 | 21 | class InvalidCLIError(Exception): 22 | pass 23 | 24 | 25 | class FunctionNotFoundError(Exception): 26 | pass 27 | 28 | 29 | class ExistingRegisteredFunctionError(Exception): 30 | pass 31 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | max-complexity = 10 4 | extend-ignore = E203, W503, N802, N803, N806 5 | # W503 is incompatible with flake8, see https://github.com/psf/black/pull/36 6 | # E203 must be disabled for flake8 to work with Black. 7 | # See https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#id1 8 | # N802, N803 and N806 prevent us from using upper cases in variables names, functions name and arguments. 9 | exclude = 10 | .git 11 | .github 12 | .dvc 13 | __pycache__ 14 | .venv 15 | .mypy_cache 16 | .pytest_cache 17 | hubconf.py 18 | **local-worker 19 | docs/* 20 | -------------------------------------------------------------------------------- /.github/workflows/towncrier-changelog.yml: -------------------------------------------------------------------------------- 1 | name: Towncrier changelog 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | app_version: 7 | type: string 8 | description: 'The version of the app' 9 | required: true 10 | branch: 11 | type: string 12 | description: 'The branch to update' 13 | required: true 14 | 15 | jobs: 16 | test-generate-publish: 17 | uses: substra/substra-gha-workflows/.github/workflows/towncrier-changelog.yml@main 18 | secrets: inherit 19 | with: 20 | app_version: ${{ inputs.app_version }} 21 | repo: substra-tools 22 | branch: ${{ inputs.branch }} 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Ask a question 4 | url: https://join.slack.com/t/substra-workspace/shared_invite/zt-1fqnk0nw6-xoPwuLJ8dAPXThfyldX8yA 5 | about: Don't hesitate to join the Substra community on Slack to ask all your questions! 6 | - name: User Documentation Improvement 7 | url: https://github.com/Substra/substra-documentation/issues 8 | about: For issues related to the User Documentation, please open an issue on the substra-documentation repository 9 | - name: Feature request 10 | url: https://github.com/Substra/substra/issues 11 | about: We centralize feature requests in the substra repository, please open an issue there 12 | -------------------------------------------------------------------------------- /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | name: Python 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - main 8 | pull_request: 9 | branches: 10 | - master 11 | - main 12 | 13 | jobs: 14 | lint: 15 | name: Lint and tests 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.12 23 | - name: Install tools 24 | run: pip install flake8 25 | - name: Lint 26 | run: flake8 substratools 27 | - name: Install substra-tools 28 | run: pip install -e '.[dev]' 29 | - name: Test 30 | run: | 31 | make test 32 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: ^docs/ 2 | repos: 3 | - repo: https://github.com/psf/black 4 | rev: 22.3.0 5 | hooks: 6 | - id: black 7 | args: [--line-length=120] 8 | 9 | - repo: https://github.com/pycqa/flake8 10 | rev: 4.0.1 11 | hooks: 12 | - id: flake8 13 | additional_dependencies: [pep8-naming, flake8-bugbear] 14 | exclude: docs/api.md 15 | 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.12.0 18 | hooks: 19 | - id: isort 20 | args: [--line-length=120] 21 | name: isort (python) 22 | 23 | - repo: https://github.com/pre-commit/pre-commit-hooks 24 | rev: v3.2.0 25 | hooks: 26 | - id: trailing-whitespace 27 | - id: end-of-file-fixer 28 | - id: debug-statements 29 | - id: check-added-large-files 30 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | This is a file of people that have made significant contributions to the Substra tools. It is sorted in chronological order. Please include your contribution at the bottom of this document in the following format : name (N), email (E), description of work (W) and date (D). 2 | 3 | To have your contribution listed, your work must meet the minimum [threshold of originality](https://en.wikipedia.org/wiki/Threshold_of_originality), which will be evaluated by the maintainers of the repository. 4 | 5 | Thank you for your contribution, your work is greatly appreciated ! 6 | 7 | —-- Example —-- 8 | 9 | - N: John Doe 10 | - E: john.doe@owkin.com 11 | - W: Integrated new feature 12 | - D: 02/02/2023 13 | 14 | --- 15 | 16 | Copyright (c) 2018-present Owkin Inc. All rights reserved. 17 | 18 | All other contributions: 19 | Copyright (c) 2023 to the respective contributors. 20 | All rights reserved. 21 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | 5 | from substratools import exceptions 6 | from substratools.opener import Opener 7 | from substratools.utils import get_logger 8 | from substratools.utils import import_module 9 | from substratools.utils import load_interface_from_module 10 | 11 | 12 | def test_invalid_interface(): 13 | code = """ 14 | def score(): 15 | pass 16 | """ 17 | import_module("score", code) 18 | with pytest.raises(exceptions.InvalidInterfaceError): 19 | load_interface_from_module("score", interface_class=Opener) 20 | 21 | 22 | @pytest.fixture 23 | def syspaths(): 24 | copy = sys.path[:] 25 | yield sys.path 26 | sys.path = copy 27 | 28 | 29 | def test_empty_module(tmpdir, syspaths): 30 | with tmpdir.as_cwd(): 31 | # python allows to import an empty directoy 32 | # check that the error message would be helpful for debugging purposes 33 | tmpdir.mkdir("foomod") 34 | syspaths.append(str(tmpdir)) 35 | 36 | with pytest.raises(exceptions.EmptyInterfaceError): 37 | load_interface_from_module("foomod", interface_class=Opener) 38 | 39 | 40 | def test_get_logger(capfd): 41 | logger = get_logger("test") 42 | logger.info("message") 43 | captured = capfd.readouterr() 44 | assert "INFO substratools.test - message" in captured.err 45 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from enum import Enum 3 | from os import PathLike 4 | from typing import Any 5 | from typing import List 6 | 7 | from substratools.task_resources import StaticInputIdentifiers 8 | 9 | 10 | class InputIdentifiers(str, Enum): 11 | local = "local" 12 | shared = "shared" 13 | predictions = "predictions" 14 | opener = StaticInputIdentifiers.opener.value 15 | datasamples = StaticInputIdentifiers.datasamples.value 16 | rank = StaticInputIdentifiers.rank.value 17 | 18 | 19 | class OutputIdentifiers(str, Enum): 20 | local = "local" 21 | shared = "shared" 22 | predictions = "predictions" 23 | performance = "performance" 24 | 25 | 26 | def load_models(paths: List[PathLike]) -> dict: 27 | models = [] 28 | for model_path in paths: 29 | with open(model_path, "r") as f: 30 | models.append(json.load(f)) 31 | 32 | return models 33 | 34 | 35 | def load_model(path: PathLike): 36 | if path: 37 | with open(path, "r") as f: 38 | return json.load(f) 39 | 40 | 41 | def save_model(model: dict, path: PathLike): 42 | with open(path, "w") as f: 43 | json.dump(model, f) 44 | 45 | 46 | def save_predictions(predictions: Any, path: PathLike): 47 | with open(path, "w") as f: 48 | json.dump(predictions, f) 49 | 50 | 51 | def load_predictions(path: PathLike) -> Any: 52 | with open(path, "r") as f: 53 | predictions = json.load(f) 54 | return predictions 55 | 56 | 57 | def no_save_model(path, model): 58 | # do not save model at all 59 | pass 60 | 61 | 62 | def wrong_save_model(model, path): 63 | # simulate numpy.save behavior 64 | with open(path + ".npy", "w") as f: 65 | json.dump(model, f) 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | docs/*.tmp 107 | 108 | .vscode 109 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.build.targets.sdist] 6 | exclude = ["tests*"] 7 | 8 | [tool.hatch.version] 9 | path = "substratools/__version__.py" 10 | 11 | [project] 12 | name = "substratools" 13 | description = "Python tools to submit functions on the Substra platform" 14 | dynamic = ["version"] 15 | readme = "README.md" 16 | requires-python = ">= 3.10" 17 | dependencies = [] 18 | keywords = ["substra"] 19 | classifiers = [ 20 | "Intended Audience :: Developers", 21 | "Topic :: Utilities", 22 | "Natural Language :: English", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | ] 28 | license = { file = "LICENSE" } 29 | authors = [{ name = "Owkin, Inc." }] 30 | 31 | [project.optional-dependencies] 32 | dev = ["flake8", "pytest", "pytest-cov", "pytest-mock", "numpy", "towncrier"] 33 | 34 | [project.urls] 35 | Documentation = "https://docs.substra.org/en/stable/" 36 | Repository = "https://github.com/Substra/substra-tools" 37 | Changelog = "https://github.com/Substra/substra-tools/blob/main/CHANGELOG.md" 38 | 39 | [tool.black] 40 | line-length = 120 41 | target-version = ['py39'] 42 | 43 | [tool.isort] 44 | filter_files = true 45 | force_single_line = true 46 | line_length = 120 47 | profile = "black" 48 | 49 | [tool.pytest.ini_options] 50 | addopts = "-v --cov=substratools" 51 | 52 | [tool.towncrier] 53 | directory = "changes" 54 | filename = "CHANGELOG.md" 55 | start_string = "\n" 56 | underlines = ["", "", ""] 57 | title_format = "## [{version}](https://github.com/Substra/substra-tools/releases/tag/{version}) - {project_date}" 58 | issue_format = "[#{issue}](https://github.com/Substra/substra-tools/pull/{issue})" 59 | [tool.towncrier.fragment.added] 60 | [tool.towncrier.fragment.removed] 61 | [tool.towncrier.fragment.changed] 62 | [tool.towncrier.fragment.fixed] 63 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from pathlib import Path 5 | from uuid import uuid4 6 | 7 | import pytest 8 | 9 | from substratools.task_resources import TaskResources 10 | from substratools.utils import import_module 11 | from substratools.workspace import FunctionWorkspace 12 | from tests.utils import OutputIdentifiers 13 | 14 | 15 | @pytest.fixture 16 | def workdir(tmp_path): 17 | d = tmp_path / "substra-workspace" 18 | d.mkdir() 19 | return d 20 | 21 | 22 | @pytest.fixture(autouse=True) 23 | def patch_cwd(monkeypatch, workdir): 24 | # this is needed to ensure the workspace is located in a tmpdir 25 | def getcwd(): 26 | return str(workdir) 27 | 28 | monkeypatch.setattr(os, "getcwd", getcwd) 29 | 30 | 31 | @pytest.fixture() 32 | def valid_opener_code(): 33 | return """ 34 | import json 35 | from substratools import Opener 36 | 37 | class FakeOpener(Opener): 38 | def get_data(self, folder): 39 | return 'X', list(range(0, 3)) 40 | 41 | def fake_data(self, n_samples): 42 | return ['Xfake'] * n_samples, [0] * n_samples 43 | """ 44 | 45 | 46 | @pytest.fixture() 47 | def valid_opener(valid_opener_code): 48 | import_module("opener", valid_opener_code) 49 | yield 50 | del sys.modules["opener"] 51 | 52 | 53 | @pytest.fixture() 54 | def valid_opener_script(workdir, valid_opener_code): 55 | opener_path = workdir / "my_opener.py" 56 | opener_path.write_text(valid_opener_code) 57 | 58 | return str(opener_path) 59 | 60 | 61 | @pytest.fixture(autouse=True) 62 | def output_model_path(workdir: Path) -> str: 63 | path = workdir / str(uuid4()) 64 | yield path 65 | if path.exists(): 66 | os.remove(path) 67 | 68 | 69 | @pytest.fixture(autouse=True) 70 | def output_model_path_2(workdir: Path) -> str: 71 | path = workdir / str(uuid4()) 72 | yield path 73 | if path.exists(): 74 | os.remove(path) 75 | 76 | 77 | @pytest.fixture() 78 | def valid_function_workspace(output_model_path: str) -> FunctionWorkspace: 79 | workspace_outputs = TaskResources( 80 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) 81 | ) 82 | 83 | workspace = FunctionWorkspace(outputs=workspace_outputs) 84 | 85 | return workspace 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Substra-tools 2 | 3 |
4 | 5 |

6 | 7 |
8 | 9 | 10 | 11 | 12 | Substra 13 | 14 |
15 |
16 |
17 | 18 | Substra is an open source federated learning (FL) software. This specific repository, substra-tools, is a Python package defining base classes for Dataset (data opener script) and wrappers to execute functions submitted on the platform. 19 | 20 | 21 | ## Getting started 22 | 23 | To install the substratools Python package, run the following command: 24 | 25 | ```sh 26 | pip install substratools 27 | ``` 28 | 29 | ## Developers 30 | 31 | Clone the repository: 32 | 33 | 34 | ### Setup 35 | 36 | To setup the project in development mode, run: 37 | 38 | ```sh 39 | pip install -e ".[dev]" 40 | ``` 41 | 42 | To run all tests, use the following command: 43 | 44 | ```sh 45 | make test 46 | ``` 47 | 48 | ## How to generate the changelog 49 | 50 | The changelog is managed with [towncrier](https://towncrier.readthedocs.io/en/stable/index.html). 51 | To add a new entry in the changelog, add a file in the `changes` folder. The file name should have the following structure: 52 | `.`. 53 | The `unique_id` is a unique identifier, we currently use the PR number. 54 | The `change_type` can be of the following types: `added`, `changed`, `removed`, `fixed`. 55 | 56 | To generate the changelog (for example during a release), use the following command (you must have the dev dependencies installed): 57 | 58 | ``` 59 | towncrier build --version= 60 | ``` 61 | 62 | You can use the `--draft` option to see what would be generated without actually writing to the changelog (and without removing the fragments). 63 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Report bug or performance issue 3 | title: "BUG: " 4 | labels: [Bug] 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | Thanks for taking the time to fill out this bug report! 11 | - type: textarea 12 | id: context 13 | attributes: 14 | label: What are you trying to do? 15 | description: > 16 | Please provide some context on what you are trying to achieve. 17 | placeholder: 18 | validations: 19 | required: true 20 | - type: textarea 21 | id: issue-description 22 | attributes: 23 | label: Issue Description (what is happening?) 24 | description: > 25 | Please provide a description of the issue. 26 | validations: 27 | required: true 28 | - type: textarea 29 | id: expected-behavior 30 | attributes: 31 | label: Expected Behavior (what should happen?) 32 | description: > 33 | Please describe or show a code example of the expected behavior. 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: example 38 | attributes: 39 | label: Reproducible Example 40 | description: > 41 | If possible, provide a reproducible example. 42 | render: python 43 | 44 | - type: textarea 45 | id: os-version 46 | attributes: 47 | label: Operating system 48 | description: > 49 | Which operating system are you using? (Provide the version number) 50 | validations: 51 | required: true 52 | - type: textarea 53 | id: python-version 54 | attributes: 55 | label: Python version 56 | description: > 57 | Which Python version are you using? 58 | placeholder: > 59 | python --version 60 | validations: 61 | required: true 62 | - type: textarea 63 | id: substra-version 64 | attributes: 65 | label: Installed Substra versions 66 | description: > 67 | Which version of `substrafl`/ `substra` / `substra-tools` are you using? 68 | You can check if they are compatible in the [compatibility table](https://docs.substra.org/en/stable/additional/release.html#compatibility-table). 69 | placeholder: > 70 | pip freeze | grep substra 71 | render: python 72 | validations: 73 | required: true 74 | - type: textarea 75 | id: dependencies-version 76 | attributes: 77 | label: Installed versions of dependencies 78 | description: > 79 | Please provide versions of dependencies which might be relevant to your issue (eg. `helm` and `skaffold` version for a deployment issue, `numpy` and `pytorch` for an algorithmic issue). 80 | 81 | 82 | - type: textarea 83 | id: logs 84 | attributes: 85 | label: Logs / Stacktrace 86 | description: > 87 | Please copy-paste here any log and/or stacktrace that might be relevant. Remove confidential and personal information if necessary. 88 | -------------------------------------------------------------------------------- /substratools/workspace.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | 4 | 5 | def makedir_safe(path): 6 | """Create dir (no failure).""" 7 | try: 8 | os.makedirs(path) 9 | except (FileExistsError, PermissionError): 10 | pass 11 | 12 | 13 | DEFAULT_INPUT_DATA_FOLDER_PATH = "data/" 14 | DEFAULT_INPUT_PREDICTIONS_PATH = "pred/pred" 15 | DEFAULT_OUTPUT_PERF_PATH = "pred/perf.json" 16 | DEFAULT_LOG_PATH = "model/log_model.log" 17 | DEFAULT_CHAINKEYS_PATH = "chainkeys/" 18 | 19 | 20 | class Workspace(abc.ABC): 21 | """Filesystem workspace for task execution.""" 22 | 23 | def __init__(self, dirpath=None): 24 | self._workdir = dirpath if dirpath else os.getcwd() 25 | 26 | def _get_default_path(self, path): 27 | return os.path.join(self._workdir, path) 28 | 29 | def _get_default_subpaths(self, path): 30 | rootpath = os.path.join(self._workdir, path) 31 | if os.path.isdir(rootpath): 32 | return [ 33 | os.path.join(rootpath, subfolder) 34 | for subfolder in os.listdir(rootpath) 35 | if os.path.isdir(os.path.join(rootpath, subfolder)) 36 | ] 37 | return [] 38 | 39 | 40 | class OpenerWorkspace(Workspace): 41 | """Filesystem workspace required by the opener.""" 42 | 43 | def __init__( 44 | self, 45 | dirpath=None, 46 | input_data_folder_paths=None, 47 | ): 48 | super().__init__(dirpath=dirpath) 49 | 50 | assert input_data_folder_paths is None or isinstance(input_data_folder_paths, list) 51 | 52 | self.input_data_folder_paths = input_data_folder_paths or self._get_default_subpaths( 53 | DEFAULT_INPUT_DATA_FOLDER_PATH 54 | ) 55 | 56 | 57 | class FunctionWorkspace(Workspace): 58 | """Filesystem workspace for user defined function execution.""" 59 | 60 | def __init__( 61 | self, 62 | dirpath=None, 63 | log_path=None, 64 | chainkeys_path=None, 65 | inputs=None, 66 | outputs=None, 67 | ): 68 | 69 | super().__init__(dirpath=dirpath) 70 | 71 | self.input_data_folder_paths = ( 72 | self._get_default_subpaths(DEFAULT_INPUT_DATA_FOLDER_PATH) 73 | if inputs is None 74 | else inputs.input_data_folder_paths 75 | ) 76 | 77 | self.log_path = log_path or self._get_default_path(DEFAULT_LOG_PATH) 78 | self.chainkeys_path = chainkeys_path or self._get_default_path(DEFAULT_CHAINKEYS_PATH) 79 | 80 | self.opener_path = inputs.opener_path if inputs else None 81 | 82 | self.task_inputs = inputs.formatted_dynamic_resources if inputs else {} 83 | self.task_outputs = outputs.formatted_dynamic_resources if outputs else {} 84 | 85 | dirs = [ 86 | self.chainkeys_path, 87 | ] 88 | paths = [ 89 | self.log_path, 90 | ] 91 | 92 | dirs.extend([os.path.dirname(p) for p in paths]) 93 | for d in dirs: 94 | if d: 95 | makedir_safe(d) 96 | -------------------------------------------------------------------------------- /tests/test_task_resources.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | 5 | from substratools.exceptions import InvalidCLIError 6 | from substratools.exceptions import InvalidInputOutputsError 7 | from substratools.task_resources import StaticInputIdentifiers 8 | from substratools.task_resources import TaskResources 9 | 10 | _VALID_RESOURCES = [ 11 | {"id": "foo", "value": "bar", "multiple": True}, 12 | {"id": "foo", "value": "babar", "multiple": True}, 13 | {"id": "fofo", "value": "bar", "multiple": False}, 14 | ] 15 | _VALID_VALUES = {"foo": {"value": ["bar", "babar"], "multiple": True}, "fofo": {"value": ["bar"], "multiple": False}} 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "invalid_arg", 20 | ( 21 | {"foo": "barr"}, 22 | "foo and bar", 23 | ["foo", "barr"], 24 | [{"foo": "bar"}], 25 | [{"foo": "bar"}, {"id": "foo", "value": "bar", "multiple": True}], 26 | # [{_RESOURCE_ID: "foo", _RESOURCE_VALUE: "some path", _RESOURCE_MULTIPLE: "str"}], 27 | ), 28 | ) 29 | def test_task_resources_invalid_argsrt(invalid_arg): 30 | with pytest.raises(InvalidCLIError): 31 | TaskResources(json.dumps(invalid_arg)) 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "valid_arg,expected", 36 | [ 37 | ([], {}), 38 | ([{"id": "foo", "value": "bar", "multiple": True}], {"foo": {"value": ["bar"], "multiple": True}}), 39 | ( 40 | [{"id": "foo", "value": "bar", "multiple": True}, {"id": "foo", "value": "babar", "multiple": True}], 41 | {"foo": {"value": ["bar", "babar"], "multiple": True}}, 42 | ), 43 | (_VALID_RESOURCES, _VALID_VALUES), 44 | ], 45 | ) 46 | def test_task_resources_values(valid_arg, expected): 47 | TaskResources(json.dumps(valid_arg))._values == expected 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "static_resource_id", 52 | ( 53 | StaticInputIdentifiers.chainkeys.value, 54 | StaticInputIdentifiers.datasamples.value, 55 | StaticInputIdentifiers.opener.value, 56 | ), 57 | ) 58 | def test_task_static_resources(static_resource_id): 59 | "checks that static keys opener, datasamples and chainkeys are excluded" 60 | 61 | TaskResources( 62 | json.dumps(_VALID_RESOURCES + [{"id": static_resource_id, "value": "foo", "multiple": False}]) 63 | )._values == _VALID_VALUES 64 | 65 | 66 | @pytest.mark.parametrize("key", tuple(_VALID_VALUES.keys())) 67 | def test_get_value(key): 68 | "get_value method returns a list of path of multiple resource and a path for non multiple ones" 69 | expected = _VALID_VALUES[key]["value"] 70 | 71 | if _VALID_VALUES[key]["multiple"]: 72 | expected = expected[0] 73 | 74 | 75 | def test_multiple_resource_error(): 76 | "non multiple resource can't have multiple values" 77 | 78 | with pytest.raises(InvalidInputOutputsError): 79 | TaskResources( 80 | json.dumps( 81 | [ 82 | {"id": "foo", "value": "bar", "multiple": False}, 83 | {"id": "foo", "value": "babar", "multiple": False}, 84 | ] 85 | ) 86 | ) 87 | -------------------------------------------------------------------------------- /tests/test_opener.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from substratools import exceptions 6 | from substratools.opener import Opener 7 | from substratools.opener import OpenerWrapper 8 | from substratools.opener import load_from_module 9 | from substratools.utils import import_module 10 | from substratools.utils import load_interface_from_module 11 | from substratools.workspace import DEFAULT_INPUT_DATA_FOLDER_PATH 12 | 13 | 14 | @pytest.fixture 15 | def tmp_cwd(tmp_path): 16 | # create a temporary current working directory 17 | new_dir = tmp_path / "workspace" 18 | new_dir.mkdir() 19 | 20 | old_dir = os.getcwd() 21 | os.chdir(new_dir) 22 | 23 | yield new_dir 24 | 25 | os.chdir(old_dir) 26 | 27 | 28 | def test_load_opener_not_found(tmp_cwd): 29 | with pytest.raises(ImportError): 30 | load_from_module() 31 | 32 | 33 | def test_load_invalid_opener(tmp_cwd): 34 | invalid_script = """ 35 | def get_data(): 36 | raise NotImplementedError 37 | """ 38 | 39 | import_module("opener", invalid_script) 40 | 41 | with pytest.raises(exceptions.InvalidInterfaceError): 42 | load_from_module() 43 | 44 | 45 | def test_load_opener_as_class(tmp_cwd): 46 | script = """ 47 | from substratools import Opener 48 | class MyOpener(Opener): 49 | def get_data(self, folders): 50 | return 'data_class' 51 | def fake_data(self, n_samples): 52 | return 'fake_data' 53 | """ 54 | 55 | import_module("opener", script) 56 | 57 | o = load_from_module() 58 | assert o.get_data() == "data_class" 59 | 60 | 61 | def test_load_opener_from_path(tmp_cwd, valid_opener_code): 62 | dirpath = tmp_cwd / "myopener" 63 | dirpath.mkdir() 64 | path = dirpath / "my_opener.py" 65 | path.write_text(valid_opener_code) 66 | 67 | interface = load_interface_from_module( 68 | "opener", 69 | interface_class=Opener, 70 | interface_signature=None, # XXX does not support interface for debugging 71 | path=path, 72 | ) 73 | o = OpenerWrapper(interface, workspace=None) 74 | assert o.get_data()[0] == "X" 75 | 76 | 77 | def test_load_opener_from_path_error_with_inheritance(tmp_cwd): 78 | wrong_opener_code = """ 79 | import json 80 | from substratools import Opener 81 | 82 | class FakeOpener(Opener): 83 | def get_data(self, folder): 84 | return 'X', list(range(0, 3)) 85 | 86 | def fake_data(self, n_samples): 87 | return ['Xfake'] * n_samples, [0] * n_samples 88 | 89 | class FinalOpener(FakeOpener): 90 | def __init__(self): 91 | super().__init__() 92 | """ 93 | dirpath = tmp_cwd / "myopener" 94 | dirpath.mkdir() 95 | path = dirpath / "my_opener.py" 96 | path.write_text(wrong_opener_code) 97 | 98 | with pytest.raises(exceptions.InvalidInterfaceError): 99 | load_interface_from_module( 100 | "opener", 101 | interface_class=Opener, 102 | interface_signature=None, # XXX does not support interface for debugging 103 | path=path, 104 | ) 105 | 106 | 107 | def test_opener_check_folders(tmp_cwd): 108 | script = """ 109 | from substratools import Opener 110 | class MyOpener(Opener): 111 | def get_data(self, folders): 112 | assert len(folders) == 5 113 | return 'data_class' 114 | def fake_data(self, n_samples): 115 | return 'fake_data_class' 116 | """ 117 | 118 | import_module("opener", script) 119 | 120 | o = load_from_module() 121 | 122 | # create some data folders 123 | data_root_path = os.path.join(o._workspace._workdir, DEFAULT_INPUT_DATA_FOLDER_PATH) 124 | data_paths = [os.path.join(data_root_path, str(i)) for i in range(5)] 125 | [os.makedirs(p) for p in data_paths] 126 | 127 | o._workspace.input_data_folder_paths = data_paths 128 | assert o.get_data() == "data_class" 129 | -------------------------------------------------------------------------------- /substratools/opener.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | import os 4 | import types 5 | from typing import Optional 6 | 7 | from substratools import exceptions 8 | from substratools import utils 9 | from substratools.workspace import OpenerWorkspace 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | REQUIRED_FUNCTIONS = set( 15 | [ 16 | "get_data", 17 | "fake_data", 18 | ] 19 | ) 20 | 21 | 22 | class Opener(abc.ABC): 23 | """Dataset opener abstract base class. 24 | 25 | To define a new opener script, subclass this class and implement the 26 | following abstract methods: 27 | 28 | - #Opener.get_data() 29 | - #Opener.fake_data() 30 | 31 | # Example 32 | 33 | ```python 34 | import os 35 | import pandas as pd 36 | import string 37 | import numpy as np 38 | 39 | import substratools as tools 40 | 41 | class DummyOpener(tools.Opener): 42 | def get_data(self, folders): 43 | return [ 44 | pd.read_csv(os.path.join(folder, 'train.csv')) 45 | for folder in folders 46 | ] 47 | 48 | def fake_data(self, n_samples): 49 | return [] # compute random fake data 50 | ``` 51 | 52 | # How to test locally an opener script 53 | 54 | An opener can be imported and used in python scripts as would any other class. 55 | 56 | For example, assuming that you have a local file named `opener.py` that contains 57 | an `Opener` named `MyOpener`: 58 | 59 | ```python 60 | import os 61 | from opener import MyOpener 62 | 63 | folders = os.listdir('./sandbox/data_samples/') 64 | 65 | o = MyOpener() 66 | loaded_datasamples = o.get_data(folders) 67 | ``` 68 | """ 69 | 70 | @abc.abstractmethod 71 | def get_data(self, folders): 72 | """Datasamples loader 73 | 74 | # Arguments 75 | 76 | folders: list of folders. Each folder represents a data sample. 77 | 78 | # Returns 79 | 80 | data: data object. 81 | """ 82 | raise NotImplementedError 83 | 84 | @abc.abstractmethod 85 | def fake_data(self, n_samples): 86 | """Generate fake loaded datasamples for offline testing. 87 | 88 | # Arguments 89 | 90 | n_samples (int): number of samples to return 91 | 92 | # Returns 93 | 94 | data: data object. 95 | """ 96 | raise NotImplementedError 97 | 98 | 99 | class OpenerWrapper(object): 100 | """Internal wrapper to call opener interface.""" 101 | 102 | def __init__(self, interface, workspace=None): 103 | assert isinstance(interface, Opener) or isinstance(interface, types.ModuleType) 104 | 105 | self._workspace = workspace or OpenerWorkspace() 106 | self._interface = interface 107 | 108 | @property 109 | def data_folder_paths(self): 110 | return self._workspace.input_data_folder_paths 111 | 112 | def get_data(self, fake_data=False, n_fake_samples=None): 113 | if fake_data: 114 | logger.info("loading data from fake data") 115 | return self._interface.fake_data(n_samples=n_fake_samples) 116 | else: 117 | logger.info("loading data from '{}'".format(self.data_folder_paths)) 118 | return self._interface.get_data(self.data_folder_paths) 119 | 120 | def _assert_output_exists(self, path, key): 121 | 122 | if os.path.isdir(path): 123 | raise exceptions.NotAFileError(f"Expected output file at {path}, found dir for output `{key}`") 124 | if not os.path.isfile(path): 125 | raise exceptions.MissingFileError(f"Output file {path} used to save argument `{key}` does not exists.") 126 | 127 | 128 | def load_from_module(workspace=None) -> Optional[OpenerWrapper]: 129 | """Load opener interface. 130 | 131 | If a workspace is given, the associated opener will be returned. This means that if no 132 | opener_path is defined within the workspace, no opener will be returned 133 | If no workspace is given, the opener interface will be directly loaded as a module. 134 | 135 | Return an OpenerWrapper instance. 136 | """ 137 | if workspace is None: 138 | # import from module 139 | path = None 140 | 141 | elif workspace.opener_path is None: 142 | # no opener within this workspace 143 | return None 144 | 145 | else: 146 | # import opener from workspace specified path 147 | path = workspace.opener_path 148 | 149 | interface = utils.load_interface_from_module( 150 | "opener", 151 | interface_class=Opener, 152 | interface_signature=None, # XXX does not support interface for debugging 153 | path=path, 154 | ) 155 | return OpenerWrapper(interface, workspace=workspace) 156 | -------------------------------------------------------------------------------- /substratools/task_resources.py: -------------------------------------------------------------------------------- 1 | import json 2 | from enum import Enum 3 | from typing import Dict 4 | from typing import List 5 | from typing import Optional 6 | from typing import Union 7 | 8 | from substratools import exceptions 9 | 10 | 11 | class StaticInputIdentifiers(str, Enum): 12 | opener = "opener" 13 | datasamples = "datasamples" 14 | chainkeys = "chainkeys" 15 | rank = "rank" 16 | 17 | 18 | _RESOURCE_ID = "id" 19 | _RESOURCE_VALUE = "value" 20 | _RESOURCE_MULTIPLE = "multiple" 21 | 22 | 23 | def _check_resources_format(resource_list): 24 | 25 | _required_keys = set((_RESOURCE_ID, _RESOURCE_VALUE, _RESOURCE_MULTIPLE)) 26 | _error_message = ( 27 | "`--inputs` and `--outputs` args should be json serialized list of dict. Each dict containing " 28 | f"the following keys: {_required_keys}. {_RESOURCE_ID} and {_RESOURCE_VALUE} must be strings, " 29 | f"{_RESOURCE_MULTIPLE} must be a bool." 30 | ) 31 | 32 | if not isinstance(resource_list, list): 33 | raise exceptions.InvalidCLIError(_error_message) 34 | 35 | if not all([isinstance(d, dict) for d in resource_list]): 36 | raise exceptions.InvalidCLIError(_error_message) 37 | 38 | if not all([set(d.keys()) == _required_keys for d in resource_list]): 39 | raise exceptions.InvalidCLIError(_error_message) 40 | 41 | if not all([isinstance(d[_RESOURCE_MULTIPLE], bool) for d in resource_list]): 42 | raise exceptions.InvalidCLIError(_error_message) 43 | 44 | if not all([isinstance(d[_RESOURCE_ID], str) for d in resource_list]): 45 | raise exceptions.InvalidCLIError(_error_message) 46 | 47 | if not all([isinstance(d[_RESOURCE_VALUE], str) for d in resource_list]): 48 | raise exceptions.InvalidCLIError(_error_message) 49 | 50 | 51 | def _check_resources_multiplicity(resource_dict): 52 | for k, v in resource_dict.items(): 53 | if not v[_RESOURCE_MULTIPLE] and len(v[_RESOURCE_VALUE]) > 1: 54 | raise exceptions.InvalidInputOutputsError(f"There is more than one path for the non multiple resource {k}") 55 | 56 | 57 | class TaskResources: 58 | """TaskResources is created from stdin to provide a nice abstraction over inputs/outputs""" 59 | 60 | _values: Dict[str, List[str]] 61 | 62 | def __init__(self, argstr: str) -> None: 63 | """Argstr is expected to be a JSON array like: 64 | [ 65 | {"id": "local", "value": "/sandbox/output/model/uuid", "multiple": False}, 66 | {"id": "shared", ...} 67 | ] 68 | """ 69 | self._values = {} 70 | resource_list = json.loads(argstr.replace("\\", "/")) 71 | 72 | _check_resources_format(resource_list) 73 | 74 | for item in resource_list: 75 | self._values.setdefault( 76 | item[_RESOURCE_ID], {_RESOURCE_VALUE: [], _RESOURCE_MULTIPLE: item[_RESOURCE_MULTIPLE]} 77 | ) 78 | self._values[item[_RESOURCE_ID]][_RESOURCE_VALUE].append(item[_RESOURCE_VALUE]) 79 | 80 | _check_resources_multiplicity(self._values) 81 | 82 | self.opener_path = self.get_value(StaticInputIdentifiers.opener.value) 83 | self.input_data_folder_paths = self.get_value(StaticInputIdentifiers.datasamples.value) 84 | self.chainkeys_path = self.get_value(StaticInputIdentifiers.chainkeys.value) 85 | 86 | def get_value(self, key: str) -> Optional[Union[List[str], str]]: 87 | """Returns the value for a given key. Return None if there is no matching resource. 88 | Will raise if there is a mismatch between the given multiplicity and the number of returned 89 | elements. 90 | 91 | If multiple is True, will return a list else will return a single value 92 | """ 93 | if key not in self._values: 94 | return None 95 | 96 | val = self._values[key][_RESOURCE_VALUE] 97 | multiple = self._values[key][_RESOURCE_MULTIPLE] 98 | 99 | if multiple: 100 | return val 101 | 102 | return val[0] 103 | 104 | @property 105 | def formatted_dynamic_resources(self) -> Union[List[str], str]: 106 | """Returns all the resources (except the datasamples, the opener and the chainkeys_path under the user format: 107 | A dict where each input is an element where 108 | - the key is the user identifier 109 | - the value is a list of Path for multiple resources and a Path for non multiple resources 110 | """ 111 | 112 | return { 113 | k: self.get_value(k) 114 | for k in self._values.keys() 115 | if k 116 | not in ( 117 | StaticInputIdentifiers.opener.value, 118 | StaticInputIdentifiers.datasamples.value, 119 | StaticInputIdentifiers.chainkeys.value, 120 | ) 121 | } 122 | -------------------------------------------------------------------------------- /substratools/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import importlib.util 3 | import inspect 4 | import logging 5 | import os 6 | import sys 7 | import time 8 | 9 | from substratools import exceptions 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | MAPPING_LOG_LEVEL = { 14 | "debug": logging.DEBUG, 15 | "info": logging.INFO, 16 | "warning": logging.WARNING, 17 | "error": logging.ERROR, 18 | "critical": logging.CRITICAL, 19 | } 20 | 21 | 22 | def configure_logging(path=None, log_level="info"): 23 | level = MAPPING_LOG_LEVEL[log_level] 24 | 25 | formatter = logging.Formatter(fmt="%(asctime)s %(levelname)-6s %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") 26 | 27 | h = logging.StreamHandler() 28 | h.setLevel(level) 29 | h.setFormatter(formatter) 30 | 31 | root = logging.getLogger("substratools") 32 | root.setLevel(level) 33 | root.addHandler(h) 34 | 35 | if path and level == logging.DEBUG: 36 | fh = logging.FileHandler(path) 37 | fh.setLevel(level) 38 | fh.setFormatter(formatter) 39 | 40 | root.addHandler(h) 41 | 42 | 43 | def get_logger(name, path=None, log_level="info"): 44 | new_logger = logging.getLogger(f"substratools.{name}") 45 | configure_logging(path, log_level) 46 | return new_logger 47 | 48 | 49 | class Timer(object): 50 | """This decorator prints the execution time for the decorated function.""" 51 | 52 | def __init__(self, module_logger): 53 | self.module_logger = module_logger 54 | 55 | def __call__(self, func): 56 | def wrapper(*args, **kwargs): 57 | start = time.time() 58 | result = func(*args, **kwargs) 59 | end = time.time() 60 | self.module_logger.info("{} ran in {}s".format(func.__qualname__, round(end - start, 2))) 61 | return result 62 | 63 | return wrapper 64 | 65 | 66 | def import_module(module_name, code): 67 | if module_name in sys.modules: 68 | logging.warning("Module {} will be overwritten".format(module_name)) 69 | spec = importlib.util.spec_from_loader(module_name, loader=None, origin=module_name) 70 | module = importlib.util.module_from_spec(spec) 71 | sys.modules[module_name] = module 72 | exec(code, module.__dict__) 73 | 74 | 75 | def import_module_from_path(path, module_name): 76 | assert os.path.exists(path), "path '{}' not found".format(path) 77 | spec = importlib.util.spec_from_file_location(module_name, path) 78 | assert spec, "could not load spec from path '{}'".format(path) 79 | module = importlib.util.module_from_spec(spec) 80 | spec.loader.exec_module(module) 81 | return module 82 | 83 | 84 | # TODO: 'load_interface_from_module' is too complex, consider refactoring 85 | def load_interface_from_module(module_name, interface_class, interface_signature=None, path=None): # noqa: C901 86 | if path: 87 | module = import_module_from_path(path, module_name) 88 | logger.info(f"Module '{module_name}' loaded from path '{path}'") 89 | else: 90 | try: 91 | module = importlib.import_module(module_name) 92 | logger.info(f"Module '{module_name}' imported dynamically; module={module}") 93 | except ImportError: 94 | # XXX don't use ModuleNotFoundError for python3.5 compatibility 95 | raise 96 | 97 | # check if module empty 98 | if not inspect.getmembers(module, lambda m: inspect.isclass(m) or inspect.isfunction(m)): 99 | raise exceptions.EmptyInterfaceError( 100 | f"Module '{module_name}' seems empty: no method/class found in members: '{dir(module)}'" 101 | ) 102 | 103 | # find interface class 104 | found_interfaces = [] 105 | for _, obj in inspect.getmembers(module, inspect.isclass): 106 | if issubclass(obj, interface_class) and obj != interface_class: 107 | found_interfaces.append(obj) 108 | 109 | if len(found_interfaces) == 1: 110 | return found_interfaces[0]() # return interface instance 111 | elif len(found_interfaces) > 1: 112 | raise exceptions.InvalidInterfaceError( 113 | f"Multiple interfaces found in module '{module_name}': {found_interfaces}" 114 | ) 115 | 116 | # backward compatibility; accept methods at module level directly 117 | if interface_signature is None: 118 | class_name = interface_class.__name__ 119 | elements = str(dir(module)) 120 | logger.info(f"Class '{class_name}' not found from: '{elements}'") 121 | raise exceptions.InvalidInterfaceError("Expecting {} subclass in {}".format(class_name, module_name)) 122 | 123 | missing_functions = interface_signature.copy() 124 | for name, obj in inspect.getmembers(module): 125 | if not inspect.isfunction(obj): 126 | continue 127 | try: 128 | missing_functions.remove(name) 129 | except KeyError: 130 | pass 131 | 132 | if missing_functions: 133 | message = "Method(s) {} not implemented".format(", ".join(["'{}'".format(m) for m in missing_functions])) 134 | raise exceptions.InvalidInterfaceError(message) 135 | return module 136 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from os import PathLike 4 | from typing import Any 5 | from typing import TypedDict 6 | 7 | import numpy as np 8 | import pytest 9 | 10 | from substratools import function 11 | from substratools import load_performance 12 | from substratools import opener 13 | from substratools import save_performance 14 | from substratools.task_resources import TaskResources 15 | from substratools.workspace import FunctionWorkspace 16 | from tests import utils 17 | from tests.utils import InputIdentifiers 18 | from tests.utils import OutputIdentifiers 19 | 20 | 21 | @pytest.fixture() 22 | def write_pred_file(workdir): 23 | pred_file = str(workdir / str(uuid.uuid4())) 24 | data = list(range(3, 6)) 25 | with open(pred_file, "w") as f: 26 | json.dump(data, f) 27 | return pred_file, data 28 | 29 | 30 | @pytest.fixture 31 | def inputs(workdir, valid_opener_script, write_pred_file): 32 | return [ 33 | {"id": InputIdentifiers.predictions, "value": str(write_pred_file[0]), "multiple": False}, 34 | {"id": InputIdentifiers.datasamples, "value": str(workdir / "datasamples_unused"), "multiple": True}, 35 | {"id": InputIdentifiers.opener, "value": str(valid_opener_script), "multiple": False}, 36 | ] 37 | 38 | 39 | @pytest.fixture 40 | def outputs(workdir): 41 | return [{"id": OutputIdentifiers.performance, "value": str(workdir / str(uuid.uuid4())), "multiple": False}] 42 | 43 | 44 | @pytest.fixture(autouse=True) 45 | def setup(valid_opener, write_pred_file): 46 | pass 47 | 48 | 49 | @function.register 50 | def score( 51 | inputs: TypedDict("inputs", {InputIdentifiers.datasamples: Any, InputIdentifiers.predictions: Any}), 52 | outputs: TypedDict("outputs", {OutputIdentifiers.performance: PathLike}), 53 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), 54 | ): 55 | y_true = inputs.get(InputIdentifiers.datasamples)[1] 56 | y_pred_path = inputs.get(InputIdentifiers.predictions) 57 | y_pred = utils.load_predictions(y_pred_path) 58 | 59 | score = sum(y_true) + sum(y_pred) 60 | 61 | save_performance(performance=score, path=outputs.get(OutputIdentifiers.performance)) 62 | 63 | 64 | def test_score(workdir, write_pred_file): 65 | inputs = TaskResources( 66 | json.dumps( 67 | [ 68 | {"id": InputIdentifiers.predictions, "value": str(write_pred_file[0]), "multiple": False}, 69 | ] 70 | ) 71 | ) 72 | outputs = TaskResources( 73 | json.dumps( 74 | [{"id": OutputIdentifiers.performance, "value": str(workdir / str(uuid.uuid4())), "multiple": False}] 75 | ) 76 | ) 77 | workspace = FunctionWorkspace(inputs=inputs, outputs=outputs) 78 | wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=opener.load_from_module()) 79 | wp.execute(function=score) 80 | s = load_performance(wp._workspace.task_outputs[OutputIdentifiers.performance]) 81 | assert s == 15 82 | 83 | 84 | def test_execute(inputs, outputs): 85 | perf_path = outputs[0]["value"] 86 | function.execute( 87 | sysargs=["--function-name", "score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], 88 | ) 89 | s = load_performance(perf_path) 90 | assert s == 15 91 | 92 | 93 | @pytest.mark.parametrize( 94 | "fake_data_mode,expected_score", 95 | [ 96 | ([], 15), 97 | (["--fake-data", "--n-fake-samples", "3"], 12), 98 | ], 99 | ) 100 | def test_execute_fake_data_modes(fake_data_mode, expected_score, inputs, outputs): 101 | perf_path = outputs[0]["value"] 102 | function.execute( 103 | sysargs=fake_data_mode 104 | + ["--function-name", "score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], 105 | ) 106 | s = load_performance(perf_path) 107 | assert s == expected_score 108 | 109 | 110 | def test_execute_np(inputs, outputs): 111 | @function.register 112 | def float_np_score( 113 | inputs, 114 | outputs, 115 | task_properties: dict, 116 | ): 117 | save_performance(np.float64(0.99), outputs.get(OutputIdentifiers.performance)) 118 | 119 | perf_path = outputs[0]["value"] 120 | function.execute( 121 | sysargs=["--function-name", "float_np_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], 122 | ) 123 | s = load_performance(perf_path) 124 | assert s == pytest.approx(0.99) 125 | 126 | 127 | def test_execute_int(inputs, outputs): 128 | @function.register 129 | def int_score( 130 | inputs, 131 | outputs, 132 | task_properties: dict, 133 | ): 134 | save_performance(int(1), outputs.get(OutputIdentifiers.performance)) 135 | 136 | perf_path = outputs[0]["value"] 137 | function.execute( 138 | sysargs=["--function-name", "int_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], 139 | ) 140 | s = load_performance(perf_path) 141 | assert s == 1 142 | 143 | 144 | def test_execute_dict(inputs, outputs): 145 | @function.register 146 | def dict_score( 147 | inputs, 148 | outputs, 149 | task_properties: dict, 150 | ): 151 | save_performance({"a": 1}, outputs.get(OutputIdentifiers.performance)) 152 | 153 | perf_path = outputs[0]["value"] 154 | function.execute( 155 | sysargs=["--function-name", "dict_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], 156 | ) 157 | s = load_performance(perf_path) 158 | assert s["a"] == 1 159 | -------------------------------------------------------------------------------- /tests/test_workflow.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import pytest 5 | 6 | from substratools import load_performance 7 | from substratools import opener 8 | from substratools import save_performance 9 | from substratools.function import FunctionWrapper 10 | from substratools.task_resources import TaskResources 11 | from substratools.utils import import_module 12 | from substratools.workspace import FunctionWorkspace 13 | from tests import utils 14 | from tests.utils import InputIdentifiers 15 | from tests.utils import OutputIdentifiers 16 | 17 | 18 | @pytest.fixture 19 | def dummy_opener(): 20 | script = """ 21 | import json 22 | from substratools import Opener 23 | 24 | class DummyOpener(Opener): 25 | def get_data(self, folder): 26 | return None 27 | 28 | def fake_data(self, n_samples): 29 | raise NotImplementedError 30 | """ 31 | import_module("opener", script) 32 | 33 | 34 | def train(inputs, outputs, task_properties): 35 | models = utils.load_models(inputs.get(InputIdentifiers.shared, [])) 36 | total = sum([m["i"] for m in models]) 37 | new_model = {"i": len(models) + 1, "total": total} 38 | 39 | utils.save_model(new_model, outputs.get(OutputIdentifiers.shared)) 40 | 41 | 42 | def predict(inputs, outputs, task_properties): 43 | model = utils.load_model(inputs.get(InputIdentifiers.shared)) 44 | pred = {"sum": model["i"]} 45 | utils.save_predictions(pred, outputs.get(OutputIdentifiers.predictions)) 46 | 47 | 48 | def score(inputs, outputs, task_properties): 49 | y_pred_path = inputs.get(InputIdentifiers.predictions) 50 | y_pred = utils.load_predictions(y_pred_path) 51 | 52 | score = y_pred["sum"] 53 | 54 | save_performance(performance=score, path=outputs.get(OutputIdentifiers.performance)) 55 | 56 | 57 | def test_workflow(workdir, dummy_opener): 58 | loop1_model_path = workdir / "loop1model" 59 | loop1_workspace_outputs = TaskResources( 60 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop1_model_path), "multiple": False}]) 61 | ) 62 | loop1_workspace = FunctionWorkspace(outputs=loop1_workspace_outputs) 63 | loop1_wp = FunctionWrapper(workspace=loop1_workspace, opener_wrapper=None) 64 | 65 | # loop 1 (no input) 66 | loop1_wp.execute(function=train) 67 | model = utils.load_model(path=loop1_wp._workspace.task_outputs[OutputIdentifiers.shared]) 68 | 69 | assert model == {"i": 1, "total": 0} 70 | assert os.path.exists(loop1_model_path) 71 | 72 | loop2_model_path = workdir / "loop2model" 73 | 74 | loop2_workspace_inputs = TaskResources( 75 | json.dumps([{"id": InputIdentifiers.shared, "value": str(loop1_model_path), "multiple": True}]) 76 | ) 77 | loop2_workspace_outputs = TaskResources( 78 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop2_model_path), "multiple": False}]) 79 | ) 80 | loop2_workspace = FunctionWorkspace(inputs=loop2_workspace_inputs, outputs=loop2_workspace_outputs) 81 | loop2_wp = FunctionWrapper(workspace=loop2_workspace, opener_wrapper=None) 82 | 83 | # loop 2 (one model as input) 84 | loop2_wp.execute(function=train) 85 | model = utils.load_model(path=loop2_wp._workspace.task_outputs[OutputIdentifiers.shared]) 86 | assert model == {"i": 2, "total": 1} 87 | assert os.path.exists(loop2_model_path) 88 | 89 | loop3_model_path = workdir / "loop2model" 90 | loop3_workspace_inputs = TaskResources( 91 | json.dumps( 92 | [ 93 | {"id": InputIdentifiers.shared, "value": str(loop1_model_path), "multiple": True}, 94 | {"id": InputIdentifiers.shared, "value": str(loop2_model_path), "multiple": True}, 95 | ] 96 | ) 97 | ) 98 | loop3_workspace_outputs = TaskResources( 99 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop3_model_path), "multiple": False}]) 100 | ) 101 | loop3_workspace = FunctionWorkspace(inputs=loop3_workspace_inputs, outputs=loop3_workspace_outputs) 102 | loop3_wp = FunctionWrapper(workspace=loop3_workspace, opener_wrapper=None) 103 | 104 | # loop 3 (two models as input) 105 | loop3_wp.execute(function=train) 106 | model = utils.load_model(path=loop3_wp._workspace.task_outputs[OutputIdentifiers.shared]) 107 | assert model == {"i": 3, "total": 3} 108 | assert os.path.exists(loop3_model_path) 109 | 110 | predictions_path = workdir / "predictions" 111 | predict_workspace_inputs = TaskResources( 112 | json.dumps([{"id": InputIdentifiers.shared, "value": str(loop3_model_path), "multiple": False}]) 113 | ) 114 | predict_workspace_outputs = TaskResources( 115 | json.dumps([{"id": OutputIdentifiers.predictions, "value": str(predictions_path), "multiple": False}]) 116 | ) 117 | predict_workspace = FunctionWorkspace(inputs=predict_workspace_inputs, outputs=predict_workspace_outputs) 118 | predict_wp = FunctionWrapper(workspace=predict_workspace, opener_wrapper=None) 119 | 120 | # predict 121 | predict_wp.execute(function=predict) 122 | pred = utils.load_predictions(path=predict_wp._workspace.task_outputs[OutputIdentifiers.predictions]) 123 | assert pred == {"sum": 3} 124 | 125 | # metrics 126 | performance_path = workdir / "performance" 127 | metric_workspace_inputs = TaskResources( 128 | json.dumps([{"id": InputIdentifiers.predictions, "value": str(predictions_path), "multiple": False}]) 129 | ) 130 | metric_workspace_outputs = TaskResources( 131 | json.dumps([{"id": OutputIdentifiers.performance, "value": str(performance_path), "multiple": False}]) 132 | ) 133 | metric_workspace = FunctionWorkspace( 134 | inputs=metric_workspace_inputs, 135 | outputs=metric_workspace_outputs, 136 | ) 137 | metrics_wp = FunctionWrapper(workspace=metric_workspace, opener_wrapper=opener.load_from_module()) 138 | metrics_wp.execute(function=score) 139 | res = load_performance(path=metrics_wp._workspace.task_outputs[OutputIdentifiers.performance]) 140 | assert res == 3.0 141 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | 9 | 10 | ## [1.0.0](https://github.com/Substra/substra-tools/releases/tag/1.0.0) - 2024-10-14 11 | 12 | ### Removed 13 | 14 | - Drop Python 3.9 support. ([#112](https://github.com/Substra/substra-tools/pull/112)) 15 | 16 | 17 | ## [0.22.0](https://github.com/Substra/substra-tools/releases/tag/0.22.0) - 2024-09-04 18 | 19 | ### Changed 20 | 21 | - The Opener and Function workspace does not try to create folder for data samples anymore. ([#100](https://github.com/Substra/substra-tools/pull/100)) 22 | 23 | ### Removed 24 | 25 | - Remove base Docker image. Substrafl now uses python-slim as base image. ([#101](https://github.com/Substra/substra-tools/pull/101)) 26 | 27 | 28 | ## [0.21.4](https://github.com/Substra/substra-tools/releases/tag/0.21.4) - 2024-06-03 29 | 30 | 31 | No significant changes. 32 | 33 | 34 | ## [0.21.3](https://github.com/Substra/substra-tools/releases/tag/0.21.3) - 2024-03-27 35 | 36 | 37 | ### Changed 38 | 39 | - - Drop Python 3.8 support ([#90](https://github.com/Substra/substra-tools/pull/90)) 40 | - - Depreciate `setup.py` in favour of `pyproject.toml` ([#92](https://github.com/Substra/substra-tools/pull/92)) 41 | 42 | 43 | ## [0.21.2](https://github.com/Substra/substra-tools/releases/tag/0.21.2) - 2024-03-07 44 | 45 | ### Changed 46 | 47 | - Update dependencies 48 | 49 | ## [0.21.1](https://github.com/Substra/substra-tools/releases/tag/0.21.1) - 2024-02-26 50 | 51 | ### Changed 52 | 53 | - Updated dependencies 54 | 55 | ## [0.21.0](https://github.com/Substra/substra-tools/releases/tag/0.21.0) - 2023-10-06 56 | 57 | ### Changed 58 | 59 | - Remove `model` and `models` for input and output identifiers in tests. Replace by `shared` instead. ([#84](https://github.com/Substra/substra-tools/pull/84)) 60 | - BREAKING: Remove minimal and workflow docker images ([#86](https://github.com/Substra/substra-tools/pull/86)) 61 | - Remove python lib from Docker image ([#86](https://github.com/Substra/substra-tools/pull/86)) 62 | 63 | ### Added 64 | 65 | - Support on Python 3.11 ([#85](https://github.com/Substra/substra-tools/pull/85)) 66 | - Contributing, contributors & code of conduct files (#77) 67 | 68 | ## [0.20.0](https://github.com/Substra/substra-tools/releases/tag/0.20.0) - 2022-12-19 69 | 70 | ### Changed 71 | 72 | - Add optional argument to `register` decorator to choose a custom function name. Allow to register the same function several time with a different name (#74) 73 | - Rank is now passed in a task properties dictionary from the backend (instead of the rank argument) (#75) 74 | 75 | ## [0.19.0](https://github.com/Substra/substra-tools/releases/tag/0.19.0) - 2022-11-22 76 | 77 | ### CHANGED 78 | 79 | - BREAKING CHANGE (#65) 80 | 81 | - Register functions to substratools can be done with a decorator. 82 | 83 | ```py 84 | def my_function1: 85 | pass 86 | 87 | def my_function2: 88 | pass 89 | 90 | if __name__ == '__main__': 91 | tools.execute(my_function1, my_function2) 92 | ``` 93 | 94 | become 95 | 96 | ```py 97 | @tools.register 98 | def my_function1: 99 | pass 100 | 101 | @tools.register 102 | def my_function2: 103 | pass 104 | 105 | if __name__ == '__main__': 106 | tools.execute() 107 | ``` 108 | 109 | - BREAKING CHANGE (#63) 110 | 111 | - Rename algo to function. 112 | - `tools.algo.execute` become `tools.execute` 113 | - The previous algo class pass to the function `tools.algo.execute` is now several functions pass as arguments to `tools.execute`. The function given by the cli `--function-name` is executed. 114 | 115 | ```py 116 | if __name__ == '__main__': 117 | tools.algo.execute(MyAlgo()) 118 | ``` 119 | 120 | become 121 | 122 | ```py 123 | if __name__ == '__main__': 124 | tools.execute(my_function1, my_function2) 125 | ``` 126 | 127 | ### Fixed 128 | 129 | - Remove depreciated `pytest-runner` from setup.py (#71) 130 | - Replace backslash by slash in TaskResources to fix windows compatibility (#70) 131 | - Update flake8 repository in pre-commit configuration (#69) 132 | - BREAKING CHANGE: Update substratools Docker image (#112) 133 | 134 | ## [0.18.0](https://github.com/Substra/substra-tools/releases/tag/0.18.0) - 2022-09-26 135 | 136 | ### Added 137 | 138 | - feat: allow CLI parameters to be read from a file 139 | 140 | ### Changed 141 | 142 | - BREAKING CHANGES: 143 | 144 | - the opener only exposes `get_data` and `fake_data` methods. 145 | - the results of the above method is passed under the `datasamples` keys within the `inputs` dict arg of all 146 | tools methods (train, predict, aggregate, score). 147 | - all method (train, predict, aggregate, score) now takes a `task_properties` argument (dict) in addition to 148 | `inputs` and `outputs`. 149 | - The `rank` of a task previously passed under the `rank` key within the inputs is now given in the `task_properties` 150 | dict under the `rank` key. 151 | 152 | - BREAKING CHANGE: The metric is now a generic algo, replace 153 | 154 | ```python 155 | import substratools as tools 156 | 157 | class MyMetric(tools.Metrics): 158 | # ... 159 | 160 | if __name__ == '__main__': 161 | tools.metrics.execute(MyMetric()) 162 | ``` 163 | 164 | by 165 | 166 | ```python 167 | import substratools as tools 168 | 169 | class MyMetric(tools.MetricAlgo): 170 | # ... 171 | if __name__ == '__main__': 172 | tools.algo.execute(MyMetric()) 173 | ``` 174 | 175 | ## [0.17.0](https://github.com/Substra/substra-tools/releases/tag/0.17.0) - 2022-09-19 176 | 177 | ### Changed 178 | 179 | - feat: all algo classes rely on a generic algo class 180 | 181 | ## [0.16.0](https://github.com/Substra/substra-tools/releases/tag/0.16.0) - 2022-09-12 182 | 183 | ### Changed 184 | 185 | - Remove documentation as it is not used. It will be replaced later on. 186 | - BREAKING CHANGES: the user must now pass the method name to execute within the dockerfile of both `algo` and 187 | `metric` under the `--method-name` argument. The method name still needs to be one of the `algo` or `metric` 188 | allowed method name: train, predict, aggregate, score. 189 | 190 | ```Dockerfile 191 | ENTRYPOINT ["python3", "metrics.py"] 192 | ``` 193 | 194 | shall be replaced by: 195 | 196 | ```Dockerfile 197 | ENTRYPOINT ["python3", "metrics.py", "--method-name", "score"] 198 | ``` 199 | 200 | - BREAKING CHANGES: rename connect-tools to substra-tools (except the github folder) 201 | 202 | ## [0.15.0](https://github.com/Substra/substra-tools/releases/tag/0.15.0) - 2022-08-29 203 | 204 | ### Changed 205 | 206 | - BREAKING CHANGES: 207 | 208 | - methods from algo, composite algo, aggregate and metrics now take `inputs` (TypeDict) and `outputs` (TypeDict) as arguments 209 | - the user must load and save all the inputs and outputs of those methods (except for the datasamples) 210 | - `load_predictions` and `get_predictions` methods have been removed from the opener 211 | - `load_trunk_model`, `save_trunk_model`, `load_head_model`, `save_head_model` have been removed from the `tools.CompositeAlgo` class 212 | - `load_model` and `save_model` have been removed from both `tools.Algo` and `tools.AggregateAlgo` classes 213 | 214 | ## [0.14.0](https://github.com/Substra/substra-tools/releases/tag/0.14.0) - 2022-08-09 215 | 216 | ### Changed 217 | 218 | - BREAKING CHANGE: drop Python 3.7 support 219 | 220 | ### Fixed 221 | 222 | - fix: metric with type np.float32() is not Json serializable #47 223 | 224 | ## [0.13.0](https://github.com/Substra/substra-tools/releases/tag/0.13.0) - 2022-05-22 225 | 226 | ### Changed 227 | 228 | - BREAKING CHANGE: change --debug (bool) to --log-level (str) 229 | 230 | ## [0.12.0](https://github.com/Substra/substra-tools/releases/tag/0.12.0) - 2022-04-29 231 | 232 | ### Fixed 233 | 234 | - nvidia rotating keys 235 | 236 | ### Changed 237 | 238 | - (BREAKING) algos receive arguments are generic inputs/outputs dict 239 | 240 | ## [0.11.0](https://github.com/Substra/substra-tools/releases/tag/0.11.0) - 2022-04-11 241 | 242 | ### Fixed 243 | 244 | - alias in pyhton 3.7 for python3 245 | 246 | ### Improved 247 | 248 | - ci: build docker images as part of CI checks 249 | - ci: push latest image from main branch 250 | - chore: make Dockerfiles independent from each other 251 | -------------------------------------------------------------------------------- /substratools/function.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | import argparse 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | from copy import deepcopy 8 | from typing import Any 9 | from typing import Callable 10 | from typing import Dict 11 | from typing import Optional 12 | 13 | from substratools import exceptions 14 | from substratools import opener 15 | from substratools import utils 16 | from substratools.exceptions import ExistingRegisteredFunctionError 17 | from substratools.exceptions import FunctionNotFoundError 18 | from substratools.task_resources import StaticInputIdentifiers 19 | from substratools.task_resources import TaskResources 20 | from substratools.workspace import FunctionWorkspace 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def _parser_add_default_arguments(parser): 26 | parser.add_argument( 27 | "--function-name", 28 | type=str, 29 | help="The name of the function to execute from the given file", 30 | ) 31 | parser.add_argument( 32 | "-r", 33 | "--task-properties", 34 | type=str, 35 | default="{}", 36 | help="Define the task properties", 37 | ), 38 | parser.add_argument( 39 | "-d", 40 | "--fake-data", 41 | action="store_true", 42 | default=False, 43 | help="Enable fake data mode", 44 | ) 45 | parser.add_argument( 46 | "--n-fake-samples", 47 | default=None, 48 | type=int, 49 | help="Number of fake samples if fake data is used.", 50 | ) 51 | parser.add_argument( 52 | "--log-path", 53 | default=None, 54 | help="Define log filename path", 55 | ) 56 | parser.add_argument( 57 | "--log-level", 58 | default="info", 59 | choices=utils.MAPPING_LOG_LEVEL.keys(), 60 | help="Choose log level", 61 | ) 62 | parser.add_argument( 63 | "--inputs", 64 | type=str, 65 | default="[]", 66 | help="Inputs of the compute task", 67 | ) 68 | parser.add_argument( 69 | "--outputs", 70 | type=str, 71 | default="[]", 72 | help="Outputs of the compute task", 73 | ) 74 | 75 | 76 | class FunctionRegister: 77 | """Class to create a decorator to register function in substratools. The functions are registered in the _functions 78 | dictionary, with the function.__name__ as key. 79 | Register a function in substratools means that this function can be access by the function.execute functions through 80 | the --function-name CLI argument.""" 81 | 82 | def __init__(self): 83 | self._functions = {} 84 | 85 | def __call__(self, function: Callable, function_name: Optional[str] = None): 86 | """Function called when using an instance of the class as a decorator. 87 | 88 | Args: 89 | function (Callable): function to register in substratools. 90 | function_name (str, optional): function name to register the given function. 91 | If None, function.__name__ is used for registration. 92 | Raises: 93 | ExistingRegisteredFunctionError: Raise if a function with the same function.__name__ 94 | has already been registered in substratools. 95 | 96 | Returns: 97 | Callable: returns the function without decorator 98 | """ 99 | 100 | function_name = function_name or function.__name__ 101 | if function_name not in self._functions: 102 | self._functions[function_name] = function 103 | else: 104 | raise ExistingRegisteredFunctionError("A function with the same name is already registered.") 105 | 106 | return function 107 | 108 | def get_registered_functions(self): 109 | return self._functions 110 | 111 | 112 | # Instance of the decorator to store the function to register in memory. 113 | # Can be imported directly from substratools. 114 | register = FunctionRegister() 115 | 116 | 117 | class FunctionWrapper(object): 118 | """Wrapper to execute a function on the platform.""" 119 | 120 | def __init__(self, workspace: FunctionWorkspace, opener_wrapper: Optional[opener.OpenerWrapper]): 121 | self._workspace = workspace 122 | self._opener_wrapper = opener_wrapper 123 | 124 | def _assert_outputs_exists(self, outputs: Dict[str, str]): 125 | for key, path in outputs.items(): 126 | if os.path.isdir(path): 127 | raise exceptions.NotAFileError(f"Expected output file at {path}, found dir for output `{key}`") 128 | if not os.path.isfile(path): 129 | raise exceptions.MissingFileError(f"Output file {path} used to save argument `{key}` does not exists.") 130 | 131 | @utils.Timer(logger) 132 | def execute( 133 | self, function: Callable, task_properties: dict = {}, fake_data: bool = False, n_fake_samples: int = None 134 | ): 135 | """Execute a compute task""" 136 | 137 | # load inputs 138 | inputs = deepcopy(self._workspace.task_inputs) 139 | 140 | # load data from opener 141 | if self._opener_wrapper: 142 | loaded_datasamples = self._opener_wrapper.get_data(fake_data, n_fake_samples) 143 | 144 | if fake_data: 145 | logger.info("Using fake data with %i fake samples." % int(n_fake_samples)) 146 | 147 | assert ( 148 | StaticInputIdentifiers.datasamples.value not in inputs.keys() 149 | ), f"{StaticInputIdentifiers.datasamples.value} must be an input of kind `datasamples`" 150 | inputs.update({StaticInputIdentifiers.datasamples.value: loaded_datasamples}) 151 | 152 | # load outputs 153 | outputs = deepcopy(self._workspace.task_outputs) 154 | 155 | logger.info("Launching task: executing `%s` function." % function.__name__) 156 | function( 157 | inputs=inputs, 158 | outputs=outputs, 159 | task_properties=task_properties, 160 | ) 161 | 162 | self._assert_outputs_exists( 163 | self._workspace.task_outputs, 164 | ) 165 | 166 | 167 | def _generate_function_cli(): 168 | """Helper to generate a command line interface client.""" 169 | 170 | def _function_from_args(args): 171 | inputs = TaskResources(args.inputs) 172 | outputs = TaskResources(args.outputs) 173 | log_path = args.log_path 174 | chainkeys_path = inputs.chainkeys_path 175 | 176 | workspace = FunctionWorkspace( 177 | log_path=log_path, 178 | chainkeys_path=chainkeys_path, 179 | inputs=inputs, 180 | outputs=outputs, 181 | ) 182 | 183 | utils.configure_logging(workspace.log_path, log_level=args.log_level) 184 | 185 | opener_wrapper = opener.load_from_module( 186 | workspace=workspace, 187 | ) 188 | 189 | return FunctionWrapper(workspace, opener_wrapper) 190 | 191 | def _user_func(args, function): 192 | function_wrapper = _function_from_args(args) 193 | function_wrapper.execute( 194 | function=function, 195 | task_properties=json.loads(args.task_properties), 196 | fake_data=args.fake_data, 197 | n_fake_samples=args.n_fake_samples, 198 | ) 199 | 200 | parser = argparse.ArgumentParser(fromfile_prefix_chars="@") 201 | _parser_add_default_arguments(parser) 202 | parser.set_defaults(func=_user_func) 203 | 204 | return parser 205 | 206 | 207 | def _get_function_from_name(functions: dict, function_name: str): 208 | 209 | if function_name not in functions: 210 | raise FunctionNotFoundError( 211 | f"The function {function_name} given as --function-name argument as not been found." 212 | ) 213 | 214 | return functions[function_name] 215 | 216 | 217 | def save_performance(performance: Any, path: os.PathLike): 218 | with open(path, "w") as f: 219 | json.dump({"all": performance}, f) 220 | 221 | 222 | def load_performance(path: os.PathLike) -> Any: 223 | with open(path, "r") as f: 224 | performance = json.load(f)["all"] 225 | return performance 226 | 227 | 228 | def execute(sysargs=None): 229 | """Launch function command line interface.""" 230 | 231 | cli = _generate_function_cli() 232 | 233 | sysargs = sysargs if sysargs is not None else sys.argv[1:] 234 | args = cli.parse_args(sysargs) 235 | function = _get_function_from_name(register.get_registered_functions(), args.function_name) 236 | args.func(args, function) 237 | 238 | return args 239 | -------------------------------------------------------------------------------- /tests/test_aggregatealgo.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import PathLike 3 | from typing import Any 4 | from typing import List 5 | from typing import TypedDict 6 | from uuid import uuid4 7 | 8 | import pytest 9 | 10 | from substratools import exceptions 11 | from substratools import function 12 | from substratools import opener 13 | from substratools.task_resources import TaskResources 14 | from substratools.workspace import FunctionWorkspace 15 | from tests import utils 16 | from tests.utils import InputIdentifiers 17 | from tests.utils import OutputIdentifiers 18 | 19 | 20 | @pytest.fixture(autouse=True) 21 | def setup(valid_opener): 22 | pass 23 | 24 | 25 | @function.register 26 | def aggregate( 27 | inputs: TypedDict( 28 | "inputs", 29 | {InputIdentifiers.shared: List[PathLike]}, 30 | ), 31 | outputs: TypedDict("outputs", {OutputIdentifiers.shared: PathLike}), 32 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), 33 | ) -> None: 34 | if inputs: 35 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) 36 | else: 37 | models = [] 38 | 39 | new_model = {"value": 0} 40 | for m in models: 41 | new_model["value"] += m["value"] 42 | 43 | utils.save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) 44 | 45 | 46 | @function.register 47 | def aggregate_predict( 48 | inputs: TypedDict( 49 | "inputs", 50 | { 51 | InputIdentifiers.datasamples: Any, 52 | InputIdentifiers.shared: PathLike, 53 | }, 54 | ), 55 | outputs: TypedDict("outputs", {OutputIdentifiers.shared: PathLike}), 56 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), 57 | ): 58 | model = utils.load_model(path=inputs.get(OutputIdentifiers.shared)) 59 | 60 | # Predict 61 | X = inputs.get(InputIdentifiers.datasamples)[0] 62 | pred = X * model["value"] 63 | 64 | # save predictions 65 | utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions)) 66 | 67 | 68 | def no_saved_aggregate(inputs, outputs, task_properties): 69 | if inputs: 70 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) 71 | else: 72 | models = [] 73 | 74 | new_model = {"value": 0} 75 | for m in models: 76 | new_model["value"] += m["value"] 77 | 78 | utils.no_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) 79 | 80 | 81 | def wrong_saved_aggregate(inputs, outputs, task_properties): 82 | if inputs: 83 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) 84 | else: 85 | models = [] 86 | 87 | new_model = {"value": 0} 88 | for m in models: 89 | new_model["value"] += m["value"] 90 | 91 | utils.wrong_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) 92 | 93 | 94 | @pytest.fixture 95 | def create_models(workdir): 96 | model_a = {"value": 1} 97 | model_b = {"value": 2} 98 | 99 | model_dir = workdir / OutputIdentifiers.shared 100 | model_dir.mkdir() 101 | 102 | def _create_model(model_data): 103 | model_name = model_data["value"] 104 | filename = "{}.json".format(model_name) 105 | path = model_dir / filename 106 | path.write_text(json.dumps(model_data)) 107 | return str(path) 108 | 109 | model_datas = [model_a, model_b] 110 | model_filenames = [_create_model(d) for d in model_datas] 111 | 112 | return model_datas, model_filenames 113 | 114 | 115 | def test_aggregate_no_model(valid_function_workspace): 116 | wp = function.FunctionWrapper(workspace=valid_function_workspace, opener_wrapper=None) 117 | wp.execute(function=aggregate) 118 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) 119 | assert model["value"] == 0 120 | 121 | 122 | def test_aggregate_multiple_models(create_models, output_model_path): 123 | _, model_filenames = create_models 124 | 125 | workspace_inputs = TaskResources( 126 | json.dumps([{"id": InputIdentifiers.shared, "value": f, "multiple": True} for f in model_filenames]) 127 | ) 128 | workspace_outputs = TaskResources( 129 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) 130 | ) 131 | 132 | workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs) 133 | wp = function.FunctionWrapper(workspace, opener_wrapper=None) 134 | 135 | wp.execute(function=aggregate) 136 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) 137 | 138 | assert model["value"] == 3 139 | 140 | 141 | @pytest.mark.parametrize( 142 | "fake_data,expected_pred,n_fake_samples", 143 | [ 144 | (False, "X", None), 145 | (True, ["Xfake"], 1), 146 | ], 147 | ) 148 | def test_predict(fake_data, expected_pred, n_fake_samples, create_models): 149 | _, model_filenames = create_models 150 | 151 | workspace_inputs = TaskResources( 152 | json.dumps([{"id": InputIdentifiers.shared, "value": model_filenames[0], "multiple": False}]) 153 | ) 154 | workspace_outputs = TaskResources( 155 | json.dumps([{"id": OutputIdentifiers.predictions, "value": model_filenames[0], "multiple": False}]) 156 | ) 157 | 158 | workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs) 159 | 160 | wp = function.FunctionWrapper(workspace, opener_wrapper=opener.load_from_module()) 161 | 162 | wp.execute(function=aggregate_predict, fake_data=fake_data, n_fake_samples=n_fake_samples) 163 | 164 | pred = utils.load_predictions(wp._workspace.task_outputs[OutputIdentifiers.predictions]) 165 | assert pred == expected_pred 166 | 167 | 168 | def test_execute_aggregate(output_model_path): 169 | assert not output_model_path.exists() 170 | 171 | outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}] 172 | 173 | function.execute(sysargs=["--function-name", "aggregate", "--outputs", json.dumps(outputs)]) 174 | assert output_model_path.exists() 175 | output_model_path.unlink() 176 | function.execute( 177 | sysargs=["--function-name", "aggregate", "--outputs", json.dumps(outputs), "--log-level", "debug"], 178 | ) 179 | assert output_model_path.exists() 180 | 181 | 182 | def test_execute_aggregate_multiple_models(workdir, create_models, output_model_path): 183 | _, model_filenames = create_models 184 | 185 | assert not output_model_path.exists() 186 | 187 | inputs = [ 188 | {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames 189 | ] 190 | outputs = [ 191 | {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, 192 | ] 193 | options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)] 194 | 195 | command = ["--function-name", "aggregate"] 196 | command.extend(options) 197 | 198 | function.execute(sysargs=command) 199 | assert output_model_path.exists() 200 | with open(output_model_path, "r") as f: 201 | model = json.load(f) 202 | assert model["value"] == 3 203 | 204 | 205 | def test_execute_predict(workdir, create_models, output_model_path, valid_opener_script): 206 | _, model_filenames = create_models 207 | assert not output_model_path.exists() 208 | 209 | inputs = [ 210 | {"id": InputIdentifiers.shared, "value": str(workdir / model_name), "multiple": True} 211 | for model_name in model_filenames 212 | ] 213 | outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}] 214 | options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)] 215 | command = ["--function-name", "aggregate"] 216 | command.extend(options) 217 | function.execute(sysargs=command) 218 | assert output_model_path.exists() 219 | 220 | # do predict on output model 221 | pred_path = workdir / str(uuid4()) 222 | assert not pred_path.exists() 223 | 224 | pred_inputs = [ 225 | {"id": InputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, 226 | {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False}, 227 | ] 228 | pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}] 229 | pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)] 230 | 231 | function.execute(sysargs=["--function-name", "predict"] + pred_options) 232 | assert pred_path.exists() 233 | with open(pred_path, "r") as f: 234 | pred = json.load(f) 235 | assert pred == "XXX" 236 | pred_path.unlink() 237 | 238 | 239 | @pytest.mark.parametrize("function_to_run", (no_saved_aggregate, wrong_saved_aggregate)) 240 | def test_model_check(function_to_run, valid_function_workspace): 241 | wp = function.FunctionWrapper(valid_function_workspace, opener_wrapper=None) 242 | 243 | with pytest.raises(exceptions.MissingFileError): 244 | wp.execute(function=function_to_run) 245 | -------------------------------------------------------------------------------- /Substra-logo-white.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /Substra-logo-colour.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2018-2022 Owkin, Inc. 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /tests/test_compositealgo.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any 4 | from typing import Optional 5 | from typing import TypedDict 6 | 7 | import pytest 8 | 9 | from substratools import exceptions 10 | from substratools import function 11 | from substratools import opener 12 | from substratools.task_resources import TaskResources 13 | from substratools.workspace import FunctionWorkspace 14 | from tests import utils 15 | from tests.utils import InputIdentifiers 16 | from tests.utils import OutputIdentifiers 17 | 18 | 19 | @pytest.fixture(autouse=True) 20 | def setup(valid_opener): 21 | pass 22 | 23 | 24 | def fake_data_train(inputs: dict, outputs: dict, task_properties: dict): 25 | utils.save_model(model=inputs[InputIdentifiers.datasamples][0], path=outputs["local"]) 26 | utils.save_model(model=inputs[InputIdentifiers.datasamples][1], path=outputs["shared"]) 27 | 28 | 29 | def fake_data_predict(inputs: dict, outputs: dict, task_properties: dict) -> None: 30 | utils.save_model(model=inputs[InputIdentifiers.datasamples][0], path=outputs["predictions"]) 31 | 32 | 33 | def train( 34 | inputs: TypedDict( 35 | "inputs", 36 | { 37 | InputIdentifiers.datasamples: Any, 38 | InputIdentifiers.local: Optional[os.PathLike], 39 | InputIdentifiers.shared: Optional[os.PathLike], 40 | }, 41 | ), 42 | outputs: TypedDict( 43 | "outputs", 44 | { 45 | OutputIdentifiers.local: os.PathLike, 46 | OutputIdentifiers.shared: os.PathLike, 47 | }, 48 | ), 49 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), 50 | ): 51 | # init phase 52 | # load models 53 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) 54 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) 55 | 56 | if head_model and trunk_model: 57 | new_head_model = dict(head_model) 58 | new_trunk_model = dict(trunk_model) 59 | else: 60 | new_head_model = {"value": 0} 61 | new_trunk_model = {"value": 0} 62 | 63 | # train models 64 | new_head_model["value"] += 1 65 | new_trunk_model["value"] -= 1 66 | 67 | # save model 68 | utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) 69 | utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) 70 | 71 | 72 | def predict( 73 | inputs: TypedDict( 74 | "inputs", 75 | { 76 | InputIdentifiers.datasamples: Any, 77 | InputIdentifiers.local: os.PathLike, 78 | InputIdentifiers.shared: os.PathLike, 79 | }, 80 | ), 81 | outputs: TypedDict( 82 | "outputs", 83 | { 84 | OutputIdentifiers.predictions: os.PathLike, 85 | }, 86 | ), 87 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), 88 | ): 89 | 90 | # init phase 91 | # load models 92 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) 93 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) 94 | 95 | pred = list(range(head_model["value"], trunk_model["value"])) 96 | 97 | # save predictions 98 | utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions)) 99 | 100 | 101 | def no_saved_trunk_train(inputs, outputs, task_properties): 102 | # init phase 103 | # load models 104 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) 105 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) 106 | 107 | if head_model and trunk_model: 108 | new_head_model = dict(head_model) 109 | new_trunk_model = dict(trunk_model) 110 | else: 111 | new_head_model = {"value": 0} 112 | new_trunk_model = {"value": 0} 113 | 114 | # train models 115 | new_head_model["value"] += 1 116 | new_trunk_model["value"] -= 1 117 | 118 | # save model 119 | utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) 120 | utils.no_save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) 121 | 122 | 123 | def no_saved_head_train(inputs, outputs, task_properties): 124 | # init phase 125 | # load models 126 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) 127 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) 128 | 129 | if head_model and trunk_model: 130 | new_head_model = dict(head_model) 131 | new_trunk_model = dict(trunk_model) 132 | else: 133 | new_head_model = {"value": 0} 134 | new_trunk_model = {"value": 0} 135 | 136 | # train models 137 | new_head_model["value"] += 1 138 | new_trunk_model["value"] -= 1 139 | 140 | # save model 141 | utils.no_save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) 142 | utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) 143 | 144 | 145 | def wrong_saved_trunk_train(inputs, outputs, task_properties): 146 | # init phase 147 | # load models 148 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) 149 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) 150 | 151 | if head_model and trunk_model: 152 | new_head_model = dict(head_model) 153 | new_trunk_model = dict(trunk_model) 154 | else: 155 | new_head_model = {"value": 0} 156 | new_trunk_model = {"value": 0} 157 | 158 | # train models 159 | new_head_model["value"] += 1 160 | new_trunk_model["value"] -= 1 161 | 162 | # save model 163 | utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) 164 | utils.wrong_save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) 165 | 166 | 167 | def wrong_saved_head_train(inputs, outputs, task_properties): 168 | # init phase 169 | # load models 170 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) 171 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) 172 | 173 | if head_model and trunk_model: 174 | new_head_model = dict(head_model) 175 | new_trunk_model = dict(trunk_model) 176 | else: 177 | new_head_model = {"value": 0} 178 | new_trunk_model = {"value": 0} 179 | 180 | # train models 181 | new_head_model["value"] += 1 182 | new_trunk_model["value"] -= 1 183 | 184 | # save model 185 | utils.wrong_save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) 186 | utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) 187 | 188 | 189 | @pytest.fixture 190 | def train_outputs(output_model_path, output_model_path_2): 191 | outputs = TaskResources( 192 | json.dumps( 193 | [ 194 | {"id": "local", "value": str(output_model_path), "multiple": False}, 195 | {"id": "shared", "value": str(output_model_path_2), "multiple": False}, 196 | ] 197 | ) 198 | ) 199 | return outputs 200 | 201 | 202 | @pytest.fixture 203 | def composite_inputs(create_models): 204 | _, local_path, shared_path = create_models 205 | inputs = TaskResources( 206 | json.dumps( 207 | [ 208 | {"id": InputIdentifiers.local, "value": str(local_path), "multiple": False}, 209 | {"id": InputIdentifiers.shared, "value": str(shared_path), "multiple": False}, 210 | ] 211 | ) 212 | ) 213 | 214 | return inputs 215 | 216 | 217 | @pytest.fixture 218 | def predict_outputs(output_model_path): 219 | outputs = TaskResources( 220 | json.dumps([{"id": OutputIdentifiers.predictions, "value": str(output_model_path), "multiple": False}]) 221 | ) 222 | return outputs 223 | 224 | 225 | @pytest.fixture 226 | def create_models(workdir): 227 | head_model = {"value": 1} 228 | trunk_model = {"value": -1} 229 | 230 | def _create_model(model_data, name): 231 | filename = "{}.json".format(name) 232 | path = workdir / filename 233 | path.write_text(json.dumps(model_data)) 234 | return path 235 | 236 | head_path = _create_model(head_model, "head") 237 | trunk_path = _create_model(trunk_model, "trunk") 238 | 239 | return ( 240 | [head_model, trunk_model], 241 | head_path, 242 | trunk_path, 243 | ) 244 | 245 | 246 | def test_train_no_model(train_outputs): 247 | 248 | dummy_train_workspace = FunctionWorkspace(outputs=train_outputs) 249 | dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, None) 250 | dummy_train_wrapper.execute(function=train) 251 | local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["local"]) 252 | shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["shared"]) 253 | 254 | assert local_model["value"] == 1 255 | assert shared_model["value"] == -1 256 | 257 | 258 | def test_train_input_head_trunk_models(composite_inputs, train_outputs): 259 | 260 | dummy_train_workspace = FunctionWorkspace(inputs=composite_inputs, outputs=train_outputs) 261 | dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, None) 262 | dummy_train_wrapper.execute(function=train) 263 | local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["local"]) 264 | shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["shared"]) 265 | 266 | assert local_model["value"] == 2 267 | assert shared_model["value"] == -2 268 | 269 | 270 | @pytest.mark.parametrize("n_fake_samples", (0, 1, 2)) 271 | def test_train_fake_data(train_outputs, n_fake_samples): 272 | _opener = opener.load_from_module() 273 | dummy_train_workspace = FunctionWorkspace(outputs=train_outputs) 274 | dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, _opener) 275 | dummy_train_wrapper.execute(function=fake_data_train, fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples) 276 | 277 | local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.local]) 278 | shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.shared]) 279 | 280 | assert local_model == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[0] 281 | assert shared_model == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[1] 282 | 283 | 284 | @pytest.mark.parametrize("n_fake_samples", (0, 1, 2)) 285 | def test_predict_fake_data(composite_inputs, predict_outputs, n_fake_samples): 286 | _opener = opener.load_from_module() 287 | dummy_train_workspace = FunctionWorkspace(inputs=composite_inputs, outputs=predict_outputs) 288 | dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, _opener) 289 | dummy_train_wrapper.execute( 290 | function=fake_data_predict, fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples 291 | ) 292 | 293 | predictions = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.predictions]) 294 | 295 | assert predictions == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[0] 296 | 297 | 298 | @pytest.mark.parametrize( 299 | "function_to_run", 300 | ( 301 | no_saved_head_train, 302 | no_saved_trunk_train, 303 | wrong_saved_head_train, 304 | wrong_saved_trunk_train, 305 | ), 306 | ) 307 | def test_model_check(function_to_run, train_outputs): 308 | dummy_train_workspace = FunctionWorkspace(outputs=train_outputs) 309 | wp = function.FunctionWrapper(workspace=dummy_train_workspace, opener_wrapper=None) 310 | 311 | with pytest.raises(exceptions.MissingFileError): 312 | wp.execute(function_to_run) 313 | -------------------------------------------------------------------------------- /tests/test_function.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | from os import PathLike 4 | from pathlib import Path 5 | from typing import Any 6 | from typing import List 7 | from typing import Optional 8 | from typing import Tuple 9 | from typing import TypedDict 10 | 11 | import pytest 12 | 13 | from substratools import exceptions 14 | from substratools import function 15 | from substratools import opener 16 | from substratools.task_resources import StaticInputIdentifiers 17 | from substratools.task_resources import TaskResources 18 | from substratools.workspace import FunctionWorkspace 19 | from tests import utils 20 | from tests.utils import InputIdentifiers 21 | from tests.utils import OutputIdentifiers 22 | 23 | 24 | @pytest.fixture(autouse=True) 25 | def setup(valid_opener): 26 | pass 27 | 28 | 29 | @function.register 30 | def train( 31 | inputs: TypedDict( 32 | "inputs", 33 | { 34 | InputIdentifiers.datasamples: Tuple[List["str"], List[int]], # cf valid_opener_code 35 | InputIdentifiers.shared: Optional[ 36 | PathLike 37 | ], # inputs contains a dict where keys are identifiers and values are paths on the disk 38 | }, 39 | ), 40 | outputs: TypedDict( 41 | "outputs", {OutputIdentifiers.shared: PathLike} 42 | ), # outputs contains a dict where keys are identifiers and values are paths on disk 43 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), 44 | ) -> None: 45 | # TODO: checks on data 46 | # load models 47 | if inputs: 48 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) 49 | else: 50 | models = [] 51 | # init model 52 | new_model = {"value": 0} 53 | 54 | # train (just add the models values) 55 | for m in models: 56 | assert isinstance(m, dict) 57 | assert "value" in m 58 | new_model["value"] += m["value"] 59 | 60 | # save model 61 | utils.save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) 62 | 63 | 64 | @function.register 65 | def predict( 66 | inputs: TypedDict("inputs", {InputIdentifiers.datasamples: Any, InputIdentifiers.shared: List[PathLike]}), 67 | outputs: TypedDict("outputs", {OutputIdentifiers.predictions: PathLike}), 68 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), 69 | ) -> None: 70 | # TODO: checks on data 71 | 72 | # load_model 73 | model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) 74 | 75 | # predict 76 | X = inputs.get(InputIdentifiers.datasamples)[0] 77 | pred = X * model["value"] 78 | 79 | # save predictions 80 | utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions)) 81 | 82 | 83 | @function.register 84 | def no_saved_train(inputs, outputs, task_properties): 85 | # TODO: checks on data 86 | # load models 87 | if inputs: 88 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) 89 | else: 90 | models = [] 91 | # init model 92 | new_model = {"value": 0} 93 | 94 | # train (just add the models values) 95 | for m in models: 96 | assert isinstance(m, dict) 97 | assert "value" in m 98 | new_model["value"] += m["value"] 99 | 100 | # save model 101 | utils.no_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) 102 | 103 | 104 | @function.register 105 | def wrong_saved_train(inputs, outputs, task_properties): 106 | # TODO: checks on data 107 | # load models 108 | if inputs: 109 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) 110 | else: 111 | models = [] 112 | # init model 113 | new_model = {"value": 0} 114 | 115 | # train (just add the models values) 116 | for m in models: 117 | assert isinstance(m, dict) 118 | assert "value" in m 119 | new_model["value"] += m["value"] 120 | 121 | # save model 122 | utils.wrong_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) 123 | 124 | 125 | @pytest.fixture 126 | def create_models(workdir): 127 | model_a = {"value": 1} 128 | model_b = {"value": 2} 129 | 130 | model_dir = workdir / "model" 131 | model_dir.mkdir() 132 | 133 | def _create_model(model_data): 134 | model_name = model_data["value"] 135 | filename = "{}.json".format(model_name) 136 | path = model_dir / filename 137 | path.write_text(json.dumps(model_data)) 138 | return str(path) 139 | 140 | model_datas = [model_a, model_b] 141 | model_filenames = [_create_model(d) for d in model_datas] 142 | 143 | return model_datas, model_filenames 144 | 145 | 146 | def test_train_no_model(valid_function_workspace): 147 | wp = function.FunctionWrapper(valid_function_workspace, opener_wrapper=None) 148 | wp.execute(function=train) 149 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) 150 | assert model["value"] == 0 151 | 152 | 153 | def test_train_multiple_models(output_model_path, create_models): 154 | _, model_filenames = create_models 155 | 156 | workspace_inputs = TaskResources( 157 | json.dumps([{"id": InputIdentifiers.shared, "value": str(f), "multiple": True} for f in model_filenames]) 158 | ) 159 | workspace_outputs = TaskResources( 160 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) 161 | ) 162 | 163 | workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs) 164 | wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=None) 165 | 166 | wp.execute(function=train) 167 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) 168 | 169 | assert model["value"] == 3 170 | 171 | 172 | def test_train_fake_data(output_model_path): 173 | workspace_outputs = TaskResources( 174 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) 175 | ) 176 | 177 | workspace = FunctionWorkspace(outputs=workspace_outputs) 178 | wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=None) 179 | wp.execute(function=train, fake_data=True, n_fake_samples=2) 180 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) 181 | assert model["value"] == 0 182 | 183 | 184 | @pytest.mark.parametrize( 185 | "fake_data,expected_pred,n_fake_samples", 186 | [ 187 | (False, "X", None), 188 | (True, ["Xfake"], 1), 189 | ], 190 | ) 191 | def test_predict(fake_data, expected_pred, n_fake_samples, create_models, output_model_path): 192 | _, model_filenames = create_models 193 | 194 | workspace_inputs = TaskResources( 195 | json.dumps([{"id": InputIdentifiers.shared, "value": model_filenames[0], "multiple": False}]) 196 | ) 197 | workspace_outputs = TaskResources( 198 | json.dumps([{"id": OutputIdentifiers.predictions, "value": str(output_model_path), "multiple": False}]) 199 | ) 200 | 201 | workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs) 202 | wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=opener.load_from_module()) 203 | wp.execute(function=predict, fake_data=fake_data, n_fake_samples=n_fake_samples) 204 | 205 | pred = utils.load_predictions(wp._workspace.task_outputs["predictions"]) 206 | assert pred == expected_pred 207 | 208 | 209 | def test_execute_train(workdir, output_model_path): 210 | inputs = [ 211 | { 212 | "id": StaticInputIdentifiers.datasamples.value, 213 | "value": str(workdir / "datasamples_unused"), 214 | "multiple": True, 215 | }, 216 | ] 217 | outputs = [ 218 | {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, 219 | ] 220 | options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)] 221 | 222 | assert not output_model_path.exists() 223 | 224 | function.execute(sysargs=["--function-name", "train"] + options) 225 | assert output_model_path.exists() 226 | 227 | function.execute( 228 | sysargs=["--function-name", "train", "--fake-data", "--n-fake-samples", "1", "--outputs", json.dumps(outputs)] 229 | ) 230 | assert output_model_path.exists() 231 | 232 | function.execute(sysargs=["--function-name", "train", "--log-level", "debug"] + options) 233 | assert output_model_path.exists() 234 | 235 | 236 | def test_execute_train_multiple_models(workdir, output_model_path, create_models): 237 | _, model_filenames = create_models 238 | 239 | output_model_path = Path(output_model_path) 240 | 241 | assert not output_model_path.exists() 242 | pred_path = workdir / "pred" 243 | assert not pred_path.exists() 244 | 245 | inputs = [ 246 | {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames 247 | ] 248 | 249 | outputs = [ 250 | {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, 251 | ] 252 | options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)] 253 | 254 | command = ["--function-name", "train"] 255 | command.extend(options) 256 | 257 | function.execute(sysargs=command) 258 | assert output_model_path.exists() 259 | with open(output_model_path, "r") as f: 260 | model = json.load(f) 261 | assert model["value"] == 3 262 | 263 | assert not pred_path.exists() 264 | 265 | 266 | def test_execute_predict(workdir, output_model_path, create_models, valid_opener_script): 267 | _, model_filenames = create_models 268 | pred_path = workdir / "pred" 269 | train_inputs = [ 270 | {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames 271 | ] 272 | 273 | train_outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}] 274 | train_options = ["--inputs", json.dumps(train_inputs), "--outputs", json.dumps(train_outputs)] 275 | 276 | output_model_path = Path(output_model_path) 277 | # first train models 278 | assert not pred_path.exists() 279 | command = ["--function-name", "train"] 280 | command.extend(train_options) 281 | function.execute(sysargs=command) 282 | assert output_model_path.exists() 283 | 284 | # do predict on output model 285 | pred_inputs = [ 286 | {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False}, 287 | {"id": InputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, 288 | ] 289 | pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}] 290 | pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)] 291 | 292 | assert not pred_path.exists() 293 | function.execute(sysargs=["--function-name", "predict"] + pred_options) 294 | assert pred_path.exists() 295 | with open(pred_path, "r") as f: 296 | pred = json.load(f) 297 | assert pred == "XXX" 298 | pred_path.unlink() 299 | 300 | # do predict with different model paths 301 | input_models_dir = workdir / "other_models" 302 | input_models_dir.mkdir() 303 | input_model_path = input_models_dir / "supermodel" 304 | shutil.move(output_model_path, input_model_path) 305 | 306 | pred_inputs = [ 307 | {"id": InputIdentifiers.shared, "value": str(input_model_path), "multiple": False}, 308 | {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False}, 309 | ] 310 | pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}] 311 | pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)] 312 | 313 | assert not pred_path.exists() 314 | function.execute(sysargs=["--function-name", "predict"] + pred_options) 315 | assert pred_path.exists() 316 | with open(pred_path, "r") as f: 317 | pred = json.load(f) 318 | assert pred == "XXX" 319 | 320 | 321 | @pytest.mark.parametrize("function_to_run", (no_saved_train, wrong_saved_train)) 322 | def test_model_check(valid_function_workspace, function_to_run): 323 | wp = function.FunctionWrapper(workspace=valid_function_workspace, opener_wrapper=None) 324 | 325 | with pytest.raises(exceptions.MissingFileError): 326 | wp.execute(function=function_to_run) 327 | 328 | 329 | def test_function_not_found(): 330 | with pytest.raises(exceptions.FunctionNotFoundError): 331 | function.execute(sysargs=["--function-name", "imaginary_function"]) 332 | 333 | 334 | def test_function_name_already_register(): 335 | @function.register 336 | def fake_function(): 337 | pass 338 | 339 | with pytest.raises(exceptions.ExistingRegisteredFunctionError): 340 | 341 | @function.register 342 | def fake_function(): 343 | pass 344 | --------------------------------------------------------------------------------