├── changes
└── .gitkeep
├── .github
├── CODEOWNERS
├── workflows
│ ├── publish.yml
│ ├── towncrier-changelog.yml
│ └── python.yml
└── ISSUE_TEMPLATE
│ ├── config.yml
│ └── bug_report.yaml
├── substratools
├── __version__.py
├── __init__.py
├── exceptions.py
├── workspace.py
├── opener.py
├── task_resources.py
├── utils.py
└── function.py
├── tests
├── __init__.py
├── test_genericalgo.py
├── test_utils.py
├── utils.py
├── conftest.py
├── test_task_resources.py
├── test_opener.py
├── test_metrics.py
├── test_workflow.py
├── test_aggregatealgo.py
├── test_compositealgo.py
└── test_function.py
├── .git-blame-ignore-revs
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Makefile
├── .coveragerc
├── .flake8
├── .pre-commit-config.yaml
├── CONTRIBUTORS.md
├── .gitignore
├── pyproject.toml
├── README.md
├── CHANGELOG.md
├── Substra-logo-white.svg
├── Substra-logo-colour.svg
└── LICENSE
/changes/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @Substra/code-owners
2 |
--------------------------------------------------------------------------------
/substratools/__version__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from tests import utils
2 |
3 | __all__ = ["utils"]
4 |
--------------------------------------------------------------------------------
/.git-blame-ignore-revs:
--------------------------------------------------------------------------------
1 | 25ad9f6b8f17422ac8a50cfc5f4604a26ae7f21a
2 | f587640673be0bb0789cc0036b7af12516a1ff9b
3 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | Substra repositories' code of conduct is available in the Substra documentation [here](https://docs.substra.org/en/stable/contributing/code-of-conduct.html).
2 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Substra repositories' contributing guide is available in the Substra documentation [here](https://docs.substra.org/en/stable/contributing/contributing-guide.html).
2 |
--------------------------------------------------------------------------------
/tests/test_genericalgo.py:
--------------------------------------------------------------------------------
1 | # TODO: As the implementation is going to change from a class to a function
2 | # decorator, those test will be added when the implementation is stable
3 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test doc
2 |
3 | DOCS_FILEPATH = docs/api.md
4 |
5 | ifeq ($(TAG),)
6 | TAG := $(shell git describe --always --tags)
7 | endif
8 |
9 | test:
10 | pytest tests
11 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | # .coveragerc to control coverage.py
2 |
3 | [report]
4 | # Regexes for lines to exclude from consideration
5 | exclude_lines =
6 |
7 | # Don't complain if tests don't hit defensive assertion code:
8 | raise AssertionError
9 | raise NotImplementedError
10 |
11 | show_missing = True
12 |
--------------------------------------------------------------------------------
/substratools/__init__.py:
--------------------------------------------------------------------------------
1 | from substratools.__version__ import __version__
2 |
3 | from . import function
4 | from . import opener
5 | from .function import execute
6 | from .function import load_performance
7 | from .function import register
8 | from .function import save_performance
9 | from .opener import Opener
10 |
11 | __all__ = [
12 | "__version__",
13 | function,
14 | opener,
15 | Opener,
16 | execute,
17 | load_performance,
18 | register,
19 | save_performance,
20 | ]
21 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | publish:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v3
12 | - uses: actions/setup-python@v5
13 | with:
14 | python-version: 3.12
15 | - name: Install Hatch
16 | run: pipx install hatch
17 | - name: Build dist
18 | run: hatch build
19 | - name: Publish
20 | run: hatch publish -u __token__ -a ${{ secrets.PYPI_API_TOKEN }}
21 |
22 |
--------------------------------------------------------------------------------
/substratools/exceptions.py:
--------------------------------------------------------------------------------
1 | class InvalidInterfaceError(Exception):
2 | pass
3 |
4 |
5 | class EmptyInterfaceError(InvalidInterfaceError):
6 | pass
7 |
8 |
9 | class NotAFileError(Exception):
10 | pass
11 |
12 |
13 | class MissingFileError(Exception):
14 | pass
15 |
16 |
17 | class InvalidInputOutputsError(Exception):
18 | pass
19 |
20 |
21 | class InvalidCLIError(Exception):
22 | pass
23 |
24 |
25 | class FunctionNotFoundError(Exception):
26 | pass
27 |
28 |
29 | class ExistingRegisteredFunctionError(Exception):
30 | pass
31 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | max-complexity = 10
4 | extend-ignore = E203, W503, N802, N803, N806
5 | # W503 is incompatible with flake8, see https://github.com/psf/black/pull/36
6 | # E203 must be disabled for flake8 to work with Black.
7 | # See https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#id1
8 | # N802, N803 and N806 prevent us from using upper cases in variables names, functions name and arguments.
9 | exclude =
10 | .git
11 | .github
12 | .dvc
13 | __pycache__
14 | .venv
15 | .mypy_cache
16 | .pytest_cache
17 | hubconf.py
18 | **local-worker
19 | docs/*
20 |
--------------------------------------------------------------------------------
/.github/workflows/towncrier-changelog.yml:
--------------------------------------------------------------------------------
1 | name: Towncrier changelog
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | app_version:
7 | type: string
8 | description: 'The version of the app'
9 | required: true
10 | branch:
11 | type: string
12 | description: 'The branch to update'
13 | required: true
14 |
15 | jobs:
16 | test-generate-publish:
17 | uses: substra/substra-gha-workflows/.github/workflows/towncrier-changelog.yml@main
18 | secrets: inherit
19 | with:
20 | app_version: ${{ inputs.app_version }}
21 | repo: substra-tools
22 | branch: ${{ inputs.branch }}
23 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: true
2 | contact_links:
3 | - name: Ask a question
4 | url: https://join.slack.com/t/substra-workspace/shared_invite/zt-1fqnk0nw6-xoPwuLJ8dAPXThfyldX8yA
5 | about: Don't hesitate to join the Substra community on Slack to ask all your questions!
6 | - name: User Documentation Improvement
7 | url: https://github.com/Substra/substra-documentation/issues
8 | about: For issues related to the User Documentation, please open an issue on the substra-documentation repository
9 | - name: Feature request
10 | url: https://github.com/Substra/substra/issues
11 | about: We centralize feature requests in the substra repository, please open an issue there
12 |
--------------------------------------------------------------------------------
/.github/workflows/python.yml:
--------------------------------------------------------------------------------
1 | name: Python
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | - main
8 | pull_request:
9 | branches:
10 | - master
11 | - main
12 |
13 | jobs:
14 | lint:
15 | name: Lint and tests
16 | runs-on: ubuntu-latest
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up python
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: 3.12
23 | - name: Install tools
24 | run: pip install flake8
25 | - name: Lint
26 | run: flake8 substratools
27 | - name: Install substra-tools
28 | run: pip install -e '.[dev]'
29 | - name: Test
30 | run: |
31 | make test
32 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: ^docs/
2 | repos:
3 | - repo: https://github.com/psf/black
4 | rev: 22.3.0
5 | hooks:
6 | - id: black
7 | args: [--line-length=120]
8 |
9 | - repo: https://github.com/pycqa/flake8
10 | rev: 4.0.1
11 | hooks:
12 | - id: flake8
13 | additional_dependencies: [pep8-naming, flake8-bugbear]
14 | exclude: docs/api.md
15 |
16 | - repo: https://github.com/pycqa/isort
17 | rev: 5.12.0
18 | hooks:
19 | - id: isort
20 | args: [--line-length=120]
21 | name: isort (python)
22 |
23 | - repo: https://github.com/pre-commit/pre-commit-hooks
24 | rev: v3.2.0
25 | hooks:
26 | - id: trailing-whitespace
27 | - id: end-of-file-fixer
28 | - id: debug-statements
29 | - id: check-added-large-files
30 |
--------------------------------------------------------------------------------
/CONTRIBUTORS.md:
--------------------------------------------------------------------------------
1 | This is a file of people that have made significant contributions to the Substra tools. It is sorted in chronological order. Please include your contribution at the bottom of this document in the following format : name (N), email (E), description of work (W) and date (D).
2 |
3 | To have your contribution listed, your work must meet the minimum [threshold of originality](https://en.wikipedia.org/wiki/Threshold_of_originality), which will be evaluated by the maintainers of the repository.
4 |
5 | Thank you for your contribution, your work is greatly appreciated !
6 |
7 | —-- Example —--
8 |
9 | - N: John Doe
10 | - E: john.doe@owkin.com
11 | - W: Integrated new feature
12 | - D: 02/02/2023
13 |
14 | ---
15 |
16 | Copyright (c) 2018-present Owkin Inc. All rights reserved.
17 |
18 | All other contributions:
19 | Copyright (c) 2023 to the respective contributors.
20 | All rights reserved.
21 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import pytest
4 |
5 | from substratools import exceptions
6 | from substratools.opener import Opener
7 | from substratools.utils import get_logger
8 | from substratools.utils import import_module
9 | from substratools.utils import load_interface_from_module
10 |
11 |
12 | def test_invalid_interface():
13 | code = """
14 | def score():
15 | pass
16 | """
17 | import_module("score", code)
18 | with pytest.raises(exceptions.InvalidInterfaceError):
19 | load_interface_from_module("score", interface_class=Opener)
20 |
21 |
22 | @pytest.fixture
23 | def syspaths():
24 | copy = sys.path[:]
25 | yield sys.path
26 | sys.path = copy
27 |
28 |
29 | def test_empty_module(tmpdir, syspaths):
30 | with tmpdir.as_cwd():
31 | # python allows to import an empty directoy
32 | # check that the error message would be helpful for debugging purposes
33 | tmpdir.mkdir("foomod")
34 | syspaths.append(str(tmpdir))
35 |
36 | with pytest.raises(exceptions.EmptyInterfaceError):
37 | load_interface_from_module("foomod", interface_class=Opener)
38 |
39 |
40 | def test_get_logger(capfd):
41 | logger = get_logger("test")
42 | logger.info("message")
43 | captured = capfd.readouterr()
44 | assert "INFO substratools.test - message" in captured.err
45 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from enum import Enum
3 | from os import PathLike
4 | from typing import Any
5 | from typing import List
6 |
7 | from substratools.task_resources import StaticInputIdentifiers
8 |
9 |
10 | class InputIdentifiers(str, Enum):
11 | local = "local"
12 | shared = "shared"
13 | predictions = "predictions"
14 | opener = StaticInputIdentifiers.opener.value
15 | datasamples = StaticInputIdentifiers.datasamples.value
16 | rank = StaticInputIdentifiers.rank.value
17 |
18 |
19 | class OutputIdentifiers(str, Enum):
20 | local = "local"
21 | shared = "shared"
22 | predictions = "predictions"
23 | performance = "performance"
24 |
25 |
26 | def load_models(paths: List[PathLike]) -> dict:
27 | models = []
28 | for model_path in paths:
29 | with open(model_path, "r") as f:
30 | models.append(json.load(f))
31 |
32 | return models
33 |
34 |
35 | def load_model(path: PathLike):
36 | if path:
37 | with open(path, "r") as f:
38 | return json.load(f)
39 |
40 |
41 | def save_model(model: dict, path: PathLike):
42 | with open(path, "w") as f:
43 | json.dump(model, f)
44 |
45 |
46 | def save_predictions(predictions: Any, path: PathLike):
47 | with open(path, "w") as f:
48 | json.dump(predictions, f)
49 |
50 |
51 | def load_predictions(path: PathLike) -> Any:
52 | with open(path, "r") as f:
53 | predictions = json.load(f)
54 | return predictions
55 |
56 |
57 | def no_save_model(path, model):
58 | # do not save model at all
59 | pass
60 |
61 |
62 | def wrong_save_model(model, path):
63 | # simulate numpy.save behavior
64 | with open(path + ".npy", "w") as f:
65 | json.dump(model, f)
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | docs/*.tmp
107 |
108 | .vscode
109 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [tool.hatch.build.targets.sdist]
6 | exclude = ["tests*"]
7 |
8 | [tool.hatch.version]
9 | path = "substratools/__version__.py"
10 |
11 | [project]
12 | name = "substratools"
13 | description = "Python tools to submit functions on the Substra platform"
14 | dynamic = ["version"]
15 | readme = "README.md"
16 | requires-python = ">= 3.10"
17 | dependencies = []
18 | keywords = ["substra"]
19 | classifiers = [
20 | "Intended Audience :: Developers",
21 | "Topic :: Utilities",
22 | "Natural Language :: English",
23 | "Operating System :: OS Independent",
24 | "Programming Language :: Python :: 3.10",
25 | "Programming Language :: Python :: 3.11",
26 | "Programming Language :: Python :: 3.12",
27 | ]
28 | license = { file = "LICENSE" }
29 | authors = [{ name = "Owkin, Inc." }]
30 |
31 | [project.optional-dependencies]
32 | dev = ["flake8", "pytest", "pytest-cov", "pytest-mock", "numpy", "towncrier"]
33 |
34 | [project.urls]
35 | Documentation = "https://docs.substra.org/en/stable/"
36 | Repository = "https://github.com/Substra/substra-tools"
37 | Changelog = "https://github.com/Substra/substra-tools/blob/main/CHANGELOG.md"
38 |
39 | [tool.black]
40 | line-length = 120
41 | target-version = ['py39']
42 |
43 | [tool.isort]
44 | filter_files = true
45 | force_single_line = true
46 | line_length = 120
47 | profile = "black"
48 |
49 | [tool.pytest.ini_options]
50 | addopts = "-v --cov=substratools"
51 |
52 | [tool.towncrier]
53 | directory = "changes"
54 | filename = "CHANGELOG.md"
55 | start_string = "\n"
56 | underlines = ["", "", ""]
57 | title_format = "## [{version}](https://github.com/Substra/substra-tools/releases/tag/{version}) - {project_date}"
58 | issue_format = "[#{issue}](https://github.com/Substra/substra-tools/pull/{issue})"
59 | [tool.towncrier.fragment.added]
60 | [tool.towncrier.fragment.removed]
61 | [tool.towncrier.fragment.changed]
62 | [tool.towncrier.fragment.fixed]
63 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | from pathlib import Path
5 | from uuid import uuid4
6 |
7 | import pytest
8 |
9 | from substratools.task_resources import TaskResources
10 | from substratools.utils import import_module
11 | from substratools.workspace import FunctionWorkspace
12 | from tests.utils import OutputIdentifiers
13 |
14 |
15 | @pytest.fixture
16 | def workdir(tmp_path):
17 | d = tmp_path / "substra-workspace"
18 | d.mkdir()
19 | return d
20 |
21 |
22 | @pytest.fixture(autouse=True)
23 | def patch_cwd(monkeypatch, workdir):
24 | # this is needed to ensure the workspace is located in a tmpdir
25 | def getcwd():
26 | return str(workdir)
27 |
28 | monkeypatch.setattr(os, "getcwd", getcwd)
29 |
30 |
31 | @pytest.fixture()
32 | def valid_opener_code():
33 | return """
34 | import json
35 | from substratools import Opener
36 |
37 | class FakeOpener(Opener):
38 | def get_data(self, folder):
39 | return 'X', list(range(0, 3))
40 |
41 | def fake_data(self, n_samples):
42 | return ['Xfake'] * n_samples, [0] * n_samples
43 | """
44 |
45 |
46 | @pytest.fixture()
47 | def valid_opener(valid_opener_code):
48 | import_module("opener", valid_opener_code)
49 | yield
50 | del sys.modules["opener"]
51 |
52 |
53 | @pytest.fixture()
54 | def valid_opener_script(workdir, valid_opener_code):
55 | opener_path = workdir / "my_opener.py"
56 | opener_path.write_text(valid_opener_code)
57 |
58 | return str(opener_path)
59 |
60 |
61 | @pytest.fixture(autouse=True)
62 | def output_model_path(workdir: Path) -> str:
63 | path = workdir / str(uuid4())
64 | yield path
65 | if path.exists():
66 | os.remove(path)
67 |
68 |
69 | @pytest.fixture(autouse=True)
70 | def output_model_path_2(workdir: Path) -> str:
71 | path = workdir / str(uuid4())
72 | yield path
73 | if path.exists():
74 | os.remove(path)
75 |
76 |
77 | @pytest.fixture()
78 | def valid_function_workspace(output_model_path: str) -> FunctionWorkspace:
79 | workspace_outputs = TaskResources(
80 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}])
81 | )
82 |
83 | workspace = FunctionWorkspace(outputs=workspace_outputs)
84 |
85 | return workspace
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Substra-tools
2 |
3 |
4 |

5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | Substra is an open source federated learning (FL) software. This specific repository, substra-tools, is a Python package defining base classes for Dataset (data opener script) and wrappers to execute functions submitted on the platform.
19 |
20 |
21 | ## Getting started
22 |
23 | To install the substratools Python package, run the following command:
24 |
25 | ```sh
26 | pip install substratools
27 | ```
28 |
29 | ## Developers
30 |
31 | Clone the repository:
32 |
33 |
34 | ### Setup
35 |
36 | To setup the project in development mode, run:
37 |
38 | ```sh
39 | pip install -e ".[dev]"
40 | ```
41 |
42 | To run all tests, use the following command:
43 |
44 | ```sh
45 | make test
46 | ```
47 |
48 | ## How to generate the changelog
49 |
50 | The changelog is managed with [towncrier](https://towncrier.readthedocs.io/en/stable/index.html).
51 | To add a new entry in the changelog, add a file in the `changes` folder. The file name should have the following structure:
52 | `.`.
53 | The `unique_id` is a unique identifier, we currently use the PR number.
54 | The `change_type` can be of the following types: `added`, `changed`, `removed`, `fixed`.
55 |
56 | To generate the changelog (for example during a release), use the following command (you must have the dev dependencies installed):
57 |
58 | ```
59 | towncrier build --version=
60 | ```
61 |
62 | You can use the `--draft` option to see what would be generated without actually writing to the changelog (and without removing the fragments).
63 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: Report bug or performance issue
3 | title: "BUG: "
4 | labels: [Bug]
5 |
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | Thanks for taking the time to fill out this bug report!
11 | - type: textarea
12 | id: context
13 | attributes:
14 | label: What are you trying to do?
15 | description: >
16 | Please provide some context on what you are trying to achieve.
17 | placeholder:
18 | validations:
19 | required: true
20 | - type: textarea
21 | id: issue-description
22 | attributes:
23 | label: Issue Description (what is happening?)
24 | description: >
25 | Please provide a description of the issue.
26 | validations:
27 | required: true
28 | - type: textarea
29 | id: expected-behavior
30 | attributes:
31 | label: Expected Behavior (what should happen?)
32 | description: >
33 | Please describe or show a code example of the expected behavior.
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: example
38 | attributes:
39 | label: Reproducible Example
40 | description: >
41 | If possible, provide a reproducible example.
42 | render: python
43 |
44 | - type: textarea
45 | id: os-version
46 | attributes:
47 | label: Operating system
48 | description: >
49 | Which operating system are you using? (Provide the version number)
50 | validations:
51 | required: true
52 | - type: textarea
53 | id: python-version
54 | attributes:
55 | label: Python version
56 | description: >
57 | Which Python version are you using?
58 | placeholder: >
59 | python --version
60 | validations:
61 | required: true
62 | - type: textarea
63 | id: substra-version
64 | attributes:
65 | label: Installed Substra versions
66 | description: >
67 | Which version of `substrafl`/ `substra` / `substra-tools` are you using?
68 | You can check if they are compatible in the [compatibility table](https://docs.substra.org/en/stable/additional/release.html#compatibility-table).
69 | placeholder: >
70 | pip freeze | grep substra
71 | render: python
72 | validations:
73 | required: true
74 | - type: textarea
75 | id: dependencies-version
76 | attributes:
77 | label: Installed versions of dependencies
78 | description: >
79 | Please provide versions of dependencies which might be relevant to your issue (eg. `helm` and `skaffold` version for a deployment issue, `numpy` and `pytorch` for an algorithmic issue).
80 |
81 |
82 | - type: textarea
83 | id: logs
84 | attributes:
85 | label: Logs / Stacktrace
86 | description: >
87 | Please copy-paste here any log and/or stacktrace that might be relevant. Remove confidential and personal information if necessary.
88 |
--------------------------------------------------------------------------------
/substratools/workspace.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import os
3 |
4 |
5 | def makedir_safe(path):
6 | """Create dir (no failure)."""
7 | try:
8 | os.makedirs(path)
9 | except (FileExistsError, PermissionError):
10 | pass
11 |
12 |
13 | DEFAULT_INPUT_DATA_FOLDER_PATH = "data/"
14 | DEFAULT_INPUT_PREDICTIONS_PATH = "pred/pred"
15 | DEFAULT_OUTPUT_PERF_PATH = "pred/perf.json"
16 | DEFAULT_LOG_PATH = "model/log_model.log"
17 | DEFAULT_CHAINKEYS_PATH = "chainkeys/"
18 |
19 |
20 | class Workspace(abc.ABC):
21 | """Filesystem workspace for task execution."""
22 |
23 | def __init__(self, dirpath=None):
24 | self._workdir = dirpath if dirpath else os.getcwd()
25 |
26 | def _get_default_path(self, path):
27 | return os.path.join(self._workdir, path)
28 |
29 | def _get_default_subpaths(self, path):
30 | rootpath = os.path.join(self._workdir, path)
31 | if os.path.isdir(rootpath):
32 | return [
33 | os.path.join(rootpath, subfolder)
34 | for subfolder in os.listdir(rootpath)
35 | if os.path.isdir(os.path.join(rootpath, subfolder))
36 | ]
37 | return []
38 |
39 |
40 | class OpenerWorkspace(Workspace):
41 | """Filesystem workspace required by the opener."""
42 |
43 | def __init__(
44 | self,
45 | dirpath=None,
46 | input_data_folder_paths=None,
47 | ):
48 | super().__init__(dirpath=dirpath)
49 |
50 | assert input_data_folder_paths is None or isinstance(input_data_folder_paths, list)
51 |
52 | self.input_data_folder_paths = input_data_folder_paths or self._get_default_subpaths(
53 | DEFAULT_INPUT_DATA_FOLDER_PATH
54 | )
55 |
56 |
57 | class FunctionWorkspace(Workspace):
58 | """Filesystem workspace for user defined function execution."""
59 |
60 | def __init__(
61 | self,
62 | dirpath=None,
63 | log_path=None,
64 | chainkeys_path=None,
65 | inputs=None,
66 | outputs=None,
67 | ):
68 |
69 | super().__init__(dirpath=dirpath)
70 |
71 | self.input_data_folder_paths = (
72 | self._get_default_subpaths(DEFAULT_INPUT_DATA_FOLDER_PATH)
73 | if inputs is None
74 | else inputs.input_data_folder_paths
75 | )
76 |
77 | self.log_path = log_path or self._get_default_path(DEFAULT_LOG_PATH)
78 | self.chainkeys_path = chainkeys_path or self._get_default_path(DEFAULT_CHAINKEYS_PATH)
79 |
80 | self.opener_path = inputs.opener_path if inputs else None
81 |
82 | self.task_inputs = inputs.formatted_dynamic_resources if inputs else {}
83 | self.task_outputs = outputs.formatted_dynamic_resources if outputs else {}
84 |
85 | dirs = [
86 | self.chainkeys_path,
87 | ]
88 | paths = [
89 | self.log_path,
90 | ]
91 |
92 | dirs.extend([os.path.dirname(p) for p in paths])
93 | for d in dirs:
94 | if d:
95 | makedir_safe(d)
96 |
--------------------------------------------------------------------------------
/tests/test_task_resources.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import pytest
4 |
5 | from substratools.exceptions import InvalidCLIError
6 | from substratools.exceptions import InvalidInputOutputsError
7 | from substratools.task_resources import StaticInputIdentifiers
8 | from substratools.task_resources import TaskResources
9 |
10 | _VALID_RESOURCES = [
11 | {"id": "foo", "value": "bar", "multiple": True},
12 | {"id": "foo", "value": "babar", "multiple": True},
13 | {"id": "fofo", "value": "bar", "multiple": False},
14 | ]
15 | _VALID_VALUES = {"foo": {"value": ["bar", "babar"], "multiple": True}, "fofo": {"value": ["bar"], "multiple": False}}
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "invalid_arg",
20 | (
21 | {"foo": "barr"},
22 | "foo and bar",
23 | ["foo", "barr"],
24 | [{"foo": "bar"}],
25 | [{"foo": "bar"}, {"id": "foo", "value": "bar", "multiple": True}],
26 | # [{_RESOURCE_ID: "foo", _RESOURCE_VALUE: "some path", _RESOURCE_MULTIPLE: "str"}],
27 | ),
28 | )
29 | def test_task_resources_invalid_argsrt(invalid_arg):
30 | with pytest.raises(InvalidCLIError):
31 | TaskResources(json.dumps(invalid_arg))
32 |
33 |
34 | @pytest.mark.parametrize(
35 | "valid_arg,expected",
36 | [
37 | ([], {}),
38 | ([{"id": "foo", "value": "bar", "multiple": True}], {"foo": {"value": ["bar"], "multiple": True}}),
39 | (
40 | [{"id": "foo", "value": "bar", "multiple": True}, {"id": "foo", "value": "babar", "multiple": True}],
41 | {"foo": {"value": ["bar", "babar"], "multiple": True}},
42 | ),
43 | (_VALID_RESOURCES, _VALID_VALUES),
44 | ],
45 | )
46 | def test_task_resources_values(valid_arg, expected):
47 | TaskResources(json.dumps(valid_arg))._values == expected
48 |
49 |
50 | @pytest.mark.parametrize(
51 | "static_resource_id",
52 | (
53 | StaticInputIdentifiers.chainkeys.value,
54 | StaticInputIdentifiers.datasamples.value,
55 | StaticInputIdentifiers.opener.value,
56 | ),
57 | )
58 | def test_task_static_resources(static_resource_id):
59 | "checks that static keys opener, datasamples and chainkeys are excluded"
60 |
61 | TaskResources(
62 | json.dumps(_VALID_RESOURCES + [{"id": static_resource_id, "value": "foo", "multiple": False}])
63 | )._values == _VALID_VALUES
64 |
65 |
66 | @pytest.mark.parametrize("key", tuple(_VALID_VALUES.keys()))
67 | def test_get_value(key):
68 | "get_value method returns a list of path of multiple resource and a path for non multiple ones"
69 | expected = _VALID_VALUES[key]["value"]
70 |
71 | if _VALID_VALUES[key]["multiple"]:
72 | expected = expected[0]
73 |
74 |
75 | def test_multiple_resource_error():
76 | "non multiple resource can't have multiple values"
77 |
78 | with pytest.raises(InvalidInputOutputsError):
79 | TaskResources(
80 | json.dumps(
81 | [
82 | {"id": "foo", "value": "bar", "multiple": False},
83 | {"id": "foo", "value": "babar", "multiple": False},
84 | ]
85 | )
86 | )
87 |
--------------------------------------------------------------------------------
/tests/test_opener.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from substratools import exceptions
6 | from substratools.opener import Opener
7 | from substratools.opener import OpenerWrapper
8 | from substratools.opener import load_from_module
9 | from substratools.utils import import_module
10 | from substratools.utils import load_interface_from_module
11 | from substratools.workspace import DEFAULT_INPUT_DATA_FOLDER_PATH
12 |
13 |
14 | @pytest.fixture
15 | def tmp_cwd(tmp_path):
16 | # create a temporary current working directory
17 | new_dir = tmp_path / "workspace"
18 | new_dir.mkdir()
19 |
20 | old_dir = os.getcwd()
21 | os.chdir(new_dir)
22 |
23 | yield new_dir
24 |
25 | os.chdir(old_dir)
26 |
27 |
28 | def test_load_opener_not_found(tmp_cwd):
29 | with pytest.raises(ImportError):
30 | load_from_module()
31 |
32 |
33 | def test_load_invalid_opener(tmp_cwd):
34 | invalid_script = """
35 | def get_data():
36 | raise NotImplementedError
37 | """
38 |
39 | import_module("opener", invalid_script)
40 |
41 | with pytest.raises(exceptions.InvalidInterfaceError):
42 | load_from_module()
43 |
44 |
45 | def test_load_opener_as_class(tmp_cwd):
46 | script = """
47 | from substratools import Opener
48 | class MyOpener(Opener):
49 | def get_data(self, folders):
50 | return 'data_class'
51 | def fake_data(self, n_samples):
52 | return 'fake_data'
53 | """
54 |
55 | import_module("opener", script)
56 |
57 | o = load_from_module()
58 | assert o.get_data() == "data_class"
59 |
60 |
61 | def test_load_opener_from_path(tmp_cwd, valid_opener_code):
62 | dirpath = tmp_cwd / "myopener"
63 | dirpath.mkdir()
64 | path = dirpath / "my_opener.py"
65 | path.write_text(valid_opener_code)
66 |
67 | interface = load_interface_from_module(
68 | "opener",
69 | interface_class=Opener,
70 | interface_signature=None, # XXX does not support interface for debugging
71 | path=path,
72 | )
73 | o = OpenerWrapper(interface, workspace=None)
74 | assert o.get_data()[0] == "X"
75 |
76 |
77 | def test_load_opener_from_path_error_with_inheritance(tmp_cwd):
78 | wrong_opener_code = """
79 | import json
80 | from substratools import Opener
81 |
82 | class FakeOpener(Opener):
83 | def get_data(self, folder):
84 | return 'X', list(range(0, 3))
85 |
86 | def fake_data(self, n_samples):
87 | return ['Xfake'] * n_samples, [0] * n_samples
88 |
89 | class FinalOpener(FakeOpener):
90 | def __init__(self):
91 | super().__init__()
92 | """
93 | dirpath = tmp_cwd / "myopener"
94 | dirpath.mkdir()
95 | path = dirpath / "my_opener.py"
96 | path.write_text(wrong_opener_code)
97 |
98 | with pytest.raises(exceptions.InvalidInterfaceError):
99 | load_interface_from_module(
100 | "opener",
101 | interface_class=Opener,
102 | interface_signature=None, # XXX does not support interface for debugging
103 | path=path,
104 | )
105 |
106 |
107 | def test_opener_check_folders(tmp_cwd):
108 | script = """
109 | from substratools import Opener
110 | class MyOpener(Opener):
111 | def get_data(self, folders):
112 | assert len(folders) == 5
113 | return 'data_class'
114 | def fake_data(self, n_samples):
115 | return 'fake_data_class'
116 | """
117 |
118 | import_module("opener", script)
119 |
120 | o = load_from_module()
121 |
122 | # create some data folders
123 | data_root_path = os.path.join(o._workspace._workdir, DEFAULT_INPUT_DATA_FOLDER_PATH)
124 | data_paths = [os.path.join(data_root_path, str(i)) for i in range(5)]
125 | [os.makedirs(p) for p in data_paths]
126 |
127 | o._workspace.input_data_folder_paths = data_paths
128 | assert o.get_data() == "data_class"
129 |
--------------------------------------------------------------------------------
/substratools/opener.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 | import os
4 | import types
5 | from typing import Optional
6 |
7 | from substratools import exceptions
8 | from substratools import utils
9 | from substratools.workspace import OpenerWorkspace
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | REQUIRED_FUNCTIONS = set(
15 | [
16 | "get_data",
17 | "fake_data",
18 | ]
19 | )
20 |
21 |
22 | class Opener(abc.ABC):
23 | """Dataset opener abstract base class.
24 |
25 | To define a new opener script, subclass this class and implement the
26 | following abstract methods:
27 |
28 | - #Opener.get_data()
29 | - #Opener.fake_data()
30 |
31 | # Example
32 |
33 | ```python
34 | import os
35 | import pandas as pd
36 | import string
37 | import numpy as np
38 |
39 | import substratools as tools
40 |
41 | class DummyOpener(tools.Opener):
42 | def get_data(self, folders):
43 | return [
44 | pd.read_csv(os.path.join(folder, 'train.csv'))
45 | for folder in folders
46 | ]
47 |
48 | def fake_data(self, n_samples):
49 | return [] # compute random fake data
50 | ```
51 |
52 | # How to test locally an opener script
53 |
54 | An opener can be imported and used in python scripts as would any other class.
55 |
56 | For example, assuming that you have a local file named `opener.py` that contains
57 | an `Opener` named `MyOpener`:
58 |
59 | ```python
60 | import os
61 | from opener import MyOpener
62 |
63 | folders = os.listdir('./sandbox/data_samples/')
64 |
65 | o = MyOpener()
66 | loaded_datasamples = o.get_data(folders)
67 | ```
68 | """
69 |
70 | @abc.abstractmethod
71 | def get_data(self, folders):
72 | """Datasamples loader
73 |
74 | # Arguments
75 |
76 | folders: list of folders. Each folder represents a data sample.
77 |
78 | # Returns
79 |
80 | data: data object.
81 | """
82 | raise NotImplementedError
83 |
84 | @abc.abstractmethod
85 | def fake_data(self, n_samples):
86 | """Generate fake loaded datasamples for offline testing.
87 |
88 | # Arguments
89 |
90 | n_samples (int): number of samples to return
91 |
92 | # Returns
93 |
94 | data: data object.
95 | """
96 | raise NotImplementedError
97 |
98 |
99 | class OpenerWrapper(object):
100 | """Internal wrapper to call opener interface."""
101 |
102 | def __init__(self, interface, workspace=None):
103 | assert isinstance(interface, Opener) or isinstance(interface, types.ModuleType)
104 |
105 | self._workspace = workspace or OpenerWorkspace()
106 | self._interface = interface
107 |
108 | @property
109 | def data_folder_paths(self):
110 | return self._workspace.input_data_folder_paths
111 |
112 | def get_data(self, fake_data=False, n_fake_samples=None):
113 | if fake_data:
114 | logger.info("loading data from fake data")
115 | return self._interface.fake_data(n_samples=n_fake_samples)
116 | else:
117 | logger.info("loading data from '{}'".format(self.data_folder_paths))
118 | return self._interface.get_data(self.data_folder_paths)
119 |
120 | def _assert_output_exists(self, path, key):
121 |
122 | if os.path.isdir(path):
123 | raise exceptions.NotAFileError(f"Expected output file at {path}, found dir for output `{key}`")
124 | if not os.path.isfile(path):
125 | raise exceptions.MissingFileError(f"Output file {path} used to save argument `{key}` does not exists.")
126 |
127 |
128 | def load_from_module(workspace=None) -> Optional[OpenerWrapper]:
129 | """Load opener interface.
130 |
131 | If a workspace is given, the associated opener will be returned. This means that if no
132 | opener_path is defined within the workspace, no opener will be returned
133 | If no workspace is given, the opener interface will be directly loaded as a module.
134 |
135 | Return an OpenerWrapper instance.
136 | """
137 | if workspace is None:
138 | # import from module
139 | path = None
140 |
141 | elif workspace.opener_path is None:
142 | # no opener within this workspace
143 | return None
144 |
145 | else:
146 | # import opener from workspace specified path
147 | path = workspace.opener_path
148 |
149 | interface = utils.load_interface_from_module(
150 | "opener",
151 | interface_class=Opener,
152 | interface_signature=None, # XXX does not support interface for debugging
153 | path=path,
154 | )
155 | return OpenerWrapper(interface, workspace=workspace)
156 |
--------------------------------------------------------------------------------
/substratools/task_resources.py:
--------------------------------------------------------------------------------
1 | import json
2 | from enum import Enum
3 | from typing import Dict
4 | from typing import List
5 | from typing import Optional
6 | from typing import Union
7 |
8 | from substratools import exceptions
9 |
10 |
11 | class StaticInputIdentifiers(str, Enum):
12 | opener = "opener"
13 | datasamples = "datasamples"
14 | chainkeys = "chainkeys"
15 | rank = "rank"
16 |
17 |
18 | _RESOURCE_ID = "id"
19 | _RESOURCE_VALUE = "value"
20 | _RESOURCE_MULTIPLE = "multiple"
21 |
22 |
23 | def _check_resources_format(resource_list):
24 |
25 | _required_keys = set((_RESOURCE_ID, _RESOURCE_VALUE, _RESOURCE_MULTIPLE))
26 | _error_message = (
27 | "`--inputs` and `--outputs` args should be json serialized list of dict. Each dict containing "
28 | f"the following keys: {_required_keys}. {_RESOURCE_ID} and {_RESOURCE_VALUE} must be strings, "
29 | f"{_RESOURCE_MULTIPLE} must be a bool."
30 | )
31 |
32 | if not isinstance(resource_list, list):
33 | raise exceptions.InvalidCLIError(_error_message)
34 |
35 | if not all([isinstance(d, dict) for d in resource_list]):
36 | raise exceptions.InvalidCLIError(_error_message)
37 |
38 | if not all([set(d.keys()) == _required_keys for d in resource_list]):
39 | raise exceptions.InvalidCLIError(_error_message)
40 |
41 | if not all([isinstance(d[_RESOURCE_MULTIPLE], bool) for d in resource_list]):
42 | raise exceptions.InvalidCLIError(_error_message)
43 |
44 | if not all([isinstance(d[_RESOURCE_ID], str) for d in resource_list]):
45 | raise exceptions.InvalidCLIError(_error_message)
46 |
47 | if not all([isinstance(d[_RESOURCE_VALUE], str) for d in resource_list]):
48 | raise exceptions.InvalidCLIError(_error_message)
49 |
50 |
51 | def _check_resources_multiplicity(resource_dict):
52 | for k, v in resource_dict.items():
53 | if not v[_RESOURCE_MULTIPLE] and len(v[_RESOURCE_VALUE]) > 1:
54 | raise exceptions.InvalidInputOutputsError(f"There is more than one path for the non multiple resource {k}")
55 |
56 |
57 | class TaskResources:
58 | """TaskResources is created from stdin to provide a nice abstraction over inputs/outputs"""
59 |
60 | _values: Dict[str, List[str]]
61 |
62 | def __init__(self, argstr: str) -> None:
63 | """Argstr is expected to be a JSON array like:
64 | [
65 | {"id": "local", "value": "/sandbox/output/model/uuid", "multiple": False},
66 | {"id": "shared", ...}
67 | ]
68 | """
69 | self._values = {}
70 | resource_list = json.loads(argstr.replace("\\", "/"))
71 |
72 | _check_resources_format(resource_list)
73 |
74 | for item in resource_list:
75 | self._values.setdefault(
76 | item[_RESOURCE_ID], {_RESOURCE_VALUE: [], _RESOURCE_MULTIPLE: item[_RESOURCE_MULTIPLE]}
77 | )
78 | self._values[item[_RESOURCE_ID]][_RESOURCE_VALUE].append(item[_RESOURCE_VALUE])
79 |
80 | _check_resources_multiplicity(self._values)
81 |
82 | self.opener_path = self.get_value(StaticInputIdentifiers.opener.value)
83 | self.input_data_folder_paths = self.get_value(StaticInputIdentifiers.datasamples.value)
84 | self.chainkeys_path = self.get_value(StaticInputIdentifiers.chainkeys.value)
85 |
86 | def get_value(self, key: str) -> Optional[Union[List[str], str]]:
87 | """Returns the value for a given key. Return None if there is no matching resource.
88 | Will raise if there is a mismatch between the given multiplicity and the number of returned
89 | elements.
90 |
91 | If multiple is True, will return a list else will return a single value
92 | """
93 | if key not in self._values:
94 | return None
95 |
96 | val = self._values[key][_RESOURCE_VALUE]
97 | multiple = self._values[key][_RESOURCE_MULTIPLE]
98 |
99 | if multiple:
100 | return val
101 |
102 | return val[0]
103 |
104 | @property
105 | def formatted_dynamic_resources(self) -> Union[List[str], str]:
106 | """Returns all the resources (except the datasamples, the opener and the chainkeys_path under the user format:
107 | A dict where each input is an element where
108 | - the key is the user identifier
109 | - the value is a list of Path for multiple resources and a Path for non multiple resources
110 | """
111 |
112 | return {
113 | k: self.get_value(k)
114 | for k in self._values.keys()
115 | if k
116 | not in (
117 | StaticInputIdentifiers.opener.value,
118 | StaticInputIdentifiers.datasamples.value,
119 | StaticInputIdentifiers.chainkeys.value,
120 | )
121 | }
122 |
--------------------------------------------------------------------------------
/substratools/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import importlib.util
3 | import inspect
4 | import logging
5 | import os
6 | import sys
7 | import time
8 |
9 | from substratools import exceptions
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | MAPPING_LOG_LEVEL = {
14 | "debug": logging.DEBUG,
15 | "info": logging.INFO,
16 | "warning": logging.WARNING,
17 | "error": logging.ERROR,
18 | "critical": logging.CRITICAL,
19 | }
20 |
21 |
22 | def configure_logging(path=None, log_level="info"):
23 | level = MAPPING_LOG_LEVEL[log_level]
24 |
25 | formatter = logging.Formatter(fmt="%(asctime)s %(levelname)-6s %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
26 |
27 | h = logging.StreamHandler()
28 | h.setLevel(level)
29 | h.setFormatter(formatter)
30 |
31 | root = logging.getLogger("substratools")
32 | root.setLevel(level)
33 | root.addHandler(h)
34 |
35 | if path and level == logging.DEBUG:
36 | fh = logging.FileHandler(path)
37 | fh.setLevel(level)
38 | fh.setFormatter(formatter)
39 |
40 | root.addHandler(h)
41 |
42 |
43 | def get_logger(name, path=None, log_level="info"):
44 | new_logger = logging.getLogger(f"substratools.{name}")
45 | configure_logging(path, log_level)
46 | return new_logger
47 |
48 |
49 | class Timer(object):
50 | """This decorator prints the execution time for the decorated function."""
51 |
52 | def __init__(self, module_logger):
53 | self.module_logger = module_logger
54 |
55 | def __call__(self, func):
56 | def wrapper(*args, **kwargs):
57 | start = time.time()
58 | result = func(*args, **kwargs)
59 | end = time.time()
60 | self.module_logger.info("{} ran in {}s".format(func.__qualname__, round(end - start, 2)))
61 | return result
62 |
63 | return wrapper
64 |
65 |
66 | def import_module(module_name, code):
67 | if module_name in sys.modules:
68 | logging.warning("Module {} will be overwritten".format(module_name))
69 | spec = importlib.util.spec_from_loader(module_name, loader=None, origin=module_name)
70 | module = importlib.util.module_from_spec(spec)
71 | sys.modules[module_name] = module
72 | exec(code, module.__dict__)
73 |
74 |
75 | def import_module_from_path(path, module_name):
76 | assert os.path.exists(path), "path '{}' not found".format(path)
77 | spec = importlib.util.spec_from_file_location(module_name, path)
78 | assert spec, "could not load spec from path '{}'".format(path)
79 | module = importlib.util.module_from_spec(spec)
80 | spec.loader.exec_module(module)
81 | return module
82 |
83 |
84 | # TODO: 'load_interface_from_module' is too complex, consider refactoring
85 | def load_interface_from_module(module_name, interface_class, interface_signature=None, path=None): # noqa: C901
86 | if path:
87 | module = import_module_from_path(path, module_name)
88 | logger.info(f"Module '{module_name}' loaded from path '{path}'")
89 | else:
90 | try:
91 | module = importlib.import_module(module_name)
92 | logger.info(f"Module '{module_name}' imported dynamically; module={module}")
93 | except ImportError:
94 | # XXX don't use ModuleNotFoundError for python3.5 compatibility
95 | raise
96 |
97 | # check if module empty
98 | if not inspect.getmembers(module, lambda m: inspect.isclass(m) or inspect.isfunction(m)):
99 | raise exceptions.EmptyInterfaceError(
100 | f"Module '{module_name}' seems empty: no method/class found in members: '{dir(module)}'"
101 | )
102 |
103 | # find interface class
104 | found_interfaces = []
105 | for _, obj in inspect.getmembers(module, inspect.isclass):
106 | if issubclass(obj, interface_class) and obj != interface_class:
107 | found_interfaces.append(obj)
108 |
109 | if len(found_interfaces) == 1:
110 | return found_interfaces[0]() # return interface instance
111 | elif len(found_interfaces) > 1:
112 | raise exceptions.InvalidInterfaceError(
113 | f"Multiple interfaces found in module '{module_name}': {found_interfaces}"
114 | )
115 |
116 | # backward compatibility; accept methods at module level directly
117 | if interface_signature is None:
118 | class_name = interface_class.__name__
119 | elements = str(dir(module))
120 | logger.info(f"Class '{class_name}' not found from: '{elements}'")
121 | raise exceptions.InvalidInterfaceError("Expecting {} subclass in {}".format(class_name, module_name))
122 |
123 | missing_functions = interface_signature.copy()
124 | for name, obj in inspect.getmembers(module):
125 | if not inspect.isfunction(obj):
126 | continue
127 | try:
128 | missing_functions.remove(name)
129 | except KeyError:
130 | pass
131 |
132 | if missing_functions:
133 | message = "Method(s) {} not implemented".format(", ".join(["'{}'".format(m) for m in missing_functions]))
134 | raise exceptions.InvalidInterfaceError(message)
135 | return module
136 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uuid
3 | from os import PathLike
4 | from typing import Any
5 | from typing import TypedDict
6 |
7 | import numpy as np
8 | import pytest
9 |
10 | from substratools import function
11 | from substratools import load_performance
12 | from substratools import opener
13 | from substratools import save_performance
14 | from substratools.task_resources import TaskResources
15 | from substratools.workspace import FunctionWorkspace
16 | from tests import utils
17 | from tests.utils import InputIdentifiers
18 | from tests.utils import OutputIdentifiers
19 |
20 |
21 | @pytest.fixture()
22 | def write_pred_file(workdir):
23 | pred_file = str(workdir / str(uuid.uuid4()))
24 | data = list(range(3, 6))
25 | with open(pred_file, "w") as f:
26 | json.dump(data, f)
27 | return pred_file, data
28 |
29 |
30 | @pytest.fixture
31 | def inputs(workdir, valid_opener_script, write_pred_file):
32 | return [
33 | {"id": InputIdentifiers.predictions, "value": str(write_pred_file[0]), "multiple": False},
34 | {"id": InputIdentifiers.datasamples, "value": str(workdir / "datasamples_unused"), "multiple": True},
35 | {"id": InputIdentifiers.opener, "value": str(valid_opener_script), "multiple": False},
36 | ]
37 |
38 |
39 | @pytest.fixture
40 | def outputs(workdir):
41 | return [{"id": OutputIdentifiers.performance, "value": str(workdir / str(uuid.uuid4())), "multiple": False}]
42 |
43 |
44 | @pytest.fixture(autouse=True)
45 | def setup(valid_opener, write_pred_file):
46 | pass
47 |
48 |
49 | @function.register
50 | def score(
51 | inputs: TypedDict("inputs", {InputIdentifiers.datasamples: Any, InputIdentifiers.predictions: Any}),
52 | outputs: TypedDict("outputs", {OutputIdentifiers.performance: PathLike}),
53 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}),
54 | ):
55 | y_true = inputs.get(InputIdentifiers.datasamples)[1]
56 | y_pred_path = inputs.get(InputIdentifiers.predictions)
57 | y_pred = utils.load_predictions(y_pred_path)
58 |
59 | score = sum(y_true) + sum(y_pred)
60 |
61 | save_performance(performance=score, path=outputs.get(OutputIdentifiers.performance))
62 |
63 |
64 | def test_score(workdir, write_pred_file):
65 | inputs = TaskResources(
66 | json.dumps(
67 | [
68 | {"id": InputIdentifiers.predictions, "value": str(write_pred_file[0]), "multiple": False},
69 | ]
70 | )
71 | )
72 | outputs = TaskResources(
73 | json.dumps(
74 | [{"id": OutputIdentifiers.performance, "value": str(workdir / str(uuid.uuid4())), "multiple": False}]
75 | )
76 | )
77 | workspace = FunctionWorkspace(inputs=inputs, outputs=outputs)
78 | wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=opener.load_from_module())
79 | wp.execute(function=score)
80 | s = load_performance(wp._workspace.task_outputs[OutputIdentifiers.performance])
81 | assert s == 15
82 |
83 |
84 | def test_execute(inputs, outputs):
85 | perf_path = outputs[0]["value"]
86 | function.execute(
87 | sysargs=["--function-name", "score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)],
88 | )
89 | s = load_performance(perf_path)
90 | assert s == 15
91 |
92 |
93 | @pytest.mark.parametrize(
94 | "fake_data_mode,expected_score",
95 | [
96 | ([], 15),
97 | (["--fake-data", "--n-fake-samples", "3"], 12),
98 | ],
99 | )
100 | def test_execute_fake_data_modes(fake_data_mode, expected_score, inputs, outputs):
101 | perf_path = outputs[0]["value"]
102 | function.execute(
103 | sysargs=fake_data_mode
104 | + ["--function-name", "score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)],
105 | )
106 | s = load_performance(perf_path)
107 | assert s == expected_score
108 |
109 |
110 | def test_execute_np(inputs, outputs):
111 | @function.register
112 | def float_np_score(
113 | inputs,
114 | outputs,
115 | task_properties: dict,
116 | ):
117 | save_performance(np.float64(0.99), outputs.get(OutputIdentifiers.performance))
118 |
119 | perf_path = outputs[0]["value"]
120 | function.execute(
121 | sysargs=["--function-name", "float_np_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)],
122 | )
123 | s = load_performance(perf_path)
124 | assert s == pytest.approx(0.99)
125 |
126 |
127 | def test_execute_int(inputs, outputs):
128 | @function.register
129 | def int_score(
130 | inputs,
131 | outputs,
132 | task_properties: dict,
133 | ):
134 | save_performance(int(1), outputs.get(OutputIdentifiers.performance))
135 |
136 | perf_path = outputs[0]["value"]
137 | function.execute(
138 | sysargs=["--function-name", "int_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)],
139 | )
140 | s = load_performance(perf_path)
141 | assert s == 1
142 |
143 |
144 | def test_execute_dict(inputs, outputs):
145 | @function.register
146 | def dict_score(
147 | inputs,
148 | outputs,
149 | task_properties: dict,
150 | ):
151 | save_performance({"a": 1}, outputs.get(OutputIdentifiers.performance))
152 |
153 | perf_path = outputs[0]["value"]
154 | function.execute(
155 | sysargs=["--function-name", "dict_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)],
156 | )
157 | s = load_performance(perf_path)
158 | assert s["a"] == 1
159 |
--------------------------------------------------------------------------------
/tests/test_workflow.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import pytest
5 |
6 | from substratools import load_performance
7 | from substratools import opener
8 | from substratools import save_performance
9 | from substratools.function import FunctionWrapper
10 | from substratools.task_resources import TaskResources
11 | from substratools.utils import import_module
12 | from substratools.workspace import FunctionWorkspace
13 | from tests import utils
14 | from tests.utils import InputIdentifiers
15 | from tests.utils import OutputIdentifiers
16 |
17 |
18 | @pytest.fixture
19 | def dummy_opener():
20 | script = """
21 | import json
22 | from substratools import Opener
23 |
24 | class DummyOpener(Opener):
25 | def get_data(self, folder):
26 | return None
27 |
28 | def fake_data(self, n_samples):
29 | raise NotImplementedError
30 | """
31 | import_module("opener", script)
32 |
33 |
34 | def train(inputs, outputs, task_properties):
35 | models = utils.load_models(inputs.get(InputIdentifiers.shared, []))
36 | total = sum([m["i"] for m in models])
37 | new_model = {"i": len(models) + 1, "total": total}
38 |
39 | utils.save_model(new_model, outputs.get(OutputIdentifiers.shared))
40 |
41 |
42 | def predict(inputs, outputs, task_properties):
43 | model = utils.load_model(inputs.get(InputIdentifiers.shared))
44 | pred = {"sum": model["i"]}
45 | utils.save_predictions(pred, outputs.get(OutputIdentifiers.predictions))
46 |
47 |
48 | def score(inputs, outputs, task_properties):
49 | y_pred_path = inputs.get(InputIdentifiers.predictions)
50 | y_pred = utils.load_predictions(y_pred_path)
51 |
52 | score = y_pred["sum"]
53 |
54 | save_performance(performance=score, path=outputs.get(OutputIdentifiers.performance))
55 |
56 |
57 | def test_workflow(workdir, dummy_opener):
58 | loop1_model_path = workdir / "loop1model"
59 | loop1_workspace_outputs = TaskResources(
60 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop1_model_path), "multiple": False}])
61 | )
62 | loop1_workspace = FunctionWorkspace(outputs=loop1_workspace_outputs)
63 | loop1_wp = FunctionWrapper(workspace=loop1_workspace, opener_wrapper=None)
64 |
65 | # loop 1 (no input)
66 | loop1_wp.execute(function=train)
67 | model = utils.load_model(path=loop1_wp._workspace.task_outputs[OutputIdentifiers.shared])
68 |
69 | assert model == {"i": 1, "total": 0}
70 | assert os.path.exists(loop1_model_path)
71 |
72 | loop2_model_path = workdir / "loop2model"
73 |
74 | loop2_workspace_inputs = TaskResources(
75 | json.dumps([{"id": InputIdentifiers.shared, "value": str(loop1_model_path), "multiple": True}])
76 | )
77 | loop2_workspace_outputs = TaskResources(
78 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop2_model_path), "multiple": False}])
79 | )
80 | loop2_workspace = FunctionWorkspace(inputs=loop2_workspace_inputs, outputs=loop2_workspace_outputs)
81 | loop2_wp = FunctionWrapper(workspace=loop2_workspace, opener_wrapper=None)
82 |
83 | # loop 2 (one model as input)
84 | loop2_wp.execute(function=train)
85 | model = utils.load_model(path=loop2_wp._workspace.task_outputs[OutputIdentifiers.shared])
86 | assert model == {"i": 2, "total": 1}
87 | assert os.path.exists(loop2_model_path)
88 |
89 | loop3_model_path = workdir / "loop2model"
90 | loop3_workspace_inputs = TaskResources(
91 | json.dumps(
92 | [
93 | {"id": InputIdentifiers.shared, "value": str(loop1_model_path), "multiple": True},
94 | {"id": InputIdentifiers.shared, "value": str(loop2_model_path), "multiple": True},
95 | ]
96 | )
97 | )
98 | loop3_workspace_outputs = TaskResources(
99 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop3_model_path), "multiple": False}])
100 | )
101 | loop3_workspace = FunctionWorkspace(inputs=loop3_workspace_inputs, outputs=loop3_workspace_outputs)
102 | loop3_wp = FunctionWrapper(workspace=loop3_workspace, opener_wrapper=None)
103 |
104 | # loop 3 (two models as input)
105 | loop3_wp.execute(function=train)
106 | model = utils.load_model(path=loop3_wp._workspace.task_outputs[OutputIdentifiers.shared])
107 | assert model == {"i": 3, "total": 3}
108 | assert os.path.exists(loop3_model_path)
109 |
110 | predictions_path = workdir / "predictions"
111 | predict_workspace_inputs = TaskResources(
112 | json.dumps([{"id": InputIdentifiers.shared, "value": str(loop3_model_path), "multiple": False}])
113 | )
114 | predict_workspace_outputs = TaskResources(
115 | json.dumps([{"id": OutputIdentifiers.predictions, "value": str(predictions_path), "multiple": False}])
116 | )
117 | predict_workspace = FunctionWorkspace(inputs=predict_workspace_inputs, outputs=predict_workspace_outputs)
118 | predict_wp = FunctionWrapper(workspace=predict_workspace, opener_wrapper=None)
119 |
120 | # predict
121 | predict_wp.execute(function=predict)
122 | pred = utils.load_predictions(path=predict_wp._workspace.task_outputs[OutputIdentifiers.predictions])
123 | assert pred == {"sum": 3}
124 |
125 | # metrics
126 | performance_path = workdir / "performance"
127 | metric_workspace_inputs = TaskResources(
128 | json.dumps([{"id": InputIdentifiers.predictions, "value": str(predictions_path), "multiple": False}])
129 | )
130 | metric_workspace_outputs = TaskResources(
131 | json.dumps([{"id": OutputIdentifiers.performance, "value": str(performance_path), "multiple": False}])
132 | )
133 | metric_workspace = FunctionWorkspace(
134 | inputs=metric_workspace_inputs,
135 | outputs=metric_workspace_outputs,
136 | )
137 | metrics_wp = FunctionWrapper(workspace=metric_workspace, opener_wrapper=opener.load_from_module())
138 | metrics_wp.execute(function=score)
139 | res = load_performance(path=metrics_wp._workspace.task_outputs[OutputIdentifiers.performance])
140 | assert res == 3.0
141 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7 |
8 |
9 |
10 | ## [1.0.0](https://github.com/Substra/substra-tools/releases/tag/1.0.0) - 2024-10-14
11 |
12 | ### Removed
13 |
14 | - Drop Python 3.9 support. ([#112](https://github.com/Substra/substra-tools/pull/112))
15 |
16 |
17 | ## [0.22.0](https://github.com/Substra/substra-tools/releases/tag/0.22.0) - 2024-09-04
18 |
19 | ### Changed
20 |
21 | - The Opener and Function workspace does not try to create folder for data samples anymore. ([#100](https://github.com/Substra/substra-tools/pull/100))
22 |
23 | ### Removed
24 |
25 | - Remove base Docker image. Substrafl now uses python-slim as base image. ([#101](https://github.com/Substra/substra-tools/pull/101))
26 |
27 |
28 | ## [0.21.4](https://github.com/Substra/substra-tools/releases/tag/0.21.4) - 2024-06-03
29 |
30 |
31 | No significant changes.
32 |
33 |
34 | ## [0.21.3](https://github.com/Substra/substra-tools/releases/tag/0.21.3) - 2024-03-27
35 |
36 |
37 | ### Changed
38 |
39 | - - Drop Python 3.8 support ([#90](https://github.com/Substra/substra-tools/pull/90))
40 | - - Depreciate `setup.py` in favour of `pyproject.toml` ([#92](https://github.com/Substra/substra-tools/pull/92))
41 |
42 |
43 | ## [0.21.2](https://github.com/Substra/substra-tools/releases/tag/0.21.2) - 2024-03-07
44 |
45 | ### Changed
46 |
47 | - Update dependencies
48 |
49 | ## [0.21.1](https://github.com/Substra/substra-tools/releases/tag/0.21.1) - 2024-02-26
50 |
51 | ### Changed
52 |
53 | - Updated dependencies
54 |
55 | ## [0.21.0](https://github.com/Substra/substra-tools/releases/tag/0.21.0) - 2023-10-06
56 |
57 | ### Changed
58 |
59 | - Remove `model` and `models` for input and output identifiers in tests. Replace by `shared` instead. ([#84](https://github.com/Substra/substra-tools/pull/84))
60 | - BREAKING: Remove minimal and workflow docker images ([#86](https://github.com/Substra/substra-tools/pull/86))
61 | - Remove python lib from Docker image ([#86](https://github.com/Substra/substra-tools/pull/86))
62 |
63 | ### Added
64 |
65 | - Support on Python 3.11 ([#85](https://github.com/Substra/substra-tools/pull/85))
66 | - Contributing, contributors & code of conduct files (#77)
67 |
68 | ## [0.20.0](https://github.com/Substra/substra-tools/releases/tag/0.20.0) - 2022-12-19
69 |
70 | ### Changed
71 |
72 | - Add optional argument to `register` decorator to choose a custom function name. Allow to register the same function several time with a different name (#74)
73 | - Rank is now passed in a task properties dictionary from the backend (instead of the rank argument) (#75)
74 |
75 | ## [0.19.0](https://github.com/Substra/substra-tools/releases/tag/0.19.0) - 2022-11-22
76 |
77 | ### CHANGED
78 |
79 | - BREAKING CHANGE (#65)
80 |
81 | - Register functions to substratools can be done with a decorator.
82 |
83 | ```py
84 | def my_function1:
85 | pass
86 |
87 | def my_function2:
88 | pass
89 |
90 | if __name__ == '__main__':
91 | tools.execute(my_function1, my_function2)
92 | ```
93 |
94 | become
95 |
96 | ```py
97 | @tools.register
98 | def my_function1:
99 | pass
100 |
101 | @tools.register
102 | def my_function2:
103 | pass
104 |
105 | if __name__ == '__main__':
106 | tools.execute()
107 | ```
108 |
109 | - BREAKING CHANGE (#63)
110 |
111 | - Rename algo to function.
112 | - `tools.algo.execute` become `tools.execute`
113 | - The previous algo class pass to the function `tools.algo.execute` is now several functions pass as arguments to `tools.execute`. The function given by the cli `--function-name` is executed.
114 |
115 | ```py
116 | if __name__ == '__main__':
117 | tools.algo.execute(MyAlgo())
118 | ```
119 |
120 | become
121 |
122 | ```py
123 | if __name__ == '__main__':
124 | tools.execute(my_function1, my_function2)
125 | ```
126 |
127 | ### Fixed
128 |
129 | - Remove depreciated `pytest-runner` from setup.py (#71)
130 | - Replace backslash by slash in TaskResources to fix windows compatibility (#70)
131 | - Update flake8 repository in pre-commit configuration (#69)
132 | - BREAKING CHANGE: Update substratools Docker image (#112)
133 |
134 | ## [0.18.0](https://github.com/Substra/substra-tools/releases/tag/0.18.0) - 2022-09-26
135 |
136 | ### Added
137 |
138 | - feat: allow CLI parameters to be read from a file
139 |
140 | ### Changed
141 |
142 | - BREAKING CHANGES:
143 |
144 | - the opener only exposes `get_data` and `fake_data` methods.
145 | - the results of the above method is passed under the `datasamples` keys within the `inputs` dict arg of all
146 | tools methods (train, predict, aggregate, score).
147 | - all method (train, predict, aggregate, score) now takes a `task_properties` argument (dict) in addition to
148 | `inputs` and `outputs`.
149 | - The `rank` of a task previously passed under the `rank` key within the inputs is now given in the `task_properties`
150 | dict under the `rank` key.
151 |
152 | - BREAKING CHANGE: The metric is now a generic algo, replace
153 |
154 | ```python
155 | import substratools as tools
156 |
157 | class MyMetric(tools.Metrics):
158 | # ...
159 |
160 | if __name__ == '__main__':
161 | tools.metrics.execute(MyMetric())
162 | ```
163 |
164 | by
165 |
166 | ```python
167 | import substratools as tools
168 |
169 | class MyMetric(tools.MetricAlgo):
170 | # ...
171 | if __name__ == '__main__':
172 | tools.algo.execute(MyMetric())
173 | ```
174 |
175 | ## [0.17.0](https://github.com/Substra/substra-tools/releases/tag/0.17.0) - 2022-09-19
176 |
177 | ### Changed
178 |
179 | - feat: all algo classes rely on a generic algo class
180 |
181 | ## [0.16.0](https://github.com/Substra/substra-tools/releases/tag/0.16.0) - 2022-09-12
182 |
183 | ### Changed
184 |
185 | - Remove documentation as it is not used. It will be replaced later on.
186 | - BREAKING CHANGES: the user must now pass the method name to execute within the dockerfile of both `algo` and
187 | `metric` under the `--method-name` argument. The method name still needs to be one of the `algo` or `metric`
188 | allowed method name: train, predict, aggregate, score.
189 |
190 | ```Dockerfile
191 | ENTRYPOINT ["python3", "metrics.py"]
192 | ```
193 |
194 | shall be replaced by:
195 |
196 | ```Dockerfile
197 | ENTRYPOINT ["python3", "metrics.py", "--method-name", "score"]
198 | ```
199 |
200 | - BREAKING CHANGES: rename connect-tools to substra-tools (except the github folder)
201 |
202 | ## [0.15.0](https://github.com/Substra/substra-tools/releases/tag/0.15.0) - 2022-08-29
203 |
204 | ### Changed
205 |
206 | - BREAKING CHANGES:
207 |
208 | - methods from algo, composite algo, aggregate and metrics now take `inputs` (TypeDict) and `outputs` (TypeDict) as arguments
209 | - the user must load and save all the inputs and outputs of those methods (except for the datasamples)
210 | - `load_predictions` and `get_predictions` methods have been removed from the opener
211 | - `load_trunk_model`, `save_trunk_model`, `load_head_model`, `save_head_model` have been removed from the `tools.CompositeAlgo` class
212 | - `load_model` and `save_model` have been removed from both `tools.Algo` and `tools.AggregateAlgo` classes
213 |
214 | ## [0.14.0](https://github.com/Substra/substra-tools/releases/tag/0.14.0) - 2022-08-09
215 |
216 | ### Changed
217 |
218 | - BREAKING CHANGE: drop Python 3.7 support
219 |
220 | ### Fixed
221 |
222 | - fix: metric with type np.float32() is not Json serializable #47
223 |
224 | ## [0.13.0](https://github.com/Substra/substra-tools/releases/tag/0.13.0) - 2022-05-22
225 |
226 | ### Changed
227 |
228 | - BREAKING CHANGE: change --debug (bool) to --log-level (str)
229 |
230 | ## [0.12.0](https://github.com/Substra/substra-tools/releases/tag/0.12.0) - 2022-04-29
231 |
232 | ### Fixed
233 |
234 | - nvidia rotating keys
235 |
236 | ### Changed
237 |
238 | - (BREAKING) algos receive arguments are generic inputs/outputs dict
239 |
240 | ## [0.11.0](https://github.com/Substra/substra-tools/releases/tag/0.11.0) - 2022-04-11
241 |
242 | ### Fixed
243 |
244 | - alias in pyhton 3.7 for python3
245 |
246 | ### Improved
247 |
248 | - ci: build docker images as part of CI checks
249 | - ci: push latest image from main branch
250 | - chore: make Dockerfiles independent from each other
251 |
--------------------------------------------------------------------------------
/substratools/function.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | import argparse
3 | import json
4 | import logging
5 | import os
6 | import sys
7 | from copy import deepcopy
8 | from typing import Any
9 | from typing import Callable
10 | from typing import Dict
11 | from typing import Optional
12 |
13 | from substratools import exceptions
14 | from substratools import opener
15 | from substratools import utils
16 | from substratools.exceptions import ExistingRegisteredFunctionError
17 | from substratools.exceptions import FunctionNotFoundError
18 | from substratools.task_resources import StaticInputIdentifiers
19 | from substratools.task_resources import TaskResources
20 | from substratools.workspace import FunctionWorkspace
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def _parser_add_default_arguments(parser):
26 | parser.add_argument(
27 | "--function-name",
28 | type=str,
29 | help="The name of the function to execute from the given file",
30 | )
31 | parser.add_argument(
32 | "-r",
33 | "--task-properties",
34 | type=str,
35 | default="{}",
36 | help="Define the task properties",
37 | ),
38 | parser.add_argument(
39 | "-d",
40 | "--fake-data",
41 | action="store_true",
42 | default=False,
43 | help="Enable fake data mode",
44 | )
45 | parser.add_argument(
46 | "--n-fake-samples",
47 | default=None,
48 | type=int,
49 | help="Number of fake samples if fake data is used.",
50 | )
51 | parser.add_argument(
52 | "--log-path",
53 | default=None,
54 | help="Define log filename path",
55 | )
56 | parser.add_argument(
57 | "--log-level",
58 | default="info",
59 | choices=utils.MAPPING_LOG_LEVEL.keys(),
60 | help="Choose log level",
61 | )
62 | parser.add_argument(
63 | "--inputs",
64 | type=str,
65 | default="[]",
66 | help="Inputs of the compute task",
67 | )
68 | parser.add_argument(
69 | "--outputs",
70 | type=str,
71 | default="[]",
72 | help="Outputs of the compute task",
73 | )
74 |
75 |
76 | class FunctionRegister:
77 | """Class to create a decorator to register function in substratools. The functions are registered in the _functions
78 | dictionary, with the function.__name__ as key.
79 | Register a function in substratools means that this function can be access by the function.execute functions through
80 | the --function-name CLI argument."""
81 |
82 | def __init__(self):
83 | self._functions = {}
84 |
85 | def __call__(self, function: Callable, function_name: Optional[str] = None):
86 | """Function called when using an instance of the class as a decorator.
87 |
88 | Args:
89 | function (Callable): function to register in substratools.
90 | function_name (str, optional): function name to register the given function.
91 | If None, function.__name__ is used for registration.
92 | Raises:
93 | ExistingRegisteredFunctionError: Raise if a function with the same function.__name__
94 | has already been registered in substratools.
95 |
96 | Returns:
97 | Callable: returns the function without decorator
98 | """
99 |
100 | function_name = function_name or function.__name__
101 | if function_name not in self._functions:
102 | self._functions[function_name] = function
103 | else:
104 | raise ExistingRegisteredFunctionError("A function with the same name is already registered.")
105 |
106 | return function
107 |
108 | def get_registered_functions(self):
109 | return self._functions
110 |
111 |
112 | # Instance of the decorator to store the function to register in memory.
113 | # Can be imported directly from substratools.
114 | register = FunctionRegister()
115 |
116 |
117 | class FunctionWrapper(object):
118 | """Wrapper to execute a function on the platform."""
119 |
120 | def __init__(self, workspace: FunctionWorkspace, opener_wrapper: Optional[opener.OpenerWrapper]):
121 | self._workspace = workspace
122 | self._opener_wrapper = opener_wrapper
123 |
124 | def _assert_outputs_exists(self, outputs: Dict[str, str]):
125 | for key, path in outputs.items():
126 | if os.path.isdir(path):
127 | raise exceptions.NotAFileError(f"Expected output file at {path}, found dir for output `{key}`")
128 | if not os.path.isfile(path):
129 | raise exceptions.MissingFileError(f"Output file {path} used to save argument `{key}` does not exists.")
130 |
131 | @utils.Timer(logger)
132 | def execute(
133 | self, function: Callable, task_properties: dict = {}, fake_data: bool = False, n_fake_samples: int = None
134 | ):
135 | """Execute a compute task"""
136 |
137 | # load inputs
138 | inputs = deepcopy(self._workspace.task_inputs)
139 |
140 | # load data from opener
141 | if self._opener_wrapper:
142 | loaded_datasamples = self._opener_wrapper.get_data(fake_data, n_fake_samples)
143 |
144 | if fake_data:
145 | logger.info("Using fake data with %i fake samples." % int(n_fake_samples))
146 |
147 | assert (
148 | StaticInputIdentifiers.datasamples.value not in inputs.keys()
149 | ), f"{StaticInputIdentifiers.datasamples.value} must be an input of kind `datasamples`"
150 | inputs.update({StaticInputIdentifiers.datasamples.value: loaded_datasamples})
151 |
152 | # load outputs
153 | outputs = deepcopy(self._workspace.task_outputs)
154 |
155 | logger.info("Launching task: executing `%s` function." % function.__name__)
156 | function(
157 | inputs=inputs,
158 | outputs=outputs,
159 | task_properties=task_properties,
160 | )
161 |
162 | self._assert_outputs_exists(
163 | self._workspace.task_outputs,
164 | )
165 |
166 |
167 | def _generate_function_cli():
168 | """Helper to generate a command line interface client."""
169 |
170 | def _function_from_args(args):
171 | inputs = TaskResources(args.inputs)
172 | outputs = TaskResources(args.outputs)
173 | log_path = args.log_path
174 | chainkeys_path = inputs.chainkeys_path
175 |
176 | workspace = FunctionWorkspace(
177 | log_path=log_path,
178 | chainkeys_path=chainkeys_path,
179 | inputs=inputs,
180 | outputs=outputs,
181 | )
182 |
183 | utils.configure_logging(workspace.log_path, log_level=args.log_level)
184 |
185 | opener_wrapper = opener.load_from_module(
186 | workspace=workspace,
187 | )
188 |
189 | return FunctionWrapper(workspace, opener_wrapper)
190 |
191 | def _user_func(args, function):
192 | function_wrapper = _function_from_args(args)
193 | function_wrapper.execute(
194 | function=function,
195 | task_properties=json.loads(args.task_properties),
196 | fake_data=args.fake_data,
197 | n_fake_samples=args.n_fake_samples,
198 | )
199 |
200 | parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
201 | _parser_add_default_arguments(parser)
202 | parser.set_defaults(func=_user_func)
203 |
204 | return parser
205 |
206 |
207 | def _get_function_from_name(functions: dict, function_name: str):
208 |
209 | if function_name not in functions:
210 | raise FunctionNotFoundError(
211 | f"The function {function_name} given as --function-name argument as not been found."
212 | )
213 |
214 | return functions[function_name]
215 |
216 |
217 | def save_performance(performance: Any, path: os.PathLike):
218 | with open(path, "w") as f:
219 | json.dump({"all": performance}, f)
220 |
221 |
222 | def load_performance(path: os.PathLike) -> Any:
223 | with open(path, "r") as f:
224 | performance = json.load(f)["all"]
225 | return performance
226 |
227 |
228 | def execute(sysargs=None):
229 | """Launch function command line interface."""
230 |
231 | cli = _generate_function_cli()
232 |
233 | sysargs = sysargs if sysargs is not None else sys.argv[1:]
234 | args = cli.parse_args(sysargs)
235 | function = _get_function_from_name(register.get_registered_functions(), args.function_name)
236 | args.func(args, function)
237 |
238 | return args
239 |
--------------------------------------------------------------------------------
/tests/test_aggregatealgo.py:
--------------------------------------------------------------------------------
1 | import json
2 | from os import PathLike
3 | from typing import Any
4 | from typing import List
5 | from typing import TypedDict
6 | from uuid import uuid4
7 |
8 | import pytest
9 |
10 | from substratools import exceptions
11 | from substratools import function
12 | from substratools import opener
13 | from substratools.task_resources import TaskResources
14 | from substratools.workspace import FunctionWorkspace
15 | from tests import utils
16 | from tests.utils import InputIdentifiers
17 | from tests.utils import OutputIdentifiers
18 |
19 |
20 | @pytest.fixture(autouse=True)
21 | def setup(valid_opener):
22 | pass
23 |
24 |
25 | @function.register
26 | def aggregate(
27 | inputs: TypedDict(
28 | "inputs",
29 | {InputIdentifiers.shared: List[PathLike]},
30 | ),
31 | outputs: TypedDict("outputs", {OutputIdentifiers.shared: PathLike}),
32 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}),
33 | ) -> None:
34 | if inputs:
35 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, []))
36 | else:
37 | models = []
38 |
39 | new_model = {"value": 0}
40 | for m in models:
41 | new_model["value"] += m["value"]
42 |
43 | utils.save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared))
44 |
45 |
46 | @function.register
47 | def aggregate_predict(
48 | inputs: TypedDict(
49 | "inputs",
50 | {
51 | InputIdentifiers.datasamples: Any,
52 | InputIdentifiers.shared: PathLike,
53 | },
54 | ),
55 | outputs: TypedDict("outputs", {OutputIdentifiers.shared: PathLike}),
56 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}),
57 | ):
58 | model = utils.load_model(path=inputs.get(OutputIdentifiers.shared))
59 |
60 | # Predict
61 | X = inputs.get(InputIdentifiers.datasamples)[0]
62 | pred = X * model["value"]
63 |
64 | # save predictions
65 | utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions))
66 |
67 |
68 | def no_saved_aggregate(inputs, outputs, task_properties):
69 | if inputs:
70 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, []))
71 | else:
72 | models = []
73 |
74 | new_model = {"value": 0}
75 | for m in models:
76 | new_model["value"] += m["value"]
77 |
78 | utils.no_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared))
79 |
80 |
81 | def wrong_saved_aggregate(inputs, outputs, task_properties):
82 | if inputs:
83 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, []))
84 | else:
85 | models = []
86 |
87 | new_model = {"value": 0}
88 | for m in models:
89 | new_model["value"] += m["value"]
90 |
91 | utils.wrong_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared))
92 |
93 |
94 | @pytest.fixture
95 | def create_models(workdir):
96 | model_a = {"value": 1}
97 | model_b = {"value": 2}
98 |
99 | model_dir = workdir / OutputIdentifiers.shared
100 | model_dir.mkdir()
101 |
102 | def _create_model(model_data):
103 | model_name = model_data["value"]
104 | filename = "{}.json".format(model_name)
105 | path = model_dir / filename
106 | path.write_text(json.dumps(model_data))
107 | return str(path)
108 |
109 | model_datas = [model_a, model_b]
110 | model_filenames = [_create_model(d) for d in model_datas]
111 |
112 | return model_datas, model_filenames
113 |
114 |
115 | def test_aggregate_no_model(valid_function_workspace):
116 | wp = function.FunctionWrapper(workspace=valid_function_workspace, opener_wrapper=None)
117 | wp.execute(function=aggregate)
118 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared])
119 | assert model["value"] == 0
120 |
121 |
122 | def test_aggregate_multiple_models(create_models, output_model_path):
123 | _, model_filenames = create_models
124 |
125 | workspace_inputs = TaskResources(
126 | json.dumps([{"id": InputIdentifiers.shared, "value": f, "multiple": True} for f in model_filenames])
127 | )
128 | workspace_outputs = TaskResources(
129 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}])
130 | )
131 |
132 | workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs)
133 | wp = function.FunctionWrapper(workspace, opener_wrapper=None)
134 |
135 | wp.execute(function=aggregate)
136 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared])
137 |
138 | assert model["value"] == 3
139 |
140 |
141 | @pytest.mark.parametrize(
142 | "fake_data,expected_pred,n_fake_samples",
143 | [
144 | (False, "X", None),
145 | (True, ["Xfake"], 1),
146 | ],
147 | )
148 | def test_predict(fake_data, expected_pred, n_fake_samples, create_models):
149 | _, model_filenames = create_models
150 |
151 | workspace_inputs = TaskResources(
152 | json.dumps([{"id": InputIdentifiers.shared, "value": model_filenames[0], "multiple": False}])
153 | )
154 | workspace_outputs = TaskResources(
155 | json.dumps([{"id": OutputIdentifiers.predictions, "value": model_filenames[0], "multiple": False}])
156 | )
157 |
158 | workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs)
159 |
160 | wp = function.FunctionWrapper(workspace, opener_wrapper=opener.load_from_module())
161 |
162 | wp.execute(function=aggregate_predict, fake_data=fake_data, n_fake_samples=n_fake_samples)
163 |
164 | pred = utils.load_predictions(wp._workspace.task_outputs[OutputIdentifiers.predictions])
165 | assert pred == expected_pred
166 |
167 |
168 | def test_execute_aggregate(output_model_path):
169 | assert not output_model_path.exists()
170 |
171 | outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]
172 |
173 | function.execute(sysargs=["--function-name", "aggregate", "--outputs", json.dumps(outputs)])
174 | assert output_model_path.exists()
175 | output_model_path.unlink()
176 | function.execute(
177 | sysargs=["--function-name", "aggregate", "--outputs", json.dumps(outputs), "--log-level", "debug"],
178 | )
179 | assert output_model_path.exists()
180 |
181 |
182 | def test_execute_aggregate_multiple_models(workdir, create_models, output_model_path):
183 | _, model_filenames = create_models
184 |
185 | assert not output_model_path.exists()
186 |
187 | inputs = [
188 | {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames
189 | ]
190 | outputs = [
191 | {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False},
192 | ]
193 | options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)]
194 |
195 | command = ["--function-name", "aggregate"]
196 | command.extend(options)
197 |
198 | function.execute(sysargs=command)
199 | assert output_model_path.exists()
200 | with open(output_model_path, "r") as f:
201 | model = json.load(f)
202 | assert model["value"] == 3
203 |
204 |
205 | def test_execute_predict(workdir, create_models, output_model_path, valid_opener_script):
206 | _, model_filenames = create_models
207 | assert not output_model_path.exists()
208 |
209 | inputs = [
210 | {"id": InputIdentifiers.shared, "value": str(workdir / model_name), "multiple": True}
211 | for model_name in model_filenames
212 | ]
213 | outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]
214 | options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)]
215 | command = ["--function-name", "aggregate"]
216 | command.extend(options)
217 | function.execute(sysargs=command)
218 | assert output_model_path.exists()
219 |
220 | # do predict on output model
221 | pred_path = workdir / str(uuid4())
222 | assert not pred_path.exists()
223 |
224 | pred_inputs = [
225 | {"id": InputIdentifiers.shared, "value": str(output_model_path), "multiple": False},
226 | {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False},
227 | ]
228 | pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}]
229 | pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)]
230 |
231 | function.execute(sysargs=["--function-name", "predict"] + pred_options)
232 | assert pred_path.exists()
233 | with open(pred_path, "r") as f:
234 | pred = json.load(f)
235 | assert pred == "XXX"
236 | pred_path.unlink()
237 |
238 |
239 | @pytest.mark.parametrize("function_to_run", (no_saved_aggregate, wrong_saved_aggregate))
240 | def test_model_check(function_to_run, valid_function_workspace):
241 | wp = function.FunctionWrapper(valid_function_workspace, opener_wrapper=None)
242 |
243 | with pytest.raises(exceptions.MissingFileError):
244 | wp.execute(function=function_to_run)
245 |
--------------------------------------------------------------------------------
/Substra-logo-white.svg:
--------------------------------------------------------------------------------
1 |
18 |
--------------------------------------------------------------------------------
/Substra-logo-colour.svg:
--------------------------------------------------------------------------------
1 |
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | Copyright 2018-2022 Owkin, Inc.
180 |
181 | Licensed under the Apache License, Version 2.0 (the "License");
182 | you may not use this file except in compliance with the License.
183 | You may obtain a copy of the License at
184 |
185 | http://www.apache.org/licenses/LICENSE-2.0
186 |
187 | Unless required by applicable law or agreed to in writing, software
188 | distributed under the License is distributed on an "AS IS" BASIS,
189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
190 | See the License for the specific language governing permissions and
191 | limitations under the License.
192 |
--------------------------------------------------------------------------------
/tests/test_compositealgo.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Any
4 | from typing import Optional
5 | from typing import TypedDict
6 |
7 | import pytest
8 |
9 | from substratools import exceptions
10 | from substratools import function
11 | from substratools import opener
12 | from substratools.task_resources import TaskResources
13 | from substratools.workspace import FunctionWorkspace
14 | from tests import utils
15 | from tests.utils import InputIdentifiers
16 | from tests.utils import OutputIdentifiers
17 |
18 |
19 | @pytest.fixture(autouse=True)
20 | def setup(valid_opener):
21 | pass
22 |
23 |
24 | def fake_data_train(inputs: dict, outputs: dict, task_properties: dict):
25 | utils.save_model(model=inputs[InputIdentifiers.datasamples][0], path=outputs["local"])
26 | utils.save_model(model=inputs[InputIdentifiers.datasamples][1], path=outputs["shared"])
27 |
28 |
29 | def fake_data_predict(inputs: dict, outputs: dict, task_properties: dict) -> None:
30 | utils.save_model(model=inputs[InputIdentifiers.datasamples][0], path=outputs["predictions"])
31 |
32 |
33 | def train(
34 | inputs: TypedDict(
35 | "inputs",
36 | {
37 | InputIdentifiers.datasamples: Any,
38 | InputIdentifiers.local: Optional[os.PathLike],
39 | InputIdentifiers.shared: Optional[os.PathLike],
40 | },
41 | ),
42 | outputs: TypedDict(
43 | "outputs",
44 | {
45 | OutputIdentifiers.local: os.PathLike,
46 | OutputIdentifiers.shared: os.PathLike,
47 | },
48 | ),
49 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}),
50 | ):
51 | # init phase
52 | # load models
53 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local))
54 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared))
55 |
56 | if head_model and trunk_model:
57 | new_head_model = dict(head_model)
58 | new_trunk_model = dict(trunk_model)
59 | else:
60 | new_head_model = {"value": 0}
61 | new_trunk_model = {"value": 0}
62 |
63 | # train models
64 | new_head_model["value"] += 1
65 | new_trunk_model["value"] -= 1
66 |
67 | # save model
68 | utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local))
69 | utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared))
70 |
71 |
72 | def predict(
73 | inputs: TypedDict(
74 | "inputs",
75 | {
76 | InputIdentifiers.datasamples: Any,
77 | InputIdentifiers.local: os.PathLike,
78 | InputIdentifiers.shared: os.PathLike,
79 | },
80 | ),
81 | outputs: TypedDict(
82 | "outputs",
83 | {
84 | OutputIdentifiers.predictions: os.PathLike,
85 | },
86 | ),
87 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}),
88 | ):
89 |
90 | # init phase
91 | # load models
92 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local))
93 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared))
94 |
95 | pred = list(range(head_model["value"], trunk_model["value"]))
96 |
97 | # save predictions
98 | utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions))
99 |
100 |
101 | def no_saved_trunk_train(inputs, outputs, task_properties):
102 | # init phase
103 | # load models
104 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local))
105 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared))
106 |
107 | if head_model and trunk_model:
108 | new_head_model = dict(head_model)
109 | new_trunk_model = dict(trunk_model)
110 | else:
111 | new_head_model = {"value": 0}
112 | new_trunk_model = {"value": 0}
113 |
114 | # train models
115 | new_head_model["value"] += 1
116 | new_trunk_model["value"] -= 1
117 |
118 | # save model
119 | utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local))
120 | utils.no_save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared))
121 |
122 |
123 | def no_saved_head_train(inputs, outputs, task_properties):
124 | # init phase
125 | # load models
126 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local))
127 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared))
128 |
129 | if head_model and trunk_model:
130 | new_head_model = dict(head_model)
131 | new_trunk_model = dict(trunk_model)
132 | else:
133 | new_head_model = {"value": 0}
134 | new_trunk_model = {"value": 0}
135 |
136 | # train models
137 | new_head_model["value"] += 1
138 | new_trunk_model["value"] -= 1
139 |
140 | # save model
141 | utils.no_save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local))
142 | utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared))
143 |
144 |
145 | def wrong_saved_trunk_train(inputs, outputs, task_properties):
146 | # init phase
147 | # load models
148 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local))
149 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared))
150 |
151 | if head_model and trunk_model:
152 | new_head_model = dict(head_model)
153 | new_trunk_model = dict(trunk_model)
154 | else:
155 | new_head_model = {"value": 0}
156 | new_trunk_model = {"value": 0}
157 |
158 | # train models
159 | new_head_model["value"] += 1
160 | new_trunk_model["value"] -= 1
161 |
162 | # save model
163 | utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local))
164 | utils.wrong_save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared))
165 |
166 |
167 | def wrong_saved_head_train(inputs, outputs, task_properties):
168 | # init phase
169 | # load models
170 | head_model = utils.load_model(path=inputs.get(InputIdentifiers.local))
171 | trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared))
172 |
173 | if head_model and trunk_model:
174 | new_head_model = dict(head_model)
175 | new_trunk_model = dict(trunk_model)
176 | else:
177 | new_head_model = {"value": 0}
178 | new_trunk_model = {"value": 0}
179 |
180 | # train models
181 | new_head_model["value"] += 1
182 | new_trunk_model["value"] -= 1
183 |
184 | # save model
185 | utils.wrong_save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local))
186 | utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared))
187 |
188 |
189 | @pytest.fixture
190 | def train_outputs(output_model_path, output_model_path_2):
191 | outputs = TaskResources(
192 | json.dumps(
193 | [
194 | {"id": "local", "value": str(output_model_path), "multiple": False},
195 | {"id": "shared", "value": str(output_model_path_2), "multiple": False},
196 | ]
197 | )
198 | )
199 | return outputs
200 |
201 |
202 | @pytest.fixture
203 | def composite_inputs(create_models):
204 | _, local_path, shared_path = create_models
205 | inputs = TaskResources(
206 | json.dumps(
207 | [
208 | {"id": InputIdentifiers.local, "value": str(local_path), "multiple": False},
209 | {"id": InputIdentifiers.shared, "value": str(shared_path), "multiple": False},
210 | ]
211 | )
212 | )
213 |
214 | return inputs
215 |
216 |
217 | @pytest.fixture
218 | def predict_outputs(output_model_path):
219 | outputs = TaskResources(
220 | json.dumps([{"id": OutputIdentifiers.predictions, "value": str(output_model_path), "multiple": False}])
221 | )
222 | return outputs
223 |
224 |
225 | @pytest.fixture
226 | def create_models(workdir):
227 | head_model = {"value": 1}
228 | trunk_model = {"value": -1}
229 |
230 | def _create_model(model_data, name):
231 | filename = "{}.json".format(name)
232 | path = workdir / filename
233 | path.write_text(json.dumps(model_data))
234 | return path
235 |
236 | head_path = _create_model(head_model, "head")
237 | trunk_path = _create_model(trunk_model, "trunk")
238 |
239 | return (
240 | [head_model, trunk_model],
241 | head_path,
242 | trunk_path,
243 | )
244 |
245 |
246 | def test_train_no_model(train_outputs):
247 |
248 | dummy_train_workspace = FunctionWorkspace(outputs=train_outputs)
249 | dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, None)
250 | dummy_train_wrapper.execute(function=train)
251 | local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["local"])
252 | shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["shared"])
253 |
254 | assert local_model["value"] == 1
255 | assert shared_model["value"] == -1
256 |
257 |
258 | def test_train_input_head_trunk_models(composite_inputs, train_outputs):
259 |
260 | dummy_train_workspace = FunctionWorkspace(inputs=composite_inputs, outputs=train_outputs)
261 | dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, None)
262 | dummy_train_wrapper.execute(function=train)
263 | local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["local"])
264 | shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["shared"])
265 |
266 | assert local_model["value"] == 2
267 | assert shared_model["value"] == -2
268 |
269 |
270 | @pytest.mark.parametrize("n_fake_samples", (0, 1, 2))
271 | def test_train_fake_data(train_outputs, n_fake_samples):
272 | _opener = opener.load_from_module()
273 | dummy_train_workspace = FunctionWorkspace(outputs=train_outputs)
274 | dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, _opener)
275 | dummy_train_wrapper.execute(function=fake_data_train, fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)
276 |
277 | local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.local])
278 | shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.shared])
279 |
280 | assert local_model == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[0]
281 | assert shared_model == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[1]
282 |
283 |
284 | @pytest.mark.parametrize("n_fake_samples", (0, 1, 2))
285 | def test_predict_fake_data(composite_inputs, predict_outputs, n_fake_samples):
286 | _opener = opener.load_from_module()
287 | dummy_train_workspace = FunctionWorkspace(inputs=composite_inputs, outputs=predict_outputs)
288 | dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, _opener)
289 | dummy_train_wrapper.execute(
290 | function=fake_data_predict, fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples
291 | )
292 |
293 | predictions = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.predictions])
294 |
295 | assert predictions == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[0]
296 |
297 |
298 | @pytest.mark.parametrize(
299 | "function_to_run",
300 | (
301 | no_saved_head_train,
302 | no_saved_trunk_train,
303 | wrong_saved_head_train,
304 | wrong_saved_trunk_train,
305 | ),
306 | )
307 | def test_model_check(function_to_run, train_outputs):
308 | dummy_train_workspace = FunctionWorkspace(outputs=train_outputs)
309 | wp = function.FunctionWrapper(workspace=dummy_train_workspace, opener_wrapper=None)
310 |
311 | with pytest.raises(exceptions.MissingFileError):
312 | wp.execute(function_to_run)
313 |
--------------------------------------------------------------------------------
/tests/test_function.py:
--------------------------------------------------------------------------------
1 | import json
2 | import shutil
3 | from os import PathLike
4 | from pathlib import Path
5 | from typing import Any
6 | from typing import List
7 | from typing import Optional
8 | from typing import Tuple
9 | from typing import TypedDict
10 |
11 | import pytest
12 |
13 | from substratools import exceptions
14 | from substratools import function
15 | from substratools import opener
16 | from substratools.task_resources import StaticInputIdentifiers
17 | from substratools.task_resources import TaskResources
18 | from substratools.workspace import FunctionWorkspace
19 | from tests import utils
20 | from tests.utils import InputIdentifiers
21 | from tests.utils import OutputIdentifiers
22 |
23 |
24 | @pytest.fixture(autouse=True)
25 | def setup(valid_opener):
26 | pass
27 |
28 |
29 | @function.register
30 | def train(
31 | inputs: TypedDict(
32 | "inputs",
33 | {
34 | InputIdentifiers.datasamples: Tuple[List["str"], List[int]], # cf valid_opener_code
35 | InputIdentifiers.shared: Optional[
36 | PathLike
37 | ], # inputs contains a dict where keys are identifiers and values are paths on the disk
38 | },
39 | ),
40 | outputs: TypedDict(
41 | "outputs", {OutputIdentifiers.shared: PathLike}
42 | ), # outputs contains a dict where keys are identifiers and values are paths on disk
43 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}),
44 | ) -> None:
45 | # TODO: checks on data
46 | # load models
47 | if inputs:
48 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, []))
49 | else:
50 | models = []
51 | # init model
52 | new_model = {"value": 0}
53 |
54 | # train (just add the models values)
55 | for m in models:
56 | assert isinstance(m, dict)
57 | assert "value" in m
58 | new_model["value"] += m["value"]
59 |
60 | # save model
61 | utils.save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared))
62 |
63 |
64 | @function.register
65 | def predict(
66 | inputs: TypedDict("inputs", {InputIdentifiers.datasamples: Any, InputIdentifiers.shared: List[PathLike]}),
67 | outputs: TypedDict("outputs", {OutputIdentifiers.predictions: PathLike}),
68 | task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}),
69 | ) -> None:
70 | # TODO: checks on data
71 |
72 | # load_model
73 | model = utils.load_model(path=inputs.get(InputIdentifiers.shared))
74 |
75 | # predict
76 | X = inputs.get(InputIdentifiers.datasamples)[0]
77 | pred = X * model["value"]
78 |
79 | # save predictions
80 | utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions))
81 |
82 |
83 | @function.register
84 | def no_saved_train(inputs, outputs, task_properties):
85 | # TODO: checks on data
86 | # load models
87 | if inputs:
88 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, []))
89 | else:
90 | models = []
91 | # init model
92 | new_model = {"value": 0}
93 |
94 | # train (just add the models values)
95 | for m in models:
96 | assert isinstance(m, dict)
97 | assert "value" in m
98 | new_model["value"] += m["value"]
99 |
100 | # save model
101 | utils.no_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared))
102 |
103 |
104 | @function.register
105 | def wrong_saved_train(inputs, outputs, task_properties):
106 | # TODO: checks on data
107 | # load models
108 | if inputs:
109 | models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, []))
110 | else:
111 | models = []
112 | # init model
113 | new_model = {"value": 0}
114 |
115 | # train (just add the models values)
116 | for m in models:
117 | assert isinstance(m, dict)
118 | assert "value" in m
119 | new_model["value"] += m["value"]
120 |
121 | # save model
122 | utils.wrong_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared))
123 |
124 |
125 | @pytest.fixture
126 | def create_models(workdir):
127 | model_a = {"value": 1}
128 | model_b = {"value": 2}
129 |
130 | model_dir = workdir / "model"
131 | model_dir.mkdir()
132 |
133 | def _create_model(model_data):
134 | model_name = model_data["value"]
135 | filename = "{}.json".format(model_name)
136 | path = model_dir / filename
137 | path.write_text(json.dumps(model_data))
138 | return str(path)
139 |
140 | model_datas = [model_a, model_b]
141 | model_filenames = [_create_model(d) for d in model_datas]
142 |
143 | return model_datas, model_filenames
144 |
145 |
146 | def test_train_no_model(valid_function_workspace):
147 | wp = function.FunctionWrapper(valid_function_workspace, opener_wrapper=None)
148 | wp.execute(function=train)
149 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared])
150 | assert model["value"] == 0
151 |
152 |
153 | def test_train_multiple_models(output_model_path, create_models):
154 | _, model_filenames = create_models
155 |
156 | workspace_inputs = TaskResources(
157 | json.dumps([{"id": InputIdentifiers.shared, "value": str(f), "multiple": True} for f in model_filenames])
158 | )
159 | workspace_outputs = TaskResources(
160 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}])
161 | )
162 |
163 | workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs)
164 | wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=None)
165 |
166 | wp.execute(function=train)
167 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared])
168 |
169 | assert model["value"] == 3
170 |
171 |
172 | def test_train_fake_data(output_model_path):
173 | workspace_outputs = TaskResources(
174 | json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}])
175 | )
176 |
177 | workspace = FunctionWorkspace(outputs=workspace_outputs)
178 | wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=None)
179 | wp.execute(function=train, fake_data=True, n_fake_samples=2)
180 | model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared])
181 | assert model["value"] == 0
182 |
183 |
184 | @pytest.mark.parametrize(
185 | "fake_data,expected_pred,n_fake_samples",
186 | [
187 | (False, "X", None),
188 | (True, ["Xfake"], 1),
189 | ],
190 | )
191 | def test_predict(fake_data, expected_pred, n_fake_samples, create_models, output_model_path):
192 | _, model_filenames = create_models
193 |
194 | workspace_inputs = TaskResources(
195 | json.dumps([{"id": InputIdentifiers.shared, "value": model_filenames[0], "multiple": False}])
196 | )
197 | workspace_outputs = TaskResources(
198 | json.dumps([{"id": OutputIdentifiers.predictions, "value": str(output_model_path), "multiple": False}])
199 | )
200 |
201 | workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs)
202 | wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=opener.load_from_module())
203 | wp.execute(function=predict, fake_data=fake_data, n_fake_samples=n_fake_samples)
204 |
205 | pred = utils.load_predictions(wp._workspace.task_outputs["predictions"])
206 | assert pred == expected_pred
207 |
208 |
209 | def test_execute_train(workdir, output_model_path):
210 | inputs = [
211 | {
212 | "id": StaticInputIdentifiers.datasamples.value,
213 | "value": str(workdir / "datasamples_unused"),
214 | "multiple": True,
215 | },
216 | ]
217 | outputs = [
218 | {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False},
219 | ]
220 | options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)]
221 |
222 | assert not output_model_path.exists()
223 |
224 | function.execute(sysargs=["--function-name", "train"] + options)
225 | assert output_model_path.exists()
226 |
227 | function.execute(
228 | sysargs=["--function-name", "train", "--fake-data", "--n-fake-samples", "1", "--outputs", json.dumps(outputs)]
229 | )
230 | assert output_model_path.exists()
231 |
232 | function.execute(sysargs=["--function-name", "train", "--log-level", "debug"] + options)
233 | assert output_model_path.exists()
234 |
235 |
236 | def test_execute_train_multiple_models(workdir, output_model_path, create_models):
237 | _, model_filenames = create_models
238 |
239 | output_model_path = Path(output_model_path)
240 |
241 | assert not output_model_path.exists()
242 | pred_path = workdir / "pred"
243 | assert not pred_path.exists()
244 |
245 | inputs = [
246 | {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames
247 | ]
248 |
249 | outputs = [
250 | {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False},
251 | ]
252 | options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)]
253 |
254 | command = ["--function-name", "train"]
255 | command.extend(options)
256 |
257 | function.execute(sysargs=command)
258 | assert output_model_path.exists()
259 | with open(output_model_path, "r") as f:
260 | model = json.load(f)
261 | assert model["value"] == 3
262 |
263 | assert not pred_path.exists()
264 |
265 |
266 | def test_execute_predict(workdir, output_model_path, create_models, valid_opener_script):
267 | _, model_filenames = create_models
268 | pred_path = workdir / "pred"
269 | train_inputs = [
270 | {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames
271 | ]
272 |
273 | train_outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]
274 | train_options = ["--inputs", json.dumps(train_inputs), "--outputs", json.dumps(train_outputs)]
275 |
276 | output_model_path = Path(output_model_path)
277 | # first train models
278 | assert not pred_path.exists()
279 | command = ["--function-name", "train"]
280 | command.extend(train_options)
281 | function.execute(sysargs=command)
282 | assert output_model_path.exists()
283 |
284 | # do predict on output model
285 | pred_inputs = [
286 | {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False},
287 | {"id": InputIdentifiers.shared, "value": str(output_model_path), "multiple": False},
288 | ]
289 | pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}]
290 | pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)]
291 |
292 | assert not pred_path.exists()
293 | function.execute(sysargs=["--function-name", "predict"] + pred_options)
294 | assert pred_path.exists()
295 | with open(pred_path, "r") as f:
296 | pred = json.load(f)
297 | assert pred == "XXX"
298 | pred_path.unlink()
299 |
300 | # do predict with different model paths
301 | input_models_dir = workdir / "other_models"
302 | input_models_dir.mkdir()
303 | input_model_path = input_models_dir / "supermodel"
304 | shutil.move(output_model_path, input_model_path)
305 |
306 | pred_inputs = [
307 | {"id": InputIdentifiers.shared, "value": str(input_model_path), "multiple": False},
308 | {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False},
309 | ]
310 | pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}]
311 | pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)]
312 |
313 | assert not pred_path.exists()
314 | function.execute(sysargs=["--function-name", "predict"] + pred_options)
315 | assert pred_path.exists()
316 | with open(pred_path, "r") as f:
317 | pred = json.load(f)
318 | assert pred == "XXX"
319 |
320 |
321 | @pytest.mark.parametrize("function_to_run", (no_saved_train, wrong_saved_train))
322 | def test_model_check(valid_function_workspace, function_to_run):
323 | wp = function.FunctionWrapper(workspace=valid_function_workspace, opener_wrapper=None)
324 |
325 | with pytest.raises(exceptions.MissingFileError):
326 | wp.execute(function=function_to_run)
327 |
328 |
329 | def test_function_not_found():
330 | with pytest.raises(exceptions.FunctionNotFoundError):
331 | function.execute(sysargs=["--function-name", "imaginary_function"])
332 |
333 |
334 | def test_function_name_already_register():
335 | @function.register
336 | def fake_function():
337 | pass
338 |
339 | with pytest.raises(exceptions.ExistingRegisteredFunctionError):
340 |
341 | @function.register
342 | def fake_function():
343 | pass
344 |
--------------------------------------------------------------------------------